Source code for catalyst.tools.forward_wrapper
from torch import nn
[docs]class ModelForwardWrapper(nn.Module):
"""Model that calls specified method instead of forward.
Args:
model: @TODO: docs
method_name: @TODO: docs
(Workaround, single method tracing is not supported)
"""
[docs] def __init__(self, model, method_name):
"""Init"""
super().__init__()
self.model = model
self.method_name = method_name
[docs] def forward(self, *args, **kwargs):
"""Forward pass.
Args:
*args: some args
**kwargs: some kwargs
Returns:
output: specified method output
"""
return getattr(self.model, self.method_name)(*args, **kwargs)
__all__ = ["ModelForwardWrapper"]