import argparse
import os
from typing import Tuple, List

import cv2
import numpy as np
from PIL import Image


def resize_keep_aspect(image: np.ndarray, short_side: int) -> np.ndarray:
    h, w = image.shape[:2]
    if min(h, w) == short_side:
        return image
    if h < w:
        new_h = short_side
        new_w = int(w * (short_side / h))
    else:
        new_w = short_side
        new_h = int(h * (short_side / w))
    return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)


def center_crop(image: np.ndarray, crop_size: int) -> np.ndarray:
    h, w = image.shape[:2]
    y0 = max((h - crop_size) // 2, 0)
    x0 = max((w - crop_size) // 2, 0)
    y1 = y0 + crop_size
    x1 = x0 + crop_size
    return image[y0:y1, x0:x1]


def sliding_window_patches(image: np.ndarray, crop_size: int, stride: int) -> List[np.ndarray]:
    patches = []
    h, w = image.shape[:2]
    for y in range(0, max(h - crop_size + 1, 1), stride):
        for x in range(0, max(w - crop_size + 1, 1), stride):
            patch = image[y:y+crop_size, x:x+crop_size]
            if patch.shape[0] == crop_size and patch.shape[1] == crop_size:
                patches.append(patch)
    return patches


def remove_text_regions(image: np.ndarray) -> np.ndarray:
    # Lightweight fallback: use simple thresholding + morphology to detect dark text-like regions,
    # then inpaint using Telea method. For stronger results, install EasyOCR and detect boxes.
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # adaptive threshold to get text-like masks
    thr = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C,
                                cv2.THRESH_BINARY_INV, 25, 15)
    # morphology to consolidate regions
    kernel = np.ones((3, 3), np.uint8)
    mask = cv2.morphologyEx(thr, cv2.MORPH_OPEN, kernel, iterations=1)
    mask = cv2.dilate(mask, kernel, iterations=1)
    # inpaint
    inpainted = cv2.inpaint(image, mask, 3, cv2.INPAINT_TELEA)
    return inpainted


def process_image(path: str, out_dir: str, min_short_side: int, crop_size: int, stride: int, remove_text: bool) -> None:
    os.makedirs(out_dir, exist_ok=True)
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        return
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = resize_keep_aspect(img, min_short_side)

    h, w = img.shape[:2]
    aspect = max(w, h) / max(1, min(w, h))

    if remove_text:
        img = cv2.cvtColor(remove_text_regions(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)), cv2.COLOR_BGR2RGB)

    # If relatively not-so-wide, center crop; otherwise sliding windows
    if aspect < 1.5:
        cropped = center_crop(img, crop_size)
        out_path = os.path.join(out_dir, os.path.basename(path))
        Image.fromarray(cropped).save(out_path)
    else:
        patches = sliding_window_patches(img, crop_size, stride)
        base, ext = os.path.splitext(os.path.basename(path))
        for i, p in enumerate(patches):
            out_path = os.path.join(out_dir, f"{base}_patch{i:03d}{ext}")
            Image.fromarray(p).save(out_path)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=str, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument('--min_short_side', type=int, default=512)
    parser.add_argument('--crop_size', type=int, default=512)
    parser.add_argument('--stride', type=int, default=256)
    parser.add_argument('--remove_text', type=str, default='false')
    args = parser.parse_args()

    remove_text = str(args.remove_text).lower() in ['1', 'true', 'yes']

    for fname in os.listdir(args.input_dir):
        in_path = os.path.join(args.input_dir, fname)
        if not os.path.isfile(in_path):
            continue
        lower = fname.lower()
        if not any(lower.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".bmp", ".webp"]):
            continue
        process_image(
            in_path,
            args.output_dir,
            args.min_short_side,
            args.crop_size,
            args.stride,
            remove_text,
        )


if __name__ == '__main__':
    main()


