# -*- coding: utf-8 -*-
import torch
import torch.nn as nn

from dataEnhance import sequence_reversal
from modules import (
    ConvLayer,
    BiLSTMLayer,
    ReconstructionModel, MConvLayer
)


class MFAM_AD(nn.Module):
    def __init__(
            self,
            n_features,
            window_size,
            out_dim,
            kernel_size=7,
            lstm_n_layers=1,
            lstm_hid_dim=150,
            recon_n_layers=1,
            recon_hid_dim=150,
            dropout=0.1,
    ):
        super(MFAM_AD, self).__init__()
        # self.conv = ConvLayer(n_features, kernel_size)
        self.conv = MConvLayer(n_features)
        self.bils = BiLSTMLayer(n_features, lstm_hid_dim, lstm_n_layers, dropout)
        self.recon_model = ReconstructionModel(window_size, lstm_hid_dim, recon_hid_dim, out_dim, recon_n_layers,
                                               dropout)

    def forward(self, x):
        # x1 = torch.flip(x, dims=(1,))
        # x2 = self.add_noise(x, 0.1)
        # x1 = self.conv(x1)
        # x2 = self.conv(x2)
        # _, h_end1 = self.bils(x1)
        # _, h_end2 = self.bils(x2)
        # h_end1 = self.recon_model(h_end1)
        # h_end2 = self.recon_model(h_end2)
        x = self.conv(x)
        _, h_end = self.bils(x)
        h_end = h_end.view(x.shape[0], -1)
        recons = self.recon_model(h_end)
        return recons

        # return h_end1,h_end2
