Source code for catalyst.contrib.criterion.focal
from functools import partial
from torch.nn.modules.loss import _Loss
from catalyst.dl.utils import criterion
[docs]class FocalLossBinary(_Loss):
[docs] def __init__(
self,
ignore: int = None,
reduced: bool = False,
gamma: float = 2.0,
alpha: float = 0.25,
threshold: float = 0.5,
reduction: str = "mean",
):
"""
Compute focal loss for binary classification problem.
"""
super().__init__()
self.ignore = ignore
if reduced:
self.loss_fn = partial(
criterion.reduced_focal_loss,
gamma=gamma,
threshold=threshold,
reduction=reduction
)
else:
self.loss_fn = partial(
criterion.sigmoid_focal_loss,
gamma=gamma,
alpha=alpha,
reduction=reduction
)
[docs] def forward(self, logits, targets):
"""
Args:
logits: [bs; ...]
targets: [bs; ...]
"""
targets = targets.view(-1)
logits = logits.view(-1)
if self.ignore is not None:
# Filter predictions with ignore label from loss computation
not_ignored = targets != self.ignore
logits = logits[not_ignored]
targets = targets[not_ignored]
loss = self.loss_fn(logits, targets)
return loss
[docs]class FocalLossMultiClass(FocalLossBinary):
"""
Compute focal loss for multi-class problem.
Ignores targets having -1 label
"""
[docs] def forward(self, logits, targets):
"""
Args:
logits: [bs; num_classes; ...]
targets: [bs; ...]
"""
num_classes = logits.size(1)
loss = 0
targets = targets.view(-1)
logits = logits.view(-1, num_classes)
# Filter anchors with -1 label from loss computation
if self.ignore is not None:
not_ignored = targets != self.ignore
for cls in range(num_classes):
cls_label_target = (targets == (cls + 0)).long()
cls_label_input = logits[..., cls]
if self.ignore is not None:
cls_label_target = cls_label_target[not_ignored]
cls_label_input = cls_label_input[not_ignored]
loss += self.loss_fn(cls_label_input, cls_label_target)
return loss
# @TODO: check
# class FocalLossMultiLabel(_Loss):
# """
# Compute focal loss for multi-label problem.
# Ignores targets having -1 label
# """
#
# def forward(self, logits, targets):
# """
# Args:
# logits: [bs; num_classes]
# targets: [bs; num_classes]
# """
# num_classes = logits.size(1)
# loss = 0
#
# for cls in range(num_classes):
# # Filter anchors with -1 label from loss computation
# if cls == self.ignore:
# continue
#
# cls_label_target = targets[..., cls].long()
# cls_label_input = logits[..., cls]
#
# loss += self.loss_fn(cls_label_input, cls_label_target)
#
# return loss
__all__ = ["FocalLossBinary", "FocalLossMultiClass"]