import torch
from PIL import Image
from torchvision import transforms
from lpips import LPIPS


lpips_model = LPIPS(net='alex')
lpips_model.eval()  # Set to evaluation mode

# Image preprocessing function
def load_image_to_tensor(path):
    """
    Load the image and convert it to a PyTorch tensor in the range [-1, 1]
    """
    img = Image.open(path).convert('RGB')  # Ensure it is an RGB image
    transform = transforms.Compose([
        transforms.ToTensor(),           # [0, 1]
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # [-1, 1]
    ])
    tensor_img = transform(img).unsqueeze(0)  # add batch : [1, 3, H, W]
    return tensor_img

# Load the image (please modify the path according to your actual path)
clean_path = r'F:\LW\balujia\coco\secretkk\00000.png'     # Original clean image
stego_path = r'F:\LW\balujia\coco\secret-revkk\00000.png'     # Steg image

img_clean = load_image_to_tensor(clean_path)
img_stego = load_image_to_tensor(stego_path)

# Ensure consistent image dimensions
assert img_clean.shape == img_stego.shape, "The image dimensions do not match, please ensure that both images are of the same size."

# compute LPIPS
def compute_lpips(img1, img2):
    with torch.no_grad():
        score = lpips_model(img1, img2).mean().item()
    return score

# compute NCC
def compute_ncc(img1, img2):
    # Flatten the image
    img1_flat = img1.view(img1.shape[0], -1)
    img2_flat = img2.view(img2.shape[0], -1)

    # Remove the mean
    img1_flat = img1_flat - img1_flat.mean(dim=1, keepdim=True)
    img2_flat = img2_flat - img2_flat.mean(dim=1, keepdim=True)

    # Compute the dot product
    numerator = torch.sum(img1_flat * img2_flat, dim=1)

    # The denominator is the product of the L2 norms of the two vectors.
    denominator = torch.norm(img1_flat, p=2, dim=1) * torch.norm(img2_flat, p=2, dim=1)

    # Prevent division by zero
    epsilon = 1e-8
    ncc = numerator / (denominator + epsilon)

    return ncc.mean().item()


lpips_score = compute_lpips(img_clean, img_stego)
ncc_score = compute_ncc(img_clean, img_stego)

print(f"LPIPS Score: {lpips_score:.4f}")
print(f"NCC Score: {ncc_score:.4f}")