import os
import numpy as np
import torch
from PIL import Image
import torchvision.transforms as T
from pytorch_fid import fid_score
import lpips

def compute_fid(real_dir, fake_dir):
    fid = fid_score.calculate_fid_given_paths([real_dir, fake_dir], batch_size=32, device='cuda' if torch.cuda.is_available() else 'cpu', dims=2048)
    return fid

def compute_diversity_score(fake_dir, n_samples=100):
    imgs = [Image.open(os.path.join(fake_dir, f)).convert('RGB') for f in os.listdir(fake_dir)[:n_samples]]
    imgs = [T.ToTensor()(img).unsqueeze(0) for img in imgs]
    imgs = torch.cat(imgs, 0)
    # LPIPS
    loss_fn = lpips.LPIPS(net='alex').to(imgs.device)
    scores = []
    for i in range(n_samples-1):
        d = loss_fn(imgs[i].to(imgs.device), imgs[i+1].to(imgs.device))
        scores.append(d.item())
    return np.mean(scores) 