import os
from torch.utils.data import Dataset, DataLoader
import cv2
from torchvision import transforms
from PIL import Image

class CUB(Dataset):
    """
    CUB Dataset class for loading CUB-200-2011 dataset.
    """
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.is_train = train
        self.transform = transform
        self.images_path = {}
        
        # Load image paths
        with open(os.path.join(self.root, 'images.txt')) as f:
            for line in f:
                image_id, path = line.split()
                self.images_path[image_id] = path

        # Load class labels
        self.class_ids = {}
        with open(os.path.join(self.root, 'image_class_labels.txt')) as f:
            for line in f:
                image_id, class_id = line.split()
                self.class_ids[image_id] = int(class_id) - 1  # Make class IDs 0-indexed

        # Load train/test split
        self.data_id = []
        with open(os.path.join(self.root, 'train_test_split.txt')) as f:
            for line in f:
                image_id, is_train = line.split()
                if (self.is_train and int(is_train)) or (not self.is_train and not int(is_train)):
                    self.data_id.append(image_id)

    def __len__(self):
        return len(self.data_id)

    def __getitem__(self, idx):
        image_id = self.data_id[idx]
        class_id = self.class_ids[image_id]
        path = self.images_path[image_id]
        image = cv2.imread(os.path.join(self.root, 'images', path))

        # Convert BGR (OpenCV format) to RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, class_id


def get_cub_dataloaders(batch_size=64, num_workers=8, image_size=448, data_dir='/home/heweihong/CUB/CUB_200_2011'):
    """
    CUB-200-2011 data loader with fixed resizing to avoid batch collation issues
    """
    
    # Define transformations for training dataset with resizing
    transform_train = transforms.Compose([
        transforms.Resize((image_size, image_size)),                # Resize all images to 448x448 or any fixed size
        transforms.RandomHorizontalFlip(),                          # Random horizontal flip
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),  # Color jitter
        transforms.RandomRotation(15),                              # Random rotation
        transforms.ToTensor(),                                      # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],            # Normalize images
                             std=[0.229, 0.224, 0.225]),
    ])
    
    # Define transformations for validation dataset with resizing
    transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)),                # Resize all images to 448x448
        transforms.ToTensor(),                                      # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],            # Normalize images
                             std=[0.229, 0.224, 0.225]),
    ])
    
    # Load training dataset
    train_dataset = CUB(root=data_dir, train=True, transform=transform_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    
    # Load validation dataset
    val_dataset = CUB(root=data_dir, train=False, transform=transform_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader


# 测试加载器
if __name__ == '__main__':
    # 示例1：默认参数加载
    train_loader, val_loader = get_cub_dataloaders()
    
    print("训练集批次数量:", len(train_loader))
    print("验证集批次数量:", len(val_loader))
    
    # 查看一个批次的图像和标签
    for images, labels in train_loader:
        print("图像批次大小:", images.shape)  # 预期 torch.Size([64, 3, H, W]) 这里 H, W 是原始图像的高度和宽度
        print("标签批次大小:", labels.shape)  # 预期 torch.Size([64])
        break  # 只打印第一个批次
    
    # 示例2：自定义批量大小
    train_loader, val_loader = get_cub_dataloaders(batch_size=32, num_workers=4)
    
    print("训练集批次数量:", len(train_loader))
    print("验证集批次数量:", len(val_loader))
    
    # 查看一个批次的图像和标签
    for images, labels in train_loader:
        print("图像批次大小:", images.shape)  # 预期 torch.Size([32, 3, H, W]) 这里 H, W 是原始图像的高度和宽度
        print("标签批次大小:", labels.shape)  # 预期 torch.Size([32])
        break  # 只打印第一个批次
