Source code for catalyst.utils.checkpoint

from typing import Dict, Union  # isort:skip
from collections import OrderedDict
import os
from pathlib import Path
import shutil

import torch

from .ddp import get_real_module
from .misc import maybe_recursive_call


[docs]def pack_checkpoint( model=None, criterion=None, optimizer=None, scheduler=None, **kwargs ): checkpoint = kwargs if isinstance(model, OrderedDict): raise NotImplementedError() else: model_ = get_real_module(model) checkpoint["model_state_dict"] = maybe_recursive_call( model_, "state_dict" ) for dict2save, name2save in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"] ): if dict2save is None: continue # @TODO refactor with maybe_recursive_call if isinstance(dict2save, dict): for key, value in dict2save.items(): if value is not None: name2save_ = name2save + "_" + str(key) # checkpoint[name2save_] = value name2save_ = name2save_ + "_state_dict" checkpoint[name2save_] = value.state_dict() else: # checkpoint[name2save] = dict2save name2save = name2save + "_state_dict" checkpoint[name2save] = dict2save.state_dict() return checkpoint
[docs]def save_checkpoint( checkpoint: Dict, logdir: Union[Path, str], suffix: str, is_best: bool = False, is_last: bool = False, special_suffix: str = "" ): os.makedirs(logdir, exist_ok=True) filename = f"{logdir}/{suffix}.pth" torch.save(checkpoint, filename) if is_best: shutil.copyfile(filename, f"{logdir}/best{special_suffix}.pth") if is_last: shutil.copyfile(filename, f"{logdir}/last{special_suffix}.pth") return filename
[docs]def unpack_checkpoint( checkpoint, model=None, criterion=None, optimizer=None, scheduler=None ): if model is not None: model = get_real_module(model) maybe_recursive_call( model, "load_state_dict", recursive_args=checkpoint["model_state_dict"] ) for dict2load, name2load in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"] ): if dict2load is None: continue if isinstance(dict2load, dict): for key, value in dict2load.items(): if value is not None: name2load_ = f"{name2load}_{key}_state_dict" value.load_state_dict(checkpoint[name2load_]) else: name2load = f"{name2load}_state_dict" dict2load.load_state_dict(checkpoint[name2load])
[docs]def load_checkpoint(filepath): checkpoint = torch.load( filepath, map_location=lambda storage, loc: storage ) return checkpoint