from typing import Any, Dict, Iterable, List, Mapping, Union  # isort:skip
from collections import OrderedDict
from torch import nn
from torch.utils.data import DataLoader
from catalyst.dl.core import Callback, Experiment
from catalyst.dl.utils import process_callbacks
from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
[docs]class BaseExperiment(Experiment):
    """
    Super-simple one-staged experiment
        you can use to declare experiment in code
    """
    def __init__(
        self,
        model: _Model,
        loaders: "OrderedDict[str, DataLoader]",
        callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None,
        logdir: str = None,
        stage: str = "train",
        criterion: _Criterion = None,
        optimizer: _Optimizer = None,
        scheduler: _Scheduler = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        distributed_params: Dict = None,
        monitoring_params: Dict = None,
        initial_seed: int = 42,
    ):
        self._model = model
        self._loaders = loaders
        self._callbacks = process_callbacks(callbacks)
        self._criterion = criterion
        self._optimizer = optimizer
        self._scheduler = scheduler
        self._initial_seed = initial_seed
        self._logdir = logdir
        self._stage = stage
        self._num_epochs = num_epochs
        self._valid_loader = valid_loader
        self._main_metric = main_metric
        self._minimize_metric = minimize_metric
        self._verbose = verbose
        self._additional_state_kwargs = state_kwargs or {}
        self.checkpoint_data = checkpoint_data or {}
        self._distributed_params = distributed_params or {}
        self._monitoring_params = monitoring_params or {}
    @property
    def initial_seed(self) -> int:
        return self._initial_seed
    @property
    def logdir(self):
        return self._logdir
    @property
    def stages(self) -> Iterable[str]:
        return [self._stage]
    @property
    def distributed_params(self) -> Dict:
        return self._distributed_params
    @property
    def monitoring_params(self) -> Dict:
        return self._monitoring_params
[docs]    def get_state_params(self, stage: str) -> Mapping[str, Any]:
        default_params = dict(
            logdir=self.logdir,
            num_epochs=self._num_epochs,
            valid_loader=self._valid_loader,
            main_metric=self._main_metric,
            verbose=self._verbose,
            minimize_metric=self._minimize_metric,
            checkpoint_data=self.checkpoint_data
        )
        state_params = {**default_params, **self._additional_state_kwargs}
        return state_params 
[docs]    def get_model(self, stage: str) -> _Model:
        return self._model 
[docs]    def get_criterion(self, stage: str) -> _Criterion:
        return self._criterion 
[docs]    def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer:
        return self._optimizer 
[docs]    def get_scheduler(self, stage: str, optimizer=None) -> _Scheduler:
        return self._scheduler 
[docs]    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        return self._callbacks 
[docs]    def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
        return self._loaders  
__all__ = ["BaseExperiment"]