Source code for catalyst.data.sampler_inbatch
from typing import List, Tuple
from abc import ABC, abstractmethod
from collections import Counter
from itertools import combinations, product
from random import sample
from sys import maxsize
import numpy as np
import torch
from torch import Tensor
from catalyst.contrib.utils.misc import find_value_ids
from catalyst.utils.torch import normalize
# order in the triplets: (anchor, positive, negative)
TTriplets = Tuple[Tensor, Tensor, Tensor]
TTripletsIds = Tuple[List[int], List[int], List[int]]
[docs]class InBatchTripletsSampler(ABC):
"""
Base class for a triplets samplers.
We expect that the child instances of this class
will be used to forming triplets inside the batches.
The batches must contain at least 2 samples for
each class and at least 2 different classes,
such behaviour can be garantee via using
catalyst.data.sampler.BalanceBatchSampler
But you are not limited to using it in any other way.
"""
@staticmethod
def _check_input_labels(labels: List[int]) -> None:
"""
The input must satisfy the conditions described in
the class documentation.
Args:
labels: labels of the samples in the batch
"""
labels_counter = Counter(labels)
assert all(n > 1 for n in labels_counter.values())
assert len(labels_counter) > 1
@abstractmethod
def _sample(self, features: Tensor, labels: List[int]) -> TTripletsIds:
"""
This method includes the logic of sampling/selecting triplets
inside the batch. It can be based on information about
the distance between the features, or the
choice can be made randomly.
Args:
features: has the shape of [batch_size, feature_size]
labels: labels of the samples in the batch
Returns: indeces of the batch samples to forming triplets.
"""
raise NotImplementedError
[docs] def sample(self, features: Tensor, labels: List[int]) -> TTriplets:
"""
Args:
features: has the shape of [batch_size, feature_size]
labels: labels of the samples in the batch
Returns:
the batch of the triplets in the order below:
(anchor, positive, negative)
"""
self._check_input_labels(labels=labels)
ids_anchor, ids_pos, ids_neg = self._sample(
features=features, labels=labels
)
return features[ids_anchor], features[ids_pos], features[ids_neg]
[docs]class AllTripletsSampler(InBatchTripletsSampler):
"""
This sampler selects all the possible triplets for the given labels
"""
[docs] def __init__(self, max_output_triplets: int = maxsize):
"""
Args:
max_output_triplets: with the strategy of choosing all
the triplets, their number in the batch can be very large,
because of it we can sample only random part of them,
determined by this parameter.
"""
self._max_out_triplets = max_output_triplets
def _sample(self, *_: Tensor, labels: List[int]) -> TTripletsIds:
"""
Args:
labels: labels of the samples in the batch
*_: note, that we ignore features argument
Returns: indeces of triplets
"""
num_labels = len(labels)
triplets = []
for label in set(labels):
ids_pos_cur = set(find_value_ids(labels, label))
ids_neg_cur = set(range(num_labels)) - ids_pos_cur
pos_pairs = list(combinations(ids_pos_cur, r=2))
tri = [(a, p, n) for (a, p), n in product(pos_pairs, ids_neg_cur)]
triplets.extend(tri)
triplets = sample(triplets, min(len(triplets), self._max_out_triplets))
ids_anchor, ids_pos, ids_neg = zip(*triplets)
return list(ids_anchor), list(ids_pos), list(ids_neg)
[docs]class HardTripletsSampler(InBatchTripletsSampler):
"""
This sampler selects hardest triplets based on distances between features:
the hardest positive sample has the maximal distance to the anchor sample,
the hardest negative sample has the minimal distance to the anchor sample.
"""
[docs] def __init__(self, need_norm: bool = False):
"""
Args:
need_norm: set True if features normalisation is needed
"""
self._need_norm = need_norm
def _sample(self, features: Tensor, labels: List[int]) -> TTripletsIds:
"""
This method samples the hardest triplets inside the batch.
Args:
features: has the shape of [batch_size, feature_size]
labels: labels of the samples in the batch
Returns:
the batch of the triplets in the order below:
(anchor, positive, negative)
"""
assert features.shape[0] == len(labels)
if self._need_norm:
features = normalize(samples=features)
dist_mat = torch.cdist(x1=features, x2=features, p=2)
ids_anchor, ids_pos, ids_neg = self._sample_from_distmat(
distmat=dist_mat, labels=labels
)
return ids_anchor, ids_pos, ids_neg
@staticmethod
def _sample_from_distmat(
distmat: Tensor, labels: List[int]
) -> TTripletsIds:
"""
This method samples the hardest triplets based on the given
distances matrix. It chooses each sample in the batch as an
anchor and then finds the harderst positive and negative pair.
Args:
distmat: matrix of distances between the features
labels: labels of the samples in the batch
Returns:
the batch of triplets in the order below:
(anchor, positive, negative)
"""
ids_all = set(range(len(labels)))
ids_anchor, ids_pos, ids_neg = [], [], []
for i_anch, label in enumerate(labels):
ids_label = set(find_value_ids(it=labels, value=label))
ids_pos_cur = np.array(list(ids_label - {i_anch}), int)
ids_neg_cur = np.array(list(ids_all - ids_label), int)
i_pos = ids_pos_cur[distmat[i_anch, ids_pos_cur].argmax()]
i_neg = ids_neg_cur[distmat[i_anch, ids_neg_cur].argmin()]
ids_anchor.append(i_anch)
ids_pos.append(i_pos)
ids_neg.append(i_neg)
return ids_anchor, ids_pos, ids_neg