import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from perlin_noise import PerlinNoise
import json
import glob

# Function to generate a Perlin noise map
def generate_perlin_noise(shape, scale=10):
    noise = PerlinNoise(octaves=6, seed=1)
    noise_map = np.zeros(shape, dtype=np.float32)
    for i in range(shape[0]):
        for j in range(shape[1]):
            noise_map[i][j] = noise([i / scale, j / scale])
    return noise_map

def load_state(save_path):
    if os.path.exists(save_path):
        with open(save_path, 'r') as f:
            state = json.load(f)
        return state["current_image_idx"], state["selected_areas"]
    else:
        return 0, []

# Function to map mask ratio to intensity range
def map_mask_ratio_to_intensity(mask_ratio, min_ratio=0.00025, max_ratio=0.006, min_intensity=0.25, max_intensity=0.5):
    normalized_ratio = (mask_ratio - min_ratio) / (max_ratio - min_ratio)
    normalized_ratio = np.clip(normalized_ratio, 0, 1)
    scaled_ratio = np.sqrt(normalized_ratio)
    intensity = min_intensity + (max_intensity - min_intensity) * (1 - scaled_ratio)
    return intensity

# Function to create a defined light border around the lesion area
def add_light_border(image, mask, border_intensity):
    border_mask = cv2.Canny(mask, 100, 200)
    dilated_border = cv2.dilate(border_mask, np.ones((5, 5), np.uint8), iterations=1)
    border_mask = cv2.GaussianBlur(dilated_border, (21, 21), 0)

    noise_map = generate_perlin_noise(border_mask.shape, scale=20)
    intensity_map = (noise_map + 1) / 2 * border_intensity

    lightened_image = image.copy()
    for i in range(3):
        lightened_image[:, :, i] = np.clip(image[:, :, i] * (1 + intensity_map * (border_mask / 255.0)), 0, 255).astype(np.uint8)

    return lightened_image

# Function to apply a darkening effect with natural blending and average it with original values
def apply_lesion(image, mask, base_min_intensity=0.2, base_max_intensity=0.4):
    darkened_image = image.copy()
    darkening_mask = mask.copy()

    mask_area = np.sum(darkening_mask > 0)
    image_area = darkening_mask.shape[0] * darkening_mask.shape[1]
    mask_ratio = mask_area / image_area

    min_intensity = map_mask_ratio_to_intensity(mask_ratio, min_intensity=base_min_intensity, max_intensity=base_max_intensity)
    max_intensity = map_mask_ratio_to_intensity(mask_ratio, min_intensity=base_min_intensity, max_intensity=base_max_intensity)

    noise_map = generate_perlin_noise(darkening_mask.shape, scale=50)
    intensity_map = min_intensity + (max_intensity - min_intensity) * (noise_map + 1) / 2

    blurred_mask = cv2.GaussianBlur(darkening_mask, (21, 21), 0)
    normalized_mask = blurred_mask / 255.0

    for i in range(3):
        darkened_image[:, :, i] = darkened_image[:, :, i] * (1 - intensity_map * normalized_mask)

    final_image = image.copy()
    for i in range(3):
        final_image[:, :, i] = (darkened_image[:, :, i] * normalized_mask + image[:, :, i] * (1 - normalized_mask)).astype(np.uint8)

    if mask_ratio <= 0.001:
        border_intensity = 0.4 * (1 - mask_ratio)
        final_image = add_light_border(final_image, mask, border_intensity)

    return final_image

# Function to crop the image around the lesion mask
def crop_around_mask(image, mask, crop_size=(200, 200)):
    mask_indices = np.where(mask > 0)
    if len(mask_indices[0]) == 0 or len(mask_indices[1]) == 0:
        return None  # No lesion area found

    center_x = (np.min(mask_indices[1]) + np.max(mask_indices[1])) // 2
    center_y = (np.min(mask_indices[0]) + np.max(mask_indices[0])) // 2

    x1 = max(0, center_x - crop_size[0] // 2)
    y1 = max(0, center_y - crop_size[1] // 2)
    x2 = min(image.shape[1], x1 + crop_size[0])
    y2 = min(image.shape[0], y1 + crop_size[1])

    return image[y1:y2, x1:x2], mask[y1:y2, x1:x2]

# Function to visualize cropped images and save them
def visualize_and_save_crops(original_img, lesion_mask, synthetic_img, final_img, image_name, crop_size=(200, 200)):
    cropped_original, cropped_mask = crop_around_mask(original_img, lesion_mask, crop_size)
    cropped_synthetic = crop_around_mask(synthetic_img, lesion_mask, crop_size)[0]
    cropped_final = crop_around_mask(final_img, lesion_mask, crop_size)[0]

    if cropped_original is None:
        print("No lesion area found in the mask. Skipping this image.")
        return

    # Draw rectangular bounding box around the mask on the final image with noise and border
    x, y, w, h = cv2.boundingRect(cropped_mask)
    boxed_image = cropped_final.copy()
    cv2.rectangle(boxed_image, (x, y), (x + w, y + h), (255, 0, 0), 5)  # Red box around the mask

    # Save the images
    cv2.imwrite(f'{image_name}_original.png', cropped_original)
    cv2.imwrite(f'{image_name}_with_mask.png', cropped_synthetic)
    cv2.imwrite(f'{image_name}_darkened_noise.png', cropped_synthetic)
    cv2.imwrite(f'{image_name}_final_with_border.png', cropped_final)
    cv2.imwrite(f'{image_name}_final_with_box.png', boxed_image)

    # Visualize the images
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 5, 1)
    plt.title('Cropped Original Image')
    plt.imshow(cv2.cvtColor(cropped_original, cv2.COLOR_BGR2RGB))

    plt.subplot(1, 5, 2)
    plt.title('Original with Mask')
    masked_image = cropped_original.copy()
    masked_image[cropped_mask > 0] = (0, 0, 255)
    plt.imshow(cv2.cvtColor(masked_image, cv2.COLOR_BGR2RGB))

    plt.subplot(1, 5, 3)
    plt.title('Darkened & Noise Applied')
    plt.imshow(cv2.cvtColor(cropped_synthetic, cv2.COLOR_BGR2RGB))

    plt.subplot(1, 5, 4)
    plt.title('Noise & Border Applied')
    plt.imshow(cv2.cvtColor(cropped_final, cv2.COLOR_BGR2RGB))

    plt.subplot(1, 5, 5)
    plt.title('Final with Bounding Box')
    plt.imshow(cv2.cvtColor(boxed_image, cv2.COLOR_BGR2RGB))

    plt.show()

def select_lesion_area(image):
    selected_mask = np.zeros(image.shape[:2], dtype=np.uint8)
    drawing = False

    def paint_lesion(event, x, y, flags, param):
        nonlocal drawing

        if event == cv2.EVENT_LBUTTONDOWN:
            drawing = True

        elif event == cv2.EVENT_MOUSEMOVE:
            if drawing:
                cv2.circle(selected_mask, (x, y), 5, (255, 255, 255), -1)

        elif event == cv2.EVENT_LBUTTONUP:
            drawing = False
            cv2.circle(selected_mask, (x, y), 5, (255, 255, 255), -1)

    cv2.namedWindow('Select Lesion Area')
    cv2.setMouseCallback('Select Lesion Area', paint_lesion)

    while True:
        preview_image = image.copy()
        preview_image[selected_mask > 0] = (0, 0, 255)

        cv2.imshow('Select Lesion Area', preview_image)
        key = cv2.waitKey(1) & 0xFF

        if key == 27:  # Escape key to exit
            break

    cv2.destroyAllWindows()
    return selected_mask
    
# Main function
def main():
    input_dir = 'input_images'  # Directory containing input images
    state_save_path = 'state.json'  # Path to save current state

    image_paths = glob.glob(os.path.join(input_dir, '*.png'))
    current_image_idx, selected_areas = load_state(state_save_path)

    while True:
        image_path = image_paths[0]
        non_lesion_img = cv2.imread(image_path)
        image_name = os.path.splitext(os.path.basename(image_path))[0]

        lesion_mask = select_lesion_area(non_lesion_img)

        if np.sum(lesion_mask) == 0:
            current_image_idx += 1
            save_state(current_image_idx, selected_areas, state_save_path)
            continue

        synthetic_lesion_img = apply_lesion(non_lesion_img, lesion_mask)
        final_img_with_border = add_light_border(synthetic_lesion_img, lesion_mask, 0.5)

        visualize_and_save_crops(non_lesion_img, lesion_mask, synthetic_lesion_img, final_img_with_border, image_name)

        key = input("Press 's' to save, 'd' to discard, or 'h' to show histogram: ")

        if key == 's':
            selected_areas.append(int(np.sum(lesion_mask)))  # Convert numpy int to Python int
            current_image_idx += 1
            save_state(current_image_idx, selected_areas, state_save_path)
        elif key == 'd':
            current_image_idx += 1
            save_state(current_image_idx, selected_areas, state_save_path)
        elif key == 'h':
            plot_histogram(selected_areas)

if __name__ == '__main__':
    main()

