Source code for catalyst.dl.utils.trace
from typing import Type # isort:skip
import torch
from torch import nn
from torch.jit import ScriptModule
from catalyst import utils
from catalyst.dl.core import Experiment, Runner
class _ForwardOverrideModel(nn.Module):
"""
Model that calls specified method instead of forward
(Workaround, single method tracing is not supported)
"""
def __init__(self, model, method_name):
super().__init__()
self.model = model
self.method = method_name
def forward(self, *args, **kwargs):
return getattr(self.model, self.method)(*args, **kwargs)
class _TracingModelWrapper(nn.Module):
"""
Wrapper that traces model with batch instead of calling it
(Workaround, to use native model batch handler)
"""
def __init__(self, model, method_name):
super().__init__()
self.method_name = method_name
self.model = model
self.tracing_result: ScriptModule
def __call__(self, *args, **kwargs):
method_model = _ForwardOverrideModel(self.model, self.method_name)
self.tracing_result = \
torch.jit.trace(
method_model,
*args, **kwargs
)
def _get_native_batch(experiment: Experiment, stage: str):
"""Returns dataset from first loader provided by experiment"""
loaders = experiment.get_loaders(stage)
assert loaders, \
"Experiment must have at least one loader to support tracing"
# Take first loader
loader = next(iter(loaders.values()))
dataset = loader.dataset
collate_fn = loader.collate_fn
sample = collate_fn([dataset[0]])
return sample
[docs]def trace_model(
model: nn.Module,
experiment: Experiment,
runner_type: Type[Runner],
method_name: str = "forward",
mode: str = "eval",
requires_grad: bool = False,
) -> ScriptModule:
"""
Traces model using it's native experiment and runner.
Args:
model: Model to trace
experiment: Native experiment that was used to train model
runner_type: Model's native runner that was used to train model
method_name (str): Model's method name that will be
used as entrypoint during tracing
mode (str): Mode for model to trace (``train`` or ``eval``)
requires_grad (bool): Flag to use grads
Returns:
Traced model ScriptModule
"""
if mode not in ["train", "eval"]:
raise ValueError(f"Unknown mode '{mode}'. Must be 'eval' or 'train'")
getattr(model, mode)()
utils.set_requires_grad(model, requires_grad=requires_grad)
tracer = _TracingModelWrapper(model, method_name)
runner: Runner = runner_type(tracer.cpu(), torch.device("cpu"))
stage = list(experiment.stages)[0]
batch = _get_native_batch(experiment, stage)
batch = runner._batch2device(batch, device=runner.device)
runner.predict_batch(batch)
return tracer.tracing_result
__all__ = ["trace_model"]