Shortcuts

Source code for catalyst.runners.supervised

from typing import Any, List, Mapping, Tuple, Union

import torch

from catalyst.core.runner import IRunner


[docs]class ISupervisedRunner(IRunner): """IRunner for experiments with supervised model. Args: input_key: key in ``runner.batch`` dict mapping for model input output_key: key for ``runner.batch`` to store model output target_key: key in ``runner.batch`` dict mapping for target loss_key: key for ``runner.batch_metrics`` to store criterion loss output Abstraction, please check out implementations for more details: - :py:mod:`catalyst.runners.runner.SupervisedRunner` - :py:mod:`catalyst.runners.config.SupervisedConfigRunner` - :py:mod:`catalyst.runners.hydra.SupervisedHydraRunner` .. note:: ISupervisedRunner contains only the logic with batch handling. ISupervisedRunner logic pseudocode: .. code-block:: python batch = {"input_key": tensor, "target_key": tensor} output = model(batch["input_key"]) batch["output_key"] = output loss = criterion(batch["output_key"], batch["target_key"]) batch_metrics["loss_key"] = loss .. note:: Please follow the `minimal examples`_ sections for use cases. .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples Examples: .. code-block:: python import os from torch import nn, optim from torch.utils.data import DataLoader from catalyst import dl, utils from catalyst.data.transforms import ToTensor from catalyst.contrib.datasets import MNIST model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.02) loaders = { "train": DataLoader( MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32 ), "valid": DataLoader( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32 ), } runner = dl.SupervisedRunner( input_key="features", output_key="logits", target_key="targets", loss_key="loss" ) # model training runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=1, callbacks=[ dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3)), dl.PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=10 ), dl.AUCCallback(input_key="logits", target_key="targets"), ], logdir="./logs", valid_loader="valid", valid_metric="loss", minimize_valid_metric=True, verbose=True, load_best_on_end=True, ) # model inference for prediction in runner.predict_loader(loader=loaders["valid"]): assert prediction["logits"].detach().cpu().numpy().shape[-1] == 10 """
[docs] def __init__( self, input_key: Any = "features", output_key: Any = "logits", target_key: str = "targets", loss_key: str = "loss", ): """Init.""" IRunner.__init__(self) self._input_key = input_key self._output_key = output_key self._target_key = target_key self._loss_key = loss_key if isinstance(self._input_key, str): # when model expects value self._process_input = self._process_input_str elif isinstance(self._input_key, (list, tuple)): # when model expects tuple self._process_input = self._process_input_list elif self._input_key is None: # when model expects dict self._process_input = self._process_input_none else: raise NotImplementedError() if isinstance(output_key, str): # when model returns value self._process_output = self._process_output_str elif isinstance(output_key, (list, tuple)): # when model returns tuple self._process_output = self._process_output_list elif self._output_key is None: # when model returns dict self._process_output = self._process_output_none else: raise NotImplementedError()
def _process_batch(self, batch): if isinstance(batch, (tuple, list)): assert len(batch) == 2 batch = {self._input_key: batch[0], self._target_key: batch[1]} return batch def _process_input_str(self, batch: Mapping[str, Any], **kwargs): output = self.model(batch[self._input_key], **kwargs) return output def _process_input_list(self, batch: Mapping[str, Any], **kwargs): input = {key: batch[key] for key in self._input_key} # noqa: WPS125 output = self.model(**input, **kwargs) return output def _process_input_none(self, batch: Mapping[str, Any], **kwargs): output = self.model(**batch, **kwargs) return output def _process_output_str(self, output: torch.Tensor): output = {self._output_key: output} return output def _process_output_list(self, output: Union[Tuple, List]): output = {key: value for key, value in zip(self._output_key, output)} return output def _process_output_none(self, output: Mapping[str, Any]): return output
[docs] def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Forward method for your Runner. Should not be called directly outside of runner. If your model has specific interface, override this method to use it Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoaders. **kwargs: additional parameters to pass to the model Returns: dict with model output batch """ output = self._process_input(batch, **kwargs) output = self._process_output(output) return output
def on_batch_start(self, runner: "IRunner"): """Event handler.""" self.batch = self._process_batch(self.batch) super().on_batch_start(runner)
[docs] def handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer stage during Experiment run. Args: batch: dictionary with data batches from DataLoader. """ self.batch = {**batch, **self.forward(batch)}
__all__ = ["ISupervisedRunner"]