from typing import Any, Callable, Dict, Mapping, Tuple, Union # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict
import torch
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler
from catalyst.core import utils
from catalyst.utils.tools.typing import (
Criterion, Device, Model, Optimizer, Scheduler
)
from .callback import Callback, CallbackNode, CallbackScope
from .callbacks import ExceptionCallback
from .experiment import _Experiment
from .state import State
[docs]class _Runner(ABC):
"""
Abstract class for all runners inherited from
"""
_experiment_fn: Callable = _Experiment
_state_fn: Callable = State
[docs] def __init__(
self,
model: Model = None,
device: Device = None,
):
"""
Args:
model (Model): Torch model object
device (Device): Torch device
"""
self._model: Model = model
self._device: Device = device
self._init()
@property
def model(self) -> Model:
"""
Returns the runner's model instance
"""
return self._model
@model.setter
def model(self, value: Union[Model, Dict[str, Model]]):
"""
Setter for the runner's model'
"""
if isinstance(value, nn.Module):
model = value
elif isinstance(value, dict):
values_are_models = all(
[isinstance(v, nn.Module) for v in value.values()]
)
if not values_are_models:
raise TypeError(
"Invalid dict value type, must be `torch.nn.Module`"
)
model = value
else:
raise TypeError(
f"Invalid value type "
f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` "
f"got '{type(value)}'"
)
if self._device is not None:
model: Model = utils.maybe_recursive_call(
model, "to", device=self._device
)
self._model = model
@property
def device(self) -> Device:
"""
Returns the runner's device instance
"""
return self._device
@device.setter
def device(self, value: Device):
"""
Setter for the runner's device'
"""
if isinstance(value, (str, torch.device)):
self._device = value
else:
raise TypeError(
f"Invalid value type "
f"must be `str` or `torch.device` "
f"got '{type(value)}'"
)
if self._model is not None:
self._model = utils.maybe_recursive_call(
self._model, "to", device=self._device
)
def _init(self):
self.experiment: _Experiment = None
self.state: State = None
[docs] @abstractmethod
def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]:
"""
Forward method for your Runner
Args:
batch: Key-value batch items
**kwargs: kwargs to pass to the model
"""
pass
def _get_experiment_components(
self, stage: str = None
) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
"""
Inner method for children's classes for model specific initialization.
As baseline, checks device support and puts model on it.
:return:
"""
utils.set_global_seed(self.experiment.initial_seed)
model = self.experiment.get_model(stage)
criterion, optimizer, scheduler = \
self.experiment.get_experiment_components(model, stage)
model, criterion, optimizer, scheduler, device = \
utils.process_components(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
distributed_params=self.experiment.distributed_params,
device=self.device
)
return model, criterion, optimizer, scheduler, device
def _get_state(
self,
stage: str,
model: Model,
criterion: Criterion,
optimizer: Optimizer,
scheduler: Scheduler,
device: Device,
callbacks: Dict[str, Callback],
):
migrating_params = dict(**self.experiment.get_state_params(stage))
migrate_from_previous_stage = \
migrating_params.get("migrate_from_previous_stage", True)
if migrate_from_previous_stage \
and self.state is not None \
and self.state.callbacks is not None:
for key, value in self.state.callbacks.items():
if value.scope == CallbackScope.Experiment:
callbacks[key] = value
callbacks = utils.process_callbacks(callbacks)
if self.state is not None and migrate_from_previous_stage:
migrating_params.update(
{
"global_step": self.state.global_step,
"global_epoch": self.state.global_epoch,
"resume": getattr(self.state, "resume", None),
}
)
state = self._state_fn(
stage=stage,
model=model,
device=device,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
callbacks=callbacks,
**migrating_params
)
return state
def _get_callbacks(self, stage: str):
callbacks = self.experiment.get_callbacks(stage)
# distributed run setting
rank = utils.get_rank()
if rank == 0: # master node
# remove worker-only callbacks on master node
for k in list(
filter(
lambda c: callbacks[c].node == CallbackNode.Worker,
callbacks
)
):
del callbacks[k]
elif rank > 0: # worker node
# remove master-only callbacks on worker nodes
for k in list(
filter(
lambda c: callbacks[c].node == CallbackNode.Master,
callbacks
)
):
del callbacks[k]
callbacks = utils.process_callbacks(callbacks)
return callbacks
def _prepare_for_stage(self, stage: str):
utils.set_global_seed(self.experiment.initial_seed)
self.model, criterion, optimizer, scheduler, self.device = \
self._get_experiment_components(stage=stage)
utils.set_global_seed(self.experiment.initial_seed)
callbacks = self._get_callbacks(stage)
utils.set_global_seed(self.experiment.initial_seed)
self.state = self._get_state(
stage=stage,
model=self.model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
device=self.device,
callbacks=callbacks,
)
def _prepare_for_epoch(self, stage: str, epoch: int):
pass
def _run_event(self, event: str):
for callback in self.state.callbacks.values():
getattr(callback, event)(self.state)
def _batch2device(self, batch: Mapping[str, Any], device: Device):
output = utils.any2device(batch, device)
return output
def _run_batch_train_step(self, batch: Mapping[str, Any]):
self.state.batch_out = self.forward(batch)
[docs] @torch.no_grad()
def predict_batch(
self, batch: Mapping[str, Any], **kwargs
) -> Mapping[str, Any]:
"""
Run model for a batch of elements
WARN: You should not override this method. If you need specific model
call, override forward() method
Args:
batch: Key-value batch items
**kwargs: kwargs to pass to the model
Returns:
model output key-value
"""
batch = self._batch2device(batch, self.device)
output = self.forward(batch, **kwargs)
return output
def _run_batch(self, batch: Mapping[str, Any]):
self.state.global_step += self.state.batch_size
batch = self._batch2device(batch, self.device)
self.state.batch_in = batch
self._run_event("on_batch_start")
self._run_batch_train_step(batch=batch)
self._run_event("on_batch_end")
def _run_loader(self, loader: DataLoader):
self.state.batch_size = (
loader.batch_sampler.batch_size
if loader.batch_sampler is not None else loader.batch_size
)
self.state.global_step = (
self.state.global_step
or self.state.global_epoch * len(loader) * self.state.batch_size
)
for i, batch in enumerate(loader):
self.state.loader_step = i + 1
self._run_batch(batch)
if self.state.need_early_stop:
self.state.need_early_stop = False
break
def _run_epoch(self, stage: str, epoch: int):
self._prepare_for_epoch(stage=stage, epoch=epoch)
state: State = self.state
assert state.loaders is not None
loaders = state.loaders
# @TODO: better solution with train/inference handling ?
state.is_infer_stage = state.stage_name.startswith("infer")
if not state.is_infer_stage:
assert state.valid_loader in loaders.keys(), \
f"'{state.valid_loader}' " \
f"should be in provided loaders: {list(loaders.keys())}"
else:
# @TODO: add check for non distributed run for inference
assert not any(x.startswith("train") for x in loaders.keys()), \
"for inference no train loader should be passed"
for loader_name, loader in loaders.items():
state.loader_name = loader_name
state.loader_len = len(loader)
state.is_train_loader = loader_name.startswith("train")
self.model.train(state.is_train_loader)
if isinstance(loader.sampler, DistributedSampler) \
and not state.is_infer_stage:
loader.sampler.set_epoch(state.epoch)
utils.set_global_seed(
self.experiment.initial_seed + state.global_epoch + 1
)
self._run_event("on_loader_start")
with torch.set_grad_enabled(state.is_train_loader):
self._run_loader(loader)
self._run_event("on_loader_end")
def _run_stage(self, stage: str):
self._prepare_for_stage(stage)
state: State = self.state
self._run_event("on_stage_start")
while state.epoch < state.num_epochs + 1:
utils.set_global_seed(
self.experiment.initial_seed + state.global_epoch + 1
)
self._run_event("on_epoch_start")
self._run_epoch(stage=stage, epoch=state.epoch)
self._run_event("on_epoch_end")
if state.need_early_stop:
state.need_early_stop = False
break
state.global_epoch += 1
state.epoch += 1
self._run_event("on_stage_end")
[docs] def run_experiment(self, experiment: _Experiment):
"""
Starts the experiment
"""
self.experiment = experiment
try:
for stage in self.experiment.stages:
self._run_stage(stage)
except (Exception, KeyboardInterrupt) as ex:
def _exception_handler_check(callbacks: OrderedDict):
return (
callbacks is not None and any(
issubclass(x.__class__, ExceptionCallback)
for x in callbacks.values()
)
)
if self.state is not None and \
_exception_handler_check(self.state.callbacks):
self.state.exception = ex
self._run_event("on_exception")
else:
raise ex
return self
__all__ = ["_Runner"]