Source code for catalyst.metrics._confusion_matrix

from typing import Any, List

import numpy as np
import torch

from catalyst.metrics._metric import IMetric
from catalyst.utils.distributed import all_gather, get_rank

[docs]class ConfusionMatrixMetric(IMetric): """Constructs a confusion matrix for a multiclass classification problems. Args: num_classes: number of classes in the classification problem normalized: determines whether or not the confusion matrix is normalized or not compute_on_call: Boolean flag to computes and return confusion matrix during __call__. default: True Examples: .. code-block:: python import torch from import DataLoader, TensorDataset from catalyst import dl # sample data num_samples, num_features, num_classes = int(1e4), int(1e1), 4 X = torch.rand(num_samples, num_features) y = (torch.rand(num_samples,) * num_classes).to(torch.int64) # pytorch loaders dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, num_classes) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2]) # model training runner = dl.SupervisedRunner( input_key="features", output_key="logits", target_key="targets", loss_key="loss" ) runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=3, valid_loader="valid", valid_metric="accuracy03", minimize_valid_metric=False, verbose=True, callbacks=[ dl.AccuracyCallback( input_key="logits", target_key="targets", num_classes=num_classes ), dl.PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=num_classes ), dl.AUCCallback(input_key="logits", target_key="targets"), dl.ConfusionMatrixCallback( input_key="logits", target_key="targets", num_classes=num_classes ), ], ) .. note:: Please follow the `minimal examples`_ sections for more use cases. .. _`minimal examples`: """ def __init__(self, num_classes: int, normalized: bool = False, compute_on_call: bool = True): """Constructs a confusion matrix for a multiclass classification problems.""" super().__init__(compute_on_call=compute_on_call) self.num_classes = num_classes self.normalized = normalized self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) self._is_ddp = False self.reset() def reset(self) -> None: """Reset confusion matrix, filling it with zeros.""" self.conf.fill(0) self._is_ddp = get_rank() > -1 def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None: """Computes the confusion matrix of ``K x K`` size where ``K`` is no of classes. Args: predictions: Can be an N x K tensor of predicted scores obtained from the model for N examples and K classes or an N-tensor of integer values between 0 and K-1 targets: Can be a N-tensor of integer values assumed to be integer values between 0 and K-1 or N x K tensor, where targets are assumed to be provided as one-hot vectors """ predictions = predictions.cpu().numpy() targets = targets.cpu().numpy() assert ( predictions.shape[0] == targets.shape[0] ), "number of targets and predicted outputs do not match" if np.ndim(predictions) != 1: assert ( predictions.shape[1] == self.num_classes ), "number of predictions does not match size of confusion matrix" predictions = np.argmax(predictions, 1) else: assert (predictions.max() < self.num_classes) and ( predictions.min() >= 0 ), "predicted values are not between 1 and k" onehot_target = np.ndim(targets) != 1 if onehot_target: assert ( targets.shape[1] == self.num_classes ), "Onehot target does not match size of confusion matrix" assert (targets >= 0).all() and ( targets <= 1 ).all(), "in one-hot encoding, target values should be 0 or 1" assert (targets.sum(1) == 1).all(), "multilabel setting is not supported" targets = np.argmax(targets, 1) else: assert (predictions.max() < self.num_classes) and ( predictions.min() >= 0 ), "predicted values are not between 0 and k-1" # hack for bincounting 2 arrays together x = predictions + self.num_classes * targets bincount_2d = np.bincount( x.astype(np.int32), minlength=self.num_classes ** 2 ) # noqa: WPS114 assert bincount_2d.size == self.num_classes ** 2 conf = bincount_2d.reshape((self.num_classes, self.num_classes)) self.conf += conf def compute(self) -> Any: """ Returns: Confusion matrix of K rows and K columns, where rows corresponds to ground-truth targets and columns corresponds to predicted targets. """ if self._is_ddp: value: List[np.ndarray] = all_gather(self.conf) value: np.ndarray = np.sum(np.stack(value, axis=0), axis=0) self.conf = value if self.normalized: conf = self.conf.astype(np.float32) return conf / conf.sum(1).clip(min=1e-12)[:, None] else: return self.conf
__all__ = ["ConfusionMatrixMetric"]