Shortcuts

Source code for catalyst.data.dataset.torch

from typing import Any, Callable, Dict, List, Optional, Union
from pathlib import Path

import numpy as np
from torch.utils.data import Dataset, Sampler

from catalyst.utils.misc import merge_dicts

_Path = Union[str, Path]


[docs]class DatasetFromSampler(Dataset): """Dataset to create indexes from `Sampler`. Args: sampler: PyTorch sampler """
[docs] def __init__(self, sampler: Sampler): """Initialisation for DatasetFromSampler.""" self.sampler = sampler self.sampler_list = None
def __getitem__(self, index: int): """Gets element of the dataset. Args: index: index of the element in the dataset Returns: Single element by index """ if self.sampler_list is None: self.sampler_list = list(self.sampler) return self.sampler_list[index] def __len__(self) -> int: """ Returns: int: length of the dataset """ return len(self.sampler)
[docs]class ListDataset(Dataset): """General purpose dataset class with several data sources `list_data`."""
[docs] def __init__( self, list_data: List[Dict], open_fn: Callable, dict_transform: Optional[Callable] = None, ): """ Args: list_data: list of dicts, that stores you data annotations, (for example path to images, labels, bboxes, etc.) open_fn: function, that can open your annotations dict and transfer it to data, needed by your network (for example open image by path, or tokenize read string.) dict_transform: transforms to use on dict. (for example normalize image, add blur, crop/resize/etc) """ self.data = list_data self.open_fn = open_fn self.dict_transform = dict_transform if dict_transform is not None else lambda x: x
def __getitem__(self, index: int) -> Any: """Gets element of the dataset. Args: index: index of the element in the dataset Returns: Single element by index """ item = self.data[index] dict_ = self.open_fn(item) dict_ = self.dict_transform(dict_) return dict_ def __len__(self) -> int: """ Returns: int: length of the dataset """ return len(self.data)
[docs]class MergeDataset(Dataset): """Abstraction to merge several datasets into one dataset."""
[docs] def __init__(self, *datasets: Dataset, dict_transform: Optional[Callable] = None): """ Args: datasets: params count of datasets to merge dict_transform: transforms common for all datasets. (for example normalize image, add blur, crop/resize/etc) """ self.length = len(datasets[0]) assert all(len(x) == self.length for x in datasets) self.datasets = datasets self.dict_transform = dict_transform
def __getitem__(self, index: int) -> Any: """Get item from all datasets. Args: index: index to value from all datasets Returns: list: list of value in every dataset """ dcts = [x[index] for x in self.datasets] dct = merge_dicts(*dcts) if self.dict_transform is not None: dct = self.dict_transform(dct) return dct def __len__(self) -> int: """ Returns: int: length of the dataset """ return self.length
[docs]class NumpyDataset(Dataset): """General purpose dataset class to use with `numpy_data`."""
[docs] def __init__( self, numpy_data: np.ndarray, numpy_key: str = "features", dict_transform: Optional[Callable] = None, ): """ General purpose dataset class to use with `numpy_data`. Args: numpy_data: numpy data (for example path to embeddings, features, etc.) numpy_key: key to use for output dictionary dict_transform: transforms to use on dict. (for example normalize vector, etc) """ super().__init__() self.data = numpy_data self.key = numpy_key self.dict_transform = dict_transform if dict_transform is not None else lambda x: x
def __getitem__(self, index: int) -> Any: """Gets element of the dataset. Args: index: index of the element in the dataset Returns: Single element by index """ dict_ = {self.key: np.copy(self.data[index])} dict_ = self.dict_transform(dict_) return dict_ def __len__(self) -> int: """ Returns: int: length of the dataset """ return len(self.data)
[docs]class PathsDataset(ListDataset): """ Dataset that derives features and targets from samples filesystem paths. Examples: >>> label_fn = lambda x: x.split("_")[0] >>> dataset = PathsDataset( >>> filenames=Path("/path/to/images/").glob("*.jpg"), >>> label_fn=label_fn, >>> open_fn=open_fn, >>> ) """
[docs] def __init__( self, filenames: List[_Path], open_fn: Callable[[dict], dict], label_fn: Callable[[_Path], Any], features_key: str = "features", target_key: str = "targets", **list_dataset_params ): """ Args: filenames: list of file paths that store information about your dataset samples; it could be images, texts or any other files in general. open_fn: function, that can open your annotations dict and transfer it to data, needed by your network (for example open image by path, or tokenize read string) label_fn: function, that can extract target value from sample path (for example, your sample could be an image file like ``/path/to/your/image_1.png`` where the target is encoded as a part of file path) features_key: key to use to store sample features target_key: key to use to store target label list_dataset_params: base class initialization parameters. """ list_data = [ {features_key: filename, target_key: label_fn(filename)} for filename in filenames ] super().__init__(list_data=list_data, open_fn=open_fn, **list_dataset_params)
__all__ = [ "DatasetFromSampler", "ListDataset", "MergeDataset", "NumpyDataset", "PathsDataset", ]