Source code for catalyst.dl.callbacks.misc

from typing import Dict, List  # isort:skip

import numpy as np
from sklearn.metrics import confusion_matrix as confusion_matrix_fn

from catalyst.dl import utils
from catalyst.dl.core import (
    Callback, CallbackOrder, LoggerCallback, RunnerState
)
from catalyst.dl.meters import ConfusionMeter


[docs]class EarlyStoppingCallback(Callback): def __init__( self, patience: int, metric: str = "loss", minimize: bool = True, min_delta: float = 1e-6 ): super().__init__(CallbackOrder.External) self.best_score = None self.metric = metric self.patience = patience self.num_bad_epochs = 0 self.is_better = None if minimize: self.is_better = lambda score, best: score <= (best - min_delta) else: self.is_better = lambda score, best: score >= (best - min_delta)
[docs] def on_epoch_end(self, state: RunnerState) -> None: if state.stage.startswith("infer"): return score = state.metrics.valid_values[self.metric] if self.best_score is None: self.best_score = score if self.is_better(score, self.best_score): self.num_bad_epochs = 0 self.best_score = score else: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: print(f"Early stop at {state.stage_epoch} epoch") state.early_stop = True
[docs]class ConfusionMatrixCallback(Callback): def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "confusion_matrix", version: str = "tnt", class_names: List[str] = None, num_classes: int = None, plot_params: Dict = None ): super().__init__(CallbackOrder.Metric) self.prefix = prefix self.output_key = output_key self.input_key = input_key assert version in ["tnt", "sklearn"] self._version = version self._plot_params = plot_params or {} self.class_names = class_names self.num_classes = num_classes \ if class_names is None \ else len(class_names) assert self.num_classes is not None self._reset_stats() def _reset_stats(self): if self._version == "tnt": self.confusion_matrix = ConfusionMeter(self.num_classes) elif self._version == "sklearn": self.outputs = [] self.targets = [] def _add_to_stats(self, outputs, targets): if self._version == "tnt": self.confusion_matrix.add(predicted=outputs, target=targets) elif self._version == "sklearn": outputs = outputs.cpu().numpy() targets = targets.cpu().numpy() outputs = np.argmax(outputs, axis=1) self.outputs.extend(outputs) self.targets.extend(targets) def _compute_confusion_matrix(self): if self._version == "tnt": confusion_matrix = self.confusion_matrix.value() elif self._version == "sklearn": confusion_matrix = confusion_matrix_fn( y_true=self.targets, y_pred=self.outputs ) else: raise NotImplementedError() return confusion_matrix def _plot_confusion_matrix( self, logger, epoch, confusion_matrix, class_names=None ): fig = utils.plot_confusion_matrix( confusion_matrix, class_names=class_names, normalize=True, show=False, **self._plot_params ) fig = utils.render_figure_to_tensor(fig) logger.add_image(f"{self.prefix}/epoch", fig, global_step=epoch)
[docs] def on_loader_start(self, state: RunnerState): self._reset_stats()
[docs] def on_batch_end(self, state: RunnerState): self._add_to_stats( state.output[self.output_key].detach(), state.input[self.input_key].detach() )
[docs] def on_loader_end(self, state: RunnerState): class_names = \ self.class_names or \ [str(i) for i in range(self.num_classes)] confusion_matrix = self._compute_confusion_matrix() self._plot_confusion_matrix( logger=state.loggers["tensorboard"].loggers[state.loader_name], epoch=state.epoch, confusion_matrix=confusion_matrix, class_names=class_names )
[docs]class RaiseExceptionCallback(LoggerCallback): def __init__(self): order = CallbackOrder.Other + 1 super().__init__(order=order)
[docs] def on_exception(self, state: RunnerState): exception = state.exception if not utils.is_exception(exception): return if state.need_reraise_exception: raise exception
__all__ = [ "EarlyStoppingCallback", "ConfusionMatrixCallback", "RaiseExceptionCallback" ]