Source code for catalyst.dl.callbacks.metrics.dice
import numpy as np
from catalyst.core import Callback, CallbackOrder, MetricCallback, State
from catalyst.dl import utils
from catalyst.utils import metrics
from .functional import calculate_dice
[docs]class DiceCallback(MetricCallback):
"""Dice metric callback."""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "dice",
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid",
):
"""
Args:
input_key (str): input key to use for dice calculation;
specifies our `y_true`
output_key (str): output key to use for dice calculation;
specifies our `y_pred`
"""
super().__init__(
prefix=prefix,
metric_fn=metrics.dice,
input_key=input_key,
output_key=output_key,
eps=eps,
threshold=threshold,
activation=activation,
)
[docs]class MulticlassDiceMetricCallback(Callback):
"""
Global Multi-Class Dice Metric Callback: calculates the exact
dice score across multiple batches. This callback is good for getting
the dice score with small batch sizes where the batchwise dice is noisier.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "dice",
class_names=None,
):
"""
Args:
input_key (str): input key to use for dice calculation;
specifies our `y_true`
output_key (str): output key to use for dice calculation;
specifies our `y_pred`
prefix (str): prefix for printing the metric
class_names (dict/List): if dictionary, should be:
{class_id: class_name, ...} where class_id is an integer
This allows you to ignore class indices.
if list, make sure it corresponds to the number of classes
"""
super().__init__(CallbackOrder.Metric)
self.input_key = input_key
self.output_key = output_key
self.prefix = prefix
self.confusion_matrix = None
self.class_names = class_names
def _reset_stats(self):
"""Resets the confusion matrix holding the epoch-wise stats."""
self.confusion_matrix = None
[docs] def on_batch_end(self, state: State):
"""Records the confusion matrix at the end of each batch.
Args:
state (State): current state
"""
outputs = state.output[self.output_key]
targets = state.input[self.input_key]
confusion_matrix = utils.calculate_confusion_matrix_from_tensors(
outputs, targets
)
if self.confusion_matrix is None:
self.confusion_matrix = confusion_matrix
else:
self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, state: State):
"""@TODO: Docs. Contribution is welcome.
Args:
state (State): current state
"""
tp_fp_fn_dict = utils.calculate_tp_fp_fn(self.confusion_matrix)
dice_scores: np.ndarray = calculate_dice(**tp_fp_fn_dict)
# logging the dice scores in the state
for i, dice in enumerate(dice_scores):
if isinstance(self.class_names, dict) and i not in list(
self.class_names.keys()
):
continue
postfix = (
self.class_names[i] if self.class_names is not None else str(i)
)
state.loader_metrics[f"{self.prefix}_{postfix}"] = dice
# For supporting averaging of only classes specified in `class_names`
values_to_avg = [
value
for key, value in state.loader_metrics.items()
if key.startswith(f"{self.prefix}_")
]
state.loader_metrics[f"{self.prefix}_mean"] = np.mean(values_to_avg)
self._reset_stats()
__all__ = ["DiceCallback", "MulticlassDiceMetricCallback"]