from typing import Any, Callable, Dict, List, Union # isort:skip
from pathlib import Path
import random
from torch.utils.data import Dataset
from catalyst.utils.misc import merge_dicts
_Path = Union[str, Path]
[docs]class ListDataset(Dataset):
"""
General purpose dataset class with several data sources `list_data`
"""
[docs] def __init__(
self,
list_data: List[Dict],
open_fn: Callable,
dict_transform: Callable = None,
cache_prob: float = -1,
cache_transforms: bool = False
):
"""
Args:
list_data (List[Dict]): list of dicts, that stores
you data annotations,
(for example path to images, labels, bboxes, etc.)
open_fn (callable): function, that can open your
annotations dict and
transfer it to data, needed by your network
(for example open image by path, or tokenize read string.)
dict_transform (callable): transforms to use on dict.
(for example normalize image, add blur, crop/resize/etc)
cache_prob (float): probability of saving opened dict to RAM
for speedup
cache_transforms (bool): flag if you need
to cache sample after transformations to RAM
"""
self.data = list_data
self.open_fn = open_fn
self.dict_transform = dict_transform
self.cache_prob = cache_prob
self.cache_transforms = cache_transforms
self.cache = dict()
def prepare_new_item(self, index: int):
row = self.data[index]
dict_ = self.open_fn(row)
if self.cache_transforms and self.dict_transform is not None:
dict_ = self.dict_transform(dict_)
return dict_
def prepare_item_from_cache(self, index: int):
return self.cache.get(index, None)
[docs] def __getitem__(self, index: int) -> Any:
"""Gets element of the dataset
Args:
index (int): index of the element in the dataset
Returns:
Single element by index
"""
dict_ = None
if random.random() < self.cache_prob:
dict_ = self.prepare_item_from_cache(index)
if dict_ is None:
dict_ = self.prepare_new_item(index)
if self.cache_prob > 0:
self.cache[index] = dict_
if not self.cache_transforms and self.dict_transform is not None:
dict_ = self.dict_transform(dict_)
return dict_
[docs] def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.data)
[docs]class MergeDataset(Dataset):
"""
Abstraction to merge several datasets into one dataset.
"""
[docs] def __init__(self, *datasets: Dataset, dict_transform: Callable = None):
"""
Args:
datasets (List[Dataset]): params count of datasets to merge
dict_transform (callable): transforms common for all datasets.
(for example normalize image, add blur, crop/resize/etc)
"""
self.len = len(datasets[0])
assert all([len(x) == self.len for x in datasets])
self.datasets = datasets
self.dict_transform = dict_transform
[docs] def __getitem__(self, index: int) -> Any:
"""Get item from all datasets
Args:
index (int): index to value from all datasets
Returns:
list: list of value in every dataset
"""
dcts = [x[index] for x in self.datasets]
dct = merge_dicts(*dcts)
if self.dict_transform is not None:
dct = self.dict_transform(dct)
return dct
def __len__(self) -> int:
return self.len
[docs]class PathsDataset(ListDataset):
"""
Dataset that derives features and targets from samples filesystem paths.
"""
[docs] def __init__(
self, filenames: List[_Path], open_fn: Callable[[dict], dict],
label_fn: Callable[[_Path], Any], **list_dataset_params
):
"""
Args:
filenames (List[str]): list of file paths that store information
about your dataset samples; it could be images, texts or
any other files in general.
open_fn (callable): function, that can open your
annotations dict and
transfer it to data, needed by your network
(for example open image by path, or tokenize read string)
label_fn (callable): function, that can extract target
value from sample path
(for example, your sample could be an image file like
``/path/to/your/image_1.png`` where the target is encoded as
a part of file path)
list_dataset_params (dict): base class initialization
parameters.
Examples:
>>> label_fn = lambda x: x.split("_")[0]
>>> dataset = PathsDataset(
>>> filenames=Path("/path/to/images/").glob("*.jpg"),
>>> label_fn=label_fn,
>>> open_fn=open_fn,
>>> )
"""
list_data = [
dict(features=filename, targets=label_fn(filename))
for filename in filenames
]
super().__init__(
list_data=list_data, open_fn=open_fn, **list_dataset_params
)
__all__ = ["_Path", "ListDataset", "MergeDataset", "PathsDataset"]