from typing import Iterator, List, Optional
from operator import itemgetter

import numpy as np

from import DistributedSampler
from import Sampler

from import DatasetFromSampler

[docs]class BalanceClassSampler(Sampler): """Abstraction over data sampler. Allows you to create stratified sample on unbalanced classes. """
[docs] def __init__(self, labels: List[int], mode: str = "downsampling"): """ Args: labels (List[int]): list of class label for each elem in the datasety mode (str): Strategy to balance classes. Must be one of [downsampling, upsampling] """ super().__init__(labels) labels = np.array(labels) samples_per_class = { label: (labels == label).sum() for label in set(labels) } self.lbl2idx = { label: np.arange(len(labels))[labels == label].tolist() for label in set(labels) } if isinstance(mode, str): assert mode in ["downsampling", "upsampling"] if isinstance(mode, int) or mode == "upsampling": samples_per_class = ( mode if isinstance(mode, int) else max(samples_per_class.values()) ) else: samples_per_class = min(samples_per_class.values()) self.labels = labels self.samples_per_class = samples_per_class self.length = self.samples_per_class * len(set(labels))
[docs] def __iter__(self) -> Iterator[int]: """ Yields: indices of stratified sample """ indices = [] for key in sorted(self.lbl2idx): replace_ = self.samples_per_class > len(self.lbl2idx[key]) indices += np.random.choice( self.lbl2idx[key], self.samples_per_class, replace=replace_ ).tolist() assert len(indices) == self.length np.random.shuffle(indices) return iter(indices)
[docs] def __len__(self) -> int: """ Returns: length of result sample """ return self.length
[docs]class MiniEpochSampler(Sampler): """ Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``. Example: >>> MiniEpochSampler(len(dataset), mini_epoch_len=100) >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, >>> drop_last=True) >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, >>> shuffle="per_epoch") """
[docs] def __init__( self, data_len: int, mini_epoch_len: int, drop_last: bool = False, shuffle: str = None, ): """ Args: data_len (int): Size of the dataset mini_epoch_len (int): Num samples from the dataset used in one mini epoch. drop_last (bool): If ``True``, sampler will drop the last batches if its size would be less than ``batches_per_epoch`` shuffle (str): one of ``"always"``, ``"real_epoch"``, or `None``. The sampler will shuffle indices > "per_mini_epoch" - every mini epoch (every ``__iter__`` call) > "per_epoch" -- every real epoch > None -- don't shuffle """ super().__init__(None) self.data_len = int(data_len) self.mini_epoch_len = int(mini_epoch_len) self.steps = int(data_len / self.mini_epoch_len) self.state_i = 0 has_reminder = data_len - self.steps * mini_epoch_len > 0 if self.steps == 0: self.divider = 1 elif has_reminder and not drop_last: self.divider = self.steps + 1 else: self.divider = self.steps self._indices = np.arange(self.data_len) self.indices = self._indices self.end_pointer = max(self.data_len, self.mini_epoch_len) if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]): raise ValueError( f"Shuffle must be one of ['per_mini_epoch', 'per_epoch']. " f"Got {shuffle}" ) self.shuffle_type = shuffle
[docs] def shuffle(self) -> None: """@TODO: Docs. Contribution is welcome.""" if self.shuffle_type == "per_mini_epoch" or ( self.shuffle_type == "per_epoch" and self.state_i == 0 ): if self.data_len >= self.mini_epoch_len: self.indices = self._indices np.random.shuffle(self.indices) else: self.indices = np.random.choice( self._indices, self.mini_epoch_len, replace=True )
[docs] def __iter__(self) -> Iterator[int]: """@TODO: Docs. Contribution is welcome.""" self.state_i = self.state_i % self.divider self.shuffle() start = self.state_i * self.mini_epoch_len stop = ( self.end_pointer if (self.state_i == self.steps) else (self.state_i + 1) * self.mini_epoch_len ) indices = self.indices[start:stop].tolist() self.state_i += 1 return iter(indices)
[docs] def __len__(self) -> int: """ Returns: int: length of the mini-epoch """ return self.mini_epoch_len
[docs]class DistributedSamplerWrapper(DistributedSampler): """ Wrapper over `Sampler` for distributed training. Allows you to use any sampler in distributed mode. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each process can pass a DistributedSamplerWrapper instance as a DataLoader sampler, and load a subset of subsampled data of the original dataset that is exclusive to it. .. note:: Sampler is assumed to be of constant size. """
[docs] def __init__( self, sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, ): """ Args: sampler: Sampler used for subsampling num_replicas (int, optional): Number of processes participating in distributed training rank (int, optional): Rank of the current process within ``num_replicas`` shuffle (bool, optional): If true (default), sampler will shuffle the indices """ super(DistributedSamplerWrapper, self).__init__( DatasetFromSampler(sampler), num_replicas=num_replicas, rank=rank, shuffle=shuffle, ) self.sampler = sampler
[docs] def __iter__(self): """@TODO: Docs. Contribution is welcome.""" self.dataset = DatasetFromSampler(self.sampler) indexes_of_indexes = super().__iter__() subsampler_indexes = self.dataset return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
__all__ = [ "BalanceClassSampler", "MiniEpochSampler", "DistributedSamplerWrapper", ]