Shortcuts

Source code for catalyst.utils.quantization

from typing import Dict, Optional, Union

import torch
from torch import quantization

from catalyst.typing import Model
from catalyst.utils.torch import get_nn_from_ddp_module


[docs]def quantize_model( model: Model, qconfig_spec: Dict = None, dtype: Union[str, Optional[torch.dtype]] = "qint8", ) -> Model: """Function to quantize model weights. Args: model: model to be quantized qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None. dtype (Union[str, Optional[torch.dtype]], optional): Type of weights after quantization. Defaults to "qint8". Returns: Model: quantized model """ nn_model = get_nn_from_ddp_module(model) if isinstance(dtype, str): type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8} try: quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype], ) except RuntimeError: torch.backends.quantized.engine = "qnnpack" quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype], ) return quantized_model
__all__ = ["quantize_model"]