import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, LSTM, Dense, Layer, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau
from sklearn.metrics import mean_absolute_error, mean_squared_error
import matplotlib.pyplot as plt


# ------------------- Custom Attention Layer (corresponding to Figure 1: Attention Layer) -------------------
class AttentionLayer(Layer):
    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(
            shape=(input_shape[-1], 1),
            initializer="glorot_uniform",
            trainable=True,
            name="attention_weight"
        )
        super(AttentionLayer, self).build(input_shape)

    def call(self, x):
        # Calculate attention weights (prioritize highly correlated features based on Table 1)
        e = tf.matmul(tf.reshape(x, (-1, x.shape[-1])), self.W)  # (batch*time_steps, 1)
        e = tf.reshape(e, (-1, x.shape[1]))  # (batch, time_steps)
        alpha = tf.nn.softmax(e)  # Weight normalization
        alpha = tf.expand_dims(alpha, axis=-1)  # (batch, time_steps, 1)
        output = x * alpha  # Attention-weighted output
        output = tf.reduce_sum(output, axis=1)  # (batch, hidden_units)
        return output, alpha


# ------------------- Build Attention-Enhanced LSTM Model (corresponding to Section 2.2.1: Model Structure) -------------------
def build_attention_lstm(
        time_steps=16,  # Time window: 16×15 minutes = 4 hours (corresponding to Section 2.2.1)
        input_dim=6,  # Number of input features: WS_cen+Pwind_fluct+TSI+Psolar_fluct+AirT+AirH (refer to Section 2.2.1)
        lstm_units1=64,  # Number of units in the first LSTM layer (corresponding to Figure 1)
        lstm_units2=32  # Number of units in the second LSTM layer (corresponding to Figure 1)
):
    # Input Layer
    inputs = Input(shape=(time_steps, input_dim), name="input_layer")

    # First LSTM Layer (return sequences for the second layer)
    lstm1 = LSTM(
        units=lstm_units1,
        return_sequences=True,
        dropout=0.2,  # Prevent overfitting (refer to Section 2.2.3)
        name="lstm_layer1"
    )(inputs)

    # Second LSTM Layer (return sequences for the attention layer)
    lstm2 = LSTM(
        units=lstm_units2,
        return_sequences=True,
        dropout=0.2,
        name="lstm_layer2"
    )(lstm1)

    # Attention Layer
    attention_layer = AttentionLayer(name="attention_layer")
    attention_output, attention_weights = attention_layer(lstm2)

    # Output Layer (1-hour load prediction, corresponding to Section 2.2.1)
    outputs = Dense(units=1, activation="linear", name="output_layer")(attention_output)

    # Build model (includes attention weight output for analysis)
    model = Model(inputs=inputs, outputs=[outputs, attention_weights], name="Attention_LSTM")

    # Compile model (Adam optimizer, MAE loss, refer to Section 2.2.3)
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(
        optimizer=optimizer,
        loss={"output_layer": "mean_absolute_error"},
        metrics={"output_layer": ["mae", "mse"]}
    )
    return model


# ------------------- Data Format Conversion (Time Series → Supervised Learning Samples) -------------------
def create_time_series_samples(data, time_steps=16, target_col="Lt"):
    """Convert time series data into input format of (number of samples, time_steps, number of features)"""
    X, y = [], []
    for i in range(time_steps, len(data)):
        # Input samples: features of the previous time_steps time steps
        X.append(data.iloc[i - time_steps:i, :-1].values)  # All features except the target column
        # Target value: load at the i-th time step (1-hour ahead prediction)
        y.append(data.iloc[i, data.columns.get_loc(target_col)])
    return np.array(X), np.array(y)


# ------------------- Model Training and Evaluation -------------------
def train_and_evaluate_model():
    # 1. Load preprocessed data (including SPN communication load Lt, refer to Section 2.1: Load Model)
    train_data = pd.read_csv("wind_train_preprocessed.csv")  # Can be replaced with photovoltaic data
    val_data = pd.read_csv("wind_val_preprocessed.csv")
    test_data = pd.read_csv("wind_test_preprocessed.csv")

    # Assume the data contains the SPN communication load column "Lt" (refer to calculation results in Section 2.1: Load Model)
    target_col = "Lt"
    feature_cols = [col for col in train_data.columns if col != target_col and col != "timestamp"]
    input_dim = len(feature_cols)
    time_steps = 16  # 4-hour time window

    # 2. Convert to time series samples
    X_train, y_train = create_time_series_samples(
        train_data[feature_cols + [target_col]],
        time_steps=time_steps,
        target_col=target_col
    )
    X_val, y_val = create_time_series_samples(
        val_data[feature_cols + [target_col]],
        time_steps=time_steps,
        target_col=target_col
    )
    X_test, y_test = create_time_series_samples(
        test_data[feature_cols + [target_col]],
        time_steps=time_steps,
        target_col=target_col
    )

    # 3. Build and train the model
    model = build_attention_lstm(
        time_steps=time_steps,
        input_dim=input_dim,
        lstm_units1=64,
        lstm_units2=32
    )

    # Learning rate decay: 10% decay every 20 epochs (refer to Section 2.2.3)
    lr_scheduler = ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.1,
        patience=20,
        verbose=1
    )

    # Train the model
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=100,  # Refer to Section 2.2.3
        batch_size=64,  # Refer to Section 2.2.3
        callbacks=[lr_scheduler],
        shuffle=False  # Do not shuffle time series data
    )

    # 4. Model evaluation (corresponding to metrics in Table 9)
    y_pred, _ = model.predict(X_test)
    # MAE: Mean Absolute Error
    mae = mean_absolute_error(y_test, y_pred)
    # RMSE: Root Mean Squared Error
    rmse = np.sqrt(mean_squared_error(y_test, y_pred))
    # PA: Proportion of samples with relative error < 10%
    relative_error = np.abs((y_pred - y_test) / y_test)
    pa = np.sum(relative_error < 0.1) / len(y_test) * 100

    # Print evaluation results (refer to Attention-LSTM performance in Table 9)
    print("=" * 50)
    print("Attention-Enhanced LSTM Model Evaluation Results (Test Set)")
    print(f"MAE (ms): {mae:.2f}")  # Target: 8.2 ms (Table 9)
    print(f"RMSE (ms): {rmse:.2f}")  # Target: 11.4 ms (Table 9)
    print(f"PA (%): {pa:.2f}")  # Target: 91.3% (Table 9)
    print("=" * 50)

    # 5. Visualize prediction results (corresponding to Figure 10: Time Series Prediction Plot)
    plt.figure(figsize=(12, 6))
    plt.plot(y_test[:200], label="Measured Load (Lt)", color="blue")  # Measured load
    plt.plot(y_pred[:200], label="Predicted Load (Attention-LSTM)", color="red")  # Predicted load
    plt.xlabel("Time Step (15 min/step)")
    plt.ylabel("SPN Communication Load (Mbps)")
    plt.title("1-Hour Ahead Communication Load Prediction (Test Set)")
    plt.legend()
    plt.savefig("attention_lstm_prediction_result.png")  # Corresponding to Figure 10
    plt.close()

    # 6. Save the model
    model.save("attention_lstm_model.h5")
    print("Model saved as attention_lstm_model.h5")


# ------------------- Main Function Call -------------------
if __name__ == "__main__":
    train_and_evaluate_model()