from typing import Any, Callable, Dict, List, Union # isort:skip
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset, Sampler
from catalyst.utils import merge_dicts
_Path = Union[str, Path]
[docs]class ListDataset(Dataset):
"""
General purpose dataset class with several data sources `list_data`
"""
[docs] def __init__(
self,
list_data: List[Dict],
open_fn: Callable,
dict_transform: Callable = None,
):
"""
Args:
list_data (List[Dict]): list of dicts, that stores
you data annotations,
(for example path to images, labels, bboxes, etc.)
open_fn (callable): function, that can open your
annotations dict and
transfer it to data, needed by your network
(for example open image by path, or tokenize read string.)
dict_transform (callable): transforms to use on dict.
(for example normalize image, add blur, crop/resize/etc)
"""
self.data = list_data
self.open_fn = open_fn
self.dict_transform = (
dict_transform if dict_transform is not None else lambda x: x
)
[docs] def __getitem__(self, index: int) -> Any:
"""
Gets element of the dataset
Args:
index (int): index of the element in the dataset
Returns:
Single element by index
"""
item = self.data[index]
dict_ = self.open_fn(item)
dict_ = self.dict_transform(dict_)
return dict_
[docs] def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.data)
[docs]class MergeDataset(Dataset):
"""
Abstraction to merge several datasets into one dataset.
"""
[docs] def __init__(self, *datasets: Dataset, dict_transform: Callable = None):
"""
Args:
datasets (List[Dataset]): params count of datasets to merge
dict_transform (callable): transforms common for all datasets.
(for example normalize image, add blur, crop/resize/etc)
"""
self.len = len(datasets[0])
assert all([len(x) == self.len for x in datasets])
self.datasets = datasets
self.dict_transform = dict_transform
[docs] def __getitem__(self, index: int) -> Any:
"""Get item from all datasets
Args:
index (int): index to value from all datasets
Returns:
list: list of value in every dataset
"""
dcts = [x[index] for x in self.datasets]
dct = merge_dicts(*dcts)
if self.dict_transform is not None:
dct = self.dict_transform(dct)
return dct
[docs] def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return self.len
[docs]class NumpyDataset(Dataset):
"""
General purpose dataset class to use with `numpy_data`
"""
[docs] def __init__(
self,
numpy_data: np.ndarray,
numpy_key: str = "features",
dict_transform: Callable = None,
):
"""
Args:
numpy_data (np.ndarray): numpy data
(for example path to embeddings, features, etc.)
numpy_key (str): key to use for output dictionary
dict_transform (callable): transforms to use on dict.
(for example normalize vector, etc)
"""
super().__init__()
self.data = numpy_data
self.key = numpy_key
self.dict_transform = (
dict_transform if dict_transform is not None else lambda x: x
)
[docs] def __getitem__(self, index: int) -> Any:
"""
Gets element of the dataset
Args:
index (int): index of the element in the dataset
Returns:
Single element by index
"""
dict_ = {self.key: np.copy(self.data[index])}
dict_ = self.dict_transform(dict_)
return dict_
[docs] def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.data)
[docs]class PathsDataset(ListDataset):
"""
Dataset that derives features and targets from samples filesystem paths.
"""
[docs] def __init__(
self, filenames: List[_Path], open_fn: Callable[[dict], dict],
label_fn: Callable[[_Path], Any], **list_dataset_params
):
"""
Args:
filenames (List[str]): list of file paths that store information
about your dataset samples; it could be images, texts or
any other files in general.
open_fn (callable): function, that can open your
annotations dict and
transfer it to data, needed by your network
(for example open image by path, or tokenize read string)
label_fn (callable): function, that can extract target
value from sample path
(for example, your sample could be an image file like
``/path/to/your/image_1.png`` where the target is encoded as
a part of file path)
list_dataset_params (dict): base class initialization
parameters.
Examples:
>>> label_fn = lambda x: x.split("_")[0]
>>> dataset = PathsDataset(
>>> filenames=Path("/path/to/images/").glob("*.jpg"),
>>> label_fn=label_fn,
>>> open_fn=open_fn,
>>> )
"""
list_data = [
dict(features=filename, targets=label_fn(filename))
for filename in filenames
]
super().__init__(
list_data=list_data, open_fn=open_fn, **list_dataset_params
)
[docs]class DatasetFromSampler(Dataset):
"""
Dataset of indexes from `Sampler`
"""
def __init__(self, sampler: Sampler):
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
return len(self.sampler)
__all__ = [
"ListDataset", "MergeDataset", "NumpyDataset", "PathsDataset",
"DatasetFromSampler"
]