Shortcuts

Source code for catalyst.utils.quantization

from typing import Dict, Optional, Union

import torch
from torch import quantization

from catalyst.typing import TorchModel
from catalyst.utils.distributed import get_nn_from_ddp_module


[docs]def quantize_model( model: TorchModel, qconfig_spec: Dict = None, dtype: Union[str, Optional[torch.dtype]] = "qint8", ) -> TorchModel: """Function to quantize model weights. Args: model: model to be quantized qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None. dtype: Type of weights after quantization. Defaults to "qint8". Returns: Model: quantized model """ nn_model = get_nn_from_ddp_module(model) if isinstance(dtype, str): type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8} try: quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype] ) except RuntimeError: torch.backends.quantized.engine = "qnnpack" quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype] ) return quantized_model
__all__ = ["quantize_model"]