Source code for catalyst.dl.callbacks.confusion_matrix
from typing import Dict, List
import numpy as np
from sklearn.metrics import confusion_matrix as confusion_matrix_fn
import torch
import torch.distributed
from catalyst.dl import Callback, CallbackNode, CallbackOrder, State, utils
from catalyst.tools import meters
[docs]class ConfusionMatrixCallback(Callback):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "confusion_matrix",
version: str = "tnt",
class_names: List[str] = None,
num_classes: int = None,
plot_params: Dict = None,
tensorboard_callback_name: str = "_tensorboard",
):
"""
Args:
@TODO: Docs. Contribution is welcome
"""
super().__init__(CallbackOrder.Metric, CallbackNode.All)
self.prefix = prefix
self.output_key = output_key
self.input_key = input_key
self.tensorboard_callback_name = tensorboard_callback_name
assert version in ["tnt", "sklearn"]
self._version = version
self._plot_params = plot_params or {}
self.class_names = class_names
self.num_classes = (
num_classes if class_names is None else len(class_names)
)
assert self.num_classes is not None
self._reset_stats()
def _reset_stats(self):
if self._version == "tnt":
self.confusion_matrix = meters.ConfusionMeter(self.num_classes)
elif self._version == "sklearn":
self.outputs = []
self.targets = []
def _add_to_stats(self, outputs, targets):
if self._version == "tnt":
self.confusion_matrix.add(predicted=outputs, target=targets)
elif self._version == "sklearn":
outputs = outputs.cpu().numpy()
targets = targets.cpu().numpy()
outputs = np.argmax(outputs, axis=1)
self.outputs.extend(outputs)
self.targets.extend(targets)
def _compute_confusion_matrix(self):
if self._version == "tnt":
confusion_matrix = self.confusion_matrix.value()
elif self._version == "sklearn":
confusion_matrix = confusion_matrix_fn(
y_true=self.targets, y_pred=self.outputs
)
else:
raise NotImplementedError()
return confusion_matrix
def _plot_confusion_matrix(
self, logger, epoch, confusion_matrix, class_names=None
):
fig = utils.plot_confusion_matrix(
confusion_matrix,
class_names=class_names,
normalize=True,
show=False,
**self._plot_params,
)
fig = utils.render_figure_to_tensor(fig)
logger.add_image(f"{self.prefix}/epoch", fig, global_step=epoch)
[docs] def on_loader_start(self, state: State):
"""Loader start hook.
Args:
state (State): current state
"""
self._reset_stats()
[docs] def on_batch_end(self, state: State):
"""Batch end hook.
Args:
state (State): current state
"""
self._add_to_stats(
state.output[self.output_key].detach(),
state.input[self.input_key].detach(),
)
[docs] def on_loader_end(self, state: State):
"""Loader end hook.
Args:
state (State): current state
"""
class_names = self.class_names or [
str(i) for i in range(self.num_classes)
]
confusion_matrix = self._compute_confusion_matrix()
if state.distributed_rank >= 0:
confusion_matrix = torch.from_numpy(confusion_matrix)
confusion_matrix = confusion_matrix.to(utils.get_device())
torch.distributed.reduce(confusion_matrix, 0)
confusion_matrix = confusion_matrix.cpu().numpy()
if state.distributed_rank <= 0:
tb_callback = state.callbacks[self.tensorboard_callback_name]
self._plot_confusion_matrix(
logger=tb_callback.loggers[state.loader_name],
epoch=state.global_epoch,
confusion_matrix=confusion_matrix,
class_names=class_names,
)
__all__ = ["ConfusionMatrixCallback"]