Source code for catalyst.dl.experiment.supervised

from collections import OrderedDict

from torch.optim.lr_scheduler import ReduceLROnPlateau

from catalyst.dl import (
from import Criterion, Optimizer, Scheduler

from .core import Experiment

[docs]class SupervisedExperiment(Experiment): """ Supervised experiment. The main difference with Experiment that it will add several callbacks by default if you haven't. Here are list of callbacks by default: CriterionCallback: measures loss with specified ``criterion``. OptimizerCallback: abstraction over ``optimizer`` step. SchedulerCallback: only in case if you provided scheduler to your experiment does `lr_scheduler.step` CheckpointCallback: saves model and optimizer state each epoch callback to save/restore your model/criterion/optimizer/metrics. ConsoleLogger: standard Catalyst logger, translates ``state.*_metrics`` to console and text file TensorboardLogger: will write ``state.*_metrics`` to tensorboard RaiseExceptionCallback: will raise exception if needed """
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """ Override of ``BaseExperiment.get_callbacks`` method. Will add several of callbacks by default in case they missed. Args: stage (str): name of stage. It should start with `infer` if you don't need default callbacks, as they required only for training stages. Returns: (OrderedDict[str, Callback]): Ordered dictionary of callbacks for experiment """ callbacks = super().get_callbacks(stage=stage) or OrderedDict() default_callbacks = [] if not stage.startswith("infer"): if self._criterion is not None and isinstance( self._criterion, Criterion ): default_callbacks.append(("_criterion", CriterionCallback)) if self._optimizer is not None and isinstance( self._optimizer, Optimizer ): default_callbacks.append(("_optimizer", OptimizerCallback)) if self._scheduler is not None and isinstance( self._scheduler, (Scheduler, ReduceLROnPlateau) ): default_callbacks.append(("_scheduler", SchedulerCallback)) for callback_name, callback_fn in default_callbacks: is_already_present = any( isinstance(x, callback_fn) for x in callbacks.values() ) if not is_already_present: callbacks[callback_name] = callback_fn() return callbacks
__all__ = ["SupervisedExperiment"]