Source code for catalyst.rl.db.redis
from redis import Redis
from catalyst.rl import utils
from catalyst.rl.core import DBSpec
[docs]class RedisDB(DBSpec):
def __init__(
self, host="127.0.0.1", port=12000, prefix=None, sync_epoch=False
):
self._server = Redis(host=host, port=port)
self._prefix = "" if prefix is None else prefix
self._index = 0
self._epoch = 0
self._sync_epoch = sync_epoch
def _set_flag(self, key, value):
self._server.set(f"{self._prefix}_{key}", value)
def _get_flag(self, key, default=None):
flag = self._server.get(f"{self._prefix}_{key}")
flag = flag if flag is not None else default
return flag
@property
def training_enabled(self) -> bool:
flag = self._get_flag("training_flag", 1) # enabled by default
flag = int(flag) == int(1)
return flag
@property
def sampling_enabled(self) -> bool:
flag = self._get_flag("sampling_flag", -1) # disabled by default
flag = int(flag) == int(1)
return flag
@property
def epoch(self) -> int:
return self._epoch
@property
def num_trajectories(self) -> int:
num_trajectories = self._server.llen("trajectories") - 1
return num_trajectories
[docs] def push_message(self, message: DBSpec.Message):
if message == DBSpec.Message.ENABLE_SAMPLING:
self._set_flag("sampling_flag", 1)
elif message == DBSpec.Message.DISABLE_SAMPLING:
self._set_flag("sampling_flag", 0)
elif message == DBSpec.Message.DISABLE_TRAINING:
self._set_flag("sampling_flag", 0)
self._set_flag("training_flag", 0)
elif message == DBSpec.Message.ENABLE_TRAINING:
self._set_flag("training_flag", 1)
else:
raise NotImplementedError("unknown message", message)
[docs] def put_trajectory(self, trajectory, raw=False):
trajectory = utils.structed2dict_trajectory(trajectory)
trajectory = {"trajectory": trajectory, "epoch": self._epoch}
trajectory = utils.pack(trajectory)
name = "raw_trajectories" if raw else "trajectories"
self._server.rpush(name, trajectory)
[docs] def get_trajectory(self, index=None):
index = index if index is not None else self._index
trajectory = self._server.lindex("trajectories", index)
if trajectory is not None:
self._index = index + 1
trajectory = utils.unpack(trajectory)
trajectory, trajectory_epoch = \
trajectory["trajectory"], trajectory["epoch"]
if self._sync_epoch and self._epoch != trajectory_epoch:
trajectory = None
else:
trajectory = utils.dict2structed_trajectory(trajectory)
return trajectory
[docs] def del_trajectory(self):
self._server.delete("trajectories")
self._index = 0
[docs] def put_checkpoint(self, checkpoint, epoch):
self._epoch = epoch
checkpoint = {"checkpoint": checkpoint, "epoch": self._epoch}
checkpoint = utils.pack(checkpoint)
self._server.set(f"{self._prefix}_checkpoint", checkpoint)
[docs] def get_checkpoint(self):
checkpoint = self._server.get(f"{self._prefix}_checkpoint")
if checkpoint is None:
return None
checkpoint = utils.unpack(checkpoint)
self._epoch = checkpoint.get("epoch")
return checkpoint["checkpoint"]
[docs] def del_checkpoint(self):
self._server.delete(f"{self._prefix}_weights")
__all__ = ["RedisDB"]