Source code for catalyst.contrib.nn.criterion.focal
from functools import partial
from torch.nn.modules.loss import _Loss
from catalyst.utils import metrics
[docs]class FocalLossBinary(_Loss):
    """Compute focal loss for binary classification problem.
    It has been proposed in `Focal Loss for Dense Object Detection`_ paper.
    @TODO: Docs (add `Example`). Contribution is welcome.
    .. _Focal Loss for Dense Object Detection:
        https://arxiv.org/abs/1708.02002
    """
[docs]    def __init__(
        self,
        ignore: int = None,
        reduced: bool = False,
        gamma: float = 2.0,
        alpha: float = 0.25,
        threshold: float = 0.5,
        reduction: str = "mean",
    ):
        """@TODO: Docs. Contribution is welcome."""
        super().__init__()
        self.ignore = ignore
        if reduced:
            self.loss_fn = partial(
                metrics.reduced_focal_loss,
                gamma=gamma,
                threshold=threshold,
                reduction=reduction,
            )
        else:
            self.loss_fn = partial(
                metrics.sigmoid_focal_loss,
                gamma=gamma,
                alpha=alpha,
                reduction=reduction,
            ) 
[docs]    def forward(self, logits, targets):
        """
        Args:
            logits: [bs; ...]
            targets: [bs; ...]
        @TODO: Docs. Contribution is welcome.
        """
        targets = targets.view(-1)
        logits = logits.view(-1)
        if self.ignore is not None:
            # Filter predictions with ignore label from loss computation
            not_ignored = targets != self.ignore
            logits = logits[not_ignored]
            targets = targets[not_ignored]
        loss = self.loss_fn(logits, targets)
        return loss  
[docs]class FocalLossMultiClass(FocalLossBinary):
    """Compute focal loss for multi-class problem.
    Ignores targets having -1 label.
    It has been proposed in `Focal Loss for Dense Object Detection`_ paper.
    @TODO: Docs (add `Example`). Contribution is welcome.
    .. _Focal Loss for Dense Object Detection:
        https://arxiv.org/abs/1708.02002
    """
[docs]    def forward(self, logits, targets):
        """
        Args:
            logits: [bs; num_classes; ...]
            targets: [bs; ...]
        @TODO: Docs. Contribution is welcome.
        """
        num_classes = logits.size(1)
        loss = 0
        targets = targets.view(-1)
        logits = logits.view(-1, num_classes)
        # Filter anchors with -1 label from loss computation
        if self.ignore is not None:
            not_ignored = targets != self.ignore
        for cls in range(num_classes):
            cls_label_target = (targets == (cls + 0)).long()
            cls_label_input = logits[..., cls]
            if self.ignore is not None:
                cls_label_target = cls_label_target[not_ignored]
                cls_label_input = cls_label_input[not_ignored]
            loss += self.loss_fn(cls_label_input, cls_label_target)
        return loss  
# @TODO: check
# class FocalLossMultiLabel(_Loss):
#     """Compute focal loss for multi-label problem.
#     Ignores targets having -1 label.
#
#     It has been proposed in `Focal Loss for Dense Object Detection`_ paper.
#
#     @TODO: Docs (add `Example`). Contribution is welcome.
#
#     .. _Focal Loss for Dense Object Detection:
#         https://arxiv.org/abs/1708.02002
#     """
#
#     def forward(self, logits, targets):
#         """
#         Args:
#             logits: [bs; num_classes]
#             targets: [bs; num_classes]
#         """
#         num_classes = logits.size(1)
#         loss = 0
#
#         for cls in range(num_classes):
#             # Filter anchors with -1 label from loss computation
#             if cls == self.ignore:
#                 continue
#
#             cls_label_target = targets[..., cls].long()
#             cls_label_input = logits[..., cls]
#
#             loss += self.loss_fn(cls_label_input, cls_label_target)
#
#         return loss
__all__ = ["FocalLossBinary", "FocalLossMultiClass"]