Source code for catalyst.contrib.utils.confusion_matrix

# flake8: noqa
# @TODO: code formatting issue for 20.07 release
import numpy as np

import torch

[docs]def calculate_tp_fp_fn(confusion_matrix: np.ndarray) -> np.ndarray: """@TODO: Docs. Contribution is welcome.""" true_positives = np.diag(confusion_matrix) false_positives = confusion_matrix.sum(axis=0) - true_positives false_negatives = confusion_matrix.sum(axis=1) - true_positives return { "true_positives": true_positives, "false_positives": false_positives, "false_negatives": false_negatives, }
[docs]def calculate_confusion_matrix_from_arrays( predictions: np.ndarray, labels: np.ndarray, num_classes: int ) -> np.ndarray: """Calculate confusion matrix for a given set of classes. If labels value is outside of the [0, num_classes) it is excluded. Args: predictions (np.ndarray): model predictions labels (np.ndarray): ground truth labels num_classes (int): number of classes Returns: np.ndarray: confusion matrix """ # @TODO: add `num_class`=None handling # a long 2xn array with each column being a pixel pair replace_indices = np.vstack((labels.flatten(), predictions.flatten())) valid_index = replace_indices[0, :] < num_classes replace_indices = replace_indices[:, valid_index].T # add up confusion matrix confusion_matrix, _ = np.histogramdd( replace_indices, bins=(num_classes, num_classes), range=[(0, num_classes), (0, num_classes)], ) return confusion_matrix.astype(np.uint64)
[docs]def calculate_confusion_matrix_from_tensors( y_pred_logits: torch.Tensor, y_true: torch.Tensor ) -> np.ndarray: """ Calculate confusion matrix from tensors. Args: y_pred_logits: model logits y_true: true labels Returns: np.ndarray: confusion matrix """ num_classes = y_pred_logits.shape[1] y_pred = torch.argmax(y_pred_logits, dim=1) predictions = y_pred.cpu().numpy() labels = y_true.cpu().numpy() return calculate_confusion_matrix_from_arrays( predictions, labels, num_classes )
__all__ = [ "calculate_tp_fp_fn", "calculate_confusion_matrix_from_arrays", "calculate_confusion_matrix_from_tensors", ]