Shortcuts

Source code for catalyst.contrib.datasets.mnist

from typing import Any, Callable, Dict, List, Optional
import os

import torch
from torch.utils.data import Dataset

from catalyst.contrib.datasets.functional import (
    download_and_extract_archive,
    read_sn3_pascalvincent_tensor,
)
from catalyst.data.dataset.metric_learning import MetricLearningTrainDataset, QueryGalleryDataset


def _read_label_file(path):
    with open(path, "rb") as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert x.dtype == torch.uint8
    assert x.ndimension() == 1
    return x.long()


def _read_image_file(path):
    with open(path, "rb") as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert x.dtype == torch.uint8
    assert x.ndimension() == 3
    return x


[docs]class MNIST(Dataset): """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.""" _repr_indent = 4 # CVDF mirror of http://yann.lecun.com/exdb/mnist/ resources = [ ( "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873", ), ( "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432", ), ( "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3", ), ( "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c", ), ] training_file = "training.pt" test_file = "test.pt" classes = [ "0 - zero", "1 - one", "2 - two", "3 - three", "4 - four", "5 - five", "6 - six", "7 - seven", "8 - eight", "9 - nine", ]
[docs] def __init__( self, root, train=True, transform=None, target_transform=None, download=False, ): """ Args: root: Root directory of dataset where ``MNIST/processed/training.pt`` and ``MNIST/processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, otherwise from ``test.pt``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an image and returns a transformed version. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ if isinstance(root, torch._six.string_classes): # noqa: WPS437 root = os.path.expanduser(root) self.root = root self.train = train # training set or test set self.transform = transform self.target_transform = target_transform if download: self.download() if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") if self.train: data_file = self.training_file else: data_file = self.test_file self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index): """ Args: index: Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index].numpy(), int(self.targets[index]) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): """@TODO: Docs. Contribution is welcome.""" return len(self.data) def __repr__(self): """@TODO: Docs. Contribution is welcome.""" head = "Dataset " + self.__class__.__name__ body = ["Number of datapoints: {}".format(self.__len__())] if self.root is not None: body.append("Root location: {}".format(self.root)) body += self.extra_repr().splitlines() if hasattr(self, "transforms") and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [" " * self._repr_indent + line for line in body] return "\n".join(lines) @property def raw_folder(self): """@TODO: Docs. Contribution is welcome.""" return os.path.join(self.root, self.__class__.__name__, "raw") @property def processed_folder(self): """@TODO: Docs. Contribution is welcome.""" return os.path.join(self.root, self.__class__.__name__, "processed") @property def class_to_idx(self): """@TODO: Docs. Contribution is welcome.""" return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self): return os.path.exists( os.path.join(self.processed_folder, self.training_file) ) and os.path.exists(os.path.join(self.processed_folder, self.test_file)) def download(self): """Download the MNIST data if it doesn't exist in processed_folder.""" if self._check_exists(): return os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.processed_folder, exist_ok=True) # download files for url, md5 in self.resources: filename = url.rpartition("/")[2] download_and_extract_archive( url, download_root=self.raw_folder, filename=filename, md5=md5 ) # process and save as torch files print("Processing...") training_set = ( _read_image_file(os.path.join(self.raw_folder, "train-images-idx3-ubyte")), _read_label_file(os.path.join(self.raw_folder, "train-labels-idx1-ubyte")), ) test_set = ( _read_image_file(os.path.join(self.raw_folder, "t10k-images-idx3-ubyte")), _read_label_file(os.path.join(self.raw_folder, "t10k-labels-idx1-ubyte")), ) with open(os.path.join(self.processed_folder, self.training_file), "wb") as f: torch.save(training_set, f) with open(os.path.join(self.processed_folder, self.test_file), "wb") as f: torch.save(test_set, f) print("Done!") def extra_repr(self): """@TODO: Docs. Contribution is welcome.""" return "Split: {}".format("Train" if self.train is True else "Test")
class MnistMLDataset(MetricLearningTrainDataset, MNIST): """ Simple wrapper for MNIST dataset for metric learning train stage. This dataset can be used only for training. For test stage use MnistQGDataset. For this dataset we use only training part of the MNIST and only those images that are labeled as 0, 1, 2, 3, 4. """ _split = 5 classes = [ "0 - zero", "1 - one", "2 - two", "3 - three", "4 - four", ] def __init__(self, **kwargs): """ Raises: ValueError: if train argument is False (MnistMLDataset should be used only for training) """ if "train" in kwargs: if kwargs["train"] is False: raise ValueError("MnistMLDataset can be used only for training stage.") else: kwargs["train"] = True super(MnistMLDataset, self).__init__(**kwargs) self._filter() def get_labels(self) -> List[int]: """ Returns: labels of digits """ return self.targets.tolist() def _filter(self) -> None: """Filter MNIST dataset: select images of 0, 1, 2, 3, 4 classes.""" mask = self.targets < self._split self.data = self.data[mask] self.targets = self.targets[mask] class MnistQGDataset(QueryGalleryDataset): """ MNIST for metric learning with query and gallery split. MnistQGDataset should be used for test stage. For this dataset we used only test part of the MNIST and only those images that are labeled as 5, 6, 7, 8, 9. """ _split = 5 classes = [ "5 - five", "6 - six", "7 - seven", "8 - eight", "9 - nine", ] def __init__( self, root: str, transform: Optional[Callable] = None, gallery_fraq: Optional[float] = 0.2, ) -> None: """ Args: root: root directory for storing dataset transform: transform gallery_fraq: gallery size """ self._mnist = MNIST(root, train=False, download=True, transform=transform) self._filter() self._gallery_size = int(gallery_fraq * len(self._mnist)) self._query_size = len(self._mnist) - self._gallery_size self._is_query = torch.zeros(len(self._mnist)).type(torch.bool) self._is_query[: self._query_size] = True def _filter(self) -> None: """Filter MNIST dataset: select images of 5, 6, 7, 8, 9 classes.""" mask = self._mnist.targets >= self._split self._mnist.data = self._mnist.data[mask] self._mnist.targets = self._mnist.targets[mask] def __getitem__(self, idx: int) -> Dict[str, Any]: """ Get item method for dataset Args: idx: index of the object Returns: Dict with features, targets and is_query flag """ image, label = self._mnist[idx] return { "features": image, "targets": label, "is_query": self._is_query[idx], } def __len__(self) -> int: """Length""" return len(self._mnist) def __repr__(self) -> None: """Print info about the dataset""" return self._mnist.__repr__() @property def gallery_size(self) -> int: """Query Gallery dataset should have gallery_size property""" return self._gallery_size @property def query_size(self) -> int: """Query Gallery dataset should have query_size property""" return self._query_size @property def data(self) -> torch.Tensor: """Images from MNIST""" return self._mnist.data @property def targets(self) -> torch.Tensor: """Labels of digits""" return self._mnist.targets __all__ = ["MNIST", "MnistMLDataset", "MnistQGDataset"]