Shortcuts

Source code for catalyst.experiments.supervised

from collections import OrderedDict

from torch.optim.lr_scheduler import ReduceLROnPlateau

from catalyst.callbacks.criterion import CriterionCallback
from catalyst.callbacks.optimizer import (
    AMPOptimizerCallback,
    IOptimizerCallback,
    OptimizerCallback,
)
from catalyst.callbacks.scheduler import ISchedulerCallback, SchedulerCallback
from catalyst.core.callback import Callback
from catalyst.core.functional import check_callback_isinstance
from catalyst.experiments.experiment import Experiment
from catalyst.typing import Criterion, Optimizer, Scheduler
from catalyst.utils.distributed import check_amp_available


[docs]class SupervisedExperiment(Experiment): """ Supervised experiment. The main difference with Experiment that it will add several callbacks by default if you haven't. Here are list of callbacks by default: CriterionCallback: measures loss with specified ``criterion``. OptimizerCallback: abstraction over ``optimizer`` step. SchedulerCallback: only in case if you provided scheduler to your experiment does `lr_scheduler.step` CheckpointCallback: saves model and optimizer state each epoch callback to save/restore your model/criterion/optimizer/metrics. ConsoleLogger: translates ``runner.*_metrics`` to console and text file. TensorboardLogger: writes ``runner.*_metrics`` to tensorboard. RaiseExceptionCallback: will raise exception if needed. """
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """ Override of ``BaseExperiment.get_callbacks`` method. Will add several of callbacks by default in case they missed. Args: stage: name of stage. It should start with `infer` if you don't need default callbacks, as they required only for training stages. Returns: OrderedDict[str, Callback]: Ordered dictionary of callbacks for experiment """ callbacks = super().get_callbacks(stage=stage) or OrderedDict() # default_callbacks = [(Name, InterfaceClass, InstanceFactory)] default_callbacks = [] is_amp_enabled = ( self.distributed_params.get("amp", False) and check_amp_available() ) optimizer_cls = ( AMPOptimizerCallback if is_amp_enabled else OptimizerCallback ) if not stage.startswith("infer"): if self._criterion is not None and isinstance( self._criterion, Criterion ): default_callbacks.append( ("_criterion", None, CriterionCallback) ) if self._optimizer is not None and isinstance( self._optimizer, Optimizer ): default_callbacks.append( ("_optimizer", IOptimizerCallback, optimizer_cls) ) if self._scheduler is not None and isinstance( self._scheduler, (Scheduler, ReduceLROnPlateau) ): default_callbacks.append( ("_scheduler", ISchedulerCallback, SchedulerCallback) ) for ( callback_name, callback_interface, callback_fn, ) in default_callbacks: callback_interface = callback_interface or callback_fn is_already_present = any( check_callback_isinstance(x, callback_interface) for x in callbacks.values() ) if not is_already_present: callbacks[callback_name] = callback_fn() return callbacks
__all__ = ["SupervisedExperiment"]