from matplotlib import pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
from model import DCCM_MSIF
from torch import optim, nn
from tqdm import trange

from utils import make_adj, build_fold_similarities
import dgl
import copy
import numpy as np
import pandas as pd
import torch as th
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, auc,
    accuracy_score, precision_score, recall_score, f1_score, roc_curve
)
from sklearn.model_selection import KFold
import torch.nn.functional as F

device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
kfolds = 5


def print_met(list_):
    print('AUC ：%.4f ' % (list_[0]),
          'AUPR ：%.4f ' % (list_[1]),
          'Accuracy ：%.4f ' % (list_[2]),
          'precision ：%.4f ' % (list_[3]),
          'recall ：%.4f ' % (list_[4]),
          'f1_score ：%.4f \n' % (list_[5]))


def loss_contrastive_m(m1, m2):
    m1, m2 = (m1 / th.norm(m1)), (m2 / th.norm(m2))
    pos_m1_m2 = th.sum(m1 * m2, dim=1, keepdim=True)
    neg_m1 = th.matmul(m1, m1.t())
    neg_m2 = th.matmul(m2, m2.t())
    neg_m1 = neg_m1 - th.diag_embed(th.diag(neg_m1))
    neg_m2 = neg_m2 - th.diag_embed(th.diag(neg_m2))
    pos_m = th.mean(th.cat([pos_m1_m2], dim=1), dim=1)
    neg_m = th.mean(th.cat([neg_m1, neg_m2], dim=1), dim=1)
    loss_m = th.mean(F.softplus(neg_m - pos_m))
    return loss_m


def loss_contrastive_d(d1, d2):
    d1, d2 = d1 / th.norm(d1), d2 / th.norm(d2)
    pos_d1_d2 = th.sum(d1 * d2, dim=1, keepdim=True)
    neg_d1 = th.matmul(d1, d1.t())
    neg_d2 = th.matmul(d2, d2.t())
    neg_d1 = neg_d1 - th.diag_embed(th.diag(neg_d1))
    neg_d2 = neg_d2 - th.diag_embed(th.diag(neg_d2))
    pos_d = th.mean(th.cat([pos_d1_d2], dim=1), dim=1)
    neg_d = th.mean(th.cat([neg_d1, neg_d2], dim=1), dim=1)
    loss_d = th.mean(F.softplus(neg_d - pos_d))
    return loss_d


def build_graphs_and_features(S_m, S_d, data, args, em_cos_m, em_cos_d):

    # miRNA / disease


    md_copy = copy.deepcopy(data['train_md'])
    md_copy[:, 1] = md_copy[:, 1] + args.miRNA_number
    src_nodes = np.concatenate((md_copy[:, 0], md_copy[:, 1]))
    dst_nodes = np.concatenate((md_copy[:, 1], md_copy[:, 0]))
    md_graph = dgl.graph((src_nodes, dst_nodes), num_nodes=args.miRNA_number + args.disease_number)


    miRNA_th = th.Tensor(S_m)
    disease_th = th.Tensor(S_d)
    emb_miRNA_th = th.Tensor(np.concatenate((S_m, em_cos_m), axis=1))
    emb_disease_th = th.Tensor(np.concatenate((S_d, em_cos_d), axis=1))
    return md_graph, miRNA_th, disease_th, emb_miRNA_th, emb_disease_th


def train_eval_one_setting(model, optimizer, cross_entropy, md_graph,
                           miRNA_th, disease_th, emb_miRNA_th, emb_disease_th,
                           train_samples, valid_samples, epochs):

    a, b = train_samples, valid_samples


    epochs_bar = trange(epochs, desc='train')
    for _ in epochs_bar:
        model.train()
        optimizer.zero_grad()

        train_samples_th = th.Tensor(a).float()
        train_score, m1, m2, d1, d2 = model(
             md_graph,miRNA_th, disease_th, emb_miRNA_th, emb_disease_th, a
        )
        train_m_loss = loss_contrastive_m(m1, m2)
        train_d_loss = loss_contrastive_d(d1, d2)
        train_cross_loss = cross_entropy(th.flatten(train_score), train_samples_th[:, 2].to(device))
        train_loss = train_cross_loss + train_d_loss + train_m_loss
        train_loss.backward()
        optimizer.step()


    model.eval()
    with th.no_grad():
        scoree, _, _, _, _ = model(
            md_graph,miRNA_th, disease_th, emb_miRNA_th, emb_disease_th, b
        )
    scoree = scoree.cpu().detach().numpy().ravel()

    sc = b
    sc_true = sc[:, 2]
    fpr, tpr, thresholds = roc_curve(sc_true, scoree)
    aucc = roc_auc_score(sc_true, scoree)
    precisions, recalls, _ = precision_recall_curve(sc_true, scoree)
    auprc = auc(recalls, precisions)


    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    pred = (scoree >= optimal_threshold).astype(int)
    accuracy = accuracy_score(sc_true, pred)
    precision1 = precision_score(sc_true, pred)
    recall1 = recall_score(sc_true, pred)
    f1 = f1_score(sc_true, pred)

    return (scoree, aucc, auprc, fpr, tpr, recalls, precisions,
            optimal_threshold, accuracy, precision1, recall1, f1)


def train(data, args):
    all_score = []
    all_folds_results = pd.DataFrame()
    kf = KFold(n_splits=kfolds, shuffle=True, random_state=123)
    train_idx, valid_idx = [], []
    for train_index, valid_index in kf.split(data['train_samples']):
        train_idx.append(train_index)
        valid_idx.append(valid_index)


    fig_roc, ax_roc = plt.subplots(figsize=(8, 6))
    ax_roc.set_title("ROC Curve")
    ax_roc.set_xlabel("False Positive Rate")
    ax_roc.set_ylabel("True Positive Rate")
    ax_roc.grid()

    fig_pr, ax_pr = plt.subplots(figsize=(8, 6))
    ax_pr.set_title("Precision-Recall Curve")
    ax_pr.set_xlabel("Recall")
    ax_pr.set_ylabel("Precision")
    ax_pr.grid()


    Y_full = make_adj(data['md'], (args.miRNA_number, args.disease_number)).numpy()
    base_ms = np.asarray(data['ms'])
    base_ds = np.asarray(data['ds'])
    em_cos_m = cosine_similarity(data['emb_mm'])
    em_cos_m = (em_cos_m + em_cos_m.T) / 2
    em_cos_d = cosine_similarity(data['emb_dd'])
    em_cos_d = (em_cos_d + em_cos_d.T) / 2

    # （A - B）
    delta_auc_list, delta_auprc_list = [], []

    for i in range(kfolds):
        a = data['train_samples'][train_idx[i]]
        b = data['train_samples'][valid_idx[i]]
        print(f'################ Fold {i + 1} of {kfolds} ################')


        val_pos_pairs = [tuple(map(int, xy)) for xy in b[b[:, 2] == 1][:, :2]]


        S_m_A, S_d_A = build_fold_similarities(Y_full, base_ms, base_ds, val_pos_pairs, leakage_free=False)
        md_graph, miRNA_A, disease_A, emb_miRNA_A, emb_disease_A = \
            build_graphs_and_features(S_m_A, S_d_A, data, args, em_cos_m, em_cos_d)


        S_m_B, S_d_B = build_fold_similarities(Y_full, base_ms, base_ds, val_pos_pairs, leakage_free=True)
        md_graph_B, miRNA_B, disease_B, emb_miRNA_B, emb_disease_B = \
            build_graphs_and_features(S_m_B, S_d_B, data, args, em_cos_m, em_cos_d)
        #
        md_graph = md_graph

        #
        model_A = DCCM_MSIF(args).to(device)
        optimizer_A = optim.AdamW(model_A.parameters(), weight_decay=args.wd, lr=args.lr)
        cross_entropy = nn.BCELoss()
        (scoree_A, aucc_A, auprc_A, fpr_A, tpr_A, recalls_A, precisions_A,
         thr_A, acc_A, pre_A, rec_A, f1_A) = train_eval_one_setting(
            model_A, optimizer_A, cross_entropy,
            md_graph,miRNA_A, disease_A, emb_miRNA_A, emb_disease_A,
            a, b, args.epochs
        )


        model_B = DCCM_MSIF(args).to(device)
        optimizer_B = optim.AdamW(model_B.parameters(), weight_decay=args.wd, lr=args.lr)
        (scoree_B, aucc_B, auprc_B, fpr_B, tpr_B, recalls_B, precisions_B,
         thr_B, acc_B, pre_B, rec_B, f1_B) = train_eval_one_setting(
            model_B, optimizer_B, cross_entropy,
            md_graph,miRNA_B, disease_B, emb_miRNA_B, emb_disease_B,
            a, b, args.epochs
        )


        print(f"[Fold {i + 1}] "
              f"AUC_A={aucc_A:.6f} AUPRC_A={auprc_A:.6f} | "
              f"AUC_B={aucc_B:.6f} AUPRC_B={auprc_B:.6f} | "
              f"ΔAUC={aucc_A - aucc_B:+.6f} ΔAUPRC={auprc_A - auprc_B:+.6f}")

        ax_roc.plot(fpr_A, tpr_A, label=f'Fold {i + 1} A (AUC={aucc_A:.4f})')
        ax_roc.plot(fpr_B, tpr_B, linestyle='--', label=f'Fold {i + 1} B (AUC={aucc_B:.4f})')
        ax_pr.plot(recalls_A, precisions_A, label=f'Fold {i + 1} A (AUPRC={auprc_A:.4f})')
        ax_pr.plot(recalls_B, precisions_B, linestyle='--', label=f'Fold {i + 1} B (AUPRC={auprc_B:.4f})')


        fold_label_A = np.array([f'Fold {i + 1}'] * len(scoree_A))
        fold_res_A = np.column_stack((b, scoree_A, fold_label_A, (scoree_A >= thr_A).astype(int), np.array(['A'] * len(scoree_A))))
        fold_df_A = pd.DataFrame(fold_res_A, columns=['miRNA', 'Disease', 'True_Label', 'Pred_Score', 'Fold', 'Pred_label', 'Setting'])

        fold_label_B = np.array([f'Fold {i + 1}'] * len(scoree_B))
        fold_res_B = np.column_stack((b, scoree_B, fold_label_B, (scoree_B >= thr_B).astype(int), np.array(['B'] * len(scoree_B))))
        fold_df_B = pd.DataFrame(fold_res_B, columns=['miRNA', 'Disease', 'True_Label', 'Pred_Score', 'Fold', 'Pred_label', 'Setting'])

        all_folds_results = pd.concat([all_folds_results, fold_df_A, fold_df_B], ignore_index=True)


        all_score.append([aucc_A, auprc_A, acc_A, pre_A, rec_A, f1_A])
        all_score.append([aucc_B, auprc_B, acc_B, pre_B, rec_B, f1_B])


        delta_auc_list.append(aucc_A - aucc_B)
        delta_auprc_list.append(auprc_A - auprc_B)


        scoree = scoree_B


    ax_roc.legend(loc="lower right")
    fig_roc.tight_layout()
    fig_roc.savefig("roc_curve.png", dpi=600)
    fig_roc.show()

    ax_pr.legend(loc="best")
    fig_pr.tight_layout()
    fig_pr.savefig("pr.png", dpi=600)
    fig_pr.show()


    all_folds_results.to_excel('combined_results.xlsx', index=False)


    cv_metric = np.mean(all_score, axis=0)
    print('################ 5-Fold Result (A & B together mean) ################')
    print_met(cv_metric)


    df_delta = pd.DataFrame({
        'Fold': [f'Fold {i + 1}' for i in range(kfolds)],
        'Delta_AUC (Orig - MaskedRecomp)': delta_auc_list,
        'Delta_AUPRC (Orig - MaskedRecomp)': delta_auprc_list
    })
    df_delta.loc[len(df_delta.index)] = ['Mean', np.mean(delta_auc_list), np.mean(delta_auprc_list)]
    df_delta.loc[len(df_delta.index)] = ['Std', np.std(delta_auc_list), np.std(delta_auprc_list)]
    df_delta.to_excel('leakage_control_results.xlsx', index=False)
    print("\n=== Leakage-free control summary ===")
    print(df_delta)

    return scoree
