import torch
import torch.nn as nn


class LatentDiscriminator(nn.Module):
    """
    Discriminator operating on images x ∈ R^(3x512x512) for WGAN-GP.
    FC(d, 256) + ReLU → FC(256, 128) + ReLU → FC(128, 1) + Sigmoid.
    """

    def __init__(self, latent_dim: int = 256) -> None:
        super().__init__()
        # Calculate input size: 3 * 512 * 512 = 786,432
        input_size = 3 * 512 * 512
        self.net = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Flatten the image: (B, 3, 512, 512) -> (B, 786432)
        x_flat = x.view(x.size(0), -1)
        return self.net(x_flat)


