
Source code for catalyst.core.state

# flake8: noqa
from typing import Dict, Optional, Union, TYPE_CHECKING  # isort:skip
from collections import defaultdict, OrderedDict
from pathlib import Path

import numpy as np

from import DataLoader

from catalyst import utils
from import FrozenClass
from import (
    Criterion, Device, Model, Optimizer, Scheduler

    from .callback import Callback  # noqa: F401

StateModel = Union[Model, Dict[str, Model]]
StateCriterion = Union[Criterion, Dict[str, Criterion]]
StateOptimizer = Union[Optimizer, Dict[str, Optimizer]]
StateScheduler = Union[Scheduler, Dict[str, Scheduler]]

[docs]class State(FrozenClass): r""" Object containing all information about current state of the experiment. state.loaders - ordered dictionary with torch.DataLoaders - "train" prefix is used for training loaders \ (metrics computations, backward pass, optimization) - "valid" prefix is used for validation loaders - metrics only - "infer" prefix is used for inference loaders - dataset prediction :: state.loaders = { "train": MnistTrainLoader(), "valid": MnistValidLoader() } state.model - an instance of torch.nn.Module class should implement ``forward`` method :: state.model = torch.nn.Linear(10, 10) state.criterion - an instance of torch.nn.Module class or torch.nn.modules.loss._Loss should implement ``forward`` method :: state.criterion = torch.nn.CrossEntropyLoss() state.optimizer - an instance of torch.optim.optimizer.Optimizer should implement ``step`` method :: state.optimizer = torch.optim.Adam() state.scheduler - an instance of torch.optim.lr_scheduler._LRScheduler should implement ``step`` method :: state.scheduler = htorch.optim.lr_scheduler.ReduceLROnPlateau() state.device - an instance of torch.device (CPU, GPU, TPU) :: state.device = torch.device("cpu") state.callbacks - ordered dictionary with Catalyst.Callback instances :: state.callbacks = { "accuracy": AccuracyCallback(), "criterion": CriterionCallback(), "optim": OptimizerCallback(), "saver": CheckpointCallback() } state.batch_in - dictionary, containing current batch of data from DataLoader :: state.batch_in = { "images": np.ndarray(batch_size, c, h, w), "targets": np.ndarray(batch_size, 1), } state.batch_out - dictionary, containing model output based on current batch :: state.batch_out = {"logits": torch.Tensor(batch_size, num_classes)} state.batch_metrics - dictionary, flatten storage for batch metrics :: state.batch_metrics = {"loss": ..., "accuracy": ..., "iou": ...} state.loader_metrics - dictionary with aggregated batch statistics for loader (mean over all batches) and global loader metrics, like AUC :: state.loader_metrics = {"loss": ..., "accuracy": ..., "auc": ...} state.epoch_metrics - dictionary with summarized metrics for different loaders and global epoch metrics, like lr, momentum :: state.epoch_metrics = { "train_loss": ..., "train_auc": ..., "valid_loss": ..., "lr": ..., "momentum": ..., } state.is_best_valid - bool, indicator flag - ``True`` if this training epoch is best over all epochs - ``False`` if not state.valid_metrics - dictionary with validation metrics for currect epoch just a subdictionary of epoch_metrics :: state.valid_metrics = {"loss": ..., "accuracy": ..., "auc": ...} state.best_valid_metrics - dictionary with best validation metrics during whole training process state.distributed_rank state.is_distributed_worker state.stage_name state.epoch state.num_epochs state.loader_name state.loader_step state.loader_len state.batch_size state.global_step state.global_epoch state.main_metric state.minimize_metric state.valid_loader state.logdir - path to logging directory to save all logs, metrics, checkpoints and artifacts state.checkpoint_data - dictionary with all extra data for experiment tracking state.is_check_run - bool, indicator flag - ``True`` if you want to check you pipeline and run only 2 batches per loader and 2 epochs per stage - ``False`` (default) if you want to just the pipeline state.need_backward_pass - bool, indicator flag - ``True`` for training loaders - ``False`` otherwise state.need_early_stop - bool, indicator flag used for EarlyStopping and CheckRun Callbacks - ``True`` if we need to stop the training - ``False`` (default) otherwise state.need_exception_reraise - bool, indicator flag - ``True`` (default) if you want to show exception during pipeline and stop the training process - ``False`` otherwise state.exception - python Exception instance to raise (or not ;) ) """ def __init__( self, *, device: Device = None, model: StateModel = None, criterion: StateCriterion = None, optimizer: StateOptimizer = None, scheduler: StateScheduler = None, callbacks: Dict[str, "Callback"] = None, logdir: str = None, stage: str = "infer", num_epochs: int = None, main_metric: str = "loss", minimize_metric: bool = True, valid_loader: str = "valid", checkpoint_data: Dict = None, is_check_run: bool = False, **kwargs, ): # main part # data self.loaders: OrderedDict[str, DataLoader] = None # components self.model: StateModel = model self.criterion: StateCriterion = criterion self.optimizer: StateOptimizer = optimizer self.scheduler: StateScheduler = scheduler # extra components - PyTorch device self.device: Device = device # extra components - Catalyst callbacks self.callbacks: Dict[str, "Callback"] = callbacks # dataflow - model input, model output, metrics self.batch_in = None self.batch_out = None # let's use flatten storage for batch metrics # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...} self.batch_metrics = defaultdict(None) # just aggregated (aka mean over all batches) # batch statistics for loader # and global loader metrics, like AUC # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...} self.loader_metrics = defaultdict(None) # summarized metrics for different loaders # and global epoch metrics, like lr, momentum # epoch_metrics = { # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ..., # 'lr': ..., 'momentum': ..., # } self.epoch_metrics = defaultdict(None) # validation self.is_best_valid = False self.valid_metrics = defaultdict(None) self.best_valid_metrics = defaultdict(None) # pipeline info self.distributed_rank = utils.get_rank() self.is_distributed_worker = self.distributed_rank > 0 self.stage_name: str = stage self.epoch: int = 1 self.num_epochs: int = num_epochs or np.iinfo(np.int32).max self.loader_name: str = None self.loader_step: int = 0 self.loader_len: int = 0 self.batch_size: int = 0 self.global_step: int = 0 self.global_epoch: int = 1 # metrics & validation self.main_metric: str = main_metric self.minimize_metric: bool = minimize_metric self.valid_loader: str = valid_loader # logging self.logdir: Path = Path(logdir) if logdir is not None else None # extra checkpoint data for saving in checkpoint files self.checkpoint_data: Dict = checkpoint_data or {} # other self.is_check_run: bool = is_check_run self.need_backward_pass: bool = False self.need_early_stop: bool = False self.need_exception_reraise: bool = True self.exception: Optional[Exception] = None # kwargs for k, v in kwargs.items(): setattr(self, k, v) self._freeze() @property def input(self): # backward compatibility return self.batch_in @property def output(self): # backward compatibility return self.batch_out
[docs] def get_attr(self, key, inner_key=None): if inner_key is None: return getattr(self, key) else: return getattr(self, key)[inner_key]
[docs] def set_attr(self, value, key, inner_key=None): if inner_key is None: setattr(self, key, value) else: getattr(self, key)[inner_key] = value