from typing import List

import numpy as np

import torch

from catalyst.core import IRunner
from catalyst.dl import CriterionCallback

[docs]class MixupCallback(CriterionCallback): """Callback to do mixup augmentation. More details about mixin can be found in the paper `mixup: Beyond Empirical Risk Minimization`_. .. warning:: :class:`catalyst.dl.callbacks.MixupCallback` is inherited from :class:`catalyst.dl.CriterionCallback` and does its work. You may not use them together. .. _mixup\: Beyond Empirical Risk Minimization: """
[docs] def __init__( self, input_key: str = "targets", output_key: str = "logits", fields: List[str] = ("features",), alpha=1.0, on_train_only=True, **kwargs ): """ Args: fields (List[str]): list of features which must be affected. alpha (float): beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the less effect of the mixup. on_train_only (bool): Apply to train only. As the mixup use the proxy inputs, the targets are also proxy. We are not interested in them, are we? So, if on_train_only is True, use a standard output/metric for validation. """ assert isinstance(input_key, str) and isinstance(output_key, str) assert ( len(fields) > 0 ), "At least one field for MixupCallback is required" assert alpha >= 0, "alpha must be>=0" super().__init__(input_key=input_key, output_key=output_key, **kwargs) self.on_train_only = on_train_only self.fields = fields self.alpha = alpha self.lam = 1 self.index = None self.is_needed = True
def _compute_loss_value(self, runner: IRunner, criterion): if not self.is_needed: return super()._compute_loss_value(runner, criterion) pred = runner.output[self.output_key] y_a = runner.input[self.input_key] y_b = runner.input[self.input_key][self.index] loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion( pred, y_b ) return loss
[docs] def on_loader_start(self, runner: IRunner): """Loader start hook. Args: runner (IRunner): current runner """ self.is_needed = not self.on_train_only or runner.is_train_loader
[docs] def on_batch_start(self, runner: IRunner): """Batch start hook. Args: runner (IRunner): current runner """ if not self.is_needed: return if self.alpha > 0: self.lam = np.random.beta(self.alpha, self.alpha) else: self.lam = 1 self.index = torch.randperm(runner.input[self.fields[0]].shape[0]) for f in self.fields: runner.input[f] = ( self.lam * runner.input[f] + (1 - self.lam) * runner.input[f][self.index] )
__all__ = ["MixupCallback"]