Source code for catalyst.contrib.nn.criterion.margin

from typing import List, Union

import torch.nn as nn

from .functional import margin_loss


[docs]class MarginLoss(nn.Module):
[docs] def __init__( self, alpha: float = 0.2, beta: float = 1.0, skip_labels: Union[int, List[int]] = -1, ): """ Constructor method for the MarginLoss class. Args: alpha: beta: skip_labels: """ super().__init__() self.alpha = alpha self.beta = beta self.skip_labels = skip_labels
[docs] def forward(self, embeddings, targets): return margin_loss( embeddings, targets, alpha=self.alpha, beta=self.beta, skip_labels=self.skip_labels, )