Source code for catalyst.callbacks.batch_transform
from typing import Any, Callable, Dict, List, Union
from functools import partial
from catalyst.core import Callback, CallbackOrder, IRunner
from catalyst.registry import REGISTRY
class _TupleWrapper:
    """Function wrapper for tuple output"""
    def __init__(self, transform: Callable) -> None:
        """Init."""
        self.transform = transform
    def __call__(self, *inputs) -> Any:
        """Call."""
        output = self.transform(*inputs)
        return (output,)
[docs]class BatchTransformCallback(Callback):
    """
    Preprocess your batch with specified function.
    Args:
        transform: Function to apply. If string will get function from registry.
        scope: ``"on_batch_end"`` (post-processing model output) or
            ``"on_batch_start"`` (pre-processing model input).
        input_key: Keys in batch dict to apply function. Defaults to ``None``.
        output_key: Keys for output.
            If None then will apply function inplace to ``keys_to_apply``.
            Defaults to ``None``.
        transform_kwargs: Kwargs for transform.
    Raises:
        TypeError: When keys is not str or a list.
            When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
    Examples:
        .. code-block:: python
            import torch
            from torch.utils.data import DataLoader, TensorDataset
            from catalyst import dl
            # sample data
            num_users, num_features, num_items = int(1e4), int(1e1), 10
            X = torch.rand(num_users, num_features)
            y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)
            # pytorch loaders
            dataset = TensorDataset(X, y)
            loader = DataLoader(dataset, batch_size=32, num_workers=1)
            loaders = {"train": loader, "valid": loader}
            # model, criterion, optimizer, scheduler
            model = torch.nn.Linear(num_features, num_items)
            criterion = torch.nn.BCEWithLogitsLoss()
            optimizer = torch.optim.Adam(model.parameters())
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
            # model training
            runner = SupervisedRunner()
            runner.train(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                loaders=loaders,
                num_epochs=3,
                verbose=True,
                callbacks=[
                    dl.BatchTransformCallback(
                        input_key="logits", output_key="scores", transform="F.sigmoid"
                    ),
                    dl.CriterionCallback(
                        input_key="logits", target_key="targets", metric_key="loss"
                    ),
                    dl.OptimizerCallback(metric_key="loss"),
                    dl.SchedulerCallback(),
                    dl.CheckpointCallback(
                        logdir="./logs",
                        loader_key="valid",
                        metric_key="map01",
                        minimize=False
                    ),
                ]
            )
        .. code-block:: python
            class CustomRunner(dl.Runner):
                def handle_batch(self, batch):
                    logits = self.model(
                        batch["features"].view(batch["features"].size(0), -1)
                    )
                    loss = F.cross_entropy(logits, batch["targets"])
                    accuracy01, accuracy03 = metrics.accuracy(
                        logits, batch["targets"], topk=(1, 3)
                    )
                    self.batch_metrics.update({
                        "loss": loss,
                        "accuracy01":accuracy01,
                        "accuracy03": accuracy03
                    })
                    if self.is_train_loader:
                        self.engine.backward(loss)
                        self.optimizer.step()
                        self.optimizer.zero_grad()
            class MnistDataset(torch.utils.data.Dataset):
                def __init__(self, dataset):
                    self.dataset = dataset
                def __getitem__(self, item):
                    return {
                        "features": self.dataset[item][0],
                        "targets": self.dataset[item][1]
                    }
                def __len__(self):
                    return len(self.dataset)
            model = torch.nn.Linear(28 * 28, 10)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
            loaders = {
                "train": DataLoader(
                    MnistDataset(
                        MNIST(os.getcwd(), train=False)
                    ),
                    batch_size=32,
                ),
                "valid": DataLoader(
                    MnistDataset(
                        MNIST(os.getcwd(), train=False)
                    ),
                    batch_size=32,
                ),
            }
            transrorms = [
                augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)),
            ]
            runner = CustomRunner()
            # model training
            runner.train(
                model=model,
                optimizer=optimizer,
                loaders=loaders,
                logdir="./logs",
                num_epochs=5,
                verbose=False,
                load_best_on_end=True,
                check=True,
                callbacks=[
                    BatchTransformCallback(
                        transform=transrorms,
                        scope="on_batch_start",
                        input_key="features"
                    )
                ],
            )
        .. code-block:: yaml
            ...
            callbacks:
                transform:
                    _target_: BatchTransformCallback
                    transform: catalyst.ToTensor
                    scope: on_batch_start
                    input_key: features
    """
[docs]    def __init__(
        self,
        transform: Union[Callable, str],
        scope: str,
        input_key: Union[List[str], str] = None,
        output_key: Union[List[str], str] = None,
        transform_kwargs: Dict[str, Any] = None,
    ):
        """
        Preprocess your batch with specified function.
        Args:
            transform: Function to apply.
                If string will get function from registry.
            scope: ``"on_batch_end"`` (post-processing model output) or
                ``"on_batch_start"`` (pre-processing model input).
            input_key: Keys in batch dict to apply function. Defaults to ``None``.
            output_key: Keys for output.
                If None then will apply function inplace to ``keys_to_apply``.
                Defaults to ``None``.
            transform_kwargs: Kwargs for transform.
        Raises:
            TypeError: When keys is not str or a list.
                When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``.
        """
        super().__init__(order=CallbackOrder.Internal)
        if isinstance(transform, str):
            transform = REGISTRY.get(transform)
        if transform_kwargs is not None:
            transform = partial(transform, **transform_kwargs)
        if input_key is not None:
            if not isinstance(input_key, (list, str)):
                raise TypeError("input key should be str or a list of str.")
            elif isinstance(input_key, str):
                input_key = [input_key]
            self._handle_batch = self._handle_value
        else:
            self._handle_batch = self._handle_key_value
        output_key = output_key or input_key
        if output_key is not None:
            if input_key is None:
                raise TypeError(
                    "You should define input_key in " "case if output_key is not None"
                )
            if not isinstance(output_key, (list, str)):
                raise TypeError("output key should be str or a list of str.")
            if isinstance(output_key, str):
                output_key = [output_key]
                transform = _TupleWrapper(transform)
        if isinstance(scope, str) and scope in ["on_batch_end", "on_batch_start"]:
            self.scope = scope
        else:
            raise TypeError(
                'Expected scope to be on of the ["on_batch_end", "on_batch_start"]'
            )
        self.input_key = input_key
        self.output_key = output_key
        self.transform = transform
    def _handle_value(self, runner):
        batch_in = [runner.batch[key] for key in self.input_key]
        batch_out = self.transform(*batch_in)
        runner.batch.update(
            **{key: value for key, value in zip(self.output_key, batch_out)}
        )
    def _handle_key_value(self, runner):
        runner.batch = self.transform(runner.batch)
    def on_batch_start(self, runner: "IRunner") -> None:
        """Event handler."""
        if self.scope == "on_batch_start":
            self._handle_batch(runner)
    def on_batch_end(self, runner: "IRunner") -> None:
        """Event handler."""
        if self.scope == "on_batch_end":
            self._handle_batch(runner)
__all__ = ["BatchTransformCallback"]