from typing import Any, Dict, Iterable, List, Mapping, Union # isort:skip
from collections import OrderedDict
from torch import nn
from torch.utils.data import DataLoader
from catalyst.dl.core import Callback, Experiment
from catalyst.dl.utils import process_callbacks
from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
[docs]class BaseExperiment(Experiment):
"""
Super-simple one-staged experiment
you can use to declare experiment in code
"""
def __init__(
self,
model: _Model,
loaders: "OrderedDict[str, DataLoader]",
callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None,
logdir: str = None,
stage: str = "train",
criterion: _Criterion = None,
optimizer: _Optimizer = None,
scheduler: _Scheduler = None,
num_epochs: int = 1,
valid_loader: str = "valid",
main_metric: str = "loss",
minimize_metric: bool = True,
verbose: bool = False,
state_kwargs: Dict = None,
checkpoint_data: Dict = None,
distributed_params: Dict = None,
monitoring_params: Dict = None,
initial_seed: int = 42,
):
self._model = model
self._loaders = loaders
self._callbacks = process_callbacks(callbacks)
self._criterion = criterion
self._optimizer = optimizer
self._scheduler = scheduler
self._initial_seed = initial_seed
self._logdir = logdir
self._stage = stage
self._num_epochs = num_epochs
self._valid_loader = valid_loader
self._main_metric = main_metric
self._minimize_metric = minimize_metric
self._verbose = verbose
self._additional_state_kwargs = state_kwargs or {}
self.checkpoint_data = checkpoint_data or {}
self._distributed_params = distributed_params or {}
self._monitoring_params = monitoring_params or {}
@property
def initial_seed(self) -> int:
return self._initial_seed
@property
def logdir(self):
return self._logdir
@property
def stages(self) -> Iterable[str]:
return [self._stage]
@property
def distributed_params(self) -> Dict:
return self._distributed_params
@property
def monitoring_params(self) -> Dict:
return self._monitoring_params
[docs] def get_state_params(self, stage: str) -> Mapping[str, Any]:
default_params = dict(
logdir=self.logdir,
num_epochs=self._num_epochs,
valid_loader=self._valid_loader,
main_metric=self._main_metric,
verbose=self._verbose,
minimize_metric=self._minimize_metric,
checkpoint_data=self.checkpoint_data
)
state_params = {**default_params, **self._additional_state_kwargs}
return state_params
[docs] def get_model(self, stage: str) -> _Model:
return self._model
[docs] def get_criterion(self, stage: str) -> _Criterion:
return self._criterion
[docs] def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer:
return self._optimizer
[docs] def get_scheduler(self, stage: str, optimizer=None) -> _Scheduler:
return self._scheduler
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
return self._callbacks
[docs] def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
return self._loaders
__all__ = ["BaseExperiment"]