Source code for catalyst.contrib.data.cv.transforms.kornia
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import copy
import random
import numpy as np
from kornia.augmentation import AugmentationBase2D, AugmentationBase3D
import torch
from torch import nn
[docs]class OneOfPerBatch(nn.Module):
"""Select one of tensor transforms and apply it batch-wise."""
[docs] def __init__(
self,
transforms: Iterable[Union[AugmentationBase2D, AugmentationBase3D]],
) -> None:
"""Constructor method for the :class:`OneOfPerBatch` transform.
Args:
transforms: list of kornia transformations to compose.
Actually, any ``nn.Module`` with defined ``p``(probability
of selecting transform) and ``p_batch`` attributes is allowed.
"""
super().__init__()
probs = [transform.p for transform in transforms]
s = sum(probs)
self.probs = [proba / s for proba in probs]
self.transforms = [copy.deepcopy(t) for t in transforms]
for t in self.transforms:
t.p = 1
t.p_batch = 1
[docs] def forward(
self,
input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
params: Optional[Dict[str, torch.Tensor]] = None,
return_transform: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Apply transform.
Args:
input: input batch
params: transform params, please check kornia documentation
return_transform: if ``True`` return the matrix describing
the geometric transformation applied to each input tensor,
please check kornia documentation
Returns:
augmented batch and, optionally, the transformation matrix
"""
# select transform to apply
random_state = np.random.RandomState(random.randint(0, 2 ** 32 - 1))
t = random_state.choice(self.transforms, p=self.probs)
# apply kornia transform
output = t(input, params, return_transform)
return output
[docs]class OneOfPerSample(nn.Module):
"""Select one of tensor transforms to apply sample-wise."""
[docs] def __init__(
self,
transforms: Iterable[Union[AugmentationBase2D, AugmentationBase3D]],
) -> None:
"""Constructor method for the :class:`OneOfPerSample` transform.
Args:
transforms: list of kornia transformations to compose.
Actually, any ``nn.Module`` with defined ``p``(probability
of selecting transform) and ``p_batch`` attributes is allowed.
"""
super().__init__()
probs = [transform.p for transform in transforms]
s = sum(probs)
self.choice_transform = torch.distributions.Categorical(
torch.tensor([proba / s for proba in probs])
)
self.transforms = [copy.deepcopy(t) for t in transforms]
for t in self.transforms:
t.p = 1
t.p_batch = 1
[docs] def forward(
self,
input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
params: Optional[Dict[str, torch.Tensor]] = None,
return_transform: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Apply transform.
Args:
input: input batch
params: transform params, please check kornia documentation
return_transform: if ``True`` return the matrix describing
the geometric transformation applied to each input tensor,
please check kornia documentation
Returns:
augmented batch and, optionally, the transformation matrix
"""
# select transform for each element
batch_size = (input[0] if isinstance(input, tuple) else input).shape[0]
transforms = self.choice_transform.sample([batch_size])
# apply transforms
for idx, transform in enumerate(self.transforms):
to_apply = transforms == idx
if to_apply.any():
# TODO: return transform matrix if `return_transform` == True
self._apply_transform(
transform,
batch=input,
mask=to_apply,
params=params,
return_transform=return_transform,
)
return input
@staticmethod
def _apply_transform(
transform: nn.Module,
batch: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
*args: Any,
return_transform: Optional[bool] = None,
**kwargs: Any,
):
"""Apply ``transform`` inplace."""
# process input
input_ = (
(batch[0][mask], batch[1][mask])
if isinstance(batch, tuple)
else batch[mask]
)
output = transform(
input_, *args, transform_matrix=return_transform, **kwargs
)
# process output
transform_matrix = None
if return_transform:
output, transform_matrix = output
if isinstance(batch, tuple):
batch[0][mask] = output[0]
batch[1][mask] = output[1]
return transform_matrix
__all__ = ["OneOfPerBatch", "OneOfPerSample"]