Source code for catalyst.metrics.dice
from functools import partial
import numpy as np
import torch
[docs]def dice(
outputs: torch.Tensor,
targets: torch.Tensor,
class_dim: int = 1,
threshold: float = None,
eps: float = 1e-7,
) -> torch.Tensor:
"""Computes the dice score.
Args:
outputs: [N; K; ...] tensor that for each of the N examples
indicates the probability of the example belonging to each of
the K classes, according to the model.
targets: binary [N; K; ...] tensort that encodes which of the K
classes are associated with the N-th input
class_dim: indicates class dimention (K) for
``outputs`` and ``targets`` tensors (default = 1)
threshold: threshold for outputs binarization
eps: epsilon to avoid zero division
Returns:
Dice score
Examples:
>>> size = 4
>>> half_size = size // 2
>>> shape = (1, 1, size, size)
>>> empty = torch.zeros(shape)
>>> full = torch.ones(shape)
>>> left = torch.ones(shape)
>>> left[:, :, :, half_size:] = 0
>>> right = torch.ones(shape)
>>> right[:, :, :, :half_size] = 0
>>> top_left = torch.zeros(shape)
>>> top_left[:, :, :half_size, :half_size] = 1
>>> pred = torch.cat([empty, left, empty, full, left, top_left], dim=1)
>>> targets = torch.cat([full, right, empty, full, left, left], dim=1)
>>> dice(
>>> outputs=pred,
>>> targets=targets,
>>> class_dim=1,
>>> threshold=0.5,
>>> )
tensor([0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 0.66666])
"""
if threshold is not None:
outputs = (outputs > threshold).float()
num_dims = len(outputs.shape)
assert num_dims > 2, "shape mismatch, please check the docs for more info"
assert (
outputs.shape == targets.shape
), "shape mismatch, please check the docs for more info"
dims = list(range(num_dims))
# support negative index
if class_dim < 0:
class_dim = num_dims + class_dim
dims.pop(class_dim)
sum_fn = partial(torch.sum, dim=dims)
intersection = sum_fn(targets * outputs)
union = sum_fn(targets) + sum_fn(outputs)
# this looks a bit awkward but `eps * (union == 0)` term
# makes sure that if I and U are both 0, than Dice == 1
# and if U != 0 and I == 0 the eps term in numerator is zeroed out
# i.e. (0 + eps) / (U - 0 + eps) doesn't happen
dice_score = (2 * intersection + eps * (union == 0).float()) / (
union + eps
)
return dice_score
# @TODO: remove
[docs]def calculate_dice(
true_positives: np.array,
false_positives: np.array,
false_negatives: np.array,
) -> np.array:
"""
Calculate list of Dice coefficients.
Args:
true_positives: true positives numpy tensor
false_positives: false positives numpy tensor
false_negatives: false negatives numpy tensor
Returns:
np.array: dice score
Raises:
ValueError: if `dice` is out of [0; 1] bounds
"""
epsilon = 1e-7
dice_metric = (2 * true_positives + epsilon) / (
2 * true_positives + false_positives + false_negatives + epsilon
)
if not np.all(dice_metric <= 1):
raise ValueError("Dice index should be less or equal to 1")
if not np.all(dice_metric > 0):
raise ValueError("Dice index should be more than 1")
return dice_metric
__all__ = ["dice", "calculate_dice"]