Source code for catalyst.contrib.losses.contrastive
import torch
from torch import nn
from torch.nn import functional as F
class ContrastiveEmbeddingLoss(nn.Module):
"""The Contrastive embedding loss.
It has been proposed in `Dimensionality Reduction
by Learning an Invariant Mapping`_.
.. _Dimensionality Reduction by Learning an Invariant Mapping:
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=1.0, reduction="mean"):
"""
Args:
margin: margin parameter
reduction: criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
def forward(
self,
embeddings_left: torch.Tensor,
embeddings_right: torch.Tensor,
distance_true,
) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Args:
embeddings_left: left objects embeddings
embeddings_right: right objects embeddings
distance_true: true distances
Returns:
torch.Tensor: loss
"""
# euclidian distance
diff = embeddings_left - embeddings_right
distance_pred = torch.sqrt(torch.sum(torch.pow(diff, 2), 1))
bs = len(distance_true)
margin_distance = self.margin - distance_pred
margin_distance = torch.clamp(margin_distance, min=0.0)
loss = (1 - distance_true) * torch.pow(
distance_pred, 2
) + distance_true * torch.pow(margin_distance, 2)
if self.reduction == "mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
class ContrastiveDistanceLoss(nn.Module):
"""The Contrastive distance loss.
@TODO: Docs. Contribution is welcome.
"""
def __init__(self, margin=1.0, reduction="mean"):
"""
Args:
margin: margin parameter
reduction: criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
def forward(self, distance_pred, distance_true) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Args:
distance_pred: predicted distances
distance_true: true distances
Returns:
torch.Tensor: loss
"""
bs = len(distance_true)
margin_distance = self.margin - distance_pred
margin_distance = torch.clamp(margin_distance, min=0.0)
loss = (1 - distance_true) * torch.pow(
distance_pred, 2
) + distance_true * torch.pow(margin_distance, 2)
if self.reduction == "mean":
loss = torch.sum(loss) / 2.0 / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
class ContrastivePairwiseEmbeddingLoss(nn.Module):
"""ContrastivePairwiseEmbeddingLoss – proof of concept criterion.
Still work in progress.
@TODO: Docs. Contribution is welcome.
"""
def __init__(self, margin=1.0, reduction="mean"):
"""
Args:
margin: margin parameter
reduction: criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
def forward(self, embeddings_pred, embeddings_true) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Work in progress.
Args:
embeddings_pred: predicted embeddings
embeddings_true: true embeddings
Returns:
torch.Tensor: loss
"""
device = embeddings_pred.device
# s - state space
# d - embeddings space
# a - action space
pairwise_similarity = torch.einsum("se,ae->sa", embeddings_pred, embeddings_true)
bs = embeddings_pred.shape[0]
batch_idx = torch.arange(bs, device=device)
loss = F.cross_entropy(pairwise_similarity, batch_idx, reduction=self.reduction)
return loss
[docs]class BarlowTwinsLoss(nn.Module):
"""The Contrastive embedding loss.
It has been proposed in `Barlow Twins:
Self-Supervised Learning via Redundancy Reduction`_.
Example:
.. code-block:: python
import torch
from torch.nn import functional as F
from catalyst.contrib import BarlowTwinsLoss
embeddings_left = F.normalize(torch.rand(256, 64, requires_grad=True))
embeddings_right = F.normalize(torch.rand(256, 64, requires_grad=True))
criterion = BarlowTwinsLoss(offdiag_lambda = 1)
criterion(embeddings_left, embeddings_right)
.. _`Barlow Twins: Self-Supervised Learning via Redundancy Reduction`:
https://arxiv.org/abs/2103.03230
"""
[docs] def __init__(self, offdiag_lambda=1.0, eps=1e-12):
"""
Args:
offdiag_lambda: trade-off parameter
eps: shift for the varience (var + eps)
"""
super().__init__()
self.offdiag_lambda = offdiag_lambda
self.eps = eps
def forward(
self, embeddings_left: torch.Tensor, embeddings_right: torch.Tensor
) -> torch.Tensor:
"""Forward propagation method for the contrastive loss.
Args:
embeddings_left: left objects embeddings [batch_size, features_dim]
embeddings_right: right objects embeddings [batch_size, features_dim]
Raises:
ValueError: if the batch size is 1
ValueError: if embeddings_left and embeddings_right shapes are different
ValueError: if embeddings shapes are not in a form (batch_size, features_dim)
Returns:
torch.Tensor: loss
"""
shape_left, shape_right = embeddings_left.shape, embeddings_right.shape
if len(shape_left) != 2:
raise ValueError(
"Left shape should be (batch_size, feature_dim),"
f"but got - {shape_left}!"
)
elif len(shape_right) != 2:
raise ValueError(
"Right shape should be (batch_size, feature_dim),"
f"but got - {shape_right}!"
)
if shape_left[0] == 1:
raise ValueError(f"Batch size should be >= 2, but got - {shape_left[0]}!")
if shape_left != shape_right:
raise ValueError(
f"Shapes should be equall, but got - {shape_left} and {shape_right}!"
)
# normalization
z_left = (embeddings_left - embeddings_left.mean(dim=0)) / (
embeddings_left.var(dim=0) + self.eps
).pow(1 / 2)
z_right = (embeddings_right - embeddings_right.mean(dim=0)) / (
embeddings_right.var(dim=0) + self.eps
).pow(1 / 2)
# cross-correlation matrix
batch_size = z_left.shape[0]
cross_correlation = torch.matmul(z_left.T, z_right) / batch_size
# selection of diagonal elements and off diagonal elements
on_diag = torch.diagonal(cross_correlation)
off_diag = cross_correlation.clone().fill_diagonal_(0)
# the loss described in the original Barlow Twin's paper
# encouraging off_diag to be zero and on_diag to be one
on_diag_loss = on_diag.add_(-1).pow_(2).sum()
off_diag_loss = off_diag.pow_(2).sum()
loss = on_diag_loss + self.offdiag_lambda * off_diag_loss
return loss
__all__ = [
"ContrastiveEmbeddingLoss",
"ContrastiveDistanceLoss",
"ContrastivePairwiseEmbeddingLoss",
"BarlowTwinsLoss",
]