Source code for catalyst.dl.experiment.supervised

from collections import OrderedDict

from catalyst.dl.callbacks import (
    CheckpointCallback, ConsoleLogger, CriterionCallback, OptimizerCallback,
    RaiseExceptionCallback, SchedulerCallback, TensorboardLogger,
    VerboseLogger
)
from catalyst.dl.core import Callback
from .base import BaseExperiment


[docs]class SupervisedExperiment(BaseExperiment):
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": callbacks = self._callbacks default_callbacks = [] if self._verbose: default_callbacks.append( ("_verbose_logger", "verbose", VerboseLogger) ) if not stage.startswith("infer"): default_callbacks.extend([ (self._criterion, "_criterion", CriterionCallback), (self._optimizer, "_optimizer", OptimizerCallback), (self._scheduler, "_scheduler", SchedulerCallback), ("_default_saver", "_saver", CheckpointCallback), ("_console_logger", "console", ConsoleLogger), ("_tensorboard_logger", "tensorboard", TensorboardLogger) ]) default_callbacks.append( ("_exception", "exception", RaiseExceptionCallback) ) for component, callback_name, callback_fn in default_callbacks: is_already_present = any( isinstance(x, callback_fn) for x in callbacks.values() ) if component is not None and not is_already_present: callbacks[callback_name] = callback_fn() return callbacks
__all__ = ["SupervisedExperiment"]