Source code for catalyst.dl.callbacks.scheduler

import safitty

import torch

from catalyst.contrib.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.utils import get_optimizer_momentum

[docs]class SchedulerCallback(Callback): def __init__( self, scheduler_key: str = None, mode: str = None, reduce_metric: str = "loss" ): super().__init__(CallbackOrder.Scheduler) self.scheduler_key = scheduler_key self.mode = mode self.reduce_metric = reduce_metric
[docs] def step(self, state: RunnerState): scheduler = state.get_key( key="scheduler", inner_key=self.scheduler_key ) valid_metric = \ safitty.get(state.metrics.valid_values, self.reduce_metric) lr, momentum = self._scheduler_step( scheduler=scheduler, valid_metric=valid_metric ) state.set_key(lr, key="lr", inner_key=self.scheduler_key) state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
[docs] def on_stage_start(self, state: RunnerState): scheduler = state.get_key( key="scheduler", inner_key=self.scheduler_key ) assert scheduler is not None if self.mode is None: if isinstance(scheduler, BatchScheduler): self.mode = "batch" else: self.mode = "epoch" if isinstance(scheduler, OneCycleLRWithWarmup) and \ self.mode == "batch": scheduler.reset()
[docs] def on_loader_start(self, state: RunnerState): scheduler = state.get_key( key="scheduler", inner_key=self.scheduler_key ) if state.loader_name.startswith("train") and \ isinstance(scheduler, OneCycleLRWithWarmup) and \ self.mode == "batch": scheduler.recalculate( loader_len=state.loader_len, current_step=state.stage_epoch )
[docs] def on_batch_end(self, state): if self.mode == "batch": self.step(state=state)
[docs] def on_epoch_end(self, state): if self.mode == "epoch": self.step(state=state)
@staticmethod def _scheduler_step( scheduler, valid_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(valid_metric) lr = safitty.get(scheduler.optimizer.param_groups, 0, "lr") else: scheduler.step() lr = scheduler.get_lr()[0] momentum = get_optimizer_momentum(scheduler.optimizer) return lr, momentum
[docs]class LRUpdater(Callback): """Basic class that all Lr updaters inherit from"""
[docs] def __init__(self, optimizer_key: str = None): """ Args: optimizer_key: which optimizer key to use for learning rate scheduling """ super().__init__(CallbackOrder.Scheduler) self.init_lr = 0 self.optimizer_key = optimizer_key
[docs] def calc_lr(self): return None
[docs] def calc_momentum(self): return None
@staticmethod def _update_lr(optimizer, new_lr): for pg in optimizer.param_groups: pg["lr"] = new_lr @staticmethod def _update_momentum(optimizer, new_momentum): if "betas" in optimizer.param_groups[0]: for pg in optimizer.param_groups: pg["betas"] = (new_momentum, pg["betas"][1]) else: for pg in optimizer.param_groups: pg["momentum"] = new_momentum def _update_optimizer(self, optimizer): new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) new_momentum = self.calc_momentum() if new_momentum is not None: self._update_momentum(optimizer, new_momentum) else: new_momentum = get_optimizer_momentum(optimizer) return new_lr, new_momentum
[docs] def update_optimizer(self, state): if not state.need_backward: return optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) lr, momentum = self._update_optimizer(optimizer=optimizer) state.set_key(lr, key="lr", inner_key=self.optimizer_key) state.set_key(momentum, key="momentum", inner_key=self.optimizer_key)
[docs] def on_stage_start(self, state): optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) self.init_lr = optimizer.defaults["lr"]
[docs] def on_loader_start(self, state): if state.need_backward: self.update_optimizer(state=state)
[docs] def on_batch_end(self, state): if state.need_backward: self.update_optimizer(state=state)
[docs]class LRFinder(LRUpdater): """ Helps you find an optimal learning rate for a model, as per suggestion of 2015 CLR paper. Learning rate is increased in linear or log scale, depending on user input. """
[docs] def __init__( self, final_lr, scale="log", num_steps=None, optimizer_key=None ): """ Args: final_lr: final learning rate to try with scale: learning rate increasing scale ("log" or "linear") num_steps: number of batches to try; if None - whole loader would be used. optimizer_key: which optimizer key to use for learning rate scheduling """ super().__init__(optimizer_key=optimizer_key) self.final_lr = final_lr self.scale = scale self.num_steps = num_steps self.multiplier = 0 self.lr_step = 0 self.find_iter = 0 self._calc_lr = None if scale == "log": self._calc_lr = self._calc_lr_log elif scale == "linear": self._calc_lr = self._calc_lr_linear else: raise Exception("Not supported")
[docs] def calc_lr(self): res = self._calc_lr() self.find_iter += 1 return res
def _calc_lr_log(self): return self.init_lr * self.multiplier**self.find_iter def _calc_lr_linear(self): return self.init_lr + self.lr_step * self.find_iter
[docs] def on_loader_start(self, state): if state.need_backward: lr_ = self.final_lr / self.init_lr self.num_steps = self.num_steps or state.loader_len self.multiplier = lr_**(1 / self.num_steps) self.lr_step = (self.final_lr - self.init_lr) / self.num_steps super().on_loader_start(state=state)
[docs] def on_batch_end(self, state): super().on_batch_end(state=state) if self.find_iter > self.num_steps: raise NotImplementedError("End of LRFinder")
__all__ = ["SchedulerCallback", "LRUpdater", "LRFinder"]