Source code for catalyst.metrics._accuracy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from catalyst.metrics._additive import AdditiveValueMetric
from catalyst.metrics._metric import ICallbackBatchMetric
from catalyst.metrics.functional._accuracy import accuracy, multilabel_accuracy
from catalyst.metrics.functional._misc import get_default_topk_args
[docs]class AccuracyMetric(ICallbackBatchMetric):
"""
This metric computes accuracy for multiclass classification case.
It computes mean value of accuracy and it's approximate std value
(note that it's not a real accuracy std but std of accuracy over batch mean values).
Args:
topk_args: list of `topk` for accuracy@topk computing
num_classes: number of classes
compute_on_call: if True, computes and returns metric value during metric call
prefix: metric prefix
suffix: metric suffix
"""
def __init__(
self,
topk_args: List[int] = None,
num_classes: int = None,
compute_on_call: bool = True,
prefix: str = None,
suffix: str = None,
):
"""Init AccuracyMetric"""
super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix)
self.metric_name_mean = f"{self.prefix}accuracy{self.suffix}"
self.metric_name_std = f"{self.prefix}accuracy{self.suffix}/std"
self.topk_args: List[int] = topk_args or get_default_topk_args(num_classes)
self.additive_metrics: List[AdditiveValueMetric] = [
AdditiveValueMetric() for _ in range(len(self.topk_args))
]
def reset(self) -> None:
"""Reset all fields"""
for metric in self.additive_metrics:
metric.reset()
def update(self, logits: torch.Tensor, targets: torch.Tensor) -> List[float]:
"""
Updates metric value with accuracy for new data and return intermediate metrics values.
Args:
logits: tensor of logits
targets: tensor of targets
Returns:
list of accuracy@k values
"""
values = accuracy(logits, targets, topk=self.topk_args)
values = [v.item() for v in values]
for value, metric in zip(values, self.additive_metrics):
metric.update(value, len(targets))
return values
def update_key_value(self, logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
"""
Update metric value with accuracy for new data and return intermediate metrics
values in key-value format.
Args:
logits: tensor of logits
targets: tensor of targets
Returns:
dict of accuracy@k values
"""
values = self.update(logits=logits, targets=targets)
output = {
f"{self.prefix}accuracy{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, values)
}
output[self.metric_name_mean] = output[f"{self.prefix}accuracy01{self.suffix}"]
return output
def compute(self) -> Tuple[List[float], List[float]]:
"""
Compute accuracy for all data
Returns:
list of mean values, list of std values
"""
means, stds = zip(*(metric.compute() for metric in self.additive_metrics))
return means, stds
def compute_key_value(self) -> Dict[str, float]:
"""
Compute accuracy for all data and return results in key-value format
Returns:
dict of metrics
"""
means, stds = self.compute()
output_mean = {
f"{self.prefix}accuracy{key:02d}{self.suffix}": value
for key, value in zip(self.topk_args, means)
}
output_std = {
f"{self.prefix}accuracy{key:02d}{self.suffix}/std": value
for key, value in zip(self.topk_args, stds)
}
output_mean[self.metric_name_mean] = output_mean[f"{self.prefix}accuracy01{self.suffix}"]
output_std[self.metric_name_std] = output_std[f"{self.prefix}accuracy01{self.suffix}/std"]
return {**output_mean, **output_std}
[docs]class MultilabelAccuracyMetric(AdditiveValueMetric, ICallbackBatchMetric):
"""
This metric computes accuracy for multilabel classification case.
It computes mean value of accuracy and it's approximate std value
(note that it's not a real accuracy std but std of accuracy over batch mean values).
Args:
compute_on_call: if True, computes and returns metric value during metric call
prefix: metric prefix
suffix: metric suffix
threshold: thresholds for model scores
"""
def __init__(
self,
threshold: Union[float, torch.Tensor] = 0.5,
compute_on_call: bool = True,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
):
"""Init MultilabelAccuracyMetric"""
super().__init__(compute_on_call=compute_on_call)
self.prefix = prefix or ""
self.suffix = suffix or ""
self.metric_name_mean = f"{self.prefix}accuracy{self.suffix}"
self.metric_name_std = f"{self.prefix}accuracy{self.suffix}/std"
self.threshold = threshold
def update(self, outputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Update metric value with accuracy for new data and return intermediate metric value.
Args:
outputs: tensor of outputs
targets: tensor of true answers
Returns:
accuracy metric for outputs and targets
"""
metric = multilabel_accuracy(
outputs=outputs, targets=targets, threshold=self.threshold
).item()
super().update(value=metric, num_samples=np.prod(targets.shape))
return metric
def update_key_value(self, outputs: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
"""
Update metric value with accuracy for new data and return intermediate metric
value in key-value format.
Args:
outputs: tensor of outputs
targets: tensor of true answers
Returns:
accuracy metric for outputs and targets
"""
metric = self.update(outputs=outputs, targets=targets)
return {self.metric_name_mean: metric}
def compute_key_value(self) -> Dict[str, float]:
"""
Compute accuracy for all data and return results in key-value format
Returns:
dict of metrics
"""
metric_mean, metric_std = self.compute()
return {
self.metric_name_mean: metric_mean,
self.metric_name_std: metric_std,
}
__all__ = ["AccuracyMetric", "MultilabelAccuracyMetric"]