import torch
import torch.nn as nn
import torch.nn.functional as F


# Single feature extraction
class ConvLayer(nn.Module):

    def __init__(self, n_features, kernel_size=3):
        super(ConvLayer, self).__init__()
        self.padding = nn.ConstantPad1d((kernel_size - 1) // 2, 0.0)
        self.conv = nn.Conv1d(in_channels=n_features, out_channels=n_features, kernel_size=kernel_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.padding(x)
        x = self.conv(x)
        x = self.relu(x)
        return x.permute(0, 2, 1)

# Attention mechanism layer
class Attention(nn.Module):
    def __init__(self, feature_dim):
        super(Attention, self).__init__()
        self.feature_dim = feature_dim
        self.weight = nn.Parameter(torch.zeros(feature_dim, 1))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x):
        attn_scores = torch.matmul(x, self.weight).squeeze(-1)
        attn_weights = F.softmax(attn_scores, dim=-1)
        return x * attn_weights.unsqueeze(-1)

# Multi-scale feature extraction and fusion layer
class MConvLayer(nn.Module):

    def __init__(self, n_features):
        super(MConvLayer, self).__init__()
        self.kernel_num = 3
        self.padding3 = nn.ConstantPad1d(1, 0.0)
        self.padding5 = nn.ConstantPad1d(2, 0.0)
        self.padding7 = nn.ConstantPad1d(3, 0.0)
        self.conv3 = nn.Conv1d(in_channels=n_features, out_channels=n_features, kernel_size=3)
        self.conv5 = nn.Conv1d(in_channels=n_features, out_channels=n_features, kernel_size=5)
        self.conv7 = nn.Conv1d(in_channels=n_features, out_channels=n_features, kernel_size=7)
        self.relu = nn.ReLU()

        self.weight1 = nn.Parameter(torch.Tensor(1))
        self.weight2 = nn.Parameter(torch.Tensor(1))
        self.weight3 = nn.Parameter(torch.Tensor(1))
        self.attention = Attention(n_features*self.kernel_num)


    def forward(self, x):
        x = x.permute(0, 2, 1)
        x3 = self.padding3(x)
        x5 = self.padding5(x)
        x7 = self.padding7(x)
        x3 = self.relu(self.conv3(x3))
        x5 = self.relu(self.conv5(x5))
        x7 = self.relu(self.conv7(x7))
        x = torch.cat((x3, x5, x7), dim=1)
        x = x.permute(0, 2, 1)
        x = self.attention(x)
        return x

# BILSTM encoding layer
class BiLSTMLayer(nn.Module):
    """bi-Long Short-Term Memory (LSTM) Layer"""

    def __init__(self, in_dim, hidden_dim, n_layers, dropout):
        super(BiLSTMLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.dropout = 0.0 if n_layers == 1 else dropout
        # self.lstm = nn.LSTM(in_dim, hidden_dim, num_layers=n_layers, batch_first=True, dropout=self.dropout,
        #                     bidirectional=True)
        self.lstm = nn.LSTM(in_dim*3, hidden_dim, num_layers=n_layers, batch_first=True, dropout=self.dropout,
                            bidirectional=True)
    #
    def forward(self, x):
        # Initialize hidden state for first input with zeros
        # h0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
        # c0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)

        # Forward propagate LSTM layer
        out, (h, c) = self.lstm(x)
        # 分别提取正向和反向的最后一个输出
        last_forward_output = out[-1, :, :]
        last_backward_output = out[0, :, :]
        last_forward_h = h[-1, :, :]
        last_backward_h = h[0, :, :]
        # 组合两个方向的最后一个输出
        out = torch.cat((last_forward_output, last_backward_output), dim=-1)
        h = torch.cat((last_forward_h, last_backward_h), dim=-1)

        # Extract the last output of each sequence
        # out, h = out[-1, :, :], h[-1, :, :]
        return out, h

# BILSTM decoding layer
class BiLSDecoder(nn.Module):
    def __init__(self, in_dim, hid_dim, n_layers, dropout):
        super(BiLSDecoder, self).__init__()
        self.in_dim = in_dim
        self.dropout = 0.0 if n_layers == 1 else dropout
        self.lstm = nn.LSTM(in_dim, hid_dim, n_layers, batch_first=True, dropout=self.dropout,
                            bidirectional=True)

    def forward(self, x):
        decoder_out, _ = self.lstm(x)
        return decoder_out


class ReconstructionModel(nn.Module):

    def __init__(self, window_size, in_dim, hid_dim, out_dim, n_layers, dropout):
        super(ReconstructionModel, self).__init__()
        self.window_size = window_size
        self.decoder = BiLSDecoder(in_dim * 2, hid_dim, n_layers, dropout)
        self.fc = nn.Linear(in_dim * 2, out_dim)

    def forward(self, x):
        h_end = x
        h_end_rep = h_end.repeat_interleave(self.window_size, dim=1).view(x.size(0), self.window_size, -1)

        decoder_out = self.decoder(h_end_rep)
        out = self.fc(decoder_out)
        return out



