import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from datetime import datetime


def load_state_grid_dataset(wind_file_path, pv_file_path):
    """Load the State Grid Wind-PV Competition Dataset defined in Table 5"""
    # Load wind farm data (6 sites, including fields such as WS_cen, Power_output)
    wind_data = pd.read_excel(wind_file_path, parse_dates=["Time"])
    # Load PV plant data (8 sites, including fields such as TSI, Power)
    pv_data = pd.read_excel(pv_file_path, parse_dates=["Time"])

    # Unify timestamp column name and format
    wind_data.rename(columns={"Time": "timestamp"}, inplace=True)
    pv_data.rename(columns={"Time": "timestamp"}, inplace=True)
    wind_data = wind_data.sort_values("timestamp").reset_index(drop=True)
    pv_data = pv_data.sort_values("timestamp").reset_index(drop=True)

    return wind_data, pv_data


def handle_missing_values(data):
    """Missing Value Handling: Linear interpolation for short gaps (≤3 consecutive values), delete long gaps"""
    # Linear interpolation for short gaps
    data_interp = data.interpolate(method="linear", limit=3)
    # Delete rows with remaining missing values (long gaps)
    data_clean = data_interp.dropna()
    return data_clean


def remove_outliers(data, feature_cols):
    """Outlier Detection: 3σ Rule, refer to Manuscript.doc Section 4.1.3"""
    data_clean = data.copy()
    for col in feature_cols:
        mean_val = data_clean[col].mean()
        std_val = data_clean[col].std()
        # Retain data within the range [mean-3σ, mean+3σ]
        data_clean = data_clean[
            (data_clean[col] >= mean_val - 3 * std_val) &
            (data_clean[col] <= mean_val + 3 * std_val)
            ]
    return data_clean.reset_index(drop=True)


def extract_fluctuation_features(data, data_type="wind"):
    """Extract fluctuation features, corresponding to Manuscript.doc Section 2.1.1 (Pwind.f, Psolar,t, TSI,t)"""
    data_with_features = data.copy()
    if data_type == "wind":
        # Wind power fluctuation: Pwind.f = Pwind,t - Pwind,t-1 (15-minute interval, corresponding to time resolution in Table 5)
        data_with_features["P_wind_fluct"] = data_with_features["Power_output"].diff(periods=1)
        # Retain rows with non-null feature values
        data_with_features = data_with_features.dropna(subset=["P_wind_fluct"])
    elif data_type == "pv":
        # PV power fluctuation: Psolar,t = Psolar,t - Psolar,t-1
        data_with_features["P_solar_fluct"] = data_with_features["Power"].diff(periods=1)
        # TSI fluctuation: TSI,t = TSI,t - TSI,t-1
        data_with_features["TSI_fluct"] = data_with_features["Total_solar_irradiance"].diff(periods=1)
        # Retain rows with non-null feature values
        data_with_features = data_with_features.dropna(subset=["P_solar_fluct", "TSI_fluct"])
    return data_with_features


def normalize_features(data, feature_cols):
    """Min-max normalization to [0,1], refer to Manuscript.doc Section 2.3.1"""
    scaler = MinMaxScaler(feature_range=(0, 1))
    data_normalized = data.copy()
    data_normalized[feature_cols] = scaler.fit_transform(data_normalized[feature_cols])
    return data_normalized, scaler


def split_dataset(data, data_type="wind"):
    """Time-based dataset splitting, corresponding to Table 8: 2019 (training) / Dec 2019 (validation) / 2020 (test)"""
    # Convert timestamp to date format
    data["date"] = data["timestamp"].dt.date
    # Training set: 2019.01.01 - 2019.11.30
    train_data = data[
        (data["date"] >= datetime(2019, 1, 1).date()) &
        (data["date"] <= datetime(2019, 11, 30).date())
        ]
    # Validation set: 2019.12.01 - 2019.12.31
    val_data = data[
        (data["date"] >= datetime(2019, 12, 1).date()) &
        (data["date"] <= datetime(2019, 12, 31).date())
        ]
    # Test set: 2020.01.01 - 2020.12.31
    test_data = data[
        (data["date"] >= datetime(2020, 1, 1).date()) &
        (data["date"] <= datetime(2020, 12, 31).date())
        ]

    # Remove auxiliary column
    for subset in [train_data, val_data, test_data]:
        subset.drop(columns=["date"], inplace=True)

    return train_data, val_data, test_data


# ------------------- Main Process Call Example -------------------
if __name__ == "__main__":
    # 1. Load data (replace with actual file paths)
    wind_data, pv_data = load_state_grid_dataset(
        wind_file_path="Manuscript_WindFarm_Data.xlsx",
        pv_file_path="Manuscript_PVPlant_Data.xlsx"
    )

    # 2. Wind farm data preprocessing
    wind_feature_cols = ["WS_cen", "Power_output"]  # Core fields in Table 5
    wind_data_clean = handle_missing_values(wind_data)
    wind_data_clean = remove_outliers(wind_data_clean, wind_feature_cols)
    wind_data_with_features = extract_fluctuation_features(wind_data_clean, data_type="wind")
    wind_normalized, wind_scaler = normalize_features(
        wind_data_with_features,
        feature_cols=["WS_cen", "Power_output", "P_wind_fluct"]
    )
    wind_train, wind_val, wind_test = split_dataset(wind_normalized, data_type="wind")

    # 3. PV plant data preprocessing
    pv_feature_cols = ["TSI", "Power"]  # Core fields in Table 5
    pv_data_clean = handle_missing_values(pv_data)
    pv_data_clean = remove_outliers(pv_data_clean, pv_feature_cols)
    pv_data_with_features = extract_fluctuation_features(pv_data_clean, data_type="pv")
    pv_normalized, pv_scaler = normalize_features(
        pv_data_with_features,
        feature_cols=["TSI", "Power", "P_solar_fluct", "TSI_fluct"]
    )
    pv_train, pv_val, pv_test = split_dataset(pv_normalized, data_type="pv")

    # 4. Save preprocessed data (for subsequent model use)
    wind_train.to_csv("wind_train_preprocessed.csv", index=False)
    wind_val.to_csv("wind_val_preprocessed.csv", index=False)
    wind_test.to_csv("wind_test_preprocessed.csv", index=False)
    pv_train.to_csv("pv_train_preprocessed.csv", index=False)
    pv_val.to_csv("pv_val_preprocessed.csv", index=False)
    pv_test.to_csv("pv_test_preprocessed.csv", index=False)

    print("Data preprocessing completed, training/validation/test sets saved")