Shortcuts

Source code for catalyst.metrics._auc

from typing import Dict, Tuple

import torch

from catalyst.metrics._metric import ICallbackLoaderMetric
from catalyst.metrics.functional._auc import auc, binary_auc
from catalyst.metrics.functional._misc import process_multilabel_components
from catalyst.utils.distributed import all_gather, get_rank


[docs]class AUCMetric(ICallbackLoaderMetric): """AUC metric, Args: compute_on_call: if True, computes and returns metric value during metric call prefix: metric prefix suffix: metric suffix .. warning:: This metric is under API improvement. """ def __init__(self, compute_on_call: bool = True, prefix: str = None, suffix: str = None): """Init.""" super().__init__(compute_on_call=compute_on_call, prefix=prefix, suffix=suffix) self.metric_name = f"{self.prefix}auc{self.suffix}" self.scores = [] self.targets = [] self._is_ddp = get_rank() > -1 def reset(self, num_batches, num_samples) -> None: """Resets all fields""" self._is_ddp = get_rank() > -1 self.scores = [] self.targets = [] def update(self, scores: torch.Tensor, targets: torch.Tensor) -> None: """Updates metric value with statistics for new data. Args: scores: tensor with scores targets: tensor with targets """ self.scores.append(scores.cpu().detach()) self.targets.append(targets.cpu().detach()) def compute(self) -> Tuple[torch.Tensor, float, float, float]: """Computes the AUC metric based on saved statistics.""" targets = torch.cat(self.targets) scores = torch.cat(self.scores) # @TODO: ddp hotfix, could be done better if self._is_ddp: scores = torch.cat(all_gather(scores)) targets = torch.cat(all_gather(targets)) scores, targets, _ = process_multilabel_components(outputs=scores, targets=targets) per_class = auc(scores=scores, targets=targets) micro = binary_auc(scores=scores.view(-1), targets=targets.view(-1))[0] macro = per_class.mean().item() weights = targets.sum(axis=0) / len(targets) weighted = (per_class * weights).sum().item() return per_class, micro, macro, weighted def compute_key_value(self) -> Dict[str, float]: """Computes the AUC metric based on saved statistics and returns key-value results.""" per_class_auc, micro_auc, macro_auc, weighted_auc = self.compute() output = { f"{self.metric_name}/class_{i:02d}": value.item() for i, value in enumerate(per_class_auc) } output[f"{self.metric_name}/_micro"] = micro_auc output[self.metric_name] = macro_auc output[f"{self.metric_name}/_macro"] = macro_auc output[f"{self.metric_name}/_weighted"] = weighted_auc return output
__all__ = ["AUCMetric"]