Source code for catalyst.metrics._auc

from typing import Dict

import torch

from catalyst.metrics._metric import ICallbackLoaderMetric
from catalyst.metrics.functional._auc import auc
from catalyst.utils.distributed import all_gather, get_rank

[docs]class AUCMetric(ICallbackLoaderMetric): """AUC metric, Args: compute_on_call: if True, computes and returns metric value during metric call prefix: metric prefix suffix: metric suffix """ def __init__(self, compute_on_call: bool = True, prefix: str = None, suffix: str = None): """Init.""" super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix) self.metric_name = f"{self.prefix}auc{self.suffix}" self.scores = [] self.targets = [] self._is_ddp = False def reset(self, num_batches, num_samples) -> None: """Resets all fields""" self._is_ddp = get_rank() > -1 self.scores = [] self.targets = [] def update(self, scores: torch.Tensor, targets: torch.Tensor) -> None: """Updates metric value with statistics for new data. Args: scores: tensor with scores targets: tensor with targets """ self.scores.append(scores.cpu().detach()) self.targets.append(targets.cpu().detach()) def compute(self) -> torch.Tensor: """Computes the AUC metric based on saved statistics.""" targets = scores = # @TODO: ddp hotfix, could be done better if self._is_ddp: scores = targets = score = auc(outputs=scores, targets=targets) return score def compute_key_value(self) -> Dict[str, float]: """Computes the AUC metric based on saved statistics and returns key-value results.""" per_class_auc = self.compute() output = { f"{self.metric_name}/class_{i:02d}": value.item() for i, value in enumerate(per_class_auc) } output[self.metric_name] = per_class_auc.mean().item() return output
__all__ = ["AUCMetric"]