Source code for catalyst.dl.experiment.base

from typing import Any, Dict, Iterable, List, Mapping, Union  # isort:skip
from collections import OrderedDict

from torch import nn
from torch.utils.data import DataLoader

from catalyst.dl.core import Callback, Experiment
from catalyst.dl.utils import process_callbacks
from catalyst.utils.typing import Criterion, Model, Optimizer, Scheduler


[docs]class BaseExperiment(Experiment): """ Super-simple one-staged experiment you can use to declare experiment in code """
[docs] def __init__( self, model: Model, loaders: "OrderedDict[str, DataLoader]", callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None, logdir: str = None, stage: str = "train", criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, state_kwargs: Dict = None, checkpoint_data: Dict = None, distributed_params: Dict = None, monitoring_params: Dict = None, initial_seed: int = 42, ): """ Args: model (Model): model loaders (dict): dictionary containing one or several ``torch.utils.data.DataLoader`` for training and validation callbacks (List[catalyst.dl.Callback]): list of callbacks logdir (str): path to output directory stage (str): current stage criterion (Criterion): criterion function optimizer (Optimizer): optimizer scheduler (Scheduler): scheduler num_epochs (int): number of experiment's epochs valid_loader (str): loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric (str): the key to the name of the metric by which the checkpoints will be selected. minimize_metric (bool): flag to indicate whether the ``main_metric`` should be minimized. verbose (bool): ff true, it displays the status of the training to the console. state_kwargs (dict): additional state params to ``RunnerState`` checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc distributed_params (dict): dictionary with the parameters for distributed and FP16 methond monitoring_params (dict): dict with the parameters for monitoring services initial_seed (int): experiment's initial seed value """ self._model = model self._loaders = loaders self._callbacks = process_callbacks(callbacks) self._criterion = criterion self._optimizer = optimizer self._scheduler = scheduler self._initial_seed = initial_seed self._logdir = logdir self._stage = stage self._num_epochs = num_epochs self._valid_loader = valid_loader self._main_metric = main_metric self._minimize_metric = minimize_metric self._verbose = verbose self._additional_state_kwargs = state_kwargs or {} self.checkpoint_data = checkpoint_data or {} self._distributed_params = distributed_params or {} self._monitoring_params = monitoring_params or {}
@property def initial_seed(self) -> int: """Experiment's initial seed value""" return self._initial_seed @property def logdir(self): """Path to the directory where the experiment logs""" return self._logdir @property def stages(self) -> Iterable[str]: """Experiment's stage names (array with one value)""" return [self._stage] @property def distributed_params(self) -> Dict: """Dict with the parameters for distributed and FP16 methond""" return self._distributed_params @property def monitoring_params(self) -> Dict: """Dict with the parameters for monitoring services""" return self._monitoring_params
[docs] def get_state_params(self, stage: str) -> Mapping[str, Any]: """Returns the state parameters for a given stage""" default_params = dict( logdir=self.logdir, num_epochs=self._num_epochs, valid_loader=self._valid_loader, main_metric=self._main_metric, verbose=self._verbose, minimize_metric=self._minimize_metric, checkpoint_data=self.checkpoint_data ) state_params = {**default_params, **self._additional_state_kwargs} return state_params
[docs] def get_model(self, stage: str) -> Model: """Returns the model for a given stage""" return self._model
[docs] def get_criterion(self, stage: str) -> Criterion: """Returns the criterion for a given stage""" return self._criterion
[docs] def get_optimizer(self, stage: str, model: nn.Module) -> Optimizer: """Returns the optimizer for a given stage""" return self._optimizer
[docs] def get_scheduler(self, stage: str, optimizer=None) -> Scheduler: """Returns the scheduler for a given stage""" return self._scheduler
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """Returns the callbacks for a given stage""" return self._callbacks
[docs] def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]": """Returns the loaders for a given stage""" return self._loaders
__all__ = ["BaseExperiment"]