import argparse
import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm

from src.data.dataset import ImageFolderDataset
from src.models import Encoder, Decoder, LatentGenerator, LatentDiscriminator
from src.utils.losses import gradient_penalty
from src.utils.checkpoint import save_checkpoint


def make_dataloader(data_dir: str, image_size: int, batch_size: int, num_workers: int) -> DataLoader:
    tfm = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    ds = ImageFolderDataset(data_dir, transform=tfm)
    return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str, required=True)
    parser.add_argument('--image_size', type=int, default=512)
    parser.add_argument('--latent_dim', type=int, default=256)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--critic_iters', type=int, default=5)
    parser.add_argument('--gp_lambda', type=float, default=10.0)
    parser.add_argument('--max_steps', type=int, default=10000)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--amp', action='store_true')
    args = parser.parse_args()

    os.makedirs(args.save_dir, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dl = make_dataloader(args.data_dir, args.image_size, args.batch_size, args.num_workers)

    encoder = Encoder(image_channels=3, latent_dim=args.latent_dim).to(device)
    decoder = Decoder(image_channels=3, latent_dim=args.latent_dim).to(device)
    gen = LatentGenerator(latent_dim=args.latent_dim).to(device)
    disc = LatentDiscriminator(latent_dim=args.latent_dim).to(device)

    recon_loss = nn.L1Loss()

    opt_ae = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, betas=(0.5, 0.999))
    opt_g = optim.Adam(gen.parameters(), lr=args.lr, betas=(0.5, 0.999))
    opt_d = optim.Adam(disc.parameters(), lr=args.lr, betas=(0.5, 0.999))

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    step = 0
    pbar = tqdm(total=args.max_steps)
    while step < args.max_steps:
        for images in dl:
            images = images.to(device)

            # 1) Update discriminator on images multiple times
            for _ in range(args.critic_iters):
                with torch.cuda.amp.autocast(enabled=args.amp):
                    # Real images from dataset
                    noise = torch.randn(images.size(0), args.latent_dim, device=device)
                    fake_images = gen(noise)

                    d_real = disc(images).mean()
                    d_fake = disc(fake_images.detach()).mean()
                    gp = gradient_penalty(disc, images, fake_images.detach())
                    d_loss = -(d_real - d_fake) + args.gp_lambda * gp

                opt_d.zero_grad(set_to_none=True)
                scaler.scale(d_loss).backward()
                scaler.step(opt_d)

            # 2) Update generator to fool the discriminator on images
            with torch.cuda.amp.autocast(enabled=args.amp):
                noise = torch.randn(images.size(0), args.latent_dim, device=device)
                fake_images = gen(noise)
                g_loss = -disc(fake_images).mean()

            opt_g.zero_grad(set_to_none=True)
            scaler.scale(g_loss).backward()
            scaler.step(opt_g)

            # 3) Update autoencoder for reconstruction
            with torch.cuda.amp.autocast(enabled=args.amp):
                z_enc = encoder(images)
                recon = decoder(z_enc)
                ae_loss = recon_loss(recon, images)

            opt_ae.zero_grad(set_to_none=True)
            scaler.scale(ae_loss).backward()
            scaler.step(opt_ae)
            scaler.update()

            if step % 100 == 0:
                with torch.no_grad():
                    save_image(
                        (images[:4] * 0.5 + 0.5).clamp(0, 1),
                        os.path.join(args.save_dir, f"real_{step:06d}.png"),
                        nrow=2,
                    )
                    save_image(
                        (recon[:4] * 0.5 + 0.5).clamp(0, 1),
                        os.path.join(args.save_dir, f"recon_{step:06d}.png"),
                        nrow=2,
                    )
                    save_image(
                        (fake_images[:4] * 0.5 + 0.5).clamp(0, 1),
                        os.path.join(args.save_dir, f"generated_{step:06d}.png"),
                        nrow=2,
                    )

            if step % 1000 == 0:
                ckpt_path = os.path.join(args.save_dir, 'latest.pt')
                save_checkpoint(
                    {
                        'step': step,
                        'encoder': encoder.state_dict(),
                        'decoder': decoder.state_dict(),
                        'gen': gen.state_dict(),
                        'disc': disc.state_dict(),
                        'args': vars(args),
                    },
                    ckpt_path,
                )

            step += 1
            pbar.update(1)
            pbar.set_description(f"step {step} | D:{d_loss.item():.3f} G:{g_loss.item():.3f} AE:{ae_loss.item():.3f}")
            if step >= args.max_steps:
                break
    pbar.close()


if __name__ == '__main__':
    main()


