Source code for catalyst.data.dataset

from typing import Any, Callable, Dict, List, Union  # isort:skip
from pathlib import Path

import numpy as np

from torch.utils.data import Dataset

from catalyst.utils.misc import merge_dicts

_Path = Union[str, Path]


[docs]class ListDataset(Dataset): """ General purpose dataset class with several data sources `list_data` """
[docs] def __init__( self, list_data: List[Dict], open_fn: Callable, dict_transform: Callable = None, ): """ Args: list_data (List[Dict]): list of dicts, that stores you data annotations, (for example path to images, labels, bboxes, etc.) open_fn (callable): function, that can open your annotations dict and transfer it to data, needed by your network (for example open image by path, or tokenize read string.) dict_transform (callable): transforms to use on dict. (for example normalize image, add blur, crop/resize/etc) """ self.data = list_data self.open_fn = open_fn self.dict_transform = ( dict_transform if dict_transform is not None else lambda x: x )
[docs] def __getitem__(self, index: int) -> Any: """ Gets element of the dataset Args: index (int): index of the element in the dataset Returns: Single element by index """ item = self.data[index] dict_ = self.open_fn(item) dict_ = self.dict_transform(dict_) return dict_
[docs] def __len__(self) -> int: """ Returns: int: length of the dataset """ return len(self.data)
[docs]class MergeDataset(Dataset): """ Abstraction to merge several datasets into one dataset. """
[docs] def __init__(self, *datasets: Dataset, dict_transform: Callable = None): """ Args: datasets (List[Dataset]): params count of datasets to merge dict_transform (callable): transforms common for all datasets. (for example normalize image, add blur, crop/resize/etc) """ self.len = len(datasets[0]) assert all([len(x) == self.len for x in datasets]) self.datasets = datasets self.dict_transform = dict_transform
[docs] def __getitem__(self, index: int) -> Any: """Get item from all datasets Args: index (int): index to value from all datasets Returns: list: list of value in every dataset """ dcts = [x[index] for x in self.datasets] dct = merge_dicts(*dcts) if self.dict_transform is not None: dct = self.dict_transform(dct) return dct
[docs] def __len__(self) -> int: """ Returns: int: length of the dataset """ return self.len
[docs]class NumpyDataset(Dataset): """ General purpose dataset class to use with `numpy_data` """
[docs] def __init__( self, numpy_data: np.ndarray, numpy_key: str = "features", dict_transform: Callable = None, ): """ Args: numpy_data (np.ndarray): numpy data (for example path to embeddings, features, etc.) numpy_key (str): key to use for output dictionary dict_transform (callable): transforms to use on dict. (for example normalize vector, etc) """ super().__init__() self.numpy_data = numpy_data self.numpy_key = numpy_key self.dict_transform = ( dict_transform if dict_transform is not None else lambda x: x )
[docs] def __getitem__(self, index: int) -> Any: """ Gets element of the dataset Args: index (int): index of the element in the dataset Returns: Single element by index """ dict_ = {self.numpy_key: np.copy(self.numpy_data[index])} dict_ = self.dict_transform(dict_) return dict_
[docs]class PathsDataset(ListDataset): """ Dataset that derives features and targets from samples filesystem paths. """
[docs] def __init__( self, filenames: List[_Path], open_fn: Callable[[dict], dict], label_fn: Callable[[_Path], Any], **list_dataset_params ): """ Args: filenames (List[str]): list of file paths that store information about your dataset samples; it could be images, texts or any other files in general. open_fn (callable): function, that can open your annotations dict and transfer it to data, needed by your network (for example open image by path, or tokenize read string) label_fn (callable): function, that can extract target value from sample path (for example, your sample could be an image file like ``/path/to/your/image_1.png`` where the target is encoded as a part of file path) list_dataset_params (dict): base class initialization parameters. Examples: >>> label_fn = lambda x: x.split("_")[0] >>> dataset = PathsDataset( >>> filenames=Path("/path/to/images/").glob("*.jpg"), >>> label_fn=label_fn, >>> open_fn=open_fn, >>> ) """ list_data = [ dict(features=filename, targets=label_fn(filename)) for filename in filenames ] super().__init__( list_data=list_data, open_fn=open_fn, **list_dataset_params )
__all__ = ["ListDataset", "MergeDataset", "NumpyDataset", "PathsDataset"]