Source code for catalyst.contrib.data.augmentor
from typing import Callable, Dict, List, Union
[docs]class Augmentor:
"""Augmentation abstraction to use with data dictionaries."""
[docs] def __init__(
self,
dict_key: str,
augment_fn: Callable,
input_key: str = None,
output_key: str = None,
**kwargs,
):
"""
Augmentation abstraction to use with data dictionaries.
Args:
dict_key: key to transform
augment_fn: augmentation function to use
input_key: ``augment_fn`` input key
output_key: ``augment_fn`` output key
**kwargs: default kwargs for augmentations function
"""
self.dict_key = dict_key
self.augment_fn = augment_fn
self.input_key = input_key
self.output_key = output_key
self.kwargs = kwargs
def __call__(self, dict_: dict):
"""Applies the augmentation."""
if self.input_key is not None:
output = self.augment_fn(**{self.input_key: dict_[self.dict_key]}, **self.kwargs)
else:
output = self.augment_fn(dict_[self.dict_key], **self.kwargs)
if self.output_key is not None:
dict_[self.dict_key] = output[self.output_key]
else:
dict_[self.dict_key] = output
return dict_
[docs]class AugmentorCompose:
"""Compose augmentors."""
[docs] def __init__(self, key2augment_fn: Dict[str, Callable]):
"""
Args:
key2augment_fn (Dict[str, Callable]): mapping from input key
to augmentation function to apply
"""
self.key2augment_fn = key2augment_fn
def __call__(self, dictionary: dict) -> dict:
"""
Args:
dictionary: item from dataset
Returns:
dict: dictionary with augmented data
"""
results = {}
for key, augment_fn in self.key2augment_fn.items():
results = {**results, **augment_fn({key: dictionary[key]})}
return {**dictionary, **results}
[docs]class AugmentorKeys:
"""Augmentation abstraction to match input and augmentations keys."""
[docs] def __init__(self, dict2fn_dict: Union[Dict[str, str], List[str]], augment_fn: Callable):
"""
Args:
dict2fn_dict (Dict[str, str]): keys matching dict
``{input_key: augment_fn_key}``. For example:
``{"image": "image", "mask": "mask"}``
augment_fn: augmentation function
"""
if isinstance(dict2fn_dict, list):
dict2fn_dict = {key: key for key in dict2fn_dict}
self.dict2fn_dict = dict2fn_dict
self.augment_fn = augment_fn
def __call__(self, dictionary: dict) -> dict:
"""
Args:
dictionary: item from dataset
Returns:
dict: dictionary with augmented data
"""
# link keys from dict_ with augment_fn keys
data = {fn_key: dictionary[dict_key] for dict_key, fn_key in self.dict2fn_dict.items()}
augmented = self.augment_fn(**data)
# link keys from augment_fn back to dict_ keys
results = {dict_key: augmented[fn_key] for dict_key, fn_key in self.dict2fn_dict.items()}
return {**dictionary, **results}
__all__ = ["Augmentor", "AugmentorCompose", "AugmentorKeys"]