import random
import numpy as np
import torch as th

def get_edge_index(matrix):
    edge_index = [[], []]
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i][j] != 0:
                edge_index[0].append(i)
                edge_index[1].append(j)
    return th.LongTensor(edge_index)


def make_adj(edges, size):
    edges_tensor = th.LongTensor(edges).t()
    values = th.ones(len(edges))
    adj = th.sparse.LongTensor(edges_tensor, values, size).to_dense().long()
    return adj

def data_processing(data, args):
    md_matrix = make_adj(data['md'], (args.miRNA_number, args.disease_number))
    one_index = []
    zero_index = []
    for i in range(md_matrix.shape[0]):
        for j in range(md_matrix.shape[1]):
            if md_matrix[i][j] >= 1:
                one_index.append([i, j])
            else:
                zero_index.append([i, j])
    random.seed(args.random_seed)
    random.shuffle(one_index)
    random.shuffle(zero_index)
    unsamples=[]
    if args.negative_rate == -1:
        zero_index = zero_index
    else:
        unsamples = zero_index[int(args.negative_rate * len(one_index)):]
        zero_index = zero_index[:int(args.negative_rate * len(one_index))]
    index = np.array(one_index + zero_index, int)
    label = np.array([1] * len(one_index) + [0] * len(zero_index), dtype=int)
    samples = np.concatenate((index, np.expand_dims(label, axis=1)), axis=1)
    md = samples[samples[:, 2] == 1, :2]


    data['train_samples'] = samples
    data['train_md'] = md
    data['unsamples']=np.array(unsamples)

def get_data(args):
    data = dict()
    ms=np.loadtxt(args.data_dir + 'ms.txt', dtype=float)
    ds = np.loadtxt(args.data_dir + 'ds.txt', dtype=float)
    data['miRNA_number'] = int(ms.shape[0])
    data['disease_number'] = int(ds.shape[0])
    data['ms'] = ms
    data['ds'] = ds

    miRNA_embedding = np.loadtxt(args.data_dir + 'miRNA_embedding.txt', dtype=float, delimiter=None,
                                 unpack=False)
    disease_embedding = np.loadtxt(args.data_dir + 'disease_embedding.txt', dtype=float, delimiter=None,
                                   unpack=False)
    emb_mm = miRNA_embedding[:901]
    emb_dd = disease_embedding[:877]

    data['emb_mm_number'] = int(emb_mm.shape[0])
    data['emb_dd_number'] = int(emb_dd.shape[0])

    data['emb_mm'] = emb_mm
    data['emb_dd'] = emb_dd

    data['d_num'] = np.loadtxt(args.data_dir + 'disease number.txt', delimiter='\t', dtype=str)[:, 1]
    data['m_num'] = np.loadtxt(args.data_dir + 'miRNA number.txt', delimiter='\t', dtype=str)[:, 1]
    data['md'] = np.loadtxt(args.data_dir + 'known disease-miRNA association number.txt', dtype=int) - 1


    return data

def mask_validation_edges(Y_full, val_pos_pairs):

    Y_masked = Y_full.copy()
    for (i, j) in val_pos_pairs:
        Y_masked[i, j] = 0
    return Y_masked

def _compute_gamma(X):
    norms2 = np.sum(X * X, axis=1)
    m = np.mean(norms2)
    return 1.0 / m if m > 0 else 1.0

def compute_gip_similarity(Y, axis='miRNA', gamma=None):

    X = Y if axis == 'miRNA' else Y.T  # (n, d)
    if gamma is None:
        gamma = _compute_gamma(X)
    sq = np.sum(X * X, axis=1, keepdims=True)            # (n,1)
    dist2 = np.maximum(sq + sq.T - 2.0 * (X @ X.T), 0.0) # (n,n)
    S = np.exp(-gamma * dist2)


    zero_rows = np.where(np.sum(X, axis=1) == 0)[0]
    if len(zero_rows) > 0:
        S[zero_rows, :] = 0.0
        S[:, zero_rows] = 0.0
        S[zero_rows, zero_rows] = 1.0
    return S

def fuse_with_base_similarity(base_sim, gip_sim):

    base_sim = np.asarray(base_sim)
    gip_sim = np.asarray(gip_sim)
    use_base = (base_sim != 0).astype(np.float32)
    return base_sim * use_base + gip_sim * (1.0 - use_base)

def build_fold_similarities(Y_full, base_ms, base_ds, val_pos_pairs, leakage_free=True):

    Y_for_gip = mask_validation_edges(Y_full, val_pos_pairs) if leakage_free else Y_full

    gip_m = compute_gip_similarity(Y_for_gip, axis='miRNA')
    gip_d = compute_gip_similarity(Y_for_gip, axis='disease')

    S_m = fuse_with_base_similarity(base_ms, gip_m)
    S_d = fuse_with_base_similarity(base_ds, gip_d)

    S_m = (S_m + S_m.T) / 2.0
    S_d = (S_d + S_d.T) / 2.0
    np.fill_diagonal(S_m, 1.0)
    np.fill_diagonal(S_d, 1.0)
    return S_m, S_d


