Source code for catalyst.dl.experiment.gan
from collections import OrderedDict
from typing import Dict, List
from catalyst.dl import (
Callback, CheckpointCallback, ConsoleLogger, ExceptionCallback,
PhaseBatchWrapperCallback, PhaseManagerCallback, VerboseLogger
)
from .base import BaseExperiment
[docs]class GanExperiment(BaseExperiment):
"""
One-staged GAN experiment
"""
[docs] def __init__(
self,
*,
phase2callbacks: Dict[str, List[str]] = None,
**kwargs,
):
"""
Args:
model (Model or Dict[str, Model]): models,
usually generator and discriminator
loaders (dict): dictionary containing one or several
``torch.utils.data.DataLoader`` for training and validation
callbacks (List[catalyst.dl.Callback]): list of callbacks
logdir (str): path to output directory
stage (str): current stage
criterion (Criterion): criterion function
optimizer (Optimizer): optimizer
scheduler (Scheduler): scheduler
num_epochs (int): number of experiment's epochs
valid_loader (str): loader name used to calculate
the metrics and save the checkpoints. For example,
you can pass `train` and then
the metrics will be taken from `train` loader.
main_metric (str): the key to the name of the metric
by which the checkpoints will be selected.
minimize_metric (bool): flag to indicate whether
the ``main_metric`` should be minimized.
verbose (bool): ff true, it displays the status of the training
to the console.
state_kwargs (dict): additional state params to ``RunnerState``
checkpoint_data (dict): additional data to save in checkpoint,
for example: ``class_names``, ``date_of_training``, etc
distributed_params (dict): dictionary with the parameters
for distributed and FP16 method
monitoring_params (dict): dict with the parameters
for monitoring services
initial_seed (int): experiment's initial seed value
phase2callbacks (dict): dictionary with lists of callback names
which should be wrapped for appropriate phase
for example: {"generator_train": "loss_g", "optim_g"}
"loss_g" and "optim_g" callbacks from callbacks dict
will be wrapped for "generator_train" phase
in wrap_callbacks method
"""
super().__init__(**kwargs)
self.wrap_callbacks(phase2callbacks or {})
[docs] def wrap_callbacks(self, phase2callbacks) -> None:
"""Phase wrapping procedure for callbacks"""
discriminator_phase_name = self._additional_state_kwargs[
"discriminator_train_phase"]
discriminator_phase_num = self._additional_state_kwargs[
"discriminator_train_num"]
generator_phase_name = self._additional_state_kwargs[
"generator_train_phase"]
generator_phase_num = self._additional_state_kwargs[
"generator_train_num"]
self._callbacks["phase_manager"] = PhaseManagerCallback(
train_phases=OrderedDict(
[
(discriminator_phase_name, discriminator_phase_num),
(generator_phase_name, generator_phase_num),
]
),
valid_mode="all",
)
for phase_name, callback_name_list in phase2callbacks.items():
# TODO: Check for phase in state_params?
for callback_name in callback_name_list:
callback = self._callbacks.pop(callback_name)
self._callbacks[callback_name] = PhaseBatchWrapperCallback(
base_callback=callback,
active_phases=[phase_name],
)
[docs] def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
callbacks = super().get_callbacks(stage=stage)
default_callbacks = []
if self._verbose:
default_callbacks.append(("verbose", VerboseLogger))
if not stage.startswith("infer"):
default_callbacks.append(("saver", CheckpointCallback))
default_callbacks.append(("console", ConsoleLogger))
default_callbacks.append(("exception", ExceptionCallback))
# Check for absent callbacks and add them
for callback_name, callback_fn in default_callbacks:
is_already_present = any(
isinstance(x, callback_fn) for x in callbacks.values()
)
if not is_already_present:
callbacks[callback_name] = callback_fn()
return callbacks
__all__ = ["GanExperiment"]