Shortcuts

Source code for catalyst.contrib.nn.criterion.gan

# flake8: noqa
# TODO: refactor code, examples and docs
import torch
from torch import nn


[docs]class MeanOutputLoss(nn.Module): """ Criterion to compute simple mean of the output, completely ignoring target (maybe useful e.g. for WGAN real/fake validity averaging. """
[docs] def forward(self, output, target): """Compute criterion. @TODO: Docs (add typing). Contribution is welcome. """ return output.mean()
[docs]class GradientPenaltyLoss(nn.Module): """Criterion to compute gradient penalty. WARN: SHOULD NOT BE RUN WITH CriterionCallback, use special GradientPenaltyCallback instead """
[docs] def forward(self, fake_data, real_data, critic, critic_condition_args): """Compute gradient penalty. @TODO: Docs. Contribution is welcome. """ device = real_data.device # Random weight term for interpolation between real and fake samples alpha = torch.rand((real_data.size(0), 1, 1, 1), device=device) # Get random interpolation between real and fake samples interpolates = (alpha * real_data + ((1 - alpha) * fake_data)).detach() interpolates.requires_grad_(True) with torch.set_grad_enabled(True): # to compute in validation mode d_interpolates = critic(interpolates, *critic_condition_args) fake = torch.ones( (real_data.size(0), 1), device=device, requires_grad=False ) # Get gradient w.r.t. interpolates gradients = torch.autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty
__all__ = ["MeanOutputLoss", "GradientPenaltyLoss"]