from typing import Dict, Tuple, Union
from collections import OrderedDict
import os
from pathlib import Path
from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils
def _pack_state(state: State):
checkpoint = utils.pack_checkpoint(
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
epoch_metrics=dict(state.epoch_metrics),
valid_metrics=dict(state.valid_metrics),
stage_name=state.stage_name,
epoch=state.epoch,
loader_name=state.loader_name,
loader_step=state.loader_step,
global_epoch=state.global_epoch,
checkpoint_data=state.checkpoint_data,
main_metric=state.main_metric,
minimize_metric=state.minimize_metric,
valid_loader=state.valid_loader,
)
return checkpoint
def _load_checkpoint(*, filename, state: State, load_full: bool = True):
if not os.path.isfile(filename):
raise Exception(f"No checkpoint found at {filename}")
print(f"=> loading checkpoint {filename}")
checkpoint = utils.load_checkpoint(filename)
if not state.stage_name.startswith("infer") and load_full:
state.stage_name = checkpoint["stage_name"]
state.epoch = checkpoint["epoch"]
state.global_epoch = checkpoint["global_epoch"]
# @TODO: should we also load,
# checkpoint_data, main_metric, minimize_metric, valid_loader ?
# epoch_metrics, valid_metrics ?
if load_full:
utils.unpack_checkpoint(
checkpoint,
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
)
print(
f"loaded state checkpoint {filename} "
f"(global epoch {checkpoint['global_epoch']}, "
f"epoch {checkpoint['epoch']}, "
f"stage {checkpoint['stage_name']})"
)
else:
utils.unpack_checkpoint(
checkpoint, model=state.model,
)
print(f"loaded model checkpoint {filename}")
class BaseCheckpointCallback(Callback):
"""Base class for all checkpoint callbacks."""
def __init__(self, metrics_filename: str = "_metrics.json"):
"""
Args:
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(order=CallbackOrder.External, node=CallbackNode.All)
self.metrics_filename = metrics_filename
self.metrics: dict = {}
def get_checkpoint_suffix(self, checkpoint: dict) -> str:
return "checkpoint"
def save_metric(self, logdir: Union[str, Path], metrics: Dict) -> None:
utils.save_config(
metrics, f"{logdir}/checkpoints/{self.metrics_filename}"
)
def on_exception(self, state: State):
exception = state.exception
if not utils.is_exception(exception):
return
try:
checkpoint = _pack_state(state)
suffix = self.get_checkpoint_suffix(checkpoint)
suffix = f"{suffix}.exception_{exception.__class__.__name__}"
utils.save_checkpoint(
logdir=Path(f"{state.logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=suffix,
is_best=False,
is_last=False,
)
metrics = self.metrics
metrics[suffix] = state.valid_metrics
self.save_metric(state.logdir, metrics)
except Exception:
pass
[docs]class CheckpointCallback(BaseCheckpointCallback):
"""
Checkpoint callback to save/restore your model/criterion/optimizer/metrics.
"""
[docs] def __init__(
self,
save_n_best: int = 1,
resume: str = None,
resume_dir: str = None,
metrics_filename: str = "_metrics.json",
load_on_stage_end: str = None,
):
"""
Args:
save_n_best (int): number of best checkpoint to keep,
if ``0`` then store only last state of model and
``load_on_stage_end`` should be one of
``last`` or ``last_full``.
resume (str): path to checkpoint to load
and initialize runner state
resume_dir (str): directory with checkpoints,
if specified in combination with ``resume``
than resume checkpoint will be loaded from ``resume_dir``
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
load_on_stage_end (str): name of the model to load
at the end of the stage.
You can use ``best``, ``best_full``
to load the best model according to validation metrics,
or ``last`` ``last_full`` (default behaviour)
to use just the last one.
If None then no action is required at stage end
and will be used last state.
"""
super().__init__(metrics_filename)
assert load_on_stage_end in [
None,
"best",
"last",
"best_full",
"last_full",
]
assert save_n_best >= 0
if save_n_best == 0:
assert load_on_stage_end in (None, "last", "last_full")
if resume_dir is not None:
assert resume is not None
self.save_n_best = save_n_best
self.resume = resume
self.resume_dir = resume_dir
self.load_on_stage_end = load_on_stage_end
self.top_best_metrics = []
self.metrics_history = []
self._keys_from_state = ["resume", "resume_dir"]
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str:
"""
Create checkpoint filename suffix based on checkpoint data.
Args:
checkpoint (dict): checkpoint dict,
should contain ``stage_name`` and ``epoch`` keys.
"""
result = f"{checkpoint['stage_name']}.{checkpoint['epoch']}"
return result
[docs] def process_metrics(self, last_valid_metrics: Dict[str, float]) -> Dict:
"""
Add last validation metrics to list of previous validation metrics
and keep ``save_n_best`` metrics.
Args:
last_valid_metrics (dict): dict with metrics
from last validation step.
"""
top_best_checkpoints = [
(Path(filepath).stem, valid_metric)
for (filepath, _, valid_metric) in self.top_best_metrics
]
all_epochs_metrics = [
(f"epoch_{order_index}", valid_metric)
for (order_index, valid_metric) in enumerate(self.metrics_history)
]
metrics = []
if self.save_n_best > 0:
best_valid_metrics = top_best_checkpoints[0][1]
metrics = (
[("best", best_valid_metrics), ("last", last_valid_metrics)]
+ top_best_checkpoints
+ all_epochs_metrics
)
else:
metrics = [("last", last_valid_metrics)]
self.metrics = OrderedDict(metrics)
return self.metrics
[docs] def truncate_checkpoints(self, minimize_metric: bool) -> None:
"""
Keep ``save_n_best`` checkpoints based on main metric.
Args:
minimize_metric (bool): if ``True`` then keep
``save_n_best`` checkpoints with the lowest/highest values
of the main metric.
"""
self.top_best_metrics = sorted(
self.top_best_metrics,
key=lambda x: x[1],
reverse=not minimize_metric,
)
if len(self.top_best_metrics) > self.save_n_best:
last_item = self.top_best_metrics.pop(-1)
last_filepath = Path(last_item[0])
last_filepaths = last_filepath.parent.glob(
last_filepath.name.replace(".pth", "*")
)
for filepath in last_filepaths:
os.remove(filepath)
def _save_checkpoint(
self,
logdir: Union[str, Path],
suffix: str,
checkpoint: Dict,
is_best: bool,
is_last: bool,
) -> Tuple[str, str]:
"""
Save checkpoint (simple and full).
Args:
logdir (str or Path object): directory for storing checkpoints
suffix (str): checkpoint suffix
checkpoint (dict): dict with checkpoint data
is_best (bool): indicator to save best checkpoint,
if true then will be saved two additional checkpoints -
``best`` and ``best_full``.
is_last (bool): indicator to save the last checkpoint,
if true then will be saved two additional checkpoints -
``last`` and ``last_full``.
"""
full_checkpoint_path = utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=f"{suffix}_full",
is_best=is_best,
is_last=is_last,
special_suffix="_full",
)
exclude = ["criterion", "optimizer", "scheduler"]
checkpoint_path = utils.save_checkpoint(
checkpoint={
key: value
for key, value in checkpoint.items()
if all(z not in key for z in exclude)
},
logdir=Path(f"{logdir}/checkpoints/"),
suffix=suffix,
is_best=is_best,
is_last=is_last,
)
return (full_checkpoint_path, checkpoint_path)
[docs] def process_checkpoint(
self,
logdir: Union[str, Path],
checkpoint: Dict,
is_best: bool,
main_metric: str = "loss",
minimize_metric: bool = True,
) -> None:
"""
Save checkpoint and metrics.
Args:
logdir (str or Path object): directory for storing checkpoints
checkpoint (dict): dict with checkpoint data
is_best (bool): indicator to save best checkpoint,
if true then will be saved two additional checkpoints -
``best`` and ``best_full``.
main_metric (str): metric to use for selecting the best model
minimize_metric (bool): indicator for selecting best metric,
if true then best metric will be the metric with
the lowest value, otherwise with the greatest value.
"""
_, filepath = self._save_checkpoint(
logdir=logdir,
checkpoint=checkpoint,
suffix=self.get_checkpoint_suffix(checkpoint),
is_best=is_best,
is_last=True,
)
valid_metrics = checkpoint["valid_metrics"]
checkpoint_metric = valid_metrics[main_metric]
metrics_record = (filepath, checkpoint_metric, valid_metrics)
self.top_best_metrics.append(metrics_record)
self.metrics_history.append(metrics_record)
self.truncate_checkpoints(minimize_metric=minimize_metric)
metrics = self.process_metrics(valid_metrics)
self.save_metric(logdir, metrics)
[docs] def on_stage_start(self, state: State) -> None:
"""
Setup model for stage.
NOTE: If CheckpointCallback initialized with ``resume``
(as path to checkpoint file) or ``resume`` (as filename)
and ``resume_dir`` (as directory with file)
then will be performed loading checkpoint.
Args:
state (State): training state
"""
for key in self._keys_from_state:
value = getattr(state, key, None)
if value is not None:
setattr(self, key, value)
if self.resume_dir is not None:
self.resume = str(self.resume_dir) + "/" + str(self.resume)
if self.resume is not None:
_load_checkpoint(filename=self.resume, state=state)
self.resume = None
[docs] def on_epoch_end(self, state: State) -> None:
"""
Collect and save checkpoint after epoch.
Args:
state (State): training state
"""
if state.stage_name.startswith("infer") or state.is_distributed_worker:
return
if self.save_n_best > 0:
checkpoint = _pack_state(state)
self.process_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
is_best=state.is_best_valid,
main_metric=state.main_metric,
minimize_metric=state.minimize_metric,
)
[docs] def on_stage_end(self, state: State) -> None:
"""
Show information about best checkpoints during the stage and
load model specified in ``load_on_stage_end``.
Args:
state (State): training state
"""
if state.stage_name.startswith("infer") or state.is_distributed_worker:
return
log_message = "Top best models:\n"
# store latest state
if self.save_n_best == 0:
checkpoint = _pack_state(state)
_, filepath = self._save_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
suffix="last",
is_best=True, # will duplicate current (last) as best
is_last=False, # don't need that because current state is last
)
metrics = self.process_metrics(checkpoint["valid_metrics"])
self.save_metric(state.logdir, metrics)
main_metric_value = metrics["last"][state.main_metric]
log_message += "{filepath}\t{metric:3.4f}".format(
filepath=filepath, metric=main_metric_value
)
else:
log_message += "\n".join(
[
"{filepath}\t{metric:3.4f}".format(
filepath=filepath, metric=checkpoint_metric
)
for filepath, checkpoint_metric, _ in self.top_best_metrics
]
)
print(log_message)
if (
self.load_on_stage_end in ["best", "best_full"]
and self.save_n_best > 0
):
resume = f"{state.logdir}/checkpoints/{self.load_on_stage_end}.pth"
print(f"Loading {self.load_on_stage_end} model from {resume}")
_load_checkpoint(
filename=resume,
state=state,
load_full=self.load_on_stage_end.endswith("full"),
)
[docs]class IterationCheckpointCallback(BaseCheckpointCallback):
"""Iteration checkpoint callback to save your model/criterion/optimizer."""
[docs] def __init__(
self,
save_n_last: int = 1,
period: int = 100,
stage_restart: bool = True,
metrics_filename: str = "_metrics_iter.json",
load_on_stage_end: str = "best_full",
):
"""
Args:
save_n_last (int): number of last checkpoint to keep
period (int): save the checkpoint every `period`
stage_restart (bool): restart counter every stage or not
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
load_on_stage_end (str): name of the model to load
at the end of the stage.
You can use ``best``, ``best_full`` (default)
to load the best model according to validation metrics,
or ``last`` ``last_full`` to use just the last one.
"""
super().__init__(metrics_filename)
self.save_n_last = save_n_last
self.period = period
self.stage_restart = stage_restart
self._iteration_counter = 0
self.last_checkpoints = []
self.metrics_history = []
self.load_on_stage_end = load_on_stage_end
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str:
"""
Create checkpoint filename suffix based on checkpoint data.
Args:
checkpoint (dict): checkpoint dict,
should contain ``stage_name`` and ``epoch`` keys.
"""
result = (
f"{checkpoint['stage_name']}."
f"epoch.{checkpoint['epoch']}."
f"iter.{self._iteration_counter}"
)
return result
[docs] def process_metrics(self) -> Dict:
"""
Update metrics with last ``save_n_last`` checkpoints.
"""
n_last_checkpoints = [
(Path(filepath).stem, batch_values)
for (filepath, batch_values) in self.last_checkpoints
]
all_epochs_metrics = [
(f"epoch_{order_index}", valid_metric)
for (order_index, valid_metric) in enumerate(self.metrics_history)
]
metrics = OrderedDict(n_last_checkpoints + all_epochs_metrics)
self.metrics = metrics
return self.metrics
[docs] def truncate_checkpoints(self, **kwargs) -> None:
"""
Keep ``save_n_best`` checkpoints based on main metric.
"""
if len(self.last_checkpoints) > self.save_n_last:
item = self.last_checkpoints.pop(0)
top_filepath = item[0]
os.remove(top_filepath)
[docs] def process_checkpoint(
self,
logdir: Union[str, Path],
checkpoint: Dict,
batch_metrics: Dict[str, float],
):
"""
Save checkpoint and metrics.
Args:
logdir (str or Path object): directory for storing checkpoints
checkpoint (dict): dict with checkpoint data
batch_metrics (dict): dict with metrics based on a few batches
"""
filepath = utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=self.get_checkpoint_suffix(checkpoint),
is_best=False,
is_last=False,
)
self.last_checkpoints.append((filepath, batch_metrics))
self.truncate_checkpoints()
self.metrics_history.append(batch_metrics)
metrics = self.process_metrics()
self.save_metric(logdir, metrics)
print(f"\nSaved checkpoint at {filepath}")
[docs] def on_stage_start(self, state: State):
"""
Reset iterations counter.
Args:
state (State): training state
"""
if self.stage_restart:
self._iteration_counter = 0
[docs] def on_batch_end(self, state: State):
"""
Save checkpoint based on batches count.
Args:
state (State): training state
"""
self._iteration_counter += 1
if self._iteration_counter % self.period == 0:
checkpoint = _pack_state(state)
self.process_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
batch_metrics=state.batch_metrics,
)
[docs] def on_stage_end(self, state: State):
"""
Load model specified in ``load_on_stage_end``.
Args:
state (State): training state
"""
if self.load_on_stage_end in ["best", "best_full"]:
resume = f"{state.logdir}/checkpoints/{self.load_on_stage_end}.pth"
print(f"Loading {self.load_on_stage_end} model from {resume}")
_load_checkpoint(
filename=resume,
state=state,
load_full=self.load_on_stage_end.endswith("full"),
)
__all__ = ["CheckpointCallback", "IterationCheckpointCallback"]