Source code for catalyst.contrib.criterion.contrastive
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class ContrastiveEmbeddingLoss(nn.Module):
"""
Contrastive embedding loss
paper: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
[docs] def __init__(self, margin=1.0, reduction="mean"):
"""
Constructor method for the ContrastiveEmbeddingLoss class.
Args:
margin: margin parameter.
reduction: criterion reduction type.
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, embeddings_left, embeddings_right, distance_true):
"""
Forward propagation method for the contrastive loss.
Args:
embeddings_left: left objects embeddings
embeddings_right: right objects embeddings
distance_true: true distances
Returns:
loss
"""
# euclidian distance
diff = embeddings_left - embeddings_right
distance_pred = torch.sqrt(torch.sum(torch.pow(diff, 2), 1))
bs = len(distance_true)
margin_distance = self.margin - distance_pred
margin_distance_ = torch.clamp(margin_distance, min=0.0)
loss = (
(1 - distance_true) * torch.pow(distance_pred, 2)
+ distance_true * torch.pow(margin_distance_, 2)
)
if self.reduction == "mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
[docs]class ContrastiveDistanceLoss(nn.Module):
"""
Contrastive distance loss
"""
[docs] def __init__(self, margin=1.0, reduction="mean"):
"""
Constructor method for the ContrastiveDistanceLoss class.
Args:
margin: margin parameter.
reduction: criterion reduction type.
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, distance_pred, distance_true):
"""
Forward propagation method for the contrastive loss.
Args:
distance_pred: predicted distances
distance_true: true distances
Returns:
loss
"""
bs = len(distance_true)
margin_distance = self.margin - distance_pred
margin_distance_ = torch.clamp(margin_distance, min=0.0)
loss = (
(1 - distance_true) * torch.pow(distance_pred, 2)
+ distance_true * torch.pow(margin_distance_, 2)
)
if self.reduction == "mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
[docs]class ContrastivePairwiseEmbeddingLoss(nn.Module):
"""
ContrastivePairwiseEmbeddingLoss – proof of concept criterion.
Still work in progress.
"""
[docs] def __init__(self, margin=1.0, reduction="mean"):
"""
Constructor method for the ContrastivePairwiseEmbeddingLoss class.
Args:
margin: margin parameter.
reduction: criterion reduction type.
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, embeddings_pred, embeddings_true):
"""
Work in progress.
Args:
embeddings_pred: predicted embeddings
embeddings_true: true embeddings
Returns:
loss
"""
device = embeddings_pred.device
# s - state space
# d - embeddings space
# a - action space
pairwise_similarity = torch.einsum(
"se,ae->sa",
embeddings_pred,
embeddings_true
)
bs = embeddings_pred.shape[0]
batch_idx = torch.arange(bs, device=device)
loss = F.cross_entropy(
pairwise_similarity,
batch_idx,
reduction=self.reduction
)
return loss
__all__ = [
"ContrastiveEmbeddingLoss",
"ContrastiveDistanceLoss",
"ContrastivePairwiseEmbeddingLoss"
]