Source code for catalyst.callbacks.metrics.confusion_matrix

from typing import Dict, List, TYPE_CHECKING

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.metrics._confusion_matrix import ConfusionMatrixMetric
from catalyst.settings import SETTINGS

if SETTINGS.ml_required:
    from catalyst.contrib.utils.visualization import plot_confusion_matrix, render_figure_to_array

    from catalyst.core.runner import IRunner

[docs]class ConfusionMatrixCallback(Callback): """Callback to plot your confusion matrix to the loggers. Args: input_key: key to use from ``runner.batch``, specifies our ``y_pred`` target_key: key to use from ``runner.batch``, specifies our ``y_true`` prefix: plot name for monitoring tools class_names: list with class names num_classes: number of classes normalized: boolean flag for confusion matrix normalization plot_params: extra params for plt.figure rendering .. note:: catalyst[ml] required for this callback Examples: .. code-block:: python import torch from import DataLoader, TensorDataset from catalyst import dl # sample data num_samples, num_features, num_classes = int(1e4), int(1e1), 4 X = torch.rand(num_samples, num_features) y = (torch.rand(num_samples,) * num_classes).to(torch.int64) # pytorch loaders dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, num_classes) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2]) # model training runner = dl.SupervisedRunner( input_key="features", output_key="logits", target_key="targets", loss_key="loss" ) runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=3, valid_loader="valid", valid_metric="accuracy03", minimize_valid_metric=False, verbose=True, callbacks=[ dl.AccuracyCallback( input_key="logits", target_key="targets", num_classes=num_classes ), dl.PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=num_classes ), dl.AUCCallback(input_key="logits", target_key="targets"), dl.ConfusionMatrixCallback( input_key="logits", target_key="targets", num_classes=num_classes ), ], ) .. note:: Please follow the `minimal examples`_ sections for more use cases. .. _`minimal examples`: """
[docs] def __init__( self, input_key: str, target_key: str, prefix: str = None, class_names: List[str] = None, num_classes: int = None, normalized: bool = False, plot_params: Dict = None, ): """Callback initialisation.""" super().__init__(CallbackOrder.metric, CallbackNode.all) assert num_classes is not None or class_names is not None self.prefix = prefix or "confusion_matrix" self.input_key = input_key self.target_key = target_key self._plot_params = plot_params or {} self.class_names = class_names or [f"class_{i:02d}" for i in range(num_classes)] self.num_classes = num_classes if class_names is None else len(class_names) self.normalized = normalized assert self.num_classes is not None self.confusion_matrix = ConfusionMatrixMetric( num_classes=self.num_classes, normalized=self.normalized )
def on_loader_start(self, runner: "IRunner"): """Loader start hook. Args: runner: current runner """ self.confusion_matrix.reset() def on_batch_end(self, runner: "IRunner"): """Batch end hook. Args: runner: current runner """ inputs, targets = ( runner.batch[self.input_key].detach(), runner.batch[self.target_key].detach(), ) self.confusion_matrix.update(predictions=inputs, targets=targets) def on_loader_end(self, runner: "IRunner"): """Loader end hook. Args: runner: current runner """ confusion_matrix = self.confusion_matrix.compute() fig = plot_confusion_matrix( confusion_matrix, class_names=self.class_names, normalize=self.normalized, show=False, **self._plot_params, ) image = render_figure_to_array(fig) runner.log_image(tag=self.prefix, image=image, scope="loader")
__all__ = ["ConfusionMatrixCallback"]