import numpy as np
import torch
from catalyst.rl import utils
from .actor_critic import OnpolicyActorCritic
[docs]class PPO(OnpolicyActorCritic):
def _init(
self,
use_value_clipping: bool = True,
gae_lambda: float = 0.95,
clip_eps: float = 0.2,
entropy_regularization: float = None
):
self.use_value_clipping = use_value_clipping
self.gae_lambda = gae_lambda
self.clip_eps = clip_eps
self.entropy_regularization = entropy_regularization
critic_distribution = self.critic.distribution
self._value_loss_fn = self._base_value_loss
self._num_atoms = self.critic.num_atoms
self._num_heads = self.critic.num_heads
self._hyperbolic_constant = self.critic.hyperbolic_constant
self._gammas = \
utils.hyperbolic_gammas(
self._gamma,
self._hyperbolic_constant,
self._num_heads
)
# 1 x num_heads x 1
self._gammas_torch = utils.any2device(
self._gammas, device=self._device
)[None, :, None]
if critic_distribution == "categorical":
self.num_atoms = self.critic.num_atoms
values_range = self.critic.values_range
self.v_min, self.v_max = values_range
self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1)
z = torch.linspace(
start=self.v_min, end=self.v_max, steps=self._num_atoms
)
self.z = utils.any2device(z, device=self._device)
self._value_loss_fn = self._categorical_value_loss
elif critic_distribution == "quantile":
assert self.critic_criterion is not None
self.num_atoms = self.critic.num_atoms
tau_min = 1 / (2 * self._num_atoms)
tau_max = 1 - tau_min
tau = torch.linspace(
start=tau_min, end=tau_max, steps=self._num_atoms
)
self.tau = utils.any2device(tau, device=self._device)
self._value_loss_fn = self._quantile_value_loss
if not self.use_value_clipping:
assert self.critic_criterion is not None
def _value_loss(self, values_tp0, values_t, returns_t):
if self.use_value_clipping:
values_clip = values_t + torch.clamp(
values_tp0 - values_t, -self.clip_eps, self.clip_eps
)
value_loss_unclipped = (values_tp0 - returns_t).pow(2)
value_loss_clipped = (values_clip - returns_t).pow(2)
value_loss = 0.5 * torch.max(
value_loss_unclipped, value_loss_clipped
).mean()
else:
value_loss = self.critic_criterion(values_tp0, returns_t).mean()
return value_loss
def _base_value_loss(
self, states_t, values_t, returns_t, states_tp1, done_t
):
# [bs; num_heads; 1, num_atoms=1] ->
# [bs; num_heads; num_atoms=1] -> many-heads view transform
# [{bs * num_heads}; num_atoms=1]
values_tp0 = self.critic(states_t).squeeze_(dim=2)
# [bs; num_heads; num_atoms=1] -> many-heads view transform
# [{bs * num_heads}; num_atoms=1]
values_t = values_t.view(-1, 1)
# [bs; num_heads; num_atoms=1] -> many-heads view transform
# [{bs * num_heads}; num_atoms=1]
returns_t = returns_t.view(-1, 1)
value_loss = self._value_loss(values_tp0, values_t, returns_t)
return value_loss
def _categorical_value_loss(
self, states_t, logits_t, returns_t, states_tp1, done_t
):
# @TODO: WIP, no guaranties
logits_tp0 = self.critic(states_t).squeeze_(dim=2)
probs_tp0 = torch.softmax(logits_tp0, dim=-1)
values_tp0 = torch.sum(probs_tp0 * self.z, dim=-1, keepdim=True)
probs_t = torch.softmax(logits_t, dim=-1)
values_t = torch.sum(probs_t * self.z, dim=-1, keepdim=True)
value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t)
# B x num_heads x num_atoms
logits_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach()
# B x num_heads x num_atoms
atoms_target_t = returns_t + (1 - done_t) * self._gammas_torch * self.z
value_loss += 0.5 * utils.categorical_loss(
logits_tp0.view(-1, self.num_atoms),
logits_tp1.view(-1, self.num_atoms),
atoms_target_t.view(-1, self.num_atoms), self.z, self.delta_z,
self.v_min, self.v_max
)
return value_loss
def _quantile_value_loss(
self, states_t, atoms_t, returns_t, states_tp1, done_t
):
# @TODO: WIP, no guaranties
atoms_tp0 = self.critic(states_t).squeeze_(dim=2)
values_tp0 = torch.mean(atoms_tp0, dim=-1, keepdim=True)
values_t = torch.mean(atoms_t, dim=-1, keepdim=True)
value_loss = 0.5 * self._value_loss(values_tp0, values_t, returns_t)
# B x num_heads x num_atoms
atoms_tp1 = self.critic(states_tp1).squeeze_(dim=2).detach()
# B x num_heads x num_atoms
atoms_target_t = returns_t \
+ (1 - done_t) * self._gammas_torch * atoms_tp1
value_loss += 0.5 * utils.quantile_loss(
atoms_tp0.view(-1, self.num_atoms),
atoms_target_t.view(-1, self.num_atoms), self.tau, self.num_atoms,
self.critic_criterion
)
return value_loss
[docs] def get_rollout_spec(self):
return {
"action_logprob": {
"shape": (),
"dtype": np.float32
},
"advantage": {
"shape": (self._num_heads, self._num_atoms),
"dtype": np.float32
},
"done": {
"shape": (),
"dtype": np.bool
},
"return": {
"shape": (self._num_heads, ),
"dtype": np.float32
},
"value": {
"shape": (self._num_heads, self._num_atoms),
"dtype": np.float32
},
}
[docs] @torch.no_grad()
def get_rollout(self, states, actions, rewards, dones):
assert len(states) == len(actions) == len(rewards) == len(dones)
trajectory_len = \
rewards.shape[0] if dones[-1] else rewards.shape[0] - 1
states_len = states.shape[0]
states = utils.any2device(states, device=self._device)
actions = utils.any2device(actions, device=self._device)
rewards = np.array(rewards)[:trajectory_len]
values = torch.zeros(
(states_len + 1, self._num_heads, self._num_atoms)).\
to(self._device)
values[:states_len, ...] = self.critic(states).squeeze_(dim=2)
# Each column corresponds to a different gamma
values = values.cpu().numpy()[:trajectory_len + 1, ...]
_, logprobs = self.actor(states, logprob=actions)
logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len]
# len x num_heads
deltas = rewards[:, None, None] \
+ self._gammas[:, None] * values[1:] - values[:-1]
# For each gamma in the list of gammas compute the
# advantage and returns
# len x num_heads x num_atoms
advantages = np.stack(
[
utils.geometric_cumsum(gamma * self.gae_lambda, deltas[:, i])
for i, gamma in enumerate(self._gammas)
],
axis=1
)
# len x num_heads
returns = np.stack(
[
utils.geometric_cumsum(gamma, rewards[:, None])[:, 0]
for gamma in self._gammas
],
axis=1
)
# final rollout
dones = dones[:trajectory_len]
values = values[:trajectory_len]
assert len(logprobs) == len(advantages) \
== len(dones) == len(returns) == len(values)
rollout = {
"action_logprob": logprobs,
"advantage": advantages,
"done": dones,
"return": returns,
"value": values,
}
return rollout
[docs] def postprocess_buffer(self, buffers, len):
adv = buffers["advantage"][:len]
adv = (adv - adv.mean(axis=0)) / (adv.std(axis=0) + 1e-8)
buffers["advantage"][:len] = adv
[docs] def train(self, batch, **kwargs):
(
states_t, actions_t, returns_t, states_tp1, done_t, values_t,
advantages_t, action_logprobs_t
) = (
batch["state"], batch["action"], batch["return"],
batch["state_tp1"], batch["done"], batch["value"],
batch["advantage"], batch["action_logprob"]
)
states_t = utils.any2device(states_t, device=self._device)
actions_t = utils.any2device(actions_t, device=self._device)
returns_t = utils.any2device(returns_t,
device=self._device).unsqueeze_(-1)
states_tp1 = utils.any2device(states_tp1, device=self._device)
done_t = utils.any2device(done_t, device=self._device)[:, None, None]
# done_t = done_t[:, None, :] # [bs; 1; 1]
values_t = utils.any2device(values_t, device=self._device)
advantages_t = utils.any2device(advantages_t, device=self._device)
action_logprobs_t = utils.any2device(
action_logprobs_t, device=self._device
)
# critic loss
# states_t - [bs; {state_shape}]
# values_t - [bs; num_heads; num_atoms]
# returns_t - [bs; num_heads; 1]
# states_tp1 - [bs; {state_shape}]
# done_t - [bs; 1; 1]
value_loss = self._value_loss_fn(
states_t, values_t, returns_t, states_tp1, done_t
)
# actor loss
_, action_logprobs_tp0 = self.actor(states_t, logprob=actions_t)
ratio = torch.exp(action_logprobs_tp0 - action_logprobs_t)
ratio = ratio[:, None, None]
# The same ratio for each head of the critic
policy_loss_unclipped = advantages_t * ratio
policy_loss_clipped = advantages_t * torch.clamp(
ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps
)
policy_loss = -torch.min(policy_loss_unclipped,
policy_loss_clipped).mean()
if self.entropy_regularization is not None:
entropy = -(torch.exp(action_logprobs_tp0) *
action_logprobs_tp0).mean()
entropy_loss = self.entropy_regularization * entropy
policy_loss = policy_loss + entropy_loss
# actor update
actor_update_metrics = self.actor_update(policy_loss) or {}
# critic update
critic_update_metrics = self.critic_update(value_loss) or {}
# metrics
kl = 0.5 * (action_logprobs_tp0 - action_logprobs_t).pow(2).mean()
clipped_fraction = \
(torch.abs(ratio - 1.0) > self.clip_eps).float().mean()
metrics = {
"loss_actor": policy_loss.item(),
"loss_critic": value_loss.item(),
"kl": kl.item(),
"clipped_fraction": clipped_fraction.item()
}
metrics = {**metrics, **actor_update_metrics, **critic_update_metrics}
return metrics