Source code for catalyst.dl.callbacks.gan

from typing import Any, Callable, Dict, List, Optional, Union  # isort:skip

from catalyst.core import CriterionCallback, OptimizerCallback
from catalyst.dl import Callback, CallbackOrder, State


"""
MetricCallbacks alternatives for input/output keys
"""


class MultiKeyMetricCallback(Callback):
    """
    A callback that returns single metric on `state.on_batch_end`
    """

    # TODO:
    #  merge it with MetricCallback in catalyst.core
    #  maybe after the changes with CriterionCallback will be finalized
    #  in the main repo
    def __init__(
        self,
        prefix: str,
        metric_fn: Callable,
        input_key: Optional[Union[str, List[str]]] = "targets",
        output_key: Optional[Union[str, List[str]]] = "logits",
        **metric_params
    ):
        """

        :param prefix:
        :param metric_fn:
        :param input_key:
        :param output_key:
        :param metric_params:
        """
        super().__init__(CallbackOrder.Metric)
        self.prefix = prefix
        self.metric_fn = metric_fn
        self.input_key = input_key
        self.output_key = output_key
        self.metric_params = metric_params

    @staticmethod
    def _get(dictionary: dict, keys: Optional[Union[str, List[str]]]) -> Any:
        if keys is None:
            result = dictionary
        elif isinstance(keys, list):
            result = {key: dictionary[key] for key in keys}
        else:
            result = dictionary[keys]
        return result

    def on_batch_end(self, state: State):
        """On batch end call"""
        outputs = self._get(state.batch_out, self.output_key)
        targets = self._get(state.batch_in, self.input_key)
        metric = self.metric_fn(outputs, targets, **self.metric_params)
        state.batch_metrics[self.prefix] = metric


[docs]class WassersteinDistanceCallback(MultiKeyMetricCallback): """ Callback to compute Wasserstein distance metric """
[docs] def __init__( self, prefix: str = "wasserstein_distance", real_validity_output_key: str = "real_validity", fake_validity_output_key: str = "fake_validity" ): """ :param prefix: :param real_validity_output_key: :param fake_validity_output_key: """ super().__init__( prefix, metric_fn=self.get_wasserstein_distance, input_key=None, output_key=[real_validity_output_key, fake_validity_output_key] ) self.real_validity_key = real_validity_output_key self.fake_validity_key = fake_validity_output_key
[docs] def get_wasserstein_distance(self, outputs, targets): """ Computes Wasserstein distance :param outputs: :param targets: :return: """ real_validity = outputs[self.real_validity_key] fake_validity = outputs[self.fake_validity_key] return real_validity.mean() - fake_validity.mean()
""" CriterionCallback extended """
[docs]class GradientPenaltyCallback(CriterionCallback): """ Criterion Callback to compute Gradient Penalty """
[docs] def __init__( self, real_input_key: str = "data", fake_output_key: str = "fake_data", condition_keys: List[str] = None, critic_model_key: str = "critic", critic_criterion_key: str = "critic", real_data_criterion_key: str = "real_data", fake_data_criterion_key: str = "fake_data", condition_args_criterion_key: str = "critic_condition_args", loss_key: str = "loss", criterion_key: str = None, multiplier: float = 1.0, ): """ :param real_input_key: real data key in state.input :param fake_output_key: fake data key in state.output :param condition_keys: all condition keys in state.input for critic :param critic_model_key: key for critic model in state.model :param critic_criterion_key: key for critic model in criterion :param real_data_criterion_key: key for real data in criterion :param fake_data_criterion_key: key for fake data in criterion :param condition_args_criterion_key: key for all condition args in criterion :param loss_key: :param criterion_key: :param multiplier: """ super().__init__( input_key=real_input_key, output_key=fake_output_key, prefix=loss_key, criterion_key=criterion_key, multiplier=multiplier ) self.condition_keys = condition_keys or [] self.critic_model_key = critic_model_key self.critic_criterion_key = critic_criterion_key self.real_data_criterion_key = real_data_criterion_key self.fake_data_criterion_key = fake_data_criterion_key self.condition_args_criterion_key = condition_args_criterion_key
def _compute_loss(self, state: State, criterion): criterion_kwargs = { self.real_data_criterion_key: state.batch_in[self.input_key], self.fake_data_criterion_key: state.batch_out[self.output_key], self.critic_criterion_key: state.model[self.critic_model_key], self.condition_args_criterion_key: [ state.batch_in[key] for key in self.condition_keys ] } return criterion(**criterion_kwargs)
""" Optimizer Callback with weights clamp after update """
[docs]class WeightClampingOptimizerCallback(OptimizerCallback): """ Optimizer callback + weights clipping after step is finished """
[docs] def __init__( self, grad_clip_params: Dict = None, accumulation_steps: int = 1, optimizer_key: str = None, loss_key: str = "loss", decouple_weight_decay: bool = True, weight_clamp_value: float = 0.01 ): """ :param grad_clip_params: :param accumulation_steps: :param optimizer_key: :param loss_key: :param decouple_weight_decay: :param weight_clamp_value: value to clamp weights after each optimization iteration Attention: will clamp WEIGHTS, not GRADIENTS """ super().__init__( grad_clip_params=grad_clip_params, accumulation_steps=accumulation_steps, optimizer_key=optimizer_key, loss_key=loss_key, decouple_weight_decay=decouple_weight_decay ) self.weight_clamp_value = weight_clamp_value
[docs] def on_batch_end(self, state: State): """On batch end event""" super().on_batch_end(state) if not state.need_backward_pass: return optimizer = state.get_attr( key="optimizer", inner_key=self.optimizer_key ) need_gradient_step = \ self._accumulation_counter % self.accumulation_steps == 0 if need_gradient_step: for group in optimizer.param_groups: for param in group["params"]: param.data.clamp_( min=-self.weight_clamp_value, max=self.weight_clamp_value )
__all__ = [ "WassersteinDistanceCallback", "GradientPenaltyCallback", "WeightClampingOptimizerCallback" ]