Shortcuts

Source code for catalyst.dl.experiment.config

from typing import Any, Callable, Dict, List, Mapping, Union
from collections import OrderedDict
from copy import deepcopy

import torch
from torch import nn
from torch.utils.data import DataLoader  # noqa F401

from catalyst.core import _StageBasedExperiment
from catalyst.data import Augmentor, AugmentorCompose
from catalyst.dl import (
    Callback,
    CheckpointCallback,
    CheckRunCallback,
    ConsoleLogger,
    CriterionCallback,
    ExceptionCallback,
    MetricManagerCallback,
    OptimizerCallback,
    SchedulerCallback,
    TensorboardLogger,
    TimerCallback,
    utils,
    ValidationManagerCallback,
    VerboseLogger,
)
from catalyst.dl.registry import (
    CALLBACKS,
    CRITERIONS,
    MODELS,
    OPTIMIZERS,
    SCHEDULERS,
    TRANSFORMS,
)
from catalyst.tools.typing import Criterion, Model, Optimizer, Scheduler


[docs]class ConfigExperiment(_StageBasedExperiment): """ Experiment created from a configuration file. """ STAGE_KEYWORDS = [ "criterion_params", "optimizer_params", "scheduler_params", "data_params", "transform_params", "state_params", "callbacks_params", ]
[docs] def __init__(self, config: Dict): """ Args: config (dict): dictionary of parameters """ self._config: Dict = deepcopy(config) self._initial_seed: int = self._config.get("args", {}).get("seed", 42) self._verbose: bool = self._config.get("args", {}).get( "verbose", False ) self._check_run: bool = self._config.get("args", {}).get( "check", False ) self._check_time: bool = self._config.get("args", {}).get( "timeit", False ) self.__prepare_logdir() self._config["stages"]["state_params"] = utils.merge_dicts( deepcopy(self._config["stages"].get("state_params", {})), deepcopy(self._config.get("args", {})), {"logdir": self._logdir}, ) self.stages_config: Dict = self._get_stages_config( self._config["stages"] )
def __prepare_logdir(self): EXCLUDE_TAG = "none" logdir = self._config.get("args", {}).get("logdir", None) baselogdir = self._config.get("args", {}).get("baselogdir", None) if logdir is not None and logdir.lower() != EXCLUDE_TAG: self._logdir = logdir elif baselogdir is not None and baselogdir.lower() != EXCLUDE_TAG: logdir_postfix = self._get_logdir(self._config) self._logdir = f"{baselogdir}/{logdir_postfix}" else: self._logdir = None def _get_stages_config(self, stages_config: Dict): stages_defaults = {} stages_config_out = OrderedDict() for key in self.STAGE_KEYWORDS: stages_defaults[key] = deepcopy(stages_config.get(key, {})) for stage in stages_config: if ( stage in self.STAGE_KEYWORDS or stages_config.get(stage) is None ): continue stages_config_out[stage] = {} for key in self.STAGE_KEYWORDS: stages_config_out[stage][key] = utils.merge_dicts( deepcopy(stages_defaults.get(key, {})), deepcopy(stages_config[stage].get(key, {})), ) return stages_config_out def _get_logdir(self, config: Dict) -> str: timestamp = utils.get_utcnow_time() config_hash = utils.get_short_hash(config) logdir = f"{timestamp}.{config_hash}" return logdir @property def initial_seed(self) -> int: """Experiment's initial seed value.""" return self._initial_seed @property def logdir(self): """Path to the directory where the experiment logs.""" return self._logdir @property def stages(self) -> List[str]: """Experiment's stage names.""" stages_keys = list(self.stages_config.keys()) # @TODO: return the feature # # Change start `stages_keys` if resume data were founded # state_params = self.get_state_params(stages_keys[0]) # resume, resume_dir = [ # state_params.get(key, None) for key in ["resume", "resume_dir"] # ] # # if resume_dir is not None: # resume = resume_dir / str(resume) # # if resume is not None and Path(resume).is_file(): # checkpoint = utils.load_checkpoint(resume) # start_stage = checkpoint["stage"] # start_idx = stages_keys.index(start_stage) # stages_keys = stages_keys[start_idx:] return stages_keys @property def distributed_params(self) -> Dict: """Dict with the parameters for distributed and FP16 methond.""" return self._config.get("distributed_params", {})
[docs] def get_state_params(self, stage: str) -> Mapping[str, Any]: """Returns the state parameters for a given stage.""" return self.stages_config[stage].get("state_params", {})
def _preprocess_model_for_stage(self, stage: str, model: Model): stage_index = self.stages.index(stage) # @TODO: remove to callbacks if stage_index > 0: checkpoint_path = f"{self.logdir}/checkpoints/best.pth" checkpoint = utils.load_checkpoint(checkpoint_path) utils.unpack_checkpoint(checkpoint, model=model) return model def _postprocess_model_for_stage(self, stage: str, model: Model): return model @staticmethod def _get_model(**params): key_value_flag = params.pop("_key_value", False) if key_value_flag: model = {} for key, params_ in params.items(): model[key] = ConfigExperiment._get_model(**params_) model = nn.ModuleDict(model) else: model = MODELS.get_from_params(**params) return model
[docs] def get_model(self, stage: str): """Returns the model for a given stage.""" model_params = self._config["model_params"] model = self._get_model(**model_params) model = self._preprocess_model_for_stage(stage, model) model = self._postprocess_model_for_stage(stage, model) return model
@staticmethod def _get_criterion(**params): key_value_flag = params.pop("_key_value", False) if key_value_flag: criterion = {} for key, params_ in params.items(): criterion[key] = ConfigExperiment._get_criterion(**params_) else: criterion = CRITERIONS.get_from_params(**params) if criterion is not None and torch.cuda.is_available(): criterion = criterion.cuda() return criterion
[docs] def get_criterion(self, stage: str) -> Criterion: """Returns the criterion for a given stage.""" criterion_params = self.stages_config[stage].get( "criterion_params", {} ) criterion = self._get_criterion(**criterion_params) return criterion
def _get_optimizer( self, stage: str, model: Union[Model, Dict[str, Model]], **params ) -> Optimizer: # @TODO 1: refactoring; this method is too long # @TODO 2: load state dicts for schedulers & criterion layerwise_params = params.pop("layerwise_params", OrderedDict()) no_bias_weight_decay = params.pop("no_bias_weight_decay", True) # linear scaling rule from https://arxiv.org/pdf/1706.02677.pdf lr_scaling_params = params.pop("lr_linear_scaling", None) if lr_scaling_params: data_params = dict(self.stages_config[stage]["data_params"]) batch_size = data_params.get("batch_size") per_gpu_scaling = data_params.get("per_gpu_scaling", False) distributed_rank = utils.get_rank() distributed = distributed_rank > -1 if per_gpu_scaling and not distributed: num_gpus = max(1, torch.cuda.device_count()) batch_size *= num_gpus base_lr = lr_scaling_params.get("lr") base_batch_size = lr_scaling_params.get("base_batch_size", 256) lr_scaling = batch_size / base_batch_size params["lr"] = base_lr * lr_scaling # scale default lr else: lr_scaling = 1.0 # getting model parameters model_key = params.pop("_model", None) if model_key is None: assert isinstance( model, nn.Module ), "model is key-value, but optimizer has no specified model" model_params = utils.process_model_params( model, layerwise_params, no_bias_weight_decay, lr_scaling ) elif isinstance(model_key, str): model_params = utils.process_model_params( model[model_key], layerwise_params, no_bias_weight_decay, lr_scaling, ) elif isinstance(model_key, (list, tuple)): model_params = [] for model_key_ in model_key: model_params_ = utils.process_model_params( model[model_key_], layerwise_params, no_bias_weight_decay, lr_scaling, ) model_params.extend(model_params_) else: raise ValueError("unknown type of model_params") load_from_previous_stage = params.pop( "load_from_previous_stage", False ) optimizer_key = params.pop("optimizer_key", None) optimizer = OPTIMIZERS.get_from_params(**params, params=model_params) if load_from_previous_stage and self.stages.index(stage) != 0: checkpoint_path = f"{self.logdir}/checkpoints/best_full.pth" checkpoint = utils.load_checkpoint(checkpoint_path) dict2load = optimizer if optimizer_key is not None: dict2load = {optimizer_key: optimizer} utils.unpack_checkpoint(checkpoint, optimizer=dict2load) # move optimizer to device device = utils.get_device() for param in model_params: param = param["params"][0] state = optimizer.state[param] for key, value in state.items(): state[key] = utils.any2device(value, device) # update optimizer params for key, value in params.items(): for pg in optimizer.param_groups: pg[key] = value return optimizer
[docs] def get_optimizer( self, stage: str, model: Union[Model, Dict[str, Model]] ) -> Union[Optimizer, Dict[str, Optimizer]]: """Returns the optimizer for a given stage. Args: stage (str): stage name model (Union[Model, Dict[str, Model]]): model or a dict of models """ optimizer_params = self.stages_config[stage].get( "optimizer_params", {} ) key_value_flag = optimizer_params.pop("_key_value", False) if key_value_flag: optimizer = {} for key, params_ in optimizer_params.items(): # load specified optimizer from checkpoint optimizer_key = "optimizer_key" assert optimizer_key not in params_, "keyword reserved" params_[optimizer_key] = key optimizer[key] = self._get_optimizer(stage, model, **params_) else: optimizer = self._get_optimizer(stage, model, **optimizer_params) return optimizer
@staticmethod def _get_scheduler(*, optimizer, **params): key_value_flag = params.pop("_key_value", False) if key_value_flag: scheduler = {} for key, params_ in params.items(): scheduler[key] = ConfigExperiment._get_scheduler( optimizer=optimizer, **params_ ) else: scheduler = SCHEDULERS.get_from_params( **params, optimizer=optimizer ) return scheduler
[docs] def get_scheduler(self, stage: str, optimizer: Optimizer) -> Scheduler: """Returns the scheduler for a given stage.""" scheduler_params = self.stages_config[stage].get( "scheduler_params", {} ) scheduler = self._get_scheduler( optimizer=optimizer, **scheduler_params ) return scheduler
@staticmethod def _get_transform(**params) -> Callable: key_value_flag = params.pop("_key_value", False) if key_value_flag: transforms_composition = { key: ConfigExperiment._get_transform(**params_) for key, params_ in params.items() } transform = AugmentorCompose( { key: Augmentor( dict_key=key, augment_fn=transform, input_key=key, output_key=key, ) for key, transform in transforms_composition.items() } ) else: if "transforms" in params: transforms_composition = [ ConfigExperiment._get_transform(**transform_params) for transform_params in params["transforms"] ] params.update(transforms=transforms_composition) transform = TRANSFORMS.get_from_params(**params) return transform
[docs] def get_transforms( self, stage: str = None, dataset: str = None ) -> Callable: """Returns transform for a given stage and mode. Args: stage (str): stage name dataset (str): dataset name (e.g. "train", "valid"), will be used only if the value of `_key_value`` is ``True`` """ transform_params = deepcopy( self.stages_config[stage].get("transform_params", {}) ) key_value_flag = transform_params.pop("_key_value", False) if key_value_flag: transform_params = transform_params.get(dataset, {}) transform = self._get_transform(**transform_params) if transform is None: def transform(dict_): return dict_ elif not isinstance(transform, AugmentorCompose): transform_ = transform def transform(dict_): return transform_(**dict_) return transform
[docs] def get_loaders( self, stage: str, epoch: int = None, ) -> "OrderedDict[str, DataLoader]": """Returns the loaders for a given stage.""" data_params = dict(self.stages_config[stage]["data_params"]) loaders = utils.get_loaders_from_params( get_datasets_fn=self.get_datasets, initial_seed=self.initial_seed, stage=stage, **data_params, ) return loaders
@staticmethod def _get_callback(**params): wrapper_params = params.pop("_wrapper", None) callback = CALLBACKS.get_from_params(**params) if wrapper_params is not None: wrapper_params["base_callback"] = callback return ConfigExperiment._get_callback(**wrapper_params) return callback
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[Callback]": """Returns the callbacks for a given stage.""" callbacks_params = self.stages_config[stage].get( "callbacks_params", {} ) callbacks = OrderedDict() for key, callback_params in callbacks_params.items(): callback = self._get_callback(**callback_params) callbacks[key] = callback default_callbacks = [] if self._verbose: default_callbacks.append(("_verbose", VerboseLogger)) if self._check_time: default_callbacks.append(("_timer", TimerCallback)) if self._check_run: default_callbacks.append(("_check", CheckRunCallback)) if not stage.startswith("infer"): default_callbacks.append(("_metrics", MetricManagerCallback)) default_callbacks.append( ("_validation", ValidationManagerCallback) ) default_callbacks.append(("_console", ConsoleLogger)) if self.logdir is not None: default_callbacks.append(("_saver", CheckpointCallback)) default_callbacks.append(("_tensorboard", TensorboardLogger)) if self.stages_config[stage].get("criterion_params", {}): default_callbacks.append(("_criterion", CriterionCallback)) if self.stages_config[stage].get("optimizer_params", {}): default_callbacks.append(("_optimizer", OptimizerCallback)) if self.stages_config[stage].get("scheduler_params", {}): default_callbacks.append(("_scheduler", SchedulerCallback)) default_callbacks.append(("_exception", ExceptionCallback)) for callback_name, callback_fn in default_callbacks: is_already_present = False for x in callbacks.values(): if isinstance(x, callback_fn): is_already_present = True break if not is_already_present: callbacks[callback_name] = callback_fn() return callbacks
__all__ = ["ConfigExperiment"]