Source code for catalyst.core.callbacks.scheduler
from typing import Tuple
import torch
from catalyst.contrib.nn.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.core import utils
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import _Runner
[docs]class SchedulerCallback(Callback):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(
self,
scheduler_key: str = None,
mode: str = None,
reduced_metric: str = None,
):
"""@TODO: Docs. Contribution is welcome."""
super().__init__(order=CallbackOrder.Scheduler, node=CallbackNode.All)
self.scheduler_key = scheduler_key
self.mode = mode
self.reduced_metric = reduced_metric
@staticmethod
def _scheduler_step(
scheduler, reduced_metric=None,
):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(reduced_metric)
lr = scheduler.optimizer.param_groups[0]["lr"]
else:
scheduler.step()
lr = scheduler.get_lr()[0]
momentum = utils.get_optimizer_momentum(scheduler.optimizer)
return lr, momentum
[docs] def step_batch(self, runner: _Runner) -> None:
"""@TODO: Docs. Contribution is welcome.
Args:
runner (_Runner): current runner
"""
lr, momentum = self._scheduler_step(scheduler=self._scheduler)
if self.scheduler_key is not None:
runner.batch_metrics[f"lr/{self.scheduler_key}"] = lr
if momentum is not None:
runner.batch_metrics[
f"momentum/{self.scheduler_key}"
] = momentum
else:
runner.batch_metrics["lr"] = lr
if momentum is not None:
runner.batch_metrics["momentum"] = momentum
[docs] def step_epoch(self, runner: _Runner) -> None:
"""@TODO: Docs. Contribution is welcome.
Args:
runner (_Runner): current runner
"""
reduced_metric = runner.valid_metrics[self.reduced_metric]
lr, momentum = self._scheduler_step(
scheduler=self._scheduler, reduced_metric=reduced_metric
)
if self.scheduler_key is not None:
runner.epoch_metrics[f"lr/{self.scheduler_key}"] = lr
if momentum is not None:
runner.epoch_metrics[
f"momentum/{self.scheduler_key}"
] = momentum
else:
runner.epoch_metrics["lr"] = lr
if momentum is not None:
runner.epoch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: _Runner) -> None:
"""Stage start hook.
Args:
runner (_Runner): current runner
"""
self.reduced_metric = self.reduced_metric or runner.main_metric
scheduler = runner.get_attr(
key="scheduler", inner_key=self.scheduler_key
)
assert scheduler is not None
self._scheduler = scheduler
if self.mode is None:
if isinstance(scheduler, BatchScheduler):
self.mode = "batch"
else:
self.mode = "epoch"
if (
isinstance(scheduler, OneCycleLRWithWarmup)
and self.mode == "batch"
):
scheduler.reset()
assert self.mode is not None
[docs] def on_loader_start(self, runner: _Runner) -> None:
"""Loader start hook.
Args:
runner (_Runner): current runner
"""
if (
runner.is_train_loader
and isinstance(self._scheduler, OneCycleLRWithWarmup)
and self.mode == "batch"
):
self._scheduler.recalculate(
loader_len=runner.loader_len, current_step=runner.epoch
)
[docs] def on_batch_end(self, runner: _Runner) -> None:
"""Batch end hook.
Args:
runner (_Runner): current runner
"""
if runner.is_train_loader and self.mode == "batch":
self.step_batch(runner=runner)
[docs] def on_epoch_end(self, runner: _Runner) -> None:
"""Epoch end hook.
Args:
runner (_Runner): current runner
"""
if self.mode == "epoch":
self.step_epoch(runner=runner)
[docs]class LRUpdater(Callback):
"""Basic class that all Lr updaters inherit from."""
[docs] def __init__(self, optimizer_key: str = None):
"""
Args:
optimizer_key (str): which optimizer key to use
for learning rate scheduling
"""
super().__init__(order=CallbackOrder.Scheduler, node=CallbackNode.All)
self.init_lr = 0
self.optimizer_key = optimizer_key
[docs] def calc_lr(self):
"""@TODO: Docs. Contribution is welcome."""
return None
[docs] def calc_momentum(self):
"""@TODO: Docs. Contribution is welcome."""
return None
@staticmethod
def _update_lr(optimizer, new_lr) -> None:
for pg in optimizer.param_groups:
pg["lr"] = new_lr
@staticmethod
def _update_momentum(optimizer, new_momentum) -> None:
if "betas" in optimizer.param_groups[0]:
for pg in optimizer.param_groups:
pg["betas"] = (new_momentum, pg["betas"][1])
else:
for pg in optimizer.param_groups:
pg["momentum"] = new_momentum
def _update_optimizer(self, optimizer) -> Tuple[float, float]:
new_lr = self.calc_lr()
if new_lr is not None:
self._update_lr(optimizer, new_lr)
new_momentum = self.calc_momentum()
if new_momentum is not None:
self._update_momentum(optimizer, new_momentum)
else:
new_momentum = utils.get_optimizer_momentum(optimizer)
return new_lr, new_momentum
[docs] def update_optimizer(self, runner: _Runner) -> None:
"""@TODO: Docs. Contribution is welcome.
Args:
runner (_Runner): current runner
"""
lr, momentum = self._update_optimizer(optimizer=self._optimizer)
if self.optimizer_key is not None:
runner.batch_metrics[f"lr_{self.optimizer_key}"] = lr
runner.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum
else:
runner.batch_metrics["lr"] = lr
runner.batch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: _Runner) -> None:
"""Stage start hook.
Args:
runner (_Runner): current runner
"""
optimizer = runner.get_attr(
key="optimizer", inner_key=self.optimizer_key
)
assert optimizer is not None
self._optimizer = optimizer
self.init_lr = optimizer.defaults["lr"]
[docs] def on_loader_start(self, runner: _Runner) -> None:
"""Loader start hook.
Args:
runner (_Runner): current runner
"""
if runner.is_train_loader:
self.update_optimizer(runner=runner)
[docs] def on_batch_end(self, runner: _Runner) -> None:
"""Batch end hook.
Args:
runner (_Runner): current runner
"""
if runner.is_train_loader:
self.update_optimizer(runner=runner)
__all__ = ["SchedulerCallback", "LRUpdater"]