Shortcuts

Source code for catalyst.rl.onpolicy.algorithms.actor

from typing import Dict, Union  # isort:skip

from catalyst.rl import utils
from catalyst.rl.core import (
    ActorSpec, AlgorithmSpec, CriticSpec, EnvironmentSpec
)
from catalyst.rl.registry import AGENTS


[docs]class OnpolicyActor(AlgorithmSpec): def __init__( self, actor: ActorSpec, gamma: float, n_step: int, actor_loss_params: Dict = None, actor_optimizer_params: Dict = None, actor_scheduler_params: Dict = None, actor_grad_clip_params: Dict = None, **kwargs ): self._device = utils.get_device() self.actor = actor.to(self._device) # actor preparation actor_components = utils.get_trainer_components( agent=self.actor, loss_params=actor_loss_params, optimizer_params=actor_optimizer_params, scheduler_params=actor_scheduler_params, grad_clip_params=actor_grad_clip_params ) # criterion self._actor_loss_params = actor_components["loss_params"] self.actor_criterion = actor_components["criterion"] # optimizer self._actor_optimizer_params = actor_components["optimizer_params"] self.actor_optimizer = actor_components["optimizer"] # scheduler self._actor_scheduler_params = actor_components["scheduler_params"] self.actor_scheduler = actor_components["scheduler"] # grad clipping self._actor_grad_clip_params = actor_components["grad_clip_params"] self.actor_grad_clip_fn = actor_components["grad_clip_fn"] # other hyperparameters self._n_step = n_step self._gamma = gamma # other init self._init(**kwargs) @property def n_step(self) -> int: return self._n_step @property def gamma(self) -> float: return self._gamma
[docs] def pack_checkpoint(self, with_optimizer: bool = True): checkpoint = {} for key in ["actor"]: checkpoint[f"{key}_state_dict"] = getattr(self, key).state_dict() if with_optimizer: for key2 in ["optimizer", "scheduler"]: key2 = f"{key}_{key2}" value2 = getattr(self, key2, None) if value2 is not None: checkpoint[f"{key2}_state_dict"] = value2.state_dict() return checkpoint
[docs] def unpack_checkpoint(self, checkpoint, with_optimizer: bool = True): for key in ["actor"]: value_l = getattr(self, key, None) if value_l is not None: value_r = checkpoint[f"{key}_state_dict"] value_l.load_state_dict(value_r) if with_optimizer: for key2 in ["optimizer", "scheduler"]: key2 = f"{key}_{key2}" value_l = getattr(self, key2, None) if value_l is not None: value_r = checkpoint[f"{key2}_state_dict"] value_l.load_state_dict(value_r)
[docs] def actor_update(self, loss): self.actor.zero_grad() self.actor_optimizer.zero_grad() loss.backward() if self.actor_grad_clip_fn is not None: self.actor_grad_clip_fn(self.actor.parameters()) self.actor_optimizer.step() if self.actor_scheduler is not None: self.actor_scheduler.step() return {"lr_actor": self.actor_scheduler.get_lr()[0]}
[docs] def get_rollout_spec(self) -> Dict: raise NotImplementedError()
[docs] def get_rollout(self, states, actions, rewards, dones): raise NotImplementedError()
[docs] def postprocess_buffer(self, buffers, len): raise NotImplementedError()
[docs] @classmethod def prepare_for_trainer( cls, env_spec: EnvironmentSpec, config: Dict ) -> "AlgorithmSpec": config_ = config.copy() agents_config = config_["agents"] actor_params = agents_config["actor"] actor = AGENTS.get_from_params( **actor_params, env_spec=env_spec, ) algorithm = cls( **config_["algorithm"], actor=actor, ) return algorithm
[docs] @classmethod def prepare_for_sampler( cls, env_spec: EnvironmentSpec, config: Dict ) -> Union[ActorSpec, CriticSpec]: config_ = config.copy() agents_config = config_["agents"] actor_params = agents_config["actor"] actor = AGENTS.get_from_params( **actor_params, env_spec=env_spec, ) return actor