Source code for catalyst.callbacks.metrics.dice
from typing import TYPE_CHECKING
import numpy as np
from catalyst.callbacks.metric import BatchMetricCallback
from catalyst.contrib.utils.confusion_matrix import (
calculate_confusion_matrix_from_tensors,
calculate_tp_fp_fn,
)
from catalyst.core.callback import Callback, CallbackOrder
from catalyst.metrics.dice import calculate_dice, dice
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
[docs]class DiceCallback(BatchMetricCallback):
"""Dice metric callback.
Args:
input_key: input key to use for iou calculation
specifies our ``y_true``
output_key: output key to use for iou calculation;
specifies our ``y_pred``
prefix: key to store in logs
eps: epsilon to avoid zero division
threshold: threshold for outputs binarization
activation: An torch.nn activation applied to the outputs.
Must be one of ``'none'``, ``'Sigmoid'``, ``'Softmax2d'``
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "dice",
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid",
):
"""
Args:
input_key: input key to use for iou calculation
specifies our ``y_true``
output_key: output key to use for iou calculation;
specifies our ``y_pred``
prefix: key to store in logs
eps: epsilon to avoid zero division
threshold: threshold for outputs binarization
activation: An torch.nn activation applied to the outputs.
Must be one of ``'none'``, ``'Sigmoid'``, ``'Softmax2d'``
"""
super().__init__(
prefix=prefix,
metric_fn=dice,
input_key=input_key,
output_key=output_key,
eps=eps,
threshold=threshold,
activation=activation,
)
[docs]class MultiClassDiceMetricCallback(Callback):
"""
Global Multi-Class Dice Metric Callback: calculates the exact
dice score across multiple batches. This callback is good for getting
the dice score with small batch sizes where the batchwise dice is noisier.
"""
[docs] def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "dice",
class_names=None,
):
"""
Args:
input_key: input key to use for dice calculation;
specifies our `y_true`
output_key: output key to use for dice calculation;
specifies our `y_pred`
prefix: prefix for printing the metric
class_names: if dictionary, should be:
{class_id: class_name, ...} where class_id is an integer
This allows you to ignore class indices.
if list, make sure it corresponds to the number of classes
"""
super().__init__(CallbackOrder.metric)
self.input_key = input_key
self.output_key = output_key
self.prefix = prefix
self.confusion_matrix = None
self.class_names = class_names
def _reset_stats(self):
"""Resets the confusion matrix holding the epoch-wise stats."""
self.confusion_matrix = None
[docs] def on_batch_end(self, runner: "IRunner"):
"""Records the confusion matrix at the end of each batch.
Args:
runner: current runner
"""
outputs = runner.output[self.output_key]
targets = runner.input[self.input_key]
confusion_matrix = calculate_confusion_matrix_from_tensors(
outputs, targets
)
if self.confusion_matrix is None:
self.confusion_matrix = confusion_matrix
else:
self.confusion_matrix += confusion_matrix
[docs] def on_loader_end(self, runner: "IRunner"):
"""@TODO: Docs. Contribution is welcome.
Args:
runner: current runner
"""
tp_fp_fn_dict = calculate_tp_fp_fn(self.confusion_matrix)
dice_scores: np.ndarray = calculate_dice(**tp_fp_fn_dict)
# logging the dice scores in the state
for i, dice_score in enumerate(dice_scores):
if isinstance(self.class_names, dict) and i not in list(
self.class_names.keys()
):
continue
postfix = (
self.class_names[i] if self.class_names is not None else str(i)
)
runner.loader_metrics[f"{self.prefix}_{postfix}"] = dice_score
# For supporting averaging of only classes specified in `class_names`
values_to_avg = [
value
for key, value in runner.loader_metrics.items()
if key.startswith(f"{self.prefix}_")
]
runner.loader_metrics[f"{self.prefix}_mean"] = np.mean(values_to_avg)
self._reset_stats()
# backward compatibility
MulticlassDiceMetricCallback = MultiClassDiceMetricCallback
__all__ = [
"DiceCallback",
"MultiClassDiceMetricCallback",
"MulticlassDiceMetricCallback",
]