import torch
import torch.optim as optim
from model import Generator, Discriminator
from pyramid import create_pyramid
from utils import save_tensor_img
from PIL import Image
import os
from torch.autograd import grad

# --- WGAN-GP loss ---
def calc_gradient_penalty(D, real_data, fake_data, device, lambda_gp=10.0):
    batch_size = real_data.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolates = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones_like(d_interpolates, device=device)
    gradients = grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp
    return gradient_penalty


def train_single_image(img_path, N=5, r=0.75, ch=32, noise_ch=3, alpha=10, iters=2000, save_dir='output'):
    img = Image.open(img_path).convert('RGB')
    pyramid = create_pyramid(img, N, r)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    Gs, Ds = [], []
    prev_up = None
    z_fixed = None
    os.makedirs(save_dir, exist_ok=True)

    for n in reversed(range(N+1)):
        x_n = pyramid[n].to(device)
        G = Generator(ch, noise_ch).to(device)
        D = Discriminator(ch).to(device)
        opt_G = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
        opt_D = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

        for it in range(iters):
            # 1. Sample noise
            if n == N:
                if z_fixed is None:
                    z_fixed = torch.randn_like(x_n)
                z = z_fixed
            else:
                z = torch.zeros_like(x_n)
            # 2. Upsample previous output
            if prev_up is None:
                prev = torch.zeros_like(x_n)
            else:
                prev = torch.nn.functional.interpolate(prev_up, size=x_n.shape[-2:], mode='bilinear', align_corners=False)
            # 3. Generator forward
            fake = G(z, prev)

            # 4. Discriminator update (WGAN-GP)
            D.train()
            opt_D.zero_grad()
            real_out = D(x_n)
            fake_out = D(fake.detach())
            d_loss = fake_out.mean() - real_out.mean()
            gp = calc_gradient_penalty(D, x_n, fake.detach(), device)
            d_loss_total = d_loss + gp
            d_loss_total.backward()
            opt_D.step()

            # 5. Generator update (adversarial + reconstruction loss)
            G.train()
            opt_G.zero_grad()
            fake = G(z, prev)
            fake_out = D(fake)
            adv_loss = -fake_out.mean()
            # Reconstruction loss
            if n == N:
                rec_loss = torch.nn.functional.mse_loss(G(z, prev), x_n)
            else:
                # For finer scales, use zero noise and upsampled previous rec
                rec_prev = torch.nn.functional.interpolate(prev_up, size=x_n.shape[-2:], mode='bilinear', align_corners=False) if prev_up is not None else torch.zeros_like(x_n)
                rec_loss = torch.nn.functional.mse_loss(G(torch.zeros_like(x_n), rec_prev), x_n)
            g_loss = adv_loss + alpha * rec_loss
            g_loss.backward()
            opt_G.step()

            # 6. Print loss, save intermediate images if needed
            if it % 500 == 0 or it == iters - 1:
                print(f"Scale {n} Iter {it}: D_loss={d_loss_total.item():.4f}, G_loss={g_loss.item():.4f}, Rec_loss={rec_loss.item():.4f}")
                save_tensor_img(fake, os.path.join(save_dir, f'scale{n}_iter{it}.png'))

        # Freeze weights, save models
        Gs.append(G.cpu())
        Ds.append(D.cpu())
        prev_up = fake.detach().cpu()
        # Save checkpoint for each scale
        torch.save({'G': G.state_dict(), 'D': D.state_dict()}, os.path.join(save_dir, f'checkpoint_scale{n}.pth'))

    # Save final model list
    torch.save({'Gs': [g.state_dict() for g in Gs], 'Ds': [d.state_dict() for d in Ds]}, os.path.join(save_dir, 'model_pyramid.pth'))
    return Gs, Ds

# Example usage:
# Gs, Ds = train_single_image('All-Paintings/Harvard/harvard_0.jpg') 