Source code for catalyst.contrib.nn.criterion.margin
from typing import List, Union
import torch
from torch import nn
from catalyst.contrib.nn.criterion.functional import margin_loss
[docs]class MarginLoss(nn.Module):
"""Margin loss criterion"""
[docs] def __init__(
self,
alpha: float = 0.2,
beta: float = 1.0,
skip_labels: Union[int, List[int]] = -1,
):
"""
Margin loss constructor.
Args:
alpha: alpha
beta: beta
skip_labels (int or List[int]): labels to skip
"""
super().__init__()
self.alpha = alpha
self.beta = beta
self.skip_labels = skip_labels
[docs] def forward(
self, embeddings: torch.Tensor, targets: torch.Tensor
) -> torch.Tensor:
"""
Forward method for the margin loss.
Args:
embeddings: tensor with embeddings
targets: tensor with target labels
Returns:
computed loss
"""
return margin_loss(
embeddings,
targets,
alpha=self.alpha,
beta=self.beta,
skip_labels=self.skip_labels,
)
__all__ = ["MarginLoss"]