# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
import os
from Utils.utils import get_time_dif
from tensorboardX import SummaryWriter
import tqdm

def test(config, model, test_iter, dataGnn):
    device = torch.device('cuda')
    dataGnn = dataGnn.to(device)
    
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, dataGnn, test=True)
    
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
    
    report_path = os.path.join(config.save_dir, f"{config.model_name}_test_report.txt")
    with open(report_path, "w") as f:
        f.write(f"Test Loss: {test_loss.item():.4f}\n")
        f.write(f"Test Accuracy: {test_acc:.4f}\n\n")
        f.write("Classification Report:\n")
        f.write(test_report)
        f.write("\n\nConfusion Matrix:\n")
        f.write(str(test_confusion))
    
    return test_acc, test_loss, test_report, test_confusion

def evaluate(config, model, data_iter, dataGnn, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    batch_count = 0
    
    with torch.no_grad():
        for texts, labels, mask in data_iter:
            outputs = model(texts, dataGnn)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)
            batch_count += 1

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / batch_count, report, confusion
    return acc, loss_total / batch_count