Source code for catalyst.contrib.nn.criterion.huber
import torch
import torch.nn as nn
[docs]class HuberLoss(nn.Module):
def __init__(self, clip_delta=1.0, reduction="mean"):
super().__init__()
self.clip_delta = clip_delta
self.reduction = reduction or "none"
[docs] def forward(self, y_pred, y_true, weights=None):
td_error = y_true - y_pred
td_error_abs = torch.abs(td_error)
quadratic_part = torch.clamp(td_error_abs, max=self.clip_delta)
linear_part = td_error_abs - quadratic_part
loss = 0.5 * quadratic_part ** 2 + self.clip_delta * linear_part
if weights is not None:
loss = torch.mean(loss * weights, dim=1)
else:
loss = torch.mean(loss, dim=1)
if self.reduction == "mean":
loss = torch.mean(loss)
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
__all__ = ["HuberLoss"]