import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True)
        )
    def forward(self, x):
        return self.block(x)

class PsiNet(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.net = nn.Sequential(
            ConvBlock(ch, ch),
            ConvBlock(ch, ch),
            ConvBlock(ch, ch)
        )
    def forward(self, x):
        return self.net(x)

class Generator(nn.Module):
    def __init__(self, ch, noise_ch):
        super().__init__()
        self.psi = PsiNet(ch)
        self.noise_ch = noise_ch
    def forward(self, z, prev_up):
        # z: [B, noise_ch, H, W], prev_up: [B, ch, H, W]
        x = z + prev_up
        res = self.psi(x)
        return prev_up + res

class Discriminator(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.net = nn.Sequential(
            ConvBlock(ch, ch),
            ConvBlock(ch, ch),
            ConvBlock(ch, 1)
        )
    def forward(self, x):
        return self.net(x) 