Source code for catalyst.contrib.nn.optimizers.lookahead

from typing import Callable, Dict, Optional
from collections import defaultdict

import torch
from torch.optim import Optimizer

[docs]class Lookahead(Optimizer): """Implements Lookahead algorithm. It has been proposed in `Lookahead Optimizer: k steps forward, 1 step back`_. Main origins of inspiration: (MIT License) .. _`Lookahead Optimizer\: k steps forward, 1 step back`: """
[docs] def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5): """@TODO: Docs. Contribution is welcome.""" self.optimizer = optimizer self.k = k self.alpha = alpha self.param_groups = self.optimizer.param_groups self.defaults = self.optimizer.defaults self.state = defaultdict(dict) self.fast_state = self.optimizer.state for group in self.param_groups: group["counter"] = 0
[docs] def update(self, group): """@TODO: Docs. Contribution is welcome.""" for fast in group["params"]: param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.zeros_like( param_state["slow_param"].copy_( slow = param_state["slow_param"] slow += ( - slow) * self.alpha
[docs] def update_lookahead(self): """@TODO: Docs. Contribution is welcome.""" for group in self.param_groups: self.update(group)
[docs] def step(self, closure: Optional[Callable] = None): """Makes optimizer step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = self.optimizer.step(closure) for group in self.param_groups: if group["counter"] == 0: self.update(group) group["counter"] += 1 if group["counter"] >= self.k: group["counter"] = 0 return loss
[docs] def state_dict(self): """@TODO: Docs. Contribution is welcome.""" fast_state_dict = self.optimizer.state_dict() slow_state = { (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } fast_state = fast_state_dict["state"] param_groups = fast_state_dict["param_groups"] return { "fast_state": fast_state, "slow_state": slow_state, "param_groups": param_groups, }
[docs] def load_state_dict(self, state_dict): """@TODO: Docs. Contribution is welcome.""" slow_state_dict = { "state": state_dict["slow_state"], "param_groups": state_dict["param_groups"], } fast_state_dict = { "state": state_dict["fast_state"], "param_groups": state_dict["param_groups"], } super(Lookahead, self).load_state_dict(slow_state_dict) self.optimizer.load_state_dict(fast_state_dict) self.fast_state = self.optimizer.state
[docs] def add_param_group(self, param_group): """@TODO: Docs. Contribution is welcome.""" param_group["counter"] = 0 self.optimizer.add_param_group(param_group)
[docs] @classmethod def get_from_params( cls, params: Dict, base_optimizer_params: Dict = None, **kwargs, ) -> "Lookahead": """@TODO: Docs. Contribution is welcome.""" from catalyst.dl.registry import OPTIMIZERS base_optimizer = OPTIMIZERS.get_from_params( params=params, **base_optimizer_params ) optimizer = cls(optimizer=base_optimizer, **kwargs) return optimizer