Shortcuts

Source code for catalyst.dl.callbacks.confusion_matrix

from typing import Dict, List

import numpy as np
from sklearn.metrics import confusion_matrix as confusion_matrix_fn

import torch
import torch.distributed

from catalyst.dl import Callback, CallbackNode, CallbackOrder, State, utils
from catalyst.tools import meters


[docs]class ConfusionMatrixCallback(Callback): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "confusion_matrix", version: str = "tnt", class_names: List[str] = None, num_classes: int = None, plot_params: Dict = None, tensorboard_callback_name: str = "_tensorboard", ): """ Args: @TODO: Docs. Contribution is welcome """ super().__init__(CallbackOrder.Metric, CallbackNode.All) self.prefix = prefix self.output_key = output_key self.input_key = input_key self.tensorboard_callback_name = tensorboard_callback_name assert version in ["tnt", "sklearn"] self._version = version self._plot_params = plot_params or {} self.class_names = class_names self.num_classes = ( num_classes if class_names is None else len(class_names) ) assert self.num_classes is not None self._reset_stats()
def _reset_stats(self): if self._version == "tnt": self.confusion_matrix = meters.ConfusionMeter(self.num_classes) elif self._version == "sklearn": self.outputs = [] self.targets = [] def _add_to_stats(self, outputs, targets): if self._version == "tnt": self.confusion_matrix.add(predicted=outputs, target=targets) elif self._version == "sklearn": outputs = outputs.cpu().numpy() targets = targets.cpu().numpy() outputs = np.argmax(outputs, axis=1) self.outputs.extend(outputs) self.targets.extend(targets) def _compute_confusion_matrix(self): if self._version == "tnt": confusion_matrix = self.confusion_matrix.value() elif self._version == "sklearn": confusion_matrix = confusion_matrix_fn( y_true=self.targets, y_pred=self.outputs ) else: raise NotImplementedError() return confusion_matrix def _plot_confusion_matrix( self, logger, epoch, confusion_matrix, class_names=None ): fig = utils.plot_confusion_matrix( confusion_matrix, class_names=class_names, normalize=True, show=False, **self._plot_params, ) fig = utils.render_figure_to_tensor(fig) logger.add_image(f"{self.prefix}/epoch", fig, global_step=epoch)
[docs] def on_loader_start(self, state: State): """Loader start hook. Args: state (State): current state """ self._reset_stats()
[docs] def on_batch_end(self, state: State): """Batch end hook. Args: state (State): current state """ self._add_to_stats( state.output[self.output_key].detach(), state.input[self.input_key].detach(), )
[docs] def on_loader_end(self, state: State): """Loader end hook. Args: state (State): current state """ class_names = self.class_names or [ str(i) for i in range(self.num_classes) ] confusion_matrix = self._compute_confusion_matrix() if state.distributed_rank >= 0: confusion_matrix = torch.from_numpy(confusion_matrix) confusion_matrix = confusion_matrix.to(utils.get_device()) torch.distributed.reduce(confusion_matrix, 0) confusion_matrix = confusion_matrix.cpu().numpy() if state.distributed_rank <= 0: tb_callback = state.callbacks[self.tensorboard_callback_name] self._plot_confusion_matrix( logger=tb_callback.loggers[state.loader_name], epoch=state.global_epoch, confusion_matrix=confusion_matrix, class_names=class_names, )
__all__ = ["ConfusionMatrixCallback"]