### Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven import neuron, encoding, functional

from torch.optim.lr_scheduler import LambdaLR

### Model training
def model_lif_fc(device, learning_rate, T, tau,
                 v_threshold, train_epoch, n_labels,
                 train_dataloader, val_dataloader, test_dataloader):
    net = nn.Sequential(
        ### Convolution layer extracts features
        nn.Conv3d(25, 32, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
        nn.ReLU(),
        nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
        nn.Conv3d(32, 64, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
        nn.ReLU(),
        nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
        nn.Conv3d(64, 128, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
        nn.ReLU(),
        nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
        nn.Conv3d(128, 256, kernel_size=(3, 5, 5), padding=(1, 2, 2)),
        nn.ReLU(),
        nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
        ### Change the output shape to fit the next fully connected layer
        nn.AdaptiveAvgPool3d((1, 1, 1)),
        nn.Flatten(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(512, n_labels),
        neuron.LIFNode(tau=tau, v_threshold=v_threshold)
    )

    ### Select GPU or CPU according to your computer configuration
    net = net.to(device)
    ### Define optimizer
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-4)
    ### Poisson encoding
    encoder = encoding.PoissonEncoder()
    ### Initialize the number of training sessions and accuracy
    train_times = 0
    max_val_accuracy = 0
    ### Specifies the relative path where the model is stored
    model_pth = r'./logs/best_snn.model'
    ### Define a list of validations and training sets that hold the precision
    val_accs, train_accs = [], []
    ### Define the loss function and select GPU or CPU according to the computer configuration
    loss_fn = nn.MSELoss().to(device)

    print('******************训练开始******************')
    # Define a custom learning rate adjustment function
    # def lr_lambda(epoch):
    #     # The learning rate remains the same for the first 30 rounds
    #     if epoch < 30:
    #         return 1.0
    #         # The learning rate of subsequent rounds decays exponentially
    #     else:
    #         return 0.95 ** (epoch - 30)
    #
    # # Create a LambdaLR scheduler
    # scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
    for epoch in range(train_epoch):
        print('***********************{}/{}***********************'.format(epoch + 1, train_epoch))
        ### Set the status of the network to train
        net.train()
        ### The main thing here is to attenuate the learning rate
        # When the training round is 20, the learning rate becomes 0.001
        if epoch == 18:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.001
        # Learning rate changes to 0.0001 at 30 training rounds
        if epoch == 28:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.0001
        ### Iterate over all imgs and labels in the training set
        for rind, (img, label) in enumerate(train_dataloader):
            ### Choose gpu or cpu according to your computer's configuration ###
            print(rind)
            img = img.to(device)
            label = label.long().to(device)
            ### Here you need to convert the tags to a one-hot encoding format
            label_one_hot = F.one_hot(label, n_labels).float()
            ### Clear past gradients
            optimizer.zero_grad()
            ### There is also a learning rate decay here, every 50 steps of training, the learning rate decay is 0.2 before (note that this step is not an epoch, which is different from the previous one)
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=18, gamma=0.2)
            for t in range(T):
                print(t)
                ### The out_spikes_counter needs to be initialized when traversing a t
                if t == 0:
                    out_spikes_counter = net(encoder(img).float())
                else:
                    out_spikes_counter += net(encoder(img).float())
            out_spikes_counter_frequency = out_spikes_counter / T
            ### Calculate loss
            loss = loss_fn(out_spikes_counter_frequency, label_one_hot)
            ### Backpropagation
            loss.backward()
            ### Update network parameters based on gradients
            optimizer.step()
            functional.reset_net(net)
            ### calculation accuracy
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item()
            ### Add second-round precision to the list of train_accs
            train_accs.append(accuracy)
            train_times += 1
        scheduler.step()

        ### The performance of the model on the validation set
        # Here you need to set the state of the model to eval
        net.eval()
        # No gradient update is required when testing the effect of the model on the validation set
        with torch.no_grad():
            test_sum = 0
            correct_sum = 0
            ### Iterate through all imgs and corresponding labels in the validation set
            for img, label in val_dataloader:
                ### Select GPU or CPU according to your computer configuration
                img = img.to(device)
                n_imgs = img.shape[0]
                out_spikes_counter = torch.zeros(n_imgs, n_labels).to(device)
                for t in range(T):
                    out_spikes_counter += net(encoder(img).float())
                ### correct_sum is a counter for the number of correct ones.
                correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            ### Calculate the validation accuracy and add it to the val_accuracy list.
            val_accuracy = correct_sum / test_sum
            val_accs.append(val_accuracy)
            ### Save the model that works best and the accuracy of the
            if val_accuracy > max_val_accuracy:
                max_val_accuracy = val_accuracy
                torch.save(net, model_pth)

    ### Test the performance of the final model on the test set
    ### Extract the model based on its path
    best_snn = torch.load(model_pth)
    # The state of the model needs to be set to eval during testing
    best_snn.eval()
    ### Choose gpu or cpu according to your computer's configuration ###
    best_snn.to(device)
    ### Initialize test accuracy
    max_test_accuracy = 0.0
    result_sops, result_num_spikes_1, result_num_spikes_2 = 0, 0, 0
    ### The gradient does not need to be updated during testing
    with torch.no_grad():
        functional.set_monitor(best_snn, True)
        test_sum, correct_sum = 0, 0
        ### Iterate over all imgs and corresponding labels in testdataloader
        for img, label in test_dataloader:
            ### Choose gpu or cpu according to your computer's configuration ###
            img = img.to(device)
            n_imgs = img.shape[0]
            ### Initialize a zero matrix of n_imgs x n_labels
            out_spikes_counter = torch.zeros(n_imgs, n_labels).to(device)
            denominator = n_imgs * len(test_dataloader)
            for t in range(T):
                enc_img = encoder(img).float()
                out_spikes_counter += best_snn(enc_img)
                result_num_spikes_1 += torch.sum(enc_img) / denominator
            result_num_spikes_2 += torch.sum(out_spikes_counter) / denominator
            out_spikes_counter_frequency = out_spikes_counter / T
            ### Converting labels to longs
            label = label.long().to(device)
            ### One-hot coding for labels ###
            label_one_hot = F.one_hot(label, n_labels).float()
            ### Add the loss of each step to the loss ###
            loss = loss_fn(out_spikes_counter_frequency, label_one_hot)
            ### correct is a counter that keeps track of the number of correct predictions
            correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
            test_sum += label.numel()
            functional.reset_net(best_snn)
        ### Calculating test accuracy
        test_accuracy = correct_sum / test_sum
        max_test_accuracy = max(max_test_accuracy, test_accuracy)
    ### Print model performance on test set
    result_msg = f'testset\'acc: device={device}, learning_rate={learning_rate}, T={T}, max_test_accuracy={max_test_accuracy:.4f}, loss = {loss:.4f}'
    result_msg += f", num_s1: {int(result_num_spikes_1)}, num_s2: {int(result_num_spikes_2)}"
    result_msg += f", num_s_per_node: {int(result_num_spikes_1) + int(result_num_spikes_2)}"
    print(result_msg)

    ### Returns the maximum precision of the model on the test set
    return max_test_accuracy
