import numpy as np
import torch


class AE(torch.nn.Module):
    def __init__(self, in_feautres, out_features, window_size):
        super(AE, self).__init__()
        self.window_size = window_size
        self.in_feautres = in_feautres
        self.out_features = out_features
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(in_features=in_feautres * self.window_size, out_features=64),
            torch.nn.ReLU(),

            # torch.nn.Linear(in_features=64, out_features=32),
            # torch.nn.ReLU(),

            torch.nn.Linear(in_features=64, out_features=out_features),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(in_features=out_features, out_features=64),
            torch.nn.ReLU(),

            # torch.nn.Linear(in_features=32, out_features=64),
            # torch.nn.ReLU(),

            torch.nn.Linear(in_features=64, out_features=in_feautres * self.window_size)
        )

    def forward(self, x):
        x = x.view(-1, self.window_size * self.in_feautres)
        # x = x.permute(1,0)
        e_x = self.encoder(x)
        d_x = self.decoder(e_x)
        time = d_x.view(-1, self.window_size, self.in_feautres)
        return time

