Source code for catalyst.data.collate_fn

import collections

from torch.utils.data.dataloader import default_collate


[docs]class FilteringCollateFn: """ Callable object doing job of ``collate_fn`` like ``default_collate``, but does not cast batch items with specified key to :class:`torch.Tensor`. Only adds them to list. Supports only key-value format batches """
[docs] def __init__(self, *keys): """ Args: keys: Keys for values that will not be converted to tensor and stacked """ self.keys = keys
[docs] def __call__(self, batch): """ Args: batch: current batch Returns: batch values filtered by `keys` """ if isinstance(batch[0], collections.Mapping): result = {} for key in batch[0]: items = [d[key] for d in batch] if key not in self.keys: items = default_collate(items) result[key] = items return result else: return default_collate(batch)
__all__ = ["FilteringCollateFn"]