Source code for catalyst.contrib.nn.criterion.contrastive
import torch
from torch import nn
from torch.nn import functional as F
[docs]class ContrastiveEmbeddingLoss(nn.Module):
"""The Contrastive embedding loss.
It has been proposed in `Dimensionality Reduction
by Learning an Invariant Mapping`_.
.. _Dimensionality Reduction by Learning an Invariant Mapping:
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
[docs] def __init__(self, margin=1.0, reduction="mean"):
"""
Args:
margin: margin parameter
reduction: criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(
self,
embeddings_left: torch.Tensor,
embeddings_right: torch.Tensor,
distance_true,
) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Args:
embeddings_left (torch.Tensor): left objects embeddings
embeddings_right (torch.Tensor): right objects embeddings
distance_true: true distances
Returns:
torch.Tensor: loss
"""
# euclidian distance
diff = embeddings_left - embeddings_right
distance_pred = torch.sqrt(torch.sum(torch.pow(diff, 2), 1))
bs = len(distance_true)
margin_distance = self.margin - distance_pred
margin_distance_ = torch.clamp(margin_distance, min=0.0)
loss = (1 - distance_true) * torch.pow(
distance_pred, 2
) + distance_true * torch.pow(margin_distance_, 2)
if self.reduction == "mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
[docs]class ContrastiveDistanceLoss(nn.Module):
"""The Contrastive distance loss.
@TODO: Docs. Contribution is welcome.
"""
[docs] def __init__(self, margin=1.0, reduction="mean"):
"""
Args:
margin: margin parameter
reduction (str): criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, distance_pred, distance_true) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Args:
distance_pred: predicted distances
distance_true: true distances
Returns:
torch.Tensor: loss
"""
bs = len(distance_true)
margin_distance = self.margin - distance_pred
margin_distance_ = torch.clamp(margin_distance, min=0.0)
loss = (1 - distance_true) * torch.pow(
distance_pred, 2
) + distance_true * torch.pow(margin_distance_, 2)
if self.reduction == "mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
[docs]class ContrastivePairwiseEmbeddingLoss(nn.Module):
"""ContrastivePairwiseEmbeddingLoss – proof of concept criterion.
Still work in progress.
@TODO: Docs. Contribution is welcome.
"""
[docs] def __init__(self, margin=1.0, reduction="mean"):
"""
Args:
margin: margin parameter
reduction: criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, embeddings_pred, embeddings_true) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Work in progress.
Args:
embeddings_pred: predicted embeddings
embeddings_true: true embeddings
Returns:
torch.Tensor: loss
"""
device = embeddings_pred.device
# s - state space
# d - embeddings space
# a - action space
pairwise_similarity = torch.einsum(
"se,ae->sa", embeddings_pred, embeddings_true
)
bs = embeddings_pred.shape[0]
batch_idx = torch.arange(bs, device=device)
loss = F.cross_entropy(
pairwise_similarity, batch_idx, reduction=self.reduction
)
return loss
__all__ = [
"ContrastiveEmbeddingLoss",
"ContrastiveDistanceLoss",
"ContrastivePairwiseEmbeddingLoss",
]