Source code for catalyst.rl.onpolicy.algorithms.ppo

import numpy as np

import torch

from catalyst.rl import utils
from .actor_critic import OnpolicyActorCritic


[docs]class PPO(OnpolicyActorCritic): def _init( self, use_value_clipping: bool = True, gae_lambda: float = 0.95, clip_eps: float = 0.2, entropy_regularization: float = None ): self.use_value_clipping = use_value_clipping self.gae_lambda = gae_lambda self.clip_eps = clip_eps self.entropy_regularization = entropy_regularization critic_distribution = self.critic.distribution self._value_loss_fn = self._base_value_loss self._num_atoms = self.critic.num_atoms self._num_heads = self.critic.num_heads self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) # 1 x num_heads x 1 self._gammas_torch = utils.any2device( self._gammas, device=self._device )[None, :, None] if critic_distribution == "categorical": self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1) z = torch.linspace( start=self.v_min, end=self.v_max, steps=self._num_atoms ) self.z = utils.any2device(z, device=self._device) self._value_loss_fn = self._categorical_value_loss elif critic_distribution == "quantile": assert self.critic_criterion is not None self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self._num_atoms) tau_max = 1 - tau_min tau = torch.linspace( start=tau_min, end=tau_max, steps=self._num_atoms ) self.tau = utils.any2device(tau, device=self._device) self._value_loss_fn = self._quantile_value_loss if not self.use_value_clipping: assert self.critic_criterion is not None def _value_loss(self, values_tp0, values_t, returns_t): if self.use_value_clipping: values_clip = values_t + torch.clamp( values_tp0 - values_t, -self.clip_eps, self.clip_eps ) value_loss_unclipped = (values_tp0 - returns_t).pow(2) value_loss_clipped = (values_clip - returns_t).pow(2) value_loss = 0.5 * torch.max( value_loss_unclipped, value_loss_clipped ).mean() else: value_loss = self.critic_criterion(values_tp0, returns_t).mean() return value_loss def _base_value_loss( self, states_t, values_t, returns_t, states_tp1, done_t ): # [bs; num_heads; 1, num_atoms=1] -> # [bs; num_heads; num_atoms=1] -> many-heads view transform # [{bs * num_heads}; num_atoms=1] values_tp0 = self.critic(states_t).squeeze_(dim=2) # [bs; num_heads; num_atoms=1] -> many-heads view transform # [{bs * num_heads}; num_atoms=1] values_t = values_t.view(-1, 1) # [bs; num_heads; num_atoms=1] -> many-heads view transform # [{bs * num_heads}; num_atoms=1] returns_t = returns_t.view(-1, 1) value_loss = self._value_loss(values_tp0, values_t, returns_t) return value_loss def _categorical_value_loss( self, states_t, logits_t, returns_t, states_tp1, done_t ): # @TODO: WIP, no guaranties logits_tp0 = self.critic(states_t).squeeze_(dim=2) probs_tp0 = torch.softmax(logits_tp0, dim=-1) values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1, keepdim=True) probs_t = torch.softmax(logits_t, dim=-1) values_t = torch.sum(probs_t * self.z, dim=-1, keepdim=True) value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t) # B x num_heads x num_atoms logits_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach() # B x num_heads x num_atoms atoms_target_t = returns_t + (1 - done_t) * self._gammas_torch * self.z value_loss += 0.5 * utils.categorical_loss( logits_tp0.view(-1, self.num_atoms), logits_tp1.view(-1, self.num_atoms), atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z, self.v_min, self.v_max ) return value_loss def _quantile_value_loss( self, states_t, atoms_t, returns_t, states_tp1, done_t ): # @TODO: WIP, no guaranties atoms_tp0 = self.critic(states_t).squeeze_(dim=2) values_tp0 = torch.mean(atoms_tp0, dim=-1, keepdim=True) values_t = torch.mean(atoms_t, dim=-1, keepdim=True) value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t) # B x num_heads x num_atoms atoms_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach() # B x num_heads x num_atoms atoms_target_t = returns_t \ + (1 - done_t) * self._gammas_torch * atoms_tp1 value_loss += 0.5 * utils.quantile_loss( atoms_tp0.view(-1, self.num_atoms), atoms_target_t.view(-1, self.num_atoms), self.tau, self.num_atoms, self.critic_criterion ) return value_loss
[docs] def get_rollout_spec(self): return { "action_logprob": { "shape": (), "dtype": np.float32 }, "advantage": { "shape": (self._num_heads, self._num_atoms), "dtype": np.float32 }, "done": { "shape": (), "dtype": np.bool }, "return": { "shape": (self._num_heads, ), "dtype": np.float32 }, "value": { "shape": (self._num_heads, self._num_atoms), "dtype": np.float32 }, }
[docs] @torch.no_grad() def get_rollout(self, states, actions, rewards, dones): assert len(states) == len(actions) == len(rewards) == len(dones) trajectory_len = \ rewards.shape[0] if dones[-1] else rewards.shape[0] - 1 states_len = states.shape[0] states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) rewards = np.array(rewards)[:trajectory_len] values = torch.zeros( (states_len + 1, self._num_heads, self._num_atoms)).\ to(self._device) values[:states_len, ...] = self.critic(states).squeeze_(dim=2) # Each column corresponds to a different gamma values = values.cpu().numpy()[:trajectory_len + 1, ...] _, logprobs = self.actor(states, logprob=actions) logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len] # len x num_heads deltas = rewards[:, None, None] \ + self._gammas[:, None] * values[1:] - values[:-1] # For each gamma in the list of gammas compute the # advantage and returns # len x num_heads x num_atoms advantages = np.stack( [ utils.geometric_cumsum(gamma * self.gae_lambda, deltas[:, i]) for i, gamma in enumerate(self._gammas) ], axis=1 ) # len x num_heads returns = np.stack( [ utils.geometric_cumsum(gamma, rewards[:, None])[:, 0] for gamma in self._gammas ], axis=1 ) # final rollout dones = dones[:trajectory_len] values = values[:trajectory_len] assert len(logprobs) == len(advantages) \ == len(dones) == len(returns) == len(values) rollout = { "action_logprob": logprobs, "advantage": advantages, "done": dones, "return": returns, "value": values, } return rollout
[docs] def postprocess_buffer(self, buffers, len): adv = buffers["advantage"][:len] adv = (adv - adv.mean(axis=0)) / (adv.std(axis=0) + 1e-8) buffers["advantage"][:len] = adv
[docs] def train(self, batch, **kwargs): ( states_t, actions_t, returns_t, states_tp1, done_t, values_t, advantages_t, action_logprobs_t ) = ( batch["state"], batch["action"], batch["return"], batch["state_tp1"], batch["done"], batch["value"], batch["advantage"], batch["action_logprob"] ) states_t = utils.any2device(states_t, device=self._device) actions_t = utils.any2device(actions_t, device=self._device) returns_t = utils.any2device(returns_t, device=self._device).unsqueeze_(-1) states_tp1 = utils.any2device(states_tp1, device=self._device) done_t = utils.any2device(done_t, device=self._device)[:, None, None] # done_t = done_t[:, None, :] # [bs; 1; 1] values_t = utils.any2device(values_t, device=self._device) advantages_t = utils.any2device(advantages_t, device=self._device) action_logprobs_t = utils.any2device( action_logprobs_t, device=self._device ) # critic loss # states_t - [bs; {state_shape}] # values_t - [bs; num_heads; num_atoms] # returns_t - [bs; num_heads; 1] # states_tp1 - [bs; {state_shape}] # done_t - [bs; 1; 1] value_loss = self._value_loss_fn( states_t, values_t, returns_t, states_tp1, done_t ) # actor loss _, action_logprobs_tp0 = self.actor(states_t, logprob=actions_t) ratio = torch.exp(action_logprobs_tp0 - action_logprobs_t) ratio = ratio[:, None, None] # The same ratio for each head of the critic policy_loss_unclipped = advantages_t * ratio policy_loss_clipped = advantages_t * torch.clamp( ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps ) policy_loss = -torch.min(policy_loss_unclipped, policy_loss_clipped).mean() if self.entropy_regularization is not None: entropy = -(torch.exp(action_logprobs_tp0) * action_logprobs_tp0).mean() entropy_loss = self.entropy_regularization * entropy policy_loss = policy_loss + entropy_loss # actor update actor_update_metrics = self.actor_update(policy_loss) or {} # critic update critic_update_metrics = self.critic_update(value_loss) or {} # metrics kl = 0.5 * (action_logprobs_tp0 - action_logprobs_t).pow(2).mean() clipped_fraction = \ (torch.abs(ratio - 1.0) > self.clip_eps).float().mean() metrics = { "loss_actor": policy_loss.item(), "loss_critic": value_loss.item(), "kl": kl.item(), "clipped_fraction": clipped_fraction.item() } metrics = {**metrics, **actor_update_metrics, **critic_update_metrics} return metrics