Shortcuts

Source code for catalyst.core.callbacks.optimizer

from typing import Callable, Dict, List
import logging
import warnings

from catalyst.core import (
    Callback,
    CallbackNode,
    CallbackOrder,
    registry,
    State,
    utils,
)
from catalyst.utils.tools.typing import Optimizer

logger = logging.getLogger(__name__)


[docs]class OptimizerCallback(Callback): """Optimizer callback, abstraction over optimizer step."""
[docs] def __init__( self, metric_key: str = None, optimizer_key: str = None, accumulation_steps: int = 1, grad_clip_params: Dict = None, decouple_weight_decay: bool = True, loss_key: str = None, ): """ Args: loss_key (str): key to get loss from ``state.batch_metrics`` optimizer_key (str): A key to take a optimizer in case there are several of them and they are in a dictionary format. accumulation_steps (int): number of steps before ``model.zero_grad()`` grad_clip_params (dict): params for gradient clipping decouple_weight_decay (bool): If True - decouple weight decay regularization. """ super().__init__(order=CallbackOrder.Optimizer, node=CallbackNode.All) assert metric_key is None or loss_key is None if loss_key is not None: warnings.warn( "OptimizerCallback: " "`loss_key` is now deprecated in favor `metric_key`", stacklevel=2, ) self.metric_key: str = metric_key or loss_key or "loss" self.optimizer_key: str = optimizer_key self.accumulation_steps: int = accumulation_steps self._accumulation_counter: int = 0 grad_clip_params: dict = grad_clip_params or {} self.grad_clip_fn = registry.GRAD_CLIPPERS.get_from_params( **grad_clip_params ) self.decouple_weight_decay = decouple_weight_decay self._optimizer_wd: List[float] = [0.0]
[docs] @staticmethod def grad_step( *, optimizer: Optimizer, optimizer_wds: List[float] = 0, grad_clip_fn: Callable = None, ) -> None: """Makes a gradient step for a given optimizer. Args: optimizer (Optimizer): the optimizer optimizer_wds (List[float]): list of weight decay parameters for each param group grad_clip_fn (Callable): function for gradient clipping """ for group, wd in zip(optimizer.param_groups, optimizer_wds): if wd > 0: for param in group["params"]: param.data = param.data.add(-wd * group["lr"], param.data) if grad_clip_fn is not None: grad_clip_fn(group["params"]) optimizer.step()
[docs] def on_stage_start(self, state: State) -> None: """Checks that the current stage has correct optimizer.""" self._optimizer = state.get_attr( key="optimizer", inner_key=self.optimizer_key ) assert self._optimizer is not None
[docs] def on_epoch_start(self, state: State) -> None: """On epoch start event. Args: state (State): current state """ if self.decouple_weight_decay: self._optimizer_wd = [ group.get("weight_decay", 0.0) for group in self._optimizer.param_groups ] for i in range(len(self._optimizer.param_groups)): self._optimizer.param_groups[i]["weight_decay"] = 0.0 else: self._optimizer_wd = [0.0] * len(self._optimizer.param_groups)
[docs] def on_epoch_end(self, state: State) -> None: """On epoch end event. Args: state (State): current state """ if self.decouple_weight_decay: for i, wd in enumerate(self._optimizer_wd): self._optimizer.param_groups[i]["weight_decay"] = wd lr = self._optimizer.param_groups[0]["lr"] lr_name = ( f"lr/{self.optimizer_key}" if self.optimizer_key is not None else "lr" ) state.epoch_metrics[lr_name] = lr momentum = utils.get_optimizer_momentum(self._optimizer) if momentum is not None: momentum_name = ( f"momentum/{self.optimizer_key}" if self.optimizer_key is not None else "momentum" ) state.epoch_metrics[momentum_name] = momentum
[docs] def on_batch_end(self, state: State) -> None: """On batch end event Args: state (State): current state """ if not state.is_train_loader: return loss = state.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = ( self._accumulation_counter % self.accumulation_steps == 0 ) # This is very hacky check whether we have AMP optimizer and this may # change in future. # But alternative solution is to have AmpOptimizerCallback. # or expose another c'tor argument. if hasattr(self._optimizer, "_amp_stash"): from apex import amp # Need to set ``delay_unscale`` # according to # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not need_gradient_step with amp.scale_loss( loss, self._optimizer, delay_unscale=delay_unscale ) as scaled_loss: scaled_loss.backward() else: loss.backward() if need_gradient_step: self.grad_step( optimizer=self._optimizer, optimizer_wds=self._optimizer_wd, grad_clip_fn=self.grad_clip_fn, ) utils.maybe_recursive_call(self._optimizer, "zero_grad") self._accumulation_counter = 0
__all__ = ["OptimizerCallback"]