Shortcuts

Source code for catalyst.contrib.losses.ntxent

from math import e

import torch
from torch import nn


[docs]class NTXentLoss(nn.Module): """A Contrastive embedding loss. It has been proposed in `A Simple Framework for Contrastive Learning of Visual Representations`_. Example: .. code-block:: python import torch from torch.nn import functional as F from catalyst.contrib import NTXentLoss embeddings_left = F.normalize(torch.rand(256, 64, requires_grad=True)) embeddings_right = F.normalize(torch.rand(256, 64, requires_grad=True)) criterion = NTXentLoss(tau = 0.1) criterion(embeddings_left, embeddings_right) .. _`A Simple Framework for Contrastive Learning of Visual Representations`: https://arxiv.org/abs/2002.05709 """
[docs] def __init__(self, tau: float, reduction: str = "mean") -> None: """ Args: tau: temperature reduction (string, optional): specifies the reduction to apply to the output: ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of positive pairs in the output, ``"sum"``: the output will be summed. Raises: ValueError: if reduction is not mean, sum or none """ super().__init__() self.tau = tau self.cosine_sim = nn.CosineSimilarity() self.reduction = reduction if self.reduction not in ["none", "mean", "sum"]: raise ValueError( f"Reduction should be: mean, sum, none. But got - {self.reduction}!" )
def forward(self, features1: torch.Tensor, features2: torch.Tensor) -> torch.Tensor: """ Args: features1: batch with samples features of shape [bs; feature_len] features2: batch with samples features of shape [bs; feature_len] Returns: torch.Tensor: NTXent loss """ assert ( features1.shape == features2.shape ), f"Invalid shape of input features: {features1.shape} and {features2.shape}" feature_matrix = torch.cat([features1, features2]) feature_matrix = torch.nn.functional.normalize(feature_matrix) # if ||x|| = ||y|| = 1 then||x-y||^2 = 2 - 2<x,y> cosine_matrix = (2 - torch.cdist(feature_matrix, feature_matrix) ** 2) / 2 # todo try different places for temparature exp_cosine_matrix = torch.exp(cosine_matrix / self.tau) # neg part of the loss # torch.exp(1) self similarity exp_sim_sum = exp_cosine_matrix.sum(dim=1) - e ** (1 / self.tau) neg_loss = torch.log(exp_sim_sum) pos_loss = self.cosine_sim(features1, features2) / self.tau pos_loss = torch.cat([pos_loss, pos_loss]) # 2*poss_loss (i,j) and (j,i) loss = -pos_loss + neg_loss if self.reduction == "mean": loss = loss.mean() elif self.reduction == "sum": loss = loss.sum() return loss