Source code for catalyst.dl.core.callback

from typing import Callable, List  # isort:skip
from collections import defaultdict
from enum import IntFlag

import numpy as np

import torch

from catalyst.utils import get_activation_fn
from .state import RunnerState


[docs]class CallbackOrder(IntFlag): Unknown = -100 Internal = 0 Criterion = 20 Optimizer = 40 Scheduler = 60 Metric = 80 External = 100 Other = 200
[docs]class Callback: """ Abstract class that all callback (e.g., Logger) classes extends from. Must be extended before usage. usage example: .. code:: bash -- stage start ---- epoch start (one epoch - one run of every loader) ------ loader start -------- batch start -------- batch handler -------- batch end ------ loader end ---- epoch end -- stage end exception – if an Exception was raised All callbacks has ``order`` value from ``CallbackOrder`` """
[docs] def __init__(self, order: int): """ For order see ``CallbackOrder`` class """ self.order = order
[docs] def on_stage_start(self, state: RunnerState): pass
[docs] def on_stage_end(self, state: RunnerState): pass
[docs] def on_epoch_start(self, state: RunnerState): pass
[docs] def on_epoch_end(self, state: RunnerState): pass
[docs] def on_loader_start(self, state: RunnerState): pass
[docs] def on_loader_end(self, state: RunnerState): pass
[docs] def on_batch_start(self, state: RunnerState): pass
[docs] def on_batch_end(self, state: RunnerState): pass
[docs] def on_exception(self, state: RunnerState): pass
[docs]class MetricCallback(Callback): """ A callback that returns single metric on `state.on_batch_end` """ def __init__( self, prefix: str, metric_fn: Callable, input_key: str = "targets", output_key: str = "logits", **metric_params ): super().__init__(CallbackOrder.Metric) self.prefix = prefix self.metric_fn = metric_fn self.input_key = input_key self.output_key = output_key self.metric_params = metric_params
[docs] def on_batch_end(self, state: RunnerState): outputs = state.output[self.output_key] targets = state.input[self.input_key] metric = self.metric_fn(outputs, targets, **self.metric_params) state.metrics.add_batch_value(name=self.prefix, value=metric)
[docs]class MultiMetricCallback(Callback): """ A callback that returns multiple metrics on `state.on_batch_end` """ def __init__( self, prefix: str, metric_fn: Callable, list_args: List, input_key: str = "targets", output_key: str = "logits", **metric_params ): super().__init__(CallbackOrder.Metric) self.prefix = prefix self.metric_fn = metric_fn self.list_args = list_args self.input_key = input_key self.output_key = output_key self.metric_params = metric_params
[docs] def on_batch_end(self, state: RunnerState): outputs = state.output[self.output_key] targets = state.input[self.input_key] metrics_ = self.metric_fn( outputs, targets, self.list_args, **self.metric_params ) batch_metrics = {} for arg, metric in zip(self.list_args, metrics_): if isinstance(arg, int): key = f"{self.prefix}{arg:02}" else: key = f"{self.prefix}_{arg}" batch_metrics[key] = metric state.metrics.add_batch_value(metrics_dict=batch_metrics)
[docs]class LoggerCallback(Callback): """ Loggers are executed on ``start`` before all callbacks, and on ``end`` after all callbacks. """ def __init__(self, order: int = None): if order is None: order = CallbackOrder.Internal super().__init__(order=order)
[docs]class MeterMetricsCallback(Callback): """ A callback that tracks metrics through meters and prints metrics for each class on `state.on_loader_end`. This callback works for both single metric and multi-metric meters. """
[docs] def __init__( self, metric_names: List[str], meter_list: List, input_key: str = "targets", output_key: str = "logits", class_names: List[str] = None, num_classes: int = 2, activation: str = "Sigmoid", ): """ Args: metric_names (List[str]): of metrics to print Make sure that they are in the same order that metrics are outputted by the meters in `meter_list` meter_list (list-like): List of meters.meter.Meter instances len(meter_list) == n_classes input_key (str): input key to use for metric calculation specifies our ``y_true``. output_key (str): output key to use for metric calculation; specifies our ``y_pred`` class_names (List[str]): class names to display in the logs. If None, defaults to indices for each class, starting from 0. num_classes (int): Number of classes; must be > 1 activation (str): An torch.nn activation applied to the logits. Must be one of ['none', 'Sigmoid', 'Softmax2d'] """ super().__init__(CallbackOrder.Metric) self.metric_names = metric_names self.meters = meter_list self.input_key = input_key self.output_key = output_key self.class_names = class_names self.num_classes = num_classes self.activation = activation
def _reset_stats(self): for meter in self.meters: meter.reset()
[docs] def on_loader_start(self, state): self._reset_stats()
[docs] def on_batch_end(self, state: RunnerState): logits: torch.Tensor = state.output[self.output_key].detach().float() targets: torch.Tensor = state.input[self.input_key].detach().float() activation_fn = get_activation_fn(self.activation) probabilities: torch.Tensor = activation_fn(logits) for i in range(self.num_classes): self.meters[i].add(probabilities[:, i], targets[:, i])
[docs] def on_loader_end(self, state: RunnerState): metrics_tracker = defaultdict(list) loader_values = state.metrics.epoch_values[state.loader_name] # Computing metrics for each class for i, meter in enumerate(self.meters): metrics = meter.value() postfix = self.class_names[i] \ if self.class_names is not None \ else str(i) for prefix, metric_ in zip(self.metric_names, metrics): # appending the per-class values metrics_tracker[prefix].append(metric_) metric_name = f"{prefix}/class_{postfix}" loader_values[metric_name] = metric_ # averaging the per-class values for each metric for prefix in self.metric_names: mean_value = float(np.mean(metrics_tracker[prefix])) metric_name = f"{prefix}/_mean" loader_values[metric_name] = mean_value self._reset_stats()
__all__ = [ "CallbackOrder", "Callback", "MetricCallback", "MultiMetricCallback", "LoggerCallback", "MeterMetricsCallback", ]