from typing import Any, Callable, Dict, List, Mapping, Union
from collections import OrderedDict
from copy import deepcopy
import torch
from torch import nn
from torch.utils.data import DataLoader
from catalyst.contrib.data.augmentor import Augmentor, AugmentorCompose
from catalyst.core.callback import Callback
from catalyst.core.experiment import IExperiment
from catalyst.experiments.functional import (
add_default_callbacks,
do_lr_linear_scaling,
get_model_parameters,
load_optimizer_from_checkpoint,
process_callbacks,
)
from catalyst.registry import (
CALLBACKS,
CRITERIONS,
MODELS,
OPTIMIZERS,
SCHEDULERS,
TRANSFORMS,
)
from catalyst.typing import Criterion, Model, Optimizer, Scheduler
from catalyst.utils.loaders import get_loaders_from_params
from catalyst.utils.misc import get_short_hash, get_utcnow_time, merge_dicts
[docs]class ConfigExperiment(IExperiment):
"""
Experiment created from a configuration file.
"""
STAGE_KEYWORDS = [ # noqa: WPS115
"criterion_params",
"optimizer_params",
"scheduler_params",
"data_params",
"transform_params",
"stage_params",
"callbacks_params",
]
[docs] def __init__(self, config: Dict):
"""
Args:
config: dictionary with parameters
"""
self._config: Dict = deepcopy(config)
self._trial = None
self._initial_seed: int = self._config.get("args", {}).get("seed", 42)
self._verbose: bool = self._config.get("args", {}).get(
"verbose", False
)
self._check_time: bool = self._config.get("args", {}).get(
"timeit", False
)
self._check_run: bool = self._config.get("args", {}).get(
"check", False
)
self._overfit: bool = self._config.get("args", {}).get(
"overfit", False
)
self._prepare_logdir()
self._config["stages"]["stage_params"] = merge_dicts(
deepcopy(
self._config["stages"].get("state_params", {})
), # saved for backward compatibility
deepcopy(self._config["stages"].get("stage_params", {})),
deepcopy(self._config.get("args", {})),
{"logdir": self._logdir},
)
self.stages_config: Dict = self._get_stages_config(
self._config["stages"]
)
def _get_logdir(self, config: Dict) -> str:
timestamp = get_utcnow_time()
config_hash = get_short_hash(config)
logdir = f"{timestamp}.{config_hash}"
return logdir
def _prepare_logdir(self): # noqa: WPS112
exclude_tag = "none"
logdir = self._config.get("args", {}).get("logdir", None)
baselogdir = self._config.get("args", {}).get("baselogdir", None)
if logdir is not None and logdir.lower() != exclude_tag:
self._logdir = logdir
elif baselogdir is not None and baselogdir.lower() != exclude_tag:
logdir_postfix = self._get_logdir(self._config)
self._logdir = f"{baselogdir}/{logdir_postfix}"
else:
self._logdir = None
def _get_stages_config(self, stages_config: Dict):
stages_defaults = {}
stages_config_out = OrderedDict()
for key in self.STAGE_KEYWORDS:
if key == "stage_params":
# backward compatibility
stages_defaults[key] = merge_dicts(
deepcopy(stages_config.get("state_params", {})),
deepcopy(stages_config.get(key, {})),
)
else:
stages_defaults[key] = deepcopy(stages_config.get(key, {}))
for stage in stages_config:
if (
stage in self.STAGE_KEYWORDS
or stage == "state_params"
or stages_config.get(stage) is None
):
continue
stages_config_out[stage] = {}
for key2 in self.STAGE_KEYWORDS:
if key2 == "stage_params":
# backward compatibility
stages_config_out[stage][key2] = merge_dicts(
deepcopy(stages_defaults.get("state_params", {})),
deepcopy(stages_defaults.get(key2, {})),
deepcopy(stages_config[stage].get("state_params", {})),
deepcopy(stages_config[stage].get(key2, {})),
)
else:
stages_config_out[stage][key2] = merge_dicts(
deepcopy(stages_defaults.get(key2, {})),
deepcopy(stages_config[stage].get(key2, {})),
)
return stages_config_out
@property
def initial_seed(self) -> int:
"""Experiment's initial seed value."""
return self._initial_seed
@property
def logdir(self):
"""Path to the directory where the experiment logs."""
return self._logdir
@property
def hparams(self) -> OrderedDict:
"""Returns hyperparameters"""
return OrderedDict(self._config)
@property
def trial(self) -> Any:
"""
Returns hyperparameter trial for current experiment.
Could be usefull for Optuna/HyperOpt/Ray.tune
hyperparameters optimizers.
Returns:
trial
Example::
>>> experiment.trial
optuna.trial._trial.Trial # Optuna variant
"""
return self._trial
@property
def distributed_params(self) -> Dict:
"""Dict with the parameters for distributed and FP16 methond."""
return self._config.get("distributed_params", {})
@property
def stages(self) -> List[str]:
"""Experiment's stage names."""
stages_keys = list(self.stages_config.keys())
return stages_keys
[docs] def get_stage_params(self, stage: str) -> Mapping[str, Any]:
"""Returns the state parameters for a given stage."""
return self.stages_config[stage].get("stage_params", {})
@staticmethod
def _get_model(**params):
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
model = {}
for model_key, model_params in params.items():
model[model_key] = ConfigExperiment._get_model( # noqa: WPS437
**model_params
)
model = nn.ModuleDict(model)
else:
model = MODELS.get_from_params(**params)
return model
[docs] def get_model(self, stage: str):
"""Returns the model for a given stage."""
model_params = self._config["model_params"]
model = self._get_model(**model_params)
return model
@staticmethod
def _get_criterion(**params):
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
criterion = {}
for key, key_params in params.items():
criterion[
key
] = ConfigExperiment._get_criterion( # noqa: WPS437
**key_params
)
else:
criterion = CRITERIONS.get_from_params(**params)
if criterion is not None and torch.cuda.is_available():
criterion = criterion.cuda()
return criterion
[docs] def get_criterion(self, stage: str) -> Criterion:
"""Returns the criterion for a given stage."""
criterion_params = self.stages_config[stage].get(
"criterion_params", {}
)
criterion = self._get_criterion(**criterion_params)
return criterion
def _get_optimizer(
self, stage: str, model: Union[Model, Dict[str, Model]], **params
) -> Optimizer:
# @TODO 1: refactoring; this method is too long
# @TODO 2: load state dicts for schedulers & criterion
# lr linear scaling
lr_scaling_params = params.pop("lr_linear_scaling", None)
if lr_scaling_params:
data_params = dict(self.stages_config[stage]["data_params"])
batch_size = data_params.get("batch_size")
per_gpu_scaling = data_params.get("per_gpu_scaling", False)
lr, lr_scaling = do_lr_linear_scaling(
lr_scaling_params=lr_scaling_params,
batch_size=batch_size,
per_gpu_scaling=per_gpu_scaling,
)
params["lr"] = lr
else:
lr_scaling = 1.0
# getting layer-wise parameters
layerwise_params = params.pop("layerwise_params", OrderedDict())
no_bias_weight_decay = params.pop("no_bias_weight_decay", True)
# getting model parameters
model_key = params.pop("_model", None)
model_params = get_model_parameters(
models=model,
models_keys=model_key,
layerwise_params=layerwise_params,
no_bias_weight_decay=no_bias_weight_decay,
lr_scaling=lr_scaling,
)
# getting load-from-previous-stage flag
load_from_previous_stage = params.pop(
"load_from_previous_stage", False
)
# instantiate optimizer
optimizer_key = params.pop("optimizer_key", None)
optimizer = OPTIMIZERS.get_from_params(**params, params=model_params)
# load from previous stage
if load_from_previous_stage and self.stages.index(stage) != 0:
checkpoint_path = f"{self.logdir}/checkpoints/best_full.pth"
optimizer = load_optimizer_from_checkpoint(
optimizer,
checkpoint_path=checkpoint_path,
checkpoint_optimizer_key=optimizer_key,
model_parameters=model_params,
optimizer_params=params,
)
return optimizer
[docs] def get_optimizer(
self, stage: str, model: Union[Model, Dict[str, Model]]
) -> Union[Optimizer, Dict[str, Optimizer]]:
"""
Returns the optimizer for a given stage.
Args:
stage: stage name
model (Union[Model, Dict[str, Model]]): model or a dict of models
Returns:
optimizer for selected stage
"""
optimizer_params = self.stages_config[stage].get(
"optimizer_params", {}
)
key_value_flag = optimizer_params.pop("_key_value", False)
if key_value_flag:
optimizer = {}
for key, params in optimizer_params.items():
# load specified optimizer from checkpoint
optimizer_key = "optimizer_key"
assert optimizer_key not in params, "keyword reserved"
params[optimizer_key] = key
optimizer[key] = self._get_optimizer(stage, model, **params)
else:
optimizer = self._get_optimizer(stage, model, **optimizer_params)
return optimizer
@staticmethod
def _get_scheduler(
*, optimizer: Union[Optimizer, Dict[str, Optimizer]], **params: Any
) -> Union[Scheduler, Dict[str, Scheduler]]:
optimizer_key = params.pop("_optimizer", None)
optimizer = optimizer[optimizer_key] if optimizer_key else optimizer
scheduler = SCHEDULERS.get_from_params(**params, optimizer=optimizer)
return scheduler
[docs] def get_scheduler(
self, stage: str, optimizer: Union[Optimizer, Dict[str, Optimizer]]
) -> Union[Scheduler, Dict[str, Scheduler]]:
"""Returns the scheduler for a given stage."""
params = self.stages_config[stage].get("scheduler_params", {})
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
scheduler: Dict[str, Scheduler] = {}
for key, scheduler_params in params.items():
scheduler[key] = self._get_scheduler(
optimizer=optimizer, **scheduler_params
)
else:
scheduler = self._get_scheduler(optimizer=optimizer, **params)
return scheduler
@staticmethod
def _get_transform(**params) -> Callable:
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
transforms_composition = {
transform_key: ConfigExperiment._get_transform( # noqa: WPS437
**transform_params
)
for transform_key, transform_params in params.items()
}
transform = AugmentorCompose(
{
key: Augmentor(
dict_key=key,
augment_fn=transform,
input_key=key,
output_key=key,
)
for key, transform in transforms_composition.items()
}
)
else:
if "transforms" in params:
transforms_composition = [
ConfigExperiment._get_transform( # noqa: WPS437
**transform_params
)
for transform_params in params["transforms"]
]
params.update(transforms=transforms_composition)
transform = TRANSFORMS.get_from_params(**params)
return transform
[docs] def get_loaders(
self, stage: str, epoch: int = None,
) -> "OrderedDict[str, DataLoader]":
"""Returns the loaders for a given stage."""
data_params = dict(self.stages_config[stage]["data_params"])
loaders = get_loaders_from_params(
get_datasets_fn=self.get_datasets,
initial_seed=self.initial_seed,
stage=stage,
**data_params,
)
return loaders
@staticmethod
def _get_callback(**params):
wrapper_params = params.pop("_wrapper", None)
callback = CALLBACKS.get_from_params(**params)
if wrapper_params is not None:
wrapper_params["base_callback"] = callback
callback = ConfigExperiment._get_callback( # noqa: WPS437
**wrapper_params
)
return callback
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[Callback]":
"""Returns the callbacks for a given stage."""
callbacks_params = self.stages_config[stage].get(
"callbacks_params", {}
)
callbacks = OrderedDict()
for key, callback_params in callbacks_params.items():
callback = self._get_callback(**callback_params)
callbacks[key] = callback
callbacks = add_default_callbacks(
callbacks,
verbose=self._verbose,
check_time=self._check_time,
check_run=self._check_run,
overfit=self._overfit,
is_infer=stage.startswith("infer"),
is_logger=self.logdir is not None,
is_criterion=self.stages_config[stage].get("criterion_params", {}),
is_optimizer=self.stages_config[stage].get("optimizer_params", {}),
is_scheduler=self.stages_config[stage].get("scheduler_params", {}),
)
# NOTE: stage should be in self._config.stages
# othervise will be raised ValueError
stage_index = list(self.stages_config.keys()).index(stage)
process_callbacks(callbacks, stage_index)
return callbacks
__all__ = ["ConfigExperiment"]