from typing import Any, Dict, Iterable, Mapping, Tuple  # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict

from torch import nn
from import DataLoader, Dataset

from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
from .callback import Callback

[docs]class Experiment(ABC): """ Object containing all information required to run the experiment Abstract, look for implementations """ @property @abstractmethod def initial_seed(self) -> int: pass @property @abstractmethod def logdir(self) -> str: pass @property @abstractmethod def stages(self) -> Iterable[str]: pass @property @abstractmethod def distributed_params(self) -> Dict: pass @property @abstractmethod def monitoring_params(self) -> Dict: pass
[docs] @abstractmethod def get_state_params(self, stage: str) -> Mapping[str, Any]: pass
[docs] @abstractmethod def get_model(self, stage: str) -> _Model: pass
[docs] @abstractmethod def get_criterion(self, stage: str) -> _Criterion: pass
[docs] @abstractmethod def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer: pass
[docs] @abstractmethod def get_scheduler(self, stage: str, optimizer) -> _Scheduler: pass
[docs] def get_experiment_components( self, model: nn.Module, stage: str ) -> Tuple[_Criterion, _Optimizer, _Scheduler]: criterion = self.get_criterion(stage) optimizer = self.get_optimizer(stage, model) scheduler = self.get_scheduler(stage, optimizer) return criterion, optimizer, scheduler
[docs] @abstractmethod def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": pass
[docs] def get_datasets( self, stage: str, **kwargs, ) -> "OrderedDict[str, Dataset]": raise NotImplementedError
[docs] @abstractmethod def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]": raise NotImplementedError
[docs] @staticmethod def get_transforms(stage: str = None, mode: str = None): raise NotImplementedError
__all__ = ["Experiment"]