Source code for catalyst.runners.supervised
from typing import Any, List, Mapping, Tuple, Union
import torch
from catalyst.core.runner import IRunner
[docs]class ISupervisedRunner(IRunner):
"""IRunner for experiments with supervised model.
Args:
input_key: key in ``runner.batch`` dict mapping for model input
output_key: key for ``runner.batch`` to store model output
target_key: key in ``runner.batch`` dict mapping for target
loss_key: key for ``runner.batch_metrics`` to store criterion loss output
"""
def __init__(
self,
input_key: Any = "features",
output_key: Any = "logits",
target_key: str = "targets",
loss_key: str = "loss",
):
"""Init."""
IRunner.__init__(self)
self._input_key = input_key
self._output_key = output_key
self._target_key = target_key
self._loss_key = loss_key
if isinstance(self._input_key, str):
# when model expects value
self._process_input = self._process_input_str
elif isinstance(self._input_key, (list, tuple)):
# when model expects tuple
self._process_input = self._process_input_list
elif self._input_key is None:
# when model expects dict
self._process_input = self._process_input_none
else:
raise NotImplementedError()
if isinstance(output_key, str):
# when model returns value
self._process_output = self._process_output_str
elif isinstance(output_key, (list, tuple)):
# when model returns tuple
self._process_output = self._process_output_list
elif self._output_key is None:
# when model returns dict
self._process_output = self._process_output_none
else:
raise NotImplementedError()
def _process_batch(self, batch):
if isinstance(batch, (tuple, list)):
assert len(batch) == 2
batch = {self._input_key: batch[0], self._target_key: batch[1]}
return batch
def _process_input_str(self, batch: Mapping[str, Any], **kwargs):
output = self.model(batch[self._input_key], **kwargs)
return output
def _process_input_list(self, batch: Mapping[str, Any], **kwargs):
input = {key: batch[key] for key in self._input_key} # noqa: WPS125
output = self.model(**input, **kwargs)
return output
def _process_input_none(self, batch: Mapping[str, Any], **kwargs):
output = self.model(**batch, **kwargs)
return output
def _process_output_str(self, output: torch.Tensor):
output = {self._output_key: output}
return output
def _process_output_list(self, output: Union[Tuple, List]):
output = {key: value for key, value in zip(self._output_key, output)}
return output
def _process_output_none(self, output: Mapping[str, Any]):
return output
[docs] def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]:
"""
Forward method for your Runner.
Should not be called directly outside of runner.
If your model has specific interface, override this method to use it
Args:
batch (Mapping[str, Any]): dictionary with data batches
from DataLoaders.
**kwargs: additional parameters to pass to the model
Returns:
dict with model output batch
"""
output = self._process_input(batch, **kwargs)
output = self._process_output(output)
return output
def on_batch_start(self, runner: "IRunner"):
"""Event handler."""
self.batch = self._process_batch(self.batch)
super().on_batch_start(runner)
[docs] def handle_batch(self, batch: Mapping[str, Any]) -> None:
"""
Inner method to handle specified data batch.
Used to make a train/valid/infer stage during Experiment run.
Args:
batch: dictionary with data batches from DataLoader.
"""
self.batch = {**batch, **self.forward(batch)}
__all__ = ["ISupervisedRunner"]