import os
from typing import Callable, Optional, List

from PIL import Image
from torch.utils.data import Dataset


class ImageFolderDataset(Dataset):
    """
    Minimal image folder dataset. Loads all files with given extensions and applies a transform.
    """

    def __init__(
        self,
        root_dir: str,
        transform: Optional[Callable] = None,
        extensions: Optional[List[str]] = None,
    ) -> None:
        self.root_dir = root_dir
        self.transform = transform
        self.extensions = extensions or [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
        self.file_paths = self._gather_files()

    def _gather_files(self) -> List[str]:
        paths: List[str] = []
        for fname in os.listdir(self.root_dir):
            fpath = os.path.join(self.root_dir, fname)
            if not os.path.isfile(fpath):
                continue
            lower = fname.lower()
            if any(lower.endswith(ext) for ext in self.extensions):
                paths.append(fpath)
        paths.sort()
        return paths

    def __len__(self) -> int:
        return len(self.file_paths)

    def __getitem__(self, idx: int):
        path = self.file_paths[idx]
        image = Image.open(path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image


