Source code for catalyst.contrib.nn.criterion.margin
from typing import List, Union
import torch.nn as nn
from .functional import margin_loss
[docs]class MarginLoss(nn.Module):
[docs] def __init__(
self,
alpha: float = 0.2,
beta: float = 1.0,
skip_labels: Union[int, List[int]] = -1,
):
"""
Constructor method for the MarginLoss class.
Args:
alpha:
beta:
skip_labels:
"""
super().__init__()
self.alpha = alpha
self.beta = beta
self.skip_labels = skip_labels
[docs] def forward(self, embeddings, targets):
return margin_loss(
embeddings,
targets,
alpha=self.alpha,
beta=self.beta,
skip_labels=self.skip_labels,
)