Shortcuts

Source code for catalyst.dl.callbacks.mixup

from typing import List

import numpy as np

import torch

from catalyst.dl import CriterionCallback, State


[docs]class MixupCallback(CriterionCallback): """Callback to do mixup augmentation. More details about mixin can be found in the paper `mixup: Beyond Empirical Risk Minimization`_. .. warning:: :class:`catalyst.dl.callbacks.MixupCallback` is inherited from :class:`catalyst.dl.CriterionCallback` and does its work. You may not use them together. .. _mixup\: Beyond Empirical Risk Minimization: https://arxiv.org/abs/1710.09412 """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", fields: List[str] = ("features",), alpha=1.0, on_train_only=True, **kwargs ): """ Args: fields (List[str]): list of features which must be affected. alpha (float): beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the less effect of the mixup. on_train_only (bool): Apply to train only. As the mixup use the proxy inputs, the targets are also proxy. We are not interested in them, are we? So, if on_train_only is True, use a standard output/metric for validation. """ assert isinstance(input_key, str) and isinstance(output_key, str) assert ( len(fields) > 0 ), "At least one field for MixupCallback is required" assert alpha >= 0, "alpha must be>=0" super().__init__(input_key=input_key, output_key=output_key, **kwargs) self.on_train_only = on_train_only self.fields = fields self.alpha = alpha self.lam = 1 self.index = None self.is_needed = True
def _compute_loss_value(self, state: State, criterion): if not self.is_needed: return super()._compute_loss_value(state, criterion) pred = state.output[self.output_key] y_a = state.input[self.input_key] y_b = state.input[self.input_key][self.index] loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion( pred, y_b ) return loss
[docs] def on_loader_start(self, state: State): """Loader start hook. Args: state (State): current state """ self.is_needed = not self.on_train_only or state.is_train_loader
[docs] def on_batch_start(self, state: State): """Batch start hook. Args: state (State): current state """ if not self.is_needed: return if self.alpha > 0: self.lam = np.random.beta(self.alpha, self.alpha) else: self.lam = 1 self.index = torch.randperm(state.input[self.fields[0]].shape[0]) self.index.to(state.device) for f in self.fields: state.input[f] = ( self.lam * state.input[f] + (1 - self.lam) * state.input[f][self.index] )
__all__ = ["MixupCallback"]