import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import json

# 自定义数据集类
class MiniImageNetDataset(Dataset):
    def __init__(self, csv_file, root_dir, class_index_file, transform=None):
        """
        :param csv_file: CSV 文件路径 (包含图片名和对应标签)
        :param root_dir: 图片根目录
        :param class_index_file: json 文件路径 (包含标签与类名映射)
        :param transform: 图像的变换操作
        """
        self.data = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

        # 读取 class index 文件
        with open(class_index_file, 'r') as f:
            self.class_idx_dict = json.load(f)
        
        # 将类名转换为整数索引
        self.class_to_idx = {v[0]: int(k) for k, v in self.class_idx_dict.items()}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])  # 获取图片路径
        image = Image.open(img_name).convert('RGB')  # 打开图片
        label_str = self.data.iloc[idx, 1]  # 获取字符串标签

        if self.transform:
            image = self.transform(image)

        # 将标签从字符串转换为整数索引
        label = torch.tensor(self.class_to_idx[label_str])

        return image, label

def get_mini_imagenet_dataloaders(batch_size=64, num_workers=8, image_size=84, data_dir='/home/heweihong/mini-imagenet'):
    """
    加载 Mini-ImageNet 数据集
    :param batch_size: 每个批次的图像数量
    :param num_workers: 加载数据时使用的工作线程数量
    :param image_size: 图像的大小（默认为 84x84）
    :param data_dir: 数据集所在的根目录
    :return: 返回训练和验证集的数据加载器
    """
    
    # 定义图像的标准变换
    transform_standard = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # 调整图像大小
        transforms.ToTensor(),  # 转换为 Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 标准化
                             std=[0.229, 0.224, 0.225])
    ])
    
    # 加载新的训练集和验证集
    train_dataset = MiniImageNetDataset(
        csv_file=os.path.join(data_dir, 'new_train.csv'),
        root_dir=os.path.join(data_dir, 'images'),
        class_index_file=os.path.join(data_dir, 'imagenet_class_index_new.json'),  # 使用新生成的映射文件
        transform=transform_standard
    )

    val_dataset = MiniImageNetDataset(
        csv_file=os.path.join(data_dir, 'new_val.csv'),
        root_dir=os.path.join(data_dir, 'images'),
        class_index_file=os.path.join(data_dir, 'imagenet_class_index_new.json'),  # 使用新生成的映射文件
        transform=transform_standard
    )


    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader


# 测试加载器并检查标签范围
if __name__ == '__main__':
    # 使用去掉数据增强的标准转换
    train_loader, val_loader = get_mini_imagenet_dataloaders()

    # 获取训练集和验证集的数据集大小
    train_dataset_size = len(train_loader.dataset)
    val_dataset_size = len(val_loader.dataset)

    # 打印训练集和验证集的图像数量
    print(f"训练集总图像数量: {train_dataset_size}")
    print(f"验证集总图像数量: {val_dataset_size}")

    print("训练集批次数量:", len(train_loader))
    print("验证集批次数量:", len(val_loader))

    # 检查标签范围是否在正确的 [0, 99] 范围内
    for images, labels in train_loader:
        print("图像批次大小:", images.shape)  # 预期 torch.Size([64, 3, 84, 84])
        print("标签批次大小:", labels.shape)  # 预期 torch.Size([64])
        
        # 打印标签的最大值和最小值，确保它们在 [0, 99] 范围内
        print("标签最小值:", labels.min().item())
        print("标签最大值:", labels.max().item())

        # 如果标签超出 [0, 99]，提醒
        if labels.min().item() < 0 or labels.max().item() >= 100:
            print("标签值超出范围，请检查映射！")
        break  # 只打印第一个批次
