# Source code for catalyst.contrib.nn.criterion.lovasz

# Lovasz-Softmax and Jaccard hinge loss in PyTorch
# Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)

from itertools import filterfalse as ifilterfalse

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss

# --------------------------- HELPER FUNCTIONS ---------------------------

def isnan(x):
return x != x

def mean(values, ignore_nan=False, empty=0):
"""
Nanmean compatible with generators.
"""
values = iter(values)
if ignore_nan:
values = ifilterfalse(isnan, values)
try:
n = 1
acc = next(values)
except StopIteration:
if empty == "raise":
raise ValueError("Empty mean")
return empty
for n, v in enumerate(values, 2):  # noqa: B007
acc += v
if n == 1:
return acc
return acc / n

def _lovasz_grad(gt_sorted):
"""
Compute gradient of the Lovasz extension w.r.t sorted errors,
see Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1.0 - intersection / union
if p > 1:  # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard

# ---------------------------- BINARY LOSSES -----------------------------

def _flatten_binary_scores(logits, targets, ignore=None):
"""
Flattens predictions in the batch (binary case).
Remove targets equal to "ignore"
"""
logits = logits.reshape(-1)
targets = targets.reshape(-1)
if ignore is None:
return logits, targets
valid = targets != ignore
logits_ = logits[valid]
targets_ = targets[valid]
return logits_, targets_

def _lovasz_hinge_flat(logits, targets):
"""The binary Lovasz hinge loss.

Args:
logits: [P] Variable, logits at each prediction
(between -iinfinity and +iinfinity)
targets: [P] Tensor, binary ground truth targets (0 or 1)
"""
if len(targets) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.0
signs = 2.0 * targets.float() - 1.0
errors = 1.0 - logits * signs
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = targets[perm]
grad = _lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), grad)
return loss

def _lovasz_hinge(logits, targets, per_image=True, ignore=None):
"""The binary Lovasz hinge loss.

Args:
logits: [B, H, W] Variable, logits at each pixel
(between -infinity and +infinity)
targets: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(
_lovasz_hinge_flat(
*_flatten_binary_scores(
logit.unsqueeze(0), target.unsqueeze(0), ignore
)
)
for logit, target in zip(logits, targets)
)
else:
loss = _lovasz_hinge_flat(
*_flatten_binary_scores(logits, targets, ignore)
)
return loss

# --------------------------- MULTICLASS LOSSES ---------------------------

def _flatten_probabilities(probabilities, targets, ignore=None):
"""
Flattens predictions in the batch
"""
if probabilities.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probabilities.size()
probabilities = probabilities.view(B, 1, H, W)
B, C, H, W = probabilities.size()
# B * H * W, C = P, C
probabilities = probabilities.permute(0, 2, 3, 1).contiguous().view(-1, C)
targets = targets.view(-1)
if ignore is None:
return probabilities, targets
valid = targets != ignore
probabilities_ = probabilities[valid.nonzero().squeeze()]
targets_ = targets[valid]
return probabilities_, targets_

def _lovasz_softmax_flat(probabilities, targets, classes="present"):
"""The multi-class Lovasz-Softmax loss.

Args:
probabilities: [P, C]
class probabilities at each prediction (between 0 and 1)
targets: [P] ground truth targets (between 0 and C - 1)
classes: "all" for all,
"present" for classes present in targets,
or a list of classes to average.
"""
if probabilities.numel() == 0:
# only void pixels, the gradients should be 0
return probabilities * 0.0
C = probabilities.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
for c in class_to_sum:
fg = (targets == c).float()  # foreground for class c
if classes == "present" and fg.sum() == 0:
continue
if C == 1:
if len(class_to_sum) > 1:
raise ValueError("Sigmoid output possible only with 1 class")
class_pred = probabilities[:, 0]
else:
class_pred = probabilities[:, c]
errors = (fg - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted)))
return mean(losses)

def _lovasz_softmax(
probabilities, targets, classes="present", per_image=False, ignore=None
):
"""The multi-class Lovasz-Softmax loss.

Args:
probabilities: [B, C, H, W]
class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output
with outputs of size [B, H, W].
targets: [B, H, W] ground truth targets (between 0 and C - 1)
classes: "all" for all,
"present" for classes present in targets,
or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class targets
"""
if per_image:
loss = mean(
_lovasz_softmax_flat(
*_flatten_probabilities(
prob.unsqueeze(0), lab.unsqueeze(0), ignore
),
classes=classes
)
for prob, lab in zip(probabilities, targets)
)
else:
loss = _lovasz_softmax_flat(
*_flatten_probabilities(probabilities, targets, ignore),
classes=classes
)
return loss

# ------------------------------ CRITERION -------------------------------

[docs]class LovaszLossBinary(_Loss):
"""Creates a criterion that optimizes a binary Lovasz loss.

It has been proposed in The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure
in neural networks_.

.. _The Lovasz-Softmax loss\: A tractable surrogate for the optimization
of the intersection-over-union measure in neural networks:
https://arxiv.org/abs/1705.08790
"""

[docs]    def __init__(self, per_image=False, ignore=None):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
self.ignore = ignore
self.per_image = per_image

[docs]    def forward(self, logits, targets):
"""Forward propagation method for the Lovasz loss.

Args:
logits: [bs; ...]
targets: [bs; ...]

@TODO: Docs. Contribution is welcome.
"""
loss = _lovasz_hinge(
logits, targets, per_image=self.per_image, ignore=self.ignore
)
return loss

[docs]class LovaszLossMultiClass(_Loss):
"""Creates a criterion that optimizes a multi-class Lovasz loss.

It has been proposed in The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure
in neural networks_.

.. _The Lovasz-Softmax loss\: A tractable surrogate for the optimization
of the intersection-over-union measure in neural networks:
https://arxiv.org/abs/1705.08790
"""

[docs]    def __init__(self, per_image=False, ignore=None):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
self.ignore = ignore
self.per_image = per_image

[docs]    def forward(self, logits, targets):
"""Forward propagation method for the Lovasz loss.

Args:
logits: [bs; num_classes; ...]
targets: [bs; ...]

@TODO: Docs. Contribution is welcome.
"""
loss = _lovasz_softmax(
logits, targets, per_image=self.per_image, ignore=self.ignore
)
return loss

[docs]class LovaszLossMultiLabel(_Loss):
"""Creates a criterion that optimizes a multi-label Lovasz loss.

It has been proposed in The Lovasz-Softmax loss: A tractable surrogate
for the optimization of the intersection-over-union measure
in neural networks_.

.. _The Lovasz-Softmax loss\: A tractable surrogate for the optimization
of the intersection-over-union measure in neural networks:
https://arxiv.org/abs/1705.08790
"""

[docs]    def __init__(self, per_image=False, ignore=None):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
self.ignore = ignore
self.per_image = per_image

[docs]    def forward(self, logits, targets):
"""Forward propagation method for the Lovasz loss.

Args:
logits: [bs; num_classes; ...]
targets: [bs; num_classes; ...]

@TODO: Docs. Contribution is welcome.
"""
losses = [
_lovasz_hinge(
logits[:, i, ...],
targets[:, i, ...],
per_image=self.per_image,
ignore=self.ignore,
)
for i in range(logits.shape[1])
]
loss = torch.mean(torch.stack(losses))
return loss

__all__ = ["LovaszLossBinary", "LovaszLossMultiClass", "LovaszLossMultiLabel"]