Source code for hydra_slayer.functional

from typing import Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar, Union
import copy
import inspect
import pydoc

from hydra_slayer.factory import default_meta_factory, Factory, MetaFactory

__all__ = ["get_factory", "get_instance", "get_from_params"]

T = TypeVar("T")

DEFAULT_FACTORY_KEY = "_target_"


def _extract_factory_name_arg(
    factory_key: str = DEFAULT_FACTORY_KEY, args: Iterable = None, kwargs: Dict = None
) -> Tuple[Optional[str], Iterable, Dict]:
    args, kwargs = args or (), kwargs or {}

    factory_name = kwargs.pop(factory_key, None)
    if factory_name is None and args:
        factory_name, *args = args

    return factory_name, args, kwargs


def _extract_positional_keyword_vars(func: Callable, kwargs: Dict) -> Tuple[Iterable, Dict]:
    # make a copy of kwargs since we don't want to modify them directly
    kwargs = copy.copy(kwargs)

    signature = inspect.signature(func)
    type2param = {p.kind: name for name, p in signature.parameters.items()}

    var_kwarg = kwargs.pop(type2param.get(inspect.Parameter.VAR_KEYWORD), {})
    kwargs.update(var_kwarg)

    args = kwargs.pop(type2param.get(inspect.Parameter.VAR_POSITIONAL), ())

    return args, kwargs


[docs]def get_factory(name_or_object: Union[str, T]) -> Union[Factory, T]: """Retrieves factory, without creating any objects with it. Args: name_or_object: factory name or any valid python object Returns: factory Raises: LookupError: if no factory with provided name was registered Examples: >>> to_int = get_factory("int") >>> to_int("42") 42 """ if isinstance(name_or_object, str): factory = pydoc.locate(name_or_object) if not factory: raise LookupError(f"No factory with name '{name_or_object}' was registered") return factory return name_or_object
def _meta_factory_call( factory: Factory, meta_factory: MetaFactory, args: Iterable, kwargs: Dict ) -> Any: if hasattr(factory, "get_from_params"): return factory.get_from_params(*args, **kwargs) return meta_factory(factory, args, kwargs) def _get_instance( factory_key: str = DEFAULT_FACTORY_KEY, get_factory_func: Callable = None, meta_factory: Optional[MetaFactory] = None, args: Optional[Iterable] = None, kwargs: Optional[Dict] = None, ) -> Any: """Creates instance by calling specified factory with ``instantiate_fn``. Note: The name of the factory to use must be provided as the first argument or directly by ``'_target_'`` keyword. Args: factory_key: key to extract factory name from get_factory_func: function that returns factory by its name. Default: :py:func:`.functional.get_factory` meta_factory: function that calls factory in the right way. Default: :py:func:`.factory.default_meta_factory` args: \*args to pass to the factory kwargs: \*\*kwargs to pass to the factory Returns: created instance Raises: TypeError: if factory name argument is missing RuntimeError: if could not create object instance """ get_factory_func = get_factory_func or get_factory meta_factory = meta_factory or default_meta_factory args, kwargs = args or (), kwargs or {} # assume that name of the factory can be provided as first argument # or directly by keyword name, args, kwarg = _extract_factory_name_arg( factory_key=factory_key, args=args, kwargs=kwargs ) if name is None: raise TypeError(f"get_instance() missing at least 1 required argument: '{factory_key}'") factory = get_factory_func(name) try: instance = _meta_factory_call( factory=factory, meta_factory=meta_factory, args=args, kwargs=kwargs ) return instance except Exception as e: raise RuntimeError(f"Factory '{name}' call failed: args={args} kwargs={kwargs}") from e
[docs]def get_instance(*args, meta_factory: Optional[MetaFactory] = None, **kwargs) -> Any: """Creates instance by calling specified factory with ``instantiate_fn``. Note: The name of the factory to use should be provided as the first argument or directly by ``'_target_'`` keyword. Args: *args: positional arguments to pass to the factory meta_factory: function that calls factory in the right way. Default: :py:func:`.factory.default_meta_factory` **kwargs: named parameters to pass to the factory Returns: created instance Examples: >>> get_instance(int, "42", base=10) 42 """ instance = _get_instance(meta_factory=meta_factory, args=args, kwargs=kwargs) return instance
def _recursive_get_from_params( factory_key: str, get_factory_func: Callable, params: Union[Dict[str, Any], Any], shared_params: Dict[str, Any], ) -> Any: if not isinstance(params, (dict, list)): return params # make a copy of params since we don't want to modify them directly params = copy.copy(params) view = params.items() if isinstance(params, dict) else enumerate(params) for key, param in view: params[key] = _recursive_get_from_params( factory_key=factory_key, get_factory_func=get_factory_func, params=param, shared_params=shared_params, ) if isinstance(params, dict) and (factory_key in params or factory_key in shared_params): # use additional dict to handle 'multiple values for keyword argument' kwargs = {**shared_params, **params} factory_name, _, kwargs = _extract_factory_name_arg( factory_key=factory_key, args=(), kwargs=kwargs ) factory = get_factory_func(factory_name) args, kwargs = _extract_positional_keyword_vars(factory, kwargs=kwargs) meta_factory = kwargs.pop("meta_factory", default_meta_factory) instance = _meta_factory_call( factory=factory, meta_factory=meta_factory, args=args, kwargs=kwargs ) return instance return params
[docs]def get_from_params(*, shared_params: Optional[Dict[str, Any]] = None, **kwargs) -> Any: """ Creates instance based in configuration dict with ``instantiation_fn``. Note: The name of the factory to use should be provided by ``'_target_'`` keyword. Args: shared_params: params to pass on all levels in case of recursive creation **kwargs: named parameters for factory Returns: result of calling ``instantiate_fn(factory, **sub_kwargs)`` Examples: >>> get_from_params(_target_="torch.nn.Linear", in_features=20, out_features=30) Linear(in_features=20, out_features=30, bias=True) """ instance = _recursive_get_from_params( factory_key=DEFAULT_FACTORY_KEY, get_factory_func=get_factory, params=kwargs, shared_params=shared_params or {}, ) return instance