from typing import Any, Dict, Iterable, List, Mapping, Tuple, Union
from collections import OrderedDict
import warnings
from torch import nn
from torch.utils.data import DataLoader, Dataset
from catalyst.core import _Experiment
from catalyst.dl import (
Callback,
CheckpointCallback,
CheckRunCallback,
ConsoleLogger,
ExceptionCallback,
MetricManagerCallback,
TensorboardLogger,
TimerCallback,
utils,
ValidationManagerCallback,
VerboseLogger,
)
from catalyst.tools import settings
from catalyst.tools.typing import Criterion, Model, Optimizer, Scheduler
[docs]class Experiment(_Experiment):
"""
Super-simple one-staged experiment,
you can use to declare experiment in code.
"""
[docs] def __init__(
self,
model: Model,
datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None,
loaders: "OrderedDict[str, DataLoader]" = None,
callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None,
logdir: str = None,
stage: str = "train",
criterion: Criterion = None,
optimizer: Optimizer = None,
scheduler: Scheduler = None,
num_epochs: int = 1,
valid_loader: str = "valid",
main_metric: str = "loss",
minimize_metric: bool = True,
verbose: bool = False,
check_time: bool = False,
check_run: bool = False,
state_kwargs: Dict = None,
checkpoint_data: Dict = None,
distributed_params: Dict = None,
initial_seed: int = 42,
):
"""
Args:
model (Model): model
datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary
with one or several ``torch.utils.data.Dataset``
for training, validation or inference
used for Loaders automatic creation
preferred way for distributed training setup
loaders (OrderedDict[str, DataLoader]): dictionary
with one or several ``torch.utils.data.DataLoader``
for training, validation or inference
callbacks (Union[List[Callback], OrderedDict[str, Callback]]):
list or dictionary with Catalyst callbacks
logdir (str): path to output directory
stage (str): current stage
criterion (Criterion): criterion function
optimizer (Optimizer): optimizer
scheduler (Scheduler): scheduler
num_epochs (int): number of experiment's epochs
valid_loader (str): loader name used to calculate
the metrics and save the checkpoints. For example,
you can pass `train` and then
the metrics will be taken from `train` loader.
main_metric (str): the key to the name of the metric
by which the checkpoints will be selected.
minimize_metric (bool): flag to indicate whether
the ``main_metric`` should be minimized.
verbose (bool): if True, it displays the status of the training
to the console.
check_time (bool): if True, computes the execution time
of training process and displays it to the console.
check_run (bool): if True, we run only 3 batches per loader
and 3 epochs per stage to check pipeline correctness
state_kwargs (dict): additional state params to ``State``
checkpoint_data (dict): additional data to save in checkpoint,
for example: ``class_names``, ``date_of_training``, etc
distributed_params (dict): dictionary with the parameters
for distributed and FP16 method
initial_seed (int): experiment's initial seed value
"""
assert (
datasets is not None or loaders is not None
), "Please specify the data sources"
self._model = model
self._loaders, self._valid_loader = self.process_loaders(
loaders=loaders,
datasets=datasets,
stage=stage,
valid_loader=valid_loader,
initial_seed=initial_seed,
)
self._callbacks = utils.sort_callbacks_by_order(callbacks)
self._criterion = criterion
self._optimizer = optimizer
self._scheduler = scheduler
self._initial_seed = initial_seed
self._logdir = logdir
self._stage = stage
self._num_epochs = num_epochs
self._main_metric = main_metric
self._minimize_metric = minimize_metric
self._verbose = verbose
self._check_time = check_time
self._check_run = check_run
self._state_kwargs = state_kwargs or {}
self._checkpoint_data = checkpoint_data or {}
self._distributed_params = distributed_params or {}
@property
def initial_seed(self) -> int:
"""Experiment's initial seed value."""
return self._initial_seed
@property
def logdir(self):
"""Path to the directory where the experiment logs."""
return self._logdir
@property
def stages(self) -> Iterable[str]:
"""Experiment's stage names (array with one value)."""
return [self._stage]
@property
def distributed_params(self) -> Dict:
"""Dict with the parameters for distributed and FP16 method."""
return self._distributed_params
[docs] @staticmethod
def process_loaders(
loaders: "OrderedDict[str, DataLoader]",
datasets: Dict,
stage: str,
valid_loader: str,
initial_seed: int,
) -> "Tuple[OrderedDict[str, DataLoader], str]":
"""Prepares loaders for a given stage."""
if datasets is not None:
loaders = utils.get_loaders_from_params(
initial_seed=initial_seed, **datasets,
)
if not stage.startswith(settings.stage_infer_prefix): # train stage
if len(loaders) == 1:
valid_loader = list(loaders.keys())[0]
warnings.warn(
"Attention, there is only one dataloader - "
+ str(valid_loader)
)
assert valid_loader in loaders, (
"The validation loader must be present "
"in the loaders used during experiment."
)
return loaders, valid_loader
[docs] def get_state_params(self, stage: str) -> Mapping[str, Any]:
"""Returns the state parameters for a given stage."""
default_params = {
"logdir": self.logdir,
"num_epochs": self._num_epochs,
"valid_loader": self._valid_loader,
"main_metric": self._main_metric,
"verbose": self._verbose,
"minimize_metric": self._minimize_metric,
"checkpoint_data": self._checkpoint_data,
}
state_params = {**default_params, **self._state_kwargs}
return state_params
[docs] def get_model(self, stage: str) -> Model:
"""Returns the model for a given stage."""
return self._model
[docs] def get_criterion(self, stage: str) -> Criterion:
"""Returns the criterion for a given stage."""
return self._criterion
[docs] def get_optimizer(self, stage: str, model: nn.Module) -> Optimizer:
"""Returns the optimizer for a given stage."""
return self._optimizer
[docs] def get_scheduler(self, stage: str, optimizer=None) -> Scheduler:
"""Returns the scheduler for a given stage."""
return self._scheduler
[docs] def get_loaders(
self, stage: str, epoch: int = None,
) -> "OrderedDict[str, DataLoader]":
"""Returns the loaders for a given stage."""
return self._loaders
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
"""
Returns the callbacks for a given stage.
"""
callbacks = self._callbacks or OrderedDict()
default_callbacks = []
if self._verbose:
default_callbacks.append(("_verbose", VerboseLogger))
if self._check_time:
default_callbacks.append(("_timer", TimerCallback))
if self._check_run:
default_callbacks.append(("_check", CheckRunCallback))
if not stage.startswith("infer"):
default_callbacks.append(("_metrics", MetricManagerCallback))
default_callbacks.append(
("_validation", ValidationManagerCallback)
)
default_callbacks.append(("_console", ConsoleLogger))
if self.logdir is not None:
default_callbacks.append(("_saver", CheckpointCallback))
default_callbacks.append(("_tensorboard", TensorboardLogger))
default_callbacks.append(("_exception", ExceptionCallback))
for callback_name, callback_fn in default_callbacks:
is_already_present = any(
isinstance(x, callback_fn) for x in callbacks.values()
)
if not is_already_present:
callbacks[callback_name] = callback_fn()
return callbacks
__all__ = ["Experiment"]