Shortcuts

Source code for catalyst.core.callbacks.scheduler

from typing import Tuple
from abc import ABC, abstractmethod

import torch

from catalyst.contrib.nn.schedulers import BatchScheduler, OneCycleLRWithWarmup
from catalyst.core import utils
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner


[docs]class SchedulerCallback(Callback): """@TODO: Docs. Contribution is welcome."""
[docs] def __init__( self, scheduler_key: str = None, mode: str = None, reduced_metric: str = None, ): """@TODO: Docs. Contribution is welcome.""" super().__init__(order=CallbackOrder.scheduler, node=CallbackNode.all) self.scheduler_key = scheduler_key self.mode = mode self.reduced_metric = reduced_metric
@staticmethod def _scheduler_step( scheduler, reduced_metric=None, ): if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): scheduler.step(reduced_metric) lr = scheduler.optimizer.param_groups[0]["lr"] else: scheduler.step() lr = scheduler.get_lr()[0] momentum = utils.get_optimizer_momentum(scheduler.optimizer) return lr, momentum
[docs] def step_batch(self, runner: IRunner) -> None: """@TODO: Docs. Contribution is welcome. Args: runner (IRunner): current runner """ lr, momentum = self._scheduler_step(scheduler=self._scheduler) if self.scheduler_key is not None: runner.batch_metrics[f"lr/{self.scheduler_key}"] = lr if momentum is not None: runner.batch_metrics[ f"momentum/{self.scheduler_key}" ] = momentum else: runner.batch_metrics["lr"] = lr if momentum is not None: runner.batch_metrics["momentum"] = momentum
[docs] def step_epoch(self, runner: IRunner) -> None: """@TODO: Docs. Contribution is welcome. Args: runner (IRunner): current runner """ reduced_metric = runner.valid_metrics[self.reduced_metric] lr, momentum = self._scheduler_step( scheduler=self._scheduler, reduced_metric=reduced_metric ) if self.scheduler_key is not None: runner.epoch_metrics[f"lr/{self.scheduler_key}"] = lr if momentum is not None: runner.epoch_metrics[ f"momentum/{self.scheduler_key}" ] = momentum else: runner.epoch_metrics["lr"] = lr if momentum is not None: runner.epoch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: IRunner) -> None: """Stage start hook. Args: runner (IRunner): current runner """ self.reduced_metric = self.reduced_metric or runner.main_metric scheduler = runner.get_attr( key="scheduler", inner_key=self.scheduler_key ) assert scheduler is not None self._scheduler = scheduler if self.mode is None: if isinstance(scheduler, BatchScheduler): self.mode = "batch" else: self.mode = "epoch" if ( isinstance(scheduler, OneCycleLRWithWarmup) and self.mode == "batch" ): scheduler.reset() assert self.mode is not None
[docs] def on_loader_start(self, runner: IRunner) -> None: """Loader start hook. Args: runner (IRunner): current runner """ if ( runner.is_train_loader and isinstance(self._scheduler, OneCycleLRWithWarmup) and self.mode == "batch" ): self._scheduler.recalculate( loader_len=runner.loader_len, current_step=runner.epoch - 1 )
[docs] def on_batch_end(self, runner: IRunner) -> None: """Batch end hook. Args: runner (IRunner): current runner """ if runner.is_train_loader and self.mode == "batch": self.step_batch(runner=runner)
[docs] def on_epoch_end(self, runner: IRunner) -> None: """Epoch end hook. Args: runner (IRunner): current runner """ if self.mode == "epoch": self.step_epoch(runner=runner)
[docs]class LRUpdater(ABC, Callback): """Basic class that all Lr updaters inherit from."""
[docs] def __init__(self, optimizer_key: str = None): """ Args: optimizer_key (str): which optimizer key to use for learning rate scheduling """ super().__init__(order=CallbackOrder.scheduler, node=CallbackNode.all) self.init_lr = 0 self.optimizer_key = optimizer_key
[docs] @abstractmethod def calc_lr(self): """@TODO: Docs. Contribution is welcome.""" pass
[docs] @abstractmethod def calc_momentum(self): """@TODO: Docs. Contribution is welcome.""" pass
@staticmethod def _update_lr(optimizer, new_lr) -> None: for pg in optimizer.param_groups: pg["lr"] = new_lr @staticmethod def _update_momentum(optimizer, new_momentum) -> None: if "betas" in optimizer.param_groups[0]: for pg in optimizer.param_groups: pg["betas"] = (new_momentum, pg["betas"][1]) else: for pg in optimizer.param_groups: pg["momentum"] = new_momentum def _update_optimizer(self, optimizer) -> Tuple[float, float]: new_lr = self.calc_lr() if new_lr is not None: self._update_lr(optimizer, new_lr) new_momentum = self.calc_momentum() if new_momentum is not None: self._update_momentum(optimizer, new_momentum) else: new_momentum = utils.get_optimizer_momentum(optimizer) return new_lr, new_momentum
[docs] def update_optimizer(self, runner: IRunner) -> None: """@TODO: Docs. Contribution is welcome. Args: runner (IRunner): current runner """ lr, momentum = self._update_optimizer(optimizer=self._optimizer) if self.optimizer_key is not None: runner.batch_metrics[f"lr_{self.optimizer_key}"] = lr runner.batch_metrics[f"momentum_{self.optimizer_key}"] = momentum else: runner.batch_metrics["lr"] = lr runner.batch_metrics["momentum"] = momentum
[docs] def on_stage_start(self, runner: IRunner) -> None: """Stage start hook. Args: runner (IRunner): current runner """ optimizer = runner.get_attr( key="optimizer", inner_key=self.optimizer_key ) assert optimizer is not None self._optimizer = optimizer self.init_lr = optimizer.defaults["lr"]
[docs] def on_loader_start(self, runner: IRunner) -> None: """Loader start hook. Args: runner (IRunner): current runner """ if runner.is_train_loader: self.update_optimizer(runner=runner)
[docs] def on_batch_end(self, runner: IRunner) -> None: """Batch end hook. Args: runner (IRunner): current runner """ if runner.is_train_loader: self.update_optimizer(runner=runner)
__all__ = ["SchedulerCallback", "LRUpdater"]