Source code for hydra_slayer.registry

from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
import collections
import copy
import inspect
import pydoc
import warnings

from hydra_slayer.exceptions import RegistryException
from hydra_slayer.factory import default_meta_factory, Factory

__all__ = ["Registry"]

LateAddCallback = Callable[["Registry"], None]
MetaFactory = Callable[[Factory, Tuple, Mapping], Any]


[docs]class Registry(collections.MutableMapping): """ Universal class allowing to add and access various factories by name. Args: meta_factory: default object that calls factory. Optional. Default just calls factory. """
[docs] def __init__(self, meta_factory: MetaFactory = None, name_key: str = "_target_"): """Init.""" self.meta_factory = meta_factory if meta_factory is not None else default_meta_factory self._factories: Dict[str, Factory] = {} self._late_add_callbacks: List[LateAddCallback] = [] self.name_key = name_key
@staticmethod def _get_factory_name(f, provided_name=None) -> str: if not provided_name: provided_name = getattr(f, "__name__", None) if not provided_name: raise RegistryException( f"Factory {f} has no __name__ and no " f"name was provided" ) if provided_name == "<lambda>": raise RegistryException("Name for lambda factories must be provided") return provided_name def _do_late_add(self): if self._late_add_callbacks: for cb in self._late_add_callbacks: cb(self) self._late_add_callbacks = []
[docs] def add( self, factory: Factory = None, *factories: Factory, name: str = None, **named_factories: Factory, ) -> Factory: """ Adds factory to registry with it's ``__name__`` attribute or provided name. Signature is flexible. Args: factory: Factory instance factories: More instances name: Provided name for first instance. Use only when pass single instance. named_factories: Factory and their names as kwargs Returns: Factory: First factory passed Raises: RegistryException: if factory with provided name is already present """ if len(factories) > 0 and name is not None: raise RegistryException("Multiple factories with single name are not allowed") if factory is not None: named_factories[self._get_factory_name(factory, name)] = factory if len(factories) > 0: new = {self._get_factory_name(f): f for f in factories} named_factories.update(new) if len(named_factories) == 0: warnings.warn("No factories were provided!") for name, f in named_factories.items(): # self._factories[name] != f is a workaround for # https://github.com/catalyst-team/catalyst/issues/135 if name in self._factories and self._factories[name] != f: raise RegistryException( f"Factory with name '{name}' is already present\n" f"Already registered: '{self._factories[name]}'\n" f"New: '{f}'" ) self._factories.update(named_factories) return factory
[docs] def late_add(self, cb: LateAddCallback): """ Allows to prevent cycle imports by delaying some imports till next registry query. Args: cb: Callback receives registry and must call it's methods to register factories """ self._late_add_callbacks.append(cb)
[docs] def add_from_module( self, module, prefix: Union[str, List[str]] = None, ignore_all: bool = False ) -> None: """ Adds all factories present in module. If ``__all__`` attribute is present, takes ony what mentioned in it. Args: module: module to scan prefix (Union[str, List[str]]): prefix string for all the module's factories. If prefix is a list, all values will be treated as aliases. ignore_all: flag, ignore `__all__` or not Raises: TypeError: if prefix is not a list or a string """ factories = { k: v for k, v in module.__dict__.items() if inspect.isclass(v) or inspect.isfunction(v) } if ignore_all: names_to_add = list(factories.keys()) else: # Filter by __all__ if present names_to_add = getattr(module, "__all__", list(factories.keys())) if prefix is None: prefix = [""] elif isinstance(prefix, str): prefix = [prefix] elif isinstance(prefix, list): if any((not isinstance(p, str)) for p in prefix): raise TypeError("All prefix in list must be strings.") else: raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}.") to_add = {f"{p}{name}": factories[name] for p in prefix for name in names_to_add} self.add(**to_add)
[docs] def get(self, name: str) -> Optional[Factory]: """ Retrieves factory, without creating any objects with it or raises error. Args: name: factory name Returns: Factory: factory by name Raises: RegistryException: if no factory with provided name was registered """ self._do_late_add() if name is None: return None res = self._factories.get(name, pydoc.locate(name)) if not res: raise RegistryException(f"No factory with name '{name}' was registered") return res
[docs] def get_if_str(self, obj: Union[str, Factory]): """ Returns object from the registry if ``obj`` type is string. """ if type(obj) is str: return self.get(obj) return obj
[docs] def get_instance(self, name: str, *args, meta_factory: Optional[MetaFactory] = None, **kwargs): """ Creates instance by calling specified factory with ``instantiate_fn``. Args: name: factory name meta_factory: Function that calls factory the right way. If not provided, default is used args: args to pass to the factory **kwargs: kwargs to pass to the factory Returns: created instance Raises: RegistryException: if could not create object instance """ meta_factory = meta_factory or self.meta_factory f = self.get(name) try: if hasattr(f, "get_from_params"): return f.get_from_params(*args, **kwargs) return meta_factory(f, args, kwargs) except Exception as e: raise RegistryException( f"Factory '{name}' call failed: args={args} kwargs={kwargs}" ) from e
def _recursive_get_from_params( self, params: Union[Dict[str, Any], Any], shared_params: Optional[Dict[str, Any]] = None ) -> Any: mapping_types = (dict,) # check for `list` but not all iterables to skip processing of generators iterable_types = (list,) if not isinstance(params, (*mapping_types, *iterable_types)): return params # make a copy of params since we don't want to modify them directly params = copy.deepcopy(params) params_iter = params.items() if isinstance(params, mapping_types) else enumerate(params) for key, param in params_iter: # note: it is assumed that `params` is mutable if isinstance(param, mapping_types): params[key] = self._recursive_get_from_params( params=param, shared_params=shared_params ) elif isinstance(param, iterable_types): params[key] = [ self._recursive_get_from_params(params=value, shared_params=shared_params) for value in param ] if isinstance(params, mapping_types) and self.name_key in params: name = params.pop(self.name_key) shared_params = shared_params or {} # use additional dict to handle 'multiple values for keyword argument' instance = self.get_instance(name=name, **{**shared_params, **params}) return instance return params
[docs] def get_from_params( self, *, shared_params: Optional[Dict[str, Any]] = None, **kwargs, ) -> Union[Any, Tuple[Any, Mapping[str, Any]]]: """ Creates instance based in configuration dict with ``instantiation_fn``. If ``config[name_key]`` is None, None is returned. Args: shared_params: params to pass on all levels in case of recursive creation **kwargs: additional kwargs for factory Returns: result of calling ``instantiate_fn(factory, **config)`` """ instance = self._recursive_get_from_params(params=kwargs, shared_params=shared_params) return instance
[docs] def all(self) -> List[str]: """ Returns: list of names of registered items """ self._do_late_add() result = list(self._factories.keys()) return result
[docs] def len(self) -> int: """ Returns: length of registered items """ return len(self._factories)
def __str__(self) -> str: """Returns a string of registered items.""" return self.all().__str__() def __repr__(self) -> str: """Returns a string representation of registered items.""" return self.all().__str__() # mapping methods def __len__(self) -> int: """Returns length of registered items.""" self._do_late_add() return self.len() def __getitem__(self, name: str) -> Optional[Factory]: """Returns a value from the registry by name.""" return self.get(name) def __iter__(self) -> Iterator[str]: """Iterates over all registered items.""" self._do_late_add() return self._factories.__iter__() def __contains__(self, name: str): """Check if a particular name was registered.""" self._do_late_add() return self._factories.__contains__(name) def __setitem__(self, name: str, factory: Factory) -> None: """Add a new factory by giving name.""" self.add(factory, name=name) def __delitem__(self, name: str) -> None: """Removes a factory by giving name.""" self._factories.pop(name)