### Import necessary libraries
from myModel import model_lif_fc
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import torch
import cv2
import os
import glob
import pandas as pd
import numpy as np

### Build your own data set


#Define data preprocessing method
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((360, 480))
])


class VideoDataset(Dataset):
    def __init__(self, video_dir, label_dict):
        self.video_dir = video_dir
        self.video_list = glob.glob(os.path.join(video_dir, "*.mp4"))
        self.label_dict = label_dict

    def __len__(self):
        return len(self.video_list)

    def __getitem__(self, idx):
        video_name = self.video_list[idx]
        cap = cv2.VideoCapture(video_name)
        label = self.label_dict[int(video_name[-8:-4])]
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = transform(frame)
            frames.append(frame)
        cap.release()
        video = torch.stack(frames)

        return video, label


def snn_run():
    ### GPU acceleration is used depending on the configuration of the computer
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')
    batch_size = 8
    ### Set learning rate
    learning_rate = 2e-3
    T = 10
    tau = 80.0
    ### Set the threshold for membrane transmission information
    v_threshold = 0.2
    ### Training rounds
    train_epoch = 20
    ### Number of categories
    n_labels = 3

    ### Divide the data set
    label = pd.read_csv(r'./data/Penn_Action/labels.csv')
    label_dict = {k: v for k, v in zip(label['name'], label['label'])}

    dataset = VideoDataset(video_dir=r'./data/Penn_Action/newVideo',
                           label_dict=label_dict)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    tr_dataset, ts_dataset = random_split(dataset, [train_size, test_size])
    val_size = int(0.2 * len(tr_dataset))
    train_size = train_size - val_size
    tr_dataset, val_dataset = random_split(tr_dataset, [train_size, val_size])
    print(f'训练集样本数：{len(tr_dataset)}')
    print(f'验证集样本数：{len(val_dataset)}')
    print(f'测试集样本数：{len(ts_dataset)}')

    ### Build the dataloader for the training, validation, and test sets
    tr_dataloader = DataLoader(dataset=tr_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
    ts_dataloader = DataLoader(dataset=ts_dataset, batch_size=batch_size, shuffle=False)

    ### Model training
    ret = model_lif_fc(device=device,
                       learning_rate=learning_rate,
                       T=T,
                       tau=tau,
                       v_threshold=v_threshold,
                       train_epoch=train_epoch,
                       n_labels=n_labels,
                       train_dataloader=tr_dataloader,
                       val_dataloader=val_dataloader,
                       test_dataloader=ts_dataloader)

    return ret


if __name__ == '__main__':
    ### Number of runs
    runs = 1
    ### List of recorded scores
    scores = []
    ### The model runs runs
    for run in range(runs):
        score = snn_run()
        scores.append(score)
    ### Calculate the mean and standard deviation of precision
    me = np.mean(scores)
    st = np.std(scores)
    print("acc_averages {} times: means: {} std: {}".format(runs, me, st))
