Source code for catalyst.utils.misc
from typing import Any, Callable, List, Union
from datetime import datetime
import inspect
from pathlib import Path
import random
import shutil
import numpy as np
from packaging.version import parse, Version
[docs]def set_global_seed(seed: int) -> None:
"""Sets random seed into Numpy and Random, PyTorch and TensorFlow.
Args:
seed: random seed
"""
random.seed(seed)
np.random.seed(seed)
try:
import torch
except ImportError:
pass
else:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
import tensorflow as tf
except ImportError:
pass
else:
if parse(tf.__version__) >= Version("2.0.0"):
tf.random.set_seed(seed)
elif parse(tf.__version__) <= Version("1.13.2"):
tf.set_random_seed(seed)
else:
tf.compat.v1.set_random_seed(seed)
[docs]def maybe_recursive_call(
object_or_dict,
method: Union[str, Callable],
recursive_args=None,
recursive_kwargs=None,
**kwargs,
):
"""Calls the ``method`` recursively for the ``object_or_dict``.
Args:
object_or_dict: some object or a dictionary of objects
method: method name to call
recursive_args: list of arguments to pass to the ``method``
recursive_kwargs: list of key-arguments to pass to the ``method``
**kwargs: Arbitrary keyword arguments
Returns:
result of `method` call
"""
if isinstance(object_or_dict, dict):
result = type(object_or_dict)()
for k, v in object_or_dict.items():
r_args = None if recursive_args is None else recursive_args[k]
r_kwargs = (
None if recursive_kwargs is None else recursive_kwargs[k]
)
result[k] = maybe_recursive_call(
v,
method,
recursive_args=r_args,
recursive_kwargs=r_kwargs,
**kwargs,
)
return result
r_args = recursive_args or []
if not isinstance(r_args, (list, tuple)):
r_args = [r_args]
r_kwargs = recursive_kwargs or {}
if isinstance(method, str):
return getattr(object_or_dict, method)(*r_args, **r_kwargs, **kwargs)
else:
return method(object_or_dict, *r_args, **r_kwargs, **kwargs)
[docs]def is_exception(ex: Any) -> bool:
"""Check if the argument is of ``Exception`` type."""
result = (ex is not None) and isinstance(ex, BaseException)
return result
[docs]def copy_directory(input_dir: Path, output_dir: Path) -> None:
"""Recursively copies the input directory.
Args:
input_dir: input directory
output_dir: output directory
"""
output_dir.mkdir(exist_ok=True, parents=True)
for path in input_dir.iterdir():
if path.is_dir():
path_name = path.name
copy_directory(path, output_dir / path_name)
else:
shutil.copy2(path, output_dir)
[docs]def get_utcnow_time(format: str = None) -> str:
"""Return string with current utc time in chosen format.
Args:
format: format string. if None "%y%m%d.%H%M%S" will be used.
Returns:
str: formatted utc time string
"""
if format is None:
format = "%y%m%d.%H%M%S"
result = datetime.utcnow().strftime(format)
return result
[docs]def get_fn_default_params(fn: Callable[..., Any], exclude: List[str] = None):
"""Return default parameters of Callable.
Args:
fn (Callable[..., Any]): target Callable
exclude: exclude list of parameters
Returns:
dict: contains default parameters of `fn`
"""
argspec = inspect.getfullargspec(fn)
default_params = zip(
argspec.args[-len(argspec.defaults) :], argspec.defaults
)
if exclude is not None:
default_params = filter(lambda x: x[0] not in exclude, default_params)
default_params = dict(default_params)
return default_params
[docs]def get_fn_argsnames(fn: Callable[..., Any], exclude: List[str] = None):
"""Return parameter names of Callable.
Args:
fn (Callable[..., Any]): target Callable
exclude: exclude list of parameters
Returns:
list: contains parameter names of `fn`
"""
argspec = inspect.getfullargspec(fn)
params = argspec.args + argspec.kwonlyargs
if exclude is not None:
params = list(filter(lambda x: x not in exclude, params))
return params
[docs]def get_attr(obj: Any, key: str, inner_key: str = None) -> Any:
"""
Alias for python `getattr` method. Useful for Callbacks preparation
and cases with multi-criterion, multi-optimizer setup.
For example, when you would like to train multi-task classification.
Used to get a named attribute from a `IRunner` by `key` keyword;
for example\
::
# example 1
runner.get_attr("criterion")
# is equivalent to
runner.criterion
# example 2
runner.get_attr("optimizer")
# is equivalent to
runner.optimizer
# example 3
runner.get_attr("scheduler")
# is equivalent to
runner.scheduler
With `inner_key` usage, it suppose to find a dictionary under `key`\
and would get `inner_key` from this dict; for example,
::
# example 1
runner.get_attr("criterion", "bce")
# is equivalent to
runner.criterion["bce"]
# example 2
runner.get_attr("optimizer", "adam")
# is equivalent to
runner.optimizer["adam"]
# example 3
runner.get_attr("scheduler", "adam")
# is equivalent to
runner.scheduler["adam"]
Args:
obj: object of interest
key: name for attribute of interest,
like `criterion`, `optimizer`, `scheduler`
inner_key: name of inner dictionary key
Returns:
inner attribute
"""
if inner_key is None:
return getattr(obj, key)
else:
return getattr(obj, key)[inner_key]
__all__ = [
"copy_directory",
"format_metric",
"get_fn_default_params",
"get_fn_argsnames",
"get_utcnow_time",
"is_exception",
"maybe_recursive_call",
"get_attr",
"set_global_seed",
]