Source code for catalyst.utils.checkpoint
from typing import Dict, Union # isort:skip
from collections import OrderedDict
import os
from pathlib import Path
import shutil
import torch
from .ddp import get_real_module
from .misc import maybe_recursive_call
[docs]def pack_checkpoint(
model=None, criterion=None, optimizer=None, scheduler=None, **kwargs
):
checkpoint = kwargs
if isinstance(model, OrderedDict):
raise NotImplementedError()
else:
model_ = get_real_module(model)
checkpoint["model_state_dict"] = maybe_recursive_call(
model_, "state_dict"
)
for dict2save, name2save in zip(
[criterion, optimizer, scheduler],
["criterion", "optimizer", "scheduler"]
):
if dict2save is None:
continue
# @TODO refactor with maybe_recursive_call
if isinstance(dict2save, dict):
for key, value in dict2save.items():
if value is not None:
name2save_ = name2save + "_" + str(key)
# checkpoint[name2save_] = value
name2save_ = name2save_ + "_state_dict"
checkpoint[name2save_] = value.state_dict()
else:
# checkpoint[name2save] = dict2save
name2save = name2save + "_state_dict"
checkpoint[name2save] = dict2save.state_dict()
return checkpoint
[docs]def save_checkpoint(
checkpoint: Dict,
logdir: Union[Path, str],
suffix: str,
is_best: bool = False,
is_last: bool = False,
special_suffix: str = ""
):
os.makedirs(logdir, exist_ok=True)
filename = f"{logdir}/{suffix}.pth"
torch.save(checkpoint, filename)
if is_best:
shutil.copyfile(filename, f"{logdir}/best{special_suffix}.pth")
if is_last:
shutil.copyfile(filename, f"{logdir}/last{special_suffix}.pth")
return filename
[docs]def unpack_checkpoint(
checkpoint, model=None, criterion=None, optimizer=None, scheduler=None
):
if model is not None:
model = get_real_module(model)
maybe_recursive_call(
model,
"load_state_dict",
recursive_args=checkpoint["model_state_dict"]
)
for dict2load, name2load in zip(
[criterion, optimizer, scheduler],
["criterion", "optimizer", "scheduler"]
):
if dict2load is None:
continue
if isinstance(dict2load, dict):
for key, value in dict2load.items():
if value is not None:
name2load_ = f"{name2load}_{key}_state_dict"
value.load_state_dict(checkpoint[name2load_])
else:
name2load = f"{name2load}_state_dict"
dict2load.load_state_dict(checkpoint[name2load])
[docs]def load_checkpoint(filepath):
checkpoint = torch.load(
filepath, map_location=lambda storage, loc: storage
)
return checkpoint