Data¶
Data subpackage has data preprocessers and dataloader abstractions.
Dataset¶
DatasetFromSampler¶
SelfSupervisedDatasetWrapper¶
-
class
catalyst.data.dataset.
SelfSupervisedDatasetWrapper
(dataset: torch.utils.data.dataset.Dataset, transforms: Callable = None, transform_left: Callable = None, transform_right: Callable = None, transform_original: Callable = None, is_target: bool = True)[source]¶ Bases:
torch.utils.data.dataset.Dataset
The Self Supervised Dataset.
The class implemets contrastive logic (see Figure 2 from A Simple Framework for Contrastive Learning of Visual Representations.)
- Parameters
dataset – original dataset for augmentation
transforms – transforms which will be applied to original batch to get both left and right output batch.
transform_left – transform only for left batch
transform_right – transform only for right batch
transform_original – transforms which will be applied to save original in batch
is_target – the flag for selection does dataset return (sample, target) or only sample
Example:
import torchvision from torchvision.datasets import CIFAR10 from catalyst.data.dataset import SelfSupervisedDatasetWrapper transforms = torchvision.transforms.Compose( [ torchvision.transforms.RandomResizedCrop(32), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010] ), torchvision.transforms.ColorJitter(0.8, 0.8, 0.8, 0.2), ] ) cifar_dataset = CIFAR10(root="./data", download=True, transform=None) cifar_contrastive = SelfSupervisedDatasetWrapper( cifar_dataset, transforms=transforms ) for transformed_sample, aug_1, aug_2 in cifar_contrastive: pass
Loader¶
BatchLimitLoaderWrapper¶
-
class
catalyst.data.loader.
BatchLimitLoaderWrapper
(loader: torch.utils.data.dataloader.DataLoader, num_batches: Union[int, float])[source]¶ Loader wrapper. Limits number of batches used per each iteration.
For example, if you have some loader and want to use only first 5 bathes:
import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.data.loader import BatchLimitLoaderWrapper num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loader = BatchLimitLoaderWrapper(loader, num_batches=5)
or if you would like to use only some portion of Dataloader (we use 30% in the example below):
import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.data.loader import BatchLimitLoaderWrapper num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loader = BatchLimitLoaderWrapper(loader, num_batches=0.3)
Note
Generally speaking, this wrapper could be used with any iterator-like object. No
DataLoader
-specific code used.
BatchPrefetchLoaderWrapper¶
-
class
catalyst.data.loader.
BatchPrefetchLoaderWrapper
(loader: torch.utils.data.dataloader.DataLoader, num_prefetches: int = None)[source]¶ Loader wrapper. Prefetches specified number of batches on the GPU.
Base usage:
import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.data import BatchPrefetchLoaderWrapper num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loader = BatchPrefetchLoaderWrapper(loader)
Minimal working example:
import os import torch from torch.nn import functional as F from torch.utils.data import DataLoader from catalyst import dl, metrics from catalyst.data.cv import ToTensor from catalyst.contrib.datasets import MNIST from catalyst.data import BatchPrefetchLoaderWrapper class CustomRunner(dl.Runner): def handle_batch(self, batch): # model train/valid step x, y = batch y_hat = self.model(x.view(x.size(0), -1)) loss = F.cross_entropy(y_hat, y) accuracy01 = metrics.accuracy(y_hat, y, topk=(1, )) self.batch_metrics.update( {"loss": loss, "accuracy01": accuracy01} ) if self.is_train_loader: self.engine.backward(loss) self.optimizer.step() self.optimizer.zero_grad() model = torch.nn.Linear(28 * 28, 10) optimizer = torch.optim.Adam(model.parameters(), lr=0.02) batch_size=32 loaders = { "train": DataLoader( MNIST( os.getcwd(), train=True, download=True, transform=ToTensor() ), batch_size=batch_size), "valid": DataLoader( MNIST( os.getcwd(), train=False, download=True, transform=ToTensor() ), batch_size=batch_size), } loaders = { k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items() } runner = CustomRunner() # model training runner.train( model=model, optimizer=optimizer, loaders=loaders, logdir="./logs", num_epochs=5, verbose=True, load_best_on_end=True, )
Samplers¶
BalanceClassSampler¶
-
class
catalyst.data.sampler.
BalanceClassSampler
(labels: List[int], mode: Union[str, int] = 'downsampling')[source]¶ Allows you to create stratified sample on unbalanced classes.
- Parameters
labels – list of class label for each elem in the dataset
mode – Strategy to balance classes. Must be one of [downsampling, upsampling]
Python API examples:
import os from torch import nn, optim from torch.utils.data import DataLoader from catalyst import dl from catalyst.data import ToTensor, BalanceClassSampler from catalyst.contrib.datasets import MNIST train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()) train_labels = train_data.targets.cpu().numpy().tolist() train_sampler = BalanceClassSampler(train_labels, mode=5000) valid_data = MNIST(os.getcwd(), train=False) loaders = { "train": DataLoader(train_data, sampler=train_sampler, batch_size=32), "valid": DataLoader(valid_data, batch_size=32), } model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.02) runner = dl.SupervisedRunner() # model training runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=1, logdir="./logs", valid_loader="valid", valid_metric="loss", minimize_valid_metric=True, verbose=True, )
BatchBalanceClassSampler¶
-
class
catalyst.data.sampler.
BatchBalanceClassSampler
(labels: Union[List[int], numpy.ndarray], num_classes: int, num_samples: int, num_batches: int = None)[source]¶ This kind of sampler can be used for both metric learning and classification task.
BatchSampler with the given strategy for the C unique classes dataset: - Selection num_classes of C classes for each batch - Selection num_samples instances for each class in the batch The epoch ends after num_batches. So, the batch sise is num_classes * num_samples.
One of the purposes of this sampler is to be used for forming triplets and pos/neg pairs inside the batch. To guarante existance of these pairs in the batch, num_classes and num_samples should be > 1. (1)
This type of sampling can be found in the classical paper of Person Re-Id, where P (num_classes) equals 32 and K (num_samples) equals 4: In Defense of the Triplet Loss for Person Re-Identification.
- Parameters
labels – list of classes labeles for each elem in the dataset
num_classes – number of classes in a batch, should be > 1
num_samples – number of instances of each class in a batch, should be > 1
num_batches – number of batches in epoch (default = len(labels) // (num_classes * num_samples))
Python API examples:
import os from torch import nn, optim from torch.utils.data import DataLoader from catalyst import dl from catalyst.data import ToTensor, BatchBalanceClassSampler from catalyst.contrib.datasets import MNIST train_data = MNIST(os.getcwd(), train=True, download=True) train_labels = train_data.targets.cpu().numpy().tolist() train_sampler = BatchBalanceClassSampler( train_labels, num_classes=10, num_samples=4) valid_data = MNIST(os.getcwd(), train=False) loaders = { "train": DataLoader(train_data, batch_sampler=train_sampler), "valid": DataLoader(valid_data, batch_size=32), } model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.02) runner = dl.SupervisedRunner() # model training runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=1, logdir="./logs", valid_loader="valid", valid_metric="loss", minimize_valid_metric=True, verbose=True, )
DistributedSamplerWrapper¶
-
class
catalyst.data.sampler.
DistributedSamplerWrapper
(sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True)[source]¶ Wrapper over Sampler for distributed training. Allows you to use any sampler in distributed mode.
It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such case, each process can pass a DistributedSamplerWrapper instance as a DataLoader sampler, and load a subset of subsampled data of the original dataset that is exclusive to it.
Note
Sampler is assumed to be of constant size.
DynamicBalanceClassSampler¶
-
class
catalyst.data.sampler.
DynamicBalanceClassSampler
(labels: List[Union[str, int]], exp_lambda: float = 0.9, start_epoch: int = 0, max_d: Optional[int] = None, mode: Union[str, int] = 'downsampling', ignore_warning: bool = False)[source]¶ This kind of sampler can be used for classification tasks with significant class imbalance.
The idea of this sampler that we start with the original class distribution and gradually move to uniform class distribution like with downsampling.
Let’s define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class i and :math: #C_min is a size of the rarest class, so :math: D_i define class distribution. Also define :math: g(n_epoch) is a exponential scheduler. On each epoch current :math: D_i calculated as :math: current D_i = D_i ^ g(n_epoch), after this data samples according this distribution.
Notes
In the end of the training, epochs will contain only min_size_class * n_classes examples. So, possible it will not necessary to do validation on each epoch. For this reason use ControlFlowCallback.
Examples
>>> import torch >>> import numpy as np
>>> from catalyst.data import DynamicBalanceClassSampler >>> from torch.utils import data
>>> features = torch.Tensor(np.random.random((200, 100))) >>> labels = np.random.randint(0, 4, size=(200,)) >>> sampler = DynamicBalanceClassSampler(labels) >>> labels = torch.LongTensor(labels) >>> dataset = data.TensorDataset(features, labels) >>> loader = data.dataloader.DataLoader(dataset, batch_size=8)
>>> for batch in loader: >>> b_features, b_labels = batch
Sampler was inspired by https://arxiv.org/abs/1901.06783
MiniEpochSampler¶
-
class
catalyst.data.sampler.
MiniEpochSampler
(data_len: int, mini_epoch_len: int, drop_last: bool = False, shuffle: str = None)[source]¶ Sampler iterates mini epochs from the dataset used by
mini_epoch_len
.- Parameters
data_len – Size of the dataset
mini_epoch_len – Num samples from the dataset used in one mini epoch.
drop_last – If
True
, sampler will drop the last batches if its size would be less thanbatches_per_epoch
shuffle – one of
"always"
,"real_epoch"
, or None`. The sampler will shuffle indices > “per_mini_epoch” - every mini epoch (every__iter__
call) > “per_epoch” – every real epoch > None – don’t shuffle
Example
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100) >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True) >>> MiniEpochSampler(len(dataset), mini_epoch_len=100, >>> shuffle="per_epoch")