Shortcuts

Source code for catalyst.callbacks.batch_overfit

from typing import TYPE_CHECKING
from collections import OrderedDict

from catalyst.core.callback import Callback, CallbackOrder
from catalyst.data.loader import BatchLimitLoaderWrapper

if TYPE_CHECKING:
    from catalyst.core.runner import IRunner


[docs]class BatchOverfitCallback(Callback): """Callback for ovefitting loaders with specified number of batches. By default we use ``1`` batch for loader. For example, if you have ``train``, ``train_additional``, ``valid`` and ``valid_additional`` loaders and wan't to overfit ``train`` on first 1 batch, ``train_additional`` on first 2 batches, ``valid`` - on first 20% of batches and ``valid_additional`` - on 50% batches: .. code-block:: python from catalyst.dl import ( SupervisedRunner, BatchOverfitCallback, ) runner = SupervisedRunner() runner.train( ... loaders={ "train": ..., "train_additional": ..., "valid": ..., "valid_additional":... } ... callbacks=[ ... BatchOverfitCallback( train_additional=2, valid=0.2, valid_additional=0.5 ), ... ] ... ) Minimal working example .. code-block:: python import torch from torch.utils.data import DataLoader, TensorDataset from catalyst import dl # data num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, 1) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) # model training runner = dl.SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=8, verbose=True, callbacks=[dl.BatchOverfitCallback(train=10, valid=0.5)] ) """
[docs] def __init__(self, **kwargs): """ Args: kwargs: loader names and their number of batches to overfit. """ super().__init__(order=CallbackOrder.internal) self.loader_batches = {} for loader, num_batches in kwargs.items(): if not isinstance(num_batches, (int, float)): raise TypeError( "Expected loader num_batches type is int/float " f"but got {type(num_batches)}" ) self.loader_batches[loader] = num_batches
[docs] def on_epoch_start(self, runner: "IRunner") -> None: """ Wraps loaders for current epoch. If number-of-batches for loader is not provided then the first batch from loader will be used for overfitting. Args: runner: current runner """ epoch_loaders = OrderedDict() for name, loader in runner.loaders.items(): num_batches = self.loader_batches.get(name, 1) if isinstance(num_batches, float): num_batches = int(len(loader) * num_batches) epoch_loaders[name] = BatchLimitLoaderWrapper( loader=loader, num_batches=num_batches, ) runner.loaders = epoch_loaders
[docs] def on_epoch_end(self, runner: "IRunner"): """ Unwraps loaders for current epoch. Args: runner: current runner """ runner.loaders = { key: value.origin if isinstance(value, BatchLimitLoaderWrapper) else value for key, value in runner.loaders.items() }
__all__ = ["BatchOverfitCallback"]