from typing import Sequence, Union
import numpy as np
import torch
from catalyst.metrics.functional import process_multilabel_components
[docs]def accuracy(
outputs: torch.Tensor, targets: torch.Tensor, topk: Sequence[int] = (1,)
) -> Sequence[torch.Tensor]:
"""
Computes multiclass accuracy@topk for the specified values of `topk`.
Args:
outputs: model outputs, logits
with shape [bs; num_classes]
targets: ground truth, labels
with shape [bs; 1]
topk: `topk` for accuracy@topk computing
Returns:
list with computed accuracy@topk
Examples:
.. code-block:: python
import torch
from catalyst import metrics
metrics.accuracy(
outputs=torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
]),
targets=torch.tensor([0, 1, 2]),
)
# [tensor([1.])]
.. code-block:: python
import torch
from catalyst import metrics
metrics.accuracy(
outputs=torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
]),
targets=torch.tensor([0, 1, 2]),
)
# [tensor([0.6667])]
.. code-block:: python
import torch
from catalyst import metrics
metrics.accuracy(
outputs=torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
]),
targets=torch.tensor([0, 1, 2]),
topk=[1, 3],
)
# [tensor([1.]), tensor([1.])]
.. code-block:: python
import torch
from catalyst import metrics
metrics.accuracy(
outputs=torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
]),
targets=torch.tensor([0, 1, 2]),
topk=[1, 3],
)
# [tensor([0.6667]), tensor([1.])]
"""
max_k = max(topk)
batch_size = targets.size(0)
if len(outputs.shape) == 1 or outputs.shape[1] == 1:
# binary accuracy
pred = outputs.t()
else:
# multiclass accuracy
_, pred = outputs.topk(max_k, 1, True, True)
pred = pred.t()
correct = pred.eq(targets.long().view(1, -1).expand_as(pred))
output = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
output.append(correct_k.mul_(1.0 / batch_size))
return output
[docs]def multilabel_accuracy(
outputs: torch.Tensor, targets: torch.Tensor, threshold: Union[float, torch.Tensor]
) -> torch.Tensor:
"""
Computes multilabel accuracy for the specified activation and threshold.
Args:
outputs: NxK tensor that for each of the N examples
indicates the probability of the example belonging to each of
the K classes, according to the model.
targets: binary NxK tensort that encodes which of the K
classes are associated with the N-th input
(eg: a row [0, 1, 0, 1] indicates that the example is
associated with classes 2 and 4)
threshold: threshold for for model output
Returns:
computed multilabel accuracy
Examples:
.. code-block:: python
import torch
from catalyst import metrics
metrics.multilabel_accuracy(
outputs=torch.tensor([
[1, 0],
[0, 1],
]),
targets=torch.tensor([
[1, 0],
[0, 1],
]),
threshold=0.5,
)
# tensor([1.])
.. code-block:: python
import torch
from catalyst import metrics
metrics.multilabel_accuracy(
outputs=torch.tensor([
[1.0, 0.0],
[0.6, 1.0],
]),
targets=torch.tensor([
[1, 0],
[0, 1],
]),
threshold=0.5,
)
# tensor(0.7500)
.. code-block:: python
import torch
from catalyst import metrics
metrics.multilabel_accuracy(
outputs=torch.tensor([
[1.0, 0.0],
[0.4, 1.0],
]),
targets=torch.tensor([
[1, 0],
[0, 1],
]),
threshold=0.5,
)
# tensor(1.0)
"""
outputs, targets, _ = process_multilabel_components(outputs=outputs, targets=targets)
outputs = (outputs > threshold).long()
output = (targets.long() == outputs.long()).sum().float() / np.prod(targets.shape)
return output
__all__ = ["accuracy", "multilabel_accuracy"]