Shortcuts

Source code for catalyst.callbacks.checkpoint

from typing import Any, List
from collections import namedtuple
import json
import os
import shutil

import torch

from catalyst.core.callback import ICheckpointCallback
from catalyst.core.runner import IRunner
from catalyst.utils import (
    load_checkpoint,
    pack_checkpoint,
    save_checkpoint,
    unpack_checkpoint,
)

Checkpoint = namedtuple("Checkpoint", field_names=["obj", "logpath", "metric"])


[docs]class CheckpointCallback(ICheckpointCallback): """Checkpoint callback to save/restore your model/runner. Args: logdir: directory to store checkpoints loader_key: loader key for best model selection (based on metric score over the dataset) metric_key: metric key for best model selection (based on metric score over the dataset) minimize: boolean flag to minimize the required metric topk: number of best checkpoint to keep mode: checkpoint type to save, ``model`` or ``runner``. (default: model) save_last: boolean flag to save extra last checkpoint as ``{mode}.last.pth`` save_best: boolean flag to save extra best checkpoint as ``{mode}.best.pth`` resume_model: path to model checkpoint to load on experiment start resume_runner: path to runner checkpoint to load on experiment start load_best_on_end: boolean flag to load best model on experiment end """ def __init__( self, logdir: str, loader_key: str = None, metric_key: str = None, minimize: bool = None, topk: int = 1, mode: str = "model", save_last: bool = True, save_best: bool = True, resume_model: str = None, resume_runner: str = None, load_best_on_end: bool = False, ): """Init.""" super().__init__() assert topk >= 1 assert mode in ( "model", "runner", ), "`CheckpointCallback` could work only in `model` or `runner` modes." if minimize is not None: assert metric_key is not None, "please define the metric to track" self._minimize = minimize self.on_epoch_end = self.on_epoch_end_best else: self._minimize = False self.on_epoch_end = self.on_epoch_end_last self.logdir = logdir self.loader_key = loader_key self.metric_key = metric_key self.topk = topk self._storage: List[Checkpoint] = [] self.save_last = save_last self.save_best = save_best self.mode = mode self._resume_model = resume_model self._resume_runner = resume_runner self.load_best_on_end = load_best_on_end os.makedirs(self.logdir, exist_ok=True) def _save(self, runner: "IRunner", obj: Any, logprefix: str) -> str: logpath = f"{logprefix}.pth" if self.mode == "model": if issubclass(obj.__class__, torch.nn.Module): runner.engine.wait_for_everyone() obj = runner.engine.unwrap_model(obj) runner.engine.save(obj.state_dict(), logpath) elif isinstance(obj, dict): # obj = dict(model=obj) # noqa: C408 checkpoint = pack_checkpoint(model=obj) save_checkpoint(checkpoint, logpath) else: raise NotImplementedError() else: checkpoint = pack_checkpoint(**obj) save_checkpoint(checkpoint, logpath) return logpath def _load( self, runner: "IRunner", resume_logpath: Any = None, resume_model: str = None, resume_runner: str = None, ): if resume_logpath is not None: runner.engine.wait_for_everyone() if self.mode == "model": try: unwrapped_model = runner.engine.unwrap_model(runner.model) unwrapped_model.load_state_dict(load_checkpoint(resume_logpath)) except BaseException: checkpoint = load_checkpoint(resume_logpath) unpack_checkpoint(checkpoint=checkpoint, model=runner.model) else: checkpoint = load_checkpoint(resume_logpath) unpack_checkpoint(checkpoint=checkpoint, model=runner.model) if resume_runner is not None: runner.engine.wait_for_everyone() checkpoint = load_checkpoint(resume_runner) unpack_checkpoint( checkpoint=checkpoint, model=runner.model, criterion=runner.criterion, optimizer=runner.optimizer, scheduler=runner.scheduler, ) runner.epoch_step = checkpoint["epoch_step"] runner.batch_step = checkpoint["batch_step"] runner.sample_step = checkpoint["sample_step"] if resume_model is not None: runner.engine.wait_for_everyone() unwrapped_model = runner.engine.unwrap_model(runner.model) unwrapped_model.load_state_dict(load_checkpoint(resume_model)) # if resume_runner is not None or resume_model is not None: # runner.model, runner.optimizer = runner.engine.prepare( # runner.model, runner.optimizer # ) def _handle_epoch(self, runner: "IRunner", score: float): if self.mode == "model": obj = runner.model else: obj = dict( # noqa: C408 model=runner.model, criterion=runner.criterion, optimizer=runner.optimizer, scheduler=runner.scheduler, epoch_step=runner.epoch_step, batch_step=runner.batch_step, sample_step=runner.sample_step, ) if self.save_last: # @TODO: simplify it logprefix = f"{self.logdir}/{self.mode}.last" logpath = self._save(runner, obj, logprefix) logprefix = f"{self.logdir}/{self.mode}.{runner.epoch_step:04d}" logpath = self._save(runner, obj, logprefix) self._storage.append(Checkpoint(obj=obj, logpath=logpath, metric=score)) self._storage = sorted( self._storage, key=lambda x: x.metric, reverse=not self._minimize ) if len(self._storage) > self.topk: last_item = self._storage.pop(-1) if os.path.isfile(last_item.logpath): try: os.remove(last_item.logpath) except OSError: pass elif os.path.isdir(last_item.logpath): shutil.rmtree(last_item.logpath, ignore_errors=True) with open(f"{self.logdir}/{self.mode}.storage.json", "w") as fout: stats = { "logdir": str(self.logdir), "topk": self.topk, "loader_key": self.loader_key, "metric_key": self.metric_key, "minimize": self._minimize, } storage = [ {"logpath": str(x.logpath), "metric": x.metric} for x in self._storage ] stats["storage"] = storage json.dump(stats, fout, indent=2, ensure_ascii=False) def on_experiment_start(self, runner: "IRunner") -> None: """Event handler.""" self._storage: List[Checkpoint] = [] # assert issubclass(runner.model.__class__, torch.nn.Module), ( # "Could not understand the model class. " # "Do you mean ``nn.Module`` or ``nn.ModuleDict``?" # ) self._load( runner=runner, resume_runner=self._resume_runner, resume_model=self._resume_model, )
[docs] def on_epoch_end_best(self, runner: "IRunner") -> None: """Event handler.""" if self.loader_key is not None: score = runner.epoch_metrics[self.loader_key][self.metric_key] else: score = runner.epoch_metrics[self.metric_key] self._handle_epoch(runner=runner, score=score) if self.save_best: best_logprefix = f"{self.logdir}/{self.mode}.best" self._save(runner, self._storage[0].obj, best_logprefix)
[docs] def on_epoch_end_last(self, runner: "IRunner") -> None: """Event handler.""" self._handle_epoch(runner=runner, score=runner.epoch_step)
def on_experiment_end(self, runner: "IRunner") -> None: """Event handler.""" if runner.engine.process_index == 0: log_message = "Top models:\n" log_message += "\n".join( [ f"{checkpoint.logpath}\t{checkpoint.metric:3.4f}" for checkpoint in self._storage ] ) print(log_message) if self.load_best_on_end: self._load(runner=runner, resume_logpath=self._storage[0].logpath)
__all__ = ["CheckpointCallback"]