Shortcuts

Source code for catalyst.runners.runner

from typing import Any, Callable, Dict, Generator, List, Mapping, Union
from collections import OrderedDict

import torch
from torch.jit import ScriptModule
from torch.utils.data import DataLoader, Dataset

from catalyst.callbacks.checkpoint import CheckpointCallback
from catalyst.core.callback import Callback
from catalyst.core.functional import sort_callbacks_by_order
from catalyst.core.runner import IStageBasedRunner
from catalyst.experiments.experiment import Experiment
from catalyst.typing import (
    Criterion,
    Device,
    Model,
    Optimizer,
    RunnerModel,
    Scheduler,
)
from catalyst.utils import check_amp_available
from catalyst.utils.checkpoint import load_checkpoint, unpack_checkpoint
from catalyst.utils.components import process_components
from catalyst.utils.misc import maybe_recursive_call, set_global_seed
from catalyst.utils.scripts import distributed_cmd_run
from catalyst.utils.torch import (
    get_device,
    get_requires_grad,
    set_requires_grad,
)
from catalyst.utils.tracing import save_traced_model, trace_model


def _resolve_bool_fp16(fp16: Union[Dict, bool]) -> Dict:
    """Resolves fp16/distributed params usage.

    Args:
        fp16: fp16 params

    Returns:
        resolved version of fp16
    """
    if isinstance(fp16, bool):
        if fp16:
            return (
                {"amp": True}
                if check_amp_available()
                else {"apex": True, "opt_level": "O1"}
            )
        else:
            return {}
    else:
        return fp16


[docs]class Runner(IStageBasedRunner): """Deep Learning Runner for supervised, unsupervised, gan, etc runs."""
[docs] def __init__( self, model: RunnerModel = None, device: Device = None, experiment_fn: Callable = Experiment, ): """ Args: model: Torch model object device: Torch device experiment_fn: callable function, which defines default experiment type to use during ``.train`` and ``.infer`` methods. """ super().__init__(model=model, device=device) self._experiment_fn = experiment_fn
[docs] def train( self, *, model: Model, criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None, loaders: "OrderedDict[str, DataLoader]" = None, callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None, logdir: str = None, resume: str = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, stage_kwargs: Dict = None, checkpoint_data: Dict = None, fp16: Union[Dict, bool] = None, distributed: bool = False, check: bool = False, overfit: bool = False, timeit: bool = False, load_best_on_end: bool = False, initial_seed: int = 42, state_kwargs: Dict = None, ) -> None: """ Starts the train stage of the model. Args: model: model to train criterion: criterion function for training optimizer: optimizer for training scheduler: scheduler for training datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary with one or several ``torch.utils.data.Dataset`` for training, validation or inference used for Loaders automatic creation preferred way for distributed training setup loaders (OrderedDict[str, DataLoader]): dictionary with one or several ``torch.utils.data.DataLoader`` for training, validation or inference callbacks (Union[List[Callback], OrderedDict[str, Callback]]): list or dictionary with Catalyst callbacks logdir: path to output directory resume: path to checkpoint for model num_epochs: number of training epochs valid_loader: loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric: the key to the name of the metric by which the checkpoints will be selected. minimize_metric: flag to indicate whether the ``main_metric`` should be minimized. verbose: if `True`, it displays the status of the training to the console. stage_kwargs: additional params for stage checkpoint_data: additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc fp16: parameters for fp16/distributed training. to use pytorch native amp - ``{"amp": True}``. to use apex - ``{"apex": True, "opt_level": "O1", ...}``. If fp16=True, params by default will be: ``{"amp": True}`` if torch>=1.6.0, ``{"apex": True, "opt_level": "O1", ...}`` if torch<1.6.0. See https://nvidia.github.io/apex/amp.html#properties for more params. distributed: if `True` will start training in distributed mode. Note: Works only with python scripts. No jupyter support. check: if True, then only checks that pipeline is working (3 epochs only with 3 batches per loader) overfit: if True, then takes only one batch per loader for model overfitting, for advance usage please check ``BatchOverfitCallback`` timeit: if True, computes the execution time of training process and displays it to the console. load_best_on_end: if True, Runner will load best checkpoint state (model, optimizer, etc) according to validation metrics. Requires specified ``logdir``. initial_seed: experiment's initial seed value state_kwargs: deprecated, use `stage_kwargs` instead Raises: NotImplementedError: if both `resume` and `CheckpointCallback` already exist """ assert state_kwargs is None or stage_kwargs is None fp16 = _resolve_bool_fp16(fp16) if resume is not None or load_best_on_end: load_on_stage_end = None if load_best_on_end: load_on_stage_end = "best_full" assert logdir is not None, ( "For ``load_best_on_end`` feature " "you need to specify ``logdir``" ) callbacks = sort_callbacks_by_order(callbacks) checkpoint_callback_flag = any( isinstance(x, CheckpointCallback) for x in callbacks.values() ) if not checkpoint_callback_flag: callbacks["_loader"] = CheckpointCallback( resume=resume, load_on_stage_end=load_on_stage_end, ) else: raise NotImplementedError("CheckpointCallback already exist") experiment = self._experiment_fn( stage="train", model=model, datasets=datasets, loaders=loaders, callbacks=callbacks, logdir=logdir, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=num_epochs, valid_loader=valid_loader, main_metric=main_metric, minimize_metric=minimize_metric, verbose=verbose, check_time=timeit, check_run=check, overfit=overfit, stage_kwargs=stage_kwargs or state_kwargs, checkpoint_data=checkpoint_data, distributed_params=fp16, initial_seed=initial_seed, ) self.experiment = experiment distributed_cmd_run(self.run_experiment, distributed)
[docs] def infer( self, *, model: Model, datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None, loaders: "OrderedDict[str, DataLoader]" = None, callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None, logdir: str = None, resume: str = None, verbose: bool = False, stage_kwargs: Dict = None, fp16: Union[Dict, bool] = None, check: bool = False, timeit: bool = False, initial_seed: int = 42, state_kwargs: Dict = None, ) -> None: """ Starts the inference stage of the model. Args: model: model for inference datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary with one or several ``torch.utils.data.Dataset`` for training, validation or inference used for Loaders automatic creation preferred way for distributed training setup loaders (OrderedDict[str, DataLoader]): dictionary with one or several ``torch.utils.data.DataLoader`` for training, validation or inference callbacks (Union[List[Callback], OrderedDict[str, Callback]]): list or dictionary with Catalyst callbacks logdir: path to output directory resume: path to checkpoint to use for resume verbose: if `True`, it displays the status of the training to the console. stage_kwargs: additional stage params fp16 (Union[Dict, bool]): fp16 settings (same as in `train`) check: if True, then only checks that pipeline is working (3 epochs only) timeit: if True, computes the execution time of training process and displays it to the console. initial_seed: experiment's initial seed value state_kwargs: deprecated, use `stage_kwargs` instead Raises: NotImplementedError: if both `resume` and `CheckpointCallback` already exist """ assert state_kwargs is None or stage_kwargs is None fp16 = _resolve_bool_fp16(fp16) if resume is not None: callbacks = sort_callbacks_by_order(callbacks) checkpoint_callback_flag = any( isinstance(x, CheckpointCallback) for x in callbacks.values() ) if not checkpoint_callback_flag: callbacks["loader"] = CheckpointCallback(resume=resume) else: raise NotImplementedError("CheckpointCallback already exist") experiment = self._experiment_fn( stage="infer", model=model, datasets=datasets, loaders=loaders, callbacks=callbacks, logdir=logdir, verbose=verbose, check_time=timeit, check_run=check, stage_kwargs=stage_kwargs or state_kwargs, distributed_params=fp16, initial_seed=initial_seed, ) self.run_experiment(experiment)
[docs] @torch.no_grad() def predict_batch( self, batch: Mapping[str, Any], **kwargs ) -> Mapping[str, Any]: """ Run model inference on specified data batch. Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. **kwargs: additional kwargs to pass to the model # noqa: DAR202 Returns: Mapping[str, Any]: model output dictionary Raises: NotImplementedError: if not implemented yet """ raise NotImplementedError( "Please implement `runner.predict_batch` method" )
[docs] @torch.no_grad() def predict_loader( self, *, loader: DataLoader, model: Model = None, resume: str = None, fp16: Union[Dict, bool] = None, initial_seed: int = 42, ) -> Generator: """ Runs model inference on PyTorch Dataloader and returns python generator with model predictions from `runner.predict_batch`. Cleans up the experiment info to avoid possible collisions. Sets `is_train_loader` and `is_valid_loader` to `False` while keeping `is_infer_loader` as True. Moves model to evaluation mode. Args: loader: loader to predict model: model to use for prediction resume: path to checkpoint to resume fp16 (Union[Dict, bool]): fp16 settings (same as in `train`) initial_seed: seed to use before prediction Yields: bathes with model predictions """ fp16 = _resolve_bool_fp16(fp16) if model is not None: self.model = model assert self.model is not None if resume is not None: checkpoint = load_checkpoint(resume) unpack_checkpoint(checkpoint, model=self.model) self.experiment = None set_global_seed(initial_seed) (model, _, _, _, device) = process_components( # noqa: WPS122 model=self.model, distributed_params=fp16, device=self.device, ) self._prepare_inner_state( stage="infer", model=model, device=device, is_train_loader=False, is_valid_loader=False, is_infer_loader=True, ) maybe_recursive_call(self.model, "train", mode=False) set_global_seed(initial_seed) for batch in loader: yield self.predict_batch(batch)
[docs] def trace( self, *, model: Model = None, batch: Any = None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit. Args: model: model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name: model's method name that will be traced mode: ``train`` or ``eval`` requires_grad: flag to trace with gradients fp16 (Union[Dict, bool]): fp16 settings (same as in `train`) device: Torch device or a string predict_params: additional parameters for model forward Returns: ScriptModule: traced model Raises: ValueError: if `batch` and `loader` are Nones """ # @TODO: refactor for easy use # @TODO: also add quantize, prune, onnx-convert if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified" ) batch = next(iter(loader)) if model is not None: self.model = model assert self.model is not None fp16 = _resolve_bool_fp16(fp16) opt_level = None if fp16: opt_level = fp16.get("opt_level", None) if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = get_device() device = self.device # Dumping previous state of the model, we will need it to restore device_dump, is_training_dump, requires_grad_dump = ( self.device, self.model.training, get_requires_grad(self.model), ) self.model.to(device) # function to run prediction on batch def predict_fn(model, inputs, **kwargs): # noqa: WPS442 model_dump = self.model self.model = model result = self.predict_batch(inputs, **kwargs) self.model = model_dump return result traced_model = trace_model( model=self.model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params, ) if logdir is not None: save_traced_model( model=traced_model, logdir=logdir, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, ) # Restore previous state of the model getattr(self.model, "train" if is_training_dump else "eval")() set_requires_grad(self.model, requires_grad_dump) self.model.to(device_dump) return traced_model
__all__ = ["Runner"]