Shortcuts

Source code for catalyst.metrics._confusion_matrix

from typing import Any, List

import numpy as np

import torch

from catalyst.metrics._metric import IMetric
from catalyst.settings import SETTINGS
from catalyst.utils import get_device
from catalyst.utils.distributed import all_gather, get_backend

if SETTINGS.xla_required:
    import torch_xla.core.xla_model as xm


[docs]class ConfusionMatrixMetric(IMetric): """Constructs a confusion matrix for a multiclass classification problems. Args: num_classes: number of classes in the classification problem normalize: determines whether or not the confusion matrix is normalize or not compute_on_call: Boolean flag to computes and return confusion matrix during __call__. default: True Examples: .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst import dl # sample data num_samples, num_features, num_classes = int(1e4), int(1e1), 4 X = torch.rand(num_samples, num_features) y = (torch.rand(num_samples,) * num_classes).to(torch.int64) # pytorch loaders dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, num_classes) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2]) # model training runner = dl.SupervisedRunner( input_key="features", output_key="logits", target_key="targets", loss_key="loss" ) runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=3, valid_loader="valid", valid_metric="accuracy03", minimize_valid_metric=False, verbose=True, callbacks=[ dl.AccuracyCallback( input_key="logits", target_key="targets", num_classes=num_classes ), dl.PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=num_classes ), dl.AUCCallback(input_key="logits", target_key="targets"), dl.ConfusionMatrixCallback( input_key="logits", target_key="targets", num_classes=num_classes ), ], ) .. note:: Please follow the `minimal examples`_ sections for more use cases. .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples # noqa: E501, W505 """ def __init__( self, num_classes: int, normalize: bool = False, compute_on_call: bool = True ): """Constructs a confusion matrix for a multiclass classification problems.""" super().__init__(compute_on_call=compute_on_call) self.num_classes = num_classes self.normalize = normalize self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) self._ddp_backend = None self.reset() def reset(self) -> None: """Reset confusion matrix, filling it with zeros.""" self.conf.fill(0) self._ddp_backend = get_backend() def update(self, predictions: torch.Tensor, targets: torch.Tensor) -> None: """Computes the confusion matrix of ``K x K`` size where ``K`` is no of classes. Args: predictions: Can be an N x K tensor of predicted scores obtained from the model for N examples and K classes or an N-tensor of integer values between 0 and K-1 targets: Can be a N-tensor of integer values assumed to be integer values between 0 and K-1 or N x K tensor, where targets are assumed to be provided as one-hot vectors """ predictions = predictions.cpu().numpy() targets = targets.cpu().numpy() assert ( predictions.shape[0] == targets.shape[0] ), "number of targets and predicted outputs do not match" if np.ndim(predictions) != 1: assert ( predictions.shape[1] == self.num_classes ), "number of predictions does not match size of confusion matrix" predictions = np.argmax(predictions, 1) else: assert (predictions.max() < self.num_classes) and ( predictions.min() >= 0 ), "predicted values are not between 1 and k" onehot_target = np.ndim(targets) != 1 if onehot_target: assert ( targets.shape[1] == self.num_classes ), "Onehot target does not match size of confusion matrix" assert (targets >= 0).all() and ( targets <= 1 ).all(), "in one-hot encoding, target values should be 0 or 1" assert (targets.sum(1) == 1).all(), "multilabel setting is not supported" targets = np.argmax(targets, 1) else: assert (predictions.max() < self.num_classes) and ( predictions.min() >= 0 ), "predicted values are not between 0 and k-1" # hack for bincounting 2 arrays together x = predictions + self.num_classes * targets bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes ** 2) assert bincount_2d.size == self.num_classes ** 2 conf = bincount_2d.reshape((self.num_classes, self.num_classes)) self.conf += conf def compute(self) -> Any: """ Returns: Confusion matrix of K rows and K columns, where rows corresponds to ground-truth targets and columns corresponds to predicted targets. """ # ddp hotfix, could be done better # but metric must handle DDP on it's own if self._ddp_backend == "xla": # if you have "RuntimeError: Aborted: Session XXX is not found" here # please, ask Google for a more powerful TPU setup ;) device = get_device() value = torch.tensor([self.conf], device=device) self.conf = xm.all_gather(value).sum(0).cpu().detach().numpy() elif self._ddp_backend == "ddp": value: List[np.ndarray] = all_gather(self.conf) value: np.ndarray = np.sum(np.stack(value, axis=0), axis=0) self.conf = value if self.normalize: conf = self.conf.astype(np.float32) return conf / conf.sum(1).clip(min=1e-12)[:, None] else: return self.conf
__all__ = ["ConfusionMatrixMetric"]