
Source code for catalyst.callbacks.metric

from typing import Dict, Iterable, Optional, Tuple, Union
from abc import ABC, abstractmethod

import torch

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.metrics._functional_metric import FunctionalBatchMetric
from catalyst.metrics._metric import ICallbackBatchMetric, ICallbackLoaderMetric, IMetric

class IMetricCallback(Callback, ABC):
    """Metric callback interface, abstraction over metric step."""

    def on_loader_start(self, runner: "IRunner") -> None:
        On loader start action

            runner: current runner

    def on_batch_end(self, runner: "IRunner") -> None:
        On batch end action

            runner: current runner

    def on_loader_end(self, runner: "IRunner") -> None:
        On loader end action

            runner: current runner

class MetricCallback(IMetricCallback):
    MetricCallback is a base implementation of callback that updates metrics over batch or loader.

        metric: metric to calculate in callback
        input_key: keys of tensors that should be used as inputs in metric calculation
        target_key: keys of tensors that should be used as targets in metric calculation

    def __init__(
        metric: Union[ICallbackBatchMetric, ICallbackLoaderMetric],
        input_key: Union[str, Iterable[str], Dict[str, str]],
        target_key: Union[str, Iterable[str], Dict[str, str]],
        """Init MetricCallback"""
        super().__init__(order=CallbackOrder.metric, node=CallbackNode.all)
        self.metric = metric
        assert isinstance(metric, IMetric)
        self._metric_update_method = self.metric.update

        kv_types = (dict, list, tuple)

        is_value_input = isinstance(input_key, str)
        is_value_targets = isinstance(target_key, str)
        is_key_value_input = isinstance(input_key, kv_types)
        is_key_value_targets = isinstance(target_key, kv_types)

        if is_value_input and is_value_targets:
            self._get_inputs = self._get_value_inputs
            self._update_metric = self._update_value_metric
        elif is_key_value_input and is_key_value_targets:
            self._get_inputs = self._get_key_value_inputs
            self._update_metric = self._update_key_value_metric
            raise NotImplementedError()

        self.input_key = input_key
        self.target_key = target_key
        self._keys = {

    def _convert_keys_to_kv(keys: Union[str, Iterable[str], Dict[str, str]]) -> Dict[str, str]:
        Convert keys to key-value format

            keys: keys to convert

            dict of keys like {"a": "b"} where "a" is a field name of field in batch,
                "b" is a name of the same data for metric
        kv_keys = {}
        if isinstance(keys, dict):
        elif isinstance(keys, str):
            kv_keys[keys] = keys
            for key in keys:
                kv_keys[key] = key
        return kv_keys

    def _get_value_inputs(self, runner: "IRunner") -> Tuple[torch.Tensor, torch.Tensor]:
        Get data from batch in value input case

            runner: current runner

            tuple of tensor of inputs and tensor of targets
        return runner.batch[self.input_key], runner.batch[self.target_key]

    def _get_key_value_inputs(self, runner: "IRunner") -> Dict[str, torch.Tensor]:
        Get data from batch in key-value input case

            runner: current runner

            dict of inputs and targets tensors
        kv_inputs = {}
        for key in self._keys:
            kv_inputs[self._keys[key]] = runner.batch[key]
        return kv_inputs

    def _update_value_metric(
        self, value_inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Optional[Dict[str, float]]:
        Update metric in value input case

            value_inputs: tuple of input tensor and target tensor

            result of metric update: None or metric values
        return self._metric_update_method(*value_inputs)

    def _update_key_value_metric(
        self, kv_inputs: Dict[str, torch.Tensor]
    ) -> Optional[Dict[str, float]]:
        Update metric in key-value input case

            kv_inputs: input tensors in key-value format

            result of metric update: None or metric values
        return self._metric_update_method(**kv_inputs)

[docs]class BatchMetricCallback(MetricCallback): """BatchMetricCallback implements batch-based metrics update and computation over loader Args: metric: metric to calculate in callback input_key: keys of tensors that should be used as inputs in metric calculation target_key: keys of tensors that should be used as targets in metric calculation log_on_batch: boolean flag to log computed metrics every batch """ def __init__( self, metric: ICallbackBatchMetric, input_key: Union[str, Iterable[str], Dict[str, str]], target_key: Union[str, Iterable[str], Dict[str, str]], log_on_batch: bool = True, ) -> None: """Init BatchMetricCallback""" super().__init__(metric=metric, input_key=input_key, target_key=target_key) assert isinstance(metric, ICallbackBatchMetric) self.log_on_batch = log_on_batch self._metric_update_method = self.metric.update_key_value def on_loader_start(self, runner: "IRunner") -> None: """On loader start action: reset metric values Args: runner: current runner """ self.metric.reset() def on_batch_end(self, runner: "IRunner") -> None: """On batch end action: update metric with new batch data and log it's value if necessary Args: runner: current runner """ metrics_inputs = self._get_inputs(runner=runner) metrics = self._update_metric(metrics_inputs) if self.log_on_batch: runner.batch_metrics.update(metrics) def on_loader_end(self, runner: "IRunner") -> None: """On loader end action: compute metric values and update runner's loader metrics with it Args: runner: current runner """ metrics = self.metric.compute_key_value() metrics = runner.engine.sync_metrics(metrics) runner.loader_metrics.update(metrics)
class FunctionalBatchMetricCallback(BatchMetricCallback): """FunctionalBatchMetricCallback implements batch-based metrics update and computation over loader for ``FunctionalBatchMetric`` metrics. Args: metric: metric to calculate in callback input_key: keys of tensors that should be used as inputs in metric calculation target_key: keys of tensors that should be used as targets in metric calculation log_on_batch: boolean flag to log computed metrics every batch .. note:: The main difference from BatchMetricCallback: FunctionalBatchMetricCallback also propagates current ``batch_size`` to the FunctionalBatchMetric for correct metric computation. """ def __init__( self, metric: FunctionalBatchMetric, input_key: Union[str, Iterable[str], Dict[str, str]], target_key: Union[str, Iterable[str], Dict[str, str]], log_on_batch: bool = True, ) -> None: """Init.""" assert isinstance(metric, FunctionalBatchMetric) super().__init__( metric=metric, input_key=input_key, target_key=target_key, log_on_batch=log_on_batch ) def _get_value_inputs(self, runner: "IRunner") -> Tuple[float, torch.Tensor, torch.Tensor]: """Get data from batch in value input case Args: runner: current runner Returns: tuple of tensor of inputs and tensor of targets """ return runner.batch_size, runner.batch[self.input_key], runner.batch[self.target_key] def _get_key_value_inputs(self, runner: "IRunner") -> Dict[str, torch.Tensor]: """Get data from batch in key-value input case Args: runner: current runner Returns: dict of inputs and targets tensors """ kv_inputs = {} for key in self._keys: kv_inputs[self._keys[key]] = runner.batch[key] kv_inputs["batch_size"] = runner.batch_size return kv_inputs
[docs]class LoaderMetricCallback(MetricCallback): """LoaderMetricCallback implements loader-based metrics update and computation over loader Args: metric: metric to calculate in callback input_key: keys of tensors that should be used as inputs in metric calculation target_key: keys of tensors that should be used as targets in metric calculation """ def __init__( self, metric: ICallbackLoaderMetric, input_key: Union[str, Iterable[str], Dict[str, str]], target_key: Union[str, Iterable[str], Dict[str, str]], ): super().__init__(metric=metric, input_key=input_key, target_key=target_key) assert isinstance(metric, ICallbackLoaderMetric) def on_loader_start(self, runner: "IRunner") -> None: """On loader star action: reset metric values in case of ICallbackLoaderMetric metric Args: runner: current runner """ self.metric.reset( num_batches=runner.loader_batch_len, num_samples=runner.loader_sample_len, ) def on_batch_end(self, runner: "IRunner") -> None: """On batch end action: get data from runner's batch and update metrics with it Args: runner: current runner """ metrics_inputs = self._get_inputs(runner=runner) self._update_metric(metrics_inputs) def on_loader_end(self, runner: "IRunner") -> None: """On loader end action: compute metric values and update runner's loader metrics with it Args: runner: current runner """ metrics = self.metric.compute_key_value() runner.loader_metrics.update(metrics)
__all__ = [ "IMetricCallback", "BatchMetricCallback", "FunctionalBatchMetricCallback", "LoaderMetricCallback", ]