Source code for catalyst.dl.callbacks.metrics.dice

from typing import Dict

import numpy as np

from catalyst.dl.core import (
    Callback, CallbackOrder, MetricCallback, RunnerState
)
from catalyst.dl.utils import criterion
from catalyst.utils.confusion_matrix import (
    calculate_confusion_matrix_from_tensors, calculate_tp_fp_fn
)


[docs]class DiceCallback(MetricCallback): """ Dice metric callback. """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", prefix: str = "dice", eps: float = 1e-7, threshold: float = None, activation: str = "Sigmoid" ): """ Args: input_key (str): input key to use for dice calculation; specifies our `y_true`. output_key (str): output key to use for dice calculation; specifies our `y_pred`. """ super().__init__( prefix=prefix, metric_fn=criterion.dice, input_key=input_key, output_key=output_key, eps=eps, threshold=threshold, activation=activation )
def calculate_dice( true_positives: np.array, false_positives: np.array, false_negatives: np.array ) -> np.array: """Calculate list of Dice coefficients. Args: true_positives: false_positives: false_negatives: Returns: """ epsilon = 1e-7 dice = (2 * true_positives + epsilon) / ( 2 * true_positives + false_positives + false_negatives + epsilon ) if not np.all(dice <= 1): raise ValueError("Dice index should be less or equal to 1") if not np.all(dice > 0): raise ValueError("Dice index should be more than 1") return dice
[docs]class MulticlassDiceMetricCallback(Callback): def __init__( self, prefix: str = "dice", input_key: str = "targets", output_key: str = "logits", class_names=None, class_prefix="", **metric_params ): super().__init__(CallbackOrder.Metric) self.prefix = prefix self.input_key = input_key self.output_key = output_key self.metric_params = metric_params self.confusion_matrix = None self.class_names = class_names # dictionary {class_id: class_name} self.class_prefix = class_prefix def _reset_stats(self): self.confusion_matrix = None
[docs] def on_batch_end(self, state: RunnerState): outputs = state.output[self.output_key] targets = state.input[self.input_key] confusion_matrix = calculate_confusion_matrix_from_tensors( outputs, targets ) if self.confusion_matrix is None: self.confusion_matrix = confusion_matrix else: self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, state: RunnerState): tp_fp_fn_dict = calculate_tp_fp_fn(self.confusion_matrix) batch_metrics: Dict = calculate_dice(**tp_fp_fn_dict) for metric_id, dice_value in batch_metrics.items(): if metric_id not in self.class_names: continue metric_name = self.class_names[metric_id] state.metrics.epoch_values[state.loader_name][ f"{self.class_prefix}_{metric_name}" ] = dice_value state.metrics.epoch_values[state.loader_name]["mean"] = np.mean( [x for x in batch_metrics.values()] ) self._reset_stats()
__all__ = ["DiceCallback", "MulticlassDiceMetricCallback"]