Source code for catalyst.dl.core.experiment
from typing import Any, Dict, Iterable, Mapping, Tuple # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict
from torch import nn
from torch.utils.data import DataLoader, Dataset
from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
from .callback import Callback
[docs]class Experiment(ABC):
"""
Object containing all information required to run the experiment
Abstract, look for implementations
"""
@property
@abstractmethod
def initial_seed(self) -> int:
pass
@property
@abstractmethod
def logdir(self) -> str:
pass
@property
@abstractmethod
def stages(self) -> Iterable[str]:
pass
@property
@abstractmethod
def distributed_params(self) -> Dict:
pass
@property
@abstractmethod
def monitoring_params(self) -> Dict:
pass
[docs] @abstractmethod
def get_state_params(self, stage: str) -> Mapping[str, Any]:
pass
[docs] @abstractmethod
def get_model(self, stage: str) -> _Model:
pass
[docs] @abstractmethod
def get_criterion(self, stage: str) -> _Criterion:
pass
[docs] @abstractmethod
def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer:
pass
[docs] @abstractmethod
def get_scheduler(self, stage: str, optimizer) -> _Scheduler:
pass
[docs] def get_experiment_components(
self, model: nn.Module, stage: str
) -> Tuple[_Criterion, _Optimizer, _Scheduler]:
criterion = self.get_criterion(stage)
optimizer = self.get_optimizer(stage, model)
scheduler = self.get_scheduler(stage, optimizer)
return criterion, optimizer, scheduler
[docs] @abstractmethod
def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
pass
[docs] def get_datasets(
self,
stage: str,
**kwargs,
) -> "OrderedDict[str, Dataset]":
raise NotImplementedError
[docs] @abstractmethod
def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
raise NotImplementedError
__all__ = ["Experiment"]