Source code for catalyst.contrib.criterion.wing
from functools import partial
import math
import torch
import torch.nn as nn
[docs]def wing_loss(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    width: int = 5,
    curvature: float = 0.5,
    reduction: str = "mean"
):
    """
    https://arxiv.org/pdf/1711.06753.pdf
    Source https://github.com/BloodAxe/pytorch-toolbelt
    See :class:`~pytorch_toolbelt.losses` for details.
    """
    diff_abs = (targets - outputs).abs()
    loss = diff_abs.clone()
    idx_smaller = diff_abs < width
    idx_bigger = diff_abs >= width
    loss[idx_smaller] = \
        
width * torch.log(1 + diff_abs[idx_smaller] / curvature)
    C = width - width * math.log(1 + width / curvature)
    loss[idx_bigger] = loss[idx_bigger] - C
    if reduction == "sum":
        loss = loss.sum()
    if reduction == "mean":
        loss = loss.mean()
    return loss 
[docs]class WingLoss(nn.Module):
    def __init__(
        self, width: int = 5, curvature: float = 0.5, reduction: str = "mean"
    ):
        super().__init__()
        self.loss_fn = partial(
            wing_loss, width=width, curvature=curvature, reduction=reduction
        )
[docs]    def forward(self, outputs, targets):
        loss = self.loss_fn(outputs, targets)
        return loss