Source code for catalyst.dl.core.state
from typing import Dict, Optional # isort:skip
from collections import defaultdict, OrderedDict
from pathlib import Path
from torch.optim.optimizer import Optimizer
from catalyst.utils.frozen import FrozenClass
from .metric_manager import MetricManager, TimerManager
# TODO Deep refactoring
# - lr/loss/momentum bypass (how to deal when multiple optimizers?)
[docs]class RunnerState(FrozenClass):
"""
An object that is used to pass internal state during train/valid/infer.
"""
def __init__(
self,
*,
device=None,
model=None,
criterion=None,
optimizer: Optimizer = None,
scheduler=None,
logdir: str = None,
stage: str = "infer",
num_epochs: int = 1,
main_metric: str = "loss",
minimize_metric: bool = True,
valid_loader: str = "valid",
verbose: bool = False,
checkpoint_data: Dict = None,
batch_consistant_metrics: bool = True,
**kwargs
):
self.logdir = Path(logdir) if logdir is not None else None
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.scheduler = scheduler
# special info
self.stage = stage
self.device = device
self.loader_name = None
self.phase = None
# data pipeline
self.input = None
self.output = None
# counters
self.loader_len = 0
self.batch_size = 0
self.step = 0
self.epoch = 0
self.stage_epoch = 0
self.num_epochs = num_epochs
# metrics & logging
self.main_metric = main_metric
self.minimize_metric = minimize_metric
self.valid_loader = valid_loader
self.metrics = MetricManager(
valid_loader=valid_loader,
main_metric=main_metric,
minimize=minimize_metric,
batch_consistant_metrics=batch_consistant_metrics
)
self.verbose: bool = verbose
self.loggers = OrderedDict()
self.timer = TimerManager()
# base metrics
single_optimizer = isinstance(optimizer, Optimizer)
self.lr = None if single_optimizer else defaultdict(lambda: None)
self.momentum = None if single_optimizer else defaultdict(lambda: None)
self.loss = None
# extra checkpoint data for saving in checkpoint files
self.checkpoint_data = checkpoint_data or {}
# other
self.need_backward = False
self.early_stop = False
for k, v in kwargs.items():
setattr(self, k, v)
self.exception: Optional[Exception] = None
self.need_reraise_exception: bool = True
self._freeze()
[docs] def get_key(self, key, inner_key=None):
if inner_key is None:
return getattr(self, key)
else:
return getattr(self, key)[inner_key]
[docs] def set_key(self, value, key, inner_key=None):
if inner_key is None:
setattr(self, key, value)
else:
getattr(self, key)[inner_key] = value
def _handle_runner_metrics(self):
values = {}
for key, value in zip(
["_base/lr", "_base/momentum"], [self.lr, self.momentum]
):
if value is not None:
if isinstance(value, dict):
for k, v in value.items():
values[f"{key}/{k}"] = v
else:
values[key] = value
values.update(self.timer.elapsed)
values["_timers/_fps"] = \
self.batch_size / self.timer.elapsed["_timers/batch_time"]
self.metrics.add_batch_value(metrics_dict=values)
[docs] def on_stage_start_pre(self):
pass
[docs] def on_stage_start_post(self):
pass
[docs] def on_stage_end_pre(self):
pass
[docs] def on_stage_end_post(self):
pass
[docs] def on_epoch_start_pre(self):
self.metrics.begin_epoch()
pass
[docs] def on_epoch_start_post(self):
pass
[docs] def on_epoch_end_pre(self):
if not self.stage.startswith("infer"):
self.metrics.end_epoch_train()
[docs] def on_epoch_end_post(self):
pass
[docs] def on_loader_start_pre(self):
self.metrics.begin_loader(self.loader_name)
[docs] def on_loader_start_post(self):
pass
[docs] def on_loader_end_pre(self):
pass
[docs] def on_loader_end_post(self):
self.metrics.end_loader()
[docs] def on_batch_start_pre(self):
self.metrics.begin_batch()
[docs] def on_batch_start_post(self):
pass
[docs] def on_batch_end_pre(self):
pass
[docs] def on_batch_end_post(self):
self._handle_runner_metrics()
self.metrics.end_batch()
[docs] def on_exception_pre(self):
pass
[docs] def on_exception_post(self):
pass
@property
def stage_epoch_log(self):
return self.stage_epoch + 1
@property
def epoch_log(self):
return self.epoch + 1
__all__ = ["RunnerState"]