Source code for catalyst.dl.callbacks.mixup
from typing import List # isort:skip
import numpy as np
import torch
from catalyst.dl.callbacks import CriterionCallback
from catalyst.dl.core.state import RunnerState
[docs]class MixupCallback(CriterionCallback):
"""
Callback to do mixup augmentation.
Paper: https://arxiv.org/abs/1710.09412
Note:
MixupCallback is inherited from CriterionCallback and
does its work.
You may not use them together.
"""
[docs] def __init__(
self,
fields: List[str] = ("features", ),
alpha=1.0,
on_train_only=True,
**kwargs
):
"""
Args:
fields (List[str]): list of features which must be affected.
alpha (float): beta distribution a=b parameters.
Must be >=0. The more alpha closer to zero
the less effect of the mixup.
on_train_only (bool): Apply to train only.
As the mixup use the proxy inputs, the targets are also proxy.
We are not interested in them, are we?
So, if on_train_only is True, use a standard output/metric
for validation.
"""
assert len(fields) > 0, \
"At least one field for MixupCallback is required"
assert alpha >= 0, "alpha must be>=0"
super().__init__(**kwargs)
self.on_train_only = on_train_only
self.fields = fields
self.alpha = alpha
self.lam = 1
self.index = None
self.is_needed = True
[docs] def on_loader_start(self, state: RunnerState):
self.is_needed = not self.on_train_only or \
state.loader_name.startswith("train")
[docs] def on_batch_start(self, state: RunnerState):
if not self.is_needed:
return
if self.alpha > 0:
self.lam = np.random.beta(self.alpha, self.alpha)
else:
self.lam = 1
self.index = torch.randperm(state.input[self.fields[0]].shape[0])
self.index.to(state.device)
for f in self.fields:
state.input[f] = self.lam * state.input[f] + \
(1 - self.lam) * state.input[f][self.index]
def _compute_loss(self, state: RunnerState, criterion):
if not self.is_needed:
return super()._compute_loss(state, criterion)
pred = state.output[self.output_key]
y_a = state.input[self.input_key]
y_b = state.input[self.input_key][self.index]
loss = self.lam * criterion(pred, y_a) + \
(1 - self.lam) * criterion(pred, y_b)
return loss
__all__ = ["MixupCallback"]