Source code for catalyst.dl.callbacks.checkpoint

from typing import Dict  # isort:skip
from collections import OrderedDict
import os
from pathlib import Path

import safitty

from catalyst.dl import utils
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.utils import is_exception


class BaseCheckpointCallback(Callback):
    """
    Base class for all checkpoint callbacks
    """
    def __init__(self, metric_filename: str = "_metrics.json"):
        """
        Args:
            metric_filename (str): filename to save metrics
                in checkpoint folder. Must ends on ``.json`` or ``.yml``
        """
        super().__init__(CallbackOrder.External)
        self.metric_filename = metric_filename
        self.metrics: dict = {}

    def get_metric(self, **kwargs) -> Dict:
        return self.metrics

    def save_metric(self, logdir: str, metrics: Dict) -> None:
        safitty.save(metrics, f"{logdir}/checkpoints/{self.metric_filename}")

    def truncate_checkpoints(self, **kwargs) -> None:
        pass

    def process_checkpoint(self, **kwargs) -> None:
        pass

    def get_checkpoint_suffix(self, checkpoint: dict) -> str:
        pass

    def on_exception(self, state: RunnerState):
        exception = state.exception
        if not is_exception(exception):
            return

        try:
            valid_metrics = state.metrics.valid_values
            epoch_metrics = state.metrics.epoch_values
            checkpoint = utils.pack_checkpoint(
                model=state.model,
                criterion=state.criterion,
                optimizer=state.optimizer,
                scheduler=state.scheduler,
                epoch_metrics=epoch_metrics,
                valid_metrics=valid_metrics,
                stage=state.stage,
                epoch=state.epoch_log,
                checkpoint_data=state.checkpoint_data
            )
            suffix = self.get_checkpoint_suffix(checkpoint)
            suffix = f"{suffix}.exception_{exception.__class__.__name__}"
            utils.save_checkpoint(
                logdir=Path(f"{state.logdir}/checkpoints/"),
                checkpoint=checkpoint,
                suffix=suffix,
                is_best=False,
                is_last=False
            )
            metrics = self.metrics
            metrics[suffix] = valid_metrics
            self.save_metric(state.logdir, metrics)
        except Exception:
            pass


[docs]class CheckpointCallback(BaseCheckpointCallback): """ Checkpoint callback to save/restore your model/criterion/optimizer/metrics. """
[docs] def __init__( self, save_n_best: int = 1, resume: str = None, resume_dir: str = None, metric_filename: str = "_metrics.json" ): """ Args: save_n_best (int): number of best checkpoint to keep resume (str): path to checkpoint to load and initialize runner state metric_filename (str): filename to save metrics in checkpoint folder. Must ends on ``.json`` or ``.yml`` """ super().__init__(metric_filename) self.save_n_best = save_n_best self.resume = resume self.resume_dir = resume_dir self.top_best_metrics = [] self.epochs_metrics = [] self._keys_from_state = ["resume", "resume_dir"]
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str: result = f"{checkpoint['stage']}.{checkpoint['epoch']}" return result
[docs] @staticmethod def load_checkpoint(*, filename, state: RunnerState): if os.path.isfile(filename): print(f"=> loading checkpoint {filename}") checkpoint = utils.load_checkpoint(filename) state.epoch = checkpoint["epoch"] utils.unpack_checkpoint( checkpoint, model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler ) print( f"loaded checkpoint {filename} (epoch {checkpoint['epoch']})" ) else: raise Exception(f"No checkpoint found at {filename}")
[docs] def get_metric(self, last_valid_metrics) -> Dict: top_best_checkpoints = [ (Path(filepath).stem, valid_metric) for (filepath, _, valid_metric) in self.top_best_metrics ] all_epochs_metrics = [ (f"epoch_{order_index}", valid_metric) for (order_index, valid_metric) in enumerate(self.epochs_metrics) ] best_valid_metrics = top_best_checkpoints[0][1] metrics = OrderedDict( [("best", best_valid_metrics)] + [("last", last_valid_metrics)] + top_best_checkpoints + all_epochs_metrics ) self.metrics = metrics return self.metrics
[docs] def truncate_checkpoints(self, minimize_metric: bool) -> None: self.top_best_metrics = sorted( self.top_best_metrics, key=lambda x: x[1], reverse=not minimize_metric ) if len(self.top_best_metrics) > self.save_n_best: last_item = self.top_best_metrics.pop(-1) last_filepath = Path(last_item[0]) last_filepaths = last_filepath.parent.glob( last_filepath.name.replace(".pth", "*")) for filepath in last_filepaths: os.remove(filepath)
[docs] def process_checkpoint( self, logdir: str, checkpoint: Dict, is_best: bool, main_metric: str = "loss", minimize_metric: bool = True ): suffix = self.get_checkpoint_suffix(checkpoint) utils.save_checkpoint( logdir=Path(f"{logdir}/checkpoints/"), checkpoint=checkpoint, suffix=f"{suffix}_full", is_best=is_best, is_last=True, special_suffix="_full" ) exclude = ["criterion", "optimizer", "scheduler"] checkpoint = { key: value for key, value in checkpoint.items() if all(z not in key for z in exclude) } filepath = utils.save_checkpoint( checkpoint=checkpoint, logdir=Path(f"{logdir}/checkpoints/"), suffix=suffix, is_best=is_best, is_last=True ) valid_metrics = checkpoint["valid_metrics"] checkpoint_metric = valid_metrics[main_metric] metrics_record = (filepath, checkpoint_metric, valid_metrics) self.top_best_metrics.append(metrics_record) self.epochs_metrics.append(metrics_record) self.truncate_checkpoints(minimize_metric=minimize_metric) metrics = self.get_metric(valid_metrics) self.save_metric(logdir, metrics)
[docs] def on_stage_start(self, state: RunnerState): for key in self._keys_from_state: value = getattr(state, key, None) if value is not None: setattr(self, key, value) if self.resume_dir is not None: self.resume = str(self.resume_dir) + "/" + str(self.resume) if self.resume is not None: self.load_checkpoint(filename=self.resume, state=state)
[docs] def on_epoch_end(self, state: RunnerState): if state.stage.startswith("infer"): return valid_metrics = dict(state.metrics.valid_values) epoch_metrics = dict(state.metrics.epoch_values) checkpoint = utils.pack_checkpoint( model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler, epoch_metrics=epoch_metrics, valid_metrics=valid_metrics, stage=state.stage, epoch=state.epoch_log, checkpoint_data=state.checkpoint_data ) self.process_checkpoint( logdir=state.logdir, checkpoint=checkpoint, is_best=state.metrics.is_best, main_metric=state.main_metric, minimize_metric=state.minimize_metric )
[docs] def on_stage_end(self, state: RunnerState): print("Top best models:") top_best_metrics_str = "\n".join( [ "{filepath}\t{metric:3.4f}".format( filepath=filepath, metric=checkpoint_metric ) for filepath, checkpoint_metric, _ in self.top_best_metrics ] ) print(top_best_metrics_str)
[docs]class IterationCheckpointCallback(BaseCheckpointCallback): """ Iteration checkpoint callback to save your model/criterion/optimizer """
[docs] def __init__( self, save_n_last: int = 3, num_iters: int = 100, stage_restart: bool = True, metric_filename: str = "_metrics_iter.json" ): """ Args: save_n_last (int): number of last checkpoint to keep num_iters (int): save the checkpoint every `num_iters` stage_restart (bool): restart counter every stage or not metric_filename (str): filename to save metrics in checkpoint folder. Must ends on ``.json`` or ``.yml`` """ super().__init__(metric_filename) self.save_n_last = save_n_last self.num_iters = num_iters self.stage_restart = stage_restart self._iteration_counter = 0 self.last_checkpoints = [] self.epochs_metrics = []
[docs] def get_checkpoint_suffix(self, checkpoint: dict) -> str: result = f"{checkpoint['stage']}." \ f"epoch.{checkpoint['epoch']}." \ f"iter.{self._iteration_counter}" return result
[docs] def get_metric(self, **kwargs) -> Dict: n_last_checkpoints = [ (Path(filepath).stem, batch_values) for (filepath, batch_values) in self.last_checkpoints ] all_epochs_metrics = [ (f"epoch_{order_index}", valid_metric) for (order_index, valid_metric) in enumerate(self.epochs_metrics) ] metrics = OrderedDict( n_last_checkpoints + all_epochs_metrics ) self.metrics = metrics return self.metrics
[docs] def truncate_checkpoints(self, **kwargs) -> None: if len(self.last_checkpoints) > self.save_n_last: item = self.last_checkpoints.pop(0) top_filepath = item[0] os.remove(top_filepath)
[docs] def process_checkpoint( self, logdir: str, checkpoint: Dict, batch_values: Dict[str, float] ): filepath = utils.save_checkpoint( logdir=Path(f"{logdir}/checkpoints/"), checkpoint=checkpoint, suffix=self.get_checkpoint_suffix(checkpoint), is_best=False, is_last=False ) self.last_checkpoints.append((filepath, batch_values)) self.truncate_checkpoints() self.epochs_metrics.append(batch_values) metrics = self.get_metric() self.save_metric(logdir, metrics) print(f"\nSaved checkpoint at {filepath}")
[docs] def on_stage_start(self, state): if self.stage_restart: self._iteration_counter = 0
[docs] def on_batch_end(self, state): self._iteration_counter += 1 if self._iteration_counter % self.num_iters == 0: checkpoint = utils.pack_checkpoint( model=state.model, criterion=state.criterion, optimizer=state.optimizer, scheduler=state.scheduler, epoch_metrics=None, valid_metrics=None, stage=state.stage, epoch=state.epoch_log ) self.process_checkpoint( logdir=state.logdir, checkpoint=checkpoint, batch_values=state.metrics.batch_values )
__all__ = ["CheckpointCallback", "IterationCheckpointCallback"]