# coding: UTF-8
import pandas as pd 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
import os
from Utils.utils import get_time_dif
from tensorboardX import SummaryWriter
from Train_Test.test import *

def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        if exclude not in name:
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass


def train(config, model, train_iter, dev_iter, test_iter, dataGnn, class_weights):
    batch_history = {
        'run': [],
        'epoch': [],
        'batch': [],
        'train_loss': [],
        'train_acc': []
    }
    
    epoch_history = {
        'run': [],
        'epoch': [],
        'avg_train_loss': [],
        'avg_train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    start_time = time.time()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    device = torch.device('cuda')
    dataGnn = dataGnn.to(device)
    class_weights = class_weights.to(device)
    total_batch = 0
    dev_best_loss = float('inf')
    last_improve = 0
    flag = False
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    
    run_number = getattr(config, 'run_number', 1)
    
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        
        epoch_train_loss = 0.0
        epoch_train_acc = 0.0
        batch_count = 0
        
        for batch_idx, (trains, labels, mask) in enumerate(train_iter):
            outputs = model(trains, dataGnn)
            model.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            true = labels.data.cpu()
            predic = torch.max(outputs.data, 1)[1].cpu()
            train_acc = metrics.accuracy_score(true, predic)
            
            batch_history['run'].append(run_number)
            batch_history['epoch'].append(epoch + 1)
            batch_history['batch'].append(batch_idx + 1)
            batch_history['train_loss'].append(loss.item())
            batch_history['train_acc'].append(train_acc)
            
            epoch_train_loss += loss.item()
            epoch_train_acc += train_acc
            batch_count += 1
            
            dev_acc, dev_loss = evaluate(config, model, dev_iter, dataGnn)
            dev_losscpu = dev_loss.cpu().item()

            if dev_loss < dev_best_loss:
                dev_best_loss = dev_loss
                torch.save(model.state_dict(), config.save_path)
                improve = '*'
                last_improve = total_batch
            else:
                improve = ''
            
            time_dif = get_time_dif(start_time)
            msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
            print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
            
            writer.add_scalar("loss/train", loss.item(), total_batch)
            writer.add_scalar("loss/dev", dev_loss, total_batch)
            writer.add_scalar("acc/train", train_acc, total_batch)
            writer.add_scalar("acc/dev", dev_acc, total_batch)
            
            model.train()
            total_batch += 1
            
            if total_batch - last_improve > config.require_improvement:
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        
        avg_train_loss = epoch_train_loss / batch_count
        avg_train_acc = epoch_train_acc / batch_count
        
        epoch_history['run'].append(run_number)
        epoch_history['epoch'].append(epoch + 1)
        epoch_history['avg_train_loss'].append(avg_train_loss)
        epoch_history['avg_train_acc'].append(avg_train_acc)
        epoch_history['val_loss'].append(dev_losscpu)
        epoch_history['val_acc'].append(dev_acc)
        
        print(f"Epoch {epoch+1} Summary - Avg Train Loss: {avg_train_loss:.4f}, Avg Train Acc: {avg_train_acc:.4f}, Val Loss: {dev_losscpu:.4f}, Val Acc: {dev_acc:.4f}")
        os.makedirs(config.save_dir, exist_ok=True)
        batch_history_path = os.path.join(config.save_dir, f"{config.model_name+'_'+str(run_number)}_batch_history.csv")
        epoch_history_path = os.path.join(config.save_dir, f"{config.model_name+'_'+str(run_number)}_epoch_history.csv")
    
        pd.DataFrame(batch_history).to_csv(batch_history_path, index=False)
        pd.DataFrame(epoch_history).to_csv(epoch_history_path, index=False)
        
        if flag:
            break
    
    writer.close()
    test_acc, test_loss, test_report, test_confusion = test(config, model, test_iter, dataGnn)
    
    epoch_history['run'].append(run_number)
    epoch_history['epoch'].append('test')
    epoch_history['avg_train_loss'].append(None)
    epoch_history['avg_train_acc'].append(None)
    epoch_history['val_loss'].append(test_loss.item())
    epoch_history['val_acc'].append(test_acc)
    pd.DataFrame(epoch_history).to_csv(epoch_history_path, index=False)
    
    print("Training completed")
    return test_acc, test_loss, test_report, test_confusion


def evaluate(config, model, data_iter, dataGnn, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    batch_count = 0
    
    with torch.no_grad():
        for texts, labels, mask in data_iter:
            outputs = model(texts, dataGnn)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)
            batch_count += 1

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / batch_count, report, confusion
    return acc, loss_total / batch_count