Shortcuts

Source code for catalyst.utils.onnx

from typing import Dict, Iterable, List, Union
import os
from pathlib import Path

import torch

from catalyst.settings import SETTINGS
from catalyst.tools.forward_wrapper import ModelForwardWrapper
from catalyst.utils.torch import get_nn_from_ddp_module

if SETTINGS.onnx_required:
    import onnx
    from onnxruntime.quantization import quantize_dynamic, QuantType


[docs]def onnx_export( model: torch.nn.Module, batch: torch.Tensor, file: str, method_name: str = "forward", input_names: Iterable = None, output_names: List[str] = None, dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None, opset_version: int = 9, do_constant_folding: bool = False, return_model: bool = False, verbose: bool = False, ) -> Union[None, "onnx"]: """Converts model to onnx runtime. Args: model (torch.nn.Module): model batch (Tensor): inputs file (str, optional): file to save. Defaults to "model.onnx". method_name (str, optional): Forward pass method to be converted. Defaults to "forward". input_names (Iterable, optional): name of inputs in graph. Defaults to None. output_names (List[str], optional): name of outputs in graph. Defaults to None. dynamic_axes (Union[Dict[str, int], Dict[str, Dict[str, int]]], optional): axes with dynamic shapes. Defaults to None. opset_version (int, optional): Defaults to 9. do_constant_folding (bool, optional): If True, the constant-folding optimization is applied to the model during export. Defaults to False. return_model (bool, optional): If True then returns onnxruntime model (onnx required). Defaults to False. verbose (bool, default False): if specified, we will print out a debug description of the trace being exported. Example: .. code-block:: python import torch from catalyst.utils import convert_to_onnx class LinModel(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = torch.nn.Linear(10, 10) self.lin2 = torch.nn.Linear(2, 10) def forward(self, inp_1, inp_2): return self.lin1(inp_1), self.lin2(inp_2) def first_only(self, inp_1): return self.lin1(inp_1) lin_model = LinModel() convert_to_onnx( model, batch=torch.randn((1, 10)), file="model.onnx", method_name="first_only" ) Raises: ImportError: when ``return_model`` is True, but onnx is not installed. Returns: Union[None, "onnx"]: onnx model if return_model set to True. """ nn_model = get_nn_from_ddp_module(model) if method_name != "forward": nn_model = ModelForwardWrapper(model=nn_model, method_name=method_name) torch.onnx.export( nn_model, batch, file, verbose=verbose, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=do_constant_folding, opset_version=opset_version, ) if return_model: if not SETTINGS.onnx_required: raise ImportError("To use onnx model you should install it with ``pip install onnx``") return onnx.load(file)
[docs]def quantize_onnx_model( onnx_model_path: Union[Path, str], quantized_model_path: Union[Path, str], qtype: str = "qint8", verbose: bool = False, ) -> None: """Takes model converted to onnx runtime and applies pruning. Args: onnx_model_path (Union[Path, str]): path to onnx model. quantized_model_path (Union[Path, str]): path to quantized model. qtype (str, optional): Type of weights in quantized model. Can be `quint8` or `qint8`. Defaults to "qint8". verbose (bool, optional): If set to True prints model size before and after quantization. Defaults to False. Raises: ValueError: If qtype is not understood. """ type_mapping = { "qint8": QuantType.QInt8, "quint8": QuantType.QUInt8, } if qtype not in type_mapping.keys(): raise ValueError("type should be string one of 'quint8' or 'qint8'. Got {}".format(qtype)) quantize_dynamic(onnx_model_path, quantized_model_path, weight_type=type_mapping[qtype]) if verbose: v_str = ( "Model size before quantization (MB):" f"{os.path.getsize(onnx_model_path) / 2**20:.2f}\n" "Model size after quantization (MB): " f"{os.path.getsize(quantized_model_path) / 2**20:.2f}" ) print("Done.") print(v_str) print(f"Quantized model saved to {quantized_model_path}.")
__all__ = ["onnx_export", "quantize_onnx_model"]