Shortcuts

Source code for catalyst.contrib.dl.callbacks.cutmix_callback

from typing import List

import numpy as np

import torch

from catalyst.core.callbacks import CriterionCallback
from catalyst.core.runner import IRunner


[docs]class CutmixCallback(CriterionCallback): """ Callback to do Cutmix augmentation that has been proposed in `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features`_. .. warning:: `catalyst.contrib.dl.callbacks.CutmixCallback` is inherited from `catalyst.dl.CriterionCallback` and does its work. You may not use them together. .. _CutMix\: Regularization Strategy to Train Strong Classifiers with Localizable Features: https://arxiv.org/abs/1905.04899 # noqa: W605, E501, W505 """
[docs] def __init__( self, fields: List[str] = ("features",), alpha=1.0, on_train_only=True, **kwargs ): """ Args: fields (List[str]): list of features which must be affected. alpha (float): beta distribution parameter. on_train_only (bool): Apply to train only. So, if on_train_only is True, use a standard output/metric for validation. """ assert ( len(fields) > 0 ), "At least one field for CutmixCallback is required" assert alpha >= 0, "alpha must be >=0" super().__init__(**kwargs) self.on_train_only = on_train_only self.fields = fields self.alpha = alpha self.lam = 1 self.index = None self.is_needed = True
def _compute_loss(self, runner: IRunner, criterion): """Computes loss. If self.is_needed is ``False`` then calls ``_compute_loss`` from ``CriterionCallback``, otherwise computes loss value. Args: runner (IRunner): current runner criterion: that is used to compute loss """ if not self.is_needed: return super()._compute_loss_value(runner, criterion) pred = runner.output[self.output_key] y_a = runner.input[self.input_key] y_b = runner.input[self.input_key][self.index] loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion( pred, y_b ) return loss def _rand_bbox(self, size, lam): """ Generates top-left and bottom-right coordinates of the box of the given size. Args: size: size of the box lam: lambda parameter Returns: top-left and bottom-right coordinates of the box """ w = size[2] h = size[3] cut_rat = np.sqrt(1.0 - lam) cut_w = np.int(w * cut_rat) cut_h = np.int(h * cut_rat) cx = np.random.randint(w) cy = np.random.randint(h) bbx1 = np.clip(cx - cut_w // 2, 0, w) bby1 = np.clip(cy - cut_h // 2, 0, h) bbx2 = np.clip(cx + cut_w // 2, 0, w) bby2 = np.clip(cy + cut_h // 2, 0, h) return bbx1, bby1, bbx2, bby2
[docs] def on_loader_start(self, runner: IRunner) -> None: """Checks if it is needed for the loader. Args: runner (IRunner): current runner """ self.is_needed = not self.on_train_only or runner.is_train_loader
[docs] def on_batch_start(self, runner: IRunner) -> None: """Mixes data according to Cutmix algorithm. Args: runner (IRunner): current runner """ if not self.is_needed: return if self.alpha > 0: self.lam = np.random.beta(self.alpha, self.alpha) else: self.lam = 1 self.index = torch.randperm(runner.input[self.fields[0]].shape[0]) self.index.to(runner.device) bbx1, bby1, bbx2, bby2 = self._rand_bbox( runner.input[self.fields[0]].shape, self.lam ) for f in self.fields: runner.input[f][:, :, bbx1:bbx2, bby1:bby2] = runner.input[f][ self.index, :, bbx1:bbx2, bby1:bby2 ] self.lam = 1 - ( (bbx2 - bbx1) * (bby2 - bby1) / ( runner.input[self.fields[0]].shape[-1] * runner.input[self.fields[0]].shape[-2] ) )
__all__ = ["CutmixCallback"]