Source code for catalyst.rl.core.trajectory_sampler

from typing import Union  # isort:skip
from ctypes import c_bool
import multiprocessing as mp

import numpy as np

import torch

from catalyst.rl import utils
from catalyst.rl.utils.buffer import get_buffer
from catalyst.utils import tools
from .agent import ActorSpec, CriticSpec
from .environment import EnvironmentSpec
from .policy_handler import PolicyHandler


[docs]class TrajectorySampler: def __init__( self, env: EnvironmentSpec, agent: Union[ActorSpec, CriticSpec], device, deterministic: bool = False, initial_capacity: int = int(1e3), sampling_flag: mp.Value = None ): self.env = env self.agent = agent self._device = device self._deterministic = deterministic self._initial_capacity = initial_capacity self._policy_handler = PolicyHandler( env=self.env, agent=self.agent, device=device ) self._sampling_flag = sampling_flag or mp.Value(c_bool, True) self._init_buffers() def _init_buffers(self): sample_size = 3 observations_, observations_dtype = get_buffer( capacity=sample_size, space=self.env.observation_space, mode="numpy" ) observations_shape = (None, ) \ if observations_.dtype.fields is not None \ else (None,) + tuple(self.env.observation_space.shape) self.observations = tools.DynamicArray( array_or_shape=observations_shape, capacity=int(self._initial_capacity), dtype=observations_dtype ) actions_, actions_dtype = get_buffer( capacity=sample_size, space=self.env.action_space, mode="numpy" ) actions_shape = (None,) \ if actions_.dtype.fields is not None \ else (None,) + tuple(self.env.action_space.shape) self.actions = tools.DynamicArray( array_or_shape=actions_shape, capacity=int(self._initial_capacity), dtype=actions_dtype ) self.rewards = tools.DynamicArray( array_or_shape=(None, ), dtype=np.float32, capacity=int(self._initial_capacity) ) self.dones = tools.DynamicArray( array_or_shape=(None, ), dtype=np.bool, capacity=int(self._initial_capacity) ) def _init_with_observation(self, observation): self.observations.append(observation) def _put_transition(self, transition): """ transition = [o_tp1, a_t, r_t, d_t] """ o_tp1, a_t, r_t, d_t = transition self.observations.append(o_tp1) self.actions.append(a_t) self.rewards.append(r_t) self.dones.append(d_t) def _get_states_history(self, history_len=None): history_len = history_len or self.env.history_len states = [ self.get_state(history_len=history_len, index=i) for i in range(len(self.observations)) ] states = np.array(states) return states
[docs] def get_state(self, index=None, history_len=None): index = index if index is not None else len(self.dones) history_len = history_len \ if history_len is not None \ else self.env.history_len state = np.zeros( (history_len, ) + tuple(self.observations.shape[1:]), dtype=self.observations.dtype ) indices = np.arange(max(0, index - history_len + 1), index + 1) state[-len(indices):] = self.observations[indices] return state
[docs] def get_trajectory(self): trajectory = ( np.array(self.observations[:-1]), np.array(self.actions), np.array(self.rewards), np.array(self.dones) ) return trajectory
[docs] @torch.no_grad() def reset(self, exploration_strategy=None): if not self._deterministic: from catalyst.rl.exploration import \ ParameterSpaceNoise, OrnsteinUhlenbeckProcess if isinstance(exploration_strategy, OrnsteinUhlenbeckProcess): exploration_strategy.reset_state( self.env.action_space.shape[0] ) if isinstance(exploration_strategy, ParameterSpaceNoise) \ and len(self.observations) > 1: states = self._get_states_history() states = utils.any2device(states, device=self._device) exploration_strategy.update_actor(self.agent, states) observation = self.env.reset() self._init_buffers() self._init_with_observation(observation)
[docs] def sample(self, exploration_strategy=None): reward, raw_reward, num_steps, done_t = 0, 0, 0, False while not done_t and self._sampling_flag.value: state_t = self.get_state() action_t = self._policy_handler.action_fn( agent=self.agent, state=state_t, device=self._device, exploration_strategy=exploration_strategy, deterministic=self._deterministic ) observation_tp1, reward_t, done_t, info = self.env.step(action_t) reward += reward_t raw_reward += info.get("raw_reward", reward_t) transition = [observation_tp1, action_t, reward_t, done_t] self._put_transition(transition) num_steps += 1 if not self._sampling_flag.value: return None, None trajectory = self.get_trajectory() trajectory_info = {"reward": reward, "num_steps": num_steps} if info and "raw_trajectory" in info: raw_trajectory = info["raw_trajectory"] trajectory_info["raw_trajectory"] = raw_trajectory reward = np.sum(raw_trajectory[2]) # This may be different from num_steps in case # we use the frame skip wrapper raw_num_steps = len(raw_trajectory[0]) assert all(len(x) == raw_num_steps for x in raw_trajectory) trajectory_info["raw_reward"] = raw_reward assert all(len(x) == num_steps for x in trajectory) return trajectory, trajectory_info