import argparse
import os

import torch
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image

from src.models import Encoder, Decoder, LatentGenerator
from src.utils.checkpoint import load_checkpoint


def load_models(checkpoint_path: str, latent_dim: int, device: torch.device):
    encoder = Encoder(image_channels=3, latent_dim=latent_dim).to(device)
    decoder = Decoder(image_channels=3, latent_dim=latent_dim).to(device)
    gen = LatentGenerator(latent_dim=latent_dim).to(device)

    ckpt = load_checkpoint(checkpoint_path, map_location=str(device))
    encoder.load_state_dict(ckpt['encoder'])
    decoder.load_state_dict(ckpt['decoder'])
    gen.load_state_dict(ckpt['gen'])
    encoder.eval(); decoder.eval(); gen.eval()
    return encoder, decoder, gen


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, required=True)
    parser.add_argument('--reference', type=str, required=True)
    parser.add_argument('--alpha', type=float, default=0.75)
    parser.add_argument('--num_samples', type=int, default=4)
    parser.add_argument('--latent_dim', type=int, default=256)
    parser.add_argument('--image_size', type=int, default=512)
    parser.add_argument('--out_dir', type=str, required=True)
    args = parser.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    encoder, decoder, gen = load_models(args.checkpoint, args.latent_dim, device)

    tfm = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    img = Image.open(args.reference).convert('RGB')
    x = tfm(img).unsqueeze(0).to(device)

    with torch.no_grad():
        # Get reference image
        ref_img = (x * 0.5 + 0.5).clamp(0, 1)
        save_image(ref_img, os.path.join(args.out_dir, 'reference.png'))
        
        for i in range(args.num_samples):
            # Generate image from noise
            noise = torch.randn(1, args.latent_dim, device=device)
            gen_img = gen(noise)
            gen_img = (gen_img * 0.5 + 0.5).clamp(0, 1)
            
            # Blend reference and generated image
            blended = args.alpha * ref_img + (1.0 - args.alpha) * gen_img
            blended = blended.clamp(0, 1)
            
            save_image(blended, os.path.join(args.out_dir, f'sample_{i:02d}.png'))


if __name__ == '__main__':
    main()


