from typing import Callable, Dict, List  # isort:skip
import logging
import safitty
import torch
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.registry import GRAD_CLIPPERS
from catalyst.dl.utils import get_optimizer_momentum, maybe_recursive_call
from catalyst.utils.typing import Optimizer
logger = logging.getLogger(__name__)
[docs]class OptimizerCallback(Callback):
    """
    Optimizer callback, abstraction over optimizer step.
    """
[docs]    def __init__(
        self,
        grad_clip_params: Dict = None,
        accumulation_steps: int = 1,
        optimizer_key: str = None,
        loss_key: str = "loss",
        decouple_weight_decay: bool = True
    ):
        """
        Args:
            grad_clip_params (dict): params for gradient clipping
            accumulation_steps (int): number of steps before
                ``model.zero_grad()``
            optimizer_key (str): A key to take a optimizer in case
                there are several of them and they are in a dictionary format.
            loss_key (str): key to get loss from ``state.loss``
            decouple_weight_decay (bool): If True - decouple weight decay
                regularization.
        """
        super().__init__(CallbackOrder.Optimizer)
        grad_clip_params: dict = grad_clip_params or {}
        self.grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)
        self.accumulation_steps: int = accumulation_steps
        self.optimizer_key: str = optimizer_key
        self.loss_key: str = loss_key
        self.decouple_weight_decay = decouple_weight_decay
        self._optimizer_wd: List[float] = [0.0]
        self._accumulation_counter: int = 0 
[docs]    @staticmethod
    def grad_step(
        *,
        optimizer: Optimizer,
        optimizer_wds: List[float] = 0,
        grad_clip_fn: Callable = None
    ):
        """
        Makes a gradient step for a given optimizer
        Args:
            optimizer (Optimizer): the optimizer
            optimizer_wds (List[float]): list of weight decay parameters
                for each param group
            grad_clip_fn (Callable): function for gradient clipping
        """
        for group, wd in zip(optimizer.param_groups, optimizer_wds):
            if wd > 0:
                for param in group["params"]:
                    param.data = param.data.add(
                        -wd * group["lr"], param.data
                    )
            if grad_clip_fn is not None:
                grad_clip_fn(group["params"])
        optimizer.step() 
[docs]    def on_stage_start(self, state: RunnerState):
        """On stage start event"""
        optimizer = state.get_key(
            key="optimizer", inner_key=self.optimizer_key
        )
        assert optimizer is not None
        lr = optimizer.defaults["lr"]
        momentum = get_optimizer_momentum(optimizer)
        state.set_key(lr, "lr", inner_key=self.optimizer_key)
        state.set_key(momentum, "momentum", inner_key=self.optimizer_key) 
[docs]    def on_epoch_start(self, state):
        """On epoch start event"""
        optimizer = state.get_key(
            key="optimizer", inner_key=self.optimizer_key
        )
        if self.decouple_weight_decay:
            self._optimizer_wd = [
                group.get("weight_decay", 0.0)
                for group in optimizer.param_groups
            ]
            for i in range(len(optimizer.param_groups)):
                safitty.set(
                    optimizer.param_groups, i, "weight_decay", value=0.0)
        else:
            self._optimizer_wd = [0.0] * len(optimizer.param_groups) 
    def _get_loss(self, state) -> torch.Tensor:
        loss = state.get_key(key="loss", inner_key=self.loss_key)
        if isinstance(loss, list):
            raise ValueError(
                f"Loss is a list. "
                f"Only the last value will be used for `backward`."
                f"To aggregate losses into "
                "one value use `CriterionAggregatorCallback`"
            )
        if isinstance(loss, dict):
            error = f"Loss is a dict: {list(loss.keys())}, " \
                    
f"to aggregate losses into " \
                    
"one value use `CriterionAggregatorCallback`."
            if self.loss_key is None:
                error = error + " Or try to pass `loss_key` " \
                                
"in the OptimizerCallback init"
            raise ValueError(error)
        return loss
[docs]    def on_batch_start(self, state):
        """On batch start event"""
        state.loss = None 
[docs]    def on_batch_end(self, state):
        """On batch end event"""
        if not state.need_backward:
            return
        loss = self._get_loss(state)
        self._accumulation_counter += 1
        model = state.model
        optimizer = state.get_key(
            key="optimizer", inner_key=self.optimizer_key
        )
        need_gradient_step = \
            
(self._accumulation_counter + 1) % self.accumulation_steps == 0
        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(optimizer, "_amp_stash"):
            from apex import amp
            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(
                loss,
                optimizer,
                delay_unscale=delay_unscale
            ) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        if need_gradient_step:
            self.grad_step(
                optimizer=optimizer,
                optimizer_wds=self._optimizer_wd,
                grad_clip_fn=self.grad_clip_fn
            )
            maybe_recursive_call(model, "zero_grad")
            self._accumulation_counter = 0 
[docs]    def on_epoch_end(self, state):
        """On epoch end event"""
        if self.decouple_weight_decay:
            optimizer = state.get_key(
                key="optimizer", inner_key=self.optimizer_key
            )
            for i, wd in enumerate(self._optimizer_wd):
                safitty.set(
                    optimizer.param_groups, i, "weight_decay", value=wd)  
__all__ = ["OptimizerCallback"]