Source code for catalyst.contrib.nn.schedulers.onecycle
from typing import List
import numpy as np
from torch.optim import Optimizer
from catalyst.contrib.nn.schedulers.base import BatchScheduler
from catalyst.utils.torch import get_optimizer_momentum
[docs]class OneCycleLRWithWarmup(BatchScheduler):
    """OneCycle scheduler with warm-up & lr decay stages.
    First stage increases lr from ``init_lr`` to ``max_lr``,
    and called ``warmup``. Also it decreases momentum
    from ``init_momentum`` to ``min_momentum``. Takes ``warmup_steps`` steps
    Second is ``annealing`` stage. Decrease lr from ``max_lr`` to ``min_lr``,
    Increase momentum from ``min_momentum`` to ``max_momentum``.
    Third, optional, lr decay.
    """
[docs]    def __init__(
        self,
        optimizer: Optimizer,
        num_steps: int,
        lr_range=(1.0, 0.005),
        init_lr: float = None,
        warmup_steps: int = 0,
        warmup_fraction: float = None,
        decay_steps: int = 0,
        decay_fraction: float = None,
        momentum_range=(0.8, 0.99, 0.999),
        init_momentum: float = None,
    ):
        """
        Args:
            optimizer: PyTorch optimizer
            num_steps: total number of steps
            lr_range: tuple with two or three elements
                (max_lr, min_lr, [final_lr])
            init_lr (float, optional): initial lr
            warmup_steps: count of steps for warm-up stage
            warmup_fraction (float, optional): fraction in [0; 1) to calculate
                number of warmup steps.
                Cannot be set together with ``warmup_steps``
            decay_steps: count of steps for lr decay stage
            decay_fraction (float, optional): fraction in [0; 1) to calculate
                number of decay steps.
                Cannot be set together with ``decay_steps``
            momentum_range: tuple with two or three elements
                (min_momentum, max_momentum, [final_momentum])
            init_momentum (float, optional): initial momentum
        """
        if len(lr_range) == 2:
            max_lr, min_lr = lr_range
            final_lr = min_lr
        elif len(lr_range) == 3:
            max_lr, min_lr, final_lr = lr_range
        if len(momentum_range) == 2:
            min_momentum, max_momentum = momentum_range
            final_momentum = max_momentum
        elif len(momentum_range) == 3:
            min_momentum, max_momentum, final_momentum = momentum_range
        if init_lr is None:
            init_lr = optimizer.defaults["lr"]
        if init_momentum is None:
            init_momentum = get_optimizer_momentum(optimizer)
        warmup_steps = self._calculate_warmup(
            num_steps, warmup_steps, warmup_fraction
        )
        decay_steps = self._calculate_decay(
            num_steps, decay_steps, decay_fraction
        )
        lr_annealing_steps = num_steps - (warmup_steps + decay_steps)
        self.warmup_steps = warmup_steps
        self.lr_annealing_steps = lr_annealing_steps
        self.decay_steps = decay_steps
        self.num_steps = warmup_steps + lr_annealing_steps + decay_steps
        self.lr_range = init_lr, max_lr, min_lr, final_lr
        self.momentum_range = (
            init_momentum,
            min_momentum,
            max_momentum,
            final_momentum,
        )
        self._calculate_lr_momentum(
            warmup_steps, lr_annealing_steps, decay_steps
        )
        self.total_groups = len(optimizer.param_groups)
        super().__init__(optimizer) 
    def _calculate_warmup(
        self, num_steps: int, warmup_steps: int, warmup_fraction: float
    ):
        if warmup_fraction is not None:
            assert 0.0 <= warmup_fraction < 1.0 and warmup_steps == 0, (
                "You should pass either warmup_steps or "
                "warmup_fraction in range [0; 1) "
            )
            warmup_steps = int(num_steps * warmup_fraction)
        self.warmup_steps = warmup_steps
        self.has_warmup = warmup_steps != 0
        return self.warmup_steps
    def _calculate_decay(
        self, num_steps: int, decay_steps: int, decay_fraction: float
    ):
        if decay_fraction is not None:
            assert 0.0 <= decay_fraction < 1.0 and decay_steps == 0, (
                "You should pass either decay_steps or "
                "decay_fraction in range [0; 1) "
            )
            decay_steps = int(num_steps * decay_fraction)
        self.decay_steps = decay_steps
        self.has_decay = decay_steps != 0
        return self.decay_steps
    def _calculate_lr_momentum(
        self, warmup_steps: int, lr_annealing_steps: int, decay_steps: int
    ):
        init_lr, max_lr, min_lr, final_lr = self.lr_range
        (
            init_momentum,
            min_momentum,
            max_momentum,
            final_momentum,
        ) = self.momentum_range
        lr_warmup = np.linspace(init_lr, max_lr, warmup_steps)
        lr_annealing = np.linspace(max_lr, min_lr, lr_annealing_steps)
        lr_decay = np.linspace(min_lr, final_lr, decay_steps)
        self.learning_rates = np.concatenate(
            (lr_warmup, lr_annealing, lr_decay)
        )
        momentum_decay = np.linspace(init_momentum, min_momentum, warmup_steps)
        momentum_annealing = np.linspace(
            min_momentum, max_momentum, lr_annealing_steps
        )
        momentum_warmup = np.linspace(
            max_momentum, final_momentum, decay_steps
        )
        self.momentums = np.concatenate(
            (momentum_decay, momentum_annealing, momentum_warmup)
        )
    def _get_steps_lr_momentum(self, step_num: int):
        if step_num < len(self.learning_rates):
            lr = self.learning_rates[step_num]
        else:
            _, _, _, final_lr = self.lr_range
            lr = final_lr
        if step_num < len(self.momentums):
            momentum = self.momentums[step_num]
        else:
            _, _, _, final_momentum = self.momentum_range
            momentum = final_momentum
        return lr, momentum
[docs]    def get_lr(self) -> List[float]:
        """Function that returns the new lr for optimizer.
        Returns:
            List[float]: calculated lr for every param groups
        """
        lr, _ = self._get_steps_lr_momentum(self.last_epoch)
        return [lr] * self.total_groups 
[docs]    def get_momentum(self) -> List[float]:
        """Function that returns the new momentum for optimizer.
        Returns:
            List[float]: calculated momentum for every param groups
        """
        _, momentum = self._get_steps_lr_momentum(self.last_epoch)
        return [momentum] * self.total_groups 
[docs]    def reset(self):
        """@TODO: Docs. Contribution is welcome."""
        self._calculate_lr_momentum(
            self.warmup_steps, self.lr_annealing_steps, self.decay_steps
        )
        self.last_epoch = 0 
[docs]    def recalculate(self, loader_len: int, current_step: int) -> None:
        """Recalculates total num_steps for ``batch`` mode.
        Args:
            loader_len: total count of batches in an epoch
            current_step: current step
        """
        warmup_steps = self.warmup_steps * loader_len
        lr_annealing_steps = self.lr_annealing_steps * loader_len
        decay_steps = self.decay_steps * loader_len
        self._calculate_lr_momentum(
            warmup_steps, lr_annealing_steps, decay_steps
        )
        self.last_epoch = current_step * loader_len  
__all__ = ["OneCycleLRWithWarmup"]