Source code for catalyst.contrib.callbacks.optuna_callback

from typing import TYPE_CHECKING

import optuna

from catalyst.core.callback import Callback, CallbackOrder

    from catalyst.core.runner import IRunner

[docs]class OptunaPruningCallback(Callback): """ Optuna callback for pruning unpromising runs .. code-block:: python import optuna from catalyst.dl import SupervisedRunner, OptunaPruningCallback # some python code ... def objective(trial: optuna.Trial): # standard optuna code for model and/or optimizer suggestion ... runner = SupervisedRunner() runner.train( model=model, loaders=loaders, criterion=criterion, optimizer=optimizer, callbacks=[ OptunaPruningCallback(trial) # some other callbacks ... ], num_epochs=num_epochs, ) return runner.best_valid_metrics[runner.main_metric] study = optuna.create_study() study.optimize(objective, n_trials=100, timeout=600) Config API is supported through `catalyst-dl tune` command. """
[docs] def __init__(self, trial: optuna.Trial = None): """ This callback can be used for early stopping (pruning) unpromising runs. Args: trial: Optuna.Trial for experiment. """ super().__init__(CallbackOrder.External) self.trial = trial
[docs] def on_stage_start(self, runner: "IRunner"): """ On stage start hook. Takes ``optuna_trial`` from ``Experiment`` for future usage if needed. Args: runner: runner for current experiment Raises: NotImplementedError: if no Optuna trial was found on stage start. """ trial = runner.experiment.trial if ( self.trial is None and trial is not None and isinstance(trial, optuna.Trial) ): self.trial = trial if self.trial is None: raise NotImplementedError("No Optuna trial found for logging")
[docs] def on_epoch_end(self, runner: "IRunner"): """ On epoch end hook. Considering prune or not to prune current run at current epoch. Args: runner: runner for current experiment Raises: TrialPruned: if current run should be pruned """ metric_value = runner.valid_metrics[runner.main_metric], step=runner.epoch) if self.trial.should_prune(): message = "Trial was pruned at epoch {}.".format(runner.epoch) raise optuna.TrialPruned(message)
OptunaCallback = OptunaPruningCallback __all__ = ["OptunaCallback", "OptunaPruningCallback"]