Source code for catalyst.dl.core.runner

from typing import Any, Mapping, Optional, Tuple, Dict, Union  # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict
import os
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DistributedSampler

from catalyst.dl import utils
from catalyst.utils.typing import (
    Criterion, Device, Model, Optimizer, Scheduler
)
from .callback import Callback, LoggerCallback
from .experiment import Experiment
from .state import RunnerState


[docs]class Runner(ABC): """ Abstract class for all runners inherited from """
[docs] def __init__( self, model: Model = None, device: Device = None, ): """ Args: model (Model): Torch model object device (Device): Torch device """ # main self._model: Model = model self._device: Device = device self.experiment: Experiment = None self.state: RunnerState = None self.callbacks: OrderedDict[str, Callback] = None self.loggers: OrderedDict[str, LoggerCallback] = None # additional self._check_run = False
def _batch2device(self, batch: Mapping[str, Any], device: Device): res = utils.any2device(batch, device) return res def _get_experiment_components( self, stage: str = None ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]: """ Inner method for children's classes for model specific initialization. As baseline, checks device support and puts model on it. :return: """ utils.set_global_seed(self.experiment.initial_seed) model = self.experiment.get_model(stage) criterion, optimizer, scheduler = \ self.experiment.get_experiment_components(model, stage) model, criterion, optimizer, scheduler, device = \ utils.process_components( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, distributed_params=self.experiment.distributed_params, device=self.device ) return model, criterion, optimizer, scheduler, device def _prepare_for_stage(self, stage: str): utils.set_global_seed(self.experiment.initial_seed) migrating_params = {} if self.state is not None: migrating_params.update( { "step": self.state.step, "epoch": self.state.epoch } ) self.model, criterion, optimizer, scheduler, self.device = \ self._get_experiment_components(stage) self.state = RunnerState( stage=stage, model=self.model, device=self.device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, **self.experiment.get_state_params(stage), **migrating_params ) utils.set_global_seed(self.experiment.initial_seed) def _run_event(self, event: str, moment: Optional[str]): fn_name = f"on_{event}" if moment is not None: fn_name = f"{fn_name}_{moment}" # before callbacks if self.state is not None: getattr(self.state, f"{fn_name}_pre")() if self.loggers is not None and moment == "start": for logger in self.loggers.values(): getattr(logger, fn_name)(self.state) # running callbacks if self.callbacks is not None: for callback in self.callbacks.values(): getattr(callback, fn_name)(self.state) # after callbacks if self.loggers is not None and \ (moment == "end" or moment is None): # for on_exception case for logger in self.loggers.values(): getattr(logger, fn_name)(self.state) if self.state is not None: getattr(self.state, f"{fn_name}_post")() @property def model(self) -> Model: """ Returns the runner's model instance """ return self._model @model.setter def model(self, value: Union[Model, Dict[str, Model]]): """ Setter for the runner's model' """ if isinstance(value, nn.Module): model = value elif isinstance(value, dict): values_are_models = all([ isinstance(v, nn.Module) for v in value.values() ]) if not values_are_models: raise TypeError( "Invalid dict value type, must be `torch.nn.Module`" ) model = value else: raise TypeError( f"Invalid value type " f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` " f"got '{type(value)}'" ) if self._device is not None: model: Model = utils.maybe_recursive_call( model, "to", device=self._device ) self._model = model @property def device(self) -> Device: """ Returns the runner's device instance """ return self._device @device.setter def device(self, value: Device): """ Setter for the runner's device' """ if isinstance(value, (str, torch.device)): self._device = value else: raise TypeError( f"Invalid value type " f"must be `str` or `torch.device` " f"got '{type(value)}'" ) if self._model is not None: self._model = utils.maybe_recursive_call( self._model, "to", device=self._device )
[docs] @abstractmethod def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]: """ Forward method for your Runner Args: batch: Key-value batch items **kwargs: kwargs to pass to the model """ pass
[docs] def predict_batch( self, batch: Mapping[str, Any], **kwargs ) -> Mapping[str, Any]: """ Run model for a batch of elements WARN: You should not override this method. If you need specific model call, override forward() method Args: batch: Key-value batch items **kwargs: kwargs to pass to the model Returns: model output key-value """ batch = self._batch2device(batch, self.device) output = self.forward(batch, **kwargs) return output
def _run_batch(self, batch): self.state.step += self.state.batch_size batch = self._batch2device(batch, self.device) self.state.input = batch self.state.timer.stop("_timers/data_time") self._run_event("batch", moment="start") self.state.timer.start("_timers/model_time") self.state.output = self.forward(batch) self.state.timer.stop("_timers/model_time") self.state.timer.stop("_timers/batch_time") self._run_event("batch", moment="end") def _run_loader(self, loader): self.state.batch_size = ( loader.batch_sampler.batch_size if loader.batch_sampler is not None else loader.batch_size ) self.state.step = ( self.state.step or self.state.epoch * len(loader) * self.state.batch_size ) # @TODO: remove time usage, use it under the hood self.state.timer.reset() self.state.timer.start("_timers/batch_time") self.state.timer.start("_timers/data_time") for i, batch in enumerate(loader): self._run_batch(batch) self.state.timer.reset() if self._check_run and i >= 3: break self.state.timer.start("_timers/batch_time") self.state.timer.start("_timers/data_time") def _run_epoch(self, loaders): # @TODO: better solution with train/inference handling ? if not self.state.stage.startswith("infer"): assert self.state.valid_loader in loaders.keys(), \ f"'{self.state.valid_loader}' " \ f"should be in provided loaders: {list(loaders.keys())}" else: assert not any(x.startswith("train") for x in loaders.keys()), \ "for inference no train loader should be passed" for loader_name, loader in loaders.items(): self.state.loader_name = loader_name self.state.loader_len = len(loader) self.state.need_backward = loader_name.startswith("train") utils.maybe_recursive_call( self.model, "train", mode=self.state.need_backward ) if isinstance(loader.sampler, DistributedSampler) \ and loader_name.startswith("train"): loader.sampler.set_epoch(self.state.stage_epoch) utils.set_global_seed( self.experiment.initial_seed + self.state.epoch + 1 ) self._run_event("loader", moment="start") with torch.set_grad_enabled(self.state.need_backward): self._run_loader(loader) self._run_event("loader", moment="end") def _run_stage(self, stage: str): self._prepare_for_stage(stage) loaders = self.experiment.get_loaders(stage) callbacks = self.experiment.get_callbacks(stage) loggers = utils.process_callbacks( OrderedDict([ (k, v) for k, v in callbacks.items() if isinstance(v, LoggerCallback) ]) ) callbacks = utils.process_callbacks( OrderedDict([ (k, v) for k, v in callbacks.items() if not isinstance(v, LoggerCallback) ]) ) self.state.loggers = loggers self.loggers = loggers self.callbacks = callbacks self._run_event("stage", moment="start") for epoch in range(self.state.num_epochs): self.state.stage_epoch = epoch self._run_event("epoch", moment="start") self._run_epoch(loaders) self._run_event("epoch", moment="end") if self._check_run and self.state.epoch >= 3: break if self.state.early_stop: self.state.early_stop = False break self.state.epoch += 1 self._run_event("stage", moment="end")
[docs] def run_experiment(self, experiment: Experiment, check: bool = False): """ Starts the experiment """ self._check_run = check self.experiment = experiment # jupyter source code logging hack # + hack to prevent cycle imports from catalyst.dl.experiment import BaseExperiment if isinstance(self.experiment, BaseExperiment) \ and self.experiment.logdir is not None: expdir = Path(os.getcwd()) logdir = Path(self.experiment.logdir) utils.dump_base_experiment_code(expdir, logdir) try: for stage in self.experiment.stages: self._run_stage(stage) except (Exception, KeyboardInterrupt) as ex: self.state.exception = ex self._run_event("exception", moment=None) return self
__all__ = ["Runner"]