Source code for catalyst.dl.core.runner

from typing import Any, Mapping, Optional, Tuple  # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict
import os
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DistributedSampler

from catalyst.dl import utils
from catalyst.dl.utils.scripts import dump_base_experiment_code
from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
from .callback import Callback, LoggerCallback
from .experiment import Experiment
from .state import RunnerState


[docs]class Runner(ABC):
[docs] def __init__( self, model: nn.Module = None, device=None, ): """ @TODO: write docs """ # main self.model: nn.Module = model self.device = device self.experiment: Experiment = None self.state: RunnerState = None self.callbacks: OrderedDict[str, Callback] = None self.loggers: OrderedDict[str, LoggerCallback] = None # additional self._check_run = False
def _batch2device(self, batch: Mapping[str, Any], device): res = utils.any2device(batch, device) return res def _get_experiment_components( self, stage: str = None ) -> Tuple[_Model, _Criterion, _Optimizer, _Scheduler, torch.device]: """ Inner method for children's classes for model specific initialization. As baseline, checks device support and puts model on it. :return: """ utils.set_global_seed(self.experiment.initial_seed) model = self.experiment.get_model(stage) criterion, optimizer, scheduler = \ self.experiment.get_experiment_components(model, stage) model, criterion, optimizer, scheduler, device = \ utils.process_components( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, distributed_params=self.experiment.distributed_params, device=self.device ) return model, criterion, optimizer, scheduler, device def _prepare_for_stage(self, stage: str): utils.set_global_seed(self.experiment.initial_seed) migrating_params = {} if self.state is not None: migrating_params.update( { "step": self.state.step, "epoch": self.state.epoch } ) self.model, criterion, optimizer, scheduler, self.device = \ self._get_experiment_components(stage) self.state = RunnerState( stage=stage, model=self.model, device=self.device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, **self.experiment.get_state_params(stage), **migrating_params ) utils.set_global_seed(self.experiment.initial_seed) def _run_event(self, event: str, moment: Optional[str]): fn_name = f"on_{event}" if moment is not None: fn_name = f"{fn_name}_{moment}" # before callbacks if self.state is not None: getattr(self.state, f"{fn_name}_pre")() if self.loggers is not None and moment == "start": for logger in self.loggers.values(): getattr(logger, fn_name)(self.state) # running callbacks if self.callbacks is not None: for callback in self.callbacks.values(): getattr(callback, fn_name)(self.state) # after callbacks if self.loggers is not None and \ (moment == "end" or moment is None): # for on_exception case for logger in self.loggers.values(): getattr(logger, fn_name)(self.state) if self.state is not None: getattr(self.state, f"{fn_name}_post")()
[docs] @abstractmethod def forward(self, batch: Mapping[str, Any]) -> Mapping[str, Any]: pass
[docs] def predict_batch(self, batch: Mapping[str, Any]) -> Mapping[str, Any]: """ Run model for a batch of elements WARN: You should not override this method. If you need specific model call, override forward() method Args: batch: Key-value batch items Returns: model output key-value """ batch = self._batch2device(batch, self.device) output = self.forward(batch) return output
def _run_batch(self, batch): self.state.step += self.state.batch_size batch = self._batch2device(batch, self.device) self.state.input = batch self.state.timer.stop("_timers/data_time") self._run_event("batch", moment="start") self.state.timer.start("_timers/model_time") self.state.output = self.forward(batch) self.state.timer.stop("_timers/model_time") self.state.timer.stop("_timers/batch_time") self._run_event("batch", moment="end") def _run_loader(self, loader): self.state.batch_size = ( loader.batch_sampler.batch_size if loader.batch_sampler is not None else loader.batch_size ) self.state.step = ( self.state.step or self.state.epoch * len(loader) * self.state.batch_size ) # @TODO: remove time usage, use it under the hood self.state.timer.reset() self.state.timer.start("_timers/batch_time") self.state.timer.start("_timers/data_time") for i, batch in enumerate(loader): self._run_batch(batch) self.state.timer.reset() if self._check_run and i >= 3: break self.state.timer.start("_timers/batch_time") self.state.timer.start("_timers/data_time") def _run_epoch(self, loaders): # @TODO: better solution with train/inference handling ? if not self.state.stage.startswith("infer"): assert self.state.valid_loader in loaders.keys(), \ f"'{self.state.valid_loader}' " \ f"should be in provided loaders: {list(loaders.keys())}" else: assert not any(x.startswith("train") for x in loaders.keys()), \ "for inference no train loader should be passed" for loader_name, loader in loaders.items(): self.state.loader_name = loader_name self.state.loader_len = len(loader) self.state.need_backward = loader_name.startswith("train") utils.maybe_recursive_call( self.model, "train", mode=self.state.need_backward ) if isinstance(loader.sampler, DistributedSampler) \ and loader_name.startswith("train"): loader.sampler.set_epoch(self.state.stage_epoch) utils.set_global_seed( self.experiment.initial_seed + self.state.epoch + 1 ) self._run_event("loader", moment="start") with torch.set_grad_enabled(self.state.need_backward): self._run_loader(loader) self._run_event("loader", moment="end") def _run_stage(self, stage: str): self._prepare_for_stage(stage) loaders = self.experiment.get_loaders(stage) callbacks = self.experiment.get_callbacks(stage) loggers = utils.process_callbacks( OrderedDict([ (k, v) for k, v in callbacks.items() if isinstance(v, LoggerCallback) ]) ) callbacks = utils.process_callbacks( OrderedDict([ (k, v) for k, v in callbacks.items() if not isinstance(v, LoggerCallback) ]) ) self.state.loggers = loggers self.loggers = loggers self.callbacks = callbacks self._run_event("stage", moment="start") for epoch in range(self.state.num_epochs): self.state.stage_epoch = epoch self._run_event("epoch", moment="start") self._run_epoch(loaders) self._run_event("epoch", moment="end") if self._check_run and self.state.epoch >= 3: break if self.state.early_stop: self.state.early_stop = False break self.state.epoch += 1 self._run_event("stage", moment="end")
[docs] def run_experiment(self, experiment: Experiment, check: bool = False): self._check_run = check self.experiment = experiment # jupyter source code logging hack # + hack to prevent cycle imports from catalyst.dl.experiment import BaseExperiment if isinstance(self.experiment, BaseExperiment) \ and self.experiment.logdir is not None: expdir = Path(os.getcwd()) logdir = Path(self.experiment.logdir) dump_base_experiment_code(expdir, logdir) try: for stage in self.experiment.stages: self._run_stage(stage) except (Exception, KeyboardInterrupt) as ex: self.state.exception = ex self._run_event("exception", moment=None) return self
__all__ = ["Runner"]