Source code for catalyst.callbacks.early_stop
from typing import TYPE_CHECKING
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
if TYPE_CHECKING:
from catalyst.core.runner import IRunner
[docs]class CheckRunCallback(Callback):
"""Executes only a pipeline part from the ``Experiment``.
Minimal working example (Notebook API):
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
# model training
runner = dl.SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
num_epochs=8,
verbose=True,
callbacks=[
dl.CheckRunCallback(num_batch_steps=3, num_epoch_steps=3)
]
)
"""
[docs] def __init__(self, num_batch_steps: int = 3, num_epoch_steps: int = 2):
"""
Args:
num_batch_steps: number of batches to iterate in epoch
num_epoch_steps: number of epoch to perform in a stage
"""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
self.num_batch_steps = num_batch_steps
self.num_epoch_steps = num_epoch_steps
[docs] def on_epoch_end(self, runner: "IRunner"):
"""Check if iterated specified number of epochs.
Args:
runner: current runner
"""
if runner.epoch >= self.num_epoch_steps:
runner.need_early_stop = True
[docs] def on_batch_end(self, runner: "IRunner"):
"""Check if iterated specified number of batches.
Args:
runner: current runner
"""
if runner.loader_batch_step >= self.num_batch_steps:
runner.need_early_stop = True
[docs]class EarlyStoppingCallback(Callback):
"""Early exit based on metric.
Minimal working example (Notebook API):
.. code-block:: python
import torch
from torch.utils.data import DataLoader, TensorDataset
from catalyst import dl
# data
num_samples, num_features = int(1e4), int(1e1)
X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
# model, criterion, optimizer, scheduler
model = torch.nn.Linear(num_features, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])
# model training
runner = dl.SupervisedRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
logdir="./logdir",
num_epochs=8,
verbose=True,
callbacks=[
dl.EarlyStoppingCallback(patience=2, metric="loss", minimize=True)
]
)
Example of usage in config API:
.. code-block:: yaml
stages:
...
stage_N:
...
callbacks_params:
...
early_stopping:
callback: EarlyStoppingCallback
# arguments for EarlyStoppingCallback
patience: 5
metric: my_metric
minimize: true
...
"""
[docs] def __init__(
self,
patience: int,
metric: str = "loss",
minimize: bool = True,
min_delta: float = 1e-6,
):
"""
Args:
patience: number of epochs with no improvement
after which training will be stopped.
metric: metric name to use for early stopping, default
is ``"loss"``.
minimize: if ``True`` then expected that metric should
decrease and early stopping will be performed only when metric
stops decreasing. If ``False`` then expected
that metric should increase. Default value ``True``.
min_delta: minimum change in the monitored metric
to qualify as an improvement, i.e. an absolute change
of less than min_delta, will count as no improvement,
default value is ``1e-6``.
"""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
self.best_score = None
self.metric = metric
self.patience = patience
self.num_bad_epochs = 0
self.is_better = None
if minimize:
self.is_better = lambda score, best: score <= (best - min_delta)
else:
self.is_better = lambda score, best: score >= (best + min_delta)
[docs] def on_epoch_end(self, runner: "IRunner") -> None:
"""Check if should be performed early stopping.
Args:
runner: current runner
"""
if runner.stage.startswith("infer"):
return
score = runner.valid_metrics[self.metric]
if self.best_score is None or self.is_better(score, self.best_score):
self.num_bad_epochs = 0
self.best_score = score
else:
self.num_bad_epochs += 1
if self.num_bad_epochs >= self.patience:
print(f"Early stop at {runner.epoch} epoch")
runner.need_early_stop = True
__all__ = ["CheckRunCallback", "EarlyStoppingCallback"]