Shortcuts

Source code for catalyst.utils.tools.registry

from typing import (  # isort:skip
    Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Type, Union
)
import collections
import inspect
import warnings

Factory = Union[Type, Callable[..., Any]]
LateAddCallbak = Callable[["Registry"], None]
MetaFactory = Callable[[Factory, Tuple, Mapping], Any]


def _default_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping):
    return factory(*args, **kwargs)


[docs]class RegistryException(Exception): """Exception class for all registry errors"""
[docs] def __init__(self, message): """Init""" super().__init__(message)
[docs]class Registry(collections.MutableMapping): """ Universal class allowing to add and access various factories by name """
[docs] def __init__( self, default_name_key: str, default_meta_factory: MetaFactory = _default_meta_factory ): """ Args: default_name_key (str): Default key containing factory name when creating from config default_meta_factory (MetaFactory): default object that calls factory. Optional. Default just calls factory. """ self.meta_factory = default_meta_factory self._name_key = default_name_key self._factories: Dict[str, Factory] = {} self._late_add_callbacks: List[LateAddCallbak] = []
@staticmethod def _get_factory_name(f, provided_name=None) -> str: if not provided_name: provided_name = getattr(f, "__name__", None) if not provided_name: raise RegistryException( f"Factory {f} has no __name__ and no " f"name was provided" ) if provided_name == "<lambda>": raise RegistryException( "Name for lambda factories must be provided" ) return provided_name def _do_late_add(self): if self._late_add_callbacks: for cb in self._late_add_callbacks: cb(self) self._late_add_callbacks = []
[docs] def add( self, factory: Factory = None, *factories: Factory, name: str = None, **named_factories: Factory ) -> Factory: """ Adds factory to registry with it's ``__name__`` attribute or provided name. Signature is flexible. Args: factory: Factory instance factories: More instances name: Provided name for first instance. Use only when pass single instance. named_factories: Factory and their names as kwargs Returns: (Factory): First factory passed """ if len(factories) > 0 and name is not None: raise RegistryException( "Multiple factories with single name are not allowed" ) if factory is not None: named_factories[self._get_factory_name(factory, name)] = factory if len(factories) > 0: new = {self._get_factory_name(f): f for f in factories} named_factories.update(new) if len(named_factories) == 0: warnings.warn("No factories were provided!") for name, f in named_factories.items(): # self._factories[name] != f is a workaround for # https://github.com/catalyst-team/catalyst/issues/135 if name in self._factories and self._factories[name] != f: raise RegistryException( f"Factory with name '{name}' is already present\n" f"Already registered: '{self._factories[name]}'\n" f"New: '{f}'" ) self._factories.update(named_factories) return factory
[docs] def late_add(self, cb: LateAddCallbak): """ Allows to prevent cycle imports by delaying some imports till next registry query Args: cb: Callback receives registry and must call it's methods to register factories """ self._late_add_callbacks.append(cb)
[docs] def add_from_module( self, module, prefix: Union[str, List[str]] = None ) -> None: """ Adds all factories present in module. If ``__all__`` attribute is present, takes ony what mentioned in it Args: module: module to scan prefix (Union[str, List[str]]): prefix string for all the module's factories. If prefix is a list, all values will be treated as aliases. """ factories = { k: v for k, v in module.__dict__.items() if inspect.isclass(v) or inspect.isfunction(v) } # Filter by __all__ if present names_to_add = getattr(module, "__all__", list(factories.keys())) if prefix is None: prefix = [""] elif isinstance(prefix, str): prefix = [prefix] elif isinstance(prefix, list): if any([(not isinstance(p, str)) for p in prefix]): raise TypeError(f"All prefix in list must be strings.") else: raise TypeError( f"Prefix must be a list or a string, got {type(prefix)}." ) to_add = { f"{p}{name}": factories[name] for p in prefix for name in names_to_add } self.add(**to_add)
[docs] def get(self, name: str) -> Optional[Factory]: """ Retrieves factory, without creating any objects with it or raises error Args: name: factory name Returns: Factory: factory by name """ self._do_late_add() if name is None: return None res = self._factories.get(name, None) if not res: raise RegistryException( f"No factory with name '{name}' was registered" ) return res
[docs] def get_if_str(self, obj: Union[str, Factory]): """ Returns object from the registry if ``obj`` type is string """ if type(obj) is str: return self.get(obj) return obj
[docs] def get_instance(self, name: str, *args, meta_factory=None, **kwargs): """ Creates instance by calling specified factory with instantiate_fn Args: name: factory name meta_factory: Function that calls factory the right way. If not provided, default is used args: args to pass to the factory **kwargs: kwargs to pass to the factory Returns: created instance """ meta_factory = meta_factory or self.meta_factory f = self.get(name) try: if hasattr(f, "get_from_params"): return f.get_from_params(*args, **kwargs) return meta_factory(f, args, kwargs) except Exception as e: raise RegistryException( f"Factory '{name}' call failed: args={args} kwargs={kwargs}" ) from e
[docs] def get_from_params( self, *, meta_factory=None, **kwargs ) -> Union[Any, Tuple[Any, Mapping[str, Any]]]: """ Creates instance based in configuration dict with ``instantiation_fn``. If ``config[name_key]`` is None, None is returned. Args: meta_factory: Function that calls factory the right way. If not provided, default is used. **kwargs: additional kwargs for factory Returns: result of calling ``instantiate_fn(factory, **config)`` """ name = kwargs.pop(self._name_key, None) if name: return self.get_instance(name, meta_factory=meta_factory, **kwargs)
[docs] def all(self) -> List[str]: """ Returns: list of names of registered items """ self._do_late_add() result = list(self._factories.keys()) return result
[docs] def len(self) -> int: """ Returns: length of registered items """ return len(self._factories)
def __str__(self) -> str: """Returns a string of registered items""" return self.all().__str__() def __repr__(self) -> str: """Returns a string representation of registered items""" return self.all().__str__() # mapping methods def __len__(self) -> int: """Returns length of registered items""" self._do_late_add() return self.len() def __getitem__(self, name: str) -> Optional[Factory]: """Returns a value from the registry by name""" return self.get(name) def __iter__(self) -> Iterator[str]: """Iterates over all registered items""" self._do_late_add() return self._factories.__iter__() def __contains__(self, name: str): """Check if a particular name was registered""" self._do_late_add() return self._factories.__contains__(name) def __setitem__(self, name: str, factory: Factory) -> None: """Add a new factory by giving name""" self.add(factory, name=name) def __delitem__(self, name: str) -> None: """Removes a factory by giving name""" self._factories.pop(name)
__all__ = ["Registry", "RegistryException"]