from typing import Dict, Union
from collections import OrderedDict
import os
from pathlib import Path
from catalyst.core import Callback, CallbackNode, CallbackOrder, State, utils
def _pack_state(state: State):
checkpoint = utils.pack_checkpoint(
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
epoch_metrics=dict(state.epoch_metrics),
valid_metrics=dict(state.valid_metrics),
stage_name=state.stage_name,
epoch=state.epoch,
loader_name=state.loader_name,
loader_step=state.loader_step,
global_epoch=state.global_epoch,
checkpoint_data=state.checkpoint_data,
main_metric=state.main_metric,
minimize_metric=state.minimize_metric,
valid_loader=state.valid_loader,
)
return checkpoint
def _load_checkpoint(*, filename, state: State, load_full: bool = True):
if not os.path.isfile(filename):
raise Exception(f"No checkpoint found at {filename}")
print(f"=> loading checkpoint {filename}")
checkpoint = utils.load_checkpoint(filename)
if not state.stage_name.startswith("infer") and load_full:
state.stage_name = checkpoint["stage_name"]
state.epoch = checkpoint["epoch"]
state.global_epoch = checkpoint["global_epoch"]
# @TODO: should we also load,
# checkpoint_data, main_metric, minimize_metric, valid_loader ?
# epoch_metrics, valid_metrics ?
if load_full:
utils.unpack_checkpoint(
checkpoint,
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
)
print(
f"loaded state checkpoint {filename} "
f"(global epoch {checkpoint['global_epoch']}, "
f"epoch {checkpoint['epoch']}, "
f"stage {checkpoint['stage_name']})"
)
else:
utils.unpack_checkpoint(
checkpoint, model=state.model,
)
print(f"loaded model checkpoint {filename}")
class BaseCheckpointCallback(Callback):
"""Base class for all checkpoint callbacks."""
def __init__(self, metrics_filename: str = "_metrics.json"):
"""
Args:
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(order=CallbackOrder.External, node=CallbackNode.All)
self.metrics_filename = metrics_filename
self.metrics: dict = {}
def get_checkpoint_suffix(self, checkpoint: dict) -> str:
return "checkpoint"
def save_metric(self, logdir: Union[str, Path], metrics: Dict) -> None:
utils.save_config(
metrics, f"{logdir}/checkpoints/{self.metrics_filename}"
)
def on_exception(self, state: State):
exception = state.exception
if not utils.is_exception(exception):
return
try:
checkpoint = _pack_state(state)
suffix = self.get_checkpoint_suffix(checkpoint)
suffix = f"{suffix}.exception_{exception.__class__.__name__}"
utils.save_checkpoint(
logdir=Path(f"{state.logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=suffix,
is_best=False,
is_last=False,
)
metrics = self.metrics
metrics[suffix] = state.valid_metrics
self.save_metric(state.logdir, metrics)
except Exception:
pass
[docs]class CheckpointCallback(BaseCheckpointCallback):
"""
Checkpoint callback to save/restore your model/criterion/optimizer/metrics.
"""
[docs] def __init__(
self,
save_n_best: int = 1,
resume: str = None,
resume_dir: str = None,
metrics_filename: str = "_metrics.json",
load_on_stage_end: str = None,
):
"""
Args:
save_n_best (int): number of best checkpoint to keep
resume (str): path to checkpoint to load
and initialize runner state
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
load_on_stage_end (str): name of the model to load
at the end of the stage.
You can use ``best`` to load the best model according
to validation metrics, or ``last`` to use just the last one
(default behaviour).
"""
super().__init__(metrics_filename)
assert load_on_stage_end in [
None,
"best",
"last",
"best_full",
"last_full",
]
self.save_n_best = save_n_best
self.resume = resume
self.resume_dir = resume_dir
self.load_on_stage_end = load_on_stage_end
self.top_best_metrics = []
self.metrics_history = []
self._keys_from_state = ["resume", "resume_dir"]
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str:
"""@TODO: Docs. Contribution is welcome."""
result = f"{checkpoint['stage_name']}.{checkpoint['epoch']}"
return result
[docs] def process_metrics(self, last_valid_metrics) -> Dict:
"""@TODO: Docs. Contribution is welcome."""
top_best_checkpoints = [
(Path(filepath).stem, valid_metric)
for (filepath, _, valid_metric) in self.top_best_metrics
]
all_epochs_metrics = [
(f"epoch_{order_index}", valid_metric)
for (order_index, valid_metric) in enumerate(self.metrics_history)
]
best_valid_metrics = top_best_checkpoints[0][1]
metrics = OrderedDict(
[("best", best_valid_metrics)]
+ [("last", last_valid_metrics)]
+ top_best_checkpoints
+ all_epochs_metrics
)
self.metrics = metrics
return self.metrics
[docs] def truncate_checkpoints(self, minimize_metric: bool) -> None:
"""@TODO: Docs. Contribution is welcome."""
self.top_best_metrics = sorted(
self.top_best_metrics,
key=lambda x: x[1],
reverse=not minimize_metric,
)
if len(self.top_best_metrics) > self.save_n_best:
last_item = self.top_best_metrics.pop(-1)
last_filepath = Path(last_item[0])
last_filepaths = last_filepath.parent.glob(
last_filepath.name.replace(".pth", "*")
)
for filepath in last_filepaths:
os.remove(filepath)
[docs] def process_checkpoint(
self,
logdir: Union[str, Path],
checkpoint: Dict,
is_best: bool,
main_metric: str = "loss",
minimize_metric: bool = True,
):
"""@TODO: Docs. Contribution is welcome."""
suffix = self.get_checkpoint_suffix(checkpoint)
utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=f"{suffix}_full",
is_best=is_best,
is_last=True,
special_suffix="_full",
)
exclude = ["criterion", "optimizer", "scheduler"]
checkpoint = {
key: value
for key, value in checkpoint.items()
if all(z not in key for z in exclude)
}
filepath = utils.save_checkpoint(
checkpoint=checkpoint,
logdir=Path(f"{logdir}/checkpoints/"),
suffix=suffix,
is_best=is_best,
is_last=True,
)
valid_metrics = checkpoint["valid_metrics"]
checkpoint_metric = valid_metrics[main_metric]
metrics_record = (filepath, checkpoint_metric, valid_metrics)
self.top_best_metrics.append(metrics_record)
self.metrics_history.append(metrics_record)
self.truncate_checkpoints(minimize_metric=minimize_metric)
metrics = self.process_metrics(valid_metrics)
self.save_metric(logdir, metrics)
[docs] def on_stage_start(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
for key in self._keys_from_state:
value = getattr(state, key, None)
if value is not None:
setattr(self, key, value)
if self.resume_dir is not None:
self.resume = str(self.resume_dir) + "/" + str(self.resume)
if self.resume is not None:
_load_checkpoint(filename=self.resume, state=state)
self.resume = None
[docs] def on_epoch_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.stage_name.startswith("infer") or state.is_distributed_worker:
return
checkpoint = _pack_state(state)
self.process_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
is_best=state.is_best_valid,
main_metric=state.main_metric,
minimize_metric=state.minimize_metric,
)
[docs] def on_stage_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.stage_name.startswith("infer") or state.is_distributed_worker:
return
print("Top best models:")
top_best_metrics_str = "\n".join(
[
"{filepath}\t{metric:3.4f}".format(
filepath=filepath, metric=checkpoint_metric
)
for filepath, checkpoint_metric, _ in self.top_best_metrics
]
)
print(top_best_metrics_str)
if self.load_on_stage_end in ["best", "best_full"]:
resume = f"{state.logdir}/checkpoints/{self.load_on_stage_end}.pth"
print(f"Loading {self.load_on_stage_end} model from {resume}")
_load_checkpoint(
filename=resume,
state=state,
load_full=self.load_on_stage_end.endswith("full"),
)
[docs]class IterationCheckpointCallback(BaseCheckpointCallback):
"""Iteration checkpoint callback to save your model/criterion/optimizer."""
[docs] def __init__(
self,
save_n_last: int = 1,
period: int = 100,
stage_restart: bool = True,
metrics_filename: str = "_metrics_iter.json",
):
"""
Args:
save_n_last (int): number of last checkpoint to keep
period (int): save the checkpoint every `period`
stage_restart (bool): restart counter every stage or not
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(metrics_filename)
self.save_n_last = save_n_last
self.period = period
self.stage_restart = stage_restart
self._iteration_counter = 0
self.last_checkpoints = []
self.metrics_history = []
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str:
"""@TODO: Docs. Contribution is welcome."""
result = (
f"{checkpoint['stage_name']}."
f"epoch.{checkpoint['epoch']}."
f"iter.{self._iteration_counter}"
)
return result
[docs] def process_metrics(self) -> Dict:
"""@TODO: Docs. Contribution is welcome."""
n_last_checkpoints = [
(Path(filepath).stem, batch_values)
for (filepath, batch_values) in self.last_checkpoints
]
all_epochs_metrics = [
(f"epoch_{order_index}", valid_metric)
for (order_index, valid_metric) in enumerate(self.metrics_history)
]
metrics = OrderedDict(n_last_checkpoints + all_epochs_metrics)
self.metrics = metrics
return self.metrics
[docs] def truncate_checkpoints(self, **kwargs) -> None:
"""
@TODO: Docs. Contribution is welcome
"""
if len(self.last_checkpoints) > self.save_n_last:
item = self.last_checkpoints.pop(0)
top_filepath = item[0]
os.remove(top_filepath)
[docs] def process_checkpoint(
self,
logdir: Union[str, Path],
checkpoint: Dict,
batch_metrics: Dict[str, float],
):
"""@TODO: Docs. Contribution is welcome."""
filepath = utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=self.get_checkpoint_suffix(checkpoint),
is_best=False,
is_last=False,
)
self.last_checkpoints.append((filepath, batch_metrics))
self.truncate_checkpoints()
self.metrics_history.append(batch_metrics)
metrics = self.process_metrics()
self.save_metric(logdir, metrics)
print(f"\nSaved checkpoint at {filepath}")
[docs] def on_stage_start(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if self.stage_restart:
self._iteration_counter = 0
[docs] def on_batch_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
self._iteration_counter += 1
if self._iteration_counter % self.period == 0:
checkpoint = _pack_state(state)
self.process_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
batch_metrics=state.batch_metrics,
)
__all__ = ["CheckpointCallback", "IterationCheckpointCallback"]