from typing import Any, Mapping, Optional, Tuple, Dict, Union  # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict
import os
from pathlib import Path
import torch
from torch import nn
from torch.utils.data import DistributedSampler
from catalyst.dl import utils
from catalyst.utils.typing import (
    Criterion, Device, Model, Optimizer, Scheduler
)
from .callback import Callback, LoggerCallback
from .experiment import Experiment
from .state import RunnerState
[docs]class Runner(ABC):
    """
    Abstract class for all runners inherited from
    """
[docs]    def __init__(
        self,
        model: Model = None,
        device: Device = None,
    ):
        """
        Args:
            model (Model): Torch model object
            device (Device): Torch device
        """
        # main
        self._model: Model = model
        self._device: Device = device
        self.experiment: Experiment = None
        self.state: RunnerState = None
        self.callbacks: OrderedDict[str, Callback] = None
        self.loggers: OrderedDict[str, LoggerCallback] = None
        # additional
        self._check_run = False 
    def _batch2device(self, batch: Mapping[str, Any], device: Device):
        res = utils.any2device(batch, device)
        return res
    def _get_experiment_components(
        self, stage: str = None
    ) -> Tuple[Model, Criterion, Optimizer, Scheduler, Device]:
        """
        Inner method for children's classes for model specific initialization.
        As baseline, checks device support and puts model on it.
        :return:
        """
        utils.set_global_seed(self.experiment.initial_seed)
        model = self.experiment.get_model(stage)
        criterion, optimizer, scheduler = \
            
self.experiment.get_experiment_components(model, stage)
        model, criterion, optimizer, scheduler, device = \
            
utils.process_components(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                distributed_params=self.experiment.distributed_params,
                device=self.device
            )
        return model, criterion, optimizer, scheduler, device
    def _prepare_for_stage(self, stage: str):
        utils.set_global_seed(self.experiment.initial_seed)
        migrating_params = {}
        if self.state is not None:
            migrating_params.update(
                {
                    "step": self.state.step,
                    "epoch": self.state.epoch
                }
            )
        self.model, criterion, optimizer, scheduler, self.device = \
            
self._get_experiment_components(stage)
        self.state = RunnerState(
            stage=stage,
            model=self.model,
            device=self.device,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            **self.experiment.get_state_params(stage),
            **migrating_params
        )
        utils.set_global_seed(self.experiment.initial_seed)
    def _run_event(self, event: str, moment: Optional[str]):
        fn_name = f"on_{event}"
        if moment is not None:
            fn_name = f"{fn_name}_{moment}"
        # before callbacks
        if self.state is not None:
            getattr(self.state, f"{fn_name}_pre")()
        if self.loggers is not None and moment == "start":
            for logger in self.loggers.values():
                getattr(logger, fn_name)(self.state)
        # running callbacks
        if self.callbacks is not None:
            for callback in self.callbacks.values():
                getattr(callback, fn_name)(self.state)
        # after callbacks
        if self.loggers is not None and \
                
(moment == "end" or moment is None):  # for on_exception case
            for logger in self.loggers.values():
                getattr(logger, fn_name)(self.state)
        if self.state is not None:
            getattr(self.state, f"{fn_name}_post")()
    @property
    def model(self) -> Model:
        """
        Returns the runner's model instance
        """
        return self._model
    @model.setter
    def model(self, value: Union[Model, Dict[str, Model]]):
        """
        Setter for the runner's model'
        """
        if isinstance(value, nn.Module):
            model = value
        elif isinstance(value, dict):
            values_are_models = all([
                isinstance(v, nn.Module) for v in value.values()
            ])
            if not values_are_models:
                raise TypeError(
                    "Invalid dict value type, must be `torch.nn.Module`"
                )
            model = value
        else:
            raise TypeError(
                f"Invalid value type "
                f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` "
                f"got '{type(value)}'"
            )
        if self._device is not None:
            model: Model = utils.maybe_recursive_call(
                model, "to", device=self._device
            )
        self._model = model
    @property
    def device(self) -> Device:
        """
        Returns the runner's device instance
        """
        return self._device
    @device.setter
    def device(self, value: Device):
        """
        Setter for the runner's device'
        """
        if isinstance(value, (str, torch.device)):
            self._device = value
        else:
            raise TypeError(
                f"Invalid value type "
                f"must be `str` or `torch.device` "
                f"got '{type(value)}'"
            )
        if self._model is not None:
            self._model = utils.maybe_recursive_call(
                self._model, "to", device=self._device
            )
[docs]    @abstractmethod
    def forward(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]:
        """
        Forward method for your Runner
        Args:
            batch: Key-value batch items
            **kwargs: kwargs to pass to the model
        """
        pass 
[docs]    def predict_batch(
        self,
        batch: Mapping[str, Any],
        **kwargs
    ) -> Mapping[str, Any]:
        """
        Run model for a batch of elements
        WARN: You should not override this method. If you need specific model
        call, override forward() method
        Args:
            batch: Key-value batch items
            **kwargs: kwargs to pass to the model
        Returns:
            model output key-value
        """
        batch = self._batch2device(batch, self.device)
        output = self.forward(batch, **kwargs)
        return output 
    def _run_batch(self, batch):
        self.state.step += self.state.batch_size
        batch = self._batch2device(batch, self.device)
        self.state.input = batch
        self.state.timer.stop("_timers/data_time")
        self._run_event("batch", moment="start")
        self.state.timer.start("_timers/model_time")
        self.state.output = self.forward(batch)
        self.state.timer.stop("_timers/model_time")
        self.state.timer.stop("_timers/batch_time")
        self._run_event("batch", moment="end")
    def _run_loader(self, loader):
        self.state.batch_size = (
            loader.batch_sampler.batch_size
            if loader.batch_sampler is not None else loader.batch_size
        )
        self.state.step = (
            self.state.step
            or self.state.epoch * len(loader) * self.state.batch_size
        )
        # @TODO: remove time usage, use it under the hood
        self.state.timer.reset()
        self.state.timer.start("_timers/batch_time")
        self.state.timer.start("_timers/data_time")
        for i, batch in enumerate(loader):
            self._run_batch(batch)
            self.state.timer.reset()
            if self._check_run and i >= 3:
                break
            self.state.timer.start("_timers/batch_time")
            self.state.timer.start("_timers/data_time")
    def _run_epoch(self, loaders):
        # @TODO: better solution with train/inference handling ?
        if not self.state.stage.startswith("infer"):
            assert self.state.valid_loader in loaders.keys(), \
                
f"'{self.state.valid_loader}' " \
                
f"should be in provided loaders: {list(loaders.keys())}"
        else:
            assert not any(x.startswith("train") for x in loaders.keys()), \
                
"for inference no train loader should be passed"
        for loader_name, loader in loaders.items():
            self.state.loader_name = loader_name
            self.state.loader_len = len(loader)
            self.state.need_backward = loader_name.startswith("train")
            utils.maybe_recursive_call(
                self.model,
                "train",
                mode=self.state.need_backward
            )
            if isinstance(loader.sampler, DistributedSampler) \
                    
and loader_name.startswith("train"):
                loader.sampler.set_epoch(self.state.stage_epoch)
            utils.set_global_seed(
                self.experiment.initial_seed + self.state.epoch + 1
            )
            self._run_event("loader", moment="start")
            with torch.set_grad_enabled(self.state.need_backward):
                self._run_loader(loader)
            self._run_event("loader", moment="end")
    def _run_stage(self, stage: str):
        self._prepare_for_stage(stage)
        loaders = self.experiment.get_loaders(stage)
        callbacks = self.experiment.get_callbacks(stage)
        loggers = utils.process_callbacks(
            OrderedDict([
                (k, v) for k, v in callbacks.items()
                if isinstance(v, LoggerCallback)
            ])
        )
        callbacks = utils.process_callbacks(
            OrderedDict([
                (k, v) for k, v in callbacks.items()
                if not isinstance(v, LoggerCallback)
            ])
        )
        self.state.loggers = loggers
        self.loggers = loggers
        self.callbacks = callbacks
        self._run_event("stage", moment="start")
        for epoch in range(self.state.num_epochs):
            self.state.stage_epoch = epoch
            self._run_event("epoch", moment="start")
            self._run_epoch(loaders)
            self._run_event("epoch", moment="end")
            if self._check_run and self.state.epoch >= 3:
                break
            if self.state.early_stop:
                self.state.early_stop = False
                break
            self.state.epoch += 1
        self._run_event("stage", moment="end")
[docs]    def run_experiment(self, experiment: Experiment, check: bool = False):
        """
        Starts the experiment
        """
        self._check_run = check
        self.experiment = experiment
        # jupyter source code logging hack
        # + hack to prevent cycle imports
        from catalyst.dl.experiment import BaseExperiment
        if isinstance(self.experiment, BaseExperiment) \
                
and self.experiment.logdir is not None:
            expdir = Path(os.getcwd())
            logdir = Path(self.experiment.logdir)
            utils.dump_base_experiment_code(expdir, logdir)
        try:
            for stage in self.experiment.stages:
                self._run_stage(stage)
        except (Exception, KeyboardInterrupt) as ex:
            self.state.exception = ex
            self._run_event("exception", moment=None)
        return self  
__all__ = ["Runner"]