Shortcuts

Source code for catalyst.core.runner

from typing import Any, Dict, Mapping, Optional, Tuple, Union
from abc import ABC, abstractmethod
from collections import defaultdict, OrderedDict
from functools import lru_cache
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler

from catalyst.core.callback import Callback, CallbackScope, ICallback
from catalyst.core.experiment import IExperiment
from catalyst.core.functional import (
    filter_callbacks_by_node,
    sort_callbacks_by_order,
)
from catalyst.core.legacy import IRunnerLegacy
from catalyst.settings import SETTINGS
from catalyst.typing import (
    Device,
    Model,
    RunnerCriterion,
    RunnerModel,
    RunnerOptimizer,
    RunnerScheduler,
)
from catalyst.utils.components import process_components
from catalyst.utils.distributed import get_rank
from catalyst.utils.loaders import validate_loaders
from catalyst.utils.misc import maybe_recursive_call, set_global_seed
from catalyst.utils.torch import any2device


@lru_cache(maxsize=42)
def _is_substring(origin_string: str, strings: Tuple):
    return any(x in origin_string for x in strings)


[docs]class RunnerException(Exception): """Exception class for all runner errors."""
[docs] def __init__(self, message: str): """ Args: message: exception message """ super().__init__(message)
[docs]class IRunner(ABC, ICallback, IRunnerLegacy): """ An abstraction that knows how to run an experiment. It contains all the logic of **how** to run the experiment, stages, epoch and batches. .. note:: To learn more about Catalyst Core concepts, please check out - :py:mod:`catalyst.core.experiment.IExperiment` - :py:mod:`catalyst.core.runner.IRunner` - :py:mod:`catalyst.core.callback.Callback` Abstraction, please check out the implementations: - :py:mod:`catalyst.runners.runner.Runner` - :py:mod:`catalyst.runners.supervised.SupervisedRunner` Runner also contains full information about experiment runner. Runner section **runner.model** - an instance of torch.nn.Module class, \ (should implement ``forward`` method); \ for example, :: runner.model = torch.nn.Linear(10, 10) **runner.device** - an instance of torch.device (CPU, GPU, TPU); \ for example, :: runner.device = torch.device("cpu") Experiment section **runner.criterion** - an instance of torch.nn.Module class\ or torch.nn.modules.loss._Loss (should implement ``forward`` method); \ for example, :: runner.criterion = torch.nn.CrossEntropyLoss() **runner.optimizer** - an instance of torch.optim.optimizer.Optimizer\ (should implement ``step`` method); \ for example, :: runner.optimizer = torch.optim.Adam() **runner.scheduler** - an instance of torch.optim.lr_scheduler._LRScheduler\ (should implement ``step`` method); \ for example, :: runner.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau() **runner.callbacks** - ordered dictionary with Catalyst.Callback instances;\ for example, :: runner.callbacks = { "accuracy": AccuracyCallback(), "criterion": CriterionCallback(), "optim": OptimizerCallback(), "saver": CheckpointCallback() } Dataflow section **runner.loaders** - ordered dictionary with torch.DataLoaders; \ for example, :: runner.loaders = { "train": MnistTrainLoader(), "valid": MnistValidLoader() } .. note:: - "*train*" prefix is used for training loaders - \ metrics computations, backward pass, optimization - "*valid*" prefix is used for validation loaders - \ metrics computations only - "*infer*" prefix is used for inference loaders - \ dataset prediction **runner.input** - dictionary, \ containing batch of data from currents DataLoader; \ for example, :: runner.input = { "images": np.ndarray(batch_size, c, h, w), "targets": np.ndarray(batch_size, 1), } **runner.output** - dictionary, \ containing model output for current batch; \ for example, :: runner.output = {"logits": torch.Tensor(batch_size, num_classes)} Metrics section **runner.batch_metrics** - dictionary, flatten storage for batch metrics; \ for example, :: runner.batch_metrics = {"loss": ..., "accuracy": ..., "iou": ...} **runner.loader_metrics** - dictionary with aggregated batch statistics \ for loader (mean over all batches) and global loader metrics, like AUC; \ for example, :: runner.loader_metrics = {"loss": ..., "accuracy": ..., "auc": ...} **runner.epoch_metrics** - dictionary with summarized metrics \ for different loaders and global epoch metrics, like lr, momentum; \ for example, :: runner.epoch_metrics = { "train_loss": ..., "train_auc": ..., "valid_loss": ..., "lr": ..., "momentum": ..., } Validation metrics section **runner.main_metric** - string, containing name of metric of interest \ for optimization, validation and checkpointing during training **runner.minimize_metric** - bool, indicator flag - ``True`` if we need to minimize metric during training,\ like `Cross Entropy loss` - ``False`` if we need to maximize metric during training, \ like `Accuracy` or `Intersection over Union` Validation section **runner.valid_loader** - string, name of validation loader \ for metric selection, validation and model checkpoining **runner.valid_metrics** - dictionary with validation metrics\ for currect epoch; \ for example, :: runner.valid_metrics = {"loss": ..., "accuracy": ..., "auc": ...} .. note:: subdictionary of epoch_metrics **runner.is_best_valid** - bool, indicator flag - ``True`` if this training epoch is best over all epochs - ``False`` if not **runner.best_valid_metrics** - dictionary with best validation metrics \ during whole training process Distributed section **runner.distributed_rank** - distributed rank of current worker **runner.is_distributed_master** - bool, indicator flag - ``True`` if is master node (runner.distributed_rank == 0) - ``False`` if is worker node (runner.distributed_rank != 0) **runner.is_distributed_worker** - bool, indicator flag - ``True`` if is worker node (runner.distributed_rank > 0) - ``False`` if is master node (runner.distributed_rank <= 0) Experiment info section **runner.global_sample_step** - int, numerical indicator, counter for all\ individual samples, that passes through our model during training,\ validation and inference stages **runner.global_batch_step** - int, numerical indicator, counter for all batches, that passes through our model during training, validation and\ inference stages **runner.global_epoch** - int, numerical indicator, counter for all epochs,\ that have passed during model training, validation and\ inference stages **runner.verbose** - bool, indicator flag **runner.is_check_run** - bool, indicator flag - ``True`` if you want to check you pipeline and \ run only 2 batches per loader and 2 epochs per stage - ``False`` (default) if you want to just the pipeline **runner.need_early_stop** - bool, indicator flag \ used for EarlyStopping and CheckRun Callbacks - ``True`` if we need to stop the training - ``False`` (default) otherwise **runner.need_exception_reraise** - bool, indicator flag - ``True`` (default) if you want to show exception \ during pipeline and stop the training process - ``False`` otherwise Stage info section **runner.stage** - string, current stage name,\ for example, :: runner.stage = "pretraining" / "training" / "finetuning" / etc **runner.num_epochs** - int, maximum number of epochs, \ required for this stage **runner.is_infer_stage** - bool, indicator flag - ``True`` for inference stages - ``False`` otherwise Epoch info section **runner.epoch** - int, numerical indicator for current stage epoch Loader info section **runner.loader_sample_step** - int, numerical indicator \ for number of samples passed through our model in current loader **runner.loader_batch_step** - int, numerical indicator \ for batch index in current loader **runner.loader_name** - string, current loader name\ for example, :: runner.loader_name = "train_dataset1" / "valid_data2" / "infer_golden" **runner.loader_len** - int, maximum number of batches in current loader **runner.loader_batch_size** - int, batch size parameter in current loader **runner.is_train_loader** - bool, indicator flag - ``True`` for training loaders - ``False`` otherwise **runner.is_valid_loader** - bool, indicator flag - ``True`` for validation loaders - ``False`` otherwise **runner.is_infer_loader** - bool, indicator flag - ``True`` for inference loaders - ``False`` otherwise Batch info section **runner.batch_size** - int, length of the current batch Logging section **runner.logdir** - string, path to logging directory to save\ all logs, metrics, checkpoints and artifacts Extra section **runner.exception** - python Exception instance to raise (or not ;) ) """
[docs] def __init__( self, model: RunnerModel = None, device: Device = None, ): """ Args: model: Torch model object device: Torch device """ self._device = None self._model = None self.experiment: IExperiment = None self._prepare_inner_state(model=model, device=device)
def _prepare_inner_state( self, stage: str = SETTINGS.stage_infer_prefix, device: Device = None, model: RunnerModel = None, criterion: RunnerCriterion = None, optimizer: RunnerOptimizer = None, scheduler: RunnerScheduler = None, callbacks: Dict[str, "Callback"] = None, loaders: Dict[str, "DataLoader"] = None, logdir: str = None, num_epochs: int = 1, main_metric: str = "loss", minimize_metric: bool = True, valid_loader: str = SETTINGS.loader_valid_prefix, checkpoint_data: Dict = None, is_check_run: bool = False, verbose: bool = False, **kwargs, ): # @TODO: move/split this method to callbacks group # here should be only a small part of it # main runner components: model and device to run self.device: Device = device self.model: RunnerModel = model # experiment components, # use `catalyst.core.IExperiment` to setup them self.criterion: RunnerCriterion = criterion self.optimizer: RunnerOptimizer = optimizer self.scheduler: RunnerScheduler = scheduler # and callbacks self.callbacks: Dict[str, "Callback"] = callbacks or {} # the data self.loader = None self.loaders: OrderedDict[str, DataLoader] = loaders # and the dataflow - model input, model output self.input = None self.output = None # metrics flow - batch, loader, epoch metrics # let's use flatten storage for batch metrics # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...} self.batch_metrics: Dict = defaultdict(None) # just aggregated (aka mean over all batches) # batch statistics for loader # and global loader metrics, like AUC # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...} self.loader_metrics: Dict = defaultdict(None) # summarized metrics for different loaders # and global epoch metrics, like lr, momentum # epoch_metrics = { # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ..., # 'lr': ..., 'momentum': ..., # } self.epoch_metrics: Dict = defaultdict(None) # metrics & validation # @TODO: mote to validation callback self.main_metric: str = main_metric self.minimize_metric: bool = minimize_metric # validation # @TODO: mote to validation callback self.valid_loader: str = valid_loader self.valid_metrics: Dict = defaultdict(None) self.is_best_valid: bool = False self.best_valid_metrics: Dict = defaultdict(None) # distributed info # @TODO: move to Engine self.distributed_rank: int = get_rank() self.is_distributed_master: bool = ~(self.distributed_rank > 0) self.is_distributed_worker: bool = self.distributed_rank > 0 # experiment info self.global_sample_step: int = 0 self.global_batch_step: int = 0 self.global_epoch: int = 1 self.verbose: bool = verbose self.is_check_run: bool = is_check_run self.need_early_stop: bool = False self.need_exception_reraise: bool = True # stage info self.num_epochs: int = num_epochs self.stage: str = stage self.is_infer_stage: bool = self.stage.startswith( SETTINGS.stage_infer_prefix ) # epoch info self.epoch: int = 1 # loader info self.loader_sample_step: int = 0 self.loader_batch_step: int = 0 self.loader_key: str = None self.loader_len: int = 0 self.loader_batch_size = 0 self.is_train_loader: bool = False self.is_valid_loader: bool = False self.is_infer_loader: bool = True # batch info self.batch_size: int = 0 # logging # @TODO: mote to checkpoint callback # self.expdir: Path = None self.logdir: Path = Path(logdir) if logdir is not None else None # extra checkpoint data for saving in checkpoint files self.checkpoint_data: Dict = checkpoint_data or {} # extra self.exception: Optional[Exception] = None # kwargs for key, value in kwargs.items(): setattr(self, key, value) @property def model(self) -> Model: """Returns the runner's model instance.""" return self._model @model.setter def model(self, value: Union[Model, Dict[str, Model]]): """ Setter for the runner's model, useful for experiment tracing. Args: value (Union[Model, Dict[str, Model]]): new model. Raises: TypeError: if value is out of `torch.nn.Module` or `Dict[str, torch.nn.Module]` """ if isinstance(value, nn.Module): model = value elif isinstance(value, dict): values_are_models = all( isinstance(v, nn.Module) for v in value.values() ) if not values_are_models: raise TypeError( "Invalid dict value type, must be `torch.nn.Module`" ) model = value elif isinstance(value, type(None)): model = None else: raise TypeError( f"Invalid value type " f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` " f"got '{type(value)}'" ) if model is not None and self._device is not None: model: Model = maybe_recursive_call( model, "to", device=self._device ) self._model = model @property def device(self) -> Device: """Returns the runner's device instance.""" return self._device @device.setter def device(self, value: Device): """ Setter for the runner's device. Args: value: new torch device. Raises: TypeError: if `value` is out of `torch.device`, `str` or `None` """ if isinstance(value, torch.device): self._device = value elif isinstance(value, str): self._device = torch.device(value) elif isinstance(value, type(None)): self._device = None else: raise TypeError( f"Invalid value type " f"must be `str` or `torch.device` " f"got '{type(value)}'" ) if self._model is not None: self._model = maybe_recursive_call( self._model, "to", device=self._device or "cpu" )
[docs] def on_experiment_start(self, runner: "IRunner"): """Event handler for experiment start. Args: runner: IRunner instance. .. note:: This event work only on IRunner. """ assert self.experiment is not None set_global_seed(self.experiment.initial_seed + self.global_epoch + 1)
[docs] def on_stage_start(self, runner: "IRunner"): """Event handler for stage start. Args: runner: IRunner instance. """ assert self.stage is not None set_global_seed(self.experiment.initial_seed + self.global_epoch + 1)
[docs] def on_epoch_start(self, runner: "IRunner"): """Event handler for epoch start. Args: runner: IRunner instance. Raises: RunnerException: if current DataLoader is empty. """ assert self.loaders is not None for loader_key, loader in self.loaders.items(): if len(loader) == 0: raise RunnerException( f"DataLoader with name {loader_key} is empty." ) if not self.is_infer_stage: assert self.valid_loader in self.loaders.keys(), ( f"'{self.valid_loader}' " f"should be in provided loaders: {list(self.loaders.keys())}" ) else: assert not any( x.startswith(SETTINGS.loader_train_prefix) for x in self.loaders.keys() ), "for inference no train loader should be passed" set_global_seed(self.experiment.initial_seed + self.global_epoch + 1)
[docs] def on_loader_start(self, runner: "IRunner"): """Event handler for loader start. Args: runner: IRunner instance. Raises: RunnerException: if current DataLoader is empty. """ assert self.loader is not None self.loader_len = len(self.loader) if self.loader_len == 0: raise RunnerException( f"DataLoader with name {self.loader_key} is empty." ) self.loader_batch_size = ( self.loader.batch_sampler.batch_size if self.loader.batch_sampler is not None else self.loader.batch_size ) self.loader_sample_step = 0 self.is_train_loader = self.loader_key.startswith( SETTINGS.loader_train_prefix ) self.is_valid_loader = self.loader_key.startswith( SETTINGS.loader_valid_prefix ) self.is_infer_loader = self.loader_key.startswith( SETTINGS.loader_infer_prefix ) maybe_recursive_call(self.model, "train", mode=self.is_train_loader) if isinstance(self.loader.sampler, DistributedSampler): self.loader.sampler.set_epoch(self.epoch) set_global_seed(self.experiment.initial_seed + self.global_epoch + 1)
[docs] def on_batch_start(self, runner: "IRunner"): """Event handler for batch start. Args: runner: IRunner instance. """ if isinstance(self.input, dict): self.batch_size = len(next(iter(self.input.values()))) else: self.batch_size = len(self.input[0]) self.global_batch_step += 1 # self.loader_batch_step += 1 self.global_sample_step += self.batch_size self.loader_sample_step += self.batch_size
[docs] def on_batch_end(self, runner: "IRunner"): """Event handler for batch end. Args: runner: IRunner instance. """ pass
[docs] def on_loader_end(self, runner: "IRunner"): """Event handler for loader end. Args: runner: IRunner instance. """ pass
[docs] def on_epoch_end(self, runner: "IRunner"): """Event handler for epoch end. Args: runner: IRunner instance. """ self.global_epoch += 1 self.epoch += 1
[docs] def on_stage_end(self, runner: "IRunner"): """Event handler for stage end. Args: runner: IRunner instance. """ pass
[docs] def on_experiment_end(self, runner: "IRunner"): """Event handler for experiment end. Args: runner: IRunner instance. .. note:: This event work only on IRunner. """ pass
[docs] def on_exception(self, runner: "IRunner"): """Event handler for exception case. Args: runner: IRunner instance. Raises: exception: if during pipeline exception, no handler we found into callbacks """ from catalyst.callbacks.exception import ExceptionCallback def _exception_handler_check(callbacks: Union[OrderedDict, Dict]): return callbacks is not None and any( issubclass(x.__class__, ExceptionCallback) for x in callbacks.values() ) if not _exception_handler_check(getattr(self, "callbacks", None)): raise self.exception
def _run_event(self, event: str) -> None: """Inner method to run specified event on Runners' callbacks. Args: event(str): event name to run on callbacks. .. note:: To learn more about Catalyst Callbacks mechanism, please follow :py:mod:`catalyst.core.callback.Callback` documentation. """ # @TODO: how to remove self duplication? and does it really matter? if _is_substring(event, ("start", "exception")): getattr(self, event)(self) for callback in self.callbacks.values(): getattr(callback, event)(self) if _is_substring(event, ("end",)): getattr(self, event)(self) def _handle_device(self, batch: Mapping[str, Any]): return any2device(batch, self.device) @abstractmethod def _handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer stage during Experiment run. Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. """ pass def _run_batch(self) -> None: self.input = self._handle_device(batch=self.input) self._run_event("on_batch_start") self._handle_batch(batch=self.input) self._run_event("on_batch_end") def _run_loader(self) -> None: self._run_event("on_loader_start") with torch.set_grad_enabled(self.is_train_loader): for self.loader_batch_step, self.input in enumerate(self.loader): self._run_batch() if self.need_early_stop: self.need_early_stop = False break self._run_event("on_loader_end") def _run_epoch(self) -> None: self._run_event("on_epoch_start") for self.loader_key, self.loader in self.loaders.items(): self._run_loader() self._run_event("on_epoch_end") def _run_stage(self) -> None: self._run_event("on_stage_start") while self.epoch < self.num_epochs + 1: self._run_epoch() if self.need_early_stop: self.need_early_stop = False break self._run_event("on_stage_end") def _run_experiment(self) -> None: self._run_event("on_experiment_start") for self.stage in self.experiment.stages: self._run_stage() self._run_event("on_experiment_end")
[docs] def run_experiment(self, experiment: IExperiment = None) -> "IRunner": """ Starts the experiment. Args: experiment: Experiment instance to use for Runner. Returns: self, `IRunner` instance after the experiment """ self.experiment = experiment or self.experiment try: self._run_experiment() except (Exception, KeyboardInterrupt) as ex: self.exception = ex self._run_event("on_exception") return self
[docs]class IStageBasedRunner(IRunner): """ Runner abstraction that suppose to have constant datasources per stage. """
[docs] def on_stage_start(self, runner: "IRunner") -> None: """Event handler for stage start. For the `IStageBasedRunner` case: - prepares loaders - our datasources - prepares model components - model, criterion, optimizer, scheduler - prepares callbacks for the current stage Args: runner: IRunner instance. """ super().on_stage_start(runner) set_global_seed(self.experiment.initial_seed) loaders = self.experiment.get_loaders(stage=self.stage) loaders = validate_loaders(loaders) # self.loaders = loaders set_global_seed(self.experiment.initial_seed) model = self.experiment.get_model(self.stage) criterion = self.experiment.get_criterion(self.stage) optimizer = self.experiment.get_optimizer(self.stage, model) scheduler = self.experiment.get_scheduler(self.stage, optimizer) model, criterion, optimizer, scheduler, device = process_components( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, distributed_params=self.experiment.distributed_params, device=self.device, ) set_global_seed(self.experiment.initial_seed) callbacks = self.experiment.get_callbacks(self.stage) callbacks = filter_callbacks_by_node(callbacks) callbacks = sort_callbacks_by_order(callbacks) migrating_params = dict(**self.experiment.get_stage_params(self.stage)) migrate_from_previous_stage = migrating_params.get( "migrate_from_previous_stage", True ) if ( migrate_from_previous_stage and getattr(self, "callbacks", None) is not None ): for key, value in self.callbacks.items(): if value.scope == CallbackScope.experiment: callbacks[key] = value callbacks = sort_callbacks_by_order(callbacks) if migrate_from_previous_stage: migrating_params.update( { "global_epoch": getattr(self, "global_epoch", 1), "global_batch_step": getattr(self, "global_batch_step", 0), "global_sample_step": getattr( self, "global_sample_step", 0 ), "resume": getattr(self, "resume", None), } ) self._prepare_inner_state( stage=self.stage, model=model, device=device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, loaders=loaders, **migrating_params, )
__all__ = ["IRunner", "IStageBasedRunner", "RunnerException"]