Source code for catalyst.rl.offpolicy.algorithms.actor_critic

from typing import Dict, Union  # isort:skip
import copy

from gym.spaces import Box

from catalyst.rl import utils
from catalyst.rl.core import (
    ActorSpec, AlgorithmSpec, CriticSpec, EnvironmentSpec
)
from catalyst.rl.registry import AGENTS


[docs]class OffpolicyActorCritic(AlgorithmSpec): def __init__( self, actor: ActorSpec, critic: CriticSpec, gamma: float, n_step: int, actor_loss_params: Dict = None, critic_loss_params: Dict = None, actor_optimizer_params: Dict = None, critic_optimizer_params: Dict = None, actor_scheduler_params: Dict = None, critic_scheduler_params: Dict = None, actor_grad_clip_params: Dict = None, critic_grad_clip_params: Dict = None, actor_tau: float = 1.0, critic_tau: float = 1.0, action_boundaries: tuple = None, **kwargs ): self._device = utils.get_device() self.actor = actor.to(self._device) self.critic = critic.to(self._device) self.target_actor = copy.deepcopy(actor).to(self._device) self.target_critic = copy.deepcopy(critic).to(self._device) # actor preparation actor_components = utils.get_trainer_components( agent=self.actor, loss_params=actor_loss_params, optimizer_params=actor_optimizer_params, scheduler_params=actor_scheduler_params, grad_clip_params=actor_grad_clip_params ) # criterion self._actor_loss_params = actor_components["loss_params"] self.actor_criterion = actor_components["criterion"] # optimizer self._actor_optimizer_params = actor_components["optimizer_params"] self.actor_optimizer = actor_components["optimizer"] # scheduler self._actor_scheduler_params = actor_components["scheduler_params"] self.actor_scheduler = actor_components["scheduler"] # grad clipping self._actor_grad_clip_params = actor_components["grad_clip_params"] self.actor_grad_clip_fn = actor_components["grad_clip_fn"] # critic preparation critic_components = utils.get_trainer_components( agent=self.critic, loss_params=critic_loss_params, optimizer_params=critic_optimizer_params, scheduler_params=critic_scheduler_params, grad_clip_params=critic_grad_clip_params ) # criterion self._critic_loss_params = critic_components["loss_params"] self.critic_criterion = critic_components["criterion"] # optimizer self._critic_optimizer_params = critic_components["optimizer_params"] self.critic_optimizer = critic_components["optimizer"] # scheduler self._critic_scheduler_params = critic_components["scheduler_params"] self.critic_scheduler = critic_components["scheduler"] # grad clipping self._critic_grad_clip_params = critic_components["grad_clip_params"] self.critic_grad_clip_fn = critic_components["grad_clip_fn"] # other hyperparameters self._n_step = n_step self._gamma = gamma self._actor_tau = actor_tau self._critic_tau = critic_tau if action_boundaries is not None: assert len(action_boundaries) == 2, \ "Should be min and max action boundaries" self._action_boundaries = action_boundaries # other init self._init(**kwargs) def _init(self, **kwargs): assert len(kwargs) == 0 @property def n_step(self) -> int: return self._n_step @property def gamma(self) -> float: return self._gamma
[docs] def pack_checkpoint(self, with_optimizer: bool = True): checkpoint = {} for key in ["actor", "critic"]: checkpoint[f"{key}_state_dict"] = getattr(self, key).state_dict() if with_optimizer: for key2 in ["optimizer", "scheduler"]: key2 = f"{key}_{key2}" value2 = getattr(self, key2, None) if value2 is not None: checkpoint[f"{key2}_state_dict"] = value2.state_dict() return checkpoint
[docs] def unpack_checkpoint(self, checkpoint, with_optimizer: bool = True): for key in ["actor", "critic"]: value_l = getattr(self, key, None) if value_l is not None: value_r = checkpoint[f"{key}_state_dict"] value_l.load_state_dict(value_r) if with_optimizer: for key2 in ["optimizer", "scheduler"]: key2 = f"{key}_{key2}" value_l = getattr(self, key2, None) if value_l is not None: value_r = checkpoint[f"{key2}_state_dict"] value_l.load_state_dict(value_r)
[docs] def actor_update(self, loss): self.actor.zero_grad() self.actor_optimizer.zero_grad() loss.backward() if self.actor_grad_clip_fn is not None: self.actor_grad_clip_fn(self.actor.parameters()) self.actor_optimizer.step() if self.actor_scheduler is not None: self.actor_scheduler.step() return {"lr_actor": self.actor_scheduler.get_lr()[0]}
[docs] def critic_update(self, loss): self.critic.zero_grad() self.critic_optimizer.zero_grad() loss.backward() if self.critic_grad_clip_fn is not None: self.critic_grad_clip_fn(self.critic.parameters()) self.critic_optimizer.step() if self.critic_scheduler is not None: self.critic_scheduler.step() return {"lr_critic": self.critic_scheduler.get_lr()[0]}
[docs] def target_actor_update(self): utils.soft_update(self.target_actor, self.actor, self._actor_tau)
[docs] def target_critic_update(self): utils.soft_update(self.target_critic, self.critic, self._critic_tau)
[docs] def update_step( self, policy_loss, value_loss, actor_update=True, critic_update=True ): """ Updates parameters of neural networks and returns learning metrics Args: policy_loss: value_loss: actor_update: critic_update: Returns: """ raise NotImplementedError
[docs] def train(self, batch, actor_update=True, critic_update=True): states_t, actions_t, rewards_t, states_tp1, done_t = \ batch["state"], batch["action"], batch["reward"], \ batch["next_state"], batch["done"] states_t = utils.any2device(states_t, device=self._device) actions_t = utils.any2device(actions_t, device=self._device) rewards_t = utils.any2device( rewards_t, device=self._device ).unsqueeze(1) states_tp1 = utils.any2device(states_tp1, device=self._device) done_t = utils.any2device(done_t, device=self._device).unsqueeze(1) """ states_t: [bs; history_len; observation_len] actions_t: [bs; action_len] rewards_t: [bs; 1] states_tp1: [bs; history_len; observation_len] done_t: [bs; 1] """ policy_loss, value_loss = self._loss_fn( states_t, actions_t, rewards_t, states_tp1, done_t ) metrics = self.update_step( policy_loss=policy_loss, value_loss=value_loss, actor_update=actor_update, critic_update=critic_update ) return metrics
[docs] @classmethod def prepare_for_trainer( cls, env_spec: EnvironmentSpec, config: Dict ) -> "AlgorithmSpec": config_ = config.copy() agents_config = config_["agents"] actor_params = agents_config["actor"] actor = AGENTS.get_from_params( **actor_params, env_spec=env_spec, ) critic_params = agents_config["critic"] critic = AGENTS.get_from_params( **critic_params, env_spec=env_spec, ) action_boundaries = config_["algorithm"].pop("action_boundaries", None) if action_boundaries is None: action_space = env_spec.action_space assert isinstance(action_space, Box) action_boundaries = [action_space.low[0], action_space.high[0]] algorithm = cls( **config_["algorithm"], action_boundaries=action_boundaries, actor=actor, critic=critic, ) return algorithm
[docs] @classmethod def prepare_for_sampler( cls, env_spec: EnvironmentSpec, config: Dict ) -> Union[ActorSpec, CriticSpec]: config_ = config.copy() agents_config = config_["agents"] actor_params = agents_config["actor"] actor = AGENTS.get_from_params( **actor_params, env_spec=env_spec, ) return actor