Source code for catalyst.rl.onpolicy.algorithms.reinforce

import numpy as np

import torch

from catalyst.rl import utils
from .actor import OnpolicyActor


[docs]class REINFORCE(OnpolicyActor): def _init(self, entropy_regularization: float = None): self.entropy_regularization = entropy_regularization
[docs] def get_rollout_spec(self): return { "return": { "shape": (), "dtype": np.float32 }, "action_logprob": { "shape": (), "dtype": np.float32 }, }
[docs] @torch.no_grad() def get_rollout(self, states, actions, rewards, dones): assert len(states) == len(actions) == len(rewards) == len(dones) trajectory_len = \ rewards.shape[0] if dones[-1] else rewards.shape[0] - 1 states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) rewards = np.array(rewards)[:trajectory_len] _, logprobs = self.actor(states, logprob=actions) logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len] returns = utils.geometric_cumsum(self.gamma, rewards[:, None])[:, 0] assert len(returns) == len(logprobs) rollout = {"return": returns, "action_logprob": logprobs} return rollout
[docs] def postprocess_buffer(self, buffers, len): pass
[docs] def train(self, batch, **kwargs): states, actions, returns, action_logprobs = \ batch["state"], batch["action"], batch["return"],\ batch["action_logprob"] states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) returns = utils.any2device(returns, device=self._device) old_logprobs = utils.any2device(action_logprobs, device=self._device) # actor loss _, logprobs = self.actor(states, logprob=actions) # REINFORCE objective function policy_loss = -torch.mean(logprobs * returns) if self.entropy_regularization is not None: entropy = -(torch.exp(logprobs) * logprobs).mean() entropy_loss = self.entropy_regularization * entropy policy_loss = policy_loss + entropy_loss # actor update actor_update_metrics = self.actor_update(policy_loss) or {} # metrics kl = 0.5 * (logprobs - old_logprobs).pow(2).mean() metrics = { "loss_actor": policy_loss.item(), "kl": kl.item(), } metrics = {**metrics, **actor_update_metrics} return metrics