Shortcuts

Source code for catalyst.callbacks.metric

from typing import Dict, Iterable, Optional, Tuple, Union
from abc import ABC, abstractmethod

import torch

from catalyst.core.callback import IMetricCallback
from catalyst.core.runner import IRunner
from catalyst.metrics._functional_metric import FunctionalBatchMetric
from catalyst.metrics._metric import ICallbackBatchMetric, ICallbackLoaderMetric, IMetric


class _IMetricCallback(IMetricCallback, ABC):
    """Metric callback interface, abstraction over metric step."""

    @abstractmethod
    def on_loader_start(self, runner: "IRunner") -> None:
        """
        On loader start action

        Args:
            runner: current runner
        """
        pass

    @abstractmethod
    def on_batch_end(self, runner: "IRunner") -> None:
        """
        On batch end action

        Args:
            runner: current runner
        """
        pass

    @abstractmethod
    def on_loader_end(self, runner: "IRunner") -> None:
        """
        On loader end action

        Args:
            runner: current runner
        """
        pass


class _MetricCallback(_IMetricCallback):
    """
    MetricCallback is a base implementation of callback
    that updates metrics over batch or loader.

    Args:
        metric: metric to calculate in callback
        input_key: keys of tensors that should be used as inputs in metric calculation
        target_key: keys of tensors that should be used as targets in metric calculation
    """

    def __init__(
        self,
        metric: Union[ICallbackBatchMetric, ICallbackLoaderMetric],
        input_key: Union[str, Iterable[str], Dict[str, str]],
        target_key: Union[str, Iterable[str], Dict[str, str]],
    ):
        """Init MetricCallback"""
        super().__init__()
        self.metric = metric
        assert isinstance(metric, IMetric)
        self._metric_update_method = self.metric.update

        kv_types = (dict, list, tuple)

        is_value_input = isinstance(input_key, str)
        is_value_targets = isinstance(target_key, str)
        is_key_value_input = isinstance(input_key, kv_types)
        is_key_value_targets = isinstance(target_key, kv_types)

        if is_value_input and is_value_targets:
            self._get_inputs = self._get_value_inputs
            self._update_metric = self._update_value_metric
        elif is_key_value_input and is_key_value_targets:
            self._get_inputs = self._get_key_value_inputs
            self._update_metric = self._update_key_value_metric
        else:
            raise NotImplementedError()

        self.input_key = input_key
        self.target_key = target_key
        self._keys = {
            **self._convert_keys_to_kv(input_key),
            **self._convert_keys_to_kv(target_key),
        }

    @staticmethod
    def _convert_keys_to_kv(
        keys: Union[str, Iterable[str], Dict[str, str]]
    ) -> Dict[str, str]:
        """
        Convert keys to key-value format

        Args:
            keys: keys to convert

        Returns:
            dict of keys like {"a": "b"} where "a" is a field name of field in batch,
                "b" is a name of the same data for metric
        """
        kv_keys = {}
        if isinstance(keys, dict):
            kv_keys.update(keys)
        elif isinstance(keys, str):
            kv_keys[keys] = keys
        else:
            for key in keys:
                kv_keys[key] = key
        return kv_keys

    def _get_value_inputs(self, runner: "IRunner") -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get data from batch in value input case

        Args:
            runner: current runner

        Returns:
            tuple of tensor of inputs and tensor of targets
        """
        return runner.batch[self.input_key], runner.batch[self.target_key]

    def _get_key_value_inputs(self, runner: "IRunner") -> Dict[str, torch.Tensor]:
        """
        Get data from batch in key-value input case

        Args:
            runner: current runner

        Returns:
            dict of inputs and targets tensors
        """
        kv_inputs = {}
        for key in self._keys:
            kv_inputs[self._keys[key]] = runner.batch[key]
        return kv_inputs

    def _update_value_metric(
        self, value_inputs: Tuple[torch.Tensor, torch.Tensor]
    ) -> Optional[Dict[str, float]]:
        """
        Update metric in value input case

        Args:
            value_inputs: tuple of input tensor and target tensor

        Returns:
            result of metric update: None or metric values
        """
        return self._metric_update_method(*value_inputs)

    def _update_key_value_metric(
        self, kv_inputs: Dict[str, torch.Tensor]
    ) -> Optional[Dict[str, float]]:
        """
        Update metric in key-value input case

        Args:
            kv_inputs: input tensors in key-value format

        Returns:
            result of metric update: None or metric values
        """
        return self._metric_update_method(**kv_inputs)


[docs]class BatchMetricCallback(_MetricCallback): """BatchMetricCallback implements batch-based metrics update and computation over loader Args: metric: metric to calculate in callback input_key: keys of tensors that should be used as inputs in metric calculation target_key: keys of tensors that should be used as targets in metric calculation log_on_batch: boolean flag to log computed metrics every batch """ def __init__( self, metric: ICallbackBatchMetric, input_key: Union[str, Iterable[str], Dict[str, str]], target_key: Union[str, Iterable[str], Dict[str, str]], log_on_batch: bool = True, ) -> None: """Init BatchMetricCallback""" super().__init__(metric=metric, input_key=input_key, target_key=target_key) assert isinstance(metric, ICallbackBatchMetric) self.log_on_batch = log_on_batch self._metric_update_method = self.metric.update_key_value def on_loader_start(self, runner: "IRunner") -> None: """On loader start action: reset metric values Args: runner: current runner """ self.metric.reset() def on_batch_end(self, runner: "IRunner") -> None: """On batch end action: update metric with new batch data and log it's value if necessary Args: runner: current runner """ metrics_inputs = self._get_inputs(runner=runner) metrics = self._update_metric(metrics_inputs) if self.log_on_batch: runner.batch_metrics.update(metrics) def on_loader_end(self, runner: "IRunner") -> None: """On loader end action: compute metric values and update runner's loader metrics with it Args: runner: current runner """ metrics = self.metric.compute_key_value() metrics = runner.engine.mean_reduce_ddp_metrics(metrics) runner.loader_metrics.update(metrics)
class FunctionalBatchMetricCallback(BatchMetricCallback): """FunctionalBatchMetricCallback implements batch-based metrics update and computation over loader for ``FunctionalBatchMetric`` metrics. Args: metric: metric to calculate in callback input_key: keys of tensors that should be used as inputs in metric calculation target_key: keys of tensors that should be used as targets in metric calculation log_on_batch: boolean flag to log computed metrics every batch .. note:: The main difference from BatchMetricCallback: FunctionalBatchMetricCallback also propagates current ``batch_size`` to the FunctionalBatchMetric for correct metric computation. """ def __init__( self, metric: FunctionalBatchMetric, input_key: Union[str, Iterable[str], Dict[str, str]], target_key: Union[str, Iterable[str], Dict[str, str]], log_on_batch: bool = True, ) -> None: """Init.""" assert isinstance(metric, FunctionalBatchMetric) super().__init__( metric=metric, input_key=input_key, target_key=target_key, log_on_batch=log_on_batch, ) def _get_value_inputs( self, runner: "IRunner" ) -> Tuple[float, torch.Tensor, torch.Tensor]: """Get data from batch in value input case Args: runner: current runner Returns: tuple of tensor of inputs and tensor of targets """ return ( runner.batch_size, runner.batch[self.input_key], runner.batch[self.target_key], ) def _get_key_value_inputs(self, runner: "IRunner") -> Dict[str, torch.Tensor]: """Get data from batch in key-value input case Args: runner: current runner Returns: dict of inputs and targets tensors """ kv_inputs = {} for key in self._keys: kv_inputs[self._keys[key]] = runner.batch[key] kv_inputs["batch_size"] = runner.batch_size return kv_inputs
[docs]class LoaderMetricCallback(_MetricCallback): """LoaderMetricCallback implements loader-based metrics update and computation over loader Args: metric: metric to calculate in callback input_key: keys of tensors that should be used as inputs in metric calculation target_key: keys of tensors that should be used as targets in metric calculation """ def __init__( self, metric: ICallbackLoaderMetric, input_key: Union[str, Iterable[str], Dict[str, str]], target_key: Union[str, Iterable[str], Dict[str, str]], ): super().__init__(metric=metric, input_key=input_key, target_key=target_key) assert isinstance(metric, ICallbackLoaderMetric) def on_loader_start(self, runner: "IRunner") -> None: """On loader star action: reset metric values in case of ICallbackLoaderMetric metric Args: runner: current runner """ self.metric.reset( num_batches=runner.loader_batch_len, num_samples=runner.loader_sample_len ) def on_batch_end(self, runner: "IRunner") -> None: """On batch end action: get data from runner's batch and update metrics with it Args: runner: current runner """ metrics_inputs = self._get_inputs(runner=runner) self._update_metric(metrics_inputs) def on_loader_end(self, runner: "IRunner") -> None: """On loader end action: compute metric values and update runner's loader metrics with it Args: runner: current runner """ metrics = self.metric.compute_key_value() runner.loader_metrics.update(metrics)
__all__ = [ "_IMetricCallback", "_MetricCallback", "BatchMetricCallback", "FunctionalBatchMetricCallback", "LoaderMetricCallback", ]