from typing import Any, Dict, List, Mapping, Union  # isort:skip
from collections import OrderedDict
import logging
from pathlib import Path
import torch
from torch.jit import ScriptModule
from torch.utils.data import DataLoader
from catalyst.dl import utils
from catalyst.dl.callbacks import CheckpointCallback, InferCallback
from catalyst.dl.core import Callback, Runner
from catalyst.dl.experiment import SupervisedExperiment
from catalyst.utils.typing import (
    Criterion, Device, Model, Optimizer, Scheduler
)
logger = logging.getLogger(__name__)
[docs]class SupervisedRunner(Runner):
    """
    Runner for experiments with supervised model
    """
    _default_experiment = SupervisedExperiment
[docs]    def __init__(
        self,
        model: Model = None,
        device: Device = None,
        input_key: str = "features",
        output_key: str = "logits",
        input_target_key: str = "targets",
    ):
        """
        Args:
            model (Model): Torch model object
            device (Device): Torch device
            input_key (str): Key in batch dict mapping for model input
            output_key (str): Key in output dict model output
                will be stored under
            input_target_key (str): Key in batch dict mapping for target
        """
        super().__init__(model=model, device=device)
        self.input_key = input_key
        self.output_key = output_key
        self.target_key = input_target_key
        if isinstance(self.input_key, str):
            self._process_input = self._process_input_str
        elif isinstance(self.input_key, (list, tuple)):
            self._process_input = self._process_input_list
        else:
            self._process_input = self._process_input_none
        if isinstance(output_key, str):
            self._process_output = self._process_output_str
        elif isinstance(output_key, (list, tuple)):
            self._process_output = self._process_output_list
        else:
            self._process_output = self._process_output_none 
    def _batch2device(self, batch: Mapping[str, Any], device: Device):
        if isinstance(batch, (tuple, list)):
            assert len(batch) == 2
            batch = {self.input_key: batch[0], self.target_key: batch[1]}
        batch = super()._batch2device(batch, device)
        return batch
    def _process_input_str(self, batch: Mapping[str, Any], **kwargs):
        output = self.model(batch[self.input_key], **kwargs)
        return output
    def _process_input_list(self, batch: Mapping[str, Any], **kwargs):
        input = dict((key, batch[key]) for key in self.input_key)
        output = self.model(**input, **kwargs)
        return output
    def _process_input_none(self, batch: Mapping[str, Any], **kwargs):
        output = self.model(**batch, **kwargs)
        return output
    def _process_output_str(self, output: Mapping[str, Any]):
        output = {self.output_key: output}
        return output
    def _process_output_list(self, output: Mapping[str, Any]):
        output = dict(
            (key, value) for key, value in zip(self.output_key, output)
        )
        return output
    def _process_output_none(self, output: Mapping[str, Any]):
        return output
[docs]    def forward(self, batch, **kwargs):
        """
        Should not be called directly outside of runner.
        If your model has specific interface, override this method to use it
        """
        output = self._process_input(batch, **kwargs)
        output = self._process_output(output)
        return output 
[docs]    def train(
        self,
        model: Model,
        criterion: Criterion,
        optimizer: Optimizer,
        loaders: "OrderedDict[str, DataLoader]",
        logdir: str,
        callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
        scheduler: Scheduler = None,
        resume: str = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        fp16: Union[Dict, bool] = None,
        monitoring_params: Dict = None,
        check: bool = False,
    ) -> None:
        """
        Starts the training process of the model.
        Args:
            model (Model): model to train
            criterion (Criterion): criterion function for training
            optimizer (Optimizer): optimizer for training
            loaders (dict): dictionary containing one or several
                ``torch.utils.data.DataLoader`` for training and validation
            logdir (str): path to output directory
            callbacks (List[catalyst.dl.Callback]): list of callbacks
            scheduler (Scheduler): scheduler for training
            resume (str): path to checkpoint for model
            num_epochs (int): number of training epochs
            valid_loader (str): loader name used to calculate
                the metrics and save the checkpoints. For example,
                you can pass `train` and then
                the metrics will be taken from `train` loader.
            main_metric (str): the key to the name of the metric
                by which the checkpoints will be selected.
            minimize_metric (bool): flag to indicate whether
                the ``main_metric`` should be minimized.
            verbose (bool): ff true, it displays the status of the training
                to the console.
            state_kwargs (dict): additional state params to ``RunnerState``
            checkpoint_data (dict): additional data to save in checkpoint,
                for example: ``class_names``, ``date_of_training``, etc
            fp16 (Union[Dict, bool]): If not None, then sets training to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            monitoring_params (dict): If not None, then create monitoring
                through Weights&Biases. This params is used for ``wandb.init``
                see https://docs.wandb.com/wandb/init
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        if len(loaders) == 1:
            valid_loader = list(loaders.keys())[0]
            logger.warning(
                "Attention, there is only one data loader - "
                + str(valid_loader)
            )
        if isinstance(fp16, bool) and fp16:
            fp16 = {"opt_level": "O1"}
        if model is not None:
            self.model = model
        if resume is not None:
            callbacks = callbacks or OrderedDict()
            checkpoint_callback_flag = any([
                isinstance(x, CheckpointCallback)
                for x in callbacks.values()
            ])
            if not checkpoint_callback_flag:
                callbacks["loader"] = CheckpointCallback(resume=resume)
            else:
                raise NotImplementedError("CheckpointCallback already exist")
        experiment = self._default_experiment(
            stage="train",
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            logdir=logdir,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=num_epochs,
            valid_loader=valid_loader,
            main_metric=main_metric,
            minimize_metric=minimize_metric,
            verbose=verbose,
            state_kwargs=state_kwargs,
            checkpoint_data=checkpoint_data,
            distributed_params=fp16,
            monitoring_params=monitoring_params
        )
        self.run_experiment(experiment, check=check) 
[docs]    def infer(
        self,
        model: Model,
        loaders: "OrderedDict[str, DataLoader]",
        callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
        verbose: bool = False,
        state_kwargs: Dict = None,
        fp16: Union[Dict, bool] = None,
        check: bool = False,
    ) -> None:
        """
        Makes the inference on the model.
        Args:
            model (Model): model to infer
            loaders (dict): dictionary containing one or several
                ``torch.utils.data.DataLoader`` for inference
            callbacks (List[catalyst.dl.Callback]): list of inference callbacks
            verbose (bool): ff true, it displays the status of the inference
                to the console.
            state_kwargs (dict): additional state params to ``RunnerState``
            fp16 (Union[Dict, bool]): If not None, then sets inference to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        if isinstance(fp16, bool) and fp16:
            fp16 = {"opt_level": "O1"}
        if model is not None:
            self.model = model
        experiment = self._default_experiment(
            stage="infer",
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            verbose=verbose,
            state_kwargs=state_kwargs,
            distributed_params=fp16
        )
        self.run_experiment(experiment, check=check) 
[docs]    def predict_loader(
        self,
        model: Model,
        loader: DataLoader,
        resume: str = None,
        verbose: bool = False,
        state_kwargs: Dict = None,
        fp16: Union[Dict, bool] = None,
        check: bool = False,
    ) -> Any:
        """
        Makes a prediction on the whole loader with the specified model.
        Args:
            model (Model): model to infer
            loader (DataLoader): dictionary containing only one
                ``torch.utils.data.DataLoader`` for inference
            resume (str): path to checkpoint for model
            verbose (bool): ff true, it displays the status of the inference
                to the console.
            state_kwargs (dict): additional state params to ``RunnerState``
            fp16 (Union[Dict, bool]): If not None, then sets inference to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        loaders = OrderedDict([("infer", loader)])
        callbacks = OrderedDict([("inference", InferCallback())])
        if resume is not None:
            callbacks["loader"] = CheckpointCallback(resume=resume)
        self.infer(
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            verbose=verbose,
            state_kwargs=state_kwargs,
            fp16=fp16,
            check=check
        )
        output = callbacks["inference"].predictions
        if isinstance(self.output_key, str):
            output = output[self.output_key]
        return output 
[docs]    def trace(
        self,
        model: Model = None,
        batch=None,
        logdir: str = None,
        loader: DataLoader = None,
        method_name: str = "forward",
        mode: str = "eval",
        requires_grad: bool = False,
        fp16: Union[Dict, bool] = None,
        device: Device = "cpu",
        predict_params: dict = None,
    ) -> ScriptModule:
        """
        Traces model using Torch Jit
        Args:
            model (Model): model to trace
            batch: batch to forward through the model to trace
            logdir (str, optional): If specified,
                the result will be written to the directory
            loader (DataLoader, optional): if batch is not specified, the batch
                will be ``next(iter(loader))``
            method_name (str): model's method name that will be traced
            mode (str): ``train`` or ``eval``
            requires_grad (bool): flag to trace with gradients
            fp16 (Union[Dict, bool]): If not None, then sets
                tracing params to FP16
            deivice (Device): Torch deivice or a string
            predict_params (dict): additional parameters for model forward
        """
        if batch is None:
            if loader is None:
                raise ValueError(
                    "If batch is not provided the loader must be specified"
                )
            batch = next(iter(loader))
        if model is not None:
            self.model = model
        if isinstance(fp16, bool) and fp16:
            opt_level = "O1"
        elif isinstance(fp16, bool) and not fp16:
            opt_level = None
        elif isinstance(fp16, dict):
            opt_level = fp16["opt_level"]
        else:
            opt_level = fp16
        if opt_level is not None:
            device = "cuda"
        elif device is None:
            if self.device is None:
                self.device = utils.get_device()
            device = self.device
        result = utils.trace_model(
            model=self.model,
            runner=self,
            batch=batch,
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            device=device,
            predict_params=predict_params
        )
        if logdir is not None:
            filename = utils.get_trace_name(
                method_name=method_name,
                mode=mode,
                requires_grad=requires_grad,
                opt_level=opt_level
            )
            logdir = Path(logdir)
            output: Path = logdir / "trace"
            output.mkdir(exist_ok=True, parents=True)
            out_model = str(output / filename)
            torch.jit.save(result, out_model)
        return result  
__all__ = ["SupervisedRunner"]