Source code for catalyst.utils.pruning
from typing import Callable, List, Optional, Union
from torch.nn import Module
from torch.nn.utils import prune
from catalyst.utils.initialization import reset_weights_if_possible
[docs]def prune_model(
model: Module,
pruning_fn: Callable,
keys_to_prune: List[str],
amount: Union[float, int],
layers_to_prune: Optional[List[str]] = None,
reinitialize_after_pruning: Optional[bool] = False,
) -> None:
"""
Prune model function can be used for pruning certain
tensors in model layers.
Raises:
AttributeError: If layers_to_prune is not None, but there is
no layers with specified name.
Exception: If no layers have specified keys.
Args:
model: Model to be pruned.
pruning_fn: Pruning function with API same as in
torch.nn.utils.pruning.
pruning_fn(module, name, amount).
keys_to_prune: list of strings. Determines
which tensor in modules will be pruned.
amount: quantity of parameters to prune.
If float, should be between 0.0 and 1.0 and
represent the fraction of parameters to prune.
If int, it represents the absolute number
of parameters to prune.
layers_to_prune: list of strings - module names to be pruned.
If None provided then will try to prune every module in
model.
reinitialize_after_pruning: if True then will reinitialize model
after pruning. (Lottery Ticket Hypothesis check e.g.)
"""
pruned_modules = 0
for name, module in model.named_modules():
try:
if layers_to_prune is None or name in layers_to_prune:
for key in keys_to_prune:
pruning_fn(module, name=key, amount=amount)
pruned_modules += 1
except AttributeError as e:
if layers_to_prune is not None:
raise e
if pruned_modules == 0:
raise Exception(f"There is no {keys_to_prune} key in your model")
if reinitialize_after_pruning:
model.apply(reset_weights_if_possible)
[docs]def remove_reparametrization(
model: Module,
keys_to_prune: List[str],
layers_to_prune: Optional[List[str]] = None,
) -> None:
"""
Removes pre-hooks and pruning masks from the model.
Args:
model: model to remove reparametrization.
keys_to_prune: list of strings. Determines
which tensor in modules have already been pruned.
layers_to_prune: list of strings - module names
have already been pruned.
If None provided then will try to prune every module in
model.
"""
for name, module in model.named_modules():
try:
if layers_to_prune is None or name in layers_to_prune:
for key in keys_to_prune:
prune.remove(module, key)
except ValueError:
pass
__all__ = ["prune_model", "remove_reparametrization"]