
Source code for catalyst.metrics.functional._accuracy

from typing import Sequence, Union

import numpy as np

import torch

from catalyst.metrics.functional import process_multilabel_components

[docs]def accuracy( outputs: torch.Tensor, targets: torch.Tensor, topk: Sequence[int] = (1,) ) -> Sequence[torch.Tensor]: """ Computes multiclass accuracy@topk for the specified values of `topk`. Args: outputs: model outputs, logits with shape [bs; num_classes] targets: ground truth, labels with shape [bs; 1] topk: `topk` for accuracy@topk computing Returns: list with computed accuracy@topk Examples: .. code-block:: python import torch from catalyst import metrics metrics.accuracy( outputs=torch.tensor([ [1, 0, 0], [0, 1, 0], [0, 0, 1], ]), targets=torch.tensor([0, 1, 2]), ) # [tensor([1.])] .. code-block:: python import torch from catalyst import metrics metrics.accuracy( outputs=torch.tensor([ [1, 0, 0], [0, 1, 0], [0, 1, 0], ]), targets=torch.tensor([0, 1, 2]), ) # [tensor([0.6667])] .. code-block:: python import torch from catalyst import metrics metrics.accuracy( outputs=torch.tensor([ [1, 0, 0], [0, 1, 0], [0, 0, 1], ]), targets=torch.tensor([0, 1, 2]), topk=[1, 3], ) # [tensor([1.]), tensor([1.])] .. code-block:: python import torch from catalyst import metrics metrics.accuracy( outputs=torch.tensor([ [1, 0, 0], [0, 1, 0], [0, 1, 0], ]), targets=torch.tensor([0, 1, 2]), topk=[1, 3], ) # [tensor([0.6667]), tensor([1.])] """ max_k = max(topk) batch_size = targets.size(0) if len(outputs.shape) == 1 or outputs.shape[1] == 1: # binary accuracy pred = outputs.t() else: # multiclass accuracy _, pred = outputs.topk(max_k, 1, True, True) pred = pred.t() correct = pred.eq(targets.long().view(1, -1).expand_as(pred)) output = [] for k in topk: correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) output.append(correct_k.mul_(1.0 / batch_size)) return output
[docs]def multilabel_accuracy( outputs: torch.Tensor, targets: torch.Tensor, threshold: Union[float, torch.Tensor] ) -> torch.Tensor: """ Computes multilabel accuracy for the specified activation and threshold. Args: outputs: NxK tensor that for each of the N examples indicates the probability of the example belonging to each of the K classes, according to the model. targets: binary NxK tensort that encodes which of the K classes are associated with the N-th input (eg: a row [0, 1, 0, 1] indicates that the example is associated with classes 2 and 4) threshold: threshold for for model output Returns: computed multilabel accuracy Examples: .. code-block:: python import torch from catalyst import metrics metrics.multilabel_accuracy( outputs=torch.tensor([ [1, 0], [0, 1], ]), targets=torch.tensor([ [1, 0], [0, 1], ]), threshold=0.5, ) # tensor([1.]) .. code-block:: python import torch from catalyst import metrics metrics.multilabel_accuracy( outputs=torch.tensor([ [1.0, 0.0], [0.6, 1.0], ]), targets=torch.tensor([ [1, 0], [0, 1], ]), threshold=0.5, ) # tensor(0.7500) .. code-block:: python import torch from catalyst import metrics metrics.multilabel_accuracy( outputs=torch.tensor([ [1.0, 0.0], [0.4, 1.0], ]), targets=torch.tensor([ [1, 0], [0, 1], ]), threshold=0.5, ) # tensor(1.0) """ outputs, targets, _, _ = process_multilabel_components(outputs=outputs, targets=targets) outputs = (outputs > threshold).long() output = (targets.long() == outputs.long()).sum().float() / return output
__all__ = ["accuracy", "multilabel_accuracy"]