from typing import Dict, Union
from collections import OrderedDict
import os
from pathlib import Path
import shutil
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.tools.metric_handler import MetricHandler
from catalyst.utils.config import save_config
def _save_checkpoint(
checkpoint: Dict,
runner: "IRunner",
logdir: Union[Path, str],
suffix: str,
is_best: bool = False,
is_last: bool = False,
extra_suffix: str = "",
) -> Union[Path, str]:
"""Saving checkpoint to a file.
Args:
checkpoint: data to save.
runner: current runner
logdir: directory where checkpoint should be stored.
suffix: checkpoint file name.
is_best: if ``True`` then also will be generated best checkpoint file.
is_last: if ``True`` then also will be generated last checkpoint file.
extra_suffix: suffix to use for saving best/last checkpoints.
Returns:
path to saved checkpoint
"""
os.makedirs(logdir, exist_ok=True)
filename = f"{logdir}/{suffix}.pth"
runner.engine.save_checkpoint(checkpoint, filename)
if is_best:
shutil.copyfile(filename, f"{logdir}/best{extra_suffix}.pth")
if is_last:
shutil.copyfile(filename, f"{logdir}/last{extra_suffix}.pth")
return filename
def _load_checkpoint(*, filename, runner: "IRunner", load_full: bool = True) -> None:
"""
Load checkpoint from a file.
Arguments:
filename: path to checkpoint
runner: current runner
load_full: if true (default) then will be performed
loading states for criterion, optimizer and scheduler.
File should contain keys required for
loading model (``'model_state_dict'``),
criterion (``'criterion_state_dict'``) (only for full load),
optimizer (``'optimizer_state_dict'``),
scheduler (``'scheduler_state_dict'``).
Raises:
FileNotFoundError: when file specified in ``filename``
is not exist.
"""
is_master_process = runner.engine.is_master_process
if not os.path.isfile(filename):
raise FileNotFoundError(f"No checkpoint found at {filename}!")
if is_master_process:
print(f"=> Loading checkpoint {filename}")
checkpoint = runner.engine.load_checkpoint(filename)
if not runner.stage_key.startswith("infer") and load_full:
runner.global_epoch_step = checkpoint["global_epoch_step"]
runner.global_batch_step = checkpoint["global_batch_step"]
runner.global_sample_step = checkpoint["global_sample_step"]
if load_full:
runner.engine.unpack_checkpoint(
checkpoint,
model=runner.model,
criterion=runner.criterion,
optimizer=runner.optimizer,
scheduler=runner.scheduler,
)
if is_master_process:
print(
f"full checkpoint {filename} loaded "
f"(global epoch {checkpoint['global_epoch_step']}, "
f"stage {checkpoint['stage_key']}, "
f"epoch {checkpoint['stage_epoch_step']})"
)
else:
runner.engine.unpack_checkpoint(checkpoint, model=runner.model)
if is_master_process:
print(
f"model checkpoint {filename} loaded "
f"(global epoch {checkpoint['global_epoch_step']}, "
f"stage {checkpoint['stage_key']}, "
f"epoch {checkpoint['stage_epoch_step']})"
)
def _get_required_files(logdir: str, load_map: Dict[str, str]) -> Dict[str, str]:
"""
Generate required files for load model, criterion,
scheduler, optimizer specified in ``load_map``.
Expected that ``load_map`` contains keys:
``"model"``, ``"criterion"``, ``"optimizer"``, ``"scheduler"``.
Otherwise an empty dict will be generated.
Arguments:
logdir: directory with logs
load_map (Dict[str, str]): dict with specification what should be loaded
Returns:
Mapping from file to parts required from this file.
"""
if load_map is None:
return OrderedDict()
default_states = {"best", "best_full", "last", "last_full"}
required_full_checkpoint = ["criterion", "optimizer", "scheduler"]
steps = ["global_epoch_step", "global_batch_step", "global_sample_step"]
experiment_parts = ["model"] + required_full_checkpoint + steps
# keep required parts
experiment_parts = list(filter(lambda part: part in load_map, experiment_parts))
# avoid unnecessary loading
if "model" in experiment_parts and len(experiment_parts) > 1:
required_full_checkpoint.append("model")
# mapping - <filename>: <list of parts to load from this file>
required_files = OrderedDict()
for part in experiment_parts:
fname = load_map[part]
required_full = fname.endswith("_full")
# specified default state
if fname in default_states:
if part in required_full_checkpoint and not required_full:
fname = fname + "_full"
fname = f"{logdir}/{fname}.pth"
# in other case specified path to checkpoint
required_files[fname] = required_files.get(fname, []) + [part]
return required_files
def _load_states_from_file_map(
*, logdir: str, runner: "IRunner", load_map: Dict[str, str]
) -> None:
"""
Load state of a model, criterion, optimizer, scheduler
from files specified in ``load_map``.
Arguments:
logdir: directory with logs
runner: current runner
load_map (Dict[str, str]): dict with mappings to load.
Expected keys - ``'model'``, ``'criterion'``
``'optimizer'``, ``'scheduler'``, other keys will be
ignored.
Expected that values will be states (``'best'``,
``"best_full"``, ``"last"``, ``"last_full"``) or
path to checkpoint.
**NOTE:** for successful load criterion, optimizer,
scheduler states required a full checkpoint.
Raises:
FileNotFoundError: when file/state specified in ``load_map``
is not exist.
"""
required_files = _get_required_files(logdir, load_map)
for filename in required_files.keys():
if not os.path.isfile(filename):
raise FileNotFoundError(f"No checkpoint found at {filename}!")
# extracting parts from files
for filename, parts_to_load in required_files.items():
print(f"=> Loading {', '.join(parts_to_load)} from {filename}")
checkpoint = runner.engine.load_checkpoint(filename)
to_unpack = {part: getattr(runner, part) for part in parts_to_load}
runner.engine.unpack_checkpoint(checkpoint, **to_unpack)
# hotfix
if "global_epoch_step" in to_unpack:
runner.global_epoch_step = checkpoint["global_epoch_step"]
if "global_batch_step" in to_unpack:
runner.global_batch_step = checkpoint["global_batch_step"]
if "global_sample_step" in to_unpack:
runner.global_sample_step = checkpoint["global_sample_step"]
print(f" loaded: {', '.join(parts_to_load)}")
def _load_runner(
logdir: str, runner: "IRunner", mapping: Union[str, Dict[str, str]], load_full: bool = False
) -> None:
"""
Selects a loading method based on type of mapping.
Args:
logdir: logdir with checkpoints
runner: current runner
mapping: mapping to use for loading
load_full: load a full model, used only when mapping type is string
"""
if isinstance(mapping, str):
if mapping in {"best", "best_full", "last", "last_full"}:
checkpoint = f"{logdir}/{mapping}.pth"
else:
checkpoint = mapping
_load_checkpoint(filename=checkpoint, runner=runner, load_full=load_full)
elif isinstance(mapping, dict):
_load_states_from_file_map(logdir=logdir, runner=runner, load_map=mapping)
class ICheckpointCallback(Callback):
"""Criterion callback interface, abstraction over checkpoint step."""
pass
[docs]class CheckpointCallback(ICheckpointCallback):
"""Checkpoint callback to save/restore your model/criterion/optimizer/scheduler.
Args:
logdir: directory to store chekpoints
loader_key: loader key for best model selection (based on metric score over the dataset)
metric_key: metric key for best model selection (based on metric score over the dataset)
minimize: boolean flag to minimize the required metric
min_delta: minimal delta for metric improve
save_n_best: number of best checkpoint to keep,
if ``0`` then store only last state of model and
``load_on_stage_end`` should be one of
``last`` or ``last_full``.
load_on_stage_start (str or Dict[str, str]): load specified
state/model at stage start.
If passed **string** then will be performed initialization from
specified state (``best``/``best_full``/``last``/``last_full``)
or checkpoint file.
If passed **dict** then will be performed initialization only
for specified parts - model, criterion, optimizer, scheduler.
Example:
>>> # possible checkpoints to use:
>>> # "best"/"best_full"/"last"/"last_full"
>>> # or path to specific checkpoint
>>> to_load = {
>>> "model": "path/to/checkpoint.pth",
>>> "criterion": "best",
>>> "optimizer": "last_full",
>>> "scheduler": "best_full",
>>> }
>>> CheckpointCallback(load_on_stage_start=to_load)
All other keys instead of ``"model"``, ``"criterion"``,
``"optimizer"`` and ``"scheduler"`` will be ignored.
If ``None`` or an empty dict (or dict without mentioned
above keys) then no action is required at stage start and:
- Config API - will be used best state of model
- Notebook API - no action will be performed (will be used the last state)
**NOTE:** Loading will be performed on all stages except first.
**NOTE:** Criterion, optimizer and scheduler are optional keys
and should be loaded from full checkpoint.
Model state can be loaded from any checkpoint.
When dict contains keys for model and some other part
(for example ``{"model": "last", "optimizer": "last"}``)
and they match in prefix (``"best"`` and
``"best_full"``) then will be loaded full checkpoint
because it contains required states.
load_on_stage_end (str or Dict[str, str]): load specified
state/model at stage end.
If passed **string** then will be performed initialization from
specified state (``best``/``best_full``/``last``/``last_full``)
or checkpoint file.
If passed **dict** then will be performed initialization only
for specified parts - model, criterion, optimizer, scheduler.
Logic for dict is the same as for ``load_on_stage_start``.
If ``None`` then no action is required at stage end
and will be used the last runner.
**NOTE:** Loading will be performed always at stage end.
metrics_filename: filename to save metrics
in checkpoint folder.
Must ends on ``.json`` or ``.yml``
mode: checkpoining mode, could be ``all``, ``full``, ``model``
use_logdir_postfix: boolean flag to use extra prefix ``checkpoints`` for logdir
use_runner_logdir: boolean flag to use ``runner._logdir`` as logdir
Examples:
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# sample data
num_users, num_features, num_items = int(1e4), int(1e1), 10
X = torch.rand(num_users, num_features)
y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)
# pytorch loaders
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, num_items)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
# model training
runner = dl.SupervisedRunner(
input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
num_epochs=3,
verbose=True,
callbacks=[
dl.BatchTransformCallback(
transform=torch.sigmoid,
scope="on_batch_end",
input_key="logits",
output_key="scores"
),
dl.CriterionCallback(
input_key="logits", target_key="targets", metric_key="loss"
),
dl.AUCCallback(input_key="scores", target_key="targets"),
dl.HitrateCallback(
input_key="scores", target_key="targets", topk_args=(1, 3, 5)
),
dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
dl.OptimizerCallback(metric_key="loss"),
dl.SchedulerCallback(),
dl.CheckpointCallback(
logdir="./logs", loader_key="valid", metric_key="loss", minimize=True
),
]
)
.. note::
Please follow the `minimal examples`_ sections for more use cases.
.. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples
"""
[docs] def __init__(
self,
logdir: str = None,
# model selection info
loader_key: str = None,
metric_key: str = None,
minimize: bool = None,
min_delta: float = 1e-6,
save_n_best: int = 1,
# loading info
load_on_stage_start: Union[str, Dict[str, str]] = None,
load_on_stage_end: Union[str, Dict[str, str]] = None,
# resume: str = None,
# resume_dir: str = None,
# checkpointer info
metrics_filename: str = "_metrics.json",
mode: str = "all",
use_logdir_postfix: bool = False,
use_runner_logdir: bool = False,
):
"""Init."""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
possible_states = {
None,
"best",
"last",
"best_full",
"last_full",
}
assert save_n_best >= 0
if save_n_best == 0:
assert load_on_stage_end in (None, "last", "last_full")
if isinstance(load_on_stage_start, str):
assert load_on_stage_start in possible_states
if isinstance(load_on_stage_end, str):
assert load_on_stage_end in possible_states
# if resume_dir is not None:
# assert resume is not None
if loader_key is not None or metric_key is not None:
assert loader_key is not None and metric_key is not None, (
"For checkpoint selection `CheckpointCallback` "
"requires both `loader_key` and `metric_key` specified."
)
self._use_model_selection = True
self.minimize = minimize if minimize is not None else True # loss-oriented selection
else:
self._use_model_selection = False
self.minimize = False # epoch-num-oriented selection
assert mode in (
"all",
"full",
"model",
), "`CheckpointCallback` could work only in `all`, `full` or `model` modes."
# checkpointer info
self.logdir = logdir
self.mode = mode
self.metrics_filename = metrics_filename
self.use_logdir_postfix = use_logdir_postfix
self.use_runner_logdir = use_runner_logdir
assert (
self.logdir is not None or self.use_runner_logdir
), "CheckpointCallback requires specified `logdir`"
# model selection info
self.loader_key = loader_key
self.metric_key = metric_key
self.is_better = MetricHandler(minimize=minimize, min_delta=min_delta)
self.save_n_best = save_n_best
# list with topN metrics [(score, filepath, stage_key, stage_epoch_step, epoch metrics)]
self.top_best_metrics = []
self.best_score = None
# loading info
self.load_on_stage_start = load_on_stage_start
self.load_on_stage_end = load_on_stage_end
# self.resume = resume
# self.resume_dir = resume_dir
def _pack_checkpoint(self, runner: "IRunner"):
checkpoint = runner.engine.pack_checkpoint(
model=runner.model,
criterion=runner.criterion,
optimizer=runner.optimizer,
scheduler=runner.scheduler,
# experiment info
run_key=runner.run_key,
global_epoch_step=runner.global_epoch_step,
global_batch_step=runner.global_batch_step,
global_sample_step=runner.global_sample_step,
# stage info
stage_key=runner.stage_key,
stage_epoch_step=runner.stage_epoch_step,
stage_batch_step=runner.stage_batch_step,
stage_sample_step=runner.stage_sample_step,
# epoch info
epoch_metrics={k: dict(v) for k, v in runner.epoch_metrics.items()},
# loader info
loader_key=runner.loader_key,
loader_batch_step=runner.loader_batch_step,
loader_sample_step=runner.loader_sample_step,
# checkpointer info
checkpointer_loader_key=self.loader_key,
checkpointer_metric_key=self.metric_key,
checkpointer_minimize=self.minimize,
)
return checkpoint
def _save_checkpoint(
self, runner: IRunner, checkpoint: Dict, is_best: bool, is_last: bool
) -> str:
"""
Saves checkpoints: full with model/criterion/optimizer/scheduler
and truncated with model only.
Args:
runner: current runner.
checkpoint: data to save.
is_best: if ``True`` then also will be generated best checkpoint file.
is_last: if ``True`` then also will be generated last checkpoint file.
Returns:
path to saved checkpoint
"""
logdir = Path(f"{self.logdir}/")
suffix = f"{runner.stage_key}.{runner.stage_epoch_step}"
checkpoint_path = None
if self.mode in ("all", "full"):
checkpoint_path = _save_checkpoint(
runner=runner,
logdir=logdir,
checkpoint=checkpoint,
suffix=f"{suffix}_full",
is_best=is_best,
is_last=is_last,
extra_suffix="_full",
)
if self.mode in ("all", "model"):
exclude = ["criterion", "optimizer", "scheduler"]
checkpoint_path = _save_checkpoint(
runner=runner,
checkpoint={
key: value
for key, value in checkpoint.items()
if all(z not in key for z in exclude)
},
logdir=logdir,
suffix=suffix,
is_best=is_best,
is_last=is_last,
)
return checkpoint_path
def _truncate_checkpoints(self) -> None:
self.top_best_metrics = sorted(
self.top_best_metrics, key=lambda x: x[0], reverse=not self.minimize
)
if len(self.top_best_metrics) > self.save_n_best:
last_item = self.top_best_metrics.pop(-1)
last_filepath = Path(last_item[1])
last_filepaths = last_filepath.parent.glob(last_filepath.name.replace(".pth", "*"))
for filepath in last_filepaths:
os.remove(filepath)
def _prepare_metrics_log(self, last_epoch_score: float, last_epoch_metrics: Dict) -> Dict:
top_best_checkpoints = [
(Path(filepath).stem, {**epoch_metrics, **{"_score_": score}})
for (score, filepath, _, _, epoch_metrics) in self.top_best_metrics
]
if self.save_n_best > 0:
best_epoch_score = top_best_checkpoints[0][0]
best_epoch_metrics = top_best_checkpoints[0][-1]
metrics = [
("best", {**best_epoch_metrics, **{"_score_": best_epoch_score}}),
("last", {**last_epoch_metrics, **{"_score_": last_epoch_score}}),
] + top_best_checkpoints
else:
metrics = [("last", {**last_epoch_metrics, **{"_score_": last_epoch_score}})]
return OrderedDict(metrics)
def on_stage_start(self, runner: "IRunner") -> None:
"""Setup model for stage.
.. note::
If CheckpointCallback initialized with
``resume`` (as path to checkpoint file)
or ``resume`` (as filename)
and ``resume_dir`` (as directory with file)
then will be performed loading checkpoint.
Raises:
FileNotFoundError: if specified load_on_stage_start
but checkpoint file is missing.
Args:
runner: current runner
"""
if runner.is_infer_stage:
return
# @TODO: very tricky hack, should be removed
if self.logdir is None and self.use_runner_logdir:
self.logdir = getattr(runner, "_logdir", None)
if self.use_logdir_postfix:
self.logdir = os.path.join(self.logdir, "checkpoints")
# @TODO:
# # Use a barrier() to make sure that process 1 loads the model after process 0 saves it.
# dist.barrier()
# # configure map_location properly
# map_location = {"cuda:%d" % 0: "cuda:%d" % rank}
# ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=map_location))
# Use a barrier() to make sure that all processes have finished reading the checkpoint
# dist.barrier()
is_first_stage = list(runner.stages).index(runner.stage_key) == 0
if self.load_on_stage_start is not None and not is_first_stage:
need_full = False
file_exists = False
if isinstance(self.load_on_stage_start, str):
need_full = self.load_on_stage_start.endswith("full")
use_file = os.path.join(self.logdir, f"{self.load_on_stage_start}.pth")
file_exists = os.path.isfile(use_file)
if not file_exists:
raise FileNotFoundError(f"Missing file '{use_file}'!") # noqa: F821
elif isinstance(self.load_on_stage_start, dict):
required_files = _get_required_files(self.logdir, self.load_on_stage_start).keys()
file_exists = True
for use_file in required_files:
if not os.path.isfile(use_file):
file_exists = False
raise FileNotFoundError(f"Missing file '{use_file}'!")
if self.load_on_stage_start is not None and file_exists:
_load_runner(
logdir=self.logdir,
runner=runner,
mapping=self.load_on_stage_start,
load_full=need_full,
)
# if getattr(runner, "resume", None) is not None:
# self.resume = runner.resume
# runner.resume = None
# elif getattr(runner, "autoresume", None) is not None:
# self.resume_dir = runner.logdir / "checkpoints"
# self.resume = f"{runner.autoresume}_full.pth"
# runner.autoresume = None
#
# for key in self._keys_from_runner:
# value = getattr(runner, key, None)
# if value is not None:
# setattr(self, key, value)
#
# if self.resume_dir is not None:
# self.resume = str(self.resume_dir) + "/" + str(self.resume)
#
# if self.resume is not None:
# _load_runner(logdir=self.logdir, runner=runner, mapping=self.resume, load_full=True)
# self.resume = None
# else:
# checkpoint_exists = False
# need_load_full = False
# if isinstance(self.load_on_stage_start, str):
# checkpoint_exists =
# os.path.isfile(f"{self.logdir}/{self.load_on_stage_start}.pth")
# need_load_full = self.load_on_stage_start.endswith("full")
# elif isinstance(self.load_on_stage_start, dict):
# required_files =
# _get_required_files(self.logdir, self.load_on_stage_start).keys()
# checkpoint_exists = all(os.path.isfile(file) for file in required_files)
#
# if self.load_on_stage_start is not None and checkpoint_exists:
# _load_runner(
# logdir=self.logdir,
# runner=runner,
# mapping=self.load_on_stage_start,
# load_full=need_load_full,
# )
def on_epoch_end(self, runner: "IRunner") -> None:
"""
Collects and saves checkpoint after epoch.
Args:
runner: current runner
"""
if runner.is_infer_stage:
return
if runner.engine.is_ddp and not runner.engine.is_master_process:
return
if self._use_model_selection:
# score model based on the specified metric
score = runner.epoch_metrics[self.loader_key][self.metric_key]
else:
# score model based on epoch number
score = runner.global_epoch_step
is_best = False
if self.best_score is None or self.is_better(score, self.best_score):
self.best_score = score
is_best = True
if self.save_n_best > 0:
# pack checkpoint
checkpoint = self._pack_checkpoint(runner)
# save checkpoint
checkpoint_path = self._save_checkpoint(
runner=runner, checkpoint=checkpoint, is_best=is_best, is_last=True
)
# add metrics to records
metrics_record = (
float(score),
checkpoint_path,
runner.stage_key,
runner.stage_epoch_step,
dict(runner.epoch_metrics),
)
self.top_best_metrics.append(metrics_record)
# truncate checkpoints
self._truncate_checkpoints()
# save checkpoint metrics
metrics_log = self._prepare_metrics_log(float(score), dict(runner.epoch_metrics))
save_config(metrics_log, f"{self.logdir}/{self.metrics_filename}")
def on_stage_end(self, runner: "IRunner") -> None:
"""
Show information about best checkpoints during the stage and
load model specified in ``load_on_stage_end``.
Args:
runner: current runner
"""
if runner.is_infer_stage:
return
if runner.engine.is_ddp and not runner.engine.is_master_process:
# worker sync
runner.engine.barrier()
return
# let's log Top-N base metrics
log_message = "Top best models:\n"
# store latest state
if self.save_n_best == 0:
score = runner.epoch_metrics[self.loader_key][self.metric_key]
# pack checkpoint
checkpoint = self._pack_checkpoint(runner)
# save checkpoint
checkpoint_path = self._save_checkpoint(
runner=runner,
checkpoint=checkpoint,
is_best=True, # will duplicate current (last) as best
is_last=False, # don't need that because current state is last
)
# add metrics to records
# save checkpoint metrics
metrics_log = self._prepare_metrics_log(float(score), dict(runner.epoch_metrics))
save_config(metrics_log, f"{self.logdir}/{self.metrics_filename}")
log_message += f"{checkpoint_path}\t{score:3.4f}"
else:
log_message += "\n".join(
[f"{filepath}\t{score:3.4f}" for score, filepath, _, _, _ in self.top_best_metrics]
)
print(log_message)
# let's load runner state (model, criterion, optimizer, scheduler) if required
not_required_load_states = {"last", "last_full"}
if (
isinstance(self.load_on_stage_end, str)
and self.load_on_stage_end not in not_required_load_states
and self.save_n_best > 0
):
need_load_full = (
self.load_on_stage_end.endswith("full")
if isinstance(self.load_on_stage_end, str)
else False
)
_load_runner(
logdir=self.logdir,
runner=runner,
mapping=self.load_on_stage_end,
load_full=need_load_full,
)
elif isinstance(self.load_on_stage_end, dict) and self.save_n_best > 0:
to_load = {
k: v
for k, v in self.load_on_stage_end.items()
if v not in not_required_load_states
}
_load_runner(logdir=self.logdir, runner=runner, mapping=to_load)
if runner.engine.is_ddp and runner.engine.is_master_process:
# worker sync
runner.engine.barrier()
__all__ = ["ICheckpointCallback", "CheckpointCallback"]