Source code for catalyst.utils.misc

from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import argparse
from base64 import urlsafe_b64encode
import collections
import copy
from datetime import datetime
from hashlib import sha256
import inspect
from itertools import tee
from pathlib import Path
import random
import shutil

import numpy as np
from packaging.version import parse, Version

[docs]def boolean_flag( parser: argparse.ArgumentParser, name: str, default: Optional[bool] = False, help: str = None, # noqa: WPS125 shorthand: str = None, ) -> None: """Add a boolean flag to a parser inplace. Examples: >>> parser = argparse.ArgumentParser() >>> boolean_flag( >>> parser, "flag", default=False, help="some flag", shorthand="f" >>> ) Args: parser: parser to add the flag to name: argument name --<name> will enable the flag, while --no-<name> will disable it default (bool, optional): default value of the flag help: help string for the flag shorthand: shorthand string for the argument """ dest = name.replace("-", "_") names = ["--" + name] if shorthand is not None: names.append("-" + shorthand) parser.add_argument( *names, action="store_true", default=default, dest=dest, help=help ) parser.add_argument("--no-" + name, action="store_false", dest=dest)
[docs]def set_global_seed(seed: int) -> None: """Sets random seed into Numpy and Random, PyTorch and TensorFlow. Args: seed: random seed """ random.seed(seed) np.random.seed(seed) try: import torch except ImportError: pass else: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) try: import tensorflow as tf except ImportError: pass else: if parse(tf.__version__) >= Version("2.0.0"): tf.random.set_seed(seed) elif parse(tf.__version__) <= Version("1.13.2"): tf.set_random_seed(seed) else: tf.compat.v1.set_random_seed(seed)
[docs]def maybe_recursive_call( object_or_dict, method: Union[str, Callable], recursive_args=None, recursive_kwargs=None, **kwargs, ): """Calls the ``method`` recursively for the ``object_or_dict``. Args: object_or_dict: some object or a dictionary of objects method: method name to call recursive_args: list of arguments to pass to the ``method`` recursive_kwargs: list of key-arguments to pass to the ``method`` **kwargs: Arbitrary keyword arguments Returns: result of `method` call """ if isinstance(object_or_dict, dict): result = type(object_or_dict)() for k, v in object_or_dict.items(): r_args = None if recursive_args is None else recursive_args[k] r_kwargs = ( None if recursive_kwargs is None else recursive_kwargs[k] ) result[k] = maybe_recursive_call( v, method, recursive_args=r_args, recursive_kwargs=r_kwargs, **kwargs, ) return result r_args = recursive_args or [] if not isinstance(r_args, (list, tuple)): r_args = [r_args] r_kwargs = recursive_kwargs or {} if isinstance(method, str): return getattr(object_or_dict, method)(*r_args, **r_kwargs, **kwargs) else: return method(object_or_dict, *r_args, **r_kwargs, **kwargs)
[docs]def is_exception(ex: Any) -> bool: """Check if the argument is of ``Exception`` type.""" result = (ex is not None) and isinstance(ex, BaseException) return result
[docs]def copy_directory(input_dir: Path, output_dir: Path) -> None: """Recursively copies the input directory. Args: input_dir: input directory output_dir: output directory """ output_dir.mkdir(exist_ok=True, parents=True) for path in input_dir.iterdir(): if path.is_dir(): path_name = copy_directory(path, output_dir / path_name) else: shutil.copy2(path, output_dir)
[docs]def get_utcnow_time(format: str = None) -> str: """Return string with current utc time in chosen format. Args: format: format string. if None "%y%m%d.%H%M%S" will be used. Returns: str: formatted utc time string """ if format is None: format = "%y%m%d.%H%M%S" result = datetime.utcnow().strftime(format) return result
[docs]def format_metric(name: str, value: float) -> str: """Format metric. Metric will be returned in the scientific format if 4 decimal chars are not enough (metric value lower than 1e-4). Args: name: metric name value: value of metric Returns: str: formatted metric """ if value < 1e-4: return f"{name}={value:1.3e}" return f"{name}={value:.4f}"
[docs]def get_fn_default_params(fn: Callable[..., Any], exclude: List[str] = None): """Return default parameters of Callable. Args: fn (Callable[..., Any]): target Callable exclude: exclude list of parameters Returns: dict: contains default parameters of `fn` """ argspec = inspect.getfullargspec(fn) default_params = zip( argspec.args[-len(argspec.defaults) :], argspec.defaults ) if exclude is not None: default_params = filter(lambda x: x[0] not in exclude, default_params) default_params = dict(default_params) return default_params
[docs]def get_fn_argsnames(fn: Callable[..., Any], exclude: List[str] = None): """Return parameter names of Callable. Args: fn (Callable[..., Any]): target Callable exclude: exclude list of parameters Returns: list: contains parameter names of `fn` """ argspec = inspect.getfullargspec(fn) params = argspec.args + argspec.kwonlyargs if exclude is not None: params = list(filter(lambda x: x not in exclude, params)) return params
[docs]def get_attr(obj: Any, key: str, inner_key: str = None) -> Any: """ Alias for python `getattr` method. Useful for Callbacks preparation and cases with multi-criterion, multi-optimizer setup. For example, when you would like to train multi-task classification. Used to get a named attribute from a `IRunner` by `key` keyword; for example\ :: get_attr(runner, "criterion") # is equivalent to runner.criterion get_attr(runner, "optimizer") # is equivalent to runner.optimizer get_attr(runner, "scheduler") # is equivalent to runner.scheduler With `inner_key` usage, it suppose to find a dictionary under `key`\ and would get `inner_key` from this dict; for example, :: get_attr(runner, "criterion", "bce") # is equivalent to runner.criterion["bce"] get_attr(runner, "optimizer", "adam") # is equivalent to runner.optimizer["adam"] get_attr(runner, "scheduler", "adam") # is equivalent to runner.scheduler["adam"] Args: obj: object of interest key: name for attribute of interest, like `criterion`, `optimizer`, `scheduler` inner_key: name of inner dictionary key Returns: inner attribute """ if inner_key is None: return getattr(obj, key) else: return getattr(obj, key)[inner_key]
def _get_key_str( dictionary: dict, key: Optional[Union[str, List[str]]], ) -> Any: return dictionary[key] def _get_key_list( dictionary: dict, key: Optional[Union[str, List[str]]], ) -> Dict: result = {name: dictionary[name] for name in key} return result def _get_key_dict( dictionary: dict, key: Optional[Union[str, List[str]]], ) -> Dict: result = {key_out: dictionary[key_in] for key_in, key_out in key.items()} return result def _get_key_none( dictionary: dict, key: Optional[Union[str, List[str]]], ) -> Dict: return {} def _get_key_all( dictionary: dict, key: Optional[Union[str, List[str]]], ) -> Dict: return dictionary
[docs]def get_dictkey_auto_fn(key: Optional[Union[str, List[str]]]) -> Callable: """Function generator for sub-dict preparation from dict based on predefined keys. Args: key: keys Returns: function Raises: NotImplementedError: if key is out of `str`, `tuple`, `list`, `dict`, `None` """ if isinstance(key, str): if key == "__all__": return _get_key_all else: return _get_key_str elif isinstance(key, (list, tuple)): return _get_key_list elif isinstance(key, dict): return _get_key_dict elif key is None: return _get_key_none else: raise NotImplementedError()
[docs]def merge_dicts(*dicts: dict) -> dict: """Recursive dict merge. Instead of updating only top-level keys, ``merge_dicts`` recurses down into dicts nested to an arbitrary depth, updating keys. Args: *dicts: several dictionaries to merge Returns: dict: deep-merged dictionary """ assert len(dicts) > 1 dict_ = copy.deepcopy(dicts[0]) for merge_dict in dicts[1:]: merge_dict = merge_dict or {} for k in merge_dict: if ( k in dict_ and isinstance(dict_[k], dict) and isinstance(merge_dict[k], collections.Mapping) ): dict_[k] = merge_dicts(dict_[k], merge_dict[k]) else: dict_[k] = merge_dict[k] return dict_
[docs]def flatten_dict( dictionary: Dict[str, Any], parent_key: str = "", separator: str = "/" ) -> "collections.OrderedDict": """Make the given dictionary flatten. Args: dictionary: giving dictionary parent_key (str, optional): prefix nested keys with string ``parent_key`` separator (str, optional): delimiter between ``parent_key`` and ``key`` to use Returns: collections.OrderedDict: ordered dictionary with flatten keys """ items = [] for key, value in dictionary.items(): new_key = parent_key + separator + key if parent_key else key if isinstance(value, collections.MutableMapping): items.extend( flatten_dict(value, new_key, separator=separator).items() ) else: items.append((new_key, value)) return collections.OrderedDict(items)
[docs]def split_dict_to_subdicts(dct: Dict, prefixes: List, extra_key: str) -> Dict: """ Splits dict into subdicts with spesicied ``prefixes``. Keys, which don't startswith one of the prefixes go to ``extra_key``. Examples: >>> dct = {"train_v1": 1, "train_v2": 2, "not_train": 3} >>> split_dict_to_subdicts(dct, prefixes=["train"], extra_key="_extra") >>> {"train": {"v1": 1, "v2": 2}, "_extra": {"not_train": 3}} Args: dct: dictionary with keys with prefixes prefixes: prefixes of interest, which we would like to reveal extra_key: extra key to store everything else Returns: dictionary with subdictionaries with ``prefixes`` and ``extra_key`` keys """ subdicts = {} extra_subdict = { k: v for k, v in dct.items() if all(not k.startswith(prefix) for prefix in prefixes) } if len(extra_subdict) > 0: subdicts[extra_key] = extra_subdict for prefix in prefixes: subdicts[prefix] = { k.replace(f"{prefix}_", ""): v for k, v in dct.items() if k.startswith(prefix) } return subdicts
def _make_hashable(o): if isinstance(o, (tuple, list)): return tuple(((type(o).__name__, _make_hashable(e)) for e in o)) if isinstance(o, dict): return tuple( sorted( (type(o).__name__, k, _make_hashable(v)) for k, v in o.items() ) ) if isinstance(o, (set, frozenset)): return tuple(sorted((type(o).__name__, _make_hashable(e)) for e in o)) return o
[docs]def get_hash(obj: Any) -> str: """ Creates unique hash from object following way: - Represent obj as sting recursively - Hash this string with sha256 hash function - encode hash with url-safe base64 encoding Args: obj: object to hash Returns: base64-encoded string """ bytes_to_hash = repr(_make_hashable(obj)).encode() hash_bytes = sha256(bytes_to_hash).digest() return urlsafe_b64encode(hash_bytes).decode()
[docs]def get_short_hash(obj) -> str: """ Creates unique short hash from object. Args: obj: object to hash Returns: short base64-encoded string (6 chars) """ hash_ = get_hash(obj)[:6] return hash_
[docs]def pairwise(iterable: Iterable[Any]) -> Iterable[Any]: """Iterate sequences by pairs. Examples: >>> for i in pairwise([1, 2, 5, -3]): >>> print(i) (1, 2) (2, 5) (5, -3) Args: iterable: Any iterable sequence Returns: pairwise iterator """ a, b = tee(iterable) next(b, None) return zip(a, b)
[docs]def make_tuple(tuple_like): """Creates a tuple if given ``tuple_like`` value isn't list or tuple. Args: tuple_like: tuple like object - list or tuple Returns: tuple or list """ tuple_like = ( tuple_like if isinstance(tuple_like, (list, tuple)) else (tuple_like, tuple_like) ) return tuple_like
[docs]def args_are_not_none(*args: Optional[Any]) -> bool: """Check that all arguments are not ``None``. Args: *args: values # noqa: RST213 Returns: bool: True if all value were not None, False otherwise """ if args is None: return False for arg in args: if arg is None: return False return True
[docs]def find_value_ids(it: Iterable[Any], value: Any) -> List[int]: """ Args: it: list of any value: query element Returns: indices of the all elements equal x0 """ if isinstance(it, np.ndarray): inds = list(np.where(it == value)[0]) else: # could be very slow inds = [i for i, el in enumerate(it) if el == value] return inds
__all__ = [ "boolean_flag", "copy_directory", "format_metric", "get_fn_default_params", "get_fn_argsnames", "get_utcnow_time", "is_exception", "maybe_recursive_call", "get_attr", "set_global_seed", "get_dictkey_auto_fn", "merge_dicts", "flatten_dict", "split_dict_to_subdicts", "get_hash", "get_short_hash", "args_are_not_none", "make_tuple", "pairwise", "find_value_ids", ]