Source code for catalyst.dl.core.state
from typing import Dict, Optional  # isort:skip
from collections import defaultdict, OrderedDict
from pathlib import Path
from torch.optim.optimizer import Optimizer
from catalyst.utils.frozen import FrozenClass
from .metric_manager import MetricManager, TimerManager
# TODO Deep refactoring
#  - lr/loss/momentum bypass (how to deal when multiple optimizers?)
[docs]class RunnerState(FrozenClass):
    """
    An object that is used to pass internal state during train/valid/infer.
    """
    def __init__(
        self,
        *,
        device=None,
        model=None,
        criterion=None,
        optimizer: Optimizer = None,
        scheduler=None,
        logdir: str = None,
        stage: str = "infer",
        num_epochs: int = 1,
        main_metric: str = "loss",
        minimize_metric: bool = True,
        valid_loader: str = "valid",
        verbose: bool = False,
        checkpoint_data: Dict = None,
        batch_consistant_metrics: bool = True,
        **kwargs
    ):
        self.logdir = Path(logdir) if logdir is not None else None
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        # special info
        self.stage = stage
        self.device = device
        self.loader_name = None
        self.phase = None
        # data pipeline
        self.input = None
        self.output = None
        # counters
        self.loader_len = 0
        self.batch_size = 0
        self.step = 0
        self.epoch = 0
        self.stage_epoch = 0
        self.num_epochs = num_epochs
        # metrics & logging
        self.main_metric = main_metric
        self.minimize_metric = minimize_metric
        self.valid_loader = valid_loader
        self.metrics = MetricManager(
            valid_loader=valid_loader,
            main_metric=main_metric,
            minimize=minimize_metric,
            batch_consistant_metrics=batch_consistant_metrics
        )
        self.verbose: bool = verbose
        self.loggers = OrderedDict()
        self.timer = TimerManager()
        # base metrics
        single_optimizer = isinstance(optimizer, Optimizer)
        self.lr = None if single_optimizer else defaultdict(lambda: None)
        self.momentum = None if single_optimizer else defaultdict(lambda: None)
        self.loss = None
        # extra checkpoint data for saving in checkpoint files
        self.checkpoint_data = checkpoint_data or {}
        # other
        self.need_backward = False
        self.early_stop = False
        for k, v in kwargs.items():
            setattr(self, k, v)
        self.exception: Optional[Exception] = None
        self.need_reraise_exception: bool = True
        self._freeze()
[docs]    def get_key(self, key, inner_key=None):
        if inner_key is None:
            return getattr(self, key)
        else:
            return getattr(self, key)[inner_key] 
[docs]    def set_key(self, value, key, inner_key=None):
        if inner_key is None:
            setattr(self, key, value)
        else:
            getattr(self, key)[inner_key] = value 
    def _handle_runner_metrics(self):
        values = {}
        for key, value in zip(
            ["_base/lr", "_base/momentum"], [self.lr, self.momentum]
        ):
            if value is not None:
                if isinstance(value, dict):
                    for k, v in value.items():
                        values[f"{key}/{k}"] = v
                else:
                    values[key] = value
        values.update(self.timer.elapsed)
        values["_timers/_fps"] = \
            
self.batch_size / self.timer.elapsed["_timers/batch_time"]
        self.metrics.add_batch_value(metrics_dict=values)
[docs]    def on_stage_start_pre(self):
        pass 
[docs]    def on_stage_start_post(self):
        pass 
[docs]    def on_stage_end_pre(self):
        pass 
[docs]    def on_stage_end_post(self):
        pass 
[docs]    def on_epoch_start_pre(self):
        self.metrics.begin_epoch()
        pass 
[docs]    def on_epoch_start_post(self):
        pass 
[docs]    def on_epoch_end_pre(self):
        if not self.stage.startswith("infer"):
            self.metrics.end_epoch_train() 
[docs]    def on_epoch_end_post(self):
        pass 
[docs]    def on_loader_start_pre(self):
        self.metrics.begin_loader(self.loader_name) 
[docs]    def on_loader_start_post(self):
        pass 
[docs]    def on_loader_end_pre(self):
        pass 
[docs]    def on_loader_end_post(self):
        self.metrics.end_loader() 
[docs]    def on_batch_start_pre(self):
        self.metrics.begin_batch() 
[docs]    def on_batch_start_post(self):
        pass 
[docs]    def on_batch_end_pre(self):
        pass 
[docs]    def on_batch_end_post(self):
        self._handle_runner_metrics()
        self.metrics.end_batch() 
[docs]    def on_exception_pre(self):
        pass 
[docs]    def on_exception_post(self):
        pass 
    @property
    def stage_epoch_log(self):
        return self.stage_epoch + 1
    @property
    def epoch_log(self):
        return self.epoch + 1 
__all__ = ["RunnerState"]