Source code for catalyst.core.callbacks.early_stop
from catalyst.core import Callback, CallbackNode, CallbackOrder, State
[docs]class CheckRunCallback(Callback):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(self, num_batch_steps: int = 3, num_epoch_steps: int = 2):
"""@TODO: Docs. Contribution is welcome."""
super().__init__(order=CallbackOrder.External, node=CallbackNode.All)
self.num_batch_steps = num_batch_steps
self.num_epoch_steps = num_epoch_steps
[docs] def on_epoch_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.epoch >= self.num_epoch_steps:
state.need_early_stop = True
[docs] def on_batch_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.loader_step >= self.num_batch_steps:
state.need_early_stop = True
[docs]class EarlyStoppingCallback(Callback):
"""@TODO: Docs. Contribution is welcome."""
[docs] def __init__(
self,
patience: int,
metric: str = "loss",
minimize: bool = True,
min_delta: float = 1e-6,
):
"""@TODO: Docs. Contribution is welcome."""
super().__init__(order=CallbackOrder.External, node=CallbackNode.All)
self.best_score = None
self.metric = metric
self.patience = patience
self.num_bad_epochs = 0
self.is_better = None
if minimize:
self.is_better = lambda score, best: score <= (best - min_delta)
else:
self.is_better = lambda score, best: score >= (best + min_delta)
[docs] def on_epoch_end(self, state: State) -> None:
"""@TODO: Docs. Contribution is welcome."""
if state.stage_name.startswith("infer"):
return
score = state.valid_metrics[self.metric]
if self.best_score is None:
self.best_score = score
if self.is_better(score, self.best_score):
self.num_bad_epochs = 0
self.best_score = score
else:
self.num_bad_epochs += 1
if self.num_bad_epochs >= self.patience:
print(f"Early stop at {state.epoch} epoch")
state.need_early_stop = True