import torch
import torch.nn as nn


class LatentGenerator(nn.Module):
    """
    Generator that maps noise vector s ∈ R^d to image x' ∈ R^(3x512x512).
    FC → reshape to 512x4x4 → series of Deconv2D(5x5) + BN + ReLU → Tanh.
    Output in range [-1, 1].
    """

    def __init__(self, latent_dim: int = 256) -> None:
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.Tanh(),
        )

    def forward(self, noise: torch.Tensor) -> torch.Tensor:
        h = self.fc(noise)
        h = h.view(h.size(0), 512, 4, 4)
        x = self.net(h)
        return x


