Source code for catalyst.rl.core.exploration
from typing import List # isort:skip
from copy import deepcopy
import numpy as np
from catalyst.rl.core import EnvironmentSpec
from catalyst.rl.registry import EXPLORATION
[docs]class ExplorationStrategy:
"""
Base class for working with various exploration strategies.
In discrete case must contain method get_action(q_values).
In continuous case must contain method get_action(action).
"""
def __init__(self, power=1.0):
self._power = power
[docs] def set_power(self, value):
assert 0. <= value <= 1.0
self._power = value
[docs]class ExplorationHandler:
def __init__(self, *exploration_params, env: EnvironmentSpec):
params = deepcopy(exploration_params)
self.exploration_strategies: List[ExplorationStrategy] = []
self.probs = []
for params_ in params:
exploration_name = params_.pop("exploration")
probability = params_.pop("probability")
strategy_fn = EXPLORATION.get(exploration_name)
strategy = strategy_fn(**params_)
self.exploration_strategies.append(strategy)
self.probs.append(probability)
self.num_strategies = len(self.probs)
assert np.isclose(np.sum(self.probs), 1.0)
[docs] def set_power(self, value):
assert 0. <= value <= 1.0
for exploration in self.exploration_strategies:
exploration.set_power(value=value)
[docs] def get_exploration_strategy(self):
strategy_idx = np.random.choice(self.num_strategies, p=self.probs)
strategy = self.exploration_strategies[strategy_idx]
return strategy
__all__ = ["ExplorationStrategy", "ExplorationHandler"]