from __future__ import print_function

import os
import time
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from loss import *
from models import model_dict

from dataset.mini_imagenet import get_mini_imagenet_dataloaders
from helper.util import accuracy, AverageMeter
from helper.loops import validate

def get_teacher_name(model_path):
    """parse teacher name from model path"""
    segments = model_path.split('/')[-2].split('_')
    if segments[0] != 'wrn':
        return segments[0]
    else:
        return f'{segments[0]}_{segments[1]}_{segments[2]}'
def load_teacher(model_path, n_cls):
    print('==> Loading teacher model')
    model_t = get_teacher_name(model_path)
    model = model_dict[model_t](num_classes=n_cls)
    
    # 加载 state_dict
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model']
    
    # 如果 state_dict 的键带有 'module.'，我们需要去掉 'module.' 前缀
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # 去掉 'module.' 前缀
        else:
            new_state_dict[k] = v
    
    model.load_state_dict(new_state_dict)
    print('==> Done loading teacher model')
    return model

# def load_teacher(model_path, n_cls):
#     print('==> Loading teacher model')
#     model_t = get_teacher_name(model_path)
#     model = model_dict[model_t](num_classes=n_cls)
#     model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
#     print('==> Done loading teacher model')
#     return model

# /home/heweihong/DSP-main/DSP-main/dsp/save/models/vgg13_mini_imagenet_lr_0.05_decay_0.0005_trial_0/vgg13_best.pth
#'/home/heweihong/DSP-main/DSP-main/dsp/save/models/resnet110_mini_imagenet_lr_0.05_decay_0.0005_trial_0/resnet110_best.pth'
#/home/heweihong/DSP-main/DSP-main/dsp/save/models/wrn_40_2_mini_imagenet_lr_0.05_decay_0.0005_trial_0/wrn_40_2_best.pth


def parse_option():
    parser = argparse.ArgumentParser(description='Arguments for validation')

    # 设置默认参数的值
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for validation')
    parser.add_argument('--num_workers', type=int, default=8, help='Number of workers for data loading')
    parser.add_argument('--model', type=str, default='vgg13',
                        choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
                                 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2',
                                 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19',
                                 'MobileNetV2', 'ShuffleV1', 'ShuffleV2'],
                        help='Model architecture')
    parser.add_argument('--dataset', type=str, default='mini_imagenet', choices=['mini_imagenet', 'cifar100', 'cifar10'], help='Dataset to use')
    parser.add_argument('--model_path', type=str, default='/home/heweihong/DSP-main/DSP-main/dsp/save/models/vgg13_mini_imagenet_lr_0.05_decay_0.0005_trial_0/vgg13_best.pth',
                        help='Path to the trained model')
    parser.add_argument('--no_cuda', action='store_true', default=False, help='Disable CUDA training')
    parser.add_argument('--log_interval', type=int, default=10, help='How many batches to wait before logging training status')

    opt = parser.parse_args()
    return opt

def validate_model(opt):
    # 加载数据集
    if opt.dataset == 'mini_imagenet':
        _, val_loader = get_mini_imagenet_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
        n_cls = 100
    elif opt.dataset == 'cifar100':
        from dataset.cifar100 import get_cifar100_dataloaders
        _, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
        n_cls = 100
    elif opt.dataset == 'cifar10':
        from dataset.cifar10 import get_cifar10_dataloaders
        _, val_loader = get_cifar10_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
        n_cls = 10
    else:
        raise ValueError(f"Unsupported dataset: {opt.dataset}")

    # 加载教师模型
    model = load_teacher(opt.model_path, n_cls=n_cls)

    # 设置设备
    use_cuda = not opt.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    model = model.to(device)

    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()

    # 验证模型
    model.eval()
    top1_avg, top5_avg, loss_avg = validate(val_loader, model, criterion, opt)

    print(f'Validation Results - Acc@1: {top1_avg:.2f}%, Acc@5: {top5_avg:.2f}%, Loss: {loss_avg:.4f}')

def main():
    opt = parse_option()
    validate_model(opt)

if __name__ == '__main__':
    main()
