Source code for catalyst.callbacks.mixup

from typing import List, Union

from catalyst.core.callback import Callback, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.utils.torch import mixup_batch

[docs]class MixupCallback(Callback): """ Callback to do mixup augmentation. More details about mixin can be found in the paper `mixup: Beyond Empirical Risk Minimization`: . Args: keys: batch keys to which you want to apply augmentation alpha: beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the less effect of the mixup. mode: mode determines the method of use. Must be in ["replace", "add"]. If "replace" then replaces the batch with a mixed one, while the batch size is not changed. If "add", concatenates mixed examples to the current ones, the batch size increases by 2 times. on_train_only: apply to train only. As the mixup use the proxy inputs, the targets are also proxy. We are not interested in them, are we? So, if ``on_train_only`` is ``True`` use a standard output/metric for validation. Examples: .. code-block:: python from typing import Any, Dict import os import numpy as np import torch from torch import nn from import DataLoader from catalyst import dl from catalyst.callbacks import MixupCallback from catalyst.contrib.datasets import MNIST class SimpleNet(nn.Module): def __init__(self, in_channels, in_hw, out_features): super().__init__() self.encoder = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, 1, 1), nn.Tanh()) self.clf = nn.Linear(in_channels * in_hw * in_hw, out_features) def forward(self, x): features = self.encoder(x) features = features.view(features.size(0), -1) logits = self.clf(features) return logits class SimpleDataset( def __init__(self, train: bool = False): self.mnist = MNIST(os.getcwd(), train=train) def __len__(self) -> int: return len(self.mnist) def __getitem__(self, idx: int) -> Dict[str, Any]: x, y = self.mnist.__getitem__(idx) y_one_hot = np.zeros(10) y_one_hot[y] = 1 return {"image": x, "clf_targets": y, "clf_targets_one_hot": torch.Tensor(y_one_hot)} model = SimpleNet(1, 28, 10) criterion = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.02) loaders = { "train": DataLoader(SimpleDataset(train=True), batch_size=32), "valid": DataLoader(SimpleDataset(train=False), batch_size=32), } class CustomRunner(dl.Runner): def handle_batch(self, batch): image = batch["image"] clf_logits = self.model(image) self.batch["clf_logits"] = clf_logits runner = CustomRunner() runner.train( loaders=loaders, model=model, criterion=criterion, optimizer=optimizer, logdir="./logdir14", num_epochs=2, verbose=True, valid_loader="valid", valid_metric="loss", minimize_valid_metric=True, callbacks={ "mixup": MixupCallback(keys=["image", "clf_targets_one_hot"]), "criterion": dl.CriterionCallback( metric_key="loss", input_key="clf_logits", target_key="clf_targets_one_hot" ), "backward": dl.BackwardCallback(metric_key="loss"), "optimizer": dl.OptimizerCallback(metric_key="loss"), "classification": dl.ControlFlowCallback( dl.PrecisionRecallF1SupportCallback( input_key="clf_logits", target_key="clf_targets", num_classes=10 ), ignore_loaders="train", ), }, ) .. By running:: With running this callback, many metrics (accuracy, etc) become undefined, so use ControlFlowCallback in order to evaluate model(see example) """ def __init__( self, keys: Union[str, List[str]], alpha=0.2, mode="replace", on_train_only=True ): """Init.""" assert isinstance(keys, (str, list, tuple)), ( f"keys must be str of list[str]," f" get: {type(keys)}" ) assert alpha >= 0, "alpha must be>=0" assert mode in ( "add", "replace", ), f"mode must be in 'add', 'replace', get: {mode}" super().__init__(order=CallbackOrder.Internal) if isinstance(keys, str): keys = [keys] self.keys = keys self.on_train_only = on_train_only self.alpha = alpha self.mode = mode self._is_required = True def on_loader_start(self, runner: "IRunner") -> None: """Event handler.""" self._is_required = not self.on_train_only or runner.is_train_loader def on_batch_start(self, runner: "IRunner") -> None: """Event handler.""" if self._is_required: mixuped_batch = [runner.batch[key] for key in self.keys] mixuped_batch = mixup_batch(mixuped_batch, alpha=self.alpha, mode=self.mode) for key, mixuped_value in zip(self.keys, mixuped_batch): runner.batch[key] = mixuped_value
__all__ = ["MixupCallback"]