Source code for hydra_slayer.functional

from typing import Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar, Union
import copy
import inspect
import pydoc
import warnings

from hydra_slayer.factory import Factory, metafactory_factory

__all__ = ["get_factory", "get_instance", "get_from_params"]

T = TypeVar("T")

DEFAULT_FACTORY_KEY = "_target_"
DEFAULT_VAR_KEY = "_var_"
DEFAULT_ATTRS_DELIMITER = "."


def _extract_factory_name_arg(
    factory_key: str = DEFAULT_FACTORY_KEY, args: Iterable = None, kwargs: Dict = None
) -> Tuple[Optional[str], Iterable, Dict]:
    args, kwargs = args or (), kwargs or {}

    factory_name = kwargs.pop(factory_key, None)
    if factory_name is None and args:
        factory_name, *args = args

    return factory_name, args, kwargs


def _extract_positional_keyword_vars(func: Callable, kwargs: Dict) -> Tuple[Iterable, Dict]:
    # make a copy of kwargs since we don't want to modify them directly
    kwargs = copy.copy(kwargs)

    try:
        signature = inspect.signature(func)
        type2param = {p.kind: name for name, p in signature.parameters.items()}
    except ValueError:
        type2param = {}
        warnings.warn(
            f"No signature found for `{func}`, *args and **kwargs arguments cannot be extracted"
        )

    var_kwarg = kwargs.pop(type2param.get(inspect.Parameter.VAR_KEYWORD), {})
    kwargs.update(var_kwarg)

    args = kwargs.pop(type2param.get(inspect.Parameter.VAR_POSITIONAL), ())

    return args, kwargs


[docs]def get_factory(name_or_object: Union[str, T]) -> Union[Factory, T]: """Retrieves factory, without creating any objects with it. Args: name_or_object: factory name or any valid python object Returns: factory Raises: LookupError: if no factory with provided name was registered Examples: >>> to_int = get_factory("int") >>> to_int("42") 42 """ if isinstance(name_or_object, str): factory = pydoc.locate(name_or_object) if not factory: raise LookupError(f"No factory with name '{name_or_object}' was registered") return factory return name_or_object
def _get_instance( factory_key: str = DEFAULT_FACTORY_KEY, get_factory_func: Callable = None, args: Optional[Iterable] = None, kwargs: Optional[Dict] = None, ) -> Any: """Creates instance by calling specified factory with ``instantiate_fn``. Note: The name of the factory to use must be provided as the first argument or directly by ``'_target_'`` keyword. Args: factory_key: key to extract factory name from get_factory_func: function that returns factory by its name. Default: :py:func:`.functional.get_factory` args: positional arguments to be passed into the factory kwargs: keyword arguments to be passed into the factory Returns: created instance Raises: TypeError: if factory name argument is missing RuntimeError: if could not create object instance """ get_factory_func = get_factory_func or get_factory args, kwargs = args or (), kwargs or {} # assume that name of the factory can be provided as first argument # or directly by keyword name, args, kwargs = _extract_factory_name_arg( factory_key=factory_key, args=args, kwargs=kwargs ) if name is None: raise TypeError(f"get_instance() missing at least 1 required argument: '{factory_key}'") factory = get_factory_func(name) args_, kwargs = _extract_positional_keyword_vars(factory, kwargs=kwargs) args = *args, *args_ try: instance = metafactory_factory(factory=factory, args=args, kwargs=kwargs) return instance except Exception as e: raise RuntimeError(f"Factory '{name}' call failed: args={args} kwargs={kwargs}") from e
[docs]def get_instance(*args, **kwargs) -> Any: """Creates instance by calling specified factory with ``instantiate_fn``. Note: The name of the factory to use should be provided as the first argument or directly by ``'_target_'`` keyword. Args: *args: positional arguments to pass to the factory **kwargs: named parameters to pass to the factory Returns: created instance Examples: >>> get_instance(int, "42", base=10) 42 """ instance = _get_instance(args=args, kwargs=kwargs) return instance
def _get_from_params( factory_key: str, get_factory_func: Callable, params: Dict[str, Any], shared_params: Dict[str, Any], var_key: str, attrs_delimiter: str, vars_dict: Dict[str, Any], ) -> Tuple[Any, Dict[str, Any]]: # use additional dict to handle 'multiple values for keyword argument' kwargs = {**shared_params, **params} params.pop(var_key, None) alias = kwargs.pop(var_key, "") alias, attribute_name = ( alias.split(attrs_delimiter) if attrs_delimiter in alias else (alias, None) ) if alias and alias in vars_dict: if factory_key in kwargs: raise ValueError( f"`{factory_key}` and `{var_key}` (in get mode) keywords are exclusive" ) obj = vars_dict[alias] if attribute_name is not None: obj_or_callable = getattr(obj, attribute_name) if callable(obj_or_callable): args, kwargs = _extract_positional_keyword_vars(obj_or_callable, kwargs=kwargs) obj = obj_or_callable(*args, **kwargs) else: obj = obj_or_callable elif factory_key in kwargs: obj = _get_instance( factory_key=factory_key, get_factory_func=get_factory_func, args=(), kwargs=kwargs, ) else: obj = params if alias and alias not in vars_dict: vars_dict[alias] = obj return obj, vars_dict def _recursive_get_from_params( factory_key: str, get_factory_func: Callable, params: Union[Dict[str, Any], Any], shared_params: Dict[str, Any], var_key: str, attrs_delimiter: str, vars_dict: Dict[str, Any], ) -> Tuple[Any, Dict[str, Any]]: if not isinstance(params, (dict, list)): return params, vars_dict # make a copy of params since we don't want to modify them directly params = copy.copy(params) common_params = { "factory_key": factory_key, "get_factory_func": get_factory_func, "shared_params": shared_params, "var_key": var_key, "attrs_delimiter": attrs_delimiter, } view = params.items() if isinstance(params, dict) else enumerate(params) for key, param in view: params[key], vars_dict = _recursive_get_from_params( params=param, vars_dict=vars_dict, **common_params ) if isinstance(params, dict): instance, vars_dict = _get_from_params(params=params, vars_dict=vars_dict, **common_params) return instance, vars_dict return params, vars_dict
[docs]def get_from_params(*, shared_params: Optional[Dict[str, Any]] = None, **kwargs) -> Any: """ Creates instance based in configuration dict with ``instantiation_fn``. Note: The name of the factory to use should be provided by ``'_target_'`` keyword. Args: shared_params: params to pass on all levels in case of recursive creation **kwargs: named parameters for factory Returns: result of calling ``instantiate_fn(factory, **sub_kwargs)`` Examples: >>> get_from_params(_target_="torch.nn.Linear", in_features=20, out_features=30) Linear(in_features=20, out_features=30, bias=True) """ instance, _ = _recursive_get_from_params( factory_key=DEFAULT_FACTORY_KEY, get_factory_func=get_factory, params=kwargs, shared_params=shared_params or {}, var_key=DEFAULT_VAR_KEY, vars_dict={}, attrs_delimiter=DEFAULT_ATTRS_DELIMITER, ) return instance