Source code for catalyst.metrics._topk_metric
from typing import Any, Callable, Dict, Iterable, List
import torch
from catalyst.metrics._additive import AdditiveMetric
from catalyst.metrics._metric import ICallbackBatchMetric
[docs]class TopKMetric(ICallbackBatchMetric):
"""
Base class for `topk` metrics.
Args:
metric_name: name of the metric
metric_function: metric calculation function
topk_args: list of `topk` for metric@topk computing
compute_on_call: if True, computes and returns metric value during metric call
prefix: metric prefix
suffix: metric suffix
"""
def __init__(
self,
metric_name: str,
metric_function: Callable,
topk_args: Iterable[int] = None,
compute_on_call: bool = True,
prefix: str = None,
suffix: str = None,
):
"""Init TopKMetric"""
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
self.metric_name = metric_name
self.metric_function = metric_function
self.topk_args = topk_args or (1,)
self.metrics: List[AdditiveMetric] = [AdditiveMetric() for _ in range(len(self.topk_args))]
def reset(self) -> None:
"""Reset all fields"""
for metric in self.metrics:
metric.reset()
def update(self, logits: torch.Tensor, targets: torch.Tensor) -> List[float]:
"""
Update metric value with value for new data and return intermediate metrics values.
Args:
logits (torch.Tensor): tensor of logits
targets (torch.Tensor): tensor of targets
Returns:
list of metric@k values
"""
values = self.metric_function(logits, targets, topk=self.topk_args)
values = [v.item() for v in values]
for value, metric in zip(values, self.metrics):
metric.update(value, len(targets))
return values
def update_key_value(self, logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
"""
Update metric value with value for new data and return intermediate metrics
values in key-value format.
Args:
logits (torch.Tensor): tensor of logits
targets (torch.Tensor): tensor of targets
Returns:
dict of metric@k values
"""
values = self.update(logits=logits, targets=targets)
output = {
f"{self.prefix}{self.metric_name}{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, values)
}
return output
def compute(self) -> Any:
"""
Compute metric for all data
Returns:
list of mean values, list of std values
"""
means, stds = zip(*(metric.compute() for metric in self.metrics))
return means, stds
def compute_key_value(self) -> Dict[str, float]:
"""
Compute metric for all data and return results in key-value format
Returns:
dict of metrics
"""
means, stds = self.compute()
output_mean = {
f"{self.prefix}{self.metric_name}{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, means)
}
output_std = {
f"{self.prefix}{self.metric_name}{key:02d}{self.suffix}/std": value
for key, value in zip(self.topk_args, stds)
}
return {**output_mean, **output_std}
__all__ = ["TopKMetric"]