import argparse
import json
import os
from typing import List, Tuple

import joblib
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, LabelEncoder


TARGET_COLUMN = "Adaptivity Level"
RANDOM_STATE = 42


def standardize_column_names(columns: List[str]) -> List[str]:
	return [c.strip().replace(" ", "_") for c in columns]


def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
	# Trim whitespace for all string cells
	for col in df.columns:
		if df[col].dtype == object:
			df[col] = df[col].astype(str).str.strip()

	# Drop fully duplicate rows
	df = df.drop_duplicates().reset_index(drop=True)

	# Remove rows with missing target
	df = df.dropna(subset=[TARGET_COLUMN])

	# Keep only valid target values
	valid_targets = {"Low", "Moderate", "High"}
	df = df[df[TARGET_COLUMN].isin(valid_targets)].copy()

	# Handle any other missing values by simple strategy: drop rows with NA
	df = df.dropna().reset_index(drop=True)
	return df


def stratified_split(
	df: pd.DataFrame,
	target_col: str,
	test_size: float = 0.1,
	val_size: float = 0.1,
	random_state: int = RANDOM_STATE,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
	# First split off test
	sss1 = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
	train_val_idx, test_idx = next(sss1.split(df, df[target_col]))
	train_val_df = df.iloc[train_val_idx].reset_index(drop=True)
	test_df = df.iloc[test_idx].reset_index(drop=True)

	# From remaining, split validation
	# val proportion relative to remaining
	relative_val = val_size / (1.0 - test_size)
	sss2 = StratifiedShuffleSplit(n_splits=1, test_size=relative_val, random_state=random_state)
	train_idx, val_idx = next(sss2.split(train_val_df, train_val_df[target_col]))
	train_df = train_val_df.iloc[train_idx].reset_index(drop=True)
	val_df = train_val_df.iloc[val_idx].reset_index(drop=True)
	return train_df, val_df, test_df


def build_preprocessor(df: pd.DataFrame, target_col: str) -> Tuple[ColumnTransformer, List[str]]:
	feature_cols = [c for c in df.columns if c != target_col]
	categorical_features = feature_cols  # All features are categorical in this dataset

	categorical_transformer = OneHotEncoder(handle_unknown="ignore", sparse=False)

	preprocessor = ColumnTransformer(
		transformers=[
			("cat", categorical_transformer, categorical_features),
		],
		remainder="drop",
	)

	# Fit on full training features outside this function
	return preprocessor, feature_cols


def extract_feature_names(preprocessor: ColumnTransformer, feature_cols: List[str]) -> List[str]:
	# Only categorical OHE
	ohe: OneHotEncoder = preprocessor.named_transformers_["cat"]
	# get_feature_names_out is available in sklearn>=1.0
	return list(ohe.get_feature_names_out(feature_cols))


def main():
	parser = argparse.ArgumentParser(description="Preprocess dataset and create stratified splits")
	parser.add_argument("--data_csv", type=str, required=True, help="Path to the raw CSV file")
	parser.add_argument("--out_dir", type=str, default="./preprocessed_data", help="Directory to save artifacts")
	args = parser.parse_args()

	os.makedirs(args.out_dir, exist_ok=True)

	df = pd.read_csv(args.data_csv)
	df.columns = standardize_column_names(list(df.columns))

	# Align standardized target column name
	global TARGET_COLUMN
	TARGET_COLUMN = TARGET_COLUMN.replace(" ", "_")

	df = clean_dataframe(df)

	# Report shape and class distribution
	print(f"Cleaned dataset shape: {df.shape}")
	print(f"Class distribution:")
	for class_name in sorted(df[TARGET_COLUMN].unique()):
		count = (df[TARGET_COLUMN] == class_name).sum()
		percentage = (count / len(df)) * 100
		print(f"  {class_name}: {count} ({percentage:.1f}%)")

	# Stratified splits 80/10/10
	train_df, val_df, test_df = stratified_split(df, TARGET_COLUMN, test_size=0.1, val_size=0.1)

	# Report split sizes
	print(f"\nSplit sizes:")
	print(f"  Train: {len(train_df)} samples")
	print(f"  Validation: {len(val_df)} samples")
	print(f"  Test: {len(test_df)} samples")

	# Persist cleaned splits
	train_csv = os.path.join(args.out_dir, "train.csv")
	val_csv = os.path.join(args.out_dir, "val.csv")
	test_csv = os.path.join(args.out_dir, "test.csv")
	train_df.to_csv(train_csv, index=False)
	val_df.to_csv(val_csv, index=False)
	test_df.to_csv(test_csv, index=False)

	# Label encode target using only training labels
	label_encoder = LabelEncoder()
	label_encoder.fit(train_df[TARGET_COLUMN])
	joblib.dump(label_encoder, os.path.join(args.out_dir, "label_encoder.joblib"))

	# Build and fit preprocessor on training features
	preprocessor, feature_cols = build_preprocessor(train_df, TARGET_COLUMN)

	X_train = train_df[feature_cols]
	preprocessor_pipeline = Pipeline(steps=[("pre", preprocessor)])
	preprocessor_pipeline.fit(X_train)
	joblib.dump(preprocessor_pipeline, os.path.join(args.out_dir, "preprocessor.joblib"))

	# Save expanded feature names
	expanded_feature_names = extract_feature_names(preprocessor, feature_cols)
	with open(os.path.join(args.out_dir, "feature_names.json"), "w", encoding="utf-8") as f:
		json.dump(expanded_feature_names, f, indent=2)

	# Save preprocessing info
	preprocessing_info = {
		"original_shape": df.shape,
		"cleaned_shape": df.shape,
		"num_features": len(feature_cols),
		"expanded_features": len(expanded_feature_names),
		"target_classes": list(label_encoder.classes_),
		"feature_columns": feature_cols
	}
	
	with open(os.path.join(args.out_dir, "preprocessing_info.json"), "w", encoding="utf-8") as f:
		json.dump(preprocessing_info, f, indent=2)

	print(f"\nArtifacts saved to: {args.out_dir}")
	print(f"Number of features after encoding: {len(expanded_feature_names)}")


if __name__ == "__main__":
	main()
