import torch
from torch.utils.data import Dataset
from utils.prompt_struct import PromStruct
import pandas as pd


class CountMesh:
    def __init__(self, args, config):
        self.start_lon = config['start_longitude']
        self.start_lat = config['start_latitude']
        
        self.longitude_lenth = 132 // config['scale']
        self.latitude_lenth = 155 // config['scale']

        self.step_longitude = (config['stop_longitude'] - config['start_longitude']) / self.longitude_lenth
        self.step_latitude = (config['stop_latitude'] - config['start_latitude']) / self.latitude_lenth

        self.lowpw = args.lowpw

    def df_init(self, df):
        df_1 = df[['start_time', 'start_soc', 'start_longitude', 'start_latitude']]
        df_2 = df[['stop_time', 'stop_soc', 'stop_longitude', 'stop_latitude']]

        df_1.columns = ['time', 'soc', 'longitude', 'latitude']
        df_2.columns = ['time', 'soc', 'longitude', 'latitude']

        result_df = pd.concat([df_1, df_2], axis=0)
        print('over')
        return result_df


    def mesh_pos(self, df):
        if 'longitude' in df.columns and 'latitude' in df.columns:
            df = df.rename(columns={'longitude': 'longitude', 'latitude': 'latitude'})
        df['Mesh_lon'] = ((df['longitude'] - self.start_lon) // self.step_longitude) + 1
        df['Mesh_lat'] = ((df['latitude'] - self.start_lat) // self.step_latitude) + 1

        mask_lon = (df['Mesh_lon'] > 0) & (df['Mesh_lon'] <= self.longitude_lenth)
        mask_lat = (df['Mesh_lat'] > 0) & (df['Mesh_lat'] <= self.latitude_lenth)
        mask = mask_lon & mask_lat

        return df[mask]


    def MeshGroup(self, df):
        return df.groupby(['Mesh_lon', 'Mesh_lat'])


    def Trans(self, df):
        df = self.MeshGroup(df)
        group_counts = df.size()
        return group_counts.reset_index(name = 'TrafficNum')
    


    def LowPower(self, df):
        df = df[df['soc'] <= self.lowpw]
        df = self.MeshGroup(df)
        df = df.size().reset_index(name = 'LowpowerNum')
        return df

    def POINum(self, df):
        df = self.MeshGroup(df)
        group_counts = df.size()
        return group_counts.reset_index(name = 'POI')


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data["input_ids"])
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.data["input_ids"][idx],
            'attention_mask': self.data["attention_masks"][idx],
            'mask_position': self.data['mask_positions'][idx],
            'label': self.data['labels'][idx]
        }
    
    @classmethod
    def preprocess_dataset(cls, data_content, tokenizer, method):
        # PromStruct = PromStruct()

        input_ls, attention_masks, mask_positions, labels = [], [], [], [] 

        for item in data_content.split('\n'):
            input_text, label = item.split('\t')  

            if method == 0:
                prompt_text = PromStruct.add_prompt_with_mask(input_text)
            elif method == 1:
                prompt_text = PromStruct.constract_prompt_with_mask(input_text)
                
            inputs = tokenizer(prompt_text, padding='max_length', truncation=True, max_length=64, return_tensors="pt")

            input_ls.append(inputs['input_ids'].squeeze(0))
            attention_masks.append(inputs['attention_mask'].squeeze(0))
            mask_position = inputs['input_ids'].squeeze(0).tolist().index(tokenizer.mask_token_id)
            mask_positions.append(mask_position)
            
            labels.append(int(label))

        return {
            "input_ids": torch.stack(input_ls),
            "attention_masks": torch.stack(attention_masks),
            "mask_positions": mask_positions,
            "labels": torch.tensor(labels)
        }

