Shortcuts

Source code for catalyst.dl.callbacks.gan

from typing import Dict, List, Optional

from catalyst.core import CriterionCallback, OptimizerCallback
from catalyst.dl import MetricCallback, State

# MetricCallbacks alternatives for input/output keys


[docs]class WassersteinDistanceCallback(MetricCallback): """Callback to compute Wasserstein distance metric."""
[docs] def __init__( self, prefix: str = "wasserstein_distance", real_validity_output_key: str = "real_validity", fake_validity_output_key: str = "fake_validity", multiplier: float = 1.0, ): """ Args: prefix (str): real_validity_output_key (str): fake_validity_output_key (str): """ super().__init__( prefix, metric_fn=self.get_wasserstein_distance, input_key={}, output_key={ real_validity_output_key: "real_validity", fake_validity_output_key: "fake_validity", }, multiplier=multiplier, )
[docs] def get_wasserstein_distance(self, real_validity, fake_validity): """Computes Wasserstein distance.""" return real_validity.mean() - fake_validity.mean()
# CriterionCallback extended
[docs]class GradientPenaltyCallback(CriterionCallback): """Criterion Callback to compute Gradient Penalty."""
[docs] def __init__( self, real_input_key: str = "data", fake_output_key: str = "fake_data", condition_keys: Optional[List[str]] = None, critic_model_key: str = "critic", critic_criterion_key: str = "critic", real_data_criterion_key: str = "real_data", fake_data_criterion_key: str = "fake_data", condition_args_criterion_key: str = "critic_condition_args", prefix: str = "loss", criterion_key: Optional[str] = None, multiplier: float = 1.0, ): """ Args: real_input_key (str): real data key in ``state.input`` fake_output_key (str): fake data key in ``state.output`` condition_keys (List[str], optional): all condition keys in ``state.input`` for critic critic_model_key (str): key for critic model in ``state.model`` critic_criterion_key (str): key for critic model in criterion real_data_criterion_key (str): key for real data in criterion fake_data_criterion_key (str): key for fake data in criterion condition_args_criterion_key (str): key for all condition args in criterion prefix (str): criterion_key (str): multiplier (float): """ super().__init__( input_key=real_input_key, output_key=fake_output_key, prefix=prefix, criterion_key=criterion_key, multiplier=multiplier, ) self.condition_keys = condition_keys or [] self.critic_model_key = critic_model_key self.critic_criterion_key = critic_criterion_key self.real_data_criterion_key = real_data_criterion_key self.fake_data_criterion_key = fake_data_criterion_key self.condition_args_criterion_key = condition_args_criterion_key
def _compute_metric(self, state: State): criterion_kwargs = { self.real_data_criterion_key: state.batch_in[self.input_key], self.fake_data_criterion_key: state.batch_out[self.output_key], self.critic_criterion_key: state.model[self.critic_model_key], self.condition_args_criterion_key: [ state.batch_in[key] for key in self.condition_keys ], } criterion = state.get_attr("criterion", self.criterion_key) return criterion(**criterion_kwargs)
# Optimizer Callback with weights clamp after update
[docs]class WeightClampingOptimizerCallback(OptimizerCallback): """Optimizer callback + weights clipping after step is finished."""
[docs] def __init__( self, grad_clip_params: Optional[Dict] = None, accumulation_steps: int = 1, optimizer_key: Optional[str] = None, loss_key: str = "loss", decouple_weight_decay: bool = True, weight_clamp_value: float = 0.01, ): """ Args: grad_clip_params (dict, optional): accumulation_steps (int): optimizer_key (str, optional): loss_key (str): decouple_weight_decay (bool): weight_clamp_value (float): value to clamp weights after each optimization iteration .. note:: ``weight_clamp_value`` will clamp WEIGHTS, not GRADIENTS """ super().__init__( grad_clip_params=grad_clip_params, accumulation_steps=accumulation_steps, optimizer_key=optimizer_key, loss_key=loss_key, decouple_weight_decay=decouple_weight_decay, ) self.weight_clamp_value = weight_clamp_value
[docs] def on_batch_end(self, state: State) -> None: """On batch end event. Args: state (State): current state """ super().on_batch_end(state) if not state.is_train_loader: return optimizer = state.get_attr( key="optimizer", inner_key=self.optimizer_key ) need_gradient_step = ( self._accumulation_counter % self.accumulation_steps == 0 ) if need_gradient_step: for group in optimizer.param_groups: for param in group["params"]: param.data.clamp_( min=-self.weight_clamp_value, max=self.weight_clamp_value, )
__all__ = [ "WassersteinDistanceCallback", "GradientPenaltyCallback", "WeightClampingOptimizerCallback", ]