Source code for catalyst.contrib.nn.criterion.huber
import torch
import torch.nn as nn
[docs]class HuberLoss(nn.Module):
    def __init__(self, clip_delta=1.0, reduction="mean"):
        super().__init__()
        self.clip_delta = clip_delta
        self.reduction = reduction or "none"
[docs]    def forward(self, y_pred, y_true, weights=None):
        td_error = y_true - y_pred
        td_error_abs = torch.abs(td_error)
        quadratic_part = torch.clamp(td_error_abs, max=self.clip_delta)
        linear_part = td_error_abs - quadratic_part
        loss = 0.5 * quadratic_part ** 2 + self.clip_delta * linear_part
        if weights is not None:
            loss = torch.mean(loss * weights, dim=1)
        else:
            loss = torch.mean(loss, dim=1)
        if self.reduction == "mean":
            loss = torch.mean(loss)
        elif self.reduction == "sum":
            loss = torch.sum(loss)
        return loss  
__all__ = ["HuberLoss"]