Shortcuts

Source code for catalyst.utils.pruning

from typing import Callable, List, Optional, Union

from torch.nn import Module
from torch.nn.utils import prune

from catalyst.utils.torch import get_nn_from_ddp_module

PRUNING_FN = {  # noqa: WPS407
    "l1_unstructured": prune.l1_unstructured,
    "random_unstructured": prune.random_unstructured,
    "ln_structured": prune.ln_structured,
    "random_structured": prune.random_structured,
}


def _wrap_pruning_fn(pruning_fn, *args, **kwargs):
    return lambda module, name, amount: pruning_fn(module, name, amount, *args, **kwargs)


[docs]def get_pruning_fn( pruning_fn: Union[str, Callable], dim: int = None, l_norm: int = None ) -> Callable: """[summary] Args: pruning_fn (Union[str, Callable]): function from torch.nn.utils.prune module or your based on BasePruningMethod. Can be string e.g. `"l1_unstructured"`. See pytorch docs for more details. dim (int, optional): if you are using structured pruning method you need to specify dimension. Defaults to None. l_norm (int, optional): if you are using ln_structured you need to specify l_norm. Defaults to None. Raises: ValueError: If ``dim`` or ``l_norm`` is not defined when it's required. Returns: Callable: pruning_fn """ if isinstance(pruning_fn, str): if pruning_fn not in PRUNING_FN.keys(): raise ValueError( f"Pruning function should be in {PRUNING_FN.keys()}, " "global pruning is not currently support." ) if "unstructured" not in pruning_fn: if dim is None: raise ValueError( "If you are using structured pruning you" "need to specify dim in args" ) if pruning_fn == "ln_structured": if l_norm is None: raise ValueError( "If you are using ln_unstructured you" "need to specify l_norm in args" ) pruning_fn = _wrap_pruning_fn(prune.ln_structured, dim=dim, n=l_norm) else: pruning_fn = _wrap_pruning_fn(PRUNING_FN[pruning_fn], dim=dim) else: # unstructured pruning_fn = PRUNING_FN[pruning_fn] return pruning_fn
[docs]def prune_model( model: Module, pruning_fn: Union[Callable, str], amount: Union[float, int], keys_to_prune: Optional[List[str]] = None, layers_to_prune: Optional[List[str]] = None, dim: int = None, l_norm: int = None, ) -> None: """ Prune model function can be used for pruning certain tensors in model layers. Args: model: Model to be pruned. pruning_fn: Pruning function with API same as in torch.nn.utils.pruning. pruning_fn(module, name, amount). keys_to_prune: list of strings. Determines which tensor in modules will be pruned. amount: quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. layers_to_prune: list of strings - module names to be pruned. If None provided then will try to prune every module in model. dim (int, optional): if you are using structured pruning method you need to specify dimension. Defaults to None. l_norm (int, optional): if you are using ln_structured you need to specify l_norm. Defaults to None. Example: .. code-block:: python pruned_model = prune_model(model, pruning_fn="l1_unstructured") Raises: AttributeError: If layers_to_prune is not None, but there is no layers with specified name. OR ValueError: if no layers have specified keys. """ nn_model = get_nn_from_ddp_module(model) pruning_fn = get_pruning_fn(pruning_fn, l_norm=l_norm, dim=dim) keys_to_prune = keys_to_prune or ["weight"] pruned_modules = 0 for name, module in nn_model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: pruning_fn(module, name=key, amount=amount) pruned_modules += 1 except AttributeError as e: if layers_to_prune is not None: raise e if pruned_modules == 0: raise ValueError(f"There is no {keys_to_prune} key in your model")
[docs]def remove_reparametrization( model: Module, keys_to_prune: List[str], layers_to_prune: Optional[List[str]] = None, ) -> None: """ Removes pre-hooks and pruning masks from the model. Args: model: model to remove reparametrization. keys_to_prune: list of strings. Determines which tensor in modules have already been pruned. layers_to_prune: list of strings - module names have already been pruned. If None provided then will try to prune every module in model. """ nn_model = get_nn_from_ddp_module(model) for name, module in nn_model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: prune.remove(module, key) except ValueError: pass
__all__ = ["prune_model", "remove_reparametrization", "get_pruning_fn"]