Source code for catalyst.contrib.criterion.wing
from functools import partial
import math
import torch
import torch.nn as nn
[docs]def wing_loss(
outputs: torch.Tensor,
targets: torch.Tensor,
width: int = 5,
curvature: float = 0.5,
reduction: str = "mean"
):
"""
https://arxiv.org/pdf/1711.06753.pdf
Source https://github.com/BloodAxe/pytorch-toolbelt
See :class:`~pytorch_toolbelt.losses` for details.
"""
diff_abs = (targets - outputs).abs()
loss = diff_abs.clone()
idx_smaller = diff_abs < width
idx_bigger = diff_abs >= width
loss[idx_smaller] = \
width * torch.log(1 + diff_abs[idx_smaller] / curvature)
C = width - width * math.log(1 + width / curvature)
loss[idx_bigger] = loss[idx_bigger] - C
if reduction == "sum":
loss = loss.sum()
if reduction == "mean":
loss = loss.mean()
return loss
[docs]class WingLoss(nn.Module):
def __init__(
self, width: int = 5, curvature: float = 0.5, reduction: str = "mean"
):
super().__init__()
self.loss_fn = partial(
wing_loss, width=width, curvature=curvature, reduction=reduction
)
[docs] def forward(self, outputs, targets):
loss = self.loss_fn(outputs, targets)
return loss