Source code for catalyst.callbacks.batch_transform

from typing import Optional, Sequence, TYPE_CHECKING, Union

from torch import nn

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.registry import REGISTRY

    from catalyst.core.runner import IRunner

[docs]class BatchTransformCallback(Callback): """Callback to perform data augmentations on GPU using kornia library. Args: transform: define augmentations to apply on a batch If a sequence of transforms passed, then each element should be either ``kornia.augmentation.AugmentationBase2D``, ``kornia.augmentation.AugmentationBase3D``, or ``nn.Module`` compatible with kornia interface. If a sequence of params (``dict``) passed, then each element of the sequence must contain ``'transform'`` key with an augmentation name as a value. Please note that in this case to use custom augmentation you should add it to the `REGISTRY` registry first. input_key (Union[str, int]): key in batch dict mapping to transform, e.g. `'image'` output_key: key to use to store the result of the transform, defaults to `input_key` if not provided Look at `Kornia: an Open Source Differentiable Computer Vision Library for PyTorch`_ for details. Usage example for notebook API: .. code-block:: python import os from kornia import augmentation import torch from torch.nn import functional as F from import DataLoader from catalyst import dl from import ToTensor from catalyst.contrib.datasets import MNIST from catalyst.contrib.callbacks.kornia_transform import ( BatchTransformCallback ) from catalyst import metrics class CustomRunner(dl.Runner): def predict_batch(self, batch): # model inference step return self.model( batch[0].to(self.device).view(batch[0].size(0), -1) ) def handle_batch(self, batch): # model train/valid step x, y = batch y_hat = self.model(x.view(x.size(0), -1)) loss = F.cross_entropy(y_hat, y) accuracy01, *_ = metrics.accuracy(y_hat, y) self.batch_metrics.update( {"loss": loss, "accuracy01": accuracy01} ) if self.is_train_loader: loss.backward() self.optimizer.step() self.optimizer.zero_grad() model = torch.nn.Linear(28 * 28, 10) optimizer = torch.optim.Adam(model.parameters(), lr=0.02) loaders = { "train": DataLoader( MNIST(os.getcwd(), train=True, transform=ToTensor()), batch_size=32, ), "valid": DataLoader( MNIST(os.getcwd(), train=False, transform=ToTensor()), batch_size=32, ), } transforms = [ augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)), ] runner = CustomRunner() # model training runner.train( model=model, optimizer=optimizer, loaders=loaders, logdir="./logs", num_epochs=5, verbose=True, callbacks=[BatchTransformCallback(transforms, input_key=0)], ) To apply augmentations only during specific loader e.g. only during training :class:`catalyst.core.callbacks.control_flow.ControlFlowCallback` callback can be used. For config API it can look like this: .. code-block:: yaml callbacks_params: ... train_transforms: _wrapper: name: ControlFlowCallback loaders: train name: BatchTransformCallback transform: - transform: kornia.RandomAffine degrees: [-15, 20] scale: [0.75, 1.25] return_transform: true - transform: kornia.ColorJitter brightness: 0.1 contrast: 0.1 saturation: 0.1 return_transform: false input_key: image ... .. _`Kornia: an Open Source Differentiable Computer Vision Library for PyTorch`: """ def __init__( self, transform: Sequence[Union[dict, nn.Module]], input_key: Union[str, int] = "image", output_key: Optional[Union[str, int]] = None, ) -> None: """Init.""" super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all) self.input_key = input_key self.output_key = output_key or self.input_key transforms: Sequence[nn.Module] = [ item if isinstance(item, nn.Module) else REGISTRY.get_from_params(**item) for item in transform ] assert all( isinstance(t, nn.Module) for t in transforms ), "`nn.Module` should be a base class for transforms" self.transform = nn.Sequential(*transforms) def on_batch_start(self, runner: "IRunner") -> None: """Apply transforms. Args: runner: сurrent runner """ input_batch = runner.batch[self.input_key] output_batch = self.transform(input_batch) runner.batch[self.output_key] = output_batch
__all__ = ["BatchTransformCallback"]