import codecs
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from catalyst.contrib.datasets.utils import download_and_extract_archive
[docs]class MNIST(Dataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset."""
_repr_indent = 4
resources = [
(
"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
"f68b3c2dcbeaaa9fbdd348bbdeb94873",
),
(
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
"d53e105ee54ea40749a09fcbcd1e9432",
),
(
"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
"9fb629c4189551a2d022fa330f9573f3",
),
(
"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
"ec29112dd5afa0611ce80d1b7f02629c",
),
]
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
[docs] def __init__(
self,
root,
train=True,
transform=None,
target_transform=None,
download=False,
):
"""
Args:
root (string): Root directory of dataset where
``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from
``training.pt``, otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from
the internet and puts it in root directory. If dataset
is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that
takes in an image and returns a transformed version.
target_transform (callable, optional): A function/transform
that takes in the target and transforms it.
"""
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
self.train = train # training set or test set
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not found. You can use download=True to download it"
)
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(
os.path.join(self.processed_folder, data_file)
)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index].numpy(), int(self.targets[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""@TODO: Docs. Contribution is welcome."""
return len(self.data)
def __repr__(self):
"""@TODO: Docs. Contribution is welcome."""
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
@property
def raw_folder(self):
"""@TODO: Docs. Contribution is welcome."""
return os.path.join(self.root, self.__class__.__name__, "raw")
@property
def processed_folder(self):
"""@TODO: Docs. Contribution is welcome."""
return os.path.join(self.root, self.__class__.__name__, "processed")
@property
def class_to_idx(self):
"""@TODO: Docs. Contribution is welcome."""
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return os.path.exists(
os.path.join(self.processed_folder, self.training_file)
) and os.path.exists(
os.path.join(self.processed_folder, self.test_file)
)
[docs] def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder."""
if self._check_exists():
return
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
# download files
for url, md5 in self.resources:
filename = url.rpartition("/")[2]
download_and_extract_archive(
url, download_root=self.raw_folder, filename=filename, md5=md5
)
# process and save as torch files
print("Processing...")
training_set = (
read_image_file(
os.path.join(self.raw_folder, "train-images-idx3-ubyte")
),
read_label_file(
os.path.join(self.raw_folder, "train-labels-idx1-ubyte")
),
)
test_set = (
read_image_file(
os.path.join(self.raw_folder, "t10k-images-idx3-ubyte")
),
read_label_file(
os.path.join(self.raw_folder, "t10k-labels-idx1-ubyte")
),
)
with open(
os.path.join(self.processed_folder, self.training_file), "wb"
) as f:
torch.save(training_set, f)
with open(
os.path.join(self.processed_folder, self.test_file), "wb"
) as f:
torch.save(test_set, f)
print("Done!")
[docs]def get_int(b):
"""@TODO: Docs. Contribution is welcome."""
return int(codecs.encode(b, "hex"), 16)
[docs]def open_maybe_compressed_file(path):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string
and ends with '.gz' or '.xz'.
"""
if not isinstance(path, torch._six.string_classes):
return path
if path.endswith(".gz"):
import gzip
return gzip.open(path, "rb")
if path.endswith(".xz"):
import lzma
return lzma.open(path, "rb")
return open(path, "rb")
[docs]def read_sn3_pascalvincent_tensor(path, strict=True):
"""Read a SN3 file in "Pascal Vincent" format.
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, "typemap"):
read_sn3_pascalvincent_tensor.typemap = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype(">i2"), "i2"),
12: (torch.int32, np.dtype(">i4"), "i4"),
13: (torch.float32, np.dtype(">f4"), "f4"),
14: (torch.float64, np.dtype(">f8"), "f8"),
}
# read
with open_maybe_compressed_file(path) as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
[docs]def read_label_file(path):
"""@TODO: Docs. Contribution is welcome."""
with open(path, "rb") as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert x.dtype == torch.uint8
assert x.ndimension() == 1
return x.long()
[docs]def read_image_file(path):
"""@TODO: Docs. Contribution is welcome."""
with open(path, "rb") as f:
x = read_sn3_pascalvincent_tensor(f, strict=False)
assert x.dtype == torch.uint8
assert x.ndimension() == 3
return x