from typing import Any, Dict, List, Union # isort:skip
from collections import OrderedDict
import copy
import json
from logging import getLogger
import os
from pathlib import Path
import platform
import re
import shutil
import subprocess
import sys
import safitty
import yaml
from catalyst import utils
from catalyst.utils.tools.tensorboard import SummaryWriter
LOG = getLogger(__name__)
[docs]def load_ordered_yaml(
stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict
):
"""
Loads `yaml` config into OrderedDict
Args:
stream: opened file with yaml
Loader: base class for yaml Loader
object_pairs_hook: type of mapping
Returns:
dict: configuration
"""
class OrderedLoader(Loader):
pass
def construct_mapping(loader, node):
loader.flatten_mapping(node)
return object_pairs_hook(loader.construct_pairs(node))
OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping
)
OrderedLoader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
u"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""", re.X
), list(u"-+0123456789.")
)
return yaml.load(stream, OrderedLoader)
def _decode_dict(dictionary: Dict[str, Union[bytes, str]]) -> Dict[str, str]:
"""
Decode bytes values in the dictionary to UTF-8
Args:
dictionary: a dict
Returns:
dict: decoded dict
"""
result = {
k: v.decode("UTF-8") if type(v) == bytes else v
for k, v in dictionary.items()
}
return result
[docs]def get_environment_vars() -> Dict[str, Any]:
"""
Creates a dictionary with environment variables
Returns:
dict: environment variables
"""
result = {
"python_version": sys.version,
"conda_environment": os.environ.get("CONDA_DEFAULT_ENV", ""),
"creation_time": utils.get_utcnow_time(),
"sysname": platform.uname()[0],
"nodename": platform.uname()[1],
"release": platform.uname()[2],
"version": platform.uname()[3],
"architecture": platform.uname()[4],
"user": os.environ.get("USER", ""),
"path": os.environ.get("PWD", ""),
}
with open(os.devnull, "w") as devnull:
try:
git_branch = subprocess.check_output(
"git rev-parse --abbrev-ref HEAD".split(),
shell=True,
stderr=devnull
).strip().decode("UTF-8")
git_local_commit = subprocess.check_output(
"git rev-parse HEAD".split(), shell=True, stderr=devnull
)
git_origin_commit = subprocess.check_output(
f"git rev-parse origin/{git_branch}".split(),
shell=True,
stderr=devnull
)
git = dict(
branch=git_branch,
local_commit=git_local_commit,
origin_commit=git_origin_commit
)
result["git"] = _decode_dict(git)
except (subprocess.CalledProcessError, FileNotFoundError):
pass
result = _decode_dict(result)
return result
def list_pip_packages() -> str:
result = ""
with open(os.devnull, "w") as devnull:
try:
result = subprocess.check_output(
"pip freeze".split(), stderr=devnull
).strip().decode("UTF-8")
except FileNotFoundError:
pass
except subprocess.CalledProcessError as e:
raise Exception("Failed to list packages") from e
return result
def list_conda_packages() -> str:
result = ""
conda_meta_path = Path(sys.prefix) / "conda-meta"
if conda_meta_path.exists():
# We are currently in conda virtual env
with open(os.devnull, "w") as devnull:
try:
result = subprocess.check_output(
"conda list --export".split(), stderr=devnull
).strip().decode("UTF-8")
except FileNotFoundError:
pass
except subprocess.CalledProcessError as e:
raise Exception(
f"Running from conda env, "
f"but failed to list conda packages. "
f"Conda Output: {e.output}"
) from e
return result
[docs]def dump_environment(
experiment_config: Dict,
logdir: str,
configs_path: List[str] = None,
) -> None:
"""
Saves config, environment variables and package list in JSON into logdir
Args:
experiment_config (dict): experiment config
logdir (str): path to logdir
configs_path: path(s) to config
"""
configs_path = configs_path or []
configs_path = [
Path(path) for path in configs_path if isinstance(path, str)
]
config_dir = Path(logdir) / "configs"
config_dir.mkdir(exist_ok=True, parents=True)
environment = get_environment_vars()
safitty.save(experiment_config, config_dir / "_config.json")
safitty.save(environment, config_dir / "_environment.json")
pip_pkg = list_pip_packages()
(config_dir / "pip-packages.txt").write_text(pip_pkg)
conda_pkg = list_conda_packages()
if conda_pkg:
(config_dir / "conda-packages.txt").write_text(conda_pkg)
for path in configs_path:
name: str = path.name
outpath = config_dir / name
shutil.copyfile(path, outpath)
config_str = json.dumps(experiment_config, indent=2, ensure_ascii=False)
config_str = config_str.replace("\n", "\n\n")
environment_str = json.dumps(environment, indent=2, ensure_ascii=False)
environment_str = environment_str.replace("\n", "\n\n")
pip_pkg = pip_pkg.replace("\n", "\n\n")
conda_pkg = conda_pkg.replace("\n", "\n\n")
with SummaryWriter(config_dir) as writer:
writer.add_text("_config", config_str, 0)
writer.add_text("_environment", environment_str, 0)
writer.add_text("pip-packages", pip_pkg, 0)
if conda_pkg:
writer.add_text("conda-packages", conda_pkg, 0)
[docs]def parse_config_args(*, config, args, unknown_args):
for arg in unknown_args:
arg_name, value = arg.split("=")
arg_name = arg_name.lstrip("-").strip("/")
value_content, value_type = value.rsplit(":", 1)
if "/" in arg_name:
arg_names = arg_name.split("/")
if value_type == "str":
arg_value = value_content
if arg_value.lower() == "none":
arg_value = None
else:
arg_value = eval("%s(%s)" % (value_type, value_content))
config_ = config
for arg_name in arg_names[:-1]:
if arg_name not in config_:
config_[arg_name] = {}
config_ = config_[arg_name]
config_[arg_names[-1]] = arg_value
else:
if value_type == "str":
arg_value = value_content
else:
arg_value = eval("%s(%s)" % (value_type, value_content))
args.__setattr__(arg_name, arg_value)
args_exists_ = config.get("args")
if args_exists_ is None:
config["args"] = dict()
for key, value in args._get_kwargs():
if value is not None:
if key in ["logdir", "baselogdir"] and value == "":
continue
config["args"][key] = value
return config, args
[docs]def parse_args_uargs(args, unknown_args):
"""
Function for parsing configuration files
Args:
args: recognized arguments
unknown_args: unrecognized arguments
Returns:
tuple: updated arguments, dict with config
"""
args_ = copy.deepcopy(args)
# load params
config = {}
for config_path in args_.configs:
with open(config_path, "r") as fin:
if config_path.endswith("json"):
config_ = json.load(fin, object_pairs_hook=OrderedDict)
elif config_path.endswith("yml"):
config_ = load_ordered_yaml(fin)
else:
raise Exception("Unknown file format")
config = utils.merge_dicts(config, config_)
config, args_ = parse_config_args(
config=config, args=args_, unknown_args=unknown_args
)
# hack with argparse in config
config_args = config.get("args", None)
if config_args is not None:
for key, value in config_args.items():
arg_value = getattr(args_, key, None)
if arg_value is None or \
(key in ["logdir", "baselogdir"] and arg_value == ""):
arg_value = value
setattr(args_, key, arg_value)
return args_, config
__all__ = [
"load_ordered_yaml", "get_environment_vars", "dump_environment",
"parse_config_args", "parse_args_uargs"
]