Source code for catalyst.contrib.nn.criterion.margin
from typing import List, Union
import torch
from torch import nn
from .functional import margin_loss
[docs]class MarginLoss(nn.Module):
    """@TODO: Docs. Contribution is welcome."""
[docs]    def __init__(
        self,
        alpha: float = 0.2,
        beta: float = 1.0,
        skip_labels: Union[int, List[int]] = -1,
    ):
        """
        Args:
            alpha (float):
            beta (float):
            skip_labels (int or List[int]):
        @TODO: Docs. Contribution is welcome.
        """
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.skip_labels = skip_labels 
[docs]    def forward(
        self, embeddings: torch.Tensor, targets: torch.Tensor
    ) -> torch.Tensor:
        """Forward propagation method for the margin loss.
        @TODO: Docs. Contribution is welcome.
        """
        return margin_loss(
            embeddings,
            targets,
            alpha=self.alpha,
            beta=self.beta,
            skip_labels=self.skip_labels,
        )