Source code for catalyst.dl.callbacks.metrics.dice

from typing import Dict

import numpy as np

from catalyst import utils
from catalyst.dl.core import Callback, CallbackOrder, MetricCallback, State
from catalyst.utils import criterion
from .functional import calculate_dice


[docs]class DiceCallback(MetricCallback): """ Dice metric callback. """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "dice", eps: float = 1e-7, threshold: float = None, activation: str = "Sigmoid" ): """ Args: input_key (str): input key to use for dice calculation; specifies our `y_true`. output_key (str): output key to use for dice calculation; specifies our `y_pred`. """ super().__init__( prefix=prefix, metric_fn=criterion.dice, input_key=input_key, output_key=output_key, eps=eps, threshold=threshold, activation=activation )
[docs]class MulticlassDiceMetricCallback(Callback): def __init__( self, prefix: str = "dice", input_key: str = "targets", output_key: str = "logits", class_names=None, class_prefix="", **metric_params ): super().__init__(CallbackOrder.Metric) self.prefix = prefix self.input_key = input_key self.output_key = output_key self.metric_params = metric_params self.confusion_matrix = None self.class_names = class_names # dictionary {class_id: class_name} self.class_prefix = class_prefix def _reset_stats(self): self.confusion_matrix = None
[docs] def on_batch_end(self, state: State): outputs = state.output[self.output_key] targets = state.input[self.input_key] confusion_matrix = utils.calculate_confusion_matrix_from_tensors( outputs, targets ) if self.confusion_matrix is None: self.confusion_matrix = confusion_matrix else: self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, state: State): tp_fp_fn_dict = utils.calculate_tp_fp_fn(self.confusion_matrix) batch_metrics: Dict = calculate_dice(**tp_fp_fn_dict) for metric_id, dice_value in batch_metrics.items(): if metric_id not in self.class_names: continue metric_name = self.class_names[metric_id] state.metric_manager.epoch_values[state.loader_name][ f"{self.class_prefix}_{metric_name}"] = dice_value state.metric_manager.epoch_values[state.loader_name]["mean"] = np.mean( [x for x in batch_metrics.values()] ) self._reset_stats()
__all__ = ["DiceCallback", "MulticlassDiceMetricCallback"]