Source code for catalyst.runners.supervised
from typing import Any, Callable, List, Mapping, Tuple, Union
import logging
import torch
from catalyst.experiments.auto import AutoCallbackExperiment
from catalyst.runners.runner import Runner
from catalyst.typing import Device, RunnerModel
logger = logging.getLogger(__name__)
[docs]class SupervisedRunner(Runner):
    """Runner for experiments with supervised model."""
[docs]    def __init__(
        self,
        model: RunnerModel = None,
        device: Device = None,
        input_key: Any = "features",
        output_key: Any = "logits",
        input_target_key: str = "targets",
        experiment_fn: Callable = AutoCallbackExperiment,
    ):
        """
        Args:
            model: Torch model object
            device: Torch device
            input_key: Key in batch dict mapping for model input
            output_key: Key in output dict model output
                will be stored under
            input_target_key: Key in batch dict mapping for target
            experiment_fn: callable function,
                which defines default experiment type to use
                during ``.train`` and ``.infer`` methods.
        """
        super().__init__(
            model=model, device=device, experiment_fn=experiment_fn
        )
        self.input_key = input_key
        self.output_key = output_key
        self.target_key = input_target_key
        if isinstance(self.input_key, str):
            # when model expects value
            self._process_input = self._process_input_str
        elif isinstance(self.input_key, (list, tuple)):
            # when model expects tuple
            self._process_input = self._process_input_list
        elif self.input_key is None:
            # when model expects dict
            self._process_input = self._process_input_none
        else:
            raise NotImplementedError()
        if isinstance(output_key, str):
            # when model returns value
            self._process_output = self._process_output_str
        elif isinstance(output_key, (list, tuple)):
            # when model returns tuple
            self._process_output = self._process_output_list
        elif self.output_key is None:
            # when model returns dict
            self._process_output = self._process_output_none
        else:
            raise NotImplementedError() 
    def _process_input_str(self, batch: Mapping[str, Any], **kwargs):
        output = self.model(batch[self.input_key], **kwargs)
        return output
    def _process_input_list(self, batch: Mapping[str, Any], **kwargs):
        input = {key: batch[key] for key in self.input_key}  # noqa: WPS125
        output = self.model(**input, **kwargs)
        return output
    def _process_input_none(self, batch: Mapping[str, Any], **kwargs):
        output = self.model(**batch, **kwargs)
        return output
    def _process_output_str(self, output: torch.Tensor):
        output = {self.output_key: output}
        return output
    def _process_output_list(self, output: Union[Tuple, List]):
        output = {key: value for key, value in zip(self.output_key, output)}
        return output
    def _process_output_none(self, output: Mapping[str, Any]):
        return output
[docs]    def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]:
        """
        Forward method for your Runner.
        Should not be called directly outside of runner.
        If your model has specific interface, override this method to use it
        Args:
            batch (Mapping[str, Any]): dictionary with data batches
                from DataLoaders.
            **kwargs: additional parameters to pass to the model
        Returns:
            dict with model output batch
        """
        output = self._process_input(batch, **kwargs)
        output = self._process_output(output)
        return output 
    def _handle_device(self, batch: Mapping[str, Any]):
        if isinstance(batch, (tuple, list)):
            assert len(batch) == 2
            batch = {self.input_key: batch[0], self.target_key: batch[1]}
        batch = super()._handle_device(batch)
        return batch
    def _handle_batch(self, batch: Mapping[str, Any]) -> None:
        """
        Inner method to handle specified data batch.
        Used to make a train/valid/infer stage during Experiment run.
        Args:
            batch (Mapping[str, Any]): dictionary with data batches
                from DataLoader.
        """
        self.output = self.forward(batch)
[docs]    @torch.no_grad()
    def predict_batch(
        self, batch: Mapping[str, Any], **kwargs
    ) -> Mapping[str, Any]:
        """
        Run model inference on specified data batch.
        .. warning::
            You should not override this method. If you need specific model
            call, override forward() method
        Args:
            batch (Mapping[str, Any]): dictionary with data batches
                from DataLoader.
            **kwargs: additional kwargs to pass to the model
        Returns:
            Mapping[str, Any]: model output dictionary
        """
        batch = self._handle_device(batch)
        output = self.forward(batch, **kwargs)
        return output  
__all__ = ["SupervisedRunner"]