Source code for catalyst.callbacks.scheduler

from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from abc import ABC, abstractmethod

import torch

from catalyst.contrib.nn.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.utils.misc import get_attr
from catalyst.utils.torch import (

    from catalyst.core.runner import IRunner

[docs]class ISchedulerCallback(Callback): """Scheduler callback interface, abstraction over scheduler step.""" pass
[docs]class SchedulerCallback(ISchedulerCallback): """Callback for wrapping schedulers. Notebook API example: .. code-block:: python import torch from import DataLoader, TensorDataset from catalyst.dl import ( SupervisedRunner, AccuracyCallback, CriterionCallback, SchedulerCallback, ) num_samples, num_features = 10_000, 10 n_classes = 10 X = torch.rand(num_samples, num_features) y = torch.randint(0, n_classes, [num_samples]) loader = DataLoader(TensorDataset(X, y), batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} model = torch.nn.Linear(num_features, n_classes) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) runner = SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, logdir="./logdir", num_epochs=5, verbose=False, main_metric="accuracy03", minimize_metric=False, callbacks=[ AccuracyCallback( accuracy_args=[1, 3, 5] ), SchedulerCallback(reduced_metric="loss") ] ) Config API usage example: .. code-block:: yaml stages: ... scheduler_params: scheduler: MultiStepLR milestones: [1] gamma: 0.3 ... stage_N: ... callbacks_params: ... scheduler: callback: SchedulerCallback # arguments for SchedulerCallback reduced_metric: loss ... """
[docs] def __init__( self, scheduler_key: str = None, mode: str = None, reduced_metric: str = None, ): """ Args: scheduler_key: scheduler name, if ``None``, default is ``None``. mode: scheduler mode, should be one of ``"epoch"`` or ``"batch"``, default is ``None``. If ``None`` and object is instance of ``BatchScheduler`` or ``OneCycleLRWithWarmup`` then will be used ``"batch"`` otherwise - ``"epoch"``. reduced_metric: metric name to forward to scheduler object, if ``None`` then will be used main metric specified in experiment. """ super().__init__(order=CallbackOrder.scheduler, node=CallbackNode.all) self.scheduler_key = scheduler_key self.mode = mode self.reduced_metric = reduced_metric
@staticmethod def _scheduler_step( scheduler, reduced_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(reduced_metric) else: scheduler.step() lr_list = [ param_group["lr"] for param_group in scheduler.optimizer.param_groups ] momentum_list = get_optimizer_momentum_list(scheduler.optimizer) return lr_list, momentum_list def _update_lr_and_momentum_in_metrics_dict( self, metrics_dict: dict, lr_list: List[float], momentum_list: List[Union[float, None]], ): """Update learning rate and momentum in metrics_dict (consider only 0-th param group) Args: metrics_dict: batch_metrics or epoch_metrics lr_list: lr for each param group momentum_list: momentum for each param group """ # todo: consider saving lr and momentum for all param groups ? lr = lr_list[0] momentum = momentum_list[0] lr_key = ( f"lr/{self.scheduler_key}" if self.scheduler_key is not None else "lr" ) metrics_dict[lr_key] = lr if momentum is not None: momentum_key = ( f"momentum/{self.scheduler_key}" if self.scheduler_key is not None else "momentum" ) metrics_dict[momentum_key] = momentum
[docs] def step_batch(self, runner: "IRunner") -> None: """Perform scheduler step and update batch metrics Args: runner: current runner """ lr_list, momentum_list = self._scheduler_step( scheduler=self._scheduler ) self._update_lr_and_momentum_in_metrics_dict( runner.batch_metrics, lr_list, momentum_list )
[docs] def step_epoch(self, runner: "IRunner") -> None: """Perform scheduler step and update epoch metrics Args: runner: current runner """ reduced_metric = runner.valid_metrics[self.reduced_metric] lr_list, momentum_list = self._scheduler_step( scheduler=self._scheduler, reduced_metric=reduced_metric ) self._update_lr_and_momentum_in_metrics_dict( runner.epoch_metrics, lr_list, momentum_list )
[docs] def on_stage_start(self, runner: "IRunner") -> None: """Stage start hook. Args: runner: current runner """ self.reduced_metric = self.reduced_metric or runner.main_metric scheduler = get_attr( runner, key="scheduler", inner_key=self.scheduler_key ) assert scheduler is not None self._scheduler = scheduler if self.mode is None: if isinstance(scheduler, BatchScheduler): self.mode = "batch" else: self.mode = "epoch" if ( isinstance(scheduler, OneCycleLRWithWarmup) and self.mode == "batch" ): scheduler.reset() assert self.mode is not None
[docs] def on_loader_start(self, runner: "IRunner") -> None: """Loader start hook. Args: runner: current runner """ if ( runner.is_train_loader and isinstance(self._scheduler, OneCycleLRWithWarmup) and self.mode == "batch" ): self._scheduler.recalculate( loader_len=runner.loader_len, current_step=runner.epoch - 1 )
[docs] def on_batch_end(self, runner: "IRunner") -> None: """Batch end hook. Args: runner: current runner """ if runner.is_train_loader and self.mode == "batch": self.step_batch(runner=runner)
[docs] def on_epoch_end(self, runner: "IRunner") -> None: """Epoch end hook. Args: runner: current runner """ if self.mode == "epoch": self.step_epoch(runner=runner)
[docs]class ILRUpdater(ABC, Callback): """Basic class that all Lr updaters inherit from."""
[docs] def __init__(self, optimizer_key: str = None): """ Args: optimizer_key: which optimizer key to use for learning rate scheduling """ super().__init__(order=CallbackOrder.scheduler, node=CallbackNode.all) self.init_lr = 0 self.optimizer_key = optimizer_key
[docs] @abstractmethod def calc_lr(self): """Interface for calculating learning rate.""" pass
[docs] @abstractmethod def calc_momentum(self): """Interface for calculating momentum""" pass
@staticmethod def _update_lr(optimizer, new_lr) -> None: for pg in optimizer.param_groups: pg["lr"] = new_lr @staticmethod def _update_momentum(optimizer, new_momentum) -> None: if "betas" in optimizer.param_groups[0]: for pg in optimizer.param_groups: pg["betas"] = (new_momentum, pg["betas"][1]) else: for pg in optimizer.param_groups: pg["momentum"] = new_momentum def _update_optimizer(self, optimizer) -> Tuple[float, float]: new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) new_momentum = self.calc_momentum() if new_momentum is not None: self._update_momentum(optimizer, new_momentum) else: new_momentum = get_optimizer_momentum(optimizer) return new_lr, new_momentum
[docs] def update_optimizer(self, runner: "IRunner") -> None: """Update learning rate and momentum in runner. Args: runner: current runner """ lr, momentum = self._update_optimizer(optimizer=self._optimizer) if self.optimizer_key is not None: runner.batch_metrics[f"lr_{self.optimizer_key}"] = lr runner.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum else: runner.batch_metrics["lr"] = lr runner.batch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: "IRunner") -> None: """Stage start hook. Args: runner: current runner """ optimizer = get_attr( runner, key="optimizer", inner_key=self.optimizer_key ) assert optimizer is not None self._optimizer = optimizer self.init_lr = optimizer.defaults["lr"]
[docs] def on_loader_start(self, runner: "IRunner") -> None: """Loader start hook. Args: runner: current runner """ if runner.is_train_loader: self.update_optimizer(runner=runner)
[docs] def on_batch_end(self, runner: "IRunner") -> None: """Batch end hook. Args: runner: current runner """ if runner.is_train_loader: self.update_optimizer(runner=runner)
[docs]class LRFinder(ILRUpdater): """ Helps you find an optimal learning rate for a model, as per suggestion of `Cyclical Learning Rates for Training Neural Networks`_ paper. Learning rate is increased in linear or log scale, depending on user input. See `How Do You Find A Good Learning Rate`_ article for details. .. _Cyclical Learning Rates for Training Neural Networks: .. _How Do You Find A Good Learning Rate: """
[docs] def __init__( self, final_lr, scale: str = "log", num_steps: Optional[int] = None, optimizer_key: str = None, ): """ Args: final_lr: final learning rate to try with scale: learning rate increasing scale ("log" or "linear") num_steps: number of batches to try; if None - whole loader would be used. optimizer_key: which optimizer key to use for learning rate scheduling """ super().__init__(optimizer_key=optimizer_key) self.final_lr = final_lr self.scale = scale self.num_steps = num_steps self.multiplier = 0 self.lr_step = 0 self.find_iter = 0 self._calc_lr = None if scale == "log": self._calc_lr = self._calc_lr_log elif scale == "linear": self._calc_lr = self._calc_lr_linear else: raise Exception("Not supported")
def _calc_lr_log(self): return self.init_lr * self.multiplier ** self.find_iter def _calc_lr_linear(self): return self.init_lr + self.lr_step * self.find_iter
[docs] def calc_lr(self): """Calculates learning rate. Returns: learning rate. """ res = self._calc_lr() self.find_iter += 1 return res
[docs] def calc_momentum(self): """Calculates new momentum.""" pass
[docs] def on_loader_start(self, runner: "IRunner"): """Loader start hook. Updates scheduler statistics. Args: runner: current runner """ if runner.is_train_loader: lr_step = self.final_lr / self.init_lr self.num_steps = self.num_steps or runner.loader_len self.multiplier = lr_step ** (1 / self.num_steps) self.lr_step = (self.final_lr - self.init_lr) / self.num_steps super().on_loader_start(runner=runner)
[docs] def on_batch_end(self, runner: "IRunner"): """Batch end hook. Make scheduler step and stops iterating if needed. Args: runner: current runner Raises: NotImplementedError: at the end of LRFinder """ super().on_batch_end(runner=runner) if self.find_iter > self.num_steps: raise NotImplementedError("End of LRFinder")
__all__ = ["ISchedulerCallback", "SchedulerCallback", "ILRUpdater", "LRFinder"]