Shortcuts

Source code for catalyst.utils.metrics.focal

"""
Focal losses:
    * :func:`sigmoid_focal_loss`
    * :func:`reduced_focal_loss`
"""

import torch
import torch.nn.functional as F


[docs]def sigmoid_focal_loss( outputs: torch.Tensor, targets: torch.Tensor, gamma: float = 2.0, alpha: float = 0.25, reduction: str = "mean", ): """ Compute binary focal loss between target and output logits. Args: outputs: tensor of arbitrary shape targets: tensor of the same shape as input reduction (string, optional): specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"`` | ``"batchwise_mean"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Source: https://github.com/BloodAxe/pytorch-toolbelt """ targets = targets.type(outputs.type()) logpt = -F.binary_cross_entropy_with_logits( outputs, targets, reduction="none" ) pt = torch.exp(logpt) # compute the loss loss = -((1 - pt).pow(gamma)) * logpt if alpha is not None: loss = loss * (alpha * targets + (1 - alpha) * (1 - targets)) if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss
[docs]def reduced_focal_loss( outputs: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, gamma: float = 2.0, reduction="mean", ): """Compute reduced focal loss between target and output logits. It has been proposed in `Reduced Focal Loss\: 1st Place Solution to xView object detection in Satellite Imagery`_ paper. Args: outputs: tensor of arbitrary shape targets: tensor of the same shape as input reduction (string, optional): specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"`` | ``"batchwise_mean"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. ``"batchwise_mean"`` computes mean loss per sample in batch. Default: "mean" .. note:: ``size_average`` and ``reduce`` params are in the process of being deprecated, and in the meantime, specifying either of those two args will override ``reduction``. Source: https://github.com/BloodAxe/pytorch-toolbelt .. _Reduced Focal Loss\: 1st Place Solution to xView object detection in Satellite Imagery: https://arxiv.org/abs/1903.01347 """ targets = targets.type(outputs.type()) logpt = -F.binary_cross_entropy_with_logits( outputs, targets, reduction="none" ) pt = torch.exp(logpt) # compute the loss focal_reduction = ((1.0 - pt) / threshold).pow(gamma) focal_reduction[pt < threshold] = 1 loss = -focal_reduction * logpt if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss
__all__ = ["sigmoid_focal_loss", "reduced_focal_loss"]