import torch
import torch.nn as nn
from .functional import triplet_loss
[docs]class TripletLoss(nn.Module):
"""
Triplet loss with hard positive/negative mining.
Reference:
Code imported from https://github.com/NegatioN/OnlineMiningTripletLoss.
Args:
margin (float): margin for triplet.
"""
[docs] def __init__(self, margin=0.3):
"""
Constructor method for the TripletLoss class.
Args:
margin: margin parameter.
"""
super().__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def _pairwise_distances(self, embeddings, squared=False):
"""
Compute the 2D matrix of distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise
squared euclidean distance matrix. If false, output
is the pairwise euclidean distance matrix.
Returns:
pairwise_distances: tensor of shape (batch_size, batch_size)
"""
# Get squared L2 norm for each embedding.
# We can just take the diagonal of `dot_product`.
# This also provides more numerical stability
# (the diagonal of the result will be exactly 0).
# shape (batch_size,)
square = torch.mm(embeddings, embeddings.t())
diag = torch.diag(square)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = diag.view(-1, 1) - 2.0 * square + diag.view(1, -1)
# Because of computation errors, some distances
# might be negative so we put everything >= 0.0
distances[distances < 0] = 0
if not squared:
# Because the gradient of sqrt is infinite
# when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = distances.eq(0).float()
distances = distances + mask * 1e-16
distances = (1.0 - mask) * torch.sqrt(distances)
return distances
def _get_anchor_positive_triplet_mask(self, labels):
"""
Return a 2D mask where mask[a, p] is True
if a and p are distinct and have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
indices_equal = torch.eye(labels.size(0)).bool()
# labels and indices should be on
# the same device, otherwise - exception
indices_equal = indices_equal.to("cuda" if labels.is_cuda else "cpu")
# Check that i and j are distinct
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
# Uses broadcasting where the 1st argument
# has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(self, labels):
"""
Return 2D mask where mask[a, n] is True if a and n have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check if labels[i] != labels[k]
# Uses broadcasting where the 1st argument
# has shape (1, batch_size) and the 2nd (batch_size, 1)
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
def _batch_hard_triplet_loss(
self,
embeddings,
labels,
margin,
squared=True,
):
"""
Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and
hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared
euclidean distance matrix. If false, output is the
pairwise euclidean distance matrix.
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = self._pairwise_distances(embeddings, squared=squared)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid
# positive (they should have same label)
mask_anchor_positive = self._get_anchor_positive_triplet_mask(
labels).float()
# We put to 0 any element where (a, p) is not valid
# (valid if a != p and label(a) == label(p))
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# shape (batch_size, 1)
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative
# (they should have different labels)
mask_anchor_negative = \
self._get_anchor_negative_triplet_mask(labels).float()
# We add the maximum value in each row
# to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * \
(1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
tl = hardest_positive_dist - hardest_negative_dist + margin
tl[tl < 0] = 0
triplet_loss = tl.mean()
return triplet_loss
[docs] def forward(self, embeddings, targets):
"""
Forward propagation method for the triplet loss.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
targets: labels of the batch, of size (batch_size,)
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
return self._batch_hard_triplet_loss(embeddings, targets, self.margin)
class TripletLossV2(nn.Module):
"""
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=0.3):
"""
Constructor method for the TripletLoss class.
Args:
margin: margin parameter.
"""
super().__init__()
self.margin = margin
def forward(self, embeddings, targets):
return triplet_loss(
embeddings,
targets,
margin=self.margin,
)
[docs]class TripletPairwiseEmbeddingLoss(nn.Module):
"""
TripletPairwiseEmbeddingLoss – proof of concept criterion.
Still work in progress.
"""
[docs] def __init__(self, margin=0.3, reduction="mean"):
"""
Constructor method for the TripletPairwiseEmbeddingLoss class.
Args:
margin: margin parameter.
reduction: criterion reduction type.
"""
super().__init__()
self.margin = margin
self.reduction = reduction or "none"
[docs] def forward(self, embeddings_pred, embeddings_true):
"""
Work in progress.
Args:
embeddings_pred: predicted embeddings
with shape [batch_size, embedding_size]
embeddings_true: true embeddings
with shape [batch_size, embedding_size]
Returns:
torch.Tensor: loss
"""
device = embeddings_pred.device
# s - state space
# d - embeddings space
# a - action space
# [batch_size, embedding_size] x [batch_size, embedding_size]
# -> [batch_size, batch_size]
pairwise_similarity = torch.einsum(
"se,ae->sa",
embeddings_pred,
embeddings_true
)
bs = embeddings_pred.shape[0]
batch_idx = torch.arange(bs, device=device)
negative_similarity = (
pairwise_similarity
+ torch.diag(torch.full([bs], -10 ** 9, device=device))
)
# TODO argsort, take k worst
hard_negative_ids = negative_similarity.argmax(dim=-1)
negative_similarities = \
pairwise_similarity[batch_idx, hard_negative_ids]
positive_similarities = pairwise_similarity[batch_idx, batch_idx]
loss = torch.relu(
self.margin - positive_similarities + negative_similarities
)
if self.reduction == "mean":
loss = torch.sum(loss) / bs
elif self.reduction == "sum":
loss = torch.sum(loss)
return loss
__all__ = ["TripletLoss", "TripletPairwiseEmbeddingLoss"]