Shortcuts

Source code for catalyst.data.sampler_inbatch

from typing import List, Tuple, Union
from abc import ABC, abstractmethod
from collections import Counter
from itertools import combinations, product
from random import sample
from sys import maxsize

import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F

from catalyst.utils.misc import convert_labels2list, find_value_ids

# order in the triplets: (anchor, positive, negative)
TTriplets = Tuple[Tensor, Tensor, Tensor]
TTripletsIds = Tuple[List[int], List[int], List[int]]
TLabels = Union[List[int], Tensor]


[docs]class IInbatchTripletSampler(ABC): """An abstraction of inbatch triplet sampler.""" @abstractmethod def _check_input_labels(self, labels: List[int]) -> None: """ Check if the batch labels list is valid for the sampler. We expect you to implement this method to guarantee correct performance of sampling method. You can pass it but we strongly do not recommend you to do it. Args: labels: labels of the samples in the batch; list or Tensor of shape (batch_size;) """ raise NotImplementedError()
[docs] @abstractmethod def sample(self, features: Tensor, labels: TLabels) -> TTriplets: """ This method includes the logic of sampling/selecting triplets. Args: features: tensor of features labels: labels of the samples in the batch, list or Tensor of shape (batch_size;) Returns: the batch of triplets Raises: NotImplementedError: you should implement it """ raise NotImplementedError()
[docs]class InBatchTripletsSampler(IInbatchTripletSampler): """ Base class for a triplets samplers. We expect that the child instances of this class will be used to forming triplets inside the batches. (Note. It is assumed that set of output features is a subset of samples features inside the batch.) The batches must contain at least 2 samples for each class and at least 2 different classes, such behaviour can be garantee via using catalyst.data.sampler.BalanceBatchSampler But you are not limited to using it in any other way. """ def _check_input_labels(self, labels: List[int]) -> None: """ The input must satisfy the conditions described in the class documentation. Args: labels: labels of the samples in the batch """ labels_counter = Counter(labels) assert all(n > 1 for n in labels_counter.values()) assert len(labels_counter) > 1 @abstractmethod def _sample(self, features: Tensor, labels: List[int]) -> TTripletsIds: """ This method includes the logic of sampling/selecting triplets inside the batch. It can be based on information about the distance between the features, or the choice can be made randomly. Args: features: has the shape of [batch_size, feature_size] labels: labels of the samples in the batch Returns: indices of the batch samples to forming triplets. """ raise NotImplementedError
[docs] def sample(self, features: Tensor, labels: TLabels) -> TTriplets: """ Args: features: has the shape of [batch_size, feature_size] labels: labels of the samples in the batch Returns: the batch of the triplets in the order below: (anchor, positive, negative) """ # Convert labels to list labels = convert_labels2list(labels) self._check_input_labels(labels=labels) ids_anchor, ids_pos, ids_neg = self._sample(features, labels=labels) return features[ids_anchor], features[ids_pos], features[ids_neg]
[docs]class AllTripletsSampler(InBatchTripletsSampler): """ This sampler selects all the possible triplets for the given labels """
[docs] def __init__(self, max_output_triplets: int = maxsize): """ Args: max_output_triplets: with the strategy of choosing all the triplets, their number in the batch can be very large, because of it we can sample only random part of them, determined by this parameter. """ self._max_out_triplets = max_output_triplets
def _sample(self, *_: Tensor, labels: List[int]) -> TTripletsIds: """ Args: labels: labels of the samples in the batch *_: note, that we ignore features argument Returns: indices of triplets """ num_labels = len(labels) triplets = [] for label in set(labels): ids_pos_cur = set(find_value_ids(labels, label)) ids_neg_cur = set(range(num_labels)) - ids_pos_cur pos_pairs = list(combinations(ids_pos_cur, r=2)) tri = [(a, p, n) for (a, p), n in product(pos_pairs, ids_neg_cur)] triplets.extend(tri) triplets = sample(triplets, min(len(triplets), self._max_out_triplets)) ids_anchor, ids_pos, ids_neg = zip(*triplets) return list(ids_anchor), list(ids_pos), list(ids_neg)
[docs]class HardTripletsSampler(InBatchTripletsSampler): """ This sampler selects hardest triplets based on distances between features: the hardest positive sample has the maximal distance to the anchor sample, the hardest negative sample has the minimal distance to the anchor sample. Note that a typical triplet loss chart is as follows: 1. Falling: loss decreases to a value equal to the margin. 2. Long plato: the loss oscillates near the margin. 3. Falling: loss decreases to zero. """
[docs] def __init__(self, norm_required: bool = False): """ Args: norm_required: set True if features normalisation is needed """ self._norm_required = norm_required
def _sample(self, features: Tensor, labels: List[int]) -> TTripletsIds: """ This method samples the hardest triplets inside the batch. Args: features: has the shape of [batch_size, feature_size] labels: labels of the samples in the batch Returns: the batch of the triplets in the order below: (anchor, positive, negative) """ assert features.shape[0] == len(labels) if self._norm_required: features = F.normalize(features.detach(), p=2, dim=1) dist_mat = torch.cdist(x1=features, x2=features, p=2) ids_anchor, ids_pos, ids_neg = self._sample_from_distmat(distmat=dist_mat, labels=labels) return ids_anchor, ids_pos, ids_neg @staticmethod def _sample_from_distmat(distmat: Tensor, labels: List[int]) -> TTripletsIds: """ This method samples the hardest triplets based on the given distances matrix. It chooses each sample in the batch as an anchor and then finds the harderst positive and negative pair. Args: distmat: matrix of distances between the features labels: labels of the samples in the batch Returns: the batch of triplets in the order below: (anchor, positive, negative) """ ids_all = set(range(len(labels))) ids_anchor, ids_pos, ids_neg = [], [], [] for i_anch, label in enumerate(labels): ids_label = set(find_value_ids(it=labels, value=label)) ids_pos_cur = np.array(list(ids_label - {i_anch}), int) ids_neg_cur = np.array(list(ids_all - ids_label), int) i_pos = ids_pos_cur[distmat[i_anch, ids_pos_cur].argmax()] i_neg = ids_neg_cur[distmat[i_anch, ids_neg_cur].argmin()] ids_anchor.append(i_anch) ids_pos.append(i_pos) ids_neg.append(i_neg) return ids_anchor, ids_pos, ids_neg
[docs]class HardClusterSampler(IInbatchTripletSampler): """ This sampler selects hardest triplets based on distance to mean vectors: anchor is a mean vector of features of i-th class in the batch, the hardest positive sample is the most distant from anchor sample of anchor's class, the hardest negative sample is the closest mean vector of another classes. The batch must contain k samples for p classes in it (k > 1, p > 1). """ def _check_input_labels(self, labels: List[int]) -> None: """ Check if the labels list is valid: contains k occurrences for each of p classes. Args: labels: labels in the batch Raises: ValueError: if batch is invalid (contains different samples for classes, contains only one class or only one sample for each class) """ labels_counter = Counter(labels) k = labels_counter[labels[0]] if not all(n == k for n in labels_counter.values()): raise ValueError("Expected equal number of samples for each class") if len(labels_counter) <= 1: raise ValueError("Expected at least 2 classes in the batch") if k == 1: raise ValueError("Expected more than one sample for each class") @staticmethod def _get_labels_mask(labels: List[int]) -> Tensor: """ Generate matrix of bool of shape (n_unique_labels, batch_size), where n_unique_labels is a number of unique labels in the batch; matrix[i, j] is True if j-th element of the batch relates to i-th class and False otherwise. Args: labels: labels of the batch, shape (batch_size) Returns: matrix of indices of classes in batch """ unique_labels = sorted(np.unique(labels)) labels_number = len(unique_labels) labels_mask = torch.zeros(size=(labels_number, len(labels))) for label_idx, label in enumerate(unique_labels): label_indices = find_value_ids(labels, label) labels_mask[label_idx][label_indices] = 1 return labels_mask.type(torch.bool) @staticmethod def _count_intra_class_distances(embeddings: Tensor, mean_vectors: Tensor) -> Tensor: """ Count matrix of distances from mean vector of each class to it's samples embeddings. Args: embeddings: tensor of shape (p, k, embed_dim) where p is a number of classes in the batch, k is a number of samples for each class mean_vectors: tensor of shape (p, embed_dim) -- mean vectors of each class in the batch Returns: tensor of shape (p, k) -- matrix of distances from mean vectors to related samples in the batch """ p, k, embed_dim = embeddings.shape # Create (p, k, embed_dim) tensor of mean vectors for each class mean_vectors = mean_vectors.unsqueeze(1).repeat((1, k, 1)) # Count euclidean distance between embeddings and mean vectors distances = torch.pow(embeddings - mean_vectors, 2).sum(2) return distances @staticmethod def _count_inter_class_distances(mean_vectors: Tensor) -> Tensor: """ Count matrix of distances from mean vectors of classes to each other Args: mean_vectors: tensor of shape (p, embed_dim) -- mean vectors of classes Returns: tensor of shape (p, p) -- matrix of distances between mean vectors """ distance = torch.cdist(x1=mean_vectors, x2=mean_vectors, p=2) return distance @staticmethod def _fill_diagonal(matrix: Tensor, value: float) -> Tensor: """ Set diagonal elements with the value. Args: matrix: tensor of shape (p, p) value: value that diagonal should be filled with Returns: modified matrix with inf on diagonal """ p, _ = matrix.shape indices = torch.diag(torch.ones(p)).type(torch.bool) matrix[indices] = value return matrix
[docs] def sample(self, features: Tensor, labels: TLabels) -> TTriplets: """ This method samples the hardest triplets in the batch. Args: features: tensor of shape (batch_size; embed_dim) that contains k samples for each of p classes labels: labels of the batch, list or tensor of size (batch_size) Returns: p triplets of (mean_vector, positive, negative_mean_vector) """ # Convert labels to list labels = convert_labels2list(labels) self._check_input_labels(labels) # Get matrix of indices of labels in batch labels_mask = self._get_labels_mask(labels) p = labels_mask.shape[0] embed_dim = features.shape[-1] # Reshape embeddings to groups of (p, k, embed_dim) ones, # each i-th group contains embeddings of i-th class. features = features.repeat((p, 1, 1)) features = features[labels_mask].view((p, -1, embed_dim)) # Count mean vectors for each class in batch mean_vectors = features.mean(1) d_intra = self._count_intra_class_distances(features, mean_vectors) # Count the distances to the sample farthest from mean vector # for each class. pos_indices = d_intra.max(1).indices # Count matrix of distances from mean vectors to each other d_inter = self._count_inter_class_distances(mean_vectors) # For each class mean vector get the closest mean vector d_inter = self._fill_diagonal(d_inter, float("inf")) neg_indices = d_inter.min(1).indices positives = torch.stack( [features[idx][pos_idx] for idx, pos_idx in enumerate(pos_indices)] ) return mean_vectors, positives, mean_vectors[neg_indices]
__all__ = [ "IInbatchTripletSampler", "InBatchTripletsSampler", "AllTripletsSampler", "HardTripletsSampler", "HardClusterSampler", ]