Source code for catalyst.dl.callbacks.optimizer

from typing import Callable, Dict, List  # isort:skip
import logging

import safitty

import torch

from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.registry import GRAD_CLIPPERS
from catalyst.dl.utils import get_optimizer_momentum, maybe_recursive_call
from catalyst.dl.utils.torch import _Optimizer

logger = logging.getLogger(__name__)


[docs]class OptimizerCallback(Callback): """ Optimizer callback, abstraction over optimizer step. """
[docs] def __init__( self, grad_clip_params: Dict = None, accumulation_steps: int = 1, optimizer_key: str = None, loss_key: str = "loss", decouple_weight_decay: bool = True ): """ Args: grad_clip_params (dict): params for gradient clipping accumulation_steps (int): number of steps before ``model.zero_grad()`` optimizer_key (str): A key to take a optimizer in case there are several of them and they are in a dictionary format. loss_key (str): key to get loss from ``state.loss`` decouple_weight_decay (bool): If True - decouple weight decay regularization. """ super().__init__(CallbackOrder.Optimizer) grad_clip_params: dict = grad_clip_params or {} self.grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params) self.accumulation_steps: int = accumulation_steps self.optimizer_key: str = optimizer_key self.loss_key: str = loss_key self.decouple_weight_decay = decouple_weight_decay self._optimizer_wd: List[float] = [0.0] self._accumulation_counter: int = 0
[docs] @staticmethod def grad_step( *, optimizer: _Optimizer, optimizer_wds: List[float] = 0, grad_clip_fn: Callable = None ): for group, wd in zip(optimizer.param_groups, optimizer_wds): if wd > 0: for param in group["params"]: param.data = param.data.add( -wd * group["lr"], param.data ) if grad_clip_fn is not None: grad_clip_fn(group["params"]) optimizer.step()
[docs] def on_stage_start(self, state: RunnerState): optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) assert optimizer is not None lr = optimizer.defaults["lr"] momentum = get_optimizer_momentum(optimizer) state.set_key(lr, "lr", inner_key=self.optimizer_key) state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
[docs] def on_epoch_start(self, state): optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) if self.decouple_weight_decay: self._optimizer_wd = [ group.get("weight_decay", 0.0) for group in optimizer.param_groups ] for i in range(len(optimizer.param_groups)): safitty.set( optimizer.param_groups, i, "weight_decay", value=0.0) else: self._optimizer_wd = [0.0] * len(optimizer.param_groups)
def _get_loss(self, state) -> torch.Tensor: loss = state.get_key(key="loss", inner_key=self.loss_key) if isinstance(loss, list): raise ValueError( f"Loss is a list. " f"Only the last value will be used for `backward`." f"To aggregate losses into " "one value use `CriterionAggregatorCallback`" ) if isinstance(loss, dict): error = f"Loss is a dict: {list(loss.keys())}, " \ f"to aggregate losses into " \ "one value use `CriterionAggregatorCallback`." if self.loss_key is None: error = error + " Or try to pass `loss_key` " \ "in the OptimizerCallback init" raise ValueError(error) return loss
[docs] def on_batch_start(self, state): state.loss = None
[docs] def on_batch_end(self, state): if not state.need_backward: return loss = self._get_loss(state) self._accumulation_counter += 1 model = state.model optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) # This is very hacky check whether we have AMP optimizer and this may # change in future. # But alternative solution is to have AmpOptimizerCallback. # or expose another c'tor argument. if hasattr(optimizer, "_amp_stash"): from apex import amp with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if (self._accumulation_counter + 1) % self.accumulation_steps == 0: self.grad_step( optimizer=optimizer, optimizer_wds=self._optimizer_wd, grad_clip_fn=self.grad_clip_fn ) maybe_recursive_call(model, "zero_grad") self._accumulation_counter = 0
[docs] def on_epoch_end(self, state): if self.decouple_weight_decay: optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) for i, wd in enumerate(self._optimizer_wd): safitty.set( optimizer.param_groups, i, "weight_decay", value=wd)
__all__ = ["OptimizerCallback"]