Source code for catalyst.dl.utils.criterion.focal

import torch
import torch.nn.functional as F


[docs]def sigmoid_focal_loss( outputs: torch.Tensor, targets: torch.Tensor, gamma: float = 2.0, alpha: float = 0.25, reduction: str = "mean" ): """ Compute binary focal loss between target and output logits. Source https://github.com/BloodAxe/pytorch-toolbelt See :class:`~pytorch_toolbelt.losses` for details. Args: outputs: Tensor of arbitrary shape targets: Tensor of the same shape as input reduction (string, optional): Specifies the reduction to apply to the output: "none" | "mean" | "sum" | "batchwise_mean". "none": no reduction will be applied, "mean": the sum of the output will be divided by the number of elements in the output, "sum": the output will be summed. See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py # noqa: E501 """ targets = targets.type(outputs.type()) logpt = -F.binary_cross_entropy_with_logits( outputs, targets, reduction="none" ) pt = torch.exp(logpt) # compute the loss loss = -((1 - pt).pow(gamma)) * logpt if alpha is not None: loss = loss * (alpha * targets + (1 - alpha) * (1 - targets)) if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss
[docs]def reduced_focal_loss( outputs: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, gamma: float = 2.0, reduction="mean" ): """ Compute reduced focal loss between target and output logits. Source https://github.com/BloodAxe/pytorch-toolbelt See :class:`~pytorch_toolbelt.losses` for details. Args: outputs: Tensor of arbitrary shape targets: Tensor of the same shape as input reduction (string, optional): Specifies the reduction to apply to the output: "none" | "mean" | "sum" | "batchwise_mean". "none": no reduction will be applied, "mean": the sum of the output will be divided by the number of elements in the output, "sum": the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. "batchwise_mean" computes mean loss per sample in batch. Default: "mean" See https://arxiv.org/abs/1903.01347 """ targets = targets.type(outputs.type()) logpt = -F.binary_cross_entropy_with_logits( outputs, targets, reduction="none" ) pt = torch.exp(logpt) # compute the loss focal_reduction = ((1. - pt) / threshold).pow(gamma) focal_reduction[pt < threshold] = 1 loss = -focal_reduction * logpt if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss
__all__ = ["sigmoid_focal_loss", "reduced_focal_loss"]