import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.utils.data import DataLoader
from data import get_mini_imagenet_dataloaders

# 全局定义训练参数
BATCH_SIZE = 64
NUM_WORKERS = 8
IMAGE_SIZE = 84
NUM_CLASSES = 100
LR = 0.001
NUM_EPOCHS = 10

# 检查 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据加载器
train_loader, val_loader = get_mini_imagenet_dataloaders(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# 预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)

# 修改输出层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)

# 模型移至 GPU
model = model.to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

# 学习率调度器（可选）
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# 训练模型
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS):
    for epoch in range(num_epochs):
        print(f"\nEpoch [{epoch+1}/{num_epochs}] Start")
        start_time = time.time()  # 记录 epoch 开始时间
        
        # 训练模式
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        # 批次训练
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录损失和准确率
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # 打印每个批次的信息
            if (batch_idx + 1) % 10 == 0:
                print(f'Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}, Accuracy: {100 * correct / total:.2f}%')

        # 更新学习率
        scheduler.step()

        # 打印每个 epoch 的总信息
        epoch_time = time.time() - start_time  # 计算 epoch 结束时间
        print(f"Epoch [{epoch+1}] - Training Loss: {total_loss / len(train_loader):.4f}, Training Accuracy: {100 * correct / total:.2f}%, Time: {epoch_time:.2f}s")
        
        # 验证模型
        validate(model, val_loader, criterion)

# 验证模型
def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(val_loader):
            images, labels = images.to(device), labels.to(device)
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 记录损失和准确率
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # 打印每个批次的信息
            if (batch_idx + 1) % 10 == 0:
                print(f'Validation Batch [{batch_idx + 1}/{len(val_loader)}], Loss: {loss.item():.4f}')

    # 打印验证信息
    print(f"Validation Loss: {total_loss / len(val_loader):.4f}, Validation Accuracy: {100 * correct / total:.2f}%")

# 开始训练
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS)
