Shortcuts

Source code for catalyst.callbacks.batch_transform

from typing import Any, Callable, Dict, List, Union
from functools import partial

from catalyst.core import Callback, CallbackOrder, IRunner
from catalyst.registry import REGISTRY


class _TupleWrapper:
    """Function wrapper for tuple output"""

    def __init__(self, transform: Callable) -> None:
        """Init."""
        self.transform = transform

    def __call__(self, *inputs) -> Any:
        """Call."""
        output = self.transform(*inputs)
        return (output,)


[docs]class BatchTransformCallback(Callback): """ Preprocess your batch with specified function. Args: transform: Function to apply. If string will get function from registry. scope: ``"on_batch_end"`` (post-processing model output) or ``"on_batch_start"`` (pre-processing model input). input_key: Keys in batch dict to apply function. Defaults to ``None``. output_key: Keys for output. If None then will apply function inplace to ``keys_to_apply``. Defaults to ``None``. transform_kwargs: Kwargs for transform. Raises: TypeError: When keys is not str or a list. When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``. Examples: .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst import dl # sample data num_users, num_features, num_items = int(1e4), int(1e1), 10 X = torch.rand(num_users, num_features) y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32) # pytorch loaders dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, num_items) criterion = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2]) # model training runner = SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, num_epochs=3, verbose=True, callbacks=[ dl.BatchTransformCallback( input_key="logits", output_key="scores", transform="F.sigmoid" ), dl.CriterionCallback( input_key="logits", target_key="targets", metric_key="loss" ), dl.OptimizerCallback(metric_key="loss"), dl.SchedulerCallback(), dl.CheckpointCallback( logdir="./logs", loader_key="valid", metric_key="map01", minimize=False ), ] ) .. code-block:: python class CustomRunner(dl.Runner): def handle_batch(self, batch): logits = self.model( batch["features"].view(batch["features"].size(0), -1) ) loss = F.cross_entropy(logits, batch["targets"]) accuracy01, accuracy03 = metrics.accuracy( logits, batch["targets"], topk=(1, 3) ) self.batch_metrics.update({ "loss": loss, "accuracy01":accuracy01, "accuracy03": accuracy03 }) if self.is_train_loader: self.engine.backward(loss) self.optimizer.step() self.optimizer.zero_grad() class MnistDataset(torch.utils.data.Dataset): def __init__(self, dataset): self.dataset = dataset def __getitem__(self, item): return { "features": self.dataset[item][0], "targets": self.dataset[item][1] } def __len__(self): return len(self.dataset) model = torch.nn.Linear(28 * 28, 10) optimizer = torch.optim.Adam(model.parameters(), lr=0.02) loaders = { "train": DataLoader( MnistDataset( MNIST(os.getcwd(), train=False) ), batch_size=32, ), "valid": DataLoader( MnistDataset( MNIST(os.getcwd(), train=False) ), batch_size=32, ), } transrorms = [ augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)), ] runner = CustomRunner() # model training runner.train( model=model, optimizer=optimizer, loaders=loaders, logdir="./logs", num_epochs=5, verbose=False, load_best_on_end=True, check=True, callbacks=[ BatchTransformCallback( transform=transrorms, scope="on_batch_start", input_key="features" ) ], ) .. code-block:: yaml ... callbacks: transform: _target_: BatchTransformCallback transform: catalyst.ToTensor scope: on_batch_start input_key: features """ def __init__( self, transform: Union[Callable, str], scope: str, input_key: Union[List[str], str] = None, output_key: Union[List[str], str] = None, transform_kwargs: Dict[str, Any] = None, ): """ Preprocess your batch with specified function. Args: transform: Function to apply. If string will get function from registry. scope: ``"on_batch_end"`` (post-processing model output) or ``"on_batch_start"`` (pre-processing model input). input_key: Keys in batch dict to apply function. Defaults to ``None``. output_key: Keys for output. If None then will apply function inplace to ``keys_to_apply``. Defaults to ``None``. transform_kwargs: Kwargs for transform. Raises: TypeError: When keys is not str or a list. When ``scope`` is not in ``["on_batch_end", "on_batch_start"]``. """ super().__init__(order=CallbackOrder.Internal) if isinstance(transform, str): transform = REGISTRY.get(transform) if transform_kwargs is not None: transform = partial(transform, **transform_kwargs) if input_key is not None: if not isinstance(input_key, (list, str)): raise TypeError("input key should be str or a list of str.") elif isinstance(input_key, str): input_key = [input_key] self._handle_batch = self._handle_value else: self._handle_batch = self._handle_key_value output_key = output_key or input_key if output_key is not None: if input_key is None: raise TypeError( "You should define input_key in " "case if output_key is not None" ) if not isinstance(output_key, (list, str)): raise TypeError("output key should be str or a list of str.") if isinstance(output_key, str): output_key = [output_key] transform = _TupleWrapper(transform) if isinstance(scope, str) and scope in ["on_batch_end", "on_batch_start"]: self.scope = scope else: raise TypeError( 'Expected scope to be on of the ["on_batch_end", "on_batch_start"]' ) self.input_key = input_key self.output_key = output_key self.transform = transform def _handle_value(self, runner): batch_in = [runner.batch[key] for key in self.input_key] batch_out = self.transform(*batch_in) runner.batch.update( **{key: value for key, value in zip(self.output_key, batch_out)} ) def _handle_key_value(self, runner): runner.batch = self.transform(runner.batch) def on_batch_start(self, runner: "IRunner") -> None: """Event handler.""" if self.scope == "on_batch_start": self._handle_batch(runner) def on_batch_end(self, runner: "IRunner") -> None: """Event handler.""" if self.scope == "on_batch_end": self._handle_batch(runner)
__all__ = ["BatchTransformCallback"]