# coding: UTF-8
import time
import torch
import numpy as np
from Train_Test.train import train, init_network
from importlib import import_module
import argparse
import os
import pandas as pd
from Utils.utils import generate_dataset, build_iterator, get_time_dif
from makedata import makedata

def save_results(model_name, results, results_dir):
    filename = os.path.join(results_dir, f"{model_name}_results.txt")
    with open(filename, "w") as f:
        f.write(f"{model_name} Results: {results}\n")
        f.write(f"{model_name} Mean Accuracy: {np.mean(results):.4f} ± {np.std(results):.4f}\n")
        f.write(f"run: {len(results)}\n\n")

parser = argparse.ArgumentParser(description='GeoDFnet')

parser.add_argument('--runs', type=int, default=5, help='Number of runs for statistical significance testing')
parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')

if __name__ == '__main__':

    

    embedding = 'embedding_SougouNews.npz'
    model_name = 'GeoDFnet'
    num_runs = 10
    makedata()
    x = import_module('Models.' + model_name)
    config = x.Config(embedding)
    
    results_dir = config.save_dir
    config.model_name = model_name
  
    os.makedirs(results_dir, exist_ok=True)
    
    all_accuracies = []
    all_reports = []
    batch_histories = []
    epoch_histories = []

    for run in range(num_runs):
        print(f"\n{'='*40}")
        print(f"Run {run+1}/{num_runs}")
        print(f"{'='*40}")
        
        seed = 42 + run
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        

        config.run_number = run + 1

        start_time = time.time()
        print("---------LOADING Datas---------")
        vocab,train_data,val_data,test_data, train_mask, dev_mask, test_mask, dataGnn,class_weights = generate_dataset(config)
        train_iter = build_iterator(train_data,train_mask, config)
        dev_iter = build_iterator(val_data,dev_mask, config)
        test_iter = build_iterator(test_data,test_mask, config)
        time_dif = get_time_dif(start_time)
        print("Time usage:", time_dif)

        # train
        config.n_vocab = len(vocab)
        model = x.Model(config).to(config.device)
        if model_name != 'GeoDFnet':
            init_network(model)
        print(model.parameters)
        
        test_acc, test_loss, test_report, test_confusion = train(config, model, train_iter, dev_iter, test_iter, dataGnn, class_weights)
        
        all_accuracies.append(test_acc)
        all_reports.append(test_report)
        

        report_filename = os.path.join(results_dir, f"{model_name}_run{run+1}_report.txt")
        with open(report_filename, "w") as f:
            f.write(test_report)
        
        batch_history_file = os.path.join(results_dir, f"{model_name}_{run+1}_batch_history.csv")
        epoch_history_file = os.path.join(results_dir, f"{model_name}_{run+1}_epoch_history.csv")
        batch_histories.append(batch_history_file)
        epoch_histories.append(epoch_history_file)
    

    save_results(model_name, all_accuracies, results_dir)
    

    mean_acc = np.mean(all_accuracies)
    std_acc = np.std(all_accuracies)
    
    print(f"\n{'='*50}")
    print(f"{model_name} Final Results after {num_runs} runs")
    print(f"{'='*50}")
    print(f"Mean Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    
    if batch_histories:
        combined_batch = pd.concat([pd.read_csv(f) for f in batch_histories])
        combined_batch.to_csv(os.path.join(results_dir, f"{model_name}_combined_batch_history.csv"), index=False)
    
    if epoch_histories:
        combined_epoch = pd.concat([pd.read_csv(f) for f in epoch_histories])
        combined_epoch.to_csv(os.path.join(results_dir, f"{model_name}_combined_epoch_history.csv"), index=False)
    
    print(f"Training history is saved to: {results_dir}")