Source code for catalyst.metrics.accuracy

Various accuracy metrics:
    * :func:`accuracy`
    * :func:`multi_label_accuracy`
from typing import Optional, Sequence, Union

import numpy as np

import torch

from catalyst.metrics.functional import process_multilabel_components
from catalyst.utils.torch import get_activation_fn

[docs]def accuracy( outputs: torch.Tensor, targets: torch.Tensor, topk: Sequence[int] = (1,), activation: Optional[str] = None, ) -> Sequence[torch.Tensor]: """ Computes multi-class accuracy@topk for the specified values of `topk`. Args: outputs: model outputs, logits with shape [bs; num_classes] targets: ground truth, labels with shape [bs; 1] activation: activation to use for model output topk: `topk` for accuracy@topk computing Returns: list with computed accuracy@topk """ activation_fn = get_activation_fn(activation) outputs = activation_fn(outputs) max_k = max(topk) batch_size = targets.size(0) if len(outputs.shape) == 1 or outputs.shape[1] == 1: # binary accuracy pred = outputs.t() else: # multi-class accuracy _, pred = outputs.topk(max_k, 1, True, True) # noqa: WPS425 pred = pred.t() correct = pred.eq(targets.long().view(1, -1).expand_as(pred)) output = [] for k in topk: correct_k = ( correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) ) output.append(correct_k.mul_(1.0 / batch_size)) return output
[docs]def multi_label_accuracy( outputs: torch.Tensor, targets: torch.Tensor, threshold: Union[float, torch.Tensor], activation: Optional[str] = None, ) -> torch.Tensor: """ Computes multi-label accuracy for the specified activation and threshold. Args: outputs: NxK tensor that for each of the N examples indicates the probability of the example belonging to each of the K classes, according to the model. targets: binary NxK tensort that encodes which of the K classes are associated with the N-th input (eg: a row [0, 1, 0, 1] indicates that the example is associated with classes 2 and 4) threshold: threshold for for model output activation: activation to use for model output Returns: computed multi-label accuracy """ outputs, targets, _ = process_multilabel_components( outputs=outputs, targets=targets ) activation_fn = get_activation_fn(activation) outputs = activation_fn(outputs) outputs = (outputs > threshold).long() output = (targets.long() == outputs.long()).sum().float() / targets.shape ) return output
__all__ = ["accuracy", "multi_label_accuracy"]