Source code for catalyst.callbacks.batch_overfit

from typing import TYPE_CHECKING
from collections import OrderedDict

from catalyst.core.callback import Callback, CallbackOrder
from import BatchLimitLoaderWrapper

    from catalyst.core.runner import IRunner

[docs]class BatchOverfitCallback(Callback): """Callback to overfit loaders with specified number of batches. By default we use ``1`` batch for loader. Args: kwargs: loader names and their number of batches to overfit. For example, if you have ``train``, ``train_additional``, ``valid`` and ``valid_additional`` loaders and wan't to overfit ``train`` on first 1 batch, ``train_additional`` on first 2 batches, ``valid`` - on first 20% of batches and ``valid_additional`` - on 50% batches: .. code-block:: python from catalyst.dl import SupervisedRunner, BatchOverfitCallback runner = SupervisedRunner() runner.train( ... loaders={ "train": ..., "train_additional": ..., "valid": ..., "valid_additional":... } ... callbacks=[ ... BatchOverfitCallback( train_additional=2, valid=0.2, valid_additional=0.5 ), ... ] ... ) Minimal working example .. code-block:: python import torch from import DataLoader, TensorDataset from catalyst import dl # data num_samples, num_features = int(1e4), int(1e1) X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) dataset = TensorDataset(X, y) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = torch.nn.Linear(num_features, 1) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) # model training runner = dl.SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=8, verbose=True, callbacks=[dl.BatchOverfitCallback(train=10, valid=0.5)] ) """ def __init__(self, **kwargs): """Init.""" super().__init__(order=CallbackOrder.internal) self.loader_batches = {} for loader, num_batches in kwargs.items(): if not isinstance(num_batches, (int, float)): raise TypeError( "Expected loader num_batches type is int/float " f"but got {type(num_batches)}" ) self.loader_batches[loader] = num_batches def on_epoch_start(self, runner: "IRunner") -> None: """Wraps loaders for current epoch. If number-of-batches for loader is not provided then the first batch from loader will be used for overfitting. Args: runner: current runner """ epoch_loaders = OrderedDict() for name, loader in runner.loaders.items(): num_batches = self.loader_batches.get(name, 1) if isinstance(num_batches, float): num_batches = int(len(loader) * num_batches) epoch_loaders[name] = BatchLimitLoaderWrapper( loader=loader, num_batches=num_batches ) runner.loaders = epoch_loaders def on_epoch_end(self, runner: "IRunner"): """Unwraps loaders for current epoch. Args: runner: current runner """ runner.loaders = { key: value.origin if isinstance(value, BatchLimitLoaderWrapper) else value for key, value in runner.loaders.items() }
__all__ = ["BatchOverfitCallback"]