from typing import Any, Callable, Iterable, Union
from itertools import tee
import queue
import sys
import threading
import numpy as np
import torch
from torch.utils.data import DataLoader
class ILoaderWrapper:
"""Loader wrapper interface.
Args:
loader: torch dataloader.
"""
def __init__(self, loader: DataLoader):
"""Init"""
self.origin = loader
def __getattr__(self, key):
"""
Gets attribute by ``key``.
Firstly, looks at the ``origin`` for the appropriate ``key``.
If none founds - looks at the wrappers attributes.
If could not found anything - raises ``NotImplementedError``.
Args:
key: attribute key
Returns:
attribute value
Raises:
NotImplementedError: if could not find attribute in ``origin`` or ``wrapper``
"""
some_default_value = "_no_attr_found_"
value = self.origin.__dict__.get(key, some_default_value)
# value = getattr(self.origin, key, None)
if value != some_default_value:
return value
value = self.__dict__.get(key, some_default_value)
# value = getattr(self, key, None)
if value != some_default_value:
return value
raise NotImplementedError()
def __len__(self) -> int:
"""Returns length of the wrapper loader.
Returns:
int: length of the wrapper loader
"""
return len(self.origin)
[docs]class BatchLimitLoaderWrapper(ILoaderWrapper):
"""Loader wrapper. Limits number of batches used per each iteration.
For example, if you have some loader and want to use only first 5 bathes:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.data.loader import BatchLimitLoaderWrapper
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loader = BatchLimitLoaderWrapper(loader, num_batches=5)
or if you would like to use only some portion of Dataloader
(we use 30% in the example below):
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.data.loader import BatchLimitLoaderWrapper
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loader = BatchLimitLoaderWrapper(loader, num_batches=0.3)
.. note::
Generally speaking, this wrapper could be used with any iterator-like
object. No ``DataLoader``-specific code used.
"""
[docs] def __init__(self, loader: DataLoader, num_batches: Union[int, float]):
"""Loader wrapper. Limits number of batches used per each iteration.
Args:
loader: torch dataloader.
num_batches (Union[int, float]): number of batches to use (int),
or portion of iterator (float, should be in [0;1] range)
"""
super().__init__(loader)
assert isinstance(num_batches, (int, float)), (
"Expected ``num_batches`` type is int/float" f"but got {type(num_batches)}"
)
if isinstance(num_batches, float):
assert 0.0 <= num_batches <= 1, (
"Expected ``num_batches`` to be in range [0; 1]" f"but got {num_batches}"
)
num_batches = int(len(loader) * num_batches)
self._iterator = iter(self.origin)
self.iteration_index = 0
self.num_batches = num_batches
def __iter__(self):
"""Iterator.
Returns:
iterator object
"""
self.iteration_index = 0
self._iterator, self.iterator = tee(self._iterator)
return self
def __next__(self):
"""Next batch.
Returns:
next batch
"""
if self.iteration_index >= len(self.origin):
raise StopIteration()
self.iteration_index += 1
if self.iteration_index % self.num_batches == 0:
self._iterator, self.iterator = tee(self._iterator)
batch = next(self.iterator)
return batch
def _any2cuda_non_blocking(value: Any):
# based on catalyst.utils.torch.any2device
# but with cuda non_blocking trick
if isinstance(value, dict):
return {k: _any2cuda_non_blocking(v) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return [_any2cuda_non_blocking(v) for v in value]
elif torch.is_tensor(value):
return value.cuda(non_blocking=True)
elif isinstance(value, (np.ndarray, np.void)) and value.dtype.fields is not None:
return {k: _any2cuda_non_blocking(value[k]) for k in value.dtype.fields.keys()}
elif isinstance(value, np.ndarray):
return torch.tensor(value).cuda(non_blocking=True)
def _map_loop(
func: Callable,
iterable: Iterable,
result_queue: queue.Queue,
error_queue: queue.Queue,
done_event: threading.Event,
):
try:
for x in iterable:
result = func(x)
result_queue.put(result)
except BaseException:
error_queue.put(sys.exc_info())
finally:
done_event.set()
def _prefetch_map(
func: Callable, iterable: Iterable, num_prefetches: int = 1, timeout: int = 2
) -> Iterable:
result_queue = queue.Queue(num_prefetches)
error_queue = queue.Queue(1)
done_event = threading.Event()
map_thread = threading.Thread(
target=_map_loop, args=(func, iterable, result_queue, error_queue, done_event)
)
map_thread.daemon = True
map_thread.start()
while not (done_event.is_set() and result_queue.empty()):
try:
result = result_queue.get(timeout=timeout)
except queue.Empty:
continue
yield result
if error_queue.full():
raise error_queue.get()[1]
def _prefetch_loader(loader: DataLoader, num_prefetches: int) -> Iterable:
if torch.cuda.is_available():
return _prefetch_map(_any2cuda_non_blocking, loader, num_prefetches=num_prefetches)
else:
return iter(loader)
[docs]class BatchPrefetchLoaderWrapper(ILoaderWrapper):
"""Loader wrapper. Prefetches specified number of batches on the GPU.
Base usage:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.data import BatchPrefetchLoaderWrapper
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loader = BatchPrefetchLoaderWrapper(loader)
Minimal working example:
.. code-block:: python
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl, metrics
from catalyst.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.data import BatchPrefetchLoaderWrapper
class CustomRunner(dl.Runner):
def handle_batch(self, batch):
# model train/valid step
x, y = batch
y_hat = self.model(x.view(x.size(0), -1))
loss = F.cross_entropy(y_hat, y)
accuracy01 = metrics.accuracy(y_hat, y, topk=(1, ))
self.batch_metrics.update(
{"loss": loss, "accuracy01": accuracy01}
)
if self.is_train_loader:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
batch_size=32
loaders = {
"train": DataLoader(
MNIST(
os.getcwd(),
train=True,
download=True,
transform=ToTensor()
),
batch_size=batch_size),
"valid": DataLoader(
MNIST(
os.getcwd(),
train=False,
download=True,
transform=ToTensor()
),
batch_size=batch_size),
}
loaders = {
k: BatchPrefetchLoaderWrapper(v) for k, v in loaders.items()
}
runner = CustomRunner()
# model training
runner.train(
model=model,
optimizer=optimizer,
loaders=loaders,
logdir="./logs",
num_epochs=5,
verbose=True,
load_best_on_end=True,
)
"""
[docs] def __init__(self, loader: DataLoader, num_prefetches: int = None):
"""Loader wrapper. Prefetches specified number of batches on the GPU.
Args:
loader: torch dataloader.
num_prefetches: number of batches to prefetch on the GPU.
"""
super().__init__(loader)
self.num_prefetches = num_prefetches or 1
def __iter__(self):
"""Iterator.
Returns:
iterator object
"""
return _prefetch_loader(self.origin, self.num_prefetches)
__all__ = ["ILoaderWrapper", "BatchLimitLoaderWrapper", "BatchPrefetchLoaderWrapper"]