Shortcuts

Source code for catalyst.callbacks.pruning

from typing import Callable, List, Optional, TYPE_CHECKING, Union
import warnings

from catalyst.core.callback import Callback, CallbackOrder
from catalyst.utils.pruning import get_pruning_fn, prune_model, remove_reparametrization

if TYPE_CHECKING:
    from catalyst.core.runner import IRunner


[docs]class PruningCallback(Callback): """This callback prunes network parameters during and/or after training. Args: pruning_fn: function from torch.nn.utils.prune module or your based on BasePruningMethod. Can be string e.g. `"l1_unstructured"`. See pytorch docs for more details. amount: quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. keys_to_prune: list of strings. Determines which tensor in modules will be pruned. prune_on_epoch_end: bool flag determines call or not to call pruning_fn on epoch end. prune_on_stage_end: bool flag determines call or not to call pruning_fn on stage end. remove_reparametrization_on_stage_end: if True then all reparametrization pre-hooks and tensors with mask will be removed on stage end. layers_to_prune: list of strings - module names to be pruned. If None provided then will try to prune every module in model. dim: if you are using structured pruning method you need to specify dimension. l_norm: if you are using ln_structured you need to specify l_norm. """ def __init__( self, pruning_fn: Union[Callable, str], amount: Union[int, float], keys_to_prune: Optional[List[str]] = None, prune_on_epoch_end: Optional[bool] = False, prune_on_stage_end: Optional[bool] = True, remove_reparametrization_on_stage_end: Optional[bool] = True, layers_to_prune: Optional[List[str]] = None, dim: Optional[int] = None, l_norm: Optional[int] = None, ) -> None: """Init method for pruning callback""" super().__init__(CallbackOrder.ExternalExtra) self.pruning_fn = get_pruning_fn(pruning_fn=pruning_fn, dim=dim, l_norm=l_norm) if keys_to_prune is None: keys_to_prune = ["weight"] self.prune_on_epoch_end = prune_on_epoch_end self.prune_on_stage_end = prune_on_stage_end if not (prune_on_stage_end or prune_on_epoch_end): warnings.warn( "Warning!" "You disabled pruning pruning both on epoch and stage end." "Model won't be pruned by this callback." ) self.remove_reparametrization_on_stage_end = remove_reparametrization_on_stage_end self.keys_to_prune = keys_to_prune self.amount = amount self.layers_to_prune = layers_to_prune def on_epoch_end(self, runner: "IRunner") -> None: """Event handler. Active if prune_on_epoch_end is True. Args: runner: runner for your experiment """ if self.prune_on_epoch_end and runner.stage_epoch_step != runner.stage_epoch_len: prune_model( model=runner.model, pruning_fn=self.pruning_fn, keys_to_prune=self.keys_to_prune, amount=self.amount, layers_to_prune=self.layers_to_prune, ) def on_stage_end(self, runner: "IRunner") -> None: """Event handler. Active if prune_on_stage_end or remove_reparametrization is True. Args: runner: runner for your experiment """ if self.prune_on_stage_end: prune_model( model=runner.model, pruning_fn=self.pruning_fn, keys_to_prune=self.keys_to_prune, amount=self.amount, layers_to_prune=self.layers_to_prune, ) if self.remove_reparametrization_on_stage_end: remove_reparametrization( model=runner.model, keys_to_prune=self.keys_to_prune, layers_to_prune=self.layers_to_prune, )
__all__ = ["PruningCallback"]