Source code for catalyst.core.callbacks.optimizer
from typing import Callable, Dict, List
import logging
import warnings
from catalyst.core import (
Callback,
CallbackNode,
CallbackOrder,
registry,
State,
utils,
)
from catalyst.tools.typing import Optimizer
logger = logging.getLogger(__name__)
[docs]class OptimizerCallback(Callback):
"""Optimizer callback, abstraction over optimizer step."""
[docs] def __init__(
self,
metric_key: str = None,
optimizer_key: str = None,
accumulation_steps: int = 1,
grad_clip_params: Dict = None,
decouple_weight_decay: bool = True,
loss_key: str = None,
):
"""
Args:
loss_key (str): key to get loss from ``state.batch_metrics``
optimizer_key (str): A key to take a optimizer in case
there are several of them and they are in a dictionary format.
accumulation_steps (int): number of steps before
``model.zero_grad()``
grad_clip_params (dict): params for gradient clipping
decouple_weight_decay (bool): If True - decouple weight decay
regularization.
"""
super().__init__(order=CallbackOrder.Optimizer, node=CallbackNode.All)
assert metric_key is None or loss_key is None
if loss_key is not None:
warnings.warn(
"OptimizerCallback: "
"`loss_key` is now deprecated in favor `metric_key`",
stacklevel=2,
)
self.metric_key: str = metric_key or loss_key or "loss"
self.optimizer_key: str = optimizer_key
self.accumulation_steps: int = accumulation_steps
self._accumulation_counter: int = 0
grad_clip_params: dict = grad_clip_params or {}
self.grad_clip_fn = registry.GRAD_CLIPPERS.get_from_params(
**grad_clip_params
)
self.decouple_weight_decay = decouple_weight_decay
self._optimizer_wd: List[float] = [0.0]
[docs] @staticmethod
def grad_step(
*,
optimizer: Optimizer,
optimizer_wds: List[float] = 0,
grad_clip_fn: Callable = None,
) -> None:
"""Makes a gradient step for a given optimizer.
Args:
optimizer (Optimizer): the optimizer
optimizer_wds (List[float]): list of weight decay parameters
for each param group
grad_clip_fn (Callable): function for gradient clipping
"""
for group, wd in zip(optimizer.param_groups, optimizer_wds):
if wd > 0:
for param in group["params"]:
param.data = param.data.add(-wd * group["lr"], param.data)
if grad_clip_fn is not None:
grad_clip_fn(group["params"])
optimizer.step()
[docs] def on_stage_start(self, state: State) -> None:
"""Checks that the current stage has correct optimizer."""
self._optimizer = state.get_attr(
key="optimizer", inner_key=self.optimizer_key
)
assert self._optimizer is not None
[docs] def on_epoch_start(self, state: State) -> None:
"""On epoch start event.
Args:
state (State): current state
"""
if self.decouple_weight_decay:
self._optimizer_wd = [
group.get("weight_decay", 0.0)
for group in self._optimizer.param_groups
]
for i in range(len(self._optimizer.param_groups)):
self._optimizer.param_groups[i]["weight_decay"] = 0.0
else:
self._optimizer_wd = [0.0] * len(self._optimizer.param_groups)
[docs] def on_epoch_end(self, state: State) -> None:
"""On epoch end event.
Args:
state (State): current state
"""
if self.decouple_weight_decay:
for i, wd in enumerate(self._optimizer_wd):
self._optimizer.param_groups[i]["weight_decay"] = wd
lr = self._optimizer.param_groups[0]["lr"]
lr_name = (
f"lr/{self.optimizer_key}"
if self.optimizer_key is not None
else "lr"
)
state.epoch_metrics[lr_name] = lr
momentum = utils.get_optimizer_momentum(self._optimizer)
if momentum is not None:
momentum_name = (
f"momentum/{self.optimizer_key}"
if self.optimizer_key is not None
else "momentum"
)
state.epoch_metrics[momentum_name] = momentum
[docs] def on_batch_end(self, state: State) -> None:
"""On batch end event
Args:
state (State): current state
"""
if not state.is_train_loader:
return
loss = state.batch_metrics[self.metric_key]
self._accumulation_counter += 1
need_gradient_step = (
self._accumulation_counter % self.accumulation_steps == 0
)
# This is very hacky check whether we have AMP optimizer and this may
# change in future.
# But alternative solution is to have AmpOptimizerCallback.
# or expose another c'tor argument.
if hasattr(self._optimizer, "_amp_stash"):
from apex import amp
# Need to set ``delay_unscale``
# according to
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not need_gradient_step
with amp.scale_loss(
loss, self._optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if need_gradient_step:
self.grad_step(
optimizer=self._optimizer,
optimizer_wds=self._optimizer_wd,
grad_clip_fn=self.grad_clip_fn,
)
utils.maybe_recursive_call(self._optimizer, "zero_grad")
self._accumulation_counter = 0
__all__ = ["OptimizerCallback"]