import os
import torch
from torchvision import datasets, transforms

torch.manual_seed(0)
def get_data():
    data_dir = './Data'
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                            [0.229, 0.224, 0.225])
    ])
    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                            [0.229, 0.224, 0.225])
    ])

    train_set = datasets.ImageFolder(
        os.path.join(data_dir, 'train'), train_transforms)
    val_set = datasets.ImageFolder(
        os.path.join(data_dir, 'val'), val_transforms)
    test_set = datasets.ImageFolder(
        os.path.join(data_dir, 'test'), val_transforms)

    print(len(train_set), len(val_set), len(test_set))
    return train_set, val_set, test_set
