from __future__ import print_function
import argparse
import copy
import hashlib
import logging
import os
import random
import time
import warnings

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn 

import dsp.sparselearning
from dsp.models import cifar_resnet, initializers, vgg
from dsp.sparselearning.core import CosineDecay, Masking
from dsp.sparselearning.utils import (get_cifar10_dataloaders,
                                      get_cifar100_dataloaders,
                                      get_mnist_dataloaders,
                                      plot_class_feature_histograms)
from dsp.sparselearning.core import add_sparse_args
from dsp.distiller_zoo import (FSP, KDSVD, PKT, ABLoss, Attention, Correlation,
                               DistillKL, FactorTransfer, HintLoss, NSTLoss,
                               RKDLoss, Similarity, VIDLoss)
from dsp.models.util import (Connector, ConvReg, LinearEmbed, Paraphraser,
                             Translator)
from dsp.models import model_dict
from dsp.helper.pretrain import init
from dsp.helper.loops import validate
from dsp.helper.loops import train_sparse_distill as train
from loss import *

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

from dsp.dataset.imagenet import get_imagenet_dataloaders
from dsp.dataset.mini_imagenet import get_mini_imagenet_dataloaders

warnings.filterwarnings('ignore', category=UserWarning)
cudnn.benchmark = True
cudnn.deterministic = True

# if not os.path.exists('./models'): os.mkdir('./models')
if not os.path.exists('./logs'):
    os.mkdir('./logs')
logger = None


def save_checkpoint(state, filename='checkpoint.pth.tar'):
    print('SAVING')
    torch.save(state, filename)

def get_teacher_name(model_path):
    """parse teacher name"""
    segments = model_path.split('/')[-2].split('_')
    if segments[0] != 'wrn':
        return segments[0]
    else:
        return f'{segments[0]}_{segments[1]}_{segments[2]}'


def load_teacher(model_path, n_cls):
    print('==> loading teacher model')
    model_t = get_teacher_name(model_path)

    model = model_dict[model_t](num_classes=n_cls)
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model'])
    print('==> done')
    return model


def setup_logger(args):
    global logger
    if logger is None:
        logger = logging.getLogger()
    else:  # wish there was a logger.close()
        for handler in logger.handlers[:]:  # make a copy of the list
            logger.removeHandler(handler)

    args_copy = copy.deepcopy(args)
    # copy to get a clean hash
    # use the same log file hash if iterations or verbose are different
    # these flags do not change the results
    args_copy.iters = 1
    args_copy.verbose = False
    args_copy.log_interval = 1
    args_copy.seed = 0

    log_path = './logs/{0}_{1}_{2}.log'.format(
        args.model, args.density,
        hashlib.md5(str(args_copy).encode('utf-8')).hexdigest()[:8])

    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        fmt='%(asctime)s: %(message)s', datefmt='%H:%M:%S')

    fh = logging.FileHandler(log_path)
    fh.setFormatter(formatter)
    logger.addHandler(fh)


def print_and_log(msg):
    global logger
    print(msg)
    logger.info(msg)


# def evaluate(args, model, device, test_loader, is_test_set=False):
#     model.eval()
#     test_loss = 0
#     correct = 0
#     n = 0
#     with torch.no_grad():
#         for data, target in test_loader:
#             data, target = data.to(device), target.to(device)
#             if args.fp16:
#                 data = data.half()
#             model.t = target
#             output = model(data)
#             test_loss += F.nll_loss(
#                 output, target, reduction='sum').item()  # sum up batch loss
#             pred = output.argmax(
#                 dim=1,
#                 keepdim=True)  # get the index of the max log-probability
#             correct += pred.eq(target.view_as(pred)).sum().item()
#             n += target.shape[0]

#     test_loss /= float(n)

#     print_and_log(
#         '\n{}: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
#             'Test evaluation' if is_test_set else 'Evaluation', test_loss,
#             correct, n, 100. * correct / float(n)))
#     return correct / float(n)

def evaluate(args, model, device, test_loader, is_test_set=False):
    model.eval()
    test_loss = 0
    correct_top1 = 0  # 用于统计Top-1正确预测的数量
    correct_top5 = 0  # 用于统计Top-5正确预测的数量
    n = 0  # 总样本数

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if args.fp16:
                data = data.half()
            output = model(data)

            # 累加损失
            test_loss += F.nll_loss(output, target, reduction='sum').item()

            # Top-1准确率计算
            pred_top1 = output.argmax(dim=1, keepdim=True)  # 获取Top-1的预测结果
            correct_top1 += pred_top1.eq(target.view_as(pred_top1)).sum().item()

            # Top-5准确率计算
            _, pred_top5 = output.topk(5, dim=1, largest=True, sorted=True)  # 获取Top-5的预测结果
            correct_top5 += pred_top5.eq(target.view(-1, 1)).sum().item()  # 统计Top-5正确的预测数量

            n += target.shape[0]  # 累加样本数量

    # 计算Top-1和Top-5的准确率
    test_loss /= float(n)
    top1_accuracy = 100. * correct_top1 / float(n)
    top5_accuracy = 100. * correct_top5 / float(n)

    # 打印并记录Top-1和Top-5准确率
    print_and_log(
        '\n{}: 平均损失: {:.4f}, Top-1 准确率: {}/{} ({:.3f}%), Top-5 准确率: ({:.3f}%)\n'.format(
            '测试评估' if is_test_set else '评估', test_loss,
            correct_top1, n, top1_accuracy, top5_accuracy))

    # 返回Top-1和Top-5准确率
    return top1_accuracy, top5_accuracy



def parse_args():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    parser.add_argument(
        '--batch-size',
        type=int,
        default=128,
        metavar='N',
        help='input batch size for training (default: 100)')
    parser.add_argument(
        '--test-batch-size',
        type=int,
        default=100,
        metavar='N',
        help='input batch size for testing (default: 100)')
    parser.add_argument(
        '--multiplier',
        type=int,
        default=1,
        metavar='N',
        help='extend training time by multiplier times')
    parser.add_argument(
        '--epochs',
        type=int,
        default=240,
        metavar='N',
        help='number of epochs to train (default: 100)')
    parser.add_argument(
        '--lr',
        type=float,
        default=0.1,
        metavar='LR',
        help='learning rate (default: 0.1)')
    parser.add_argument(
        '--momentum',
        type=float,
        default=0.9,
        metavar='M',
        help='SGD momentum (default: 0.9)')
   #parser.add_argument(
       # '--no-cuda',
       # action='store_true',
       # default=False,
       # help='disables CUDA training')
   #parser.add_argument('--sparse_init', type=str, default='ERK',choices=['ERK', 'snip', 'ER'], help='sparse initialization')
   #parser.add_argument('--density', type=float, default=0.65, help='The density of the overall sparse network.')
    parser.add_argument(
        '--seed',
        type=int,
        default=17,
        metavar='S',
        help='random seed (default: 17)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument(
        '--optimizer',
        type=str,
        default='sgd',
        help='The optimizer to use. Default: sgd. Options: sgd, adam.')
    randomhash = ''.join(str(time.time()).split('.'))
    parser.add_argument(
        '--save',
        type=str,
        default=randomhash + '.pt',
        help='path to save the final model')
    # parser.add_argument('--data', type=str, default='cifar100',choices=['cifar100','cifar10'])
    # parser.add_argument('--data', type=str, default='cifar100', choices=['cifar100', 'cifar10', 'imagenet'])
    parser.add_argument('--data', type=str, default='cifar100', choices=['cifar100', 'cifar10', 'imagenet', 'mini_imagenet'],
                    help='Dataset to use for training and validation. Options: cifar100, cifar10, imagenet, mini_imagenet')
    parser.add_argument('--decay_frequency', type=int, default=25000)
    parser.add_argument('--l1', type=float, default=0.0)
    parser.add_argument(
        '--fp16', action='store_true', help='Run in fp16 mode.')
    parser.add_argument('--valid_split', type=float, default=0.1)
    parser.add_argument('--resume', type=str)
    parser.add_argument('--start-epoch', type=int, default=1)
    parser.add_argument('--model', type=str, default='cifar_resnet_20',choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2','vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'ResNet50', 'ResNet18', 'ResNet34', 'MobileNetV2', 'ShuffleV1', 'ShuffleV2', 'wrn_cifar10', 'resnet_feat_at_110', 'resnet_feat_at_20', 'resnet_feat_at_14'])
    parser.add_argument('--l2', type=float, default=5.0e-4)
    parser.add_argument(
        '--iters',
        type=int,
        default=1,
        help=
        'How many times the model should be run after each other. Default=1')
    parser.add_argument(
        '--save-features',
        action='store_true',
        help=
        'Resumes a saved model and saves its feature data to disk for plotting.'
    )
    parser.add_argument(
        '--bench',
        action='store_true',
        help='Enables the benchmarking of layers and estimates sparse speedups'
    )
    parser.add_argument(
        '--max-threads',
        type=int,
        default=10,
        help='How many threads to use for data loading.')
    parser.add_argument(
        '--scaled',
        action='store_true',
        help='scale the initialization by 1/density')

    # KL distillation
    parser.add_argument(
        '--kd_T',
        type=float,
        default=4,
        help='temperature for KD distillation')
    # distillation
    parser.add_argument(
        '--distill',
        type=str,
        default='kd',
        choices=[
            'kd', 'hint', 'attention', 'similarity', 'correlation', 'vid',
            'crd', 'kdsvd', 'fsp', 'rkd', 'pkt', 'abound', 'factor', 'nst'
    ])
    
    parser.add_argument(
        # '--path_t', type=str, default='./save/models/resnet110_vanilla/ckpt_epoch_240.pth', help='teacher model snapshot')
        '--path_t', type=str, default='./dsp/save/models/resnet110_mini_imagenet_lr_0.05_decay_0.0005_trial_0/resnet110_best.pth', help='teacher model snapshot')
        # '--path_t', type=str, default='./dsp/save/models/wrn_40_2_mini_imagenet_lr_0.05_decay_0.0005_trial_0/wrn_40_2_best.pth', help='teacher model snapshot')
        # '--path_t', type=str, default='./dsp/save/models/vgg13_mini_imagenet_lr_0.05_decay_0.0005_trial_0/vgg13_best.pth', help='teacher model snapshot')
        # '--path_t', type=str, default='./dsp/save/teacher_models/cifar10/resnet110_cifar10_lr_0.05_decay_0.0005_trial_0_nesterov_False_step_150-180-210_bs_128_seed_0_ep_240/resnet110_best.pth', help='teacher model snapshot')
    parser.add_argument(
        '-r',
        '--gamma',
        type=float,
        default=1,
        help='weight for classification')
    parser.add_argument(
        '-a',
        '--alpha',
        type=float,
        default=0.1,
        help='weight balance for KD')
    parser.add_argument(
        '-b',
        '--beta',
        type=float,
        default=0.1,
        help='weight balance for other losses')
    parser.add_argument('--hint_layer', default=3, type=int, choices=[0, 1, 2, 3, 4])
    # ITOP settings
    add_sparse_args(parser)

    args = parser.parse_args()
    return args 

def main():

    top5_accuracies = []

    args = parse_args()
    # ITOP settings
   # dsp.sparselearning.core.add_sparse_args(parser)
    setup_logger(args)
    print_and_log(args)

    if args.fp16:
        try:
            from apex.fp16_utils import FP16_Optimizer
        except:
            print('WARNING: apex not installed, ignoring --fp16 option')
            args.fp16 = False

   # use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # device = torch.device('cuda' if use_cuda else 'cpu')
    
    print_and_log('\n\n')
    print_and_log('=' * 80)

    # fix random seed for Reproducibility
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    for i in range(args.iters):
        print_and_log('\nIteration start: {0}/{1}\n'.format(i + 1, args.iters))

        # if args.data == 'mnist':
        #     train_loader, valid_loader, test_loader = get_mnist_dataloaders(
        #         args, validation_split=args.valid_split)
        #     output = 10
        # elif args.data == 'cifar10':
        #     train_loader, valid_loader, test_loader = get_cifar10_dataloaders(
        #         args, args.valid_split, max_threads=args.max_threads)
        #     output = 10
        # elif args.data == 'cifar100':
        #     train_loader, valid_loader, test_loader = get_cifar100_dataloaders(
        #         args, args.valid_split, max_threads=args.max_threads)
        #     output = 100
        if args.data == 'mnist':
            train_loader, valid_loader, test_loader = get_mnist_dataloaders(args, validation_split=args.valid_split)
            output = 10
        elif args.data == 'cifar10':
            train_loader, valid_loader, test_loader = get_cifar10_dataloaders(args, args.valid_split, max_threads=args.max_threads)
            output = 10
        elif args.data == 'cifar100':
            train_loader, valid_loader, test_loader = get_cifar100_dataloaders(args, args.valid_split, max_threads=args.max_threads)
            output = 100
        elif args.data == 'mini_imagenet':  
            train_loader, valid_loader = get_mini_imagenet_dataloaders(batch_size=args.batch_size, num_workers=args.max_threads, image_size=84, data_dir='/home/heweihong/mini-imagenet')
            output = 100
            test_loader = valid_loader  
        elif args.data == 'imagenet':
            train_loader, test_loader = get_imagenet_dataloaders(dataset='imagenet', 
                                                                batch_size=args.batch_size, 
                                                                num_workers=args.max_threads)  
            valid_loader = test_loader  
            output = 1000  
        if args.scaled:
            init_type = 'scaled_kaiming_normal'
        else:
            init_type = 'kaiming_normal'

        # teacher model load 
        model_t = load_teacher(args.path_t, n_cls=output).to(device)


        if torch.cuda.device_count() > 1:
            print(f"使用 {torch.cuda.device_count()} 张 GPU 进行训练教师模型")
            model_t = nn.DataParallel(model_t)

        # student model
        if 'vgg' in args.model:
            model = vgg.VGG(
                depth=int(args.model[3:]), dataset=args.data,
                batchnorm=True).to(device)
        # else:
        #     if args.model in model_dict:
        #         model = model_dict[args.model](num_classes=output).to(device)
        #     else:
        #         raise ValueError(f"Unsupported model: {args.model}")
        else:
            model = cifar_resnet.Model.get_model_from_name(
                args.model,
                initializers.initializations(init_type, args.density),
                outputs=output).to(device)

        if torch.cuda.device_count() > 1:
            print(f"使用 {torch.cuda.device_count()} 张 GPU 进行训练学生模型")
            model = nn.DataParallel(model)

        print_and_log(model)
        print_and_log('=' * 60)
        print_and_log(args.model)
        print_and_log('=' * 60)

        print_and_log('=' * 60)
        print_and_log('Prune mode: {0}'.format(args.death))
        print_and_log('Growth mode: {0}'.format(args.growth))
        print_and_log('Redistribution mode: {0}'.format(args.redistribution))
        print_and_log('=' * 60)

        if args.resume:
            if os.path.isfile(args.resume):
                print_and_log("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print_and_log("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
                print_and_log('Testing...')
                evaluate(args, model, device, test_loader)
                model.feats = []
                model.densities = []
                plot_class_feature_histograms(args, model, device,
                                              train_loader, optimizer)
            else:
                print_and_log("=> no checkpoint found at '{}'".format(
                    args.resume))

        if args.fp16:
            print('FP16')
            optimizer = FP16_Optimizer(
                optimizer,
                static_loss_scale=None,
                dynamic_loss_scale=True,
                dynamic_loss_args={'init_scale': 2**16})
            model = model.half()


        best_acc = 0.0

        # single forward batch 
        # data = torch.randn(1, 3, 32, 32).to(device)
        if args.data == 'imagenet':
            data = torch.randn(1, 3, 224, 224).to(device)  # ImageNet uses 224x224 images
        else:
            data = torch.randn(1, 3, 32, 32).to(device)  # CIFAR uses 32x32 images
        model_t.eval()
        model.eval()
        feat_t, logit_t = model_t(data, is_feat=True)
        feat, logit = model(data, is_feat=True)

        module_list = nn.ModuleList([])
        module_list.append(model)
        trainable_list = nn.ModuleList([])
        trainable_list.append(model)

        criterion_div = DistillKL(args.kd_T)
        criterion_cls = nn.CrossEntropyLoss()
        if args.distill == 'kd':
            criterion_kd = DistillKL(args.kd_T)
        elif args.distill == 'hint':
            criterion_kd = HintLoss()
            regress_s = ConvReg(feat[args.hint_layer].shape,
                                feat_t[args.hint_layer].shape)
            module_list.append(regress_s)
            trainable_list.append(regress_s)
        elif args.distill == 'attention':
            criterion_kd = Attention()
        elif args.distill == 'nst':
            criterion_kd = NSTLoss()
        elif args.distill == 'similarity':
            criterion_kd = Similarity()
        elif args.distill == 'rkd':
            criterion_kd = RKDLoss()
        elif args.distill == 'pkt':
            criterion_kd = PKT()
        elif args.distill == 'kdsvd':
            criterion_kd = KDSVD()
        elif args.distill == 'correlation':
            criterion_kd = Correlation()
            embed_s = LinearEmbed(feat[-1].shape[1], args.feat_dim)
            embed_t = LinearEmbed(feat_t[-1].shape[1], args.feat_dim)
            module_list.append(embed_s)
            module_list.append(embed_t)
            trainable_list.append(embed_s)
            trainable_list.append(embed_t)
        elif args.distill == 'vid':
            s_n = [f.shape[1] for f in feat[1:-1]]
            t_n = [f.shape[1] for f in feat_t[1:-1]]
            criterion_kd = nn.ModuleList(
                [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)])
            # add this as some parameters in VIDLoss need to be updated
            trainable_list.append(criterion_kd)
        elif args.distill == 'abound':
            s_shapes = [f.shape for f in feat[1:-1]]
            t_shapes = [f.shape for f in feat_t[1:-1]]
            connector = Connector(s_shapes, t_shapes)
            # init stage training
            init_trainable_list = nn.ModuleList([])
            init_trainable_list.append(connector)
            init_trainable_list.append(model.get_feat_modules())
            criterion_kd = ABLoss(len(feat[1:-1]))
            init(model, model_t, init_trainable_list, criterion_kd, train_loader,
                logger, args)
            # classification
            module_list.append(connector)
        elif args.distill == 'factor':
            s_shape = feat[-2].shape
            t_shape = feat_t[-2].shape
            paraphraser = Paraphraser(t_shape)
            translator = Translator(s_shape, t_shape)
            # init stage training
            init_trainable_list = nn.ModuleList([])
            init_trainable_list.append(paraphraser)
            criterion_init = nn.MSELoss()
            init(model, model_t, init_trainable_list, criterion_init,
                train_loader, logger, args)
            # classification
            criterion_kd = FactorTransfer()
            module_list.append(translator)
            module_list.append(paraphraser)
            trainable_list.append(translator)
        elif args.distill == 'fsp':
            s_shapes = [s.shape for s in feat[:-1]]
            t_shapes = [t.shape for t in feat_t[:-1]]
            criterion_kd = FSP(s_shapes, t_shapes)
            # init stage training
            init_trainable_list = nn.ModuleList([])
            init_trainable_list.append(model.get_feat_modules())
            init(model, model_t, init_trainable_list, criterion_kd, train_loader,
                logger, args)
        else:
            raise NotImplementedError(args.distill)
        
        # build criterion_list 
        criterion_list = nn.ModuleList([])
        criterion_list.append(criterion_cls)  # classification loss
        criterion_list.append(
            criterion_div)  # KL divergence loss, original knowledge distillation
        criterion_list.append(criterion_kd)  # other knowledge distillation loss

        ###################################################################################
        # Loss, Optimizer
        ###################################################################################
        criterion_list.append(SimMaxLoss(metric='cos', alpha=0.25))
        criterion_list.append(SimMinLoss(metric='cos'))
        criterion_list.append(SimMaxLoss(metric='cos', alpha=0.25))

        # optimizer setting
        optimizer = None
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(
                trainable_list.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.l2,
                nesterov=True)
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(
                trainable_list.parameters(), lr=args.lr, weight_decay=args.l2)
        else:
            print('Unknown optimizer: {0}'.format(args.optimizer))
            raise Exception('Unknown optimizer.')

        # build mask for sparse training
        mask = None
        if args.sparse:
            decay = CosineDecay(
                args.death_rate,
                len(train_loader) * (args.epochs * args.multiplier))
            mask = Masking(
                optimizer,
                death_rate=args.death_rate,
                death_mode=args.death,
                death_rate_decay=decay,
                growth_mode=args.growth,
                redistribution_mode=args.redistribution,
                args=args,
                train_loader=train_loader)
            mask.add_module(
                model, sparse_init=args.sparse_init, density=args.density)

        # lr scheduler setting
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[
                int(args.epochs / 2) * args.multiplier,
                int(args.epochs * 3 / 4) * args.multiplier
            ],
            last_epoch=-1)

        # append teacher after optimizer to avoid weight_decay
        module_list.append(model_t)

        if torch.cuda.is_available():
            module_list.cuda()
            criterion_list.cuda()
            cudnn.benchmark = True

        # validate teacher accuracy
        teacher_acc, _, _ = validate(valid_loader, model_t, criterion_cls, args)
        #print_and_log('teacher accuracy: ', teacher_acc)
        print_and_log('teacher accuracy: ' + str(teacher_acc))

        # create output file
        save_path = './save/' + str(args.model) + '/' + str(
            args.data) + '/' + str(args.sparse_init) + '/' + str(args.seed)
        if args.sparse:
            save_subfolder = os.path.join(save_path,
                                          'sparsity' + str(1 - args.density))
        else:
            save_subfolder = os.path.join(save_path, 'dense')
        if not os.path.exists(save_subfolder):
            os.makedirs(save_subfolder)

        for epoch in range(1, args.epochs * args.multiplier + 1):

            t0 = time.time()
            train(args, model, device, train_loader, optimizer, epoch, module_list, criterion_list, mask, logger)
            lr_scheduler.step()
            if args.valid_split > 0.0:
                # val_acc = evaluate(args, model, device, valid_loader)
                val_acc_top1, val_acc_top5 = evaluate(args, model, device, valid_loader)
                top5_accuracies.append(val_acc_top5)  # 保存当前的Top-5准确率

            # if val_acc > best_acc:
            if val_acc_top1 > best_acc:
                print('Saving model')
                best_acc = val_acc_top1
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    },
                    filename=os.path.join(save_subfolder, 'model_final.pth'))

            print_and_log(
                'Current learning rate: {0}. Time taken for epoch: {1:.2f} seconds.\n'
                .format(optimizer.param_groups[0]['lr'],
                        time.time() - t0))
        print('Testing model')
        model.load_state_dict(
            torch.load(os.path.join(save_subfolder,
                                    'model_final.pth'))['state_dict'])
        # evaluate(args, model, device, test_loader, is_test_set=True)
        test_acc_top1, test_acc_top5 = evaluate(args, model, device, test_loader, is_test_set=True)


        print_and_log('\nIteration end: {0}/{1}\n'.format(i + 1, args.iters))

        # 打印最终的Top-1和Top-5准确率
        print_and_log(f'最终 Top-1 准确率: {test_acc_top1:.3f}%')
        print_and_log(f'最终 Top-5 准确率: {test_acc_top5:.3f}%')

        # 训练结束后打印前五的Top-5准确率
        top5_accuracies.sort(reverse=True)  # 从高到低排序
        print("前五个最高的Top-5准确率：")
        for i, acc in enumerate(top5_accuracies[:5]):  # 打印前五个
            print(f"第 {i + 1} 名的 Top-5 准确率: {acc:.3f}%")


if __name__ == '__main__':
    main()
