from typing import Tuple

import torch
import torch.nn.functional as F


def gradient_penalty(critic, real_samples: torch.Tensor, fake_samples: torch.Tensor) -> torch.Tensor:
    """
    WGAN-GP gradient penalty computed on interpolations between real and fake samples.
    Operates on images (3x512x512).
    """
    batch_size = real_samples.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=real_samples.device)
    epsilon = epsilon.expand_as(real_samples)
    interpolates = epsilon * real_samples + (1.0 - epsilon) * fake_samples
    interpolates.requires_grad_(True)

    critic_interpolates = critic(interpolates)
    ones = torch.ones_like(critic_interpolates)
    gradients = torch.autograd.grad(
        outputs=critic_interpolates,
        inputs=interpolates,
        grad_outputs=ones,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(batch_size, -1)
    gp = ((gradients.norm(2, dim=1) - 1.0) ** 2).mean()
    return gp


