Source code for catalyst.callbacks.optuna

from typing import TYPE_CHECKING

from catalyst.core.callback import Callback, CallbackOrder
from catalyst.extras.metric_handler import MetricHandler
from catalyst.settings import SETTINGS

if SETTINGS.optuna_required:
    import optuna

    from catalyst.core.runner import IRunner

[docs]class OptunaPruningCallback(Callback): """Optuna callback for pruning unpromising runs. This callback can be used for early stopping (pruning) unpromising runs. Args: trial: Optuna.Trial for the experiment. loader_key: loader key for best model selection (based on metric score over the dataset) metric_key: metric key for best model selection (based on metric score over the dataset) minimize: boolean flag to minimize the required metric min_delta: minimal delta for metric improve .. code-block:: python import optuna from catalyst.dl import SupervisedRunner, OptunaPruningCallback # some python code ... def objective(trial: optuna.Trial): # standard optuna code for model and/or optimizer suggestion ... runner = SupervisedRunner() runner.train( model=model, loaders=loaders, criterion=criterion, optimizer=optimizer, callbacks=[ OptunaPruningCallback(trial) # some other callbacks ... ], num_epochs=num_epochs, ) return runner.best_valid_metrics[runner.valid_metric] study = optuna.create_study() study.optimize(objective, n_trials=100, timeout=600) """ def __init__( self, trial: "optuna.Trial", loader_key: str, metric_key: str, minimize: bool, min_delta: float = 1e-6, ): """Init.""" super().__init__(CallbackOrder.External) self.trial = trial self.loader_key = loader_key self.metric_key = metric_key self.minimize = minimize self.is_better = MetricHandler(minimize=minimize, min_delta=min_delta) self.best_score = None def on_epoch_end(self, runner: "IRunner"): """Considering prune or not to prune current run at the end of the epoch. Args: runner: runner for current experiment Raises: TrialPruned: if current run should be pruned """ score = runner.epoch_metrics[self.loader_key][self.metric_key] if self.best_score is None or self.is_better(score, self.best_score): self.best_score = score, step=runner.epoch_step) # @TODO: hack self.trial.best_score = self.best_score if self.trial.should_prune(): message = "Trial was pruned at epoch {}.".format(runner.epoch_step) raise optuna.TrialPruned(message)
__all__ = ["OptunaPruningCallback"]