
Source code for catalyst.runners.supervised

from typing import Any, List, Mapping, Tuple, Union
from collections import OrderedDict

import torch

from catalyst.callbacks.backward import BackwardCallback
from catalyst.callbacks.criterion import CriterionCallback
from catalyst.callbacks.optimizer import OptimizerCallback
from catalyst.callbacks.scheduler import SchedulerCallback
from catalyst.core.callback import (
from catalyst.core.engine import Engine
from catalyst.core.misc import callback_isinstance, sort_callbacks_by_order
from catalyst.core.runner import IRunner
from catalyst.runners.runner import Runner
from catalyst.typing import RunnerModel, TorchCriterion, TorchOptimizer, TorchScheduler

[docs]class ISupervisedRunner(Runner): """IRunner for experiments with supervised model. Args: input_key: key in ``runner.batch`` dict mapping for model input output_key: key for ``runner.batch`` to store model output target_key: key in ``runner.batch`` dict mapping for target loss_key: key for ``runner.batch_metrics`` to store criterion loss output Abstraction, please check out implementations for more details: - :py:mod:`catalyst.runners.runner.SupervisedRunner` .. note:: ISupervisedRunner contains only the logic with batch handling. ISupervisedRunner logic pseudocode: .. code-block:: python batch = {"input_key": tensor, "target_key": tensor} output = model(batch["input_key"]) batch["output_key"] = output loss = criterion(batch["output_key"], batch["target_key"]) batch_metrics["loss_key"] = loss .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: # noqa: E501, W505 """
[docs] def __init__( self, input_key: Any = "features", output_key: Any = "logits", target_key: str = "targets", loss_key: str = "loss", ): """Init.""" IRunner.__init__(self) self._input_key = input_key self._output_key = output_key self._target_key = target_key self._loss_key = loss_key if isinstance(self._input_key, str): # when model expects value self._process_input = self._process_input_str elif isinstance(self._input_key, (list, tuple)): # when model expects tuple self._process_input = self._process_input_list elif self._input_key is None: # when model expects dict self._process_input = self._process_input_none else: raise NotImplementedError() if isinstance(output_key, str): # when model returns value self._process_output = self._process_output_str elif isinstance(output_key, (list, tuple)): # when model returns tuple self._process_output = self._process_output_list elif self._output_key is None: # when model returns dict self._process_output = self._process_output_none else: raise NotImplementedError()
def _process_batch(self, batch): if isinstance(batch, (tuple, list)): assert len(batch) == 2 batch = {self._input_key: batch[0], self._target_key: batch[1]} return batch def _process_input_str(self, batch: Mapping[str, Any], **kwargs): output = self.model(batch[self._input_key], **kwargs) return output def _process_input_list(self, batch: Mapping[str, Any], **kwargs): input = {key: batch[key] for key in self._input_key} output = self.model(**input, **kwargs) return output def _process_input_none(self, batch: Mapping[str, Any], **kwargs): output = self.model(**batch, **kwargs) return output def _process_output_str(self, output: torch.Tensor): output = {self._output_key: output} return output def _process_output_list(self, output: Union[Tuple, List]): output = {key: value for key, value in zip(self._output_key, output)} return output def _process_output_none(self, output: Mapping[str, Any]): return output
[docs] def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Forward method for your Runner. Should not be called directly outside of runner. If your model has specific interface, override this method to use it Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoaders. **kwargs: additional parameters to pass to the model Returns: dict with model output batch """ output = self._process_input(batch, **kwargs) output = self._process_output(output) return output
def on_batch_start(self, runner: "IRunner"): """Event handler.""" self.batch = self._process_batch(self.batch) super().on_batch_start(runner)
[docs] def handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer step during Experiment run. Args: batch: dictionary with data batches from DataLoader. """ self.batch = {**batch, **self.forward(batch)}
[docs]class SupervisedRunner(ISupervisedRunner, Runner): """Runner for experiments with supervised model. Args: model: Torch model instance engine: Engine instance input_key: key in ``runner.batch`` dict mapping for model input output_key: key for ``runner.batch`` to store model output target_key: key in ``runner.batch`` dict mapping for target loss_key: key for ``runner.batch_metrics`` to store criterion loss output SupervisedRunner logic pseudocode: .. code-block:: python batch = {"input_key": tensor, "target_key": tensor} output = model(batch["input_key"]) batch["output_key"] = output loss = criterion(batch["output_key"], batch["target_key"]) batch_metrics["loss_key"] = loss .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: # noqa: E501, W505 """
[docs] def __init__( self, model: RunnerModel = None, engine: Engine = None, input_key: Any = "features", output_key: Any = "logits", target_key: str = "targets", loss_key: str = "loss", ): """Init.""" ISupervisedRunner.__init__( self, input_key=input_key, output_key=output_key, target_key=target_key, loss_key=loss_key, ) Runner.__init__(self, model=model, engine=engine)
[docs] @torch.no_grad() def predict_batch(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Run model inference on specified data batch. .. warning:: You should not override this method. If you need specific model call, override runner.forward() method. Args: batch: dictionary with data batch from DataLoader. **kwargs: additional kwargs to pass to the model Returns: Mapping[str, Any]: model output dictionary """ batch = self._process_batch(batch) output = self.forward(batch, **kwargs) return output
[docs] def get_callbacks(self) -> "OrderedDict[str, Callback]": """Returns the callbacks for the experiment.""" callbacks = sort_callbacks_by_order(super().get_callbacks()) callback_exists = lambda callback_fn: any( callback_isinstance(x, callback_fn) for x in callbacks.values() ) if isinstance(self._criterion, TorchCriterion) and not callback_exists( ICriterionCallback ): callbacks["_criterion"] = CriterionCallback( input_key=self._output_key, target_key=self._target_key, metric_key=self._loss_key, ) if isinstance(self._optimizer, TorchOptimizer) and not callback_exists( IBackwardCallback ): callbacks["_backward"] = BackwardCallback(metric_key=self._loss_key) if isinstance(self._optimizer, TorchOptimizer) and not callback_exists( IOptimizerCallback ): callbacks["_optimizer"] = OptimizerCallback(metric_key=self._loss_key) if isinstance(self._scheduler, TorchScheduler) and not callback_exists( ISchedulerCallback ): callbacks["_scheduler"] = SchedulerCallback( loader_key=self._valid_loader, metric_key=self._valid_metric ) return callbacks
__all__ = ["ISupervisedRunner", "SupervisedRunner"]