# flake8: noqa
from typing import List, TYPE_CHECKING, Union
import torch
from torch import nn, Tensor
from torch.nn import TripletMarginLoss
from catalyst.contrib.nn.criterion.functional import triplet_loss
from catalyst.utils.misc import convert_labels2list
if TYPE_CHECKING:
from catalyst.data.sampler_inbatch import IInbatchTripletSampler
TORCH_BOOL = torch.bool if torch.__version__ > "1.1.0" else torch.ByteTensor
class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Adapted from: https://github.com/NegatioN/OnlineMiningTripletLoss
"""
def __init__(self, margin: float = 0.3):
"""
Args:
margin: margin for triplet
"""
super().__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def _pairwise_distances(self, embeddings, squared=False):
"""Compute the 2D matrix of distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: if true, output is the pairwise squared euclidean
distance matrix. If false, output is the pairwise euclidean
distance matrix
Returns:
torch.Tensor: pairwise matrix of size (batch_size, batch_size)
"""
# Get squared L2 norm for each embedding.
# We can just take the diagonal of `dot_product`.
# This also provides more numerical stability
# (the diagonal of the result will be exactly 0).
# shape (batch_size)
square = torch.mm(embeddings, embeddings.t())
diag = torch.diag(square)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = diag.view(-1, 1) - 2.0 * square + diag.view(1, -1)
# Because of computation errors, some distances
# might be negative so we put everything >= 0.0
distances[distances < 0] = 0
if not squared:
# Because the gradient of sqrt is infinite
# when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = distances.eq(0).float()
distances = distances + mask * 1e-16
distances = (1.0 - mask) * torch.sqrt(distances)
return distances
def _get_anchor_positive_triplet_mask(self, labels):
"""
Return a 2D mask where mask[a, p] is True
if a and p are distinct and have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
torch.Tensor: mask with shape [batch_size, batch_size]
"""
indices_equal = torch.eye(labels.size(0)).type(torch.bool)
# labels and indices should be on
# the same device, otherwise - exception
indices_equal = indices_equal.to("cuda" if labels.is_cuda else "cpu")
# Check that i and j are distinct
indices_equal = indices_equal.type(TORCH_BOOL)
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
# Uses broadcasting where the 1st argument
# has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(self, labels):
"""Return 2D mask where mask[a, n] is True if a and n have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
torch.Tensor: mask with shape [batch_size, batch_size]
"""
# Check if labels[i] != labels[k]
# Uses broadcasting where the 1st argument
# has shape (1, batch_size) and the 2nd (batch_size, 1)
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
def _batch_hard_triplet_loss(
self, embeddings, labels, margin, squared=True,
):
"""
Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and
hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared
euclidean distance matrix. If false, output is the
pairwise euclidean distance matrix.
Returns:
torch.Tensor: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = self._pairwise_distances(embeddings, squared=squared)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid
# positive (they should have same label)
mask_anchor_positive = self._get_anchor_positive_triplet_mask(labels).float()
# We put to 0 any element where (a, p) is not valid
# (valid if a != p and label(a) == label(p))
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# shape (batch_size, 1)
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative
# (they should have different labels)
mask_anchor_negative = self._get_anchor_negative_triplet_mask(labels).float()
# We add the maximum value in each row
# to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
1.0 - mask_anchor_negative
)
# shape (batch_size)
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
tl = hardest_positive_dist - hardest_negative_dist + margin
tl[tl < 0] = 0
loss = tl.mean()
return loss
def forward(self, embeddings, targets):
"""Forward propagation method for the triplet loss.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
targets: labels of the batch, of size (batch_size)
Returns:
torch.Tensor: scalar tensor containing the triplet loss
"""
return self._batch_hard_triplet_loss(embeddings, targets, self.margin)
class TripletLossV2(nn.Module):
"""@TODO: Docs. Contribution is welcome."""
def __init__(self, margin=0.3):
"""
Args:
margin: margin for triplet.
"""
super().__init__()
self.margin = margin
def forward(self, embeddings, targets):
"""@TODO: Docs. Contribution is welcome."""
return triplet_loss(embeddings, targets, margin=self.margin)
class TripletPairwiseEmbeddingLoss(nn.Module):
"""TripletPairwiseEmbeddingLoss – proof of concept criterion.
Still work in progress.
@TODO: Docs. Contribution is welcome.
"""
def __init__(self, margin: float = 0.3, reduction: str = "mean"):
"""
Args:
margin: margin parameter
reduction: criterion reduction type
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
def forward(self, embeddings_pred, embeddings_true):
"""
Work in progress.
Args:
embeddings_pred: predicted embeddings
with shape [batch_size, embedding_size]
embeddings_true: true embeddings
with shape [batch_size, embedding_size]
Returns:
torch.Tensor: loss
"""
device = embeddings_pred.device
# s - state space
# d - embeddings space
# a - action space
# [batch_size, embedding_size] x [batch_size, embedding_size]
# -> [batch_size, batch_size]
pairwise_similarity = torch.einsum("se,ae->sa", embeddings_pred, embeddings_true)
bs = embeddings_pred.shape[0]
batch_idx = torch.arange(bs, device=device)
negative_similarity = pairwise_similarity + torch.diag(
torch.full([bs], -(10 ** 9), device=device)
)
# TODO argsort, take k worst
hard_negative_ids = negative_similarity.argmax(dim=-1)
negative_similarities = pairwise_similarity[batch_idx, hard_negative_ids]
positive_similarities = pairwise_similarity[batch_idx, batch_idx]
loss = torch.relu(self.margin - positive_similarities + negative_similarities)
if self.reduction == "mean":
loss = torch.sum(loss) / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
[docs]class TripletMarginLossWithSampler(nn.Module):
"""
This class combines in-batch sampling of triplets and
default TripletMargingLoss from PyTorch.
"""
[docs] def __init__(self, margin: float, sampler_inbatch: "IInbatchTripletSampler"):
"""
Args:
margin: margin value
sampler_inbatch: sampler for forming triplets inside the batch
"""
super().__init__()
self._sampler_inbatch = sampler_inbatch
self._triplet_margin_loss = TripletMarginLoss(margin=margin)
def forward(self, features: Tensor, labels: Union[Tensor, List[int]]) -> Tensor:
"""
Args:
features: features with shape [batch_size, features_dim]
labels: labels of samples having batch_size elements
Returns: loss value
"""
labels_list = convert_labels2list(labels)
features_anchor, features_positive, features_negative = self._sampler_inbatch.sample(
features=features, labels=labels_list
)
loss = self._triplet_margin_loss(
anchor=features_anchor, positive=features_positive, negative=features_negative,
)
return loss
__all__ = [
"TripletLoss",
"TripletLossV2",
"TripletPairwiseEmbeddingLoss",
"TripletMarginLossWithSampler",
]