Source code for catalyst.callbacks.backward
from typing import Callable, Dict, TYPE_CHECKING, Union
from functools import partial
from catalyst.core.callback import IBackwardCallback
from catalyst.registry import REGISTRY
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
[docs]class BackwardCallback(IBackwardCallback):
"""Optimizer callback, abstraction over backward step.
Args:
metric_key: a key to get loss from ``runner.batch_metrics``
grad_clip_fn: callable gradient cliping function or it's name
grad_clip_params: key-value parameters for grad_clip_fn
log_gradient: boolean flag to log gradient norm to ``runner.batch_metrics``
.. note::
Please follow the `minimal examples`_ sections for more use cases.
.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505
"""
[docs] def __init__(
self,
metric_key: str,
grad_clip_fn: Union[str, Callable] = None,
grad_clip_params: Dict = None,
log_gradient: bool = False,
):
"""Init."""
super().__init__()
self.metric_key = metric_key
if isinstance(grad_clip_fn, str):
self.grad_clip_fn = REGISTRY.get(grad_clip_fn)
else:
self.grad_clip_fn = grad_clip_fn
if grad_clip_params is not None:
self.grad_clip_fn = partial(self.grad_clip_fn, **grad_clip_params)
self._prefix_gradient = f"gradient/{metric_key}"
self._log_gradient = log_gradient
def on_batch_end(self, runner: "IRunner"):
"""Event handler."""
if runner.is_train_loader:
loss = runner.batch_metrics[self.metric_key]
runner.engine.backward(loss)
if self.grad_clip_fn is not None:
runner.engine.unscale_gradients()
norm = self.grad_clip_fn(self.model.parameters())
if self._log_gradient:
runner.batch_metrics[f"{self._prefix_gradient}/norm"] = norm
__all__ = ["BackwardCallback"]