Shortcuts

Source code for catalyst.core.state

from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from collections import defaultdict, OrderedDict
from pathlib import Path
import warnings

from torch.utils.data import DataLoader

from catalyst.core import utils
from catalyst.utils.tools.frozen_class import FrozenClass
from catalyst.utils.tools.settings import (
    LOADER_VALID_PREFIX,
    STAGE_INFER_PREFIX,
    STATE_MAIN_METRIC,
)
from catalyst.utils.tools.typing import (
    Criterion,
    Device,
    Model,
    Optimizer,
    Scheduler,
)

if TYPE_CHECKING:
    from .callback import Callback  # noqa: F401

StateModel = Union[Model, Dict[str, Model]]
StateCriterion = Union[Criterion, Dict[str, Criterion]]
StateOptimizer = Union[Optimizer, Dict[str, Optimizer]]
StateScheduler = Union[Scheduler, Dict[str, Scheduler]]


[docs]class State(FrozenClass): """ Some intermediate storage between Experiment and Runner that saves the current state of the Experiments – model, criterion, optimizer, schedulers, metrics, loggers, loaders, etc .. note:: To learn more about Catalyst Core concepts, please check out - :py:mod:`catalyst.core.experiment._Experiment` - :py:mod:`catalyst.core.runner._Runner` - :py:mod:`catalyst.core.state.State` - :py:mod:`catalyst.core.callback.Callback` **state.loaders** - ordered dictionary with torch.DataLoaders; \ for example, :: state.loaders = { "train": MnistTrainLoader(), "valid": MnistValidLoader() } .. note:: - "*train*" prefix is used for training loaders - \ metrics computations, backward pass, optimization - "*valid*" prefix is used for validation loaders - \ metrics computations only - "*infer*" prefix is used for inference loaders - \ dataset prediction **state.model** - an instance of torch.nn.Module class, \ (should implement ``forward`` method); \ for example, :: state.model = torch.nn.Linear(10, 10) **state.criterion** - an instance of torch.nn.Module class\ or torch.nn.modules.loss._Loss (should implement ``forward`` method); \ for example, :: state.criterion = torch.nn.CrossEntropyLoss() **state.optimizer** - an instance of torch.optim.optimizer.Optimizer\ (should implement ``step`` method); \ for example, :: state.optimizer = torch.optim.Adam() **state.scheduler** - an instance of torch.optim.lr_scheduler._LRScheduler\ (should implement ``step`` method); \ for example, :: state.scheduler = htorch.optim.lr_scheduler.ReduceLROnPlateau() **state.device** - an instance of torch.device (CPU, GPU, TPU); \ for example, :: state.device = torch.device("cpu") **state.callbacks** - ordered dictionary with Catalyst.Callback instances;\ for example, :: state.callbacks = { "accuracy": AccuracyCallback(), "criterion": CriterionCallback(), "optim": OptimizerCallback(), "saver": CheckpointCallback() } **state.input** - dictionary, \ containing batch of data from currents DataLoader; \ for example, :: state.input = { "images": np.ndarray(batch_size, c, h, w), "targets": np.ndarray(batch_size, 1), } **state.output** - dictionary, \ containing model output for current batch; \ for example, :: state.output = {"logits": torch.Tensor(batch_size, num_classes)} **state.batch_metrics** - dictionary, flatten storage for batch metrics; \ for example, :: state.batch_metrics = {"loss": ..., "accuracy": ..., "iou": ...} **state.loader_metrics** - dictionary with aggregated batch statistics \ for loader (mean over all batches) and global loader metrics, like AUC; \ for example, :: state.loader_metrics = {"loss": ..., "accuracy": ..., "auc": ...} **state.epoch_metrics** - dictionary with summarized metrics \ for different loaders and global epoch metrics, like lr, momentum; \ for example, :: state.epoch_metrics = { "train_loss": ..., "train_auc": ..., "valid_loss": ..., "lr": ..., "momentum": ..., } **state.is_best_valid** - bool, indicator flag - ``True`` if this training epoch is best over all epochs - ``False`` if not **state.valid_metrics** - dictionary with validation metrics\ for currect epoch; \ for example, :: state.valid_metrics = {"loss": ..., "accuracy": ..., "auc": ...} .. note:: subdictionary of epoch_metrics **state.best_valid_metrics** - dictionary with best validation metrics \ during whole training process **state.distributed_rank** - distributed rank of current worker **state.is_distributed_master** - bool, indicator flag - ``True`` if is master node (state.distributed_rank == 0) - ``False`` if is worker node (state.distributed_rank != 0) **state.is_distributed_worker** - bool, indicator flag - ``True`` if is worker node (state.distributed_rank > 0) - ``False`` if is master node (state.distributed_rank <= 0) **state.stage_name** - string, current stage name,\ for example, :: state.stage_name = "pretraining" / "training" / "finetuning" / etc **state.epoch** - int, numerical indicator for current stage epoch **state.num_epochs** - int, maximum number of epochs, \ required for this stage **state.loader_name** - string, current loader name\ for example, :: state.loader_name = "train_dataset1" / "valid_data2" / "infer_golden" **state.loader_step** - int, numerical indicator \ for batch index in current loader **state.loader_len** - int, maximum number of batches in current loaders **state.batch_size** - int, typical Deep Learning batch size parameter **state.global_step** - int, numerical indicator, counter for all batches,\ that passes through our model during training, validation and\ inference stages **state.global_epoch** - int, numerical indicator, counter for all epochs,\ that have passed during model training, validation and\ inference stages **state.main_metric** - string, containing name of metric of interest \ for optimization, validation and checkpointing during training **state.minimize_metric** - bool, indicator flag - ``True`` if we need to minimize metric during training,\ like `Cross Entropy loss` - ``False`` if we need to maximize metric during training, \ like `Accuracy` or `Intersection over Union` **state.valid_loader** - string, name of validation loader \ for metric selection, validation and model checkpoining **state.logdir** - string, path to logging directory to save\ all logs, metrics, checkpoints and artifacts **state.checkpoint_data** - dictionary\ with all extra data for experiment tracking **state.is_check_run** - bool, indicator flag - ``True`` if you want to check you pipeline and \ run only 2 batches per loader and 2 epochs per stage - ``False`` (default) if you want to just the pipeline **state.is_train_loader** - bool, indicator flag - ``True`` for training loaders - ``False`` otherwise **state.is_valid_loader** - bool, indicator flag - ``True`` for validation loaders - ``False`` otherwise **state.is_infer_loader** - bool, indicator flag - ``True`` for inference loaders - ``False`` otherwise **state.is_infer_stage** - bool, indicator flag - ``True`` for inference stages - ``False`` otherwise **state.need_early_stop** - bool, indicator flag \ used for EarlyStopping and CheckRun Callbacks - ``True`` if we need to stop the training - ``False`` (default) otherwise **state.need_exception_reraise** - bool, indicator flag - ``True`` (default) if you want to show exception \ during pipeline and stop the training process - ``False`` otherwise **state.exception** - python Exception instance to raise (or not ;) ) """
[docs] def __init__( self, *, device: Device = None, model: StateModel = None, criterion: StateCriterion = None, optimizer: StateOptimizer = None, scheduler: StateScheduler = None, callbacks: Dict[str, "Callback"] = None, logdir: str = None, stage: str = STAGE_INFER_PREFIX, num_epochs: int = 1, main_metric: str = STATE_MAIN_METRIC, minimize_metric: bool = True, valid_loader: str = LOADER_VALID_PREFIX, checkpoint_data: Dict = None, is_check_run: bool = False, **kwargs, ): """ Args: @TODO: Docs. Contribution is welcome """ # main part # data self.loaders: OrderedDict[str, DataLoader] = None # components self.model: StateModel = model self.criterion: StateCriterion = criterion self.optimizer: StateOptimizer = optimizer self.scheduler: StateScheduler = scheduler # extra components - PyTorch device self.device: Device = device # extra components - Catalyst callbacks self.callbacks: Dict[str, "Callback"] = callbacks # dataflow - model input, model output self.input = None self.output = None # metrics flow - batch, loader, epoch metrics # let's use flatten storage for batch metrics # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...} self.batch_metrics = defaultdict(None) # just aggregated (aka mean over all batches) # batch statistics for loader # and global loader metrics, like AUC # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...} self.loader_metrics = defaultdict(None) # summarized metrics for different loaders # and global epoch metrics, like lr, momentum # epoch_metrics = { # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ..., # 'lr': ..., 'momentum': ..., # } self.epoch_metrics = defaultdict(None) # validation self.is_best_valid = False self.valid_metrics = defaultdict(None) self.best_valid_metrics = defaultdict(None) # pipeline info self.distributed_rank = utils.get_rank() self.is_distributed_master = ~(self.distributed_rank > 0) self.is_distributed_worker = self.distributed_rank > 0 self.stage_name: str = stage self.epoch: int = 1 self.num_epochs: int = num_epochs self.loader_name: str = None self.loader_step: int = 0 self.loader_len: int = 0 self.batch_size: int = 0 self.global_step: int = 0 self.global_epoch: int = 1 # metrics & validation self.main_metric: str = main_metric self.minimize_metric: bool = minimize_metric self.valid_loader: str = valid_loader # logging self.logdir: Path = Path(logdir) if logdir is not None else None # extra checkpoint data for saving in checkpoint files self.checkpoint_data: Dict = checkpoint_data or {} # other self.is_check_run: bool = is_check_run self.is_train_loader: bool = False self.is_valid_loader: bool = False self.is_infer_loader: bool = False self.is_infer_stage: bool = self.stage_name.startswith( STAGE_INFER_PREFIX ) self.need_early_stop: bool = False self.need_exception_reraise: bool = True self.exception: Optional[Exception] = None # kwargs for k, v in kwargs.items(): setattr(self, k, v) self._freeze()
@property def batch_in(self): """Alias for `state.input`. .. warning:: Deprecated, saved for backward compatibility. Please use `state.batch_in` instead. """ warnings.warn( "`state.batch_in` was deprecated, " "please use `state.input` instead", DeprecationWarning, ) return self.input @property def batch_out(self): """Alias for `state.output`. .. warning:: Deprecated, saved for backward compatibility. Please use `state.batch_out` instead. """ warnings.warn( "`state.batch_out` was deprecated, " "please use `state.output` instead", DeprecationWarning, ) return self.output @property def need_backward_pass(self): """Alias for `state.is_train_loader`. .. warning:: Deprecated, saved for backward compatibility. Please use `state.is_train_loader` instead. """ warnings.warn( "`need_backward_pass` was deprecated, " "please use `is_train_loader` instead", DeprecationWarning, ) return self.is_train_loader
[docs] def get_attr(self, key: str, inner_key: str = None) -> Any: """ Alias for python `getattr` method. Useful for Callbacks preparation and cases with multi-criterion, multi-optimizer setup. For example, when you would like to train multi-task classification. Used to get a named attribute from a `State` by `key` keyword; for example\ :: # example 1 state.get_attr("criterion") # is equivalent to state.criterion # example 2 state.get_attr("optimizer") # is equivalent to state.optimizer # example 3 state.get_attr("scheduler") # is equivalent to state.scheduler With `inner_key` usage, it suppose to find a dictionary under `key`\ and would get `inner_key` from this dict; for example, :: # example 1 state.get_attr("criterion", "bce") # is equivalent to state.criterion["bce"] # example 2 state.get_attr("optimizer", "adam") # is equivalent to state.optimizer["adam"] # example 3 state.get_attr("scheduler", "adam") # is equivalent to state.scheduler["adam"] Args: key (str): name for attribute of interest, like `criterion`, `optimizer`, `scheduler` inner_key (str): name of inner dictionary key """ if inner_key is None: return getattr(self, key) else: return getattr(self, key)[inner_key]