Source code for catalyst.callbacks.optimizer
from typing import Callable, Dict, List, TYPE_CHECKING
import logging
import warnings
import torch
from catalyst import registry
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.typing import Optimizer
from catalyst.utils.misc import get_attr, maybe_recursive_call
from catalyst.utils.torch import get_optimizer_momentum
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
logger = logging.getLogger(__name__)
try:
import torch_xla.core.xla_model as xm
except ModuleNotFoundError:
pass
def zero_grad(optimizer: Optimizer) -> None:
"""Perform an hacky way to zero gradients.
Args:
optimizer: optimizer with model parameters.
"""
for group in optimizer.param_groups:
for p in group["params"]:
p.grad = None
[docs]class IOptimizerCallback(Callback):
"""Optimizer callback interface, abstraction over optimizer step."""
pass
[docs]class OptimizerCallback(IOptimizerCallback):
"""Optimizer callback, abstraction over optimizer step."""
[docs] def __init__(
self,
metric_key: str = None,
optimizer_key: str = None,
accumulation_steps: int = 1,
grad_clip_params: Dict = None,
decouple_weight_decay: bool = True,
loss_key: str = None,
use_fast_zero_grad: bool = False,
xla_barrier: bool = True,
):
"""
Args:
loss_key: key to get loss from ``runner.batch_metrics``
optimizer_key: A key to take a optimizer in case
there are several of them and they are in a dictionary format.
accumulation_steps: number of steps before
``model.zero_grad()``
grad_clip_params: params for gradient clipping
decouple_weight_decay: If ``True`` - decouple weight decay
regularization.
use_fast_zero_grad: boost ``optiomizer.zero_grad()``,
default is ``False``.
xla_barrier: barrier option for xla. Here you can find
more about usage of `barrier flag
<https://pytorch.org/xla/release/1.5/index.html?
highlight=optimizer_step#torch_xla.core.xla_model.optimizer_step>`_
and `examples
<https://pytorch.org/xla/release/1.5/index.html#
running-on-a-single-xla-device>`_.
Default is ``True``.
"""
super().__init__(order=CallbackOrder.optimizer, node=CallbackNode.all)
assert metric_key is None or loss_key is None
if loss_key is not None:
warnings.warn(
"OptimizerCallback: "
"`loss_key` is now deprecated in favor `metric_key`",
stacklevel=2,
)
self.metric_key: str = metric_key or loss_key or "loss"
self.optimizer_key: str = optimizer_key
self.accumulation_steps: int = accumulation_steps
self._accumulation_counter: int = 0
grad_clip_params: dict = grad_clip_params or {}
self.grad_clip_fn = registry.GRAD_CLIPPER.get_from_params(
**grad_clip_params
)
self.decouple_weight_decay = decouple_weight_decay
self._optimizer_wd: List[float] = [0.0]
self._optimizer_step_fn: Callable = None
self.is_xla = False
self.use_fast_zero_grad = use_fast_zero_grad
self.use_xla_barrier = xla_barrier
def _optimizer_step(self, optimizer: Optimizer) -> None:
"""CPU and GPU optimization step.
Args:
optimizer: optimizer object
"""
optimizer.step()
def _optimizer_step_tpu(self, optimizer: Optimizer) -> None:
"""TPU optimization step.
Args:
optimizer: optimizer object
"""
if self.use_xla_barrier:
xm.optimizer_step(optimizer, barrier=True)
else:
xm.optimizer_step(optimizer)
[docs] def grad_step(
self,
*,
optimizer: Optimizer,
optimizer_wds: List[float] = 0,
grad_clip_fn: Callable = None,
) -> None:
"""Makes a gradient step for a given optimizer.
Args:
optimizer: the optimizer
optimizer_wds: list of weight decay parameters
for each param group
grad_clip_fn: function for gradient clipping
"""
for group, wd in zip(optimizer.param_groups, optimizer_wds):
if wd > 0:
for param in group["params"]:
param.data = param.data.add(
other=param.data, alpha=-wd * group["lr"]
)
if grad_clip_fn is not None:
grad_clip_fn(group["params"])
# optimize parameters
self._optimizer_step_fn(optimizer)
[docs] def on_stage_start(self, runner: "IRunner") -> None:
"""Checks that the current stage has correct optimizer.
Args:
runner(IRunner): current runner
"""
self._optimizer = get_attr(
runner, key="optimizer", inner_key=self.optimizer_key
)
# device based optimization step
if runner.device.type == "xla":
self._optimizer_step_fn = self._optimizer_step_tpu
else:
self._optimizer_step_fn = self._optimizer_step
assert self._optimizer is not None
[docs] def on_epoch_start(self, runner: "IRunner") -> None:
"""On epoch start event.
Args:
runner: current runner
"""
if self.decouple_weight_decay:
self._optimizer_wd = [
group.get("weight_decay", 0.0)
for group in self._optimizer.param_groups
]
for i in range(len(self._optimizer.param_groups)):
self._optimizer.param_groups[i]["weight_decay"] = 0.0
else:
self._optimizer_wd = [0.0] * len(self._optimizer.param_groups)
[docs] def on_batch_end(self, runner: "IRunner") -> None:
"""On batch end event
Args:
runner: current runner
"""
if not runner.is_train_loader:
return
loss = runner.batch_metrics[self.metric_key]
self._accumulation_counter += 1
need_gradient_step = (
self._accumulation_counter % self.accumulation_steps == 0
)
# This is very hacky check whether we have AMP optimizer and this may
# change in future.
# But alternative solution is to have AmpOptimizerCallback.
# or expose another c'tor argument.
# @TODO: speedup with re-definition ``on_stage_start``
if hasattr(self._optimizer, "_amp_stash"):
from apex import amp
# Need to set ``delay_unscale``
# according to
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not need_gradient_step
with amp.scale_loss(
loss, self._optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if need_gradient_step:
self.grad_step(
optimizer=self._optimizer,
optimizer_wds=self._optimizer_wd,
grad_clip_fn=self.grad_clip_fn,
)
if not self.use_fast_zero_grad:
maybe_recursive_call(self._optimizer, "zero_grad")
else:
maybe_recursive_call(self._optimizer, zero_grad)
self._accumulation_counter = 0
[docs] def on_epoch_end(self, runner: "IRunner") -> None:
"""On epoch end event.
Args:
runner: current runner
"""
if self.decouple_weight_decay:
for i, wd in enumerate(self._optimizer_wd):
self._optimizer.param_groups[i]["weight_decay"] = wd
lr = self._optimizer.param_groups[0]["lr"]
lr_name = (
f"lr/{self.optimizer_key}"
if self.optimizer_key is not None
else "lr"
)
runner.epoch_metrics[lr_name] = lr
momentum = get_optimizer_momentum(self._optimizer)
if momentum is not None:
momentum_name = (
f"momentum/{self.optimizer_key}"
if self.optimizer_key is not None
else "momentum"
)
runner.epoch_metrics[momentum_name] = momentum
[docs]class AMPOptimizerCallback(IOptimizerCallback):
"""
Optimizer callback with native torch amp support.
"""
[docs] def __init__(
self,
metric_key: str = None,
optimizer_key: str = None,
accumulation_steps: int = 1,
grad_clip_params: Dict = None,
loss_key: str = None,
):
"""
Args:
loss_key: key to get loss from ``runner.batch_metrics``
optimizer_key: A key to take a optimizer in case
there are several of them and they are in a dictionary format.
accumulation_steps: number of steps before
``model.zero_grad()``
grad_clip_params: params for gradient clipping
decouple_weight_decay: If True - decouple weight decay
regularization.
"""
super().__init__(order=CallbackOrder.optimizer, node=CallbackNode.all)
assert metric_key is None or loss_key is None
if loss_key is not None:
warnings.warn(
"OptimizerCallback: "
"`loss_key` is now deprecated in favor `metric_key`",
stacklevel=2,
)
self.metric_key: str = metric_key or loss_key or "loss"
self.optimizer_key: str = optimizer_key
self.accumulation_steps: int = accumulation_steps
self._accumulation_counter: int = 0
grad_clip_params: dict = grad_clip_params or {}
self.grad_clip_fn = registry.GRAD_CLIPPER.get_from_params(
**grad_clip_params
)
# Initialized at on_state_start()
self.scaler = None
[docs] def grad_step(
self, *, optimizer: Optimizer, grad_clip_fn: Callable = None,
) -> None:
"""Makes a gradient step for a given optimizer.
Args:
optimizer: the optimizer
grad_clip_fn: function for gradient clipping
"""
if grad_clip_fn is not None:
# Unscales the gradients of
# optimizer's assigned params in-place
self.scaler.unscale_(optimizer)
for group in optimizer.param_groups:
# Since the gradients of optimizer's
# assigned params are unscaled, clips as usual:
grad_clip_fn(group["params"])
self.scaler.step(optimizer)
self.scaler.update()
[docs] def on_stage_start(self, runner: "IRunner") -> None:
"""Checks that the current stage has correct optimizer.
Args:
runner(IRunner): current runner
"""
from torch.cuda.amp import GradScaler
self._optimizer = get_attr(
runner, key="optimizer", inner_key=self.optimizer_key
)
self.scaler = GradScaler()
assert self._optimizer is not None
[docs] def on_batch_start(self, runner: "IRunner") -> None:
"""On batch start event
Args:
runner: current runner
"""
self.prev_autocast_state = torch.is_autocast_enabled()
torch.set_autocast_enabled(True)
torch.autocast_increment_nesting()
[docs] def on_batch_end(self, runner: "IRunner") -> None:
"""On batch end event
Args:
runner: current runner
"""
# Drop the cache when we exit to a nesting level
# that's outside any instance of autocast.
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev_autocast_state)
if not runner.is_train_loader:
return
loss = runner.batch_metrics[self.metric_key]
self._accumulation_counter += 1
need_gradient_step = (
self._accumulation_counter % self.accumulation_steps == 0
)
self.scaler.scale(loss).backward()
if need_gradient_step:
self.grad_step(
optimizer=self._optimizer, grad_clip_fn=self.grad_clip_fn,
)
maybe_recursive_call(self._optimizer, "zero_grad")
self._accumulation_counter = 0
[docs] def on_epoch_end(self, runner: "IRunner") -> None:
"""On epoch end event.
Args:
runner: current runner
"""
lr = self._optimizer.param_groups[0]["lr"]
lr_name = (
f"lr/{self.optimizer_key}"
if self.optimizer_key is not None
else "lr"
)
runner.epoch_metrics[lr_name] = lr
momentum = get_optimizer_momentum(self._optimizer)
if momentum is not None:
momentum_name = (
f"momentum/{self.optimizer_key}"
if self.optimizer_key is not None
else "momentum"
)
runner.epoch_metrics[momentum_name] = momentum
[docs] def on_stage_end(self, runner: "IRunner") -> None:
"""On stage end event.
Args:
runner: current runner
"""
self.scaler = None
# @TODO: add OptimizerCallback autocreation
# def OptimizerCallback(*args, **kwargs):
# """
# Optimizer callback factory-wrapper to select required OptimizerCallback
# automatically.
# """
# is_amp_enabled = (
# os.getenv("USE_AMP", "0") == "1" and utils.check_amp_available()
# )
#
# optimizer_callback = AMPOptimizerCallback(*args, **kwargs) \
# if is_amp_enabled \
# else OptimizerCallback(*args, **kwargs)
# return optimizer_callback
__all__ = [
"IOptimizerCallback",
"AMPOptimizerCallback",
"OptimizerCallback",
]