from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple, Union
from collections import abc
import inspect
import warnings
from hydra_slayer import functional as F
from hydra_slayer.factory import Factory
__all__ = ["Registry"]
LateAddCallback = Callable[["Registry"], None]
[docs]class Registry(abc.MutableMapping):
"""
Universal class allowing to add and access various factories by name.
Args:
name_key: key to use to extract names of the factories from
var_key: key to use to for aliases, aliases let you identify an item
and then refer to that item and reuse it multiple times
attrs_delimiter: delimiter to use for separation of alias and
attribute of an instance to get
"""
def __init__(
self,
name_key: str = F.DEFAULT_FACTORY_KEY,
var_key: str = F.DEFAULT_VAR_KEY,
attrs_delimiter: str = F.DEFAULT_ATTRS_DELIMITER,
):
self._late_add_callbacks: List[LateAddCallback] = []
self.name_key = name_key
self._factories: Dict[str, Factory] = {}
self.var_key = var_key
self.attrs_delimiter = attrs_delimiter
self._vars_dict = {}
@staticmethod
def _get_factory_name(f, provided_name: str = None) -> str:
if not provided_name:
provided_name = getattr(f, "__name__", None)
if not provided_name:
raise ValueError(f"Factory '{f}' has no '__name__' and no name was provided")
if provided_name == "<lambda>":
raise ValueError("Name for lambda factories must be provided")
return provided_name
def _do_late_add(self):
if self._late_add_callbacks:
for cb in self._late_add_callbacks:
cb(self)
self._late_add_callbacks = []
[docs] def add(
self,
factory: Factory = None,
*factories: Factory,
name: str = None,
**named_factories: Factory,
) -> Factory:
"""
Adds factory to registry with it's ``__name__`` attribute or provided
name. Signature is flexible.
Args:
factory: factory instance
factories: more instances
name: name to use for the first factory instance,
if a single instance is passed
named_factories: factory and their names as keyword arguments
Returns:
first factory passed
Raises:
ValueError: if multiple factories with a single name are provided
LookupError: if factory with provided name is already registered
"""
if len(factories) > 0 and name is not None:
raise ValueError("Multiple factories with single name are not allowed")
if factory is not None:
named_factories[self._get_factory_name(factory, name)] = factory
if len(factories) > 0:
new = {self._get_factory_name(f): f for f in factories}
named_factories.update(new)
if len(named_factories) == 0:
warnings.warn("No factories were provided!")
for name, f in named_factories.items():
# self._factories[name] != f is a workaround for
# https://github.com/catalyst-team/catalyst/issues/135
if name in self._factories and self._factories[name] != f:
raise LookupError(
f"Factory with name '{name}' is already present\n"
f"Already registered: '{self._factories[name]}'\n"
f"New: '{f}'"
)
self._factories.update(named_factories)
return factory
[docs] def late_add(self, cb: LateAddCallback):
"""
Allows to prevent cycle imports by delaying some imports till next
registry query.
Args:
cb: callback receives registry and must call it's methods to
register factories
"""
self._late_add_callbacks.append(cb)
[docs] def add_from_module(
self,
module,
prefix: Union[str, List[str]] = None,
ignore_all: bool = False,
) -> None:
"""
Adds all factories present in module.
If ``__all__`` attribute is present, takes ony what mentioned in it.
Args:
module: module to scan
prefix: prefix string for all the module's factories.
If prefix is a list, all values will be treated as aliases
ignore_all: if ``True``, ignores ``__all__`` attribute
of the module
Raises:
TypeError: if prefix is not a list or a string
"""
factories = {
k: v for k, v in module.__dict__.items() if inspect.isclass(v) or inspect.isfunction(v)
}
if ignore_all:
names_to_add = list(factories.keys())
else:
# filter by __all__ if present
names_to_add = getattr(module, "__all__", list(factories.keys()))
if prefix is None:
prefix = [""]
elif isinstance(prefix, str):
prefix = [prefix]
elif isinstance(prefix, list):
if any((not isinstance(p, str)) for p in prefix):
raise TypeError("All prefix in list must be strings")
else:
raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}")
to_add = {f"{p}{name}": factories[name] for p in prefix for name in names_to_add}
self.add(**to_add)
[docs] def get(self, name: str) -> Optional[Factory]:
"""
Retrieves factory, without creating any objects with it
or raises error.
Args:
name: factory name
Returns:
factory by name
"""
self._do_late_add()
if name is None:
return None
res = self._factories.get(name, None)
if res is None:
res = F.get_factory(name)
return res
[docs] def get_if_str(self, obj: Union[str, Factory]):
"""Returns object from the registry if ``obj`` type is string."""
if isinstance(obj, str):
return self.get(obj)
return obj
[docs] def get_instance(self, *args, **kwargs):
"""
Creates instance by calling specified factory with ``instantiate_fn``.
Args:
*args: positional arguments to be passed into the factory
**kwargs: keyword arguments to be passed into the factory
Returns:
created instance
"""
instance = F._get_instance(
factory_key=self.name_key, get_factory_func=self.get, args=args, kwargs=kwargs
)
return instance
[docs] def get_from_params(
self, *, shared_params: Optional[Dict[str, Any]] = None, **kwargs
) -> Union[Any, Tuple[Any, Mapping[str, Any]]]:
"""
Creates instance based in configuration dict with ``instantiation_fn``.
If ``config[name_key]`` is None, ``None`` is returned.
Args:
shared_params: params to pass on all levels in case of
recursive creation
**kwargs: keyword arguments to be passed into the factory
Returns:
result of calling ``instantiate_fn(factory, **sub_kwargs)``
"""
instance, _ = F._recursive_get_from_params(
factory_key=self.name_key,
get_factory_func=self.get,
params=kwargs,
shared_params=shared_params or {},
var_key=self.var_key,
attrs_delimiter=self.attrs_delimiter,
vars_dict=self._vars_dict,
)
return instance
[docs] def all(self) -> Iterable[str]:
"""Returns list with names of all registered items."""
self._do_late_add()
result = tuple(self._factories.keys())
return result
def __str__(self) -> str:
"""Returns a string of registered items."""
return self.all().__str__()
def __repr__(self) -> str:
"""Returns a string representation of registered items."""
return self.all().__str__()
# mapping methods
def __len__(self) -> int:
"""Returns length of registered items."""
self._do_late_add()
return len(self._factories)
def __getitem__(self, name: str) -> Optional[Factory]:
"""Returns a value from the registry by name."""
return self.get(name)
def __iter__(self) -> Iterator[str]:
"""Iterates over all registered items."""
self._do_late_add()
return self._factories.__iter__()
def __contains__(self, name: str):
"""Check if a particular name was registered."""
self._do_late_add()
return self._factories.__contains__(name)
def __setitem__(self, name: str, factory: Factory) -> None:
"""Add a new factory by giving name."""
self.add(factory, name=name)
def __delitem__(self, name: str) -> None:
"""Removes a factory by giving name."""
self._factories.pop(name)