Shortcuts

Source code for catalyst.utils.metrics.dice

"""
Dice metric.
"""

import numpy as np

import torch

from catalyst.utils.torch import get_activation_fn


[docs]def dice( outputs: torch.Tensor, targets: torch.Tensor, eps: float = 1e-7, threshold: float = None, activation: str = "Sigmoid", ): """Computes the dice metric. Args: outputs (list): a list of predicted elements targets (list): a list of elements that are to be predicted eps (float): epsilon threshold (float): threshold for outputs binarization activation (str): An torch.nn activation applied to the outputs. Must be one of ["none", "Sigmoid", "Softmax2d"] Returns: float: Dice score """ activation_fn = get_activation_fn(activation) outputs = activation_fn(outputs) if threshold is not None: outputs = (outputs > threshold).float() intersection = torch.sum(targets * outputs) union = torch.sum(targets) + torch.sum(outputs) # this looks a bit awkward but `eps * (union == 0)` term # makes sure that if I and U are both 0, than Dice == 1 # and if U != 0 and I == 0 the eps term in numerator is zeroed out # i.e. (0 + eps) / (U - 0 + eps) doesn't happen output_dice = (2 * intersection + eps * (union == 0)) / (union + eps) return output_dice
[docs]def calculate_dice( true_positives: np.array, false_positives: np.array, false_negatives: np.array, ) -> np.array: """ Calculate list of Dice coefficients. Args: true_positives: true positives numpy tensor false_positives: false positives numpy tensor false_negatives: false negatives numpy tensor Returns: np.array: dice score Raises: ValueError: if `dice` is out of [0; 1] bounds """ epsilon = 1e-7 dice_metric = (2 * true_positives + epsilon) / ( 2 * true_positives + false_positives + false_negatives + epsilon ) if not np.all(dice_metric <= 1): raise ValueError("Dice index should be less or equal to 1") if not np.all(dice_metric > 0): raise ValueError("Dice index should be more than 1") return dice_metric
__all__ = ["dice", "calculate_dice"]