Source code for catalyst.rl.offpolicy.trainer
from typing import Dict  # isort:skip
import threading
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from catalyst.rl import utils
from catalyst.rl.core import DBSpec, TrainerSpec
def _db2buffer_loop(
    db_server: DBSpec,
    buffer: utils.OffpolicyReplayBuffer,
):
    trajectory = None
    while True:
        try:
            if trajectory is None:
                trajectory = db_server.get_trajectory()
            if trajectory is not None:
                if buffer.push_trajectory(trajectory):
                    trajectory = None
                else:
                    time.sleep(1.0)
            else:
                if not db_server.training_enabled:
                    return
                time.sleep(1.0)
        except Exception as ex:
            print("=" * 80)
            print("Something go wrong with trajectory:")
            print(ex)
            print(trajectory)
            print("=" * 80)
            trajectory = None
[docs]class OffpolicyTrainer(TrainerSpec):
    def _init(
        self,
        target_update_period: int = 1,
        replay_buffer_size: int = int(1e6),
        replay_buffer_mode: str = "numpy",
        epoch_len: int = int(1e2),
        max_updates_per_sample: int = None,
        min_transitions_per_epoch: int = None,
    ):
        super()._init()
        # updates configuration
        # (actor_period, critic_period)
        self.actor_update_period, self.critic_update_period = \
            
utils.make_tuple(target_update_period)
        self.actor_updates = 0
        self.critic_updates = 0
        #
        self.epoch_len = epoch_len
        self.max_updates_per_sample = max_updates_per_sample or np.inf
        self.min_transitions_per_epoch = min_transitions_per_epoch or -np.inf
        self.last_epoch_transitions = 0
        self.replay_buffer = utils.OffpolicyReplayBuffer(
            observation_space=self.env_spec.observation_space,
            action_space=self.env_spec.action_space,
            capacity=replay_buffer_size,
            history_len=self.env_spec.history_len,
            n_step=self.algorithm.n_step,
            gamma=self.algorithm.gamma,
            mode=replay_buffer_mode,
            logdir=self.logdir
        )
        self.replay_sampler = utils.OffpolicyReplaySampler(
            buffer=self.replay_buffer,
            epoch_len=self.epoch_len,
            batch_size=self.batch_size
        )
        self.loader = DataLoader(
            dataset=self.replay_buffer,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=torch.cuda.is_available(),
            sampler=self.replay_sampler
        )
        self._db_loop_thread = None
    def _start_db_loop(self):
        self._db_loop_thread = threading.Thread(
            target=_db2buffer_loop,
            kwargs={
                "db_server": self.db_server,
                "buffer": self.replay_buffer,
            }
        )
        self._db_loop_thread.start()
    def _update_target_weights(self, update_step) -> Dict:
        output = {}
        if not self.env_spec.discrete_actions:
            if update_step % self.actor_update_period == 0:
                self.algorithm.target_actor_update()
                self.actor_updates += 1
                output["num_actor_updates"] = self.actor_updates
        if update_step % self.critic_update_period == 0:
            self.algorithm.target_critic_update()
            self.critic_updates += 1
            output["num_critic_updates"] = self.critic_updates
        return output
    def _run_epoch(self) -> Dict:
        self.replay_buffer.recalculate_index()
        expected_num_updates = (
            self.num_updates + len(self.loader) * self.loader.batch_size
        )
        expected_updates_per_sample = (
            expected_num_updates / self.replay_buffer.num_transitions
        )
        min_epoch_transitions = \
            
self.last_epoch_transitions + self.min_transitions_per_epoch
        while expected_updates_per_sample > self.max_updates_per_sample \
                
or self.replay_buffer.num_transitions < min_epoch_transitions:
            time.sleep(5.0)
            self.replay_buffer.recalculate_index()
            expected_num_updates = (
                self.num_updates + len(self.loader) * self.loader.batch_size
            )
            expected_updates_per_sample = (
                expected_num_updates / self.replay_buffer.num_transitions
            )
        metrics = self._run_loader(self.loader)
        self.last_epoch_transitions = self.replay_buffer.num_transitions
        updates_per_sample = (
            self.num_updates / self.replay_buffer.num_transitions
        )
        metrics.update(
            {
                "num_trajectories": self.replay_buffer.num_trajectories,
                "num_transitions": self.replay_buffer.num_transitions,
                "buffer_size": len(self.replay_buffer),
                "updates_per_sample": updates_per_sample
            }
        )
        return metrics
    def _fetch_initial_buffer(self):
        buffer_size = len(self.replay_buffer)
        while buffer_size < self.min_num_transitions:
            self.replay_buffer.recalculate_index()
            num_trajectories = self.replay_buffer.num_trajectories
            num_transitions = self.replay_buffer.num_transitions
            buffer_size = len(self.replay_buffer)
            metrics = [
                f"fps: {0:7.1f}",
                f"updates per sample: {0:7.1f}",
                f"trajectories: {num_trajectories:09d}",
                f"transitions: {num_transitions:09d}",
                f"buffer size: "
                f"{buffer_size:09d}/{self.min_num_transitions:09d}",
            ]
            metrics = " | ".join(metrics)
            print(f"--- {metrics}")
            time.sleep(1.0)
    def _start_train_loop(self):
        self._start_db_loop()
        self.db_server.push_message(self.db_server.Message.ENABLE_TRAINING)
        self.db_server.push_message(self.db_server.Message.ENABLE_SAMPLING)
        self._fetch_initial_buffer()
        self._run_train_stage()