Shortcuts

Source code for catalyst.dl.experiment.gan

from collections import OrderedDict
from typing import Dict, List

from catalyst.dl import (
    Callback, CheckpointCallback, ConsoleLogger, ExceptionCallback,
    PhaseBatchWrapperCallback, PhaseManagerCallback, VerboseLogger
)
from .base import BaseExperiment


[docs]class GanExperiment(BaseExperiment): """ One-staged GAN experiment """
[docs] def __init__( self, *, phase2callbacks: Dict[str, List[str]] = None, **kwargs, ): """ Args: model (Model or Dict[str, Model]): models, usually generator and discriminator loaders (dict): dictionary containing one or several ``torch.utils.data.DataLoader`` for training and validation callbacks (List[catalyst.dl.Callback]): list of callbacks logdir (str): path to output directory stage (str): current stage criterion (Criterion): criterion function optimizer (Optimizer): optimizer scheduler (Scheduler): scheduler num_epochs (int): number of experiment's epochs valid_loader (str): loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric (str): the key to the name of the metric by which the checkpoints will be selected. minimize_metric (bool): flag to indicate whether the ``main_metric`` should be minimized. verbose (bool): ff true, it displays the status of the training to the console. state_kwargs (dict): additional state params to ``RunnerState`` checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc distributed_params (dict): dictionary with the parameters for distributed and FP16 method monitoring_params (dict): dict with the parameters for monitoring services initial_seed (int): experiment's initial seed value phase2callbacks (dict): dictionary with lists of callback names which should be wrapped for appropriate phase for example: {"generator_train": "loss_g", "optim_g"} "loss_g" and "optim_g" callbacks from callbacks dict will be wrapped for "generator_train" phase in wrap_callbacks method """ super().__init__(**kwargs) self.wrap_callbacks(phase2callbacks or {})
[docs] def wrap_callbacks(self, phase2callbacks) -> None: """Phase wrapping procedure for callbacks""" discriminator_phase_name = self._additional_state_kwargs[ "discriminator_train_phase"] discriminator_phase_num = self._additional_state_kwargs[ "discriminator_train_num"] generator_phase_name = self._additional_state_kwargs[ "generator_train_phase"] generator_phase_num = self._additional_state_kwargs[ "generator_train_num"] self._callbacks["phase_manager"] = PhaseManagerCallback( train_phases=OrderedDict( [ (discriminator_phase_name, discriminator_phase_num), (generator_phase_name, generator_phase_num), ] ), valid_mode="all", ) for phase_name, callback_name_list in phase2callbacks.items(): # TODO: Check for phase in state_params? for callback_name in callback_name_list: callback = self._callbacks.pop(callback_name) self._callbacks[callback_name] = PhaseBatchWrapperCallback( base_callback=callback, active_phases=[phase_name], )
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": callbacks = super().get_callbacks(stage=stage) default_callbacks = [] if self._verbose: default_callbacks.append(("verbose", VerboseLogger)) if not stage.startswith("infer"): default_callbacks.append(("saver", CheckpointCallback)) default_callbacks.append(("console", ConsoleLogger)) default_callbacks.append(("exception", ExceptionCallback)) # Check for absent callbacks and add them for callback_name, callback_fn in default_callbacks: is_already_present = any( isinstance(x, callback_fn) for x in callbacks.values() ) if not is_already_present: callbacks[callback_name] = callback_fn() return callbacks
__all__ = ["GanExperiment"]