Shortcuts

Source code for catalyst.data.loader

from typing import Union

from torch.utils.data import DataLoader


[docs]class BatchLimitLoaderWrapper: """ Loader wrapper. Limits number of batches used per each iteration. For example, if you have some loader and want to use only first 5 bathes: .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.data.loader import BatchLimitLoaderWrapper num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loader = BatchLimitLoaderWrapper(loader, num_batches=5) or if you would like to use only some portion of Dataloader (we use 30% in the example below): .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst.data.loader import BatchLimitLoaderWrapper num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loader = BatchLimitLoaderWrapper(loader, num_batches=0.3) .. note:: Generally speaking, this wrapper could be used with any iterator-like object. No ``DataLoader``-specific code used. """
[docs] def __init__(self, loader: DataLoader, num_batches: Union[int, float]): """ Loader wrapper. Limits number of batches used per each iteration. Args: loader (DataLoader): torch dataloader. num_batches (Union[int, float]): number of batches to use (int), or portion of iterator (float, should be in [0;1] range) """ assert isinstance(num_batches, (int, float)), ( "Expected ``num_batches`` type is int/float" f"but got {type(num_batches)}" ) if isinstance(num_batches, float): assert 0.0 <= num_batches <= 1, ( "Expected ``num_batches`` to be in range [0; 1]" f"but got {num_batches}" ) num_batches = int(len(loader) * num_batches) self.origin = loader self.iterator = iter(self.origin) self.iteration_index = 0 self.num_batches = num_batches
[docs] def __getattr__(self, key): """ Gets attribute by ``key``. Firstly, looks at the ``origin`` for the appropriate ``key``. If none founds - looks at the wrappers attributes. If could not found anything - raises ``NotImplementedError``. Args: key: attribute key Returns: attribute value Raises: NotImplementedError: if could not find attribute in ``origin`` or ``wrapper`` """ value = getattr(self.origin, key, None) if value is not None: return value value = getattr(self, key, None) if value is not None: return value raise NotImplementedError()
[docs] def __iter__(self): """Iterator. Returns: iterator object """ self.iteration_index = 0 self.iterator = iter(self.origin) return self
[docs] def __next__(self): """Next batch. Returns: next batch """ if self.iteration_index >= len(self.origin): raise StopIteration() self.iteration_index += 1 if self.iteration_index % self.num_batches == 0: self.iterator = iter(self.origin) batch = next(self.iterator) return batch
[docs] def __len__(self) -> int: """Returns length of the wrapper loader. Returns: int: length of the wrapper loader """ return len(self.origin)
__all__ = ["BatchLimitLoaderWrapper"]