Source code for catalyst.contrib.nn.criterion.margin
from typing import List, Union
import torch
from torch import nn
from .functional import margin_loss
[docs]class MarginLoss(nn.Module):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(
self,
alpha: float = 0.2,
beta: float = 1.0,
skip_labels: Union[int, List[int]] = -1,
):
"""
Args:
alpha (float):
beta (float):
skip_labels (int or List[int]):
@TODO: Docs. Contribution is welcome.
"""
super().__init__()
self.alpha = alpha
self.beta = beta
self.skip_labels = skip_labels
[docs] def forward(
self, embeddings: torch.Tensor, targets: torch.Tensor
) -> torch.Tensor:
"""Forward propagation method for the margin loss.
@TODO: Docs. Contribution is welcome.
"""
return margin_loss(
embeddings,
targets,
alpha=self.alpha,
beta=self.beta,
skip_labels=self.skip_labels,
)