Source code for catalyst.rl.offpolicy.algorithms.ddpg

import torch

from catalyst.rl import utils
from .actor_critic import OffpolicyActorCritic


[docs]class DDPG(OffpolicyActorCritic): """ Swiss Army knife DDPG algorithm. """ def _init(self): # value distribution approximation critic_distribution = self.critic.distribution self._loss_fn = self._base_loss self._num_heads = self.critic.num_heads self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) self._gammas = utils.any2device(self._gammas, device=self._device) assert critic_distribution in [None, "categorical", "quantile"] if critic_distribution == "categorical": self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) z = torch.linspace( start=self.v_min, end=self.v_max, steps=self.num_atoms ) self.z = utils.any2device(z, device=self._device) self._loss_fn = self._categorical_loss elif critic_distribution == "quantile": self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self.num_atoms) tau_max = 1 - tau_min tau = torch.linspace( start=tau_min, end=tau_max, steps=self.num_atoms ) self.tau = utils.any2device(tau, device=self._device) self._loss_fn = self._quantile_loss else: assert self.critic_criterion is not None def _process_components(self, done_t, rewards_t): # Array of size [num_heads,] gammas = self._gammas**self._n_step gammas = gammas[None, :, None] # [1; num_heads; 1] # We use the same done_t, rewards_t, actions_t for each head done_t = done_t[:, None, :] # [bs; 1; 1] rewards_t = rewards_t[:, None, :] # [bs; 1; 1] return gammas, done_t, rewards_t def _base_loss(self, states_t, actions_t, rewards_t, states_tp1, done_t): gammas, done_t, rewards_t = self._process_components(done_t, rewards_t) # actor loss # For now we have the same actor for all heads of the critic policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t))) # critic loss # [bs; num_heads; 1] -> many-heads view transform # [{bs * num_heads}; 1] q_values_t = ( self.critic(states_t, actions_t).squeeze_(dim=2).view(-1, 1) ) # [bs; num_heads; 1] q_values_tp1 = self.target_critic( states_tp1, self.target_actor(states_tp1) ).squeeze_(dim=2) # [bs; num_heads; 1] -> many-heads view transform # [{bs * num_heads}; 1] q_target_t = (rewards_t + (1 - done_t) * gammas * q_values_tp1).view(-1, 1).detach() value_loss = self.critic_criterion(q_values_t, q_target_t).mean() return policy_loss, value_loss def _categorical_loss( self, states_t, actions_t, rewards_t, states_tp1, done_t ): gammas, done_t, rewards_t = self._process_components(done_t, rewards_t) # actor loss # For now we have the same actor for all heads of the critic # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_tp0 = ( self.critic(states_t, self.actor(states_t)).squeeze_( dim=2 ).view(-1, self.num_atoms) ) # [{bs * num_heads}; num_atoms] probs_tp0 = torch.softmax(logits_tp0, dim=-1) # [{bs * num_heads}; 1] q_values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1) policy_loss = -torch.mean(q_values_tp0) # critic loss (kl-divergence between categorical distributions) # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_t = ( self.critic(states_t, actions_t).squeeze_(dim=2).view(-1, self.num_atoms) ) # [bs; action_size] actions_tp1 = self.target_actor(states_tp1) # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] logits_tp1 = ( self.target_critic(states_tp1, actions_tp1).squeeze_( dim=2 ).view(-1, self.num_atoms) ).detach() # [bs; num_heads; num_atoms] -> many-heads view transform # [{bs * num_heads}; num_atoms] atoms_target_t = (rewards_t + (1 - done_t) * gammas * self.z).view(-1, self.num_atoms) value_loss = utils.categorical_loss( # [{bs * num_heads}; num_atoms] logits_t, # [{bs * num_heads}; num_atoms] logits_tp1, # [{bs * num_heads}; num_atoms] atoms_target_t, self.z, self.delta_z, self.v_min, self.v_max ) return policy_loss, value_loss def _quantile_loss( self, states_t, actions_t, rewards_t, states_tp1, done_t ): gammas, done_t, rewards_t = self._process_components(done_t, rewards_t) # actor loss policy_loss = -torch.mean(self.critic(states_t, self.actor(states_t))) # critic loss (quantile regression) # [bs; num_heads; num_atoms] atoms_t = self.critic(states_t, actions_t).squeeze_(dim=2) # [bs; num_heads; num_atoms] atoms_tp1 = self.target_critic( states_tp1, self.target_actor(states_tp1) ).squeeze_(dim=2).detach() # [bs; num_heads; num_atoms] atoms_target_t = rewards_t + (1 - done_t) * gammas * atoms_tp1 value_loss = utils.quantile_loss( # [{bs * num_heads}; num_atoms] atoms_t.view(-1, self.num_atoms), # [{bs * num_heads}; num_atoms] atoms_target_t.view(-1, self.num_atoms), self.tau, self.num_atoms, self.critic_criterion ) return policy_loss, value_loss
[docs] def update_step( self, policy_loss, value_loss, actor_update=True, critic_update=True ): # actor update actor_update_metrics = {} if actor_update: actor_update_metrics = self.actor_update(policy_loss) or {} # critic update critic_update_metrics = {} if critic_update: critic_update_metrics = self.critic_update(value_loss) or {} loss = 0 loss += value_loss loss += policy_loss metrics = { "loss": loss.item(), "loss_critic": value_loss.item(), "loss_actor": policy_loss.item() } metrics = {**metrics, **actor_update_metrics, **critic_update_metrics} return metrics