Shortcuts

Source code for catalyst.contrib.nn.criterion.huber

# flake8: noqa
# @TODO: code formatting issue for 20.07 release
import torch
from torch import nn


[docs]class HuberLoss(nn.Module): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, clip_delta=1.0, reduction="mean"): """@TODO: Docs. Contribution is welcome.""" super().__init__() self.clip_delta = clip_delta self.reduction = reduction or "none"
[docs] def forward( self, y_pred: torch.Tensor, y_true: torch.Tensor, weights=None ) -> torch.Tensor: """@TODO: Docs. Contribution is welcome.""" td_error = y_true - y_pred td_error_abs = torch.abs(td_error) quadratic_part = torch.clamp(td_error_abs, max=self.clip_delta) linear_part = td_error_abs - quadratic_part loss = 0.5 * quadratic_part ** 2 + self.clip_delta * linear_part if weights is not None: loss = torch.mean(loss * weights, dim=1) else: loss = torch.mean(loss, dim=1) if self.reduction == "mean": loss = torch.mean(loss) elif self.reduction == "sum": loss = torch.sum(loss) return loss
__all__ = ["HuberLoss"]