Source code for catalyst.data.augmentor
from typing import Callable, Dict # isort:skip
[docs]class Augmentor:
"""
Augmentation abstraction to use with data dictionaries.
"""
[docs] def __init__(
self, dict_key: str, augment_fn: Callable, default_kwargs: Dict = None
):
"""
Args:
dict_key: key to transform
augment_fn: augmentation function to use
default_kwargs: default kwargs for augmentations function
"""
self.dict_key = dict_key
self.augment_fn = augment_fn
self.default_kwargs = default_kwargs or {}
def __call__(self, dict_):
dict_[self.dict_key
] = self.augment_fn(dict_[self.dict_key], **self.default_kwargs)
return dict_
[docs]class AugmentorKeys:
"""Augmentation abstraction to match input and augmentations keys"""
[docs] def __init__(self, dict2fn_dict: Dict[str, str], augment_fn: Callable):
"""
:param dict2fn_dict: keys matching dict {input_key: augment_fn_key}
ex: {"image": "image", "mask": "mask"}
:param augment_fn: augmentation function
"""
self.dict2fn_dict = dict2fn_dict
self.augment_fn = augment_fn
def __call__(self, dict_):
"""
:param dict_: dict - item form dataset
:return dict_: dict - with augmented data
"""
# link keys from dict_ with augment_fn keys
data = {
fn_key: dict_[dict_key]
for dict_key, fn_key in self.dict2fn_dict.items()
}
augmented = self.augment_fn(**data)
# link keys from augment_fn back to dict_ keys
results = {
dict_key: augmented[fn_key]
for dict_key, fn_key in self.dict2fn_dict.items()
}
return {**dict_, **results}
__all__ = ["Augmentor", "AugmentorKeys"]