Source code for catalyst.callbacks.criterion
import torch
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.metrics._additive import AdditiveValueMetric
from catalyst.utils.misc import get_attr
class ICriterionCallback(Callback):
"""Criterion callback interface, abstraction over criterion step."""
pass
# @TODO: add KV support
[docs]class CriterionCallback(ICriterionCallback):
"""Criterion callback, abstraction over criterion step.
Args:
input_key:
target_key:
metric_key: prefix for metrics and output key for loss
in ``runner.batch_metrics`` dictionary
criterion_key: A key to take a criterion in case
there are several of them and they are in a dictionary format.
"""
def __init__(
self, input_key: str, target_key: str, metric_key: str, criterion_key: str = None,
):
"""Init."""
super().__init__(order=CallbackOrder.metric, node=CallbackNode.all)
self.input_key = input_key
self.target_key = target_key
self.metric_key = metric_key
self.criterion_key = criterion_key
self.additive_metric = AdditiveValueMetric()
self.criterion = None
def on_stage_start(self, runner: "IRunner"):
"""Checks that the current stage has correct criterion.
Args:
runner: current runner
"""
self.criterion = get_attr(runner, key="criterion", inner_key=self.criterion_key)
assert self.criterion is not None
def on_loader_start(self, runner: "IRunner") -> None:
"""Event handler."""
self.additive_metric.reset()
def on_batch_end(self, runner: "IRunner"):
"""Event handler."""
inputs, targets = runner.batch[self.input_key], runner.batch[self.target_key]
# NOTE: similar to amp guides in docs
# https://pytorch.org/docs/stable/notes/amp_examples.html
# with runner.engine.autocast():
loss = self.criterion(inputs, targets)
runner.batch_metrics.update({self.metric_key: loss})
self.additive_metric.update(loss.detach().cpu(), len(targets))
def on_loader_end(self, runner: "IRunner") -> None:
"""Event handler."""
mean, std = self.additive_metric.compute()
metrics = {self.metric_key: mean, f"{self.metric_key}/std": std}
metrics = {
k: runner.engine.sync_tensor(torch.tensor(v, device=runner.device), "mean")
for k, v in metrics.items()
}
runner.loader_metrics.update(metrics)
__all__ = ["ICriterionCallback", "CriterionCallback"]