Shortcuts

Source code for catalyst.dl.runner.core

from typing import Any, Callable, Dict, List, Union
from collections import OrderedDict

from torch.utils.data import DataLoader, Dataset

from catalyst.core import Callback, CheckpointCallback, StageBasedRunner, State
from catalyst.dl import Experiment, utils
from catalyst.utils.tools.typing import Criterion, Model, Optimizer, Scheduler


[docs]class Runner(StageBasedRunner): """ Deep Learning Runner for different supervised, unsupervised, gan, etc runs. """ _experiment_fn: Callable = Experiment _state_fn: Callable = State def _init(self): self.experiment: Experiment = None self.state: State = None
[docs] def train( self, *, model: Model, criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None, loaders: "OrderedDict[str, DataLoader]" = None, callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None, logdir: str = None, resume: str = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, state_kwargs: Dict = None, checkpoint_data: Dict = None, fp16: Union[Dict, bool] = None, distributed: bool = False, monitoring_params: Dict = None, check: bool = False, timeit: bool = False, ) -> None: """ Starts the training process of the model. Args: model (Model): model to train criterion (Criterion): criterion function for training optimizer (Optimizer): optimizer for training scheduler (Scheduler): scheduler for training datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary with one or several ``torch.utils.data.Dataset`` for training, validation or inference used for Loaders automatic creation preferred way for distributed training setup loaders (OrderedDict[str, DataLoader]): dictionary with one or several ``torch.utils.data.DataLoader`` for training, validation or inference callbacks (Union[List[Callback], OrderedDict[str, Callback]]): list or dictionary with Catalyst callbacks logdir (str): path to output directory resume (str): path to checkpoint for model num_epochs (int): number of training epochs valid_loader (str): loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric (str): the key to the name of the metric by which the checkpoints will be selected. minimize_metric (bool): flag to indicate whether the ``main_metric`` should be minimized. verbose (bool): if `True`, it displays the status of the training to the console. state_kwargs (dict): additional state params to ``State`` checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc fp16 (Union[Dict, bool]): If not None, then sets training to FP16. See https://nvidia.github.io/apex/amp.html#properties if fp16=True, params by default will be ``{"opt_level": "O1"}`` distributed (bool): if `True` will start training in distributed mode. Note: Works only with python scripts. No jupyter support. monitoring_params (dict): If not None, then create monitoring through Alchemy or other tools. For example, ``{"token": "api_token", "experiment": "experiment_name"}`` check (bool): if True, then only checks that pipeline is working (3 epochs only) timeit (bool): if True, computes the execution time of training process and displays it to the console. """ if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} if resume is not None: callbacks = utils.sort_callbacks_by_order(callbacks) checkpoint_callback_flag = any( isinstance(x, CheckpointCallback) for x in callbacks.values() ) if not checkpoint_callback_flag: callbacks["loader"] = CheckpointCallback(resume=resume) else: raise NotImplementedError("CheckpointCallback already exist") experiment = self._experiment_fn( stage="train", model=model, loaders=loaders, datasets=datasets, callbacks=callbacks, logdir=logdir, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=num_epochs, valid_loader=valid_loader, main_metric=main_metric, minimize_metric=minimize_metric, verbose=verbose, check_run=check, check_time=timeit, state_kwargs=state_kwargs, checkpoint_data=checkpoint_data, distributed_params=fp16, monitoring_params=monitoring_params, ) self.experiment = experiment utils.distributed_cmd_run(self.run_experiment, distributed)
__all__ = ["Runner"]