Source code for catalyst.rl.db.mongo
import datetime
import time
import gridfs
import pymongo
import safitty
from catalyst.rl import utils
from catalyst.rl.core import DBSpec
[docs]class MongoDB(DBSpec):
    def __init__(
        self,
        host: str = "127.0.0.1",
        port: int = 12000,
        prefix: str = None,
        sync_epoch: bool = False,
        reconnect_timeout: int = 3,
    ):
        self._server = pymongo.MongoClient(host=host, port=port)
        self._prefix = "" if prefix is None else prefix
        self._reconnect_timeout = reconnect_timeout
        self._shared_db = self._server["shared"]
        self._agent_db = self._server[f"agent_{self._prefix}"]
        self._trajectory_collection = self._shared_db["trajectories"]
        self._raw_trajectory_collection = self._shared_db["raw_trajectories"]
        self._checkpoint_collection =\
            
gridfs.GridFS(self._agent_db, collection="checkpoints")
        self._message_collection = self._agent_db["messages"]
        self._last_datetime = datetime.datetime.min
        self._epoch = 0
        self._sync_epoch = sync_epoch
    def _set_flag(self, key, value):
        try:
            self._message_collection.replace_one(
                {"key": key},
                {"key": key, "value": value},
                upsert=True
            )
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self._set_flag(key, value)
    def _get_flag(self, key, default=None):
        try:
            flag_obj = self._message_collection.find_one(
                {"key": {"$eq": key}}
            )
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self._get_flag(key, default)
        flag = safitty.get(flag_obj, "value", default=default)
        return flag
    @property
    def training_enabled(self) -> bool:
        flag = self._get_flag("training_flag", 1)  # enabled by default
        flag = int(flag) == int(1)
        return flag
    @property
    def sampling_enabled(self) -> bool:
        flag = self._get_flag("sampling_flag", -1)  # disabled by default
        flag = int(flag) == int(1)
        return flag
    @property
    def epoch(self) -> int:
        return self._epoch
    @property
    def num_trajectories(self) -> int:
        num_trajectories = self._trajectory_collection.count() - 1
        return num_trajectories
[docs]    def push_message(self, message: DBSpec.Message):
        if message == DBSpec.Message.ENABLE_SAMPLING:
            self._set_flag("sampling_flag", 1)
        elif message == DBSpec.Message.DISABLE_SAMPLING:
            self._set_flag("sampling_flag", 0)
        elif message == DBSpec.Message.DISABLE_TRAINING:
            self._set_flag("sampling_flag", 0)
            self._set_flag("training_flag", 0)
        elif message == DBSpec.Message.ENABLE_TRAINING:
            self._set_flag("training_flag", 1)
        else:
            raise NotImplementedError("unknown message", message) 
[docs]    def put_trajectory(self, trajectory, raw=False):
        try:
            trajectory_ = utils.structed2dict_trajectory(trajectory)
            trajectory_ = utils.pack(trajectory_)
            collection = self._raw_trajectory_collection if raw \
                
else self._trajectory_collection
            collection.insert_one(
                {
                    "trajectory": trajectory_,
                    "date": datetime.datetime.utcnow(),
                    "epoch": self._epoch
                }
            )
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self.put_trajectory(trajectory, raw) 
[docs]    def get_trajectory(self, index=None):
        assert index is None
        try:
            trajectory_obj = self._trajectory_collection.find_one(
                {"date": {
                    "$gt": self._last_datetime
                }}
            )
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self.get_trajectory(index)
        if trajectory_obj is not None:
            self._last_datetime = trajectory_obj["date"]
            trajectory, trajectory_epoch = \
                
utils.unpack(trajectory_obj["trajectory"]), \
                
trajectory_obj["epoch"]
            if self._sync_epoch and self._epoch != trajectory_epoch:
                trajectory = None
            else:
                trajectory = utils.dict2structed_trajectory(trajectory)
        else:
            trajectory = None
        return trajectory 
[docs]    def del_trajectory(self):
        try:
            self._trajectory_collection.drop()
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self.del_trajectory() 
[docs]    def put_checkpoint(self, checkpoint, epoch):
        try:
            self._epoch = epoch
            checkpoint_ = utils.pack(checkpoint)
            if self._checkpoint_collection.exists({"filename": "checkpoint"}):
                self.del_checkpoint()
            self._checkpoint_collection.put(
                checkpoint_,
                encoding="ascii",
                filename="checkpoint",
                epoch=self._epoch
            )
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self.put_checkpoint(checkpoint, epoch) 
[docs]    def get_checkpoint(self):
        try:
            checkpoint_obj = self._checkpoint_collection.find_one(
                {"filename": "checkpoint"}
            )
        except pymongo.errors.AutoReconnect:
            time.sleep(self._reconnect_timeout)
            return self.get_checkpoint()
        if checkpoint_obj is not None:
            checkpoint = checkpoint_obj.read().decode("ascii")
            self._epoch = checkpoint_obj.epoch
            checkpoint = utils.unpack(checkpoint)
        else:
            checkpoint = None
        return checkpoint 
[docs]    def del_checkpoint(self):
        id_ = self._checkpoint_collection.find_one(
            {"filename": "checkpoint"}
        )._id
        self._checkpoint_collection.delete(id_)  
__all__ = ["MongoDB"]