Source code for catalyst.contrib.dl.callbacks.optuna_callback

import optuna

from catalyst.core import Callback, CallbackOrder, IRunner

[docs]class OptunaCallback(Callback): """ Optuna callback for pruning unpromising runs .. code-block:: python import optuna from catalyst.dl import SupervisedRunner from catalyst.dl.callbacks import OptunaCallback # some python code ... def objective(trial: optuna.Trial): # standard optuna code for model and/or optimizer suggestion ... runner = SupervisedRunner() runner.train( model=model, loaders=loaders, criterion=criterion, optimizer=optimizer, callbacks=[ OptunaCallback(trial) # some other callbacks ... ], num_epochs=num_epochs, ) return runner.best_valid_metrics[runner.main_metric] study = optuna.create_study() study.optimize(objective, n_trials=100, timeout=600) Config API is not supported. """
[docs] def __init__(self, trial: optuna.Trial): """ This callback can be used for early stopping (pruning) unpromising runs. Args: trial: Optuna.Trial for experiment. """ super(OptunaCallback, self).__init__(CallbackOrder.External) self.trial = trial
[docs] def on_epoch_end(self, runner: "IRunner"): """ On epoch end action. Considering prune or not to prune current run at current epoch. Raises: TrialPruned: if current run should be pruned Args: runner: runner for current experiment """ metric_value = runner.valid_metrics[runner.main_metric], step=runner.epoch) if self.trial.should_prune(): message = "Trial was pruned at epoch {}.".format(runner.epoch) raise optuna.TrialPruned(message)