

import numpy as np

import torch
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn

import dgl

# from module.GAT import GAT, GAT_ffn
from module.Encoder import sentEncoder
from module.GAT import WSWGAT
from module.PositionEmbedding import get_sinusoid_encoding_table
from GAMLPmodel import R_GAMLP,JK_GAMLP,NARS_JK_GAMLP,NARS_R_GAMLP,R_GAMLP_RLU,JK_GAMLP_RLU,NARS_JK_GAMLP_RLU,NARS_R_GAMLP_RLU
from layer import *
from module.GATLayer import PositionwiseFeedForward
class danGraph(nn.Module):
    """ without sent2sent and add residual connection """
    def __init__(self, hps, embed):
        """

        :param hps: 
        :param embed: word embedding
        """
        super().__init__()

        self._hps = hps
        self._n_iter = hps.n_iter
        self._embed = embed
        self.embed_size = hps.word_emb_dim


        # sent node feature
        self._init_sn_param()
        self._TFembed = nn.Embedding(10, hps.feat_embed_size)   # box=10
        self.n_feature_proj = nn.Linear(hps.n_feature_size * 2, hps.hidden_size, bias=False)

        # word -> sent
        embed_size = hps.word_emb_dim
        self.word2sent = WSWGAT(in_dim=embed_size,
                                out_dim=hps.hidden_size,
                                num_heads=hps.n_head,
                                attn_drop_out=hps.atten_dropout_prob,
                                ffn_inner_hidden_size=hps.ffn_inner_hidden_size,
                                ffn_drop_out=hps.ffn_dropout_prob,
                                feat_embed_size=hps.feat_embed_size,
                                layerType="W2S"
                                )
        self.wdffn = PositionwiseFeedForward(hps.hidden_size, hps.ffn_inner_hidden_size, hps.ffn_dropout_prob)
        # sent -> word
        self.sent2word = WSWGAT(in_dim=hps.hidden_size,
                                out_dim=embed_size,
                                num_heads=6,
                                attn_drop_out=hps.atten_dropout_prob,
                                ffn_inner_hidden_size=hps.ffn_inner_hidden_size,
                                ffn_drop_out=hps.ffn_dropout_prob,
                                feat_embed_size=hps.feat_embed_size,
                                layerType="S2W"
                                )
        
        self.dwffn = PositionwiseFeedForward(embed_size, hps.ffn_inner_hidden_size, hps.ffn_dropout_prob)
        
        self.sent2T = WSWGAT(in_dim=hps.hidden_size,
                                out_dim=hps.hidden_size,
                                num_heads=8,
                                attn_drop_out=hps.atten_dropout_prob,
                                ffn_inner_hidden_size=hps.ffn_inner_hidden_size,
                                ffn_drop_out=hps.ffn_dropout_prob,
                                feat_embed_size=hps.feat_embed_size,
                                layerType="S2T"
                                )
        self.stffn = PositionwiseFeedForward(hps.hidden_size, hps.ffn_inner_hidden_size, hps.ffn_dropout_prob)        
        self.T2sent = WSWGAT(in_dim=hps.hidden_size,
                                out_dim=hps.hidden_size,
                                num_heads=8,
                                attn_drop_out=hps.atten_dropout_prob,
                                ffn_inner_hidden_size=hps.ffn_inner_hidden_size,
                                ffn_drop_out=hps.ffn_dropout_prob,
                                feat_embed_size=hps.feat_embed_size,
                                layerType="T2S"
                                )
        self.tsffn = PositionwiseFeedForward(hps.hidden_size, hps.ffn_inner_hidden_size, hps.ffn_dropout_prob) 
        # node classification
        self.n_feature = hps.hidden_size
        self.wh = nn.Linear(self.n_feature, 2)
        
        '''self.JK_GAMLP(hps.hidden_size, 128, hps.hidden_size,1,
                 0.5, 0,0,0.5,4,4,sigmoid,False,False,False,False)'''
        '''self.model1=JK_GAMLP(64, 64, 64,1,
                 0.5, 0,0,0.5,4,4,'sigmoid',False,False,False,False)'''
        self.lr_jk_ref1 = FeedForwardNetII(
                1*300, 300, 300, 4, 0.5, 0.5, False)
        self.lr_att1 = nn.Linear(300*2, 300)
        self.lr_output1 = FeedForwardNetII(
                300, 300, 300, 4, 0.5, 0.5, False)
        self.res_fc1 = nn.Linear(300, 300)
        
        self.lr_jk_ref = FeedForwardNetII(
                1*64, 64, 64, 4, 0.5, 0.5, False)
            #self.lr_att = nn.Linear(nfeat + hidden, 1)
            
        self.num_hops = 1
        self.prelu = nn.PReLU()
        self.lr_att = nn.Linear(128, 64)
        self.lr_output = FeedForwardNetII(
                64, 64, 64, 4, 0.5, 0.5, False)
        self.dropout = nn.Dropout(0.5)
        self.input_drop = nn.Dropout(0)
        self.att_drop = nn.Dropout(0.5)
        self.pre_process = False
        self.res_fc = nn.Linear(64, 64)
        self.act = torch.nn.LeakyReLU(0.2)
    def forward(self, graph):
        """
        :param graph: [batch_size] * DGLGraph
            node:
                word: unit=0, dtype=0, id=(int)wordid in vocab
                sentence: unit=1, dtype=1, words=tensor, position=int, label=tensor
            edge:
                word2sent, sent2word:  tffrac=int, type=0
        :return: result: [sentnum, 2]
        """

        # word node init


        
        
        
        
        word_feature = self.set_wnfeature(graph)    # [wnode, embed_size]
        word_state = word_feature
        
        a,b=self.set_snfeature(graph)
        sent_feature = self.n_feature_proj(a)    # [snode, n_feature_size]
        T_feature = self.n_feature_proj(b)

        sent_state = self.GTMLPS(graph, word_feature, sent_feature)
        
        '''num_hops=sent_feature.shape[0]
        num_node=sent_feature.shape[0]
        #sent_feature1 = self.n_feature_proj(self.set_snfeature(graph)) 
        
        
        #hh=sent_feature.reshape(-1,1,64)
        xx=self.lr_jk_ref(sent_feature)
        sent_state1 = self.word2sent(graph, word_feature, xx)
        sent_state = self.dropout(self.prelu(sent_state1))
        dd=sent_feature.reshape(1,-1,64)
        rr=[self.act(self.lr_att(torch.cat((sent_state, x), dim=1))) for x in dd]

        W = torch.cat(rr, dim=1)
        right_1 = self.lr_output(W)
        h = self.wdffn(right_1.unsqueeze(0)).squeeze(0)
        sent_state=h'''
        '''rr=[self.act(self.lr_att(torch.cat((sent_state, x), dim=1))).view(num_node, 1) for x in dd]
        W = torch.cat(rr, dim=1)
        W = F.softmax(W, 1)
        
        right_1 = torch.mul(sent_state1, self.att_drop(W))'''
        
        '''for i in range(1, num_hops):
            right_1 = right_1 + \
                torch.mul(sent_state1[i], self.att_drop(W[i, :]))'''
        #self.model1.train()

        #sent_state2 = torch.cat([sent_state, sent_state1], dim=1)
        
        
        # the start state


        for i in range(self._n_iter):
            # sent -> word

            word_state = self.GTMLPW(graph, word_state, sent_state)
            # word -> sent
            sent_state = self.GTMLPS(graph, word_state, sent_state)
            
            T_feature = self.GTMLPSTT(graph, T_feature, sent_state)

            sent_state = self.GTMLPTTS(graph, T_feature, sent_state)
            

        result = self.wh(sent_state)
        return result
    def GTMLPS(self,graph,word_feature,sent_feature):
        num_hops=sent_feature.shape[0]
        num_node=sent_feature.shape[0]
        xx=self.lr_jk_ref(sent_feature)
        sent_state1 = self.word2sent(graph, word_feature, xx)
        sent_state = self.dropout(self.prelu(sent_state1))
        dd=sent_feature.reshape(1,-1,64)
        rr=[self.act(self.lr_att(torch.cat((sent_state, x), dim=1))) for x in dd]
        W = torch.cat(rr, dim=1)
        #hh = self.wdffn(W.unsqueeze(0)).squeeze(0)
        #right = self.lr_output(hh)
        #sent_state11 = self.word2sent(graph, word_feature, right)
        
        
        h = self.wdffn(W.unsqueeze(0)).squeeze(0)
        sent_state=h
        return sent_state
    def GTMLPW(self,graph,word_state,sent_state):
        num_hops=word_state.shape[0]
        num_node=word_state.shape[0]
        xx=self.lr_jk_ref1(word_state)
        
        word_state1 = self.sent2word(graph,xx, sent_state)
        word_state = self.dropout(self.prelu(word_state1))
        dd=word_state.reshape(1,-1,300)
        rr=[self.act(self.lr_att1(torch.cat((word_state,x), dim=1))) for x in dd]
        W = torch.cat(rr, dim=1)
        #hh = self.dwffn(W.unsqueeze(0)).squeeze(0)
        
        #right = self.lr_output1(hh)
        #word_state11 = self.sent2word(graph, right, sent_state)
        h = self.dwffn(W.unsqueeze(0)).squeeze(0)
        word_state=h
        return word_state
    
    
    def GTMLPTTS(self,graph,word_feature,sent_feature):
        num_hops=sent_feature.shape[0]
        num_node=sent_feature.shape[0]
        xx=self.lr_jk_ref(sent_feature)

        sent_state1 = self.T2sent(graph,word_feature,xx)
        sent_state = self.dropout(self.prelu(sent_state1))
        dd=sent_feature.reshape(1,-1,64)
        rr=[self.act(self.lr_att(torch.cat((sent_state, x), dim=1))) for x in dd]
        W = torch.cat(rr, dim=1)
        #hh = self.wdffn(W.unsqueeze(0)).squeeze(0)
        #right = self.lr_output(hh)
        #sent_state11 = self.word2sent(graph, word_feature, right)
        
        
        h = self.tsffn(W.unsqueeze(0)).squeeze(0)
        sent_state=h
        return sent_state
    
    
    def GTMLPSTT(self,graph,word_feature,sent_feature):
        num_hops=word_feature.shape[0]
        num_node=word_feature.shape[0]
        xx=self.lr_jk_ref(word_feature)
        
        sent_state1 = self.sent2T(graph, xx, sent_feature)
        
        sent_state = self.dropout(self.prelu(sent_state1))
        
        dd=word_feature.reshape(1,-1,64)
       
        
        #rr=[self.act(self.lr_att(torch.cat((sent_state, x), dim=1))) for x in dd]
        rr=[self.act(self.lr_att(torch.cat((sent_state, x), dim=1))) for x in dd]
        W = torch.cat(rr, dim=1)
        #hh = self.wdffn(W.unsqueeze(0)).squeeze(0)
        #right = self.lr_output(hh)
        #sent_state11 = self.word2sent(graph, word_feature, right)
        
        
        h = self.stffn(W.unsqueeze(0)).squeeze(0)
        sent_state=h
        return sent_state
    

    def _init_sn_param(self):
        self.sent_pos_embed = nn.Embedding.from_pretrained(
            get_sinusoid_encoding_table(self._hps.doc_max_timesteps + 1, self.embed_size, padding_idx=0),
            freeze=True)
        self.cnn_proj = nn.Linear(self.embed_size, self._hps.n_feature_size)
        self.lstm_hidden_state = self._hps.lstm_hidden_state
        self.lstm = nn.LSTM(self.embed_size, self.lstm_hidden_state, num_layers=self._hps.lstm_layers, dropout=0.1,
                            batch_first=True, bidirectional=self._hps.bidirectional)
        if self._hps.bidirectional:
            self.lstm_proj = nn.Linear(self.lstm_hidden_state * 2, self._hps.n_feature_size)
        else:
            self.lstm_proj = nn.Linear(self.lstm_hidden_state, self._hps.n_feature_size)

        self.ngram_enc = sentEncoder(self._hps, self._embed)

    def _sent_cnn_feature(self, graph, snode_id):
        ngram_feature = self.ngram_enc.forward(graph.nodes[snode_id].data["words"])  # [snode, embed_size]
        graph.nodes[snode_id].data["sent_embedding"] = ngram_feature
        snode_pos = graph.nodes[snode_id].data["position"].view(-1)  # [n_nodes]
        position_embedding = self.sent_pos_embed(snode_pos)
        cnn_feature = self.cnn_proj(ngram_feature + position_embedding)
        return cnn_feature

    def _sent_lstm_feature(self, features, glen):
        pad_seq = rnn.pad_sequence(features, batch_first=True)
        lstm_input = rnn.pack_padded_sequence(pad_seq, glen, batch_first=True)
        lstm_output, _ = self.lstm(lstm_input)
        unpacked, unpacked_len = rnn.pad_packed_sequence(lstm_output, batch_first=True)
        lstm_embedding = [unpacked[i][:unpacked_len[i]] for i in range(len(unpacked))]
        lstm_feature = self.lstm_proj(torch.cat(lstm_embedding, dim=0))  # [n_nodes, n_feature_size]
        return lstm_feature

    def set_wnfeature(self, graph):
        wnode_id = graph.filter_nodes(lambda nodes: nodes.data["unit"]==0)
        wsedge_id = graph.filter_edges(lambda edges: edges.data["dtype"] == 0)   # for word to supernode(sent&doc)
        wid = graph.nodes[wnode_id].data["id"]  # [n_wnodes]
        w_embed = self._embed(wid)  # [n_wnodes, D]
        graph.nodes[wnode_id].data["embed"] = w_embed
        etf = graph.edges[wsedge_id].data["tffrac"]
        graph.edges[wsedge_id].data["tfidfembed"] = self._TFembed(etf)
        return w_embed

    def set_snfeature(self, graph):
        # node feature

        
        snode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 1)
        snode_id1 = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 2)
        cnn_feature = self._sent_cnn_feature(graph, snode_id)
        cnn_feature1 = self._sent_cnn_feature(graph, snode_id1)
        
        features, glen,features1, glen1 = get_snode_feat(graph, feat="sent_embedding")
        lstm_feature = self._sent_lstm_feature(features, glen)
        lstm_feature1 = self._sent_lstm_feature(features1, glen1)
        
        node_feature = torch.cat([cnn_feature, lstm_feature], dim=1)  # [n_nodes, n_feature_size * 2]
        topi_feature = torch.cat([cnn_feature1, lstm_feature1], dim=1)  # [n_nodes, n_feature_size * 2]
        #print(node_feature.shape,"node",topi_feature.shape,"topi_feature")
        return node_feature,topi_feature
        




class duoGraph(HSumGraph):
    """
        without sent2sent and add residual connection
        add Document Nodes
    """

    def __init__(self, hps, embed):
        super().__init__(hps, embed)
        self.dn_feature_proj = nn.Linear(hps.hidden_size, hps.hidden_size, bias=False)
        self.wh = nn.Linear(self.n_feature * 2, 2)

    def forward(self, graph):
        """
        :param graph: [batch_size] * DGLGraph
            node:
                word: unit=0, dtype=0, id=(int)wordid in vocab
                sentence: unit=1, dtype=1, words=tensor, position=int, label=tensor
                document: unit=1, dtype=2
            edge:
                word2sent, sent2word: tffrac=int, type=0
                word2doc, doc2word: tffrac=int, type=0
                sent2doc: type=2
        :return: result: [sentnum, 2]
        """

        snode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 1)
        dnode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 2)
        supernode_id = graph.filter_nodes(lambda nodes: nodes.data["unit"] == 1)

        # word node init
        word_feature = self.set_wnfeature(graph)    # [wnode, embed_size]
        sent_feature = self.n_feature_proj(self.set_snfeature(graph))    # [snode, n_feature_size]

        # sent and doc node init
        graph.nodes[snode_id].data["init_feature"] = sent_feature
        doc_feature, snid2dnid = self.set_dnfeature(graph)
        doc_feature = self.dn_feature_proj(doc_feature)
        graph.nodes[dnode_id].data["init_feature"] = doc_feature

        # the start state
        word_state = word_feature
        sent_state = graph.nodes[supernode_id].data["init_feature"]
        sent_state = self.word2sent(graph, word_state, sent_state)

        for i in range(self._n_iter):
            # sent -> word
            word_state = self.sent2word(graph, word_state, sent_state)
            # word -> sent
            sent_state = self.word2sent(graph, word_state, sent_state)

        graph.nodes[supernode_id].data["hidden_state"] = sent_state

        # extract sentence nodes
        s_state_list = []
        for snid in snode_id:
            d_state = graph.nodes[snid2dnid[int(snid)]].data["hidden_state"]
            s_state = graph.nodes[snid].data["hidden_state"]
            s_state = torch.cat([s_state, d_state], dim=-1)
            s_state_list.append(s_state)

        s_state = torch.cat(s_state_list, dim=0)
        result = self.wh(s_state)
        return result


    def set_dnfeature(self, graph):
        """ init doc node by mean pooling on the its sent node (connected by the edges with type=1) """
        dnode_id = graph.filter_nodes(lambda nodes: nodes.data["dtype"] == 2)
        node_feature_list = []
        snid2dnid = {}
        for dnode in dnode_id:
            snodes = [nid for nid in graph.predecessors(dnode) if graph.nodes[nid].data["dtype"]==1]
            doc_feature = graph.nodes[snodes].data["init_feature"].mean(dim=0)
            assert not torch.any(torch.isnan(doc_feature)), "doc_feature_element"
            node_feature_list.append(doc_feature)
            for s in snodes:
                snid2dnid[int(s)] = dnode
        node_feature = torch.stack(node_feature_list)
        return node_feature, snid2dnid



def get_snode_feat(G, feat):
    glist = dgl.unbatch(G)
    feature = []
    glen = []
    feature1 = []
    glen1 = []
    for g in glist:
        snode_id = g.filter_nodes(lambda nodes: nodes.data["dtype"] == 1)
        snode_id1 = g.filter_nodes(lambda nodes: nodes.data["dtype"] == 2)
        feature.append(g.nodes[snode_id].data[feat])
        glen.append(len(snode_id))
        feature1.append(g.nodes[snode_id1].data[feat])
        glen1.append(len(snode_id1))
    return feature, glen,feature1, glen1