Shortcuts

Source code for catalyst.metrics.functional

from typing import Callable, Dict, Optional, Sequence, Tuple
from functools import partial

import numpy as np

import torch
from torch import Tensor
from torch.nn import functional as F

# @TODO:
# after full classification metrics re-implementation, make a reference to
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics
# as a baseline


def process_multiclass_components(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    argmax_dim: int = -1,
    num_classes: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    Preprocess input in case multiclass classification task.

    Args:
        outputs: estimated targets as predicted by a model
            with shape [bs; ..., (num_classes or 1)]
        targets: ground truth (correct) target values
            with shape [bs; ..., 1]
        argmax_dim: int, that specifies dimension for argmax transformation
            in case of scores/probabilities in ``outputs``
        num_classes: int, that specifies number of classes if it known

    Returns:
        preprocessed outputs, targets and num_classes
    """
    # @TODO: better multiclass preprocessing, label -> class_id mapping
    if not torch.is_tensor(outputs):
        outputs = torch.from_numpy(np.array(outputs))
    if not torch.is_tensor(targets):
        targets = torch.from_numpy(np.array(targets))

    if outputs.dim() == targets.dim() + 1:
        # looks like we have scores/probabilities in our outputs
        # let's convert them to final model predictions
        num_classes = max(
            outputs.shape[argmax_dim], int(targets.max().detach().item() + 1)
        )
        outputs = torch.argmax(outputs, dim=argmax_dim)
    if num_classes is None:
        # as far as we expect the outputs/targets tensors to be int64
        # we could find number of classes as max available number
        num_classes = max(
            int(outputs.max().detach().item() + 1),
            int(targets.max().detach().item() + 1),
        )

    if outputs.dim() == 1:
        outputs = outputs.view(-1, 1)
    elif outputs.dim() == 2 and outputs.size(0) == 1:
        # transpose case
        outputs.permute(1, 0)
    else:
        assert outputs.size(1) == 1 and outputs.dim() == 2, (
            "Wrong `outputs` shape, "
            "expected 1D or 2D with size 1 in the second dim "
            "got {}".format(outputs.shape)
        )

    if targets.dim() == 1:
        targets = targets.view(-1, 1)
    elif targets.dim() == 2 and targets.size(0) == 1:
        # transpose case
        targets.permute(1, 0)
    else:
        assert targets.size(1) == 1 and targets.dim() == 2, (
            "Wrong `outputs` shape, "
            "expected 1D or 2D with size 1 in the second dim"
        )

    return outputs, targets, num_classes


[docs]def process_multilabel_components( outputs: torch.Tensor, targets: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """General preprocessing for multi-label-based metrics. Args: outputs: NxK tensor that for each of the N examples indicates the probability of the example belonging to each of the K classes, according to the model. targets: binary NxK tensor that encodes which of the K classes are associated with the N-th input (eg: a row [0, 1, 0, 1] indicates that the example is associated with classes 2 and 4) weights: importance for each sample Returns: processed ``outputs`` and ``targets`` with [batch_size; num_classes] shape """ if not torch.is_tensor(outputs): outputs = torch.from_numpy(outputs) if not torch.is_tensor(targets): targets = torch.from_numpy(targets) if weights is not None: if not torch.is_tensor(weights): weights = torch.from_numpy(weights) weights = weights.squeeze() if outputs.dim() == 1: outputs = outputs.view(-1, 1) else: assert outputs.dim() == 2, ( "wrong `outputs` size " "(should be 1D or 2D with one column per class)" ) if targets.dim() == 1: if outputs.shape[1] > 1: # multi-class case num_classes = outputs.shape[1] targets = F.one_hot(targets, num_classes).float() else: # binary case targets = targets.view(-1, 1) else: assert targets.dim() == 2, ( "wrong `targets` size " "(should be 1D or 2D with one column per class)" ) if weights is not None: assert weights.dim() == 1, "Weights dimension should be 1" assert weights.numel() == targets.size( 0 ), "Weights dimension 1 should be the same as that of target" assert torch.min(weights) >= 0, "Weight should be non-negative only" assert torch.equal( targets ** 2, targets ), "targets should be binary (0 or 1)" return outputs, targets, weights
[docs]def get_binary_statistics( outputs: Tensor, targets: Tensor, label: int = 1, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Computes the number of true negative, false positive, false negative, true negative and support for a binary classification problem for a given label. Args: outputs: estimated targets as predicted by a model with shape [bs; ..., 1] targets: ground truth (correct) target values with shape [bs; ..., 1] label: integer, that specifies label of interest for statistics compute Returns: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: stats Example: >>> y_pred = torch.tensor([[0, 0, 1, 1, 0, 1, 0, 1]]) >>> y_true = torch.tensor([[0, 1, 0, 1, 0, 0, 1, 1]]) >>> tn, fp, fn, tp, support = get_binary_statistics(y_pred, y_true) tensor(2) tensor(2) tensor(2) tensor(2) tensor(4) """ tn = ((outputs != label) * (targets != label)).to(torch.long).sum() fp = ((outputs == label) * (targets != label)).to(torch.long).sum() fn = ((outputs != label) * (targets == label)).to(torch.long).sum() tp = ((outputs == label) * (targets == label)).to(torch.long).sum() support = (targets == label).to(torch.long).sum() return tn, fp, fn, tp, support
[docs]def get_multiclass_statistics( outputs: Tensor, targets: Tensor, argmax_dim: int = -1, num_classes: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Computes the number of true negative, false positive, false negative, true negative and support for a multi-class classification problem. Args: outputs: estimated targets as predicted by a model with shape [bs; ..., (num_classes or 1)] targets: ground truth (correct) target values with shape [bs; ..., 1] argmax_dim: int, that specifies dimension for argmax transformation in case of scores/probabilities in ``outputs`` num_classes: int, that specifies number of classes if it known Returns: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: stats Example: >>> y_pred = torch.tensor([1, 2, 3, 0]) >>> y_true = torch.tensor([1, 3, 4, 0]) >>> tn, fp, fn, tp, support = get_multiclass_statistics(y_pred, y_true) tensor([3., 3., 3., 2., 3.]), tensor([0., 0., 1., 1., 0.]), tensor([0., 0., 0., 1., 1.]), tensor([1., 1., 0., 0., 0.]), tensor([1., 1., 0., 1., 1.]) """ outputs, targets, num_classes = process_multiclass_components( outputs=outputs, targets=targets, argmax_dim=argmax_dim, num_classes=num_classes, ) tn = torch.zeros((num_classes,), device=outputs.device) fp = torch.zeros((num_classes,), device=outputs.device) fn = torch.zeros((num_classes,), device=outputs.device) tp = torch.zeros((num_classes,), device=outputs.device) support = torch.zeros((num_classes,), device=outputs.device) for class_index in range(num_classes): ( tn[class_index], fp[class_index], fn[class_index], tp[class_index], support[class_index], ) = get_binary_statistics( outputs=outputs, targets=targets, label=class_index ) return tn, fp, fn, tp, support
[docs]def get_multilabel_statistics( outputs: Tensor, targets: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Computes the number of true negative, false positive, false negative, true negative and support for a multi-label classification problem. Args: outputs: estimated targets as predicted by a model with shape [bs; ..., (num_classes or 1)] targets: ground truth (correct) target values with shape [bs; ..., 1] Returns: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: stats Example: >>> y_pred = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]) >>> y_true = torch.tensor([[0, 1, 0, 1], [0, 0, 1, 1]]) >>> tn, fp, fn, tp, support = get_multilabel_statistics(y_pred, y_true) tensor([2., 0., 0., 0.]) tensor([0., 1., 1., 0.]), tensor([0., 1., 1., 0.]) tensor([0., 0., 0., 2.]), tensor([0., 1., 1., 2.]) >>> y_pred = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) >>> y_true = torch.tensor([0, 1, 2]) >>> tn, fp, fn, tp, support = get_multilabel_statistics(y_pred, y_true) tensor([2., 2., 2.]) tensor([0., 0., 0.]) tensor([0., 0., 0.]) tensor([1., 1., 1.]) tensor([1., 1., 1.]) >>> y_pred = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) >>> y_true = torch.nn.functional.one_hot(torch.tensor([0, 1, 2])) >>> tn, fp, fn, tp, support = get_multilabel_statistics(y_pred, y_true) tensor([2., 2., 2.]) tensor([0., 0., 0.]) tensor([0., 0., 0.]) tensor([1., 1., 1.]) tensor([1., 1., 1.]) """ outputs, targets, _ = process_multilabel_components( outputs=outputs, targets=targets ) assert outputs.shape == targets.shape num_classes = outputs.shape[-1] tn = torch.zeros((num_classes,), device=outputs.device) fp = torch.zeros((num_classes,), device=outputs.device) fn = torch.zeros((num_classes,), device=outputs.device) tp = torch.zeros((num_classes,), device=outputs.device) support = torch.zeros((num_classes,), device=outputs.device) for class_index in range(num_classes): class_outputs = outputs[..., class_index] class_targets = targets[..., class_index] ( tn[class_index], fp[class_index], fn[class_index], tp[class_index], support[class_index], ) = get_binary_statistics( outputs=class_outputs, targets=class_targets, label=1 ) return tn, fp, fn, tp, support
[docs]def get_default_topk_args(num_classes: int) -> Sequence[int]: """Calculate list params for ``Accuracy@k`` and ``mAP@k``. Args: num_classes: number of classes Returns: iterable: array of accuracy arguments Examples: >>> get_default_topk_args(num_classes=4) [1, 3] >>> get_default_topk_args(num_classes=8) [1, 3, 5] """ result = [1] if num_classes is None: return result if num_classes > 3: result.append(3) if num_classes > 5: result.append(5) return result
[docs]def wrap_class_metric2dict( metric_fn: Callable, class_args: Sequence[str] = None ) -> Callable: """# noqa: D202 Logging wrapper for metrics with torch.Tensor output and [num_classes] shape. Computes the metric and sync each element from the output Tensor with passed `class` argument. Args: metric_fn: metric function to compute class_args: class names for logging. default: None - class indexes will be used. Returns: wrapped metric function with List[Dict] output """ def class_metric_with_dict_output(*args, **kwargs): output = metric_fn(*args, **kwargs) num_classes = len(output) output_class_args = class_args or [ f"/class_{i:02}" for i in range(num_classes) ] mean_stats = torch.mean(output).item() output = { key: value.item() for key, value in zip(output_class_args, output) } output["/mean"] = mean_stats return output return class_metric_with_dict_output
[docs]def wrap_topk_metric2dict( metric_fn: Callable, topk_args: Sequence[int] ) -> Callable: """ Logging wrapper for metrics with Sequence[Union[torch.Tensor, int, float, Dict]] output. Computes the metric and sync each element from the output sequence with passed `topk` argument. Args: metric_fn: metric function to compute topk_args: topk args to sync outputs with Returns: wrapped metric function with List[Dict] output Raises: NotImplementedError: if metrics returned values are out of torch.Tensor, int, float, Dict union. """ metric_fn = partial(metric_fn, topk=topk_args) def topk_metric_with_dict_output(*args, **kwargs): output: Sequence = metric_fn(*args, **kwargs) if isinstance(output[0], (int, float, torch.Tensor)): output = { f"{topk_key:02}": metric_value for topk_key, metric_value in zip(topk_args, output) } elif isinstance(output[0], Dict): output = { { f"{metric_key}{topk_key:02}": metric_value for metric_key, metric_value in metric_dict_value.items() } for topk_key, metric_dict_value in zip(topk_args, output) } else: raise NotImplementedError() return output return topk_metric_with_dict_output
__all__ = [ "process_multilabel_components", "get_binary_statistics", "get_multiclass_statistics", "get_multilabel_statistics", "get_default_topk_args", "wrap_topk_metric2dict", "wrap_class_metric2dict", ]