Source code for catalyst.data.dataset
from typing import Any, Callable, Dict
from torch.utils.data import Dataset, Sampler
[docs]class DatasetFromSampler(Dataset):
"""Dataset to create indexes from `Sampler`.
Args:
sampler: PyTorch sampler
"""
def __init__(self, sampler: Sampler):
"""Initialisation for DatasetFromSampler."""
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
"""Gets element of the dataset.
Args:
index: index of the element in the dataset
Returns:
Single element by index
"""
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)
[docs]class SelfSupervisedDatasetWrapper(Dataset):
"""The Self Supervised Dataset.
The class implemets contrastive logic (see Figure 2 from `A Simple Framework
for Contrastive Learning of Visual Representations`_.)
Args:
dataset: original dataset for augmentation
transforms: transforms which will be applied to original batch to get both
left and right output batch.
transform_left: transform only for left batch
transform_right: transform only for right batch
transform_original: transforms which will be applied to save original in batch
is_target: the flag for selection does dataset return (sample, target) or only sample
Example:
.. code-block:: python
import torchvision
from torchvision.datasets import CIFAR10
from catalyst.contrib.data.dataset import SelfSupervisedDatasetWrapper
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(32),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
),
torchvision.transforms.ColorJitter(0.8, 0.8, 0.8, 0.2),
]
)
cifar_dataset = CIFAR10(root="./data", download=True, transform=None)
cifar_contrastive = SelfSupervisedDatasetWrapper(cifar_dataset, transforms=transforms)
.. _`A Simple Framework for Contrastive Learning of Visual Representations`:
https://arxiv.org/abs/2002.05709
"""
def __init__(
self,
dataset: Dataset,
transforms: Callable = None,
transform_left: Callable = None,
transform_right: Callable = None,
transform_original: Callable = None,
is_target: bool = True,
) -> None:
"""
Args:
dataset: original dataset for augmentation
transforms: transforms which will be applied to original batch to get both
left and right output batch.
transform_left: transform only for left batch
transform_right: transform only for right batch
transform_original: transforms which will be applied to save original in batch
is_target: the flag for selection does dataset return (sample, target) or only sample
Raises:
ValueError: should be specified transform_left and transform_right simultaneously
or only transforms
"""
super().__init__()
if transform_right is not None and transform_left is not None:
self.transform_right = transform_right
self.transform_left = transform_left
elif transforms is not None:
self.transform_right = transforms
self.transform_left = transforms
else:
raise ValueError(
"Specify transform_left and transform_right simultaneously or only transforms."
)
self.transform_original = transform_original
self.dataset = dataset
self.is_target = is_target
def __len__(self) -> int:
"""Length"""
return len(self.dataset)
def __getitem__(self, idx) -> Dict[str, Any]:
"""Get item method for dataset
Args:
idx: index of the object
Returns:
Dict with left agumention (aug1), right agumention (aug2) and target
"""
if self.is_target:
sample, target = self.dataset[idx]
else:
sample = self.dataset[idx]
transformed_sample = self.transform_original(sample) if self.transform_original else sample
aug_1 = self.transform_left(sample)
aug_2 = self.transform_right(sample)
if self.is_target:
return transformed_sample, aug_1, aug_2, target
return transformed_sample, aug_1, aug_2
__all__ = ["DatasetFromSampler", "SelfSupervisedDatasetWrapper"]