'''

    1. Read the MAT file and remove the label
    2. Read out the frame

'''

'''

1. 分类一
    baseball_pitch
    baseball_swing
    bowling
    golf_swing
    tennis_forehand
    tennis_serve

2. 分类二
    bench_press
    clean_and_jerk
    jumping_jacks
    pull_ups
    sit_ups
    push_ups
    squats

3. 分类三
    jump_rope
    strumming_guitar

'''

action_classification = {
    0: ['baseball_pitch', 'baseball_swing', 'bowling', 'golf_swing', 'tennis_forehand', 'tennis_serve'],
    1: ['bench_press', 'clean_and_jerk', 'jumping_jacks', 'pull_ups', 'sit_ups', 'push_ups', 'squats'],
    2: ['jump_rope', 'strumming_guitar']

}

import cv2
import os
from scipy.io import loadmat
import pandas as pd
import math
import random
import warnings

video_path = r'./data/Penn_Action/video'
img_size = (360, 480)



### 1. Merge pictures into videos
def img2video(path, labels):
    for file_name in labels['name']:
        print(file_name)
        video_name = os.path.join(video_path, f'{file_name}.mp4')

        # Set the codec and frame rate of the output video file
        output_video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, img_size)


        # Cycle through all JPG files and add them to the output video
        img_name_list = os.listdir(os.path.join(path, file_name))
        img_name_list = sorted(img_name_list)
        for img_name in img_name_list:
            img_path = os.path.join(path, file_name, img_name)
            img = cv2.imread(os.path.join(img_path))
            img = cv2.resize(img, img_size)
            output_video.write(img)


        # Release resources and close the output video file
        output_video.release()
        cv2.destroyAllWindows()



### Extract tags
def load_labels(path, num_per_class):
    mat_name_list = os.listdir(path)
    mat_name_list = sorted(mat_name_list)
    label = []
    for mat_name in mat_name_list:
        mat_path = os.path.join(path, mat_name)
        data = loadmat(mat_path)
        label.append(str(data['action'][0]))

    mat_name_list = [name[:4] for name in mat_name_list]
    df = pd.DataFrame(data=[mat_name_list, label])
    df = df.transpose()
    df.columns = ['name', 'label']


    ### Convert categorical variables into numerical variables
    label_ = []
    for l in df['label']:
        if l in action_classification[0]:
            label_.append(0)
        elif l in action_classification[1]:
            label_.append(1)
        else:
            label_.append(2)

    df['label'] = label_


    ### Resample the label
    index_list = []
    for x in pd.unique(df['label']):
        index = list(df[df['label'] == x].index)
        index_list += random.sample(index, num_per_class)

    df = df.iloc[index_list, :]
    df = df.sort_values(by=['name'], axis=0)
    df.to_csv(r'./data/Penn_Action/labels.csv', index=False)
    return df


### Fill the frame numbers of all videos to be consistent
def set_frame(input_folder, output_folder):

    # Specify the number of frames to fill
    max_frames = 25


    # Traverse all MP4 video files in the input folder
    for filename in os.listdir(input_folder):
        print(filename)
        if filename.endswith(".mp4"):
            input_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_folder, filename)


            # Read the frame number of the original video
            cap = cv2.VideoCapture(input_path)
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cap.release()


            # If the original number of frames is less than the number of frames to be filled, fill it
            if frame_count < max_frames:

                # Read the original video and copy it to the new video
                cap = cv2.VideoCapture(input_path)
                fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
                fps = int(cap.get(cv2.CAP_PROP_FPS))
                width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))


                # Evenly sample some frames from the original video to fill the new video
                for i in range(max_frames):
                    t = i / max_frames * frame_count
                    cap.set(cv2.CAP_PROP_POS_FRAMES, math.floor(t))
                    ret, frame = cap.read()
                    out.write(frame)

                cap.release()
                out.release()
            elif frame_count > max_frames:

                # If the original number of frames is greater than the number of frames to be filled, truncate
                # Read the original video and copy it to the new video
                cap = cv2.VideoCapture(input_path)
                fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
                fps = int(cap.get(cv2.CAP_PROP_FPS))
                width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))


                # Evenly sample some frames from the original video to fill the new video
                for i in range(max_frames):
                    t = i / max_frames * frame_count
                    cap.set(cv2.CAP_PROP_POS_FRAMES, math.floor(t))
                    ret, frame = cap.read()
                    out.write(frame)

                cap.release()
                out.release()
            else:

                # If the original number of frames is greater than or equal to the number of frames to be filled, copy the file directly
                os.rename(input_path, output_path)


if __name__ == '__main__':
    warnings.filterwarnings('ignore')
    num_per_class = 200


    ### Extract tags
    label_path = r'./data/Penn_Action/labels'
    labels = load_labels(label_path, num_per_class)


    # Create video_path if it does not exist
    if not os.path.exists(video_path):
        os.mkdir(video_path)

    frames_path = r'./data/Penn_Action/frames'
    img2video(path=frames_path, labels=labels)

    set_frame(input_folder='./data/Penn_Action/video', output_folder=r'./data/Penn_Action/newVideo')
