Source code for catalyst.dl.callbacks.mixup

from typing import List  # isort:skip

import numpy as np

import torch

from catalyst.dl.callbacks import CriterionCallback
from catalyst.dl.core.state import RunnerState


[docs]class MixupCallback(CriterionCallback): """ Callback to do mixup augmentation. Paper: https://arxiv.org/abs/1710.09412 Note: MixupCallback is inherited from CriterionCallback and does its work. You may not use them together. """
[docs] def __init__( self, fields: List[str] = ("features", ), alpha=1.0, on_train_only=True, **kwargs ): """ Args: fields (List[str]): list of features which must be affected. alpha (float): beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the less effect of the mixup. on_train_only (bool): Apply to train only. As the mixup use the proxy inputs, the targets are also proxy. We are not interested in them, are we? So, if on_train_only is True, use a standard output/metric for validation. """ assert len(fields) > 0, \ "At least one field for MixupCallback is required" assert alpha >= 0, "alpha must be>=0" super().__init__(**kwargs) self.on_train_only = on_train_only self.fields = fields self.alpha = alpha self.lam = 1 self.index = None self.is_needed = True
[docs] def on_loader_start(self, state: RunnerState): self.is_needed = not self.on_train_only or \ state.loader_name.startswith("train")
[docs] def on_batch_start(self, state: RunnerState): if not self.is_needed: return if self.alpha > 0: self.lam = np.random.beta(self.alpha, self.alpha) else: self.lam = 1 self.index = torch.randperm(state.input[self.fields[0]].shape[0]) self.index.to(state.device) for f in self.fields: state.input[f] = self.lam * state.input[f] + \ (1 - self.lam) * state.input[f][self.index]
def _compute_loss(self, state: RunnerState, criterion): if not self.is_needed: return super()._compute_loss(state, criterion) pred = state.output[self.output_key] y_a = state.input[self.input_key] y_b = state.input[self.input_key][self.index] loss = self.lam * criterion(pred, y_a) + \ (1 - self.lam) * criterion(pred, y_b) return loss
__all__ = ["MixupCallback"]