from typing import Any, Dict, List, Mapping, Union # isort:skip
from collections import OrderedDict
from copy import deepcopy
import safitty
import torch
from torch import nn
from torch.utils.data import ( # noqa F401
DataLoader, Dataset, DistributedSampler
)
from catalyst.dl import utils
from catalyst.dl.callbacks import (
ConsoleLogger, RaiseExceptionCallback, TensorboardLogger, VerboseLogger
)
from catalyst.dl.core import Callback, Experiment
from catalyst.dl.registry import (
CALLBACKS, CRITERIONS, MODELS, OPTIMIZERS, SCHEDULERS
)
from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
[docs]class ConfigExperiment(Experiment):
STAGE_KEYWORDS = [
"criterion_params",
"optimizer_params",
"scheduler_params",
"data_params",
"state_params",
"callbacks_params",
]
def __init__(self, config: Dict):
self._config = deepcopy(config)
self._initial_seed = self._config.get("args", {}).get("seed", 42)
self._verbose = safitty.get(
self._config, "args", "verbose", default=False
)
self.__prepare_logdir()
self._config["stages"]["state_params"] = utils.merge_dicts(
deepcopy(self._config["stages"].get("state_params", {})),
deepcopy(self._config.get("args", {})), {"logdir": self._logdir}
)
self.stages_config = self._get_stages_config(self._config["stages"])
def __prepare_logdir(self):
EXCLUDE_TAG = "none"
logdir = self._config.get("args", {}).get("logdir", None)
baselogdir = self._config.get("args", {}).get("baselogdir", None)
if logdir is not None and logdir.lower() != EXCLUDE_TAG:
self._logdir = logdir
elif baselogdir is not None and baselogdir.lower() != EXCLUDE_TAG:
logdir_postfix = self._get_logdir(self._config)
self._logdir = f"{baselogdir}/{logdir_postfix}"
else:
self._logdir = None
def _get_stages_config(self, stages_config):
stages_defaults = {}
stages_config_out = OrderedDict()
for key in self.STAGE_KEYWORDS:
stages_defaults[key] = deepcopy(stages_config.get(key, {}))
for stage in stages_config:
if stage in self.STAGE_KEYWORDS \
or stages_config.get(stage) is None:
continue
stages_config_out[stage] = {}
for key in self.STAGE_KEYWORDS:
stages_config_out[stage][key] = utils.merge_dicts(
deepcopy(stages_defaults.get(key, {})),
deepcopy(stages_config[stage].get(key, {})),
)
return stages_config_out
def _get_logdir(self, config: Dict) -> str:
timestamp = utils.get_utcnow_time()
config_hash = utils.get_short_hash(config)
logdir = f"{timestamp}.{config_hash}"
distributed_rank = self.distributed_params.get("rank", -1)
if distributed_rank > -1:
logdir = f"{logdir}.rank{distributed_rank:02d}"
return logdir
@property
def initial_seed(self) -> int:
return self._initial_seed
@property
def logdir(self):
return self._logdir
@property
def stages(self) -> List[str]:
stages_keys = list(self.stages_config.keys())
return stages_keys
@property
def distributed_params(self) -> Dict:
return self._config.get("distributed_params", {})
@property
def monitoring_params(self) -> Dict:
return self._config.get("monitoring_params", {})
[docs] def get_state_params(self, stage: str) -> Mapping[str, Any]:
return self.stages_config[stage].get("state_params", {})
def _preprocess_model_for_stage(self, stage: str, model: _Model):
stage_index = self.stages.index(stage)
if stage_index > 0:
checkpoint_path = \
f"{self.logdir}/checkpoints/best.pth"
checkpoint = utils.load_checkpoint(checkpoint_path)
utils.unpack_checkpoint(checkpoint, model=model)
return model
def _postprocess_model_for_stage(self, stage: str, model: _Model):
return model
@staticmethod
def _get_model(**params):
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
model = {}
for key, params_ in params.items():
model[key] = ConfigExperiment._get_model(**params_)
else:
model = MODELS.get_from_params(**params)
return model
[docs] def get_model(self, stage: str):
model_params = self._config["model_params"]
model = self._get_model(**model_params)
model = self._preprocess_model_for_stage(stage, model)
model = self._postprocess_model_for_stage(stage, model)
return model
@staticmethod
def _get_criterion(**params):
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
criterion = {}
for key, params_ in params.items():
criterion[key] = ConfigExperiment._get_criterion(**params_)
else:
criterion = CRITERIONS.get_from_params(**params)
if criterion is not None and torch.cuda.is_available():
criterion = criterion.cuda()
return criterion
[docs] def get_criterion(self, stage: str) -> _Criterion:
criterion_params = \
self.stages_config[stage].get("criterion_params", {})
criterion = self._get_criterion(**criterion_params)
return criterion
def _get_optimizer(
self,
stage: str,
model: Union[_Model, Dict[str, _Model]],
**params
) -> _Optimizer:
# @TODO 1: refactoring; this method is too long
# @TODO 2: load state dicts for schedulers & criteria
layerwise_params = \
params.pop("layerwise_params", OrderedDict())
no_bias_weight_decay = \
params.pop("no_bias_weight_decay", True)
# linear scaling rule from https://arxiv.org/pdf/1706.02677.pdf
lr_scaling_params = params.pop("lr_linear_scaling", None)
if lr_scaling_params:
data_params = dict(self.stages_config[stage]["data_params"])
batch_size = data_params.get("batch_size")
per_gpu_scaling = data_params.get("per_gpu_scaling", False)
distributed_rank = self.distributed_params.get("rank", -1)
distributed = distributed_rank > -1
if per_gpu_scaling and not distributed:
num_gpus = max(1, torch.cuda.device_count())
batch_size *= num_gpus
base_lr = lr_scaling_params.get("lr")
base_batch_size = lr_scaling_params.get("base_batch_size", 256)
lr_scaling = batch_size / base_batch_size
params["lr"] = base_lr * lr_scaling # scale default lr
else:
lr_scaling = 1.0
# getting model parameters
model_key = params.pop("_model", None)
if model_key is None:
assert isinstance(model, nn.Module), \
"model is keyvalue, but optimizer has no specified model"
model_params = utils.process_model_params(
model, layerwise_params, no_bias_weight_decay, lr_scaling
)
elif isinstance(model_key, str):
model_params = utils.process_model_params(
model[model_key], layerwise_params, no_bias_weight_decay,
lr_scaling
)
elif isinstance(model_key, (list, tuple)):
model_params = []
for model_key_ in model_key:
model_params_ = utils.process_model_params(
model[model_key_], layerwise_params, no_bias_weight_decay,
lr_scaling
)
model_params.extend(model_params_)
else:
raise ValueError("unknown type of model_params")
load_from_previous_stage = \
params.pop("load_from_previous_stage", False)
optimizer_key = params.pop("optimizer_key", None)
optimizer = OPTIMIZERS.get_from_params(**params, params=model_params)
if load_from_previous_stage:
checkpoint_path = f"{self.logdir}/checkpoints/best_full.pth"
checkpoint = utils.load_checkpoint(checkpoint_path)
dict2load = optimizer
if optimizer_key is not None:
dict2load = {optimizer_key: optimizer}
utils.unpack_checkpoint(checkpoint, optimizer=dict2load)
# move optimizer to device
device = utils.get_device()
for param in model_params:
param = param["params"][0]
state = optimizer.state[param]
for key, value in state.items():
state[key] = utils.any2device(value, device)
# update optimizer params
for key, value in params.items():
for pg in optimizer.param_groups:
pg[key] = value
return optimizer
[docs] def get_optimizer(
self,
stage: str,
model: Union[_Model, Dict[str, _Model]]
) -> Union[_Optimizer, Dict[str, _Optimizer]]:
optimizer_params = \
self.stages_config[stage].get("optimizer_params", {})
key_value_flag = optimizer_params.pop("_key_value", False)
if key_value_flag:
optimizer = {}
for key, params_ in optimizer_params.items():
# load specified optimizer from checkpoint
optimizer_key = "optimizer_key"
assert optimizer_key not in params_, "keyword reserved"
params_[optimizer_key] = key
optimizer[key] = self._get_optimizer(stage, model, **params_)
else:
optimizer = self._get_optimizer(stage, model, **optimizer_params)
return optimizer
@staticmethod
def _get_scheduler(*, optimizer, **params):
key_value_flag = params.pop("_key_value", False)
if key_value_flag:
scheduler = {}
for key, params_ in params.items():
scheduler[key] = ConfigExperiment._get_scheduler(
optimizer=optimizer, **params_
)
else:
scheduler = SCHEDULERS.get_from_params(
**params, optimizer=optimizer
)
return scheduler
[docs] def get_scheduler(self, stage: str, optimizer) -> _Scheduler:
scheduler_params = \
self.stages_config[stage].get("scheduler_params", {})
scheduler = self._get_scheduler(
optimizer=optimizer, **scheduler_params
)
return scheduler
[docs] def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
data_params = dict(self.stages_config[stage]["data_params"])
batch_size = data_params.pop("batch_size", 1)
num_workers = data_params.pop("num_workers")
drop_last = data_params.pop("drop_last", False)
per_gpu_scaling = data_params.pop("per_gpu_scaling", False)
distributed_rank = self.distributed_params.get("rank", -1)
distributed = distributed_rank > -1
datasets = self.get_datasets(stage=stage, **data_params)
overridden_loaders_params = data_params.pop("loaders_params", {})
assert isinstance(overridden_loaders_params, dict), \
f"{overridden_loaders_params} should be Dict"
loaders = OrderedDict()
for name, ds_ in datasets.items():
assert isinstance(ds_, (Dataset, dict)), \
f"{ds_} should be Dataset or Dict"
overridden_loader_params = overridden_loaders_params.pop(name, {})
assert isinstance(overridden_loader_params, dict), \
f"{overridden_loader_params} should be Dict"
batch_size = overridden_loader_params.pop("batch_size", batch_size)
num_workers = overridden_loader_params.\
pop("num_workers", num_workers)
if per_gpu_scaling and not distributed:
num_gpus = max(1, torch.cuda.device_count())
batch_size *= num_gpus
num_workers *= num_gpus
loader_params = {
"batch_size": batch_size,
"num_workers": num_workers,
"pin_memory": torch.cuda.is_available(),
"drop_last": drop_last,
**overridden_loader_params
}
if isinstance(ds_, Dataset):
loader_params["dataset"] = ds_
elif isinstance(ds_, dict):
assert "dataset" in ds_, \
"You need to specify dataset for dataloader"
loader_params = utils.merge_dicts(ds_, loader_params)
else:
raise NotImplementedError
if distributed:
sampler = loader_params.get("sampler")
if sampler is not None:
assert isinstance(sampler, DistributedSampler)
else:
loader_params["sampler"] = DistributedSampler(
dataset=loader_params["dataset"]
)
loader_params["shuffle"] = (
name.startswith("train")
and loader_params.get("sampler") is None
)
if "batch_sampler" in loader_params:
if distributed:
raise ValueError(
"batch_sampler option is mutually "
"exclusive with distributed"
)
for k in ("batch_size", "shuffle", "sampler", "drop_last"):
loader_params.pop(k, None)
if "worker_init_fn" not in loader_params:
loader_params["worker_init_fn"] = \
lambda x: utils.set_global_seed(self.initial_seed + x)
loaders[name] = DataLoader(**loader_params)
return loaders
@staticmethod
def _get_callback(**params):
wrapper_params = params.pop("_wrapper", None)
callback = CALLBACKS.get_from_params(**params)
if wrapper_params is not None:
wrapper_params["base_callback"] = callback
return ConfigExperiment._get_callback(**wrapper_params)
return callback
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[Callback]":
callbacks_params = (
self.stages_config[stage].get("callbacks_params", {})
)
callbacks = OrderedDict()
for key, callback_params in callbacks_params.items():
callback = self._get_callback(**callback_params)
callbacks[key] = callback
# ! For compatibility with previous versions.
default_callbacks = []
if self._verbose:
default_callbacks.append(("verbose", VerboseLogger))
if not stage.startswith("infer"):
default_callbacks.append(("console", ConsoleLogger))
default_callbacks.append(("tensorboard", TensorboardLogger))
default_callbacks.append(("exception", RaiseExceptionCallback))
for callback_name, callback_fn in default_callbacks:
is_already_present = any(
isinstance(x, callback_fn) for x in callbacks.values()
)
if not is_already_present:
callbacks[callback_name] = callback_fn()
return callbacks
__all__ = ["ConfigExperiment"]