Source code for catalyst.callbacks.onnx
from typing import Dict, Iterable, List, TYPE_CHECKING, Union
from pathlib import Path
from catalyst.core import Callback, CallbackNode, CallbackOrder
from catalyst.utils import onnx_export
if TYPE_CHECKING:
from catalyst.core import IRunner
[docs]class OnnxCallback(Callback):
"""
Callback for converting model to onnx runtime.
Args:
input_key: input key from ``runner.batch`` to use for onnx export
logdir: path to folder for saving
filename: filename
method_name (str, optional): Forward pass method to be converted. Defaults to "forward".
input_names (Iterable, optional): name of inputs in graph. Defaults to None.
output_names (List[str], optional): name of outputs in graph. Defaults to None.
dynamic_axes (Union[Dict[str, int], Dict[str, Dict[str, int]]], optional): axes
with dynamic shapes. Defaults to None.
opset_version (int, optional): Defaults to 9.
do_constant_folding (bool, optional): If True, the constant-folding optimization
is applied to the model during export. Defaults to False.
verbose (bool, default False): if specified, we will print out a debug
description of the trace being exported.
Example:
.. code-block:: python
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.data import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.contrib.nn.modules import Flatten
loaders = {
"train": DataLoader(
MNIST(
os.getcwd(), train=False, download=True, transform=ToTensor()
),
batch_size=32,
),
"valid": DataLoader(
MNIST(
os.getcwd(), train=False, download=True, transform=ToTensor()
),
batch_size=32,
),
}
model = nn.Sequential(Flatten(), nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 10))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
runner = dl.SupervisedRunner()
runner.train(
model=model,
callbacks=[dl.OnnxCallback(input_key="features", logdir="./logs")],
loaders=loaders,
criterion=criterion,
optimizer=optimizer,
num_epochs=1,
logdir="./logs",
)
"""
[docs] def __init__(
self,
input_key: str,
logdir: Union[str, Path] = None,
filename: str = "onnx.py",
method_name: str = "forward",
input_names: Iterable = None,
output_names: List[str] = None,
dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None,
opset_version: int = 9,
do_constant_folding: bool = False,
verbose: bool = False,
):
"""Init."""
super().__init__(order=CallbackOrder.ExternalExtra, node=CallbackNode.Master)
if logdir is not None:
self.filename = str(Path(logdir) / filename)
else:
self.filename = filename
# self.input_key = [input_key] if isinstance(input_key, str) else input_key
self.input_key = input_key
self.method_name = method_name
self.input_names = input_names
self.output_names = output_names
self.dynamic_axes = dynamic_axes
self.opset_version = opset_version
self.do_constant_folding = do_constant_folding
self.verbose = verbose
def on_stage_end(self, runner: "IRunner") -> None:
"""
On stage end action.
Args:
runner: runner for experiment
"""
model = runner.model
batch = runner.engine.sync_device(runner.batch[self.input_key])
onnx_export(
model=model,
file=self.filename,
batch=batch,
method_name=self.method_name,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
opset_version=self.opset_version,
do_constant_folding=self.do_constant_folding,
verbose=self.verbose,
)
__all__ = ["OnnxCallback"]