from sklearn.metrics import precision_score, recall_score, f1_score
from math import sqrt

# Function to compute mean ± SE
def summarize_metric(values):
    mean = np.mean(values)
    se = np.std(values, ddof=1) / sqrt(len(values))
    return f"{mean:.3f} $\\pm$ {se:.3f}"

# Store results for each min_comments setting
all_results = {}

for min_comments in [8, 5, 10]:
    print(f"\n=== Processing {min_comments} min comments ===")

    df = ur_df.query("n_comments >= @min_comments and n_emo >= 1")
    Y = np.where(df["is_questionable"].astype(int) > 0, 1, 0)
    X = df[["anger","anticipation","disgust","fear","joy","sadness","surprise","trust"]]

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1518)

    f1_all, prec_all, rec_all = [], [], []
    f1_majority, prec_majority, rec_majority = [], [], []
    f1_minority, prec_minority, rec_minority = [], [], []

    for fold, (train_idx, test_idx) in enumerate(skf.split(X, Y)):
        X_train, y_train = X.iloc[train_idx], Y[train_idx]
        X_test, y_test = X.iloc[test_idx], Y[test_idx]

        oversample = SMOTE(random_state=1518)
        X_train_res, y_train_res = oversample.fit_resample(X_train, y_train)

        model = RandomForestClassifier(random_state=1518)
        model.fit(X_train_res, y_train_res)

        y_pred = model.predict(X_test)

        # Overall (on all data, micro average)
        f1_all.append(f1_score(y_test, y_pred, average="micro"))
        prec_all.append(precision_score(y_test, y_pred, average="micro"))
        rec_all.append(recall_score(y_test, y_pred, average="micro"))

        # Per-class metrics
        f1_majority.append(f1_score(y_test, y_pred, labels=[0], average="macro", zero_division=0))
        prec_majority.append(precision_score(y_test, y_pred, labels=[0], average="macro", zero_division=0))
        rec_majority.append(recall_score(y_test, y_pred, labels=[0], average="macro", zero_division=0))

        f1_minority.append(f1_score(y_test, y_pred, labels=[1], average="macro", zero_division=0))
        prec_minority.append(precision_score(y_test, y_pred, labels=[1], average="macro", zero_division=0))
        rec_minority.append(recall_score(y_test, y_pred, labels=[1], average="macro", zero_division=0))

    # Store results
    all_results[min_comments] = {
        "f1": summarize_metric(f1_all),
        "precision": summarize_metric(prec_all),
        "recall": summarize_metric(rec_all),
        "f1 MAp": summarize_metric(f1_majority),
        "precision MAp": summarize_metric(prec_majority),
        "recall MAp": summarize_metric(rec_majority),
        "f1 MIp": summarize_metric(f1_minority),
        "precision MIp": summarize_metric(prec_minority),
        "recall MIp": summarize_metric(rec_minority),
    }

# --- Build LaTeX Table ---
rows = [
    "f1", "precision", "recall",
    "f1 MAp", "precision MAp", "recall MAp",
    "f1 MIp", "precision MIp", "recall MIp",
]

latex = []
latex.append("\\begin{table}[H]")
latex.append("\\small\\sf\\centering")
latex.append("\\begin{tabular}{lccc}")
latex.append("\\toprule")
latex.append("\\rowcolor{gray!10}min comments & 5 & 8 & 10 \\\\")
latex.append("\\midrule")

for i, metric in enumerate(rows):
    row = metric
    for mc in [5, 8, 10]:
        row += f" & {all_results[mc][metric]}"
    row += " \\\\"
    if i % 2 == 0:
        latex.append(row)
    else:
        latex.append("\\rowcolor{gray!10}" + row)

latex.append("\\bottomrule")
latex.append("\\end{tabular}")
latex.append("\\caption{\\rev{\\textbf{Average Evaluation Metrics Across 5-Fold Cross-Validation.} "
             "The table reports the mean performance of five SMOTE Random Forest models, "
             "each trained and tested on a distinct fold. "
             "While overall accuracy remains relatively stable, the performance on the minority class (MIp) "
             "is consistently poor, with low recall and precision. "
             "This highlights the difficulty of learning meaningful decision boundaries under severe class imbalance, "
             "despite the application of oversampling techniques.}}")
latex.append("\\label{tab:model_performance}")
latex.append("\\end{table}")

# Save
with open("cv_metrics_table.tex", "w") as f:
    f.write("\n".join(latex))

print("\n".join(latex))
