Source code for catalyst.contrib.nn.criterion.dice
from functools import partial
import torch.nn as nn
from catalyst.utils import criterion
[docs]class DiceLoss(nn.Module):
def __init__(
self,
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid"
):
super().__init__()
self.loss_fn = partial(
criterion.dice,
eps=eps,
threshold=threshold,
activation=activation
)
[docs] def forward(self, logits, targets):
dice = self.loss_fn(logits, targets)
return 1 - dice
[docs]class BCEDiceLoss(nn.Module):
def __init__(
self,
eps: float = 1e-7,
threshold: float = None,
activation: str = "Sigmoid",
bce_weight: float = 0.5,
dice_weight: float = 0.5,
):
super().__init__()
if bce_weight == 0 and dice_weight == 0:
raise ValueError(
"Both bce_wight and dice_weight cannot be "
"equal to 0 at the same time."
)
self.bce_weight = bce_weight
self.dice_weight = dice_weight
if self.bce_weight != 0:
self.bce_loss = nn.BCEWithLogitsLoss()
if self.dice_weight != 0:
self.dice_loss = DiceLoss(
eps=eps, threshold=threshold, activation=activation
)
[docs] def forward(self, outputs, targets):
if self.bce_weight == 0:
return self.dice_weight * self.dice_loss(outputs, targets)
if self.dice_weight == 0:
return self.bce_weight * self.bce_loss(outputs, targets)
dice = self.dice_weight * self.dice_loss(outputs, targets)
bce = self.bce_weight * self.bce_loss(outputs, targets)
return dice + bce
__all__ = ["BCEDiceLoss", "DiceLoss"]