Source code for catalyst.rl.offpolicy.trainer

from typing import Dict  # isort:skip
import threading
import time

import numpy as np

import torch
from torch.utils.data import DataLoader

from catalyst.rl import utils
from catalyst.rl.core import DBSpec, TrainerSpec


def _db2buffer_loop(
    db_server: DBSpec,
    buffer: utils.OffpolicyReplayBuffer,
):
    trajectory = None
    while True:
        try:
            if trajectory is None:
                trajectory = db_server.get_trajectory()

            if trajectory is not None:
                if buffer.push_trajectory(trajectory):
                    trajectory = None
                else:
                    time.sleep(1.0)
            else:
                if not db_server.training_enabled:
                    return
                time.sleep(1.0)
        except Exception as ex:
            print("=" * 80)
            print("Something go wrong with trajectory:")
            print(ex)
            print(trajectory)
            print("=" * 80)
            trajectory = None


[docs]class OffpolicyTrainer(TrainerSpec): def _init( self, target_update_period: int = 1, replay_buffer_size: int = int(1e6), replay_buffer_mode: str = "numpy", epoch_len: int = int(1e2), max_updates_per_sample: int = None, min_transitions_per_epoch: int = None, ): super()._init() # updates configuration # (actor_period, critic_period) self.actor_update_period, self.critic_update_period = \ utils.make_tuple(target_update_period) self.actor_updates = 0 self.critic_updates = 0 # self.epoch_len = epoch_len self.max_updates_per_sample = max_updates_per_sample or np.inf self.min_transitions_per_epoch = min_transitions_per_epoch or -np.inf self.last_epoch_transitions = 0 self.replay_buffer = utils.OffpolicyReplayBuffer( observation_space=self.env_spec.observation_space, action_space=self.env_spec.action_space, capacity=replay_buffer_size, history_len=self.env_spec.history_len, n_step=self.algorithm.n_step, gamma=self.algorithm.gamma, mode=replay_buffer_mode, logdir=self.logdir ) self.replay_sampler = utils.OffpolicyReplaySampler( buffer=self.replay_buffer, epoch_len=self.epoch_len, batch_size=self.batch_size ) self.loader = DataLoader( dataset=self.replay_buffer, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=torch.cuda.is_available(), sampler=self.replay_sampler ) self._db_loop_thread = None def _start_db_loop(self): self._db_loop_thread = threading.Thread( target=_db2buffer_loop, kwargs={ "db_server": self.db_server, "buffer": self.replay_buffer, } ) self._db_loop_thread.start() def _update_target_weights(self, update_step) -> Dict: output = {} if not self.env_spec.discrete_actions: if update_step % self.actor_update_period == 0: self.algorithm.target_actor_update() self.actor_updates += 1 output["num_actor_updates"] = self.actor_updates if update_step % self.critic_update_period == 0: self.algorithm.target_critic_update() self.critic_updates += 1 output["num_critic_updates"] = self.critic_updates return output def _run_epoch(self) -> Dict: self.replay_buffer.recalculate_index() expected_num_updates = ( self.num_updates + len(self.loader) * self.loader.batch_size ) expected_updates_per_sample = ( expected_num_updates / self.replay_buffer.num_transitions ) min_epoch_transitions = \ self.last_epoch_transitions + self.min_transitions_per_epoch while expected_updates_per_sample > self.max_updates_per_sample \ or self.replay_buffer.num_transitions < min_epoch_transitions: time.sleep(5.0) self.replay_buffer.recalculate_index() expected_num_updates = ( self.num_updates + len(self.loader) * self.loader.batch_size ) expected_updates_per_sample = ( expected_num_updates / self.replay_buffer.num_transitions ) metrics = self._run_loader(self.loader) self.last_epoch_transitions = self.replay_buffer.num_transitions updates_per_sample = ( self.num_updates / self.replay_buffer.num_transitions ) metrics.update( { "num_trajectories": self.replay_buffer.num_trajectories, "num_transitions": self.replay_buffer.num_transitions, "buffer_size": len(self.replay_buffer), "updates_per_sample": updates_per_sample } ) return metrics def _fetch_initial_buffer(self): buffer_size = len(self.replay_buffer) while buffer_size < self.min_num_transitions: self.replay_buffer.recalculate_index() num_trajectories = self.replay_buffer.num_trajectories num_transitions = self.replay_buffer.num_transitions buffer_size = len(self.replay_buffer) metrics = [ f"fps: {0:7.1f}", f"updates per sample: {0:7.1f}", f"trajectories: {num_trajectories:09d}", f"transitions: {num_transitions:09d}", f"buffer size: " f"{buffer_size:09d}/{self.min_num_transitions:09d}", ] metrics = " | ".join(metrics) print(f"--- {metrics}") time.sleep(1.0) def _start_train_loop(self): self._start_db_loop() self.db_server.push_message(self.db_server.Message.ENABLE_TRAINING) self.db_server.push_message(self.db_server.Message.ENABLE_SAMPLING) self._fetch_initial_buffer() self._run_train_stage()