Source code for catalyst.contrib.datasets.cv.misc
from typing import Iterable, Tuple
import os
from catalyst.contrib.data.cv.dataset import ImageFolderDataset
from catalyst.contrib.datasets.functional import download_and_extract_archive
class ImageClassificationDataset(ImageFolderDataset):
"""
Base class for datasets with the following structure:
.. code-block:: bash
path/to/dataset/
|-- train/
| |-- class1/ # folder of N images
| | |-- train_image11
| | |-- train_image12
| | ...
| | `-- train_image1N
| ...
| `-- classM/ # folder of K images
| |-- train_imageM1
| |-- train_imageM2
| ...
| `-- train_imageMK
`-- val/
|-- class1/ # folder of P images
| |-- val_image11
| |-- val_image12
| ...
| `-- val_image1P
...
`-- classM/ # folder of T images
|-- val_imageT1
|-- val_imageT2
...
`-- val_imageMT
"""
# name of dataset folder
name: str
# list of (url, md5 hash) tuples representing files to download
resources: Iterable[Tuple[str, str]] = None
def __init__(self, root: str, train: bool = True, download: bool = False, **kwargs):
"""Constructor method for the ``ImageClassificationDataset`` class.
Args:
root: root directory of dataset
train: if ``True``, creates dataset from ``train/``
subfolder, otherwise from ``val/``
download: if ``True``, downloads the dataset from
the internet and puts it in root directory. If dataset
is already downloaded, it is not downloaded again
**kwargs:
"""
# downlad dataset if needed
if download and not os.path.exists(os.path.join(root, self.name)):
os.makedirs(root, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition("/")[2]
download_and_extract_archive(url, download_root=root, filename=filename, md5=md5)
rootpath = os.path.join(root, self.name, "train" if train else "val")
super().__init__(rootpath=rootpath, **kwargs)
__all__ = ["ImageClassificationDataset"]