Source code for catalyst.contrib.nn.schedulers.base

from typing import List, Optional  # isort:skip
from abc import ABC

from torch.optim.lr_scheduler import _LRScheduler

from catalyst.utils import set_optimizer_momentum


[docs]class BaseScheduler(_LRScheduler, ABC): """ Base class for all schedulers with momentum update """
[docs] def get_momentum(self) -> List[float]: """ Function that returns the new momentum for optimizer Returns: List[float]: calculated momentum for every param groups """ raise NotImplementedError
[docs] def step(self, epoch: Optional[int] = None) -> None: """ Make one scheduler step Args: epoch (int, optional): current epoch's num """ super().step(epoch) momentums = self.get_momentum() for i, momentum in enumerate(momentums): set_optimizer_momentum(self.optimizer, momentum, index=i)
[docs]class BatchScheduler(BaseScheduler, ABC): pass