Source code for catalyst.contrib.nn.criterion.huber

import torch
import torch.nn as nn


[docs]class HuberLoss(nn.Module): def __init__(self, clip_delta=1.0, reduction="mean"): super().__init__() self.clip_delta = clip_delta self.reduction = reduction or "none"
[docs] def forward(self, y_pred, y_true, weights=None): td_error = y_true - y_pred td_error_abs = torch.abs(td_error) quadratic_part = torch.clamp(td_error_abs, max=self.clip_delta) linear_part = td_error_abs - quadratic_part loss = 0.5 * quadratic_part ** 2 + self.clip_delta * linear_part if weights is not None: loss = torch.mean(loss * weights, dim=1) else: loss = torch.mean(loss, dim=1) if self.reduction == "mean": loss = torch.mean(loss) elif self.reduction == "sum": loss = torch.sum(loss) return loss
__all__ = ["HuberLoss"]