Source code for catalyst.rl.onpolicy.trainer

from typing import Dict  # isort:skip

import time

import numpy as np

import torch
from torch.utils.data import DataLoader

from catalyst.rl import utils
from catalyst.rl.core import TrainerSpec


def _get_states_from_observations(observations: np.ndarray, history_len=1):
    """
    DB stores observations but not states.
    This function creates states from observations
    by adding new dimension of size (history_len).
    """
    states = np.concatenate(
        [np.expand_dims(np.zeros_like(observations), 1)] * history_len, axis=1
    )
    for i in range(history_len - 1):
        pivot = history_len - i - 1
        states[pivot:, i, ...] = observations[:-pivot, ...]
    states[:, -1, ...] = observations

    # structed numpy array
    if observations.dtype.fields is not None:
        states_dtype = []
        for key, value in observations.dtype.fields.items():
            states_dtype.append(
                (key, value[0].base, (history_len, ) + tuple(value[0].shape))
            )
        states_dtype = np.dtype(states_dtype)
        states_ = np.empty(len(observations), dtype=states_dtype)

        for key in observations.dtype.fields.keys():
            states_[key] = states[key]

        states = states_

    return states


[docs]class OnpolicyTrainer(TrainerSpec): def _init( self, num_mini_epochs: int = 1, min_num_trajectories: int = 100, rollout_batch_size: int = None ): super()._init() self.num_mini_epochs = num_mini_epochs self.min_num_trajectories = min_num_trajectories self.max_num_transitions = self.min_num_transitions * 3 self.rollout_batch_size = rollout_batch_size def _get_rollout_in_batches(self, states, actions, rewards, dones): if self.rollout_batch_size is None: return self.algorithm.get_rollout(states, actions, rewards, dones) indices = np.arange( 0, len(states) + self.rollout_batch_size - 1, self.rollout_batch_size ) rollout = None for i_from, i_to in utils.pairwise(indices): states_batch = states[i_from:i_to + 1] actions_batch = actions[i_from:i_to + 1] rewards_batch = rewards[i_from:i_to + 1] dones_batch = dones[i_from:i_to + 1] rollout_batch = self.algorithm.get_rollout( states_batch, actions_batch, rewards_batch, dones_batch ) rollout = rollout_batch \ if rollout is None \ else utils.append_dict(rollout, rollout_batch) return rollout def _fetch_trajectories(self): # cleanup trajectories self.db_server.del_trajectory() num_trajectories = 0 num_transitions = 0 del self.replay_buffer rollout_spec = self.algorithm.get_rollout_spec() self.replay_buffer = utils.OnpolicyRolloutBuffer( state_space=self.env_spec.state_space, action_space=self.env_spec.action_space, capacity=self.max_num_transitions, **rollout_spec ) # start samplers self.db_server.push_message(self.db_server.Message.ENABLE_SAMPLING) start_time = time.time() while num_trajectories < self.min_num_trajectories \ and num_transitions < self.min_num_transitions: trajectories_percentrage = \ 100 * num_trajectories / self.min_num_trajectories trajectories_stats = \ f"{num_trajectories:09d} / " \ f"{self.min_num_trajectories:09d} " \ f"({trajectories_percentrage:5.2f}%)" transitions_percentrage = \ 100 * num_transitions / self.min_num_transitions transitions_stats = \ f"{num_transitions:09d} / " \ f"{self.min_num_transitions:09d} " \ f"({transitions_percentrage:5.2f}%)" print( f"trajectories: {trajectories_stats}\t" f"transitions: {transitions_stats}\t" ) try: trajectory = self.db_server.get_trajectory() assert trajectory is not None except AssertionError: time.sleep(1.0) continue num_trajectories += 1 num_transitions += len(trajectory[-1]) observations, actions, rewards, dones = trajectory states = _get_states_from_observations( observations, self.env_spec.history_len ) rollout = self._get_rollout_in_batches( states, actions, rewards, dones ) self.replay_buffer.push_rollout( state=states, action=actions, reward=rewards, **rollout, ) # stop samplers self.db_server.push_message(self.db_server.Message.DISABLE_SAMPLING) self._num_trajectories += num_trajectories self._num_transitions += num_transitions # @TODO: refactor self.algorithm.postprocess_buffer( self.replay_buffer.buffers, len(self.replay_buffer) ) elapsed_time = time.time() - start_time self.logger.add_scalar("fetch time", elapsed_time, self.epoch) def _run_epoch(self) -> Dict: sampler = utils.OnpolicyRolloutSampler( buffer=self.replay_buffer, num_mini_epochs=self.num_mini_epochs ) loader = DataLoader( dataset=self.replay_buffer, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=torch.cuda.is_available(), sampler=sampler ) metrics = self._run_loader(loader) updates_per_sample = self.num_updates / self._num_transitions metrics.update( { "num_trajectories": self._num_trajectories, "num_transitions": self._num_transitions, "buffer_size": len(self.replay_buffer), "updates_per_sample": updates_per_sample } ) return metrics def _run_train_stage(self): self.db_server.push_message(self.db_server.Message.ENABLE_TRAINING) epoch_limit = self._epoch_limit or np.iinfo(np.int32).max while self.epoch < epoch_limit: try: # get trajectories self._fetch_trajectories() # train & update self._run_epoch_loop() except Exception as ex: self.db_server.push_message( self.db_server.Message.DISABLE_TRAINING ) raise ex self.db_server.push_message(self.db_server.Message.DISABLE_TRAINING)