Source code for catalyst.core.callbacks.scheduler
from typing import Tuple
from abc import ABC, abstractmethod
import torch
from catalyst.contrib.nn.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.core import utils
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner
[docs]class SchedulerCallback(Callback):
"""Callback for wrapping schedulers.
Notebook API example:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst.dl import (
SupervisedRunner, AccuracyCallback,
CriterionCallback, SchedulerCallback,
)
num_samples, num_features = 10_000, 10
n_classes = 10
X = torch.rand(num_samples, num_features)
y = torch.randint(0, n_classes, [num_samples])
loader = DataLoader(TensorDataset(X, y), batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
model = torch.nn.Linear(num_features, n_classes)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
runner = SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
num_epochs=5,
verbose=False,
main_metric="accuracy03",
minimize_metric=False,
callbacks=[
AccuracyCallback(
accuracy_args=[1, 3, 5]
),
SchedulerCallback(reduced_metric="loss")
]
)
Config API usage example:
.. code-block:: yaml
stages:
...
scheduler_params:
scheduler: MultiStepLR
milestones: [1]
gamma: 0.3
...
stage_N:
...
callbacks_params:
...
scheduler:
callback: SchedulerCallback
# arguments for SchedulerCallback
reduced_metric: loss
...
"""
[docs] def __init__(
self,
scheduler_key: str = None,
mode: str = None,
reduced_metric: str = None,
):
"""
Args:
scheduler_key (str): scheduler name, if ``None``,
default is ``None``.
mode (str): scheduler mode, should be one of
``"epoch"`` or ``"batch"``, default is ``None``.
If ``None`` and object is instance of ``BatchScheduler``
or ``OneCycleLRWithWarmup`` then will be used ``"batch"``
otherwise - ``"epoch"``.
reduced_metric (str): metric name to forward to scheduler
object, if ``None`` then will be used main metric
specified in experiment.
"""
super().__init__(order=CallbackOrder.scheduler, node=CallbackNode.all)
self.scheduler_key = scheduler_key
self.mode = mode
self.reduced_metric = reduced_metric
@staticmethod
def _scheduler_step(
scheduler, reduced_metric=None,
):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(reduced_metric)
lr = scheduler.optimizer.param_groups[0]["lr"]
else:
scheduler.step()
lr = scheduler.get_lr()[0]
momentum = utils.get_optimizer_momentum(scheduler.optimizer)
return lr, momentum
[docs] def step_batch(self, runner: IRunner) -> None:
"""Update learning rate and momentum in runner.
Args:
runner (IRunner): current runner
"""
lr, momentum = self._scheduler_step(scheduler=self._scheduler)
if self.scheduler_key is not None:
runner.batch_metrics[f"lr/{self.scheduler_key}"] = lr
if momentum is not None:
runner.batch_metrics[
f"momentum/{self.scheduler_key}"
] = momentum
else:
runner.batch_metrics["lr"] = lr
if momentum is not None:
runner.batch_metrics["momentum"] = momentum
[docs] def step_epoch(self, runner: IRunner) -> None:
"""Update momentum in runner.
Args:
runner (IRunner): current runner
"""
reduced_metric = runner.valid_metrics[self.reduced_metric]
lr, momentum = self._scheduler_step(
scheduler=self._scheduler, reduced_metric=reduced_metric
)
if self.scheduler_key is not None:
runner.epoch_metrics[f"lr/{self.scheduler_key}"] = lr
if momentum is not None:
runner.epoch_metrics[
f"momentum/{self.scheduler_key}"
] = momentum
else:
runner.epoch_metrics["lr"] = lr
if momentum is not None:
runner.epoch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: IRunner) -> None:
"""Stage start hook.
Args:
runner (IRunner): current runner
"""
self.reduced_metric = self.reduced_metric or runner.main_metric
scheduler = runner.get_attr(
key="scheduler", inner_key=self.scheduler_key
)
assert scheduler is not None
self._scheduler = scheduler
if self.mode is None:
if isinstance(scheduler, BatchScheduler):
self.mode = "batch"
else:
self.mode = "epoch"
if (
isinstance(scheduler, OneCycleLRWithWarmup)
and self.mode == "batch"
):
scheduler.reset()
assert self.mode is not None
[docs] def on_loader_start(self, runner: IRunner) -> None:
"""Loader start hook.
Args:
runner (IRunner): current runner
"""
if (
runner.is_train_loader
and isinstance(self._scheduler, OneCycleLRWithWarmup)
and self.mode == "batch"
):
self._scheduler.recalculate(
loader_len=runner.loader_len, current_step=runner.epoch - 1
)
[docs] def on_batch_end(self, runner: IRunner) -> None:
"""Batch end hook.
Args:
runner (IRunner): current runner
"""
if runner.is_train_loader and self.mode == "batch":
self.step_batch(runner=runner)
[docs] def on_epoch_end(self, runner: IRunner) -> None:
"""Epoch end hook.
Args:
runner (IRunner): current runner
"""
if self.mode == "epoch":
self.step_epoch(runner=runner)
[docs]class LRUpdater(ABC, Callback):
"""Basic class that all Lr updaters inherit from."""
[docs] def __init__(self, optimizer_key: str = None):
"""
Args:
optimizer_key (str): which optimizer key to use
for learning rate scheduling
"""
super().__init__(order=CallbackOrder.scheduler, node=CallbackNode.all)
self.init_lr = 0
self.optimizer_key = optimizer_key
[docs] @abstractmethod
def calc_lr(self):
"""Interface for calculating learning rate."""
pass
[docs] @abstractmethod
def calc_momentum(self):
"""Interface for calculating momentum"""
pass
@staticmethod
def _update_lr(optimizer, new_lr) -> None:
for pg in optimizer.param_groups:
pg["lr"] = new_lr
@staticmethod
def _update_momentum(optimizer, new_momentum) -> None:
if "betas" in optimizer.param_groups[0]:
for pg in optimizer.param_groups:
pg["betas"] = (new_momentum, pg["betas"][1])
else:
for pg in optimizer.param_groups:
pg["momentum"] = new_momentum
def _update_optimizer(self, optimizer) -> Tuple[float, float]:
new_lr = self.calc_lr()
if new_lr is not None:
self._update_lr(optimizer, new_lr)
new_momentum = self.calc_momentum()
if new_momentum is not None:
self._update_momentum(optimizer, new_momentum)
else:
new_momentum = utils.get_optimizer_momentum(optimizer)
return new_lr, new_momentum
[docs] def update_optimizer(self, runner: IRunner) -> None:
"""Update learning rate and momentum in runner.
Args:
runner (IRunner): current runner
"""
lr, momentum = self._update_optimizer(optimizer=self._optimizer)
if self.optimizer_key is not None:
runner.batch_metrics[f"lr_{self.optimizer_key}"] = lr
runner.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum
else:
runner.batch_metrics["lr"] = lr
runner.batch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: IRunner) -> None:
"""Stage start hook.
Args:
runner (IRunner): current runner
"""
optimizer = runner.get_attr(
key="optimizer", inner_key=self.optimizer_key
)
assert optimizer is not None
self._optimizer = optimizer
self.init_lr = optimizer.defaults["lr"]
[docs] def on_loader_start(self, runner: IRunner) -> None:
"""Loader start hook.
Args:
runner (IRunner): current runner
"""
if runner.is_train_loader:
self.update_optimizer(runner=runner)
[docs] def on_batch_end(self, runner: IRunner) -> None:
"""Batch end hook.
Args:
runner (IRunner): current runner
"""
if runner.is_train_loader:
self.update_optimizer(runner=runner)
__all__ = ["SchedulerCallback", "LRUpdater"]