Shortcuts

Source code for catalyst.utils.mixup

from typing import List

import numpy as np
import torch


[docs]def mixup_batch( batch: List[torch.Tensor], alpha: float = 0.2, mode: str = "replace" ) -> List[torch.Tensor]: """ Args: batch: batch to which you want to apply augmentation alpha: beta distribution a=b parameters. Must be >=0. The closer alpha to zero the less effect of the mixup. mode: algorithm used for muxup: ``"replace"`` | ``"add"``. If "replace" then replaces the batch with a mixed one, while the batch size is not changed If "add", concatenates mixed examples to the current ones, the batch size increases by 2 times. Returns: augmented batch """ assert alpha >= 0, "alpha must be>=0" assert mode in ("add", "replace"), f"mode must be in 'add', 'replace', get: {mode}" batch_size = batch[0].shape[0] beta = np.random.beta(alpha, alpha, batch_size).astype(np.float32) indexes = np.arange(batch_size) # index shift by 1 indexes_2 = (indexes + 1) % batch_size for idx, targets in enumerate(batch): device = targets.device targets_shape = [batch_size] + [1] * len(targets.shape[1:]) key_beta = torch.as_tensor(beta.reshape(targets_shape), device=device) targets = targets * key_beta + targets[indexes_2] * (1 - key_beta) if mode == "replace": batch[idx] = targets else: # mode == 'add' batch[idx] = torch.cat([batch[idx], targets]) return batch