import os
os.environ['CMDSTAN'] = "D:/anaconda/envs/Desktop/Library/bin/cmdstan"
import ax
import pandas as pd
import kats.utils.time_series_parameter_tuning as tpt
from kats.consts import ModelEnum, SearchMethodEnum, TimeSeriesData
import numpy as np
from ax.core.parameter import ChoiceParameter, FixedParameter, ParameterType
from ax.models.random.sobol import SobolGenerator
from ax.models.random.uniform import UniformGenerator
import warnings
class LSTMParams(Params):
    """Parameter class for time series LSTM model

    This is the parameter class for time series LSTM model, containing parameters including stacked layers

    Attributes:
        hidden_size: LSTM hidden unit size
        time_window: Time series sequence length that feeds into the model
        num_epochs: Number of epochs for the training process
        num_layers: Number of stacked LSTM layers
    """

    __slots__ = ["hidden_size", "time_window", "num_epochs", "num_layers"]

    def __init__(self, hidden_size: int, time_window: int, num_epochs: int, num_layers: int = 1) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.time_window = time_window
        self.num_epochs = num_epochs
        self.num_layers = num_layers  # 新增堆叠层数属性

        logging.debug(
            "Initialized LSTMParams instance."
            f"hidden_size:{hidden_size}, time_window:{time_window}, "
            f"num_epochs:{num_epochs}, num_layers:{num_layers}"
        )

    def validate_params(self):
        # 新增参数验证
        if self.num_layers < 1:
            raise ValueError("num_layers must be a positive integer")
        if not isinstance(self.num_layers, int):
            raise TypeError("num_layers must be an integer")
        logging.info("Validated LSTM parameters.")
class LSTMForecast(nn.Module):
    """Torch forecast class for time series LSTM model with stacked layers support"""

    def __init__(self, params: LSTMParams, input_size: int, output_size: int) -> None:
        super().__init__()

        # 初始化LSTM时加入num_layers参数
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=params.hidden_size,
            num_layers=params.num_layers  # 使用堆叠层数
        )

        self.linear = nn.Linear(
            in_features=params.hidden_size,
            out_features=output_size
        )

    def forward(self, input_seq: torch.Tensor) -> torch.Tensor:
        """Forward method with support for multi-layer LSTM"""
        lstm_out, self.hidden_cell = self.lstm(
            input_seq.view(len(input_seq), 1, -1), self.hidden_cell
        )
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions[-1]
class LSTMModel(mm.Model):
    """Kats model class for time series LSTM model with stacked layers"""

    def __init__(self, data: TimeSeriesData, params: LSTMParams) -> None:
        super().__init__(data, params)
        if not isinstance(self.data.value, pd.Series):
            msg = (
                f"Only support univariate time series, but get {type(self.data.value)}."
            )
            logging.error(msg)
            raise ValueError(msg)

    def __setup_data(self) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Prepare input data for the LSTM model (unchanged)"""
        train_data = self.data.value.values.astype(float)

        self.scaler = MinMaxScaler(feature_range=(-1, 1))
        train_data_scaled = self.scaler.fit_transform(train_data.reshape(-1, 1))
        self.train_data_normalized = torch.FloatTensor(train_data_scaled).view(-1)

        inout_seq = []
        for i in range(len(self.train_data_normalized) - self.params.time_window):
            train_seq = self.train_data_normalized[i: i + self.params.time_window]
            train_label = self.train_data_normalized[i + self.params.time_window: i + self.params.time_window + 1]
            inout_seq.append((train_seq, train_label))

        return inout_seq

    def fit(self, **kwargs) -> None:
        """Fit method with multi-layer LSTM support"""
        logging.debug(f"Call fit() with parameters. kwargs:{kwargs}")

        self.lr = kwargs.get("lr", 0.001)
        self.model = LSTMForecast(params=self.params, input_size=1, output_size=1)

        loss_function = nn.MSELoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        train_inout_seq = self.__setup_data()

        for i in range(self.params.num_epochs):
            for seq, labels in train_inout_seq:
                optimizer.zero_grad()

                # 初始化隐藏状态和细胞状态，形状为(num_layers, batch_size, hidden_size)
                self.model.hidden_cell = (
                    torch.zeros(self.params.num_layers, 1, self.params.hidden_size),
                    torch.zeros(self.params.num_layers, 1, self.params.hidden_size)
                )

                y_pred = self.model(seq)
                single_loss = loss_function(y_pred, labels)
                single_loss.backward()
                optimizer.step()

            if i % 25 == 1:
                logging.info(f"epoch: {i:3} loss: {single_loss.item():10.8f}")

        return self

    def predict(self, steps: int, **kwargs) -> pd.DataFrame:
        """Prediction function with multi-layer LSTM support"""
        logging.debug(f"Call predict() with parameters. steps:{steps}, kwargs:{kwargs}")
        self.freq = kwargs.get("freq", pd.infer_freq(self.data.time))
        self.model.eval()

        test_inputs = self.train_data_normalized[-self.params.time_window:].tolist()

        for _ in range(steps):
            seq = torch.FloatTensor(test_inputs[-self.params.time_window:])
            with torch.no_grad():
                # 预测时初始化多层隐藏状态
                self.model.hidden_cell = (
                    torch.zeros(self.params.num_layers, 1, self.params.hidden_size),
                    torch.zeros(self.params.num_layers, 1, self.params.hidden_size)
                )
                test_inputs.append(self.model(seq).item())

        fcst_denormalized = self.scaler.inverse_transform(
            np.array(test_inputs[self.params.time_window:]).reshape(-1, 1)
        ).flatten()
        logging.info("Generated forecast data from LSTM model.")
        logging.debug(f"Forecast data: {fcst_denormalized}")

        last_date = self.data.time.max()
        dates = pd.date_range(start=last_date, periods=steps + 1, freq=self.freq)
        self.dates = dates[dates != last_date]
        self.y_fcst = fcst_denormalized
        self.y_fcst_lower = fcst_denormalized * 0.95
        self.y_fcst_upper = fcst_denormalized * 1.05

        self.fcst_df = pd.DataFrame(
            {
                "time": self.dates,
                "fcst": self.y_fcst,
                "fcst_lower": self.y_fcst_lower,
                "fcst_upper": self.y_fcst_upper,
            }
        )

        logging.debug(f"Return forecast data: {self.fcst_df}")
        return self.fcst_df

    def plot(self):
        mm.Model.plot(self.data, self.fcst_df)

    def __str__(self):
        return "LSTM"

    @staticmethod
    def get_parameter_search_space() -> List[Dict[str, Any]]:
        """Updated parameter search space including num_layers"""
        return [
            {
                "name": "hidden_size",
                "type": "choice",
                "values": list(range(1, 500, 10)),
                "value_type": "int",
                "is_ordered": True,
            },
            {
                "name": "time_window",
                "type": "choice",
                "values": list(range(1, 20, 1)),
                "value_type": "int",
                "is_ordered": True,
            },
            {
                "name": "num_epochs",
                "type": "choice",
                "values": list(range(50, 2000, 50)),
                "value_type": "int",
                "is_ordered": True,
            },
            {
                "name": "num_layers",
                "type": "choice",
                "values": [1, 2, 3, 4, 5],
                "value_type": "int",
                "is_ordered": True,
            },
        ]
warnings.simplefilter(action='ignore')
print('step1')
air_passengers_df_1 = pd.read_csv('D:/dataset/10point.csv')
parameters = [
{
    "name": "hidden_size",
    "type": "choice",
    "values": list(range(1, 64)),
    "value_type": "int",
    "is_ordered": True,
},
{
    "name": "time_window",
    "type": "choice",
    "values": list(range(1, 25)),
    "value_type": "int",
    "is_ordered": True,
},
{
    "name": "num_epochs",
    "type": "choice",
    "values": list(range(1,128)),
    "value_type": "int",
    "is_ordered": True,
},
{
    "name": "num_layers",
    "type": "choice",
    "values": [1, 2, 3],
    "value_type": "int",
    "is_ordered": True,
},
]
print('step2')
def evaluation_function(params):
    params1 = LSTMParams(
        hidden_size=params['hidden_size'],
        time_window=params['time_window'],
        num_epochs=params['num_epochs'],
        num_layers=params['num_layers']
    )
    a=0
    error=0
    for i in range(10):
        warnings.simplefilter(action='ignore')
        air_passengers_df = air_passengers_df_1[['time', f'{a}']]
        air_passengers_ts = TimeSeriesData(air_passengers_df)
        split = int(0.8 * len(air_passengers_df))
        train_ts = air_passengers_ts[0:split]
        test_ts = air_passengers_ts[split:]
        model = LSTMModel(train_ts, params1)
        model.fit()
        model_pred = model.predict(steps=7,freq='12D',include_history=True)
        error =np.mean(np.abs(model_pred['fcst'].values - test_ts.value.values))+error
        a+=1
    return error/10
parameter_tuner_grid = tpt.SearchMethodFactory.create_search_method(
    objective_name="evaluation_metric",
    parameters=parameters,
    selected_search_method=SearchMethodEnum.BAYES_OPT,
    evaluation_function=evaluation_function,
    bootstrap_size=500,
)
parameter_tuner_grid.generate_evaluate_new_parameter_values(
    evaluation_function=evaluation_function
    )
parameter_tuning_results_grid = (
    parameter_tuner_grid.list_parameter_value_scores()
    )
parameter_tuning_results_grid.to_csv(f'D:/LSTM_beysb.csv', index=False)