Source code for catalyst.contrib.nn.criterion.ce
# flake8: noqa
# TODO: update docs and shapes
import torch
from torch import nn
from torch.nn import functional as F
[docs]class NaiveCrossEntropyLoss(nn.Module):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, size_average=True):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
self.size_average = size_average
[docs] def forward(
self, input_: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""Calculates loss between ``input_`` and ``target`` tensors.
Args:
input_: input tensor of shape ...
target: target tensor of shape ...
@TODO: Docs (add shapes). Contribution is welcome.
"""
assert input_.size() == target.size()
input_ = F.log_softmax(input_)
loss = -torch.sum(input_ * target)
loss = loss / input_.size()[0] if self.size_average else loss
return loss
[docs]class SymmetricCrossEntropyLoss(nn.Module):
"""The Symmetric Cross Entropy loss.
It has been proposed in `Symmetric Cross Entropy for Robust Learning
with Noisy Labels`_.
.. _Symmetric Cross Entropy for Robust Learning with Noisy Labels:
https://arxiv.org/abs/1908.06112
"""
[docs] def __init__(self, alpha: float = 1.0, beta: float = 1.0):
"""
Args:
alpha(float):
corresponds to overfitting issue of CE
beta(float):
corresponds to flexible exploration on the robustness of RCE
"""
super(SymmetricCrossEntropyLoss, self).__init__()
self.alpha = alpha
self.beta = beta
[docs] def forward(
self, input_: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""Calculates loss between ``input_`` and ``target`` tensors.
Args:
input_: input tensor of size
(batch_size, num_classes)
target: target tensor of size (batch_size), where
values of a vector correspond to class index
Returns:
torch.Tensor: computed loss
"""
num_classes = input_.shape[1]
target_one_hot = F.one_hot(target, num_classes).float()
assert target_one_hot.shape == input_.shape
input_ = torch.clamp(input_, min=1e-7, max=1.0)
target_one_hot = torch.clamp(target_one_hot, min=1e-4, max=1.0)
cross_entropy = (
-torch.sum(target_one_hot * torch.log(input_), dim=1)
).mean()
reverse_cross_entropy = (
-torch.sum(input_ * torch.log(target_one_hot), dim=1)
).mean()
loss = self.alpha * cross_entropy + self.beta * reverse_cross_entropy
return loss
[docs]class MaskCrossEntropyLoss(nn.Module):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, *args, **kwargs):
"""@TODO: Docs. Contribution is welcome."""
super().__init__()
self.ce_loss = nn.CrossEntropyLoss(*args, **kwargs, reduction="none")
[docs] def forward(
self, logits: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
) -> torch.Tensor:
"""
Calculates loss between ``logits`` and ``target`` tensors.
Args:
logits: model logits
target: true targets
mask: targets mask
Returns:
torch.Tensor: computed loss
"""
loss = self.ce_loss.forward(logits, target)
loss = torch.mean(loss[mask == 1])
return loss
__all__ = [
"MaskCrossEntropyLoss",
"SymmetricCrossEntropyLoss",
"NaiveCrossEntropyLoss",
]