Source code for catalyst.dl.runner.supervised

from typing import (  # isort:skip
    Any, Callable, Dict, List, Mapping, Tuple, Union  # isort:skip
)  # isort:skip
from collections import OrderedDict
import logging
from pathlib import Path

import torch
from torch.jit import ScriptModule
from torch.utils.data import DataLoader

from catalyst.dl import (
    Callback, CheckpointCallback, InferCallback, Runner, State,
    SupervisedExperiment, utils
)
from catalyst.dl.utils import trace
from catalyst.utils.tools.typing import (
    Criterion, Device, Model, Optimizer, Scheduler
)

logger = logging.getLogger(__name__)


[docs]class SupervisedRunner(Runner): """ Runner for experiments with supervised model """ _experiment_fn: Callable = SupervisedExperiment
[docs] def __init__( self, model: Model = None, device: Device = None, input_key: Any = "features", output_key: Any = "logits", input_target_key: str = "targets", ): """ Args: model (Module): Torch model object device (Device): Torch device input_key (Any): Key in batch dict mapping for model input output_key (Any): Key in output dict model output will be stored under input_target_key (str): Key in batch dict mapping for target """ super().__init__(model=model, device=device) self.input_key = input_key self.output_key = output_key self.target_key = input_target_key if isinstance(self.input_key, str): # when model expects value self._process_input = self._process_input_str elif isinstance(self.input_key, (list, tuple)): # when model expects tuple self._process_input = self._process_input_list elif self.input_key is None: # when model expects dict self._process_input = self._process_input_none else: raise NotImplementedError() if isinstance(output_key, str): # when model returns value self._process_output = self._process_output_str elif isinstance(output_key, (list, tuple)): # when model returns tuple self._process_output = self._process_output_list elif self.output_key is None: # when model returns dict self._process_output = self._process_output_none else: raise NotImplementedError()
def _init(self): self.experiment: SupervisedExperiment = None self.state: State = None def _batch2device(self, batch: Mapping[str, Any], device: Device): if isinstance(batch, (tuple, list)): assert len(batch) == 2 batch = {self.input_key: batch[0], self.target_key: batch[1]} batch = super()._batch2device(batch, device) return batch def _process_input_str(self, batch: Mapping[str, Any], **kwargs): output = self.model(batch[self.input_key], **kwargs) return output def _process_input_list(self, batch: Mapping[str, Any], **kwargs): input = dict((key, batch[key]) for key in self.input_key) output = self.model(**input, **kwargs) return output def _process_input_none(self, batch: Mapping[str, Any], **kwargs): output = self.model(**batch, **kwargs) return output def _process_output_str(self, output: torch.Tensor): output = {self.output_key: output} return output def _process_output_list(self, output: Union[Tuple, List]): output = dict( (key, value) for key, value in zip(self.output_key, output) ) return output def _process_output_none(self, output: Mapping[str, Any]): return output
[docs] def forward(self, batch, **kwargs): """ Should not be called directly outside of runner. If your model has specific interface, override this method to use it """ output = self._process_input(batch, **kwargs) output = self._process_output(output) return output
[docs] def train( self, model: Model, criterion: Criterion, optimizer: Optimizer, loaders: "OrderedDict[str, DataLoader]", logdir: str, callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None, scheduler: Scheduler = None, resume: str = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, state_kwargs: Dict = None, checkpoint_data: Dict = None, fp16: Union[Dict, bool] = None, monitoring_params: Dict = None, check: bool = False, ) -> None: """ Starts the training process of the model. Args: model (Model): model to train criterion (Criterion): criterion function for training optimizer (Optimizer): optimizer for training loaders (dict): dictionary containing one or several ``torch.utils.data.DataLoader`` for training and validation logdir (str): path to output directory callbacks (List[catalyst.dl.Callback]): list of callbacks scheduler (Scheduler): scheduler for training resume (str): path to checkpoint for model num_epochs (int): number of training epochs valid_loader (str): loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric (str): the key to the name of the metric by which the checkpoints will be selected. minimize_metric (bool): flag to indicate whether the ``main_metric`` should be minimized. verbose (bool): ff true, it displays the status of the training to the console. state_kwargs (dict): additional state params to ``State`` checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc fp16 (Union[Dict, bool]): If not None, then sets training to FP16. See https://nvidia.github.io/apex/amp.html#properties if fp16=True, params by default will be ``{"opt_level": "O1"}`` monitoring_params (dict): If not None, then create monitoring through Alchemy or Weights&Biases. For example, ``{"token": "api_token", "experiment": "experiment_name"}`` check (bool): if True, then only checks that pipeline is working (3 epochs only) """ if len(loaders) == 1: valid_loader = list(loaders.keys())[0] logger.warning( "Attention, there is only one data loader - " + str(valid_loader) ) if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} if model is not None: self.model = model if resume is not None: callbacks = utils.process_callbacks(callbacks) checkpoint_callback_flag = any( [ isinstance(x, CheckpointCallback) for x in callbacks.values() ] ) if not checkpoint_callback_flag: callbacks["loader"] = CheckpointCallback(resume=resume) else: raise NotImplementedError("CheckpointCallback already exist") experiment = self._experiment_fn( stage="train", model=model, loaders=loaders, callbacks=callbacks, logdir=logdir, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=num_epochs, valid_loader=valid_loader, main_metric=main_metric, minimize_metric=minimize_metric, verbose=verbose, check_run=check, state_kwargs=state_kwargs, checkpoint_data=checkpoint_data, distributed_params=fp16, monitoring_params=monitoring_params ) self.run_experiment(experiment)
[docs] def infer( self, model: Model, loaders: "OrderedDict[str, DataLoader]", callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None, verbose: bool = False, state_kwargs: Dict = None, fp16: Union[Dict, bool] = None, check: bool = False, ) -> None: """ Makes the inference on the model. Args: model (Model): model to infer loaders (dict): dictionary containing one or several ``torch.utils.data.DataLoader`` for inference callbacks (List[catalyst.dl.Callback]): list of inference callbacks verbose (bool): ff true, it displays the status of the inference to the console. state_kwargs (dict): additional state params to ``State`` fp16 (Union[Dict, bool]): If not None, then sets inference to FP16. See https://nvidia.github.io/apex/amp.html#properties if fp16=True, params by default will be ``{"opt_level": "O1"}`` check (bool): if True, then only checks that pipeline is working (3 epochs only) """ if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} if model is not None: self.model = model experiment = self._experiment_fn( stage="infer", model=model, loaders=loaders, callbacks=callbacks, verbose=verbose, check_run=check, state_kwargs=state_kwargs, distributed_params=fp16 ) self.run_experiment(experiment)
[docs] def predict_loader( self, model: Model, loader: DataLoader, resume: str = None, verbose: bool = False, state_kwargs: Dict = None, fp16: Union[Dict, bool] = None, check: bool = False, ) -> Any: """ Makes a prediction on the whole loader with the specified model. Args: model (Model): model to infer loader (DataLoader): dictionary containing only one ``torch.utils.data.DataLoader`` for inference resume (str): path to checkpoint for model verbose (bool): ff true, it displays the status of the inference to the console. state_kwargs (dict): additional state params to ``State`` fp16 (Union[Dict, bool]): If not None, then sets inference to FP16. See https://nvidia.github.io/apex/amp.html#properties if fp16=True, params by default will be ``{"opt_level": "O1"}`` check (bool): if True, then only checks that pipeline is working (3 epochs only) """ loaders = OrderedDict([("infer", loader)]) callbacks = OrderedDict([("inference", InferCallback())]) if resume is not None: callbacks["loader"] = CheckpointCallback(resume=resume) self.infer( model=model, loaders=loaders, callbacks=callbacks, verbose=verbose, state_kwargs=state_kwargs, fp16=fp16, check=check ) output = callbacks["inference"].predictions if isinstance(self.output_key, str): output = output[self.output_key] return output
[docs] def trace( self, model: Model = None, batch=None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit Args: model (Model): model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name (str): model's method name that will be traced mode (str): ``train`` or ``eval`` requires_grad (bool): flag to trace with gradients fp16 (Union[Dict, bool]): If not None, then sets tracing params to FP16 deivice (Device): Torch deivice or a string predict_params (dict): additional parameters for model forward """ if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified" ) batch = next(iter(loader)) if model is not None: self.model = model if isinstance(fp16, bool) and fp16: opt_level = "O1" elif isinstance(fp16, bool) and not fp16: opt_level = None elif isinstance(fp16, dict): opt_level = fp16["opt_level"] else: opt_level = fp16 if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = utils.get_device() device = self.device result = trace.trace_model( model=self.model, runner=self, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params ) if logdir is not None: filename = trace.get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level ) logdir = Path(logdir) output: Path = logdir / "trace" output.mkdir(exist_ok=True, parents=True) out_model = str(output / filename) torch.jit.save(result, out_model) return result
__all__ = ["SupervisedRunner"]