from typing import Dict # isort:skip
from collections import OrderedDict
import os
from pathlib import Path
import safitty
from catalyst.dl import utils
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.utils import is_exception
class BaseCheckpointCallback(Callback):
"""
Base class for all checkpoint callbacks
"""
def __init__(self, metric_filename: str = "_metrics.json"):
"""
Args:
metric_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(CallbackOrder.External)
self.metric_filename = metric_filename
self.metrics: dict = {}
def get_metric(self, **kwargs) -> Dict:
return self.metrics
def save_metric(self, logdir: str, metrics: Dict) -> None:
safitty.save(metrics, f"{logdir}/checkpoints/{self.metric_filename}")
def truncate_checkpoints(self, **kwargs) -> None:
pass
def process_checkpoint(self, **kwargs) -> None:
pass
def get_checkpoint_suffix(self, checkpoint: dict) -> str:
pass
def on_exception(self, state: RunnerState):
exception = state.exception
if not is_exception(exception):
return
try:
valid_metrics = state.metrics.valid_values
epoch_metrics = state.metrics.epoch_values
checkpoint = utils.pack_checkpoint(
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
epoch_metrics=epoch_metrics,
valid_metrics=valid_metrics,
stage=state.stage,
epoch=state.epoch_log,
checkpoint_data=state.checkpoint_data
)
suffix = self.get_checkpoint_suffix(checkpoint)
suffix = f"{suffix}.exception_{exception.__class__.__name__}"
utils.save_checkpoint(
logdir=Path(f"{state.logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=suffix,
is_best=False,
is_last=False
)
metrics = self.metrics
metrics[suffix] = valid_metrics
self.save_metric(state.logdir, metrics)
except Exception:
pass
[docs]class CheckpointCallback(BaseCheckpointCallback):
"""
Checkpoint callback to save/restore your model/criterion/optimizer/metrics.
"""
[docs] def __init__(
self,
save_n_best: int = 1,
resume: str = None,
resume_dir: str = None,
metric_filename: str = "_metrics.json"
):
"""
Args:
save_n_best (int): number of best checkpoint to keep
resume (str): path to checkpoint to load
and initialize runner state
metric_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(metric_filename)
self.save_n_best = save_n_best
self.resume = resume
self.resume_dir = resume_dir
self.top_best_metrics = []
self.epochs_metrics = []
self._keys_from_state = ["resume", "resume_dir"]
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str:
result = f"{checkpoint['stage']}.{checkpoint['epoch']}"
return result
[docs] @staticmethod
def load_checkpoint(*, filename, state: RunnerState):
if os.path.isfile(filename):
print(f"=> loading checkpoint {filename}")
checkpoint = utils.load_checkpoint(filename)
state.epoch = checkpoint["epoch"]
utils.unpack_checkpoint(
checkpoint,
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler
)
print(
f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})"
)
else:
raise Exception(f"No checkpoint found at {filename}")
[docs] def get_metric(self, last_valid_metrics) -> Dict:
top_best_checkpoints = [
(Path(filepath).stem, valid_metric)
for (filepath, _, valid_metric) in self.top_best_metrics
]
all_epochs_metrics = [
(f"epoch_{order_index}", valid_metric)
for (order_index, valid_metric) in enumerate(self.epochs_metrics)
]
best_valid_metrics = top_best_checkpoints[0][1]
metrics = OrderedDict(
[("best", best_valid_metrics)] +
[("last", last_valid_metrics)] +
top_best_checkpoints +
all_epochs_metrics
)
self.metrics = metrics
return self.metrics
[docs] def truncate_checkpoints(self, minimize_metric: bool) -> None:
self.top_best_metrics = sorted(
self.top_best_metrics,
key=lambda x: x[1],
reverse=not minimize_metric
)
if len(self.top_best_metrics) > self.save_n_best:
last_item = self.top_best_metrics.pop(-1)
last_filepath = Path(last_item[0])
last_filepaths = last_filepath.parent.glob(
last_filepath.name.replace(".pth", "*"))
for filepath in last_filepaths:
os.remove(filepath)
[docs] def process_checkpoint(
self,
logdir: str,
checkpoint: Dict,
is_best: bool,
main_metric: str = "loss",
minimize_metric: bool = True
):
suffix = self.get_checkpoint_suffix(checkpoint)
utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=f"{suffix}_full",
is_best=is_best,
is_last=True,
special_suffix="_full"
)
exclude = ["criterion", "optimizer", "scheduler"]
checkpoint = {
key: value
for key, value in checkpoint.items()
if all(z not in key for z in exclude)
}
filepath = utils.save_checkpoint(
checkpoint=checkpoint,
logdir=Path(f"{logdir}/checkpoints/"),
suffix=suffix,
is_best=is_best,
is_last=True
)
valid_metrics = checkpoint["valid_metrics"]
checkpoint_metric = valid_metrics[main_metric]
metrics_record = (filepath, checkpoint_metric, valid_metrics)
self.top_best_metrics.append(metrics_record)
self.epochs_metrics.append(metrics_record)
self.truncate_checkpoints(minimize_metric=minimize_metric)
metrics = self.get_metric(valid_metrics)
self.save_metric(logdir, metrics)
[docs] def on_stage_start(self, state: RunnerState):
for key in self._keys_from_state:
value = getattr(state, key, None)
if value is not None:
setattr(self, key, value)
if self.resume_dir is not None:
self.resume = str(self.resume_dir) + "/" + str(self.resume)
if self.resume is not None:
self.load_checkpoint(filename=self.resume, state=state)
[docs] def on_epoch_end(self, state: RunnerState):
if state.stage.startswith("infer"):
return
valid_metrics = dict(state.metrics.valid_values)
epoch_metrics = dict(state.metrics.epoch_values)
checkpoint = utils.pack_checkpoint(
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
epoch_metrics=epoch_metrics,
valid_metrics=valid_metrics,
stage=state.stage,
epoch=state.epoch_log,
checkpoint_data=state.checkpoint_data
)
self.process_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
is_best=state.metrics.is_best,
main_metric=state.main_metric,
minimize_metric=state.minimize_metric
)
[docs] def on_stage_end(self, state: RunnerState):
print("Top best models:")
top_best_metrics_str = "\n".join(
[
"{filepath}\t{metric:3.4f}".format(
filepath=filepath, metric=checkpoint_metric
) for filepath, checkpoint_metric, _ in self.top_best_metrics
]
)
print(top_best_metrics_str)
[docs]class IterationCheckpointCallback(BaseCheckpointCallback):
"""
Iteration checkpoint callback to save your model/criterion/optimizer
"""
[docs] def __init__(
self,
save_n_last: int = 3,
num_iters: int = 100,
stage_restart: bool = True,
metric_filename: str = "_metrics_iter.json"
):
"""
Args:
save_n_last (int): number of last checkpoint to keep
num_iters (int): save the checkpoint every `num_iters`
stage_restart (bool): restart counter every stage or not
metric_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(metric_filename)
self.save_n_last = save_n_last
self.num_iters = num_iters
self.stage_restart = stage_restart
self._iteration_counter = 0
self.last_checkpoints = []
self.epochs_metrics = []
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str:
result = f"{checkpoint['stage']}." \
f"epoch.{checkpoint['epoch']}." \
f"iter.{self._iteration_counter}"
return result
[docs] def get_metric(self, **kwargs) -> Dict:
n_last_checkpoints = [
(Path(filepath).stem, batch_values)
for (filepath, batch_values) in self.last_checkpoints
]
all_epochs_metrics = [
(f"epoch_{order_index}", valid_metric)
for (order_index, valid_metric) in enumerate(self.epochs_metrics)
]
metrics = OrderedDict(
n_last_checkpoints +
all_epochs_metrics
)
self.metrics = metrics
return self.metrics
[docs] def truncate_checkpoints(self, **kwargs) -> None:
if len(self.last_checkpoints) > self.save_n_last:
item = self.last_checkpoints.pop(0)
top_filepath = item[0]
os.remove(top_filepath)
[docs] def process_checkpoint(
self,
logdir: str,
checkpoint: Dict,
batch_values: Dict[str, float]
):
filepath = utils.save_checkpoint(
logdir=Path(f"{logdir}/checkpoints/"),
checkpoint=checkpoint,
suffix=self.get_checkpoint_suffix(checkpoint),
is_best=False,
is_last=False
)
self.last_checkpoints.append((filepath, batch_values))
self.truncate_checkpoints()
self.epochs_metrics.append(batch_values)
metrics = self.get_metric()
self.save_metric(logdir, metrics)
print(f"\nSaved checkpoint at {filepath}")
[docs] def on_stage_start(self, state):
if self.stage_restart:
self._iteration_counter = 0
[docs] def on_batch_end(self, state):
self._iteration_counter += 1
if self._iteration_counter % self.num_iters == 0:
checkpoint = utils.pack_checkpoint(
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
epoch_metrics=None,
valid_metrics=None,
stage=state.stage,
epoch=state.epoch_log
)
self.process_checkpoint(
logdir=state.logdir,
checkpoint=checkpoint,
batch_values=state.metrics.batch_values
)
__all__ = ["CheckpointCallback", "IterationCheckpointCallback"]