### Import necessary libraries
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven import functional, encoding
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import torch
import cv2
import os
import glob
import time
import pynvml
from s_tui.sources.rapl_power_source import RaplPowerSource
#Define data preprocessing method
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((360, 480))
])


class VideoDataset(Dataset):
    def __init__(self, video_dir, label_dict):
        self.video_dir = video_dir
        self.video_list = glob.glob(os.path.join(video_dir, "*.mp4"))
        self.label_dict = label_dict

    def __len__(self):
        return len(self.video_list)

    def __getitem__(self, idx):
        video_name = self.video_list[idx]
        cap = cv2.VideoCapture(video_name)
        label = self.label_dict[int(video_name[-8:-4])]
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = transform(frame)
            frames.append(frame)
        cap.release()
        video = torch.stack(frames)

        return video, label


### Model training
def model_test(device, learning_rate, T, model_pth, n_labels, test_dataloader):
    ### Extract model
    best_snn = torch.load(model_pth)
    ### Poisson encoding
    encoder = encoding.PoissonEncoder()
    ### Define the loss function and select GPU or CPU according to the computer configuration
    loss_fn = nn.MSELoss().to(device)
    # The status of the model needs to be set to eval during testing
    best_snn.eval()
    ### Select GPU or CPU according to your computer configuration
    best_snn.to(device)
    ### Initialize test accuracy
    max_test_accuracy = 0.0
    result_sops, result_num_spikes_1, result_num_spikes_2 = 0, 0, 0
    ### The gradient does not need to be updated during testing
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    powerusage = pynvml.nvmlDeviceGetPowerUsage(handle)/1000
    start_time = time.time()

    source = RaplPowerSource()
    source.update()
    summary = dict(source.get_sensors_summary())
    cpu_power_total = str(sum(list(map(float, [summary[key] for key in summary.keys() if key.startswith('package')]))))
    with torch.no_grad():

        functional.set_monitor(best_snn, True)
        test_sum, correct_sum = 0, 0
        ### Iterate through all imgs and corresponding labels in TestDataLoader
        for rind, (img, label) in enumerate(test_dataloader):
            print(rind)
            ### Select GPU or CPU according to your computer configuration
            img = img.to(device)
            n_imgs = img.shape[0]
            ### Initialize a zero matrix of n_imgs x n_labels
            out_spikes_counter = torch.zeros(n_imgs, n_labels).to(device)
            denominator = n_imgs * len(test_dataloader)
            for t in range(T):
                print(t)
                enc_img = encoder(img).float()
                out_spikes_counter += best_snn(enc_img)
                result_num_spikes_1 += torch.sum(enc_img) / denominator
            result_num_spikes_2 += torch.sum(out_spikes_counter) / denominator
            out_spikes_counter_frequency = out_spikes_counter / T
            ### Convert labels to long type
            label = label.long().to(device)
            ### One-hot encoding of labels
            label_one_hot = F.one_hot(label, n_labels).float()
            ### Add the loss of each step to the loss
            loss = loss_fn(out_spikes_counter_frequency, label_one_hot)
            ### correct is a counter that records the number of correct predictions
            correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
            test_sum += label.numel()
            functional.reset_net(best_snn)
        ### Calculate test accuracy
        test_accuracy = correct_sum / test_sum
        max_test_accuracy = max(max_test_accuracy, test_accuracy)
    end_time = time.time()
    execution_time = end_time - start_time
    if device.type == 'cuda':
        # GPU Power Consumption (Example Value)
        power_consumption = powerusage  # Unit: Watt
    else:
        # CPU Power Consumption (Example Value)
        power_consumption = cpu_power_total  # Unit: Watt
    energy_consumption = power_consumption * execution_time  # Energy = Power Consumption x Time
    print('W:', energy_consumption)
    ### Print the performance of the model on the test set
    result_msg = f'testset\'acc: device={device}, learning_rate={learning_rate}, T={T}, max_test_accuracy={max_test_accuracy:.4f}, loss = {loss:.4f}'
    result_msg += f", num_s1: {int(result_num_spikes_1)}, num_s2: {int(result_num_spikes_2)}"
    result_msg += f", num_s_per_node: {int(result_num_spikes_1) + int(result_num_spikes_2)}"
    print(result_msg)


if __name__ == '__main__':
    ### GPU acceleration is used depending on the configuration of the computer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ### Set learning rate
    learning_rate = 2e-3
    T = 10
    ### Path to model weights
    model_pth = r'./logs/best_snn.model'
    ### Number of categories
    n_labels = 3
    batch_size = 8
    ### Divide the data set
    label = pd.read_csv(r'./data/Penn_Action/labels.csv')
    label_dict = {k: v for k, v in zip(label['name'], label['label'])}
    dataset = VideoDataset(video_dir=r'./data/Penn_Action/newVideo', label_dict=label_dict)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    tr_dataset, ts_dataset = random_split(dataset, [train_size, test_size])

    ### Build the dataloader for the training, validation, and test sets
    ts_dataloader = DataLoader(dataset=ts_dataset, batch_size=batch_size, shuffle=False)

    ### test
    model_test(device, learning_rate, T, model_pth, n_labels, ts_dataloader)
