import safitty
import torch
from catalyst.contrib.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.utils import get_optimizer_momentum
[docs]class SchedulerCallback(Callback):
def __init__(
self,
scheduler_key: str = None,
mode: str = None,
reduce_metric: str = "loss"
):
super().__init__(CallbackOrder.Scheduler)
self.scheduler_key = scheduler_key
self.mode = mode
self.reduce_metric = reduce_metric
[docs] def step(self, state: RunnerState):
scheduler = state.get_key(
key="scheduler", inner_key=self.scheduler_key
)
valid_metric = \
safitty.get(state.metrics.valid_values, self.reduce_metric)
lr, momentum = self._scheduler_step(
scheduler=scheduler, valid_metric=valid_metric
)
state.set_key(lr, key="lr", inner_key=self.scheduler_key)
state.set_key(momentum, key="momentum", inner_key=self.scheduler_key)
[docs] def on_stage_start(self, state: RunnerState):
scheduler = state.get_key(
key="scheduler", inner_key=self.scheduler_key
)
assert scheduler is not None
if self.mode is None:
if isinstance(scheduler, BatchScheduler):
self.mode = "batch"
else:
self.mode = "epoch"
if isinstance(scheduler, OneCycleLRWithWarmup) and \
self.mode == "batch":
scheduler.reset()
[docs] def on_loader_start(self, state: RunnerState):
scheduler = state.get_key(
key="scheduler", inner_key=self.scheduler_key
)
if state.loader_name.startswith("train") and \
isinstance(scheduler, OneCycleLRWithWarmup) and \
self.mode == "batch":
scheduler.recalculate(
loader_len=state.loader_len, current_step=state.stage_epoch
)
[docs] def on_batch_end(self, state):
if self.mode == "batch":
self.step(state=state)
[docs] def on_epoch_end(self, state):
if self.mode == "epoch":
self.step(state=state)
@staticmethod
def _scheduler_step(
scheduler,
valid_metric=None,
):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(valid_metric)
lr = safitty.get(scheduler.optimizer.param_groups, 0, "lr")
else:
scheduler.step()
lr = scheduler.get_lr()[0]
momentum = get_optimizer_momentum(scheduler.optimizer)
return lr, momentum
[docs]class LRUpdater(Callback):
"""Basic class that all Lr updaters inherit from"""
[docs] def __init__(self, optimizer_key: str = None):
"""
Args:
optimizer_key: which optimizer key to use
for learning rate scheduling
"""
super().__init__(CallbackOrder.Scheduler)
self.init_lr = 0
self.optimizer_key = optimizer_key
[docs] def calc_lr(self):
return None
[docs] def calc_momentum(self):
return None
@staticmethod
def _update_lr(optimizer, new_lr):
for pg in optimizer.param_groups:
pg["lr"] = new_lr
@staticmethod
def _update_momentum(optimizer, new_momentum):
if "betas" in optimizer.param_groups[0]:
for pg in optimizer.param_groups:
pg["betas"] = (new_momentum, pg["betas"][1])
else:
for pg in optimizer.param_groups:
pg["momentum"] = new_momentum
def _update_optimizer(self, optimizer):
new_lr = self.calc_lr()
if new_lr is not None:
self._update_lr(optimizer, new_lr)
new_momentum = self.calc_momentum()
if new_momentum is not None:
self._update_momentum(optimizer, new_momentum)
else:
new_momentum = get_optimizer_momentum(optimizer)
return new_lr, new_momentum
[docs] def update_optimizer(self, state):
if not state.need_backward:
return
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
lr, momentum = self._update_optimizer(optimizer=optimizer)
state.set_key(lr, key="lr", inner_key=self.optimizer_key)
state.set_key(momentum, key="momentum", inner_key=self.optimizer_key)
[docs] def on_stage_start(self, state):
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
self.init_lr = optimizer.defaults["lr"]
[docs] def on_loader_start(self, state):
if state.need_backward:
self.update_optimizer(state=state)
[docs] def on_batch_end(self, state):
if state.need_backward:
self.update_optimizer(state=state)
[docs]class LRFinder(LRUpdater):
"""
Helps you find an optimal learning rate for a model,
as per suggestion of 2015 CLR paper.
Learning rate is increased in linear or log scale, depending on user input.
https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
"""
[docs] def __init__(
self, final_lr, scale="log", num_steps=None, optimizer_key=None
):
"""
Args:
final_lr: final learning rate to try with
scale: learning rate increasing scale ("log" or "linear")
num_steps: number of batches to try;
if None - whole loader would be used.
optimizer_key: which optimizer key to use
for learning rate scheduling
"""
super().__init__(optimizer_key=optimizer_key)
self.final_lr = final_lr
self.scale = scale
self.num_steps = num_steps
self.multiplier = 0
self.lr_step = 0
self.find_iter = 0
self._calc_lr = None
if scale == "log":
self._calc_lr = self._calc_lr_log
elif scale == "linear":
self._calc_lr = self._calc_lr_linear
else:
raise Exception("Not supported")
[docs] def calc_lr(self):
res = self._calc_lr()
self.find_iter += 1
return res
def _calc_lr_log(self):
return self.init_lr * self.multiplier**self.find_iter
def _calc_lr_linear(self):
return self.init_lr + self.lr_step * self.find_iter
[docs] def on_loader_start(self, state):
if state.need_backward:
lr_ = self.final_lr / self.init_lr
self.num_steps = self.num_steps or state.loader_len
self.multiplier = lr_**(1 / self.num_steps)
self.lr_step = (self.final_lr - self.init_lr) / self.num_steps
super().on_loader_start(state=state)
[docs] def on_batch_end(self, state):
super().on_batch_end(state=state)
if self.find_iter > self.num_steps:
raise NotImplementedError("End of LRFinder")
__all__ = ["SchedulerCallback", "LRUpdater", "LRFinder"]