Source code for catalyst.dl.callbacks.metrics.dice
from typing import Dict
import numpy as np
from catalyst import utils
from catalyst.dl.core import Callback, CallbackOrder, MetricCallback, State
from catalyst.utils import criterion
from .functional import calculate_dice
[docs]class DiceCallback(MetricCallback):
"""
Dice metric callback.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "dice",
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid"
):
"""
Args:
input_key (str): input key to use for dice calculation;
specifies our `y_true`.
output_key (str): output key to use for dice calculation;
specifies our `y_pred`.
"""
super().__init__(
prefix=prefix,
metric_fn=criterion.dice,
input_key=input_key,
output_key=output_key,
eps=eps,
threshold=threshold,
activation=activation
)
[docs]class MulticlassDiceMetricCallback(Callback):
def __init__(
self,
prefix: str = "dice",
input_key: str = "targets",
output_key: str = "logits",
class_names=None,
class_prefix="",
**metric_params
):
super().__init__(CallbackOrder.Metric)
self.prefix = prefix
self.input_key = input_key
self.output_key = output_key
self.metric_params = metric_params
self.confusion_matrix = None
self.class_names = class_names # dictionary {class_id: class_name}
self.class_prefix = class_prefix
def _reset_stats(self):
self.confusion_matrix = None
[docs] def on_batch_end(self, state: State):
outputs = state.output[self.output_key]
targets = state.input[self.input_key]
confusion_matrix = utils.calculate_confusion_matrix_from_tensors(
outputs, targets
)
if self.confusion_matrix is None:
self.confusion_matrix = confusion_matrix
else:
self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, state: State):
tp_fp_fn_dict = utils.calculate_tp_fp_fn(self.confusion_matrix)
batch_metrics: Dict = calculate_dice(**tp_fp_fn_dict)
for metric_id, dice_value in batch_metrics.items():
if metric_id not in self.class_names:
continue
metric_name = self.class_names[metric_id]
state.metric_manager.epoch_values[state.loader_name][
f"{self.class_prefix}_{metric_name}"] = dice_value
state.metric_manager.epoch_values[state.loader_name]["mean"] = np.mean(
[x for x in batch_metrics.values()]
)
self._reset_stats()
__all__ = ["DiceCallback", "MulticlassDiceMetricCallback"]