Source code for catalyst.utils.tracing
from typing import Tuple, Union
import logging
import torch
from torch import jit
from catalyst.tools.forward_wrapper import ModelForwardWrapper
from catalyst.typing import Model
from catalyst.utils.torch import get_nn_from_ddp_module
logger = logging.getLogger(__name__)
[docs]def trace_model(
model: Model, batch: Union[Tuple[torch.Tensor], torch.Tensor], method_name: str = "forward",
) -> jit.ScriptModule:
"""Traces model using runner and batch.
Args:
model: Model to trace
batch: Batch to trace the model
method_name: Model's method name that will be
used as entrypoint during tracing
Example:
.. code-block:: python
import torch
from catalyst.utils import trace_model
class LinModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin1 = torch.nn.Linear(10, 10)
self.lin2 = torch.nn.Linear(2, 10)
def forward(self, inp_1, inp_2):
return self.lin1(inp_1), self.lin2(inp_2)
def first_only(self, inp_1):
return self.lin1(inp_1)
lin_model = LinModel()
traced_model = trace_model(
lin_model, batch=torch.randn(1, 10), method_name="first_only"
)
Returns:
jit.ScriptModule: Traced model
"""
nn_model = get_nn_from_ddp_module(model)
wrapped_model = ModelForwardWrapper(model=nn_model, method_name=method_name)
traced = jit.trace(wrapped_model, example_inputs=batch)
return traced
__all__ = ["trace_model"]