Source code for catalyst.contrib.nn.criterion.dice
from functools import partial
import torch.nn as nn
from catalyst.utils import criterion
[docs]class DiceLoss(nn.Module):
    def __init__(
        self,
        eps: float = 1e-7,
        threshold: float = None,
        activation: str = "Sigmoid"
    ):
        super().__init__()
        self.loss_fn = partial(
            criterion.dice,
            eps=eps,
            threshold=threshold,
            activation=activation
        )
[docs]    def forward(self, logits, targets):
        dice = self.loss_fn(logits, targets)
        return 1 - dice  
[docs]class BCEDiceLoss(nn.Module):
    def __init__(
        self,
        eps: float = 1e-7,
        threshold: float = None,
        activation: str = "Sigmoid",
        bce_weight: float = 0.5,
        dice_weight: float = 0.5,
    ):
        super().__init__()
        if bce_weight == 0 and dice_weight == 0:
            raise ValueError(
                "Both bce_wight and dice_weight cannot be "
                "equal to 0 at the same time."
            )
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        if self.bce_weight != 0:
            self.bce_loss = nn.BCEWithLogitsLoss()
        if self.dice_weight != 0:
            self.dice_loss = DiceLoss(
                eps=eps, threshold=threshold, activation=activation
            )
[docs]    def forward(self, outputs, targets):
        if self.bce_weight == 0:
            return self.dice_weight * self.dice_loss(outputs, targets)
        if self.dice_weight == 0:
            return self.bce_weight * self.bce_loss(outputs, targets)
        dice = self.dice_weight * self.dice_loss(outputs, targets)
        bce = self.bce_weight * self.bce_loss(outputs, targets)
        return dice + bce  
__all__ = ["BCEDiceLoss", "DiceLoss"]