Shortcuts

Source code for catalyst.utils.swa

from typing import List, Union
from collections import OrderedDict
import glob
import os
from pathlib import Path

import torch


def _load_weights(path: str) -> dict:
    """
    Load weights of a model.

    Args:
        path: Path to model weights

    Returns:
        Weights
    """
    weights = torch.load(path, map_location=lambda storage, loc: storage)
    if "model_state_dict" in weights:
        weights = weights["model_state_dict"]
    return weights


[docs]def average_weights(state_dicts: List[dict]) -> OrderedDict: """ Averaging of input weights. Args: state_dicts: Weights to average Raises: KeyError: If states do not match Returns: Averaged weights """ # source https://gist.github.com/qubvel/70c3d5e4cddcde731408f478e12ef87b params_keys = None for i, state_dict in enumerate(state_dicts): model_params_keys = list(state_dict.keys()) if params_keys is None: params_keys = model_params_keys elif params_keys != model_params_keys: raise KeyError( "For checkpoint {}, expected list of params: {}, " "but found: {}".format(i, params_keys, model_params_keys) ) average_dict = OrderedDict() for k in state_dicts[0].keys(): average_dict[k] = torch.div( sum(state_dict[k] for state_dict in state_dicts), len(state_dicts), ) return average_dict
[docs]def get_averaged_weights_by_path_mask( path_mask: str, logdir: Union[str, Path] = None, ) -> OrderedDict: """ Averaging of input weights and saving them. Args: path_mask: globe-like pattern for models to average logdir: Path to logs directory Returns: Averaged weights """ if logdir is None: models_pathes = glob.glob(path_mask) else: models_pathes = glob.glob(os.path.join(logdir, "checkpoints", path_mask)) all_weights = [_load_weights(path) for path in models_pathes] averaged_dict = average_weights(all_weights) return averaged_dict
__all__ = ["average_weights", "get_averaged_weights_by_path_mask"]