Source code for catalyst.contrib.criterion.contrastive
import torch
import torch.nn as nn
[docs]class ContrastiveEmbeddingLoss(nn.Module):
"""
Contrastive embedding loss
paper: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=1.0, reduction="elementwise_mean"):
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, x0, x1, y):
# euclidian distance
diff = x0 - x1
dist = torch.sqrt(torch.sum(torch.pow(diff, 2), 1))
bs = len(y)
mdist = self.margin - dist
mdist_ = torch.clamp(mdist, min=0.0)
loss = (1 - y) * torch.pow(dist, 2) + y * torch.pow(mdist_, 2)
if self.reduction == "elementwise_mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
[docs]class ContrastiveDistanceLoss(nn.Module):
"""
Contrastive distance loss
"""
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
[docs] def forward(self, dist, y):
bs = len(y)
mdist = self.margin - dist
mdist_ = torch.clamp(mdist, min=0.0)
loss = (1 - y) * torch.pow(dist, 2) + y * torch.pow(mdist_, 2)
loss = torch.sum(loss) / 2.0 / bs
return loss