import argparse
import json
import os

import joblib
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report

from tabnsa_model import create_tabnsa_model


TARGET_COLUMN = "Adaptivity_Level"


def main():
	parser = argparse.ArgumentParser(description="Evaluate trained TabNSA model on test set")
	parser.add_argument("--artifacts_dir", type=str, default="./preprocessed_data")
	parser.add_argument("--model_dir", type=str, default="./output")
	args = parser.parse_args()

	# Set device
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	print(f"Using device: {device}")

	art_dir = args.artifacts_dir
	model_dir = args.model_dir
	
	preprocessor = joblib.load(os.path.join(art_dir, "preprocessor.joblib"))
	label_encoder = joblib.load(os.path.join(art_dir, "label_encoder.joblib"))
	
	# Load model configuration from history
	with open(os.path.join(model_dir, "history.json"), "r", encoding="utf-8") as f:
		history = json.load(f)
	
	model_config = history["model_config"]
	
	# Create TabNSA model with same configuration
	model = create_tabnsa_model(
		num_features=model_config["num_features"],
		num_classes=model_config["num_classes"],
		embed_dim=model_config["embed_dim"],
		num_heads=model_config["num_heads"],
		num_layers=model_config["num_layers"],
		window_size=model_config["window_size"],
		block_size=model_config["block_size"],
		mlp_ratio=model_config["mlp_ratio"],
		dropout=model_config["dropout"]
	).to(device)
	
	# Load trained weights
	model.load_state_dict(torch.load(os.path.join(model_dir, "model.pth")))
	model.eval()

	test_df = pd.read_csv(os.path.join(art_dir, "test.csv"))
	feature_cols = [c for c in test_df.columns if c != TARGET_COLUMN]

	X_test = preprocessor.transform(test_df[feature_cols]).astype(np.float32)
	y_true = label_encoder.transform(test_df[TARGET_COLUMN])

	# Convert to PyTorch tensor
	X_test_tensor = torch.FloatTensor(X_test).to(device)

	# Get predictions
	with torch.no_grad():
		outputs = model(X_test_tensor)
		y_prob = torch.softmax(outputs, dim=1).cpu().numpy()
		y_pred = torch.argmax(outputs, dim=1).cpu().numpy()

	acc = accuracy_score(y_true, y_pred)
	p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
	p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)

	metrics = {
		"accuracy": acc,
		"precision_macro": p_macro,
		"recall_macro": r_macro,
		"f1_macro": f1_macro,
		"precision_weighted": p_weighted,
		"recall_weighted": r_weighted,
		"f1_weighted": f1_weighted,
		"model_config": model_config
	}

	with open(os.path.join(model_dir, "test_metrics.json"), "w", encoding="utf-8") as f:
		json.dump(metrics, f, indent=2)

	# Confusion matrix
	cm = confusion_matrix(y_true, y_pred)
	labels = list(label_encoder.classes_)
	plt.figure(figsize=(6, 5))
	sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
	plt.xlabel("Predicted")
	plt.ylabel("True")
	plt.title("TabNSA Model - Confusion Matrix")
	plt.tight_layout()
	plt.savefig(os.path.join(model_dir, "confusion_matrix.png"), dpi=200)
	plt.close()

	# Classification report
	report = classification_report(y_true, y_pred, target_names=labels, digits=4)
	with open(os.path.join(model_dir, "classification_report.txt"), "w", encoding="utf-8") as f:
		f.write("TabNSA Model - Classification Report\n")
		f.write("=" * 50 + "\n\n")
		f.write(report)

	print("Saved metrics and plots to", model_dir)
	print(f"Test Accuracy: {acc:.4f}")
	print(f"Test F1-Macro: {f1_macro:.4f}")


if __name__ == "__main__":
	main()
