Shortcuts

Source code for catalyst.runners.multi_supervised

from typing import Any, Callable, List, Mapping, Tuple, Union
import logging

import torch

from catalyst.experiments.auto import AutoCallbackExperiment
from catalyst.runners.runner import Runner
from catalyst.typing import Device, RunnerModel

logger = logging.getLogger(__name__)


[docs]class MultiSupervisedRunner(Runner): """Runner for experiments with supervised models."""
[docs] def __init__( self, model: RunnerModel = None, device: Device = None, models_keys: Mapping[str, Any] = None, experiment_fn: Callable = AutoCallbackExperiment, ): """ Args: model: (RunnerModel) Torch model object device: (Device) Torch device models_keys: (Mapping[str, Any]) Key in batch dict mapping for model input, output, target experiment_fn: callable function, which defines default experiment type to use during ``.train`` and ``.infer`` methods. """ super().__init__( model=model, device=device, experiment_fn=experiment_fn, ) self.input_key = {} self.output_key = {} self.target_key = {} self._process_input = {} self._process_output = {} for model_name, model_keys in models_keys.items(): self.input_key[model_name] = ( model_keys["input_key"] if model_keys["input_key"] is not None else "features" ) self.output_key[model_name] = ( model_keys["output_key"] if model_keys["output_key"] is not None else "logits" ) self.target_key[model_name] = ( model_keys["target_key"] if model_keys["target_key"] is not None else "targets" ) if isinstance(self.input_key[model_name], str): # when model expects value self._process_input[model_name] = self._process_input_str elif isinstance(self.input_key[model_name], (list, tuple)): # when model expects tuple self._process_input[model_name] = self._process_input_list elif self.input_key[model_name] is None: # when model expects dict self._process_input[model_name] = self._process_input_none else: raise NotImplementedError() if isinstance(self.output_key[model_name], str): # when model returns value self._process_output[model_name] = self._process_output_str elif isinstance(self.output_key[model_name], (list, tuple)): # when model returns tuple self._process_output[model_name] = self._process_output_list elif self.output_key[model_name] is None: # when model returns dict self._process_output[model_name] = self._process_output_none else: raise NotImplementedError()
def _prepare_inner_state(self, *args, **kwargs): logdir = None if self.experiment is None else self.experiment.logdir super()._prepare_inner_state(*args, logdir=logdir, **kwargs) def _handle_device(self, batch: Mapping[str, Any]): if isinstance(batch, (tuple, list)): assert len(batch) == 2 batch_dict = {} for _, input_key in self.input_key.items(): batch_dict.setdefault(input_key, batch[0]) for _, target_key in self.target_key.items(): batch_dict.setdefault(target_key, batch[1]) batch = batch_dict batch = super()._handle_device(batch) return batch def _process_input_str( self, model_name: str, batch: Mapping[str, Any], **kwargs ): output = self.model[model_name]( batch[self.input_key[model_name]], **kwargs ) return output def _process_input_list( self, model_name: str, batch: Mapping[str, Any], **kwargs ): input = {key: batch[key] for key in self.input_key[model_name]} output = self.model[model_name](**input, **kwargs) return output def _process_input_none( self, model_name: str, batch: Mapping[str, Any], **kwargs ): output = self.model[model_name](**batch, **kwargs) return output def _process_output_str(self, model_name: str, output: torch.Tensor): output = {self.output_key[model_name]: output} return output def _process_output_list( self, model_name: str, output: Union[Tuple, List] ): output = { key: value for key, value in zip(self.output_key[model_name], output) } return output def _process_output_none(self, model_name: str, output: Mapping[str, Any]): return output
[docs] def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Forward method for your Runner. Should not be called directly outside of runner. If your model has specific interface, override this method to use it Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoaders. **kwargs: additional parameters to pass to the model Returns: dict with model output batch """ output = {} for model_name in self.model: output = self._process_input[model_name]( model_name, batch, **kwargs ) output = self._process_output[model_name](model_name, output) output.update(output) return output
def _handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer stage during Experiment run. Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. """ self.output = self.forward(batch)
[docs] @torch.no_grad() def predict_batch( self, batch: Mapping[str, Any], **kwargs ) -> Mapping[str, Any]: """ Run model inference on specified data batch. .. warning:: You should not override this method. If you need specific model call, override forward() method Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. **kwargs: additional kwargs to pass to the model Returns: Mapping[str, Any]: model output dictionary """ batch = self._handle_device(batch) output = self.forward(batch, **kwargs) return output
__all__ = ["MultiSupervisedRunner"]