Source code for catalyst.tools.meters.confusionmeter
"""
Maintains a confusion matrix for a given classification problem.
"""
import numpy as np
import torch
from catalyst.tools.meters import meter
[docs]class ConfusionMeter(meter.Meter):
"""
ConfusionMeter constructs a confusion matrix for a multiclass
classification problems. It does not support multilabel, multiclass
problems: for such problems, please use MultiLabelConfusionMeter.
"""
[docs] def __init__(self, k: int, normalized: bool = False):
"""
Args:
k: number of classes in the classification problem
normalized: determines whether or not the confusion
matrix is normalized or not
"""
super(ConfusionMeter, self).__init__()
self.conf = np.ndarray((k, k), dtype=np.int32)
self.normalized = normalized
self.k = k
self.reset()
[docs] def reset(self) -> None:
"""Reset confusion matrix, filling it with zeros."""
self.conf.fill(0)
[docs] def add(self, predicted: torch.Tensor, target: torch.Tensor) -> None:
"""Computes the confusion matrix of ``K x K`` size
where ``K`` is no of classes.
Args:
predicted: Can be an N x K tensor of predicted scores
obtained from the model for N examples and K classes
or an N-tensor of integer values between 0 and K-1
target: Can be a N-tensor of integer values assumed
to be integer values between 0 and K-1 or N x K tensor, where
targets are assumed to be provided as one-hot vectors
"""
predicted = predicted.cpu().numpy()
target = target.cpu().numpy()
assert (
predicted.shape[0] == target.shape[0]
), "number of targets and predicted outputs do not match"
if np.ndim(predicted) != 1:
assert (
predicted.shape[1] == self.k
), "number of predictions does not match size of confusion matrix"
predicted = np.argmax(predicted, 1)
else:
assert (predicted.max() < self.k) and (
predicted.min() >= 0
), "predicted values are not between 1 and k"
onehot_target = np.ndim(target) != 1
if onehot_target:
assert (
target.shape[1] == self.k
), "Onehot target does not match size of confusion matrix"
assert (target >= 0).all() and (
target <= 1
).all(), "in one-hot encoding, target values should be 0 or 1"
assert (
target.sum(1) == 1
).all(), "multilabel setting is not supported"
target = np.argmax(target, 1)
else:
assert (predicted.max() < self.k) and (
predicted.min() >= 0
), "predicted values are not between 0 and k-1"
# hack for bincounting 2 arrays together
x = predicted + self.k * target
bincount_2d = np.bincount( # noqa: WPS114
x.astype(np.int32), minlength=self.k ** 2
)
assert bincount_2d.size == self.k ** 2
conf = bincount_2d.reshape((self.k, self.k))
self.conf += conf
[docs] def value(self):
"""
Returns:
Confusion matrix of K rows and K columns, where rows corresponds
to ground-truth targets and columns corresponds to predicted
targets.
"""
if self.normalized:
conf = self.conf.astype(np.float32)
return conf / conf.sum(1).clip(min=1e-12)[:, None]
else:
return self.conf
__all__ = ["ConfusionMeter"]