Shortcuts

Source code for catalyst.core.runner

from typing import Any, Callable, Dict, Mapping, Tuple, Union  # isort:skip

from abc import ABC, abstractmethod
from collections import OrderedDict

import torch
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler

from catalyst.core import utils
from catalyst.utils.tools.typing import (
    Criterion, Device, Model, Optimizer, Scheduler
)
from .callback import Callback, CallbackNode, CallbackScope
from .callbacks import ExceptionCallback
from .experiment import _Experiment
from .state import State


[docs]class _Runner(ABC): """ Abstract class for all runners inherited from """ _experiment_fn: Callable = _Experiment _state_fn: Callable = State
[docs] def __init__( self, model: Model = None, device: Device = None, ): """ Args: model (Model): Torch model object device (Device): Torch device """ self._model: Model = model self._device: Device = device self._init()
@property def model(self) -> Model: """ Returns the runner's model instance """ return self._model @model.setter def model(self, value: Union[Model, Dict[str, Model]]): """ Setter for the runner's model' """ if isinstance(value, nn.Module): model = value elif isinstance(value, dict): values_are_models = all( [isinstance(v, nn.Module) for v in value.values()] ) if not values_are_models: raise TypeError( "Invalid dict value type, must be `torch.nn.Module`" ) model = value else: raise TypeError( f"Invalid value type " f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` " f"got '{type(value)}'" ) if self._device is not None: model: Model = utils.maybe_recursive_call( model, "to", device=self._device ) self._model = model @property def device(self) -> Device: """ Returns the runner's device instance """ return self._device @device.setter def device(self, value: Device): """ Setter for the runner's device' """ if isinstance(value, (str, torch.device)): self._device = value else: raise TypeError( f"Invalid value type " f"must be `str` or `torch.device` " f"got '{type(value)}'" ) if self._model is not None: self._model = utils.maybe_recursive_call( self._model, "to", device=self._device ) def _init(self): self.experiment: _Experiment = None self.state: State = None
[docs] @abstractmethod def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Forward method for your Runner Args: batch: Key-value batch items **kwargs: kwargs to pass to the model """ pass
def _get_experiment_components( self, stage: str = None ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]: """ Inner method for children's classes for model specific initialization. As baseline, checks device support and puts model on it. :return: """ utils.set_global_seed(self.experiment.initial_seed) model = self.experiment.get_model(stage) criterion, optimizer, scheduler = \ self.experiment.get_experiment_components(model, stage) model, criterion, optimizer, scheduler, device = \ utils.process_components( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, distributed_params=self.experiment.distributed_params, device=self.device ) return model, criterion, optimizer, scheduler, device def _get_state( self, stage: str, model: Model, criterion: Criterion, optimizer: Optimizer, scheduler: Scheduler, device: Device, callbacks: Dict[str, Callback], ): migrating_params = dict(**self.experiment.get_state_params(stage)) migrate_from_previous_stage = \ migrating_params.get("migrate_from_previous_stage", True) if migrate_from_previous_stage \ and self.state is not None \ and self.state.callbacks is not None: for key, value in self.state.callbacks.items(): if value.scope == CallbackScope.Experiment: callbacks[key] = value callbacks = utils.process_callbacks(callbacks) if self.state is not None and migrate_from_previous_stage: migrating_params.update( { "global_step": self.state.global_step, "global_epoch": self.state.global_epoch, "resume": getattr(self.state, "resume", None), } ) state = self._state_fn( stage=stage, model=model, device=device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, callbacks=callbacks, **migrating_params ) return state def _get_callbacks(self, stage: str): callbacks = self.experiment.get_callbacks(stage) # distributed run setting rank = utils.get_rank() if rank == 0: # master node # remove worker-only callbacks on master node for k in list( filter( lambda c: callbacks[c].node == CallbackNode.Worker, callbacks ) ): del callbacks[k] elif rank > 0: # worker node # remove master-only callbacks on worker nodes for k in list( filter( lambda c: callbacks[c].node == CallbackNode.Master, callbacks ) ): del callbacks[k] callbacks = utils.process_callbacks(callbacks) return callbacks def _prepare_for_stage(self, stage: str): utils.set_global_seed(self.experiment.initial_seed) self.model, criterion, optimizer, scheduler, self.device = \ self._get_experiment_components(stage=stage) utils.set_global_seed(self.experiment.initial_seed) callbacks = self._get_callbacks(stage) utils.set_global_seed(self.experiment.initial_seed) self.state = self._get_state( stage=stage, model=self.model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, device=self.device, callbacks=callbacks, ) def _prepare_for_epoch(self, stage: str, epoch: int): pass def _run_event(self, event: str): for callback in self.state.callbacks.values(): getattr(callback, event)(self.state) def _batch2device(self, batch: Mapping[str, Any], device: Device): output = utils.any2device(batch, device) return output def _run_batch_train_step(self, batch: Mapping[str, Any]): self.state.batch_out = self.forward(batch)
[docs] @torch.no_grad() def predict_batch( self, batch: Mapping[str, Any], **kwargs ) -> Mapping[str, Any]: """ Run model for a batch of elements WARN: You should not override this method. If you need specific model call, override forward() method Args: batch: Key-value batch items **kwargs: kwargs to pass to the model Returns: model output key-value """ batch = self._batch2device(batch, self.device) output = self.forward(batch, **kwargs) return output
def _run_batch(self, batch: Mapping[str, Any]): self.state.global_step += self.state.batch_size batch = self._batch2device(batch, self.device) self.state.batch_in = batch self._run_event("on_batch_start") self._run_batch_train_step(batch=batch) self._run_event("on_batch_end") def _run_loader(self, loader: DataLoader): self.state.batch_size = ( loader.batch_sampler.batch_size if loader.batch_sampler is not None else loader.batch_size ) self.state.global_step = ( self.state.global_step or self.state.global_epoch * len(loader) * self.state.batch_size ) for i, batch in enumerate(loader): self.state.loader_step = i + 1 self._run_batch(batch) if self.state.need_early_stop: self.state.need_early_stop = False break def _run_epoch(self, stage: str, epoch: int): self._prepare_for_epoch(stage=stage, epoch=epoch) state: State = self.state assert state.loaders is not None loaders = state.loaders # @TODO: better solution with train/inference handling ? state.is_infer_stage = state.stage_name.startswith("infer") if not state.is_infer_stage: assert state.valid_loader in loaders.keys(), \ f"'{state.valid_loader}' " \ f"should be in provided loaders: {list(loaders.keys())}" else: # @TODO: add check for non distributed run for inference assert not any(x.startswith("train") for x in loaders.keys()), \ "for inference no train loader should be passed" for loader_name, loader in loaders.items(): state.loader_name = loader_name state.loader_len = len(loader) state.is_train_loader = loader_name.startswith("train") self.model.train(state.is_train_loader) if isinstance(loader.sampler, DistributedSampler) \ and not state.is_infer_stage: loader.sampler.set_epoch(state.epoch) utils.set_global_seed( self.experiment.initial_seed + state.global_epoch + 1 ) self._run_event("on_loader_start") with torch.set_grad_enabled(state.is_train_loader): self._run_loader(loader) self._run_event("on_loader_end") def _run_stage(self, stage: str): self._prepare_for_stage(stage) state: State = self.state self._run_event("on_stage_start") while state.epoch < state.num_epochs + 1: utils.set_global_seed( self.experiment.initial_seed + state.global_epoch + 1 ) self._run_event("on_epoch_start") self._run_epoch(stage=stage, epoch=state.epoch) self._run_event("on_epoch_end") if state.need_early_stop: state.need_early_stop = False break state.global_epoch += 1 state.epoch += 1 self._run_event("on_stage_end")
[docs] def run_experiment(self, experiment: _Experiment): """ Starts the experiment """ self.experiment = experiment try: for stage in self.experiment.stages: self._run_stage(stage) except (Exception, KeyboardInterrupt) as ex: def _exception_handler_check(callbacks: OrderedDict): return ( callbacks is not None and any( issubclass(x.__class__, ExceptionCallback) for x in callbacks.values() ) ) if self.state is not None and \ _exception_handler_check(self.state.callbacks): self.state.exception = ex self._run_event("on_exception") else: raise ex return self
__all__ = ["_Runner"]