Source code for catalyst.rl.db.redis

from redis import Redis

from catalyst.rl import utils
from catalyst.rl.core import DBSpec


[docs]class RedisDB(DBSpec): def __init__( self, host="127.0.0.1", port=12000, prefix=None, sync_epoch=False ): self._server = Redis(host=host, port=port) self._prefix = "" if prefix is None else prefix self._index = 0 self._epoch = 0 self._sync_epoch = sync_epoch def _set_flag(self, key, value): self._server.set(f"{self._prefix}_{key}", value) def _get_flag(self, key, default=None): flag = self._server.get(f"{self._prefix}_{key}") flag = flag if flag is not None else default return flag @property def training_enabled(self) -> bool: flag = self._get_flag("training_flag", 1) # enabled by default flag = int(flag) == int(1) return flag @property def sampling_enabled(self) -> bool: flag = self._get_flag("sampling_flag", -1) # disabled by default flag = int(flag) == int(1) return flag @property def epoch(self) -> int: return self._epoch @property def num_trajectories(self) -> int: num_trajectories = self._server.llen("trajectories") - 1 return num_trajectories
[docs] def push_message(self, message: DBSpec.Message): if message == DBSpec.Message.ENABLE_SAMPLING: self._set_flag("sampling_flag", 1) elif message == DBSpec.Message.DISABLE_SAMPLING: self._set_flag("sampling_flag", 0) elif message == DBSpec.Message.DISABLE_TRAINING: self._set_flag("sampling_flag", 0) self._set_flag("training_flag", 0) elif message == DBSpec.Message.ENABLE_TRAINING: self._set_flag("training_flag", 1) else: raise NotImplementedError("unknown message", message)
[docs] def put_trajectory(self, trajectory, raw=False): trajectory = utils.structed2dict_trajectory(trajectory) trajectory = {"trajectory": trajectory, "epoch": self._epoch} trajectory = utils.pack(trajectory) name = "raw_trajectories" if raw else "trajectories" self._server.rpush(name, trajectory)
[docs] def get_trajectory(self, index=None): index = index if index is not None else self._index trajectory = self._server.lindex("trajectories", index) if trajectory is not None: self._index = index + 1 trajectory = utils.unpack(trajectory) trajectory, trajectory_epoch = \ trajectory["trajectory"], trajectory["epoch"] if self._sync_epoch and self._epoch != trajectory_epoch: trajectory = None else: trajectory = utils.dict2structed_trajectory(trajectory) return trajectory
[docs] def del_trajectory(self): self._server.delete("trajectories") self._index = 0
[docs] def put_checkpoint(self, checkpoint, epoch): self._epoch = epoch checkpoint = {"checkpoint": checkpoint, "epoch": self._epoch} checkpoint = utils.pack(checkpoint) self._server.set(f"{self._prefix}_checkpoint", checkpoint)
[docs] def get_checkpoint(self): checkpoint = self._server.get(f"{self._prefix}_checkpoint") if checkpoint is None: return None checkpoint = utils.unpack(checkpoint) self._epoch = checkpoint.get("epoch") return checkpoint["checkpoint"]
[docs] def del_checkpoint(self): self._server.delete(f"{self._prefix}_weights")
__all__ = ["RedisDB"]