Source code for catalyst.dl.experiment.supervised

from collections import OrderedDict

from catalyst.dl.callbacks import (
    CheckpointCallback, ConsoleLogger, CriterionCallback, OptimizerCallback,
    RaiseExceptionCallback, SchedulerCallback, TensorboardLogger,
    VerboseLogger
)
from catalyst.dl.core import Callback
from .base import BaseExperiment


[docs]class SupervisedExperiment(BaseExperiment): """Supervised experiment used mostly in Notebook API The main difference with BaseExperiment that it will add several callbacks by default if you haven't. Here are list of callbacks by default: CriterionCallback: measures loss with specified ``criterion``. OptimizerCallback: abstraction over ``optimizer`` step. SchedulerCallback: only in case if you provided scheduler to your experiment does `lr_scheduler.step` CheckpointCallback: saves model and optimizer state each epoch callback to save/restore your model/criterion/optimizer/metrics. ConsoleLogger: standard Catalyst logger, translates ``state.metrics`` to console and text file TensorboardLogger: will write ``state.metrics`` to tensorboard RaiseExceptionCallback: will raise exception if needed """
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """ Override of ``BaseExperiment.get_callbacks`` method. Will add several of callbacks by default in case they missed. Args: stage (str): name of stage. It should start with `infer` if you don't need default callbacks, as they required only for training stages. Returns: List[Callback]: list of callbacks for experiment """ callbacks = self._callbacks default_callbacks = [] if self._verbose: default_callbacks.append(("verbose", VerboseLogger)) if not stage.startswith("infer"): default_callbacks.append(("_criterion", CriterionCallback)) default_callbacks.append(("_optimizer", OptimizerCallback)) if self._scheduler is not None: default_callbacks.append(("_scheduler", SchedulerCallback)) default_callbacks.append(("_saver", CheckpointCallback)) default_callbacks.append(("console", ConsoleLogger)) default_callbacks.append(("tensorboard", TensorboardLogger)) default_callbacks.append(("exception", RaiseExceptionCallback)) for callback_name, callback_fn in default_callbacks: is_already_present = any( isinstance(x, callback_fn) for x in callbacks.values() ) if not is_already_present: callbacks[callback_name] = callback_fn() return callbacks
__all__ = ["SupervisedExperiment"]