Source code for catalyst.contrib.nn.schedulers.onecycle

from typing import List  # isort:skip

import numpy as np

from torch.optim import Optimizer

from catalyst.utils import get_optimizer_momentum
from .base import BatchScheduler


[docs]class OneCycleLRWithWarmup(BatchScheduler): """ OneCycle scheduler with warm-up & lr decay stages. First stage increases lr from ``init_lr`` to ``max_lr``, and called ``warmup``. Also it decreases momentum from ``init_momentum`` to ``min_momentum``. Takes ``warmup_steps`` steps Second is ``annealing`` stage. Decrease lr from ``max_lr`` to ``min_lr``, Increase momentum from ``min_momentum`` to ``max_momentum``. Third, optional, lr decay. """
[docs] def __init__( self, optimizer: Optimizer, num_steps: int, lr_range=(1.0, 0.005), init_lr: float = None, warmup_steps: int = 0, warmup_fraction: float = None, decay_steps: int = 0, decay_fraction: float = None, momentum_range=(0.8, 0.99, 0.999), init_momentum: float = None, ): """ Args: optimizer: PyTorch optimizer num_steps (int): total number of steps lr_range: tuple with two or three elements (max_lr, min_lr, [final_lr]) init_lr (float, optional): initial lr warmup_steps (int): count of steps for warm-up stage warmup_fraction (float, optional): fraction in [0; 1) to calculate number of warmup steps. Cannot be set together with ``warmup_steps`` decay_steps (int): count of steps for lr decay stage decay_fraction (float, optional): fraction in [0; 1) to calculate number of decay steps. Cannot be set together with ``decay_steps`` momentum_range: tuple with two or three elements (min_momentum, max_momentum, [final_momentum]) init_momentum (float, optional): initial momentum """ if len(lr_range) == 2: max_lr, min_lr = lr_range final_lr = min_lr elif len(lr_range) == 3: max_lr, min_lr, final_lr = lr_range if len(momentum_range) == 2: min_momentum, max_momentum = momentum_range final_momentum = max_momentum elif len(momentum_range) == 3: min_momentum, max_momentum, final_momentum = momentum_range if init_lr is None: init_lr = optimizer.defaults["lr"] if init_momentum is None: init_momentum = get_optimizer_momentum(optimizer) warmup_steps = self._calculate_warmup( num_steps, warmup_steps, warmup_fraction ) decay_steps = self._calculate_decay( num_steps, decay_steps, decay_fraction ) lr_annealing_steps = num_steps - (warmup_steps + decay_steps) self.warmup_steps = warmup_steps self.lr_annealing_steps = lr_annealing_steps self.decay_steps = decay_steps self.num_steps = warmup_steps + lr_annealing_steps + decay_steps self.lr_range = init_lr, max_lr, min_lr, final_lr self.momentum_range = \ init_momentum, min_momentum, max_momentum, final_momentum self._calculate_lr_momentum( warmup_steps, lr_annealing_steps, decay_steps ) self.total_groups = len(optimizer.param_groups) super().__init__(optimizer)
def _calculate_warmup( self, num_steps: int, warmup_steps: int, warmup_fraction: float ): if warmup_fraction is not None: assert 0.0 <= warmup_fraction < 1.0 and warmup_steps == 0, \ "You should pass either warmup_steps or " \ "warmup_fraction in range [0; 1) " warmup_steps = int(num_steps * warmup_fraction) self.warmup_steps = warmup_steps self.has_warmup = warmup_steps != 0 return self.warmup_steps def _calculate_decay( self, num_steps: int, decay_steps: int, decay_fraction: float ): if decay_fraction is not None: assert 0.0 <= decay_fraction < 1.0 and decay_steps == 0, \ "You should pass either decay_steps or " \ "decay_fraction in range [0; 1) " decay_steps = int(num_steps * decay_fraction) self.decay_steps = decay_steps self.has_decay = decay_steps != 0 return self.decay_steps def _calculate_lr_momentum( self, warmup_steps: int, lr_annealing_steps: int, decay_steps: int ): init_lr, max_lr, min_lr, final_lr = self.lr_range init_momentum, min_momentum, max_momentum, final_momentum = \ self.momentum_range lr_warmup = np.linspace(init_lr, max_lr, warmup_steps) lr_annealing = np.linspace(max_lr, min_lr, lr_annealing_steps) lr_decay = np.linspace(min_lr, final_lr, decay_steps) self.learning_rates = np.concatenate( (lr_warmup, lr_annealing, lr_decay) ) momentum_decay = np.linspace(init_momentum, min_momentum, warmup_steps) momentum_annealing = np.linspace( min_momentum, max_momentum, lr_annealing_steps ) momentum_warmup = np.linspace( max_momentum, final_momentum, decay_steps ) self.momentums = np.concatenate( (momentum_decay, momentum_annealing, momentum_warmup) ) def _get_steps_lr_momentum(self, step_num: int): if step_num < len(self.learning_rates): lr = self.learning_rates[step_num] else: _, _, _, final_lr = self.lr_range lr = final_lr if step_num < len(self.momentums): momentum = self.momentums[step_num] else: _, _, _, final_momentum = self.momentum_range momentum = final_momentum return lr, momentum
[docs] def get_lr(self) -> List[float]: """ Function that returns the new lr for optimizer Returns: List[float]: calculated lr for every param groups """ lr, _ = self._get_steps_lr_momentum(self.last_epoch) return [lr] * self.total_groups
[docs] def get_momentum(self) -> List[float]: """ Function that returns the new momentum for optimizer Returns: List[float]: calculated momentum for every param groups """ _, momentum = self._get_steps_lr_momentum(self.last_epoch) return [momentum] * self.total_groups
[docs] def reset(self): self._calculate_lr_momentum( self.warmup_steps, self.lr_annealing_steps, self.decay_steps ) self.last_epoch = 0
[docs] def recalculate( self, loader_len: int, current_step: int, ) -> None: """ Recalculates total num_steps for ``batch`` mode Args: loader_len (int): total count of batches in an epoch current_step (int): current step """ warmup_steps = self.warmup_steps * loader_len lr_annealing_steps = self.lr_annealing_steps * loader_len decay_steps = self.decay_steps * loader_len self._calculate_lr_momentum( warmup_steps, lr_annealing_steps, decay_steps ) self.last_epoch = current_step * loader_len