from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from utils.dataset_process import CountMesh
from transformers import AutoTokenizer, BertForSequenceClassification
from tqdm import tqdm
from utils.decodemesh import Decodemesh
from utils.decodepos import Posdecode
import argparse
from sklearn.model_selection import train_test_split
import train
import pandas as pd
import yaml
import torch
import os

def encode_data(tokenizer, data, max_length):
    input_ids = []
    attention_masks = []
    
    for sentence in data:
        encoded_dict = tokenizer.encode_plus(
            sentence,                      
            add_special_tokens = True,     
            max_length = max_length,     
            pad_to_max_length = True,     
            return_attention_mask = True, 
            return_tensors = 'pt',     
        )
        
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])
    
    return torch.cat(input_ids, dim=0), torch.cat(attention_masks, dim=0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    root_path = os.path.dirname(os.path.abspath(__file__))
    parser.add_argument('--config', type=str, default = os.path.join(root_path, 'config/config.yaml'))
    parser.add_argument('--model', type=str, default=os.path.join(root_path, 'model'))
    # parser.add_argument('--weight', type = str, default=os.path.join(root_path, 'check_point/model_best.pth'))
    parser.add_argument('--weight', type = str, default=os.path.join(root_path, '/workspace/best_model_0706.pt'))

    parser.add_argument('--trick_path', type=str, default = os.path.join(root_path, 'datasets/trip2023.csv'))
    parser.add_argument('--poi_path', type=str, default = os.path.join(root_path, 'datasets/total_poi.csv'))
    parser.add_argument('--map_path', type= str , default=os.path.join(root_path, 'utils/map.jsonl'))
    parser.add_argument('--output_path', type=str, default=os.path.join(root_path, 'result'))

    parser.add_argument('--lowpw', type=int, default = 35)

    parser.add_argument('--train', type=bool, default = False)
    parser.add_argument('--predict', type = bool, default = True)

    parser.add_argument('--num_labels', type=int, default=10)
    parser.add_argument('--max_length', type = int, default = 128)
    parser.add_argument('--epochs', type = int, default = 100)
    parser.add_argument('--batch_size', type = int, default = 16)
    parser.add_argument('--lr', type = float, default = 2e-5)

    args = parser.parse_args()

    with open(args.config, 'r') as reader:
        config = yaml.load(reader, Loader=yaml.Loader)

    # load data
    trick = pd.read_csv(args.trick_path)
    poi_csv = pd.read_csv(args.poi_path)
    
    # set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    meshfeature = CountMesh(args, config)

    result = meshfeature.mesh_pos(meshfeature.df_init(trick))
    poi_df = meshfeature.mesh_pos(poi_csv)

    # result = pd.merge(result, meshfeature.Trans(result), on = ['Mesh_lon', 'Mesh_lat'], how='outer')
    # result = pd.merge(result, meshfeature.LowPower(result), on = ['Mesh_lon', 'Mesh_lat'], how='outer').drop(columns='soc_difference')
    # result = pd.merge(result, meshfeature.POINum(poi_df), on = ['Mesh_lon', 'Mesh_lat'], how='outer')
    trans_num_df = meshfeature.Trans(result)
    lowpower_num_df = meshfeature.LowPower(result)
    poi_num_df = meshfeature.POINum(poi_df)

    result = pd.merge(trans_num_df, pd.merge(lowpower_num_df, poi_num_df, on=['Mesh_lon', 'Mesh_lat'], how='outer'), on=['Mesh_lon', 'Mesh_lat'], how='outer')
    input_df_simple = result.fillna(0)

    predictions = []

    input_df_simple.to_csv('/workspace/debug/temp_val/input.csv', index=False)

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    
    res_df = input_df_simple[['TrafficNum', 'LowpowerNum' ,'POI']]
    res_df['prompt'] = ("The mash's TrafficNum is" + res_df['TrafficNum'].astype(str) + ', LowpowerNum is' + res_df['LowpowerNum'].astype(str) + ', POI is' + res_df['POI'].astype(int).astype(str))
    
    if args.train:
        res_df['level'] = res_df['level'] - 1  
        train_data, val_data = train_test_split(res_df, test_size=0.1, random_state=42)
        
        train_inputs, train_masks = encode_data(tokenizer, train_data['prompt'], args.max_length)
        val_inputs, val_masks = encode_data(tokenizer, train_data['prompt'], args.max_length)

        train_labels = torch.tensor(train_data['level'].values)
        val_labels = torch.tensor(val_data['level'].values)

        train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
        train_sampler = RandomSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size)
        
        val_dataset = TensorDataset(train_inputs, train_masks, train_labels)
        val_sampler = RandomSampler(val_dataset)
        val_dataloader = DataLoader(val_dataset, sampler=val_sampler, batch_size=args.batch_size)
        

        model = BertForSequenceClassification.from_pretrained(
            args.model,
            num_labels = len(res_df['level'].unique()),  
                output_attentions = False,
            output_hidden_states = False,
        )
        model.to(device)

        print('<==================train start==================>')
        train.train(args, model, train_dataloader, val_dataloader, device)
    
    if args.predict:
        print('<==================inference start==================>')
        data_inputs, data_masks = encode_data(tokenizer, res_df['prompt'], args.max_length)

        predict_dataset = TensorDataset(data_inputs, data_masks)
        predict_dataloader = DataLoader(predict_dataset, batch_size = args.batch_size)

        model = BertForSequenceClassification.from_pretrained(args.model, num_labels = args.num_labels)

        if torch.cuda.is_available():
            model.load_state_dict(torch.load(args.weight))
            model = model.to(device)

        else:
            state_dict = torch.load(args.weight, map_location=torch.device('cpu'))
            # Remove unexpected keys
            state_dict.pop("bert.embeddings.position_ids", None)
            
            model.load_state_dict(state_dict, strict=False)
            model.to(torch.device('cpu'))

        with torch.no_grad():
            for batch in tqdm(predict_dataloader, desc='Testing'):
                batch = [b.to(device) for b in batch]
                b_input_ids, b_input_mask = batch

                outputs = model(b_input_ids, attention_mask=b_input_mask)
                logits = outputs.logits
                preds = torch.argmax(logits, dim=-1)

                predictions.extend(preds.cpu().numpy())
                
        input_df_simple['level'] = predictions

        input_df_simple.to_csv('/workspace/debug/temp_val/input_data.csv')

        decoder = Decodemesh(config)
        input_df_simple['longitude'], input_df_simple['latitude'] = decoder.getPosiotion(input_df_simple['Mesh_lon'], input_df_simple['Mesh_lat'])

        pos_decode = Posdecode(args.map_path).decode(input_df_simple)
        result_df = pos_decode[['longitude', 'latitude', 'TrafficNum', 'LowpowerNum', 'POI', 'address', 'level']]
        result_df.columns = ['longitude', 'latitude', 'TrafficNum', 'LowpowerNum', 'POI', 'address', 'level']
        
        print('<==================save==================>')
        result_df.to_csv(os.path.join(args.output_path, 'result.csv'))

        result_df.to_json(os.path.join(args.output_path, 'result.json'), orient='records', force_ascii=False)
        print(f"successful saved! \ncsv save to: {os.path.join(args.output_path, 'result.csv')}\njson save to: {os.path.join(args.output_path, 'result.json')}")

                
