Shortcuts

Source code for catalyst.metrics._functional_metric

from typing import Callable, Dict

import torch

from catalyst.metrics import ICallbackBatchMetric
from catalyst.metrics._additive import AdditiveValueMetric


[docs]class BatchFunctionalMetric(ICallbackBatchMetric): """ Class for custom metric in functional way. Note: the loader metrics calculated as average over all batch metrics Args: metric_fn: metric function, that get outputs, targets and return score as torch.Tensor metric_name: metric name """ def __init__(self, metric_fn: Callable, metric_name: str): """Init""" super().__init__(compute_on_call=True, prefix=metric_name) self.metric_fn = metric_fn self.cumulative_metric = AdditiveValueMetric() def reset(self): """Reset all statistics""" self.cumulative_metric.reset() def update_key_value( self, outputs: torch.Tensor, targets: torch.Tensor ) -> Dict[str, torch.Tensor]: """ Calculate metric and update average metric Args: outputs: tensor of logits targets: tensor of targets Returns: Dict with one element-custom metric """ value = self.update(outputs, targets) return {f"{self.prefix}": value} def update(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Calculate metric and update average metric Args: outputs: tensor of logits targets: tensor of targets Returns: custom metric """ value = self.metric_fn(outputs, targets) self.cumulative_metric.update(value, len(outputs)) return value def compute(self) -> torch.Tensor: """ Get metric average over all examples Returns: custom metric """ return self.cumulative_metric.compute()[0] def compute_key_value(self) -> Dict[str, torch.Tensor]: """ Get metric average over all examples Returns: Dict with one element-custom metric """ return {f"{self.prefix}/mean": self.compute()}
__all__ = ["BatchFunctionalMetric"]