Shortcuts

Source code for catalyst.dl.runner.gan

from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

from catalyst.dl import GanExperiment, GanState
from catalyst.utils.tools.typing import Criterion, Device, Model, Optimizer

from .core import Runner


[docs]class MultiPhaseRunner(Runner): """Base Runner with multiple phases."""
[docs] def __init__( self, model: Union[Model, Dict[str, Model]] = None, device: Device = None, input_batch_keys: List[str] = None, registered_phases: Tuple[Tuple[str, Union[str, Callable]], ...] = None, ): """ Args: model: gan models device: runner's device input_batch_keys: list of strings of keys for batch elements, e.g. ``input_batch_keys = ["features", "targets"]`` and your DataLoader returns 2 tensors (images and targets) when state.input will be ``{"features": batch[0], "targets": batch[1]}`` registered_phases: Tuple of pairs (phase_name, phase_forward_function) phase_forward_function's may be also str, in that case Runner should have method with same name, which will be called """ super().__init__(model, device) self.input_batch_keys = input_batch_keys or [] self.registered_phases = {} for phase_name, phase_batch_forward_fn in registered_phases: if not (isinstance(phase_name, str) or phase_name is None): raise ValueError( f"phase '{phase_name}' of type '{type(phase_name)}' " f"not supported, must be str of None" ) if phase_name in self.registered_phases: raise ValueError(f"phase '{phase_name}' already registered") if isinstance(phase_batch_forward_fn, str): assert hasattr(self, phase_batch_forward_fn) phase_batch_forward_fn = getattr(self, phase_batch_forward_fn) assert isinstance( phase_batch_forward_fn, Callable ), "must be callable" self.registered_phases[phase_name] = phase_batch_forward_fn
def _batch2device(self, batch: Mapping[str, Any], device): if isinstance(batch, (list, tuple)): assert len(batch) >= len(self.input_batch_keys) batch = { key: value for key, value in zip(self.input_batch_keys, batch) } return super()._batch2device(batch, device)
[docs] def forward(self, batch, **kwargs): """Forward call. @TODO: Docs (add `batch` shapes). Contribution is welcome. """ if self.state.phase not in self.registered_phases: raise ValueError(f"Unknown phase: '{self.state.phase}'") return self.registered_phases[self.state.phase]()
def _handle_batch(self, batch: Mapping[str, Any]) -> None: """ Inner method to handle specified data batch. Used to make a train/valid/infer step during Experiment run. Args: batch (Mapping[str, Any]): dictionary with data batches from DataLoader. """ self.state.batch_out = self.forward(batch)
[docs]class GanRunner(MultiPhaseRunner): """ Runner with logic for single-generator single-discriminator GAN training. Various conditioning types, penalties and regularization (such as WGAN-GP) can be easily derived from this class """ experiment: GanExperiment state: GanState _experiment_fn: Callable = GanExperiment _state_fn: callable = GanState
[docs] def __init__( self, model: Union[Model, Dict[str, Model]] = None, device: Device = None, input_batch_keys: Optional[List[str]] = None, # input keys data_input_key: str = "data", class_input_key: str = "class_targets", noise_input_key: str = "noise", # output keys fake_logits_output_key: str = "fake_logits", real_logits_output_key: str = "real_logits", fake_data_output_key: str = "fake_data", # condition_keys fake_condition_keys: List[str] = None, real_condition_keys: List[str] = None, # phases: generator_train_phase: str = "generator_train", discriminator_train_phase: str = "discriminator_train", # model keys: generator_model_key: str = "generator", discriminator_model_key: str = "discriminator", ): """ Args: model: Model device: Device input_batch_keys: list of strings of keys for batch elements, e.g. ``input_batch_keys = ["features", "targets"]`` and your DataLoader returns 2 tensors (images and targets) when state.input will be ``{"features": batch[0], "targets": batch[1]}`` data_input_key: real distribution to fit class_input_key: labels for real distribution noise_input_key: noise fake_logits_output_key: prediction scores of discriminator for fake data real_logits_output_key: prediction scores of discriminator for real data fake_data_output_key: generated data fake_condition_keys: list of all conditional inputs of discriminator (fake data conditions) (appear in same order as in generator model forward() call) real_condition_keys: list of all conditional inputs of discriminator (real data conditions) (appear in same order as in generator model forward() call) generator_train_phase(str): name for generator training phase discriminator_train_phase(str): name for discriminator training phase generator_model_key: name for generator model, e.g. "generator" discriminator_model_key: name for discriminator model, e.g. "discriminator" .. note:: THIS RUNNER SUPPORTS ONLY EQUALLY CONDITIONED generator and discriminator (i.e. if generator is conditioned on 3 variables, discriminator must be conditioned on same 3 variables) """ input_batch_keys = input_batch_keys or [data_input_key] registered_phases = ( (generator_train_phase, "_generator_train_phase"), (discriminator_train_phase, "_discriminator_train_phase"), (None, "_discriminator_train_phase"), ) super().__init__(model, device, input_batch_keys, registered_phases) # input keys self.data_input_key = data_input_key self.class_input_key = class_input_key self.noise_input_key = noise_input_key # output keys self.fake_logits_output_key = fake_logits_output_key self.real_logits_output_key = real_logits_output_key self.fake_data_output_key = fake_data_output_key # condition keys self.fake_condition_keys = fake_condition_keys or [] self.real_condition_keys = real_condition_keys or [] # check that discriminator will have # same number of arguments for real/fake data assert len(self.fake_condition_keys) == len( self.real_condition_keys ), "Number of real/fake conditions should be the same" # Note: this generator supports only # EQUALLY CONDITIONED generator (G) and discriminator (D) # below are some thoughts why: # # 1. G is more conditioned than D. # # it would be strange if G is conditioned on something # and D is NOT conditioned on same variable # which will most probably lead to interpreting that variable # as additional noise # # 2. D is more conditioned than G. # # imagine D to have additional condition 'cond_var' which is not # condition of G. now you have: # fake_data = G(z, *other_conditions) # fake_score = D(fake_data, cond_var, *other_conditions) # in the above example fake_data and cond_var are ~independent? # if they are not independent (e.g. cond_var represents # class condition which is fixed to the single "cat" class, # which may be used for finetuning pretrained GAN for specific # class) such configuration may have some sense # so they case #2 may have some sense, however for simplicity # it is not implemented in this Runner # model keys self.generator_key = generator_model_key self.discriminator_key = discriminator_model_key
def _prepare_for_stage(self, stage: str): super()._prepare_for_stage(stage) self.generator = self.model[self.generator_key] self.discriminator = self.model[self.discriminator_key] # Common utility functions def _get_noise_and_conditions(self): """Returns generator inputs.""" z = self.state.batch_in[self.noise_input_key] conditions = [ self.state.batch_in[key] for key in self.fake_condition_keys ] return z, conditions def _get_real_data_conditions(self): """Returns discriminator conditions (for real data).""" return [self.state.batch_in[key] for key in self.real_condition_keys] def _get_fake_data_conditions(self): """Returns discriminator conditions (for fake data).""" return [self.state.batch_in[key] for key in self.fake_condition_keys] # concrete phase methods def _generator_train_phase(self): """Forward call on generator training phase.""" z, g_conditions = self._get_noise_and_conditions() d_fake_conditions = self._get_fake_data_conditions() fake_data = self.generator(z, *g_conditions) fake_logits = self.discriminator(fake_data, *d_fake_conditions) return { self.fake_data_output_key: fake_data, self.fake_logits_output_key: fake_logits, } def _discriminator_train_phase(self): """Forward call on discriminator training phase.""" z, g_conditions = self._get_noise_and_conditions() d_fake_conditions = self._get_fake_data_conditions() d_real_conditions = self._get_real_data_conditions() fake_data = self.generator(z, *g_conditions) fake_logits = self.discriminator( fake_data.detach(), *d_fake_conditions ) real_logits = self.discriminator( self.state.batch_in[self.data_input_key], *d_real_conditions ) return { self.fake_data_output_key: fake_data, self.fake_logits_output_key: fake_logits, self.real_logits_output_key: real_logits, }
[docs] def train( self, model: Model, loaders: "OrderedDict[str, DataLoader]", # noqa: F821 callbacks: "OrderedDict[str, Callback]" = None, # noqa: F821 phase2callbacks: Dict[str, List[str]] = None, criterion: Criterion = None, optimizer: Optimizer = None, num_epochs: int = 1, main_metric: str = "loss", minimize_metric: bool = True, state_kwargs: Dict = None, checkpoint_data: Dict = None, distributed_params: Dict = None, monitoring_params: Dict = None, logdir: str = None, verbose: bool = False, initial_seed: int = 42, check: bool = False, ) -> None: """Starts the training process of the model. Args: model: models, usually generator and discriminator loaders: dictionary containing one or several ``torch.utils.data.DataLoader`` for training and validation callbacks: list of callbacks phase2callbacks: dictionary with lists of callback names which should be wrapped for appropriate phase for example: {"generator_train": "loss_g", "optim_g"} "loss_g" and "optim_g" callbacks from callbacks dict will be wrapped for "generator_train" phase in wrap_callbacks method criterion: criterion function optimizer: optimize num_epochs: number of experiment's epochs the metrics and save the checkpoints. main_metric: the key to the name of the metric by which the checkpoints will be selected. minimize_metric: flag to indicate whether the ``main_metric`` should be minimized. state_kwargs: additional state params to ``RunnerState`` checkpoint_data: additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc distributed_params: dictionary with the parameters for distributed and FP16 method monitoring_params: dict with the parameters for monitoring services logdir: path to output directory verbose: if true, it displays the status of the training to the console. initial_seed: experiment's initial seed value check: if True, then only checks that pipeline is working (3 epochs only) """ # Initialize and run experiment self.experiment = self.experiment_fn( model=model, loaders=loaders, callbacks=callbacks, logdir=logdir, criterion=criterion, optimizer=optimizer, num_epochs=num_epochs, main_metric=main_metric, minimize_metric=minimize_metric, verbose=verbose, state_kwargs=state_kwargs, checkpoint_data=checkpoint_data, distributed_params=distributed_params, monitoring_params=monitoring_params, initial_seed=initial_seed, phase2callbacks=phase2callbacks, ) self.run_experiment(experiment=self.experiment)
__all__ = ["MultiPhaseRunner", "GanRunner"]