Shortcuts

Source code for catalyst.callbacks.pruning

from typing import Callable, List, Optional, TYPE_CHECKING, Union
import warnings

from torch.nn.utils import prune

from catalyst.core.callback import Callback, CallbackOrder
from catalyst.utils.pruning import prune_model, remove_reparametrization

if TYPE_CHECKING:
    from catalyst.core.runner import IRunner

PRUNING_FN = {  # noqa: WPS407
    "l1_unstructured": prune.l1_unstructured,
    "random_unstructured": prune.random_unstructured,
    "ln_structured": prune.ln_structured,
    "random_structured": prune.random_structured,
}


def _wrap_pruning_fn(pruning_fn, *args, **kwargs):
    return lambda module, name, amount: pruning_fn(
        module, name, amount, *args, **kwargs
    )


[docs]class PruningCallback(Callback): """ Pruning Callback This callback is designed to prune network parameters during and/or after training. """
[docs] def __init__( self, pruning_fn: Union[Callable, str], keys_to_prune: Optional[List[str]] = None, amount: Optional[Union[int, float]] = 0.5, prune_on_epoch_end: Optional[bool] = False, prune_on_stage_end: Optional[bool] = True, remove_reparametrization_on_stage_end: Optional[bool] = True, reinitialize_after_pruning: Optional[bool] = False, layers_to_prune: Optional[List[str]] = None, dim: Optional[int] = None, l_norm: Optional[int] = None, ) -> None: """Init method for pruning callback Args: pruning_fn: function from torch.nn.utils.prune module or your based on BasePruningMethod. Can be string e.g. `"l1_unstructured"`. See pytorch docs for more details. keys_to_prune: list of strings. Determines which tensor in modules will be pruned. amount: quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. prune_on_epoch_end: bool flag determines call or not to call pruning_fn on epoch end. prune_on_stage_end: bool flag determines call or not to call pruning_fn on stage end. remove_reparametrization_on_stage_end: if True then all reparametrization pre-hooks and tensors with mask will be removed on stage end. reinitialize_after_pruning: if True then will reinitialize model after pruning. (Lottery Ticket Hypothesis) layers_to_prune: list of strings - module names to be pruned. If None provided then will try to prune every module in model. dim: if you are using structured pruning method you need to specify dimension. l_norm: if you are using ln_structured you need to specify l_norm. """ super().__init__(CallbackOrder.External) if isinstance(pruning_fn, str): if pruning_fn not in PRUNING_FN.keys(): raise Exception( f"Pruning function should be in {PRUNING_FN.keys()}, " "global pruning is not currently support." ) if "unstructured" not in pruning_fn: if dim is None: raise Exception( "If you are using structured pruning you" "need to specify dim in callback args" ) if pruning_fn == "ln_structured": if l_norm is None: raise Exception( "If you are using ln_unstructured you" "need to specify n in callback args" ) self.pruning_fn = _wrap_pruning_fn( prune.ln_structured, dim=dim, n=l_norm ) else: self.pruning_fn = _wrap_pruning_fn( PRUNING_FN[pruning_fn], dim=dim ) else: # unstructured self.pruning_fn = PRUNING_FN[pruning_fn] else: self.pruning_fn = pruning_fn if keys_to_prune is None: keys_to_prune = ["weight"] self.prune_on_epoch_end = prune_on_epoch_end self.prune_on_stage_end = prune_on_stage_end if not (prune_on_stage_end or prune_on_epoch_end): warnings.warn( "Warning!" "You disabled pruning pruning both on epoch and stage end." "Model won't be pruned by this callback." ) self.remove_reparametrization_on_stage_end = ( remove_reparametrization_on_stage_end ) self.keys_to_prune = keys_to_prune self.amount = amount self.reinitialize_after_pruning = reinitialize_after_pruning self.layers_to_prune = layers_to_prune
[docs] def on_epoch_end(self, runner: "IRunner") -> None: """ On epoch end action. Active if prune_on_epoch_end is True. Args: runner: runner for your experiment """ if self.prune_on_epoch_end and runner.num_epochs != runner.epoch: prune_model( model=runner.model, pruning_fn=self.pruning_fn, keys_to_prune=self.keys_to_prune, amount=self.amount, layers_to_prune=self.layers_to_prune, reinitialize_after_pruning=self.reinitialize_after_pruning, )
[docs] def on_stage_end(self, runner: "IRunner") -> None: """ On stage end action. Active if prune_on_stage_end or remove_reparametrization is True. Args: runner: runner for your experiment """ if self.prune_on_stage_end: prune_model( model=runner.model, pruning_fn=self.pruning_fn, keys_to_prune=self.keys_to_prune, amount=self.amount, layers_to_prune=self.layers_to_prune, reinitialize_after_pruning=self.reinitialize_after_pruning, ) if self.remove_reparametrization_on_stage_end: remove_reparametrization( model=runner.model, keys_to_prune=self.keys_to_prune, layers_to_prune=self.layers_to_prune, )
__all__ = ["PruningCallback"]