Source code for catalyst.callbacks.optimizer
from typing import Callable, Dict, List, TYPE_CHECKING
import logging
import warnings
import torch
from catalyst import registry
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.typing import Optimizer
from catalyst.utils.misc import get_attr, maybe_recursive_call
from catalyst.utils.torch import get_optimizer_momentum
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
logger = logging.getLogger(__name__)
try:
import torch_xla.core.xla_model as xm
except ModuleNotFoundError:
pass
def zero_grad(optimizer: Optimizer) -> None:
"""Perform an hacky way to zero gradients.
Args:
optimizer: optimizer with model parameters.
"""
for group in optimizer.param_groups:
for p in group["params"]:
p.grad = None
[docs]class IOptimizerCallback(Callback):
"""Optimizer callback interface, abstraction over optimizer step."""
pass
[docs]class OptimizerCallback(IOptimizerCallback):
"""Optimizer callback, abstraction over optimizer step."""
[docs] def __init__(
self,
metric_key: str = None,
optimizer_key: str = None,
accumulation_steps: int = 1,
grad_clip_params: Dict = None,
decouple_weight_decay: bool = False,
loss_key: str = None,
use_fast_zero_grad: bool = False,
xla_barrier: bool = True,
use_amp: bool = None,
use_apex: bool = None,
):
"""
Args:
loss_key: key to get loss from ``runner.batch_metrics``
optimizer_key: A key to take a optimizer in case
there are several of them and they are in a dictionary format.
accumulation_steps: number of steps before ``model.zero_grad()``
grad_clip_params: params for gradient clipping, example:
``{'func': 'clip_grad_norm_', 'max_norm': 1, norm_type': 2}``
decouple_weight_decay: If ``True`` - decouple weight decay
regularization, default is ``False``.
use_fast_zero_grad: boost ``optimizer.zero_grad()``,
default is ``False``.
xla_barrier: barrier option for xla. Here you can find
more about usage of `barrier flag
<https://pytorch.org/xla/release/1.5/index.html?
highlight=optimizer_step#torch_xla.core.xla_model.optimizer_step>`_
and `examples
<https://pytorch.org/xla/release/1.5/index.html#
running-on-a-single-xla-device>`_.
Default is ``True``.
use_amp: whether to use native pytorch AMP, if None will be set
based on runner.experiment.distributed_params on stage start
use_apex: whether to use apex, if None will be set
based on runner.experiment.distributed_params on stage start
"""
super().__init__(order=CallbackOrder.optimizer, node=CallbackNode.all)
assert metric_key is None or loss_key is None
if loss_key is not None:
warnings.warn(
"OptimizerCallback: "
"`loss_key` is now deprecated in favor `metric_key`",
stacklevel=2,
)
self.metric_key: str = metric_key or loss_key or "loss"
self.optimizer_key: str = optimizer_key
self.accumulation_steps: int = accumulation_steps
self._accumulation_counter: int = 0
if use_apex and use_amp:
raise ValueError(
"OptimizerCallback: ``use_amp==True`` and ``use_apex==True`` "
"You must choose only one mixed precision backend"
)
self.use_amp = use_amp
self.use_apex = use_apex
# If use_amp==True scaler is initialized at on_stage_start()
self.scaler = None
grad_clip_params: dict = grad_clip_params or {}
self.grad_clip_fn = registry.GRAD_CLIPPER.get_from_params(
**grad_clip_params
)
self.decouple_weight_decay = decouple_weight_decay
self._optimizer_wds: List = None
self._optimizer_step_fn: Callable = None
self.is_xla = False
self.use_fast_zero_grad = use_fast_zero_grad
self.use_xla_barrier = xla_barrier
def _optimizer_step(self) -> None:
"""CPU and GPU optimization step.
"""
self._optimizer.step()
def _optimizer_step_amp(self) -> None:
"""Optimization step with pytorch native amp
"""
self.scaler.step(self._optimizer)
self.scaler.update()
def _optimizer_step_tpu(self) -> None:
"""TPU optimization step.
"""
if self.use_xla_barrier:
xm.optimizer_step(self._optimizer, barrier=True)
else:
xm.optimizer_step(self._optimizer)
[docs] def grad_step(
self, *, optimizer: Optimizer, grad_clip_fn: Callable = None
) -> None:
"""Makes a gradient step for a given optimizer.
Args:
optimizer: the optimizer
grad_clip_fn: function for gradient clipping
"""
if self.decouple_weight_decay:
for group, wd in zip(optimizer.param_groups, self._optimizer_wds):
for param in group["params"]:
param.data = param.data.add(
other=param.data, alpha=-wd * group["lr"]
)
if grad_clip_fn is not None:
if self.use_amp:
self.scaler.unscale_(optimizer)
for group in optimizer.param_groups:
grad_clip_fn(group["params"])
# optimize parameters
self._optimizer_step_fn()
[docs] def on_stage_start(self, runner: "IRunner") -> None:
"""Resolve amp/apex settings, prepare optimizer and scaler
Args:
runner(IRunner): current runner
"""
if self.use_amp is None and runner.experiment is not None:
self.use_amp = runner.experiment.distributed_params.get(
"amp", False
)
else:
self.use_amp = False
if self.use_apex is None and runner.experiment is not None:
self.use_apex = runner.experiment.distributed_params.get(
"apex", False
)
else:
self.use_apex = False
self._optimizer = get_attr(
runner, key="optimizer", inner_key=self.optimizer_key
)
# device based optimization step
if runner.device.type == "xla":
self._optimizer_step_fn = self._optimizer_step_tpu
elif self.use_amp:
self._optimizer_step_fn = self._optimizer_step_amp
else:
self._optimizer_step_fn = self._optimizer_step
if hasattr(self._optimizer, "_amp_stash") and not self.use_apex:
warnings.warn(
"`_amp_stash` is found in `self._optimizer`:, "
"but `use_apex` is False",
stacklevel=2,
)
assert self._optimizer is not None
if self.use_amp:
from torch.cuda.amp import GradScaler
self.scaler = GradScaler()
[docs] def on_epoch_start(self, runner: "IRunner") -> None:
"""On epoch start event.
Args:
runner: current runner
"""
if self.decouple_weight_decay:
self._optimizer_wds = [
group.get("weight_decay", 0.0)
for group in self._optimizer.param_groups
]
for i in range(len(self._optimizer.param_groups)):
self._optimizer.param_groups[i]["weight_decay"] = 0.0
[docs] def on_batch_start(self, runner: "IRunner") -> None:
"""On batch start event
Args:
runner: current runner
"""
if self.use_amp:
self.prev_autocast_state = torch.is_autocast_enabled()
torch.set_autocast_enabled(True)
torch.autocast_increment_nesting()
[docs] def on_batch_end(self, runner: "IRunner") -> None:
"""On batch end event
Args:
runner: current runner
"""
if self.use_amp:
# Drop the cache when we exit to a nesting level
# that's outside any instance of autocast.
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev_autocast_state)
if not runner.is_train_loader:
return
loss = runner.batch_metrics[self.metric_key]
self._accumulation_counter += 1
need_gradient_step = (
self._accumulation_counter % self.accumulation_steps == 0
)
# @TODO: speedup with re-definition ``on_stage_start``
if self.use_apex:
from apex import amp
# Need to set ``delay_unscale``
# according to
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not need_gradient_step
with amp.scale_loss(
loss, self._optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
elif self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
if need_gradient_step:
self.grad_step(
optimizer=self._optimizer, grad_clip_fn=self.grad_clip_fn,
)
if not self.use_fast_zero_grad:
maybe_recursive_call(self._optimizer, "zero_grad")
else:
maybe_recursive_call(self._optimizer, zero_grad)
self._accumulation_counter = 0
[docs] def on_epoch_end(self, runner: "IRunner") -> None:
"""On epoch end event.
Args:
runner: current runner
"""
if self.decouple_weight_decay:
for i, wd in enumerate(self._optimizer_wds):
self._optimizer.param_groups[i]["weight_decay"] = wd
lr = self._optimizer.param_groups[0]["lr"]
lr_name = (
f"lr/{self.optimizer_key}"
if self.optimizer_key is not None
else "lr"
)
runner.epoch_metrics[lr_name] = lr
momentum = get_optimizer_momentum(self._optimizer)
if momentum is not None:
momentum_name = (
f"momentum/{self.optimizer_key}"
if self.optimizer_key is not None
else "momentum"
)
runner.epoch_metrics[momentum_name] = momentum
[docs] def on_stage_end(self, runner: "IRunner") -> None:
"""On stage end event.
Args:
runner: current runner
"""
if self.use_amp:
self.scaler = None
__all__ = [
"IOptimizerCallback",
"OptimizerCallback",
]