from typing import Any, Dict, Iterable, List, Mapping, Tuple, Union
from collections import OrderedDict
import warnings
from torch import nn
from torch.utils.data import DataLoader, Dataset
from catalyst.core import _Experiment
from catalyst.dl import (
    Callback,
    CheckpointCallback,
    CheckRunCallback,
    ConsoleLogger,
    ExceptionCallback,
    MetricManagerCallback,
    TensorboardLogger,
    TimerCallback,
    utils,
    ValidationManagerCallback,
    VerboseLogger,
)
from catalyst.tools import settings
from catalyst.tools.typing import Criterion, Model, Optimizer, Scheduler
[docs]class Experiment(_Experiment):
    """
    Super-simple one-staged experiment,
    you can use to declare experiment in code.
    """
[docs]    def __init__(
        self,
        model: Model,
        datasets: "OrderedDict[str, Union[Dataset, Dict, Any]]" = None,
        loaders: "OrderedDict[str, DataLoader]" = None,
        callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None,
        logdir: str = None,
        stage: str = "train",
        criterion: Criterion = None,
        optimizer: Optimizer = None,
        scheduler: Scheduler = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        check_time: bool = False,
        check_run: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        distributed_params: Dict = None,
        initial_seed: int = 42,
    ):
        """
        Args:
            model (Model): model
            datasets (OrderedDict[str, Union[Dataset, Dict, Any]]): dictionary
                with one or several  ``torch.utils.data.Dataset``
                for training, validation or inference
                used for Loaders automatic creation
                preferred way for distributed training setup
            loaders (OrderedDict[str, DataLoader]): dictionary
                with one or several ``torch.utils.data.DataLoader``
                for training, validation or inference
            callbacks (Union[List[Callback], OrderedDict[str, Callback]]):
                list or dictionary with Catalyst callbacks
            logdir (str): path to output directory
            stage (str): current stage
            criterion (Criterion): criterion function
            optimizer (Optimizer): optimizer
            scheduler (Scheduler): scheduler
            num_epochs (int): number of experiment's epochs
            valid_loader (str): loader name used to calculate
                the metrics and save the checkpoints. For example,
                you can pass `train` and then
                the metrics will be taken from `train` loader.
            main_metric (str): the key to the name of the metric
                by which the checkpoints will be selected.
            minimize_metric (bool): flag to indicate whether
                the ``main_metric`` should be minimized.
            verbose (bool): if True, it displays the status of the training
                to the console.
            check_time (bool): if True, computes the execution time
                of training process and displays it to the console.
            check_run (bool): if True, we run only 3 batches per loader
                and 3 epochs per stage to check pipeline correctness
            state_kwargs (dict): additional state params to ``State``
            checkpoint_data (dict): additional data to save in checkpoint,
                for example: ``class_names``, ``date_of_training``, etc
            distributed_params (dict): dictionary with the parameters
                for distributed and FP16 method
            initial_seed (int): experiment's initial seed value
        """
        assert (
            datasets is not None or loaders is not None
        ), "Please specify the data sources"
        self._model = model
        self._loaders, self._valid_loader = self.process_loaders(
            loaders=loaders,
            datasets=datasets,
            stage=stage,
            valid_loader=valid_loader,
            initial_seed=initial_seed,
        )
        self._callbacks = utils.sort_callbacks_by_order(callbacks)
        self._criterion = criterion
        self._optimizer = optimizer
        self._scheduler = scheduler
        self._initial_seed = initial_seed
        self._logdir = logdir
        self._stage = stage
        self._num_epochs = num_epochs
        self._main_metric = main_metric
        self._minimize_metric = minimize_metric
        self._verbose = verbose
        self._check_time = check_time
        self._check_run = check_run
        self._state_kwargs = state_kwargs or {}
        self._checkpoint_data = checkpoint_data or {}
        self._distributed_params = distributed_params or {} 
    @property
    def initial_seed(self) -> int:
        """Experiment's initial seed value."""
        return self._initial_seed
    @property
    def logdir(self):
        """Path to the directory where the experiment logs."""
        return self._logdir
    @property
    def stages(self) -> Iterable[str]:
        """Experiment's stage names (array with one value)."""
        return [self._stage]
    @property
    def distributed_params(self) -> Dict:
        """Dict with the parameters for distributed and FP16 method."""
        return self._distributed_params
[docs]    @staticmethod
    def process_loaders(
        loaders: "OrderedDict[str, DataLoader]",
        datasets: Dict,
        stage: str,
        valid_loader: str,
        initial_seed: int,
    ) -> "Tuple[OrderedDict[str, DataLoader], str]":
        """Prepares loaders for a given stage."""
        if datasets is not None:
            loaders = utils.get_loaders_from_params(
                initial_seed=initial_seed, **datasets,
            )
        if not stage.startswith(settings.stage_infer_prefix):  # train stage
            if len(loaders) == 1:
                valid_loader = list(loaders.keys())[0]
                warnings.warn(
                    "Attention, there is only one dataloader - "
                    + str(valid_loader)
                )
            assert valid_loader in loaders, (
                "The validation loader must be present "
                "in the loaders used during experiment."
            )
        return loaders, valid_loader 
[docs]    def get_state_params(self, stage: str) -> Mapping[str, Any]:
        """Returns the state parameters for a given stage."""
        default_params = {
            "logdir": self.logdir,
            "num_epochs": self._num_epochs,
            "valid_loader": self._valid_loader,
            "main_metric": self._main_metric,
            "verbose": self._verbose,
            "minimize_metric": self._minimize_metric,
            "checkpoint_data": self._checkpoint_data,
        }
        state_params = {**default_params, **self._state_kwargs}
        return state_params 
[docs]    def get_model(self, stage: str) -> Model:
        """Returns the model for a given stage."""
        return self._model 
[docs]    def get_criterion(self, stage: str) -> Criterion:
        """Returns the criterion for a given stage."""
        return self._criterion 
[docs]    def get_optimizer(self, stage: str, model: nn.Module) -> Optimizer:
        """Returns the optimizer for a given stage."""
        return self._optimizer 
[docs]    def get_scheduler(self, stage: str, optimizer=None) -> Scheduler:
        """Returns the scheduler for a given stage."""
        return self._scheduler 
[docs]    def get_loaders(
        self, stage: str, epoch: int = None,
    ) -> "OrderedDict[str, DataLoader]":
        """Returns the loaders for a given stage."""
        return self._loaders 
[docs]    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """
        Returns the callbacks for a given stage.
        """
        callbacks = self._callbacks or OrderedDict()
        default_callbacks = []
        if self._verbose:
            default_callbacks.append(("_verbose", VerboseLogger))
        if self._check_time:
            default_callbacks.append(("_timer", TimerCallback))
        if self._check_run:
            default_callbacks.append(("_check", CheckRunCallback))
        if not stage.startswith("infer"):
            default_callbacks.append(("_metrics", MetricManagerCallback))
            default_callbacks.append(
                ("_validation", ValidationManagerCallback)
            )
            default_callbacks.append(("_console", ConsoleLogger))
            if self.logdir is not None:
                default_callbacks.append(("_saver", CheckpointCallback))
                default_callbacks.append(("_tensorboard", TensorboardLogger))
        default_callbacks.append(("_exception", ExceptionCallback))
        for callback_name, callback_fn in default_callbacks:
            is_already_present = any(
                isinstance(x, callback_fn) for x in callbacks.values()
            )
            if not is_already_present:
                callbacks[callback_name] = callback_fn()
        return callbacks  
__all__ = ["Experiment"]