import argparse
import json
import os
from typing import Dict, Tuple

import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import DataLoader, TensorDataset

from tabnsa_model import create_tabnsa_model


TARGET_COLUMN = "Adaptivity_Level"
RANDOM_STATE = 42


def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, 
				optimizer: optim.Optimizer, device: torch.device) -> Tuple[float, float]:
	model.train()
	running_loss = 0.0
	correct = 0
	total = 0
	
	for inputs, targets in dataloader:
		inputs, targets = inputs.to(device), targets.to(device)
		
		optimizer.zero_grad()
		outputs = model(inputs)
		loss = criterion(outputs, targets)
		loss.backward()
		optimizer.step()
		
		running_loss += loss.item()
		_, predicted = torch.max(outputs.data, 1)
		total += targets.size(0)
		correct += (predicted == targets).sum().item()
	
	epoch_loss = running_loss / len(dataloader)
	epoch_acc = correct / total
	return epoch_loss, epoch_acc


def validate_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, 
				  device: torch.device) -> Tuple[float, float]:
	model.eval()
	running_loss = 0.0
	correct = 0
	total = 0
	
	with torch.no_grad():
		for inputs, targets in dataloader:
			inputs, targets = inputs.to(device), targets.to(device)
			
			outputs = model(inputs)
			loss = criterion(outputs, targets)
			
			running_loss += loss.item()
			_, predicted = torch.max(outputs.data, 1)
			total += targets.size(0)
			correct += (predicted == targets).sum().item()
	
	epoch_loss = running_loss / len(dataloader)
	epoch_acc = correct / total
	return epoch_loss, epoch_acc


def main():
	parser = argparse.ArgumentParser(description="Train TabNSA model on preprocessed data")
	parser.add_argument("--artifacts_dir", type=str, default="./preprocessed_data")
	parser.add_argument("--output_dir", type=str, default="./output")
	parser.add_argument("--epochs", type=int, default=100)
	parser.add_argument("--batch_size", type=int, default=32)
	parser.add_argument("--learning_rate", type=float, default=1e-3)
	parser.add_argument("--embed_dim", type=int, default=128)
	parser.add_argument("--num_heads", type=int, default=8)
	parser.add_argument("--num_layers", type=int, default=3)
	parser.add_argument("--window_size", type=int, default=8)
	parser.add_argument("--block_size", type=int, default=4)
	parser.add_argument("--mlp_ratio", type=float, default=4.0)
	parser.add_argument("--dropout", type=float, default=0.1)
	args = parser.parse_args()

	# Create output directory
	os.makedirs(args.output_dir, exist_ok=True)

	# Set device
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	print(f"Using device: {device}")

	# Set random seeds
	torch.manual_seed(RANDOM_STATE)
	np.random.seed(RANDOM_STATE)
	if torch.cuda.is_available():
		torch.cuda.manual_seed(RANDOM_STATE)

	art_dir = args.artifacts_dir
	preprocessor = joblib.load(os.path.join(art_dir, "preprocessor.joblib"))
	label_encoder = joblib.load(os.path.join(art_dir, "label_encoder.joblib"))

	train_df = pd.read_csv(os.path.join(art_dir, "train.csv"))
	val_df = pd.read_csv(os.path.join(art_dir, "val.csv"))

	feature_cols = [c for c in train_df.columns if c != TARGET_COLUMN]

	X_train = preprocessor.transform(train_df[feature_cols]).astype(np.float32)
	y_train = label_encoder.transform(train_df[TARGET_COLUMN])

	X_val = preprocessor.transform(val_df[feature_cols]).astype(np.float32)
	y_val = label_encoder.transform(val_df[TARGET_COLUMN])

	# Convert to PyTorch tensors
	X_train_tensor = torch.FloatTensor(X_train)
	y_train_tensor = torch.LongTensor(y_train)
	X_val_tensor = torch.FloatTensor(X_val)
	y_val_tensor = torch.LongTensor(y_val)

	# Create data loaders
	train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
	val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
	
	train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
	val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

	# Class weights to handle imbalance
	classes = np.unique(y_train)
	weights = compute_class_weight(class_weight="balanced", classes=classes, y=y_train)
	class_weights = torch.FloatTensor(weights).to(device)

	# Create TabNSA model
	num_features = X_train.shape[1]
	num_classes = len(label_encoder.classes_)
	
	model = create_tabnsa_model(
		num_features=num_features,
		num_classes=num_classes,
		embed_dim=args.embed_dim,
		num_heads=args.num_heads,
		num_layers=args.num_layers,
		window_size=args.window_size,
		block_size=args.block_size,
		mlp_ratio=args.mlp_ratio,
		dropout=args.dropout
	).to(device)

	print(f"TabNSA Model created with {sum(p.numel() for p in model.parameters()):,} parameters")
	print(f"Input features: {num_features}, Output classes: {num_classes}")

	# Initialize loss, optimizer, and scheduler
	criterion = nn.CrossEntropyLoss(weight=class_weights)
	optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
	scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-5)

	# Training loop
	best_val_loss = float('inf')
	patience_counter = 0
	patience = 10
	
	train_losses, train_accs = [], []
	val_losses, val_accs = [], []

	for epoch in range(args.epochs):
		# Training
		train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
		
		# Validation
		val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
		
		# Learning rate scheduling
		scheduler.step(val_loss)
		
		# Store metrics
		train_losses.append(train_loss)
		train_accs.append(train_acc)
		val_losses.append(val_loss)
		val_accs.append(val_acc)
		
		print(f"Epoch {epoch+1}/{args.epochs}:")
		print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
		print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
		print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
		
		# Early stopping
		if val_loss < best_val_loss:
			best_val_loss = val_loss
			patience_counter = 0
			# Save best model
			torch.save(model.state_dict(), os.path.join(args.output_dir, "best_model.pth"))
		else:
			patience_counter += 1
			if patience_counter >= patience:
				print(f"Early stopping at epoch {epoch+1}")
				break

	# Load best model
	model.load_state_dict(torch.load(os.path.join(args.output_dir, "best_model.pth")))
	
	# Save final model
	torch.save(model.state_dict(), os.path.join(args.output_dir, "model.pth"))
	
	# Save training history
	history = {
		"train_loss": train_losses,
		"train_acc": train_accs,
		"val_loss": val_losses,
		"val_acc": val_accs,
		"model_config": {
			"embed_dim": args.embed_dim,
			"num_heads": args.num_heads,
			"num_layers": args.num_layers,
			"window_size": args.window_size,
			"block_size": args.block_size,
			"mlp_ratio": args.mlp_ratio,
			"dropout": args.dropout,
			"num_features": num_features,
			"num_classes": num_classes
		}
	}
	
	history_path = os.path.join(args.output_dir, "history.json")
	with open(history_path, "w", encoding="utf-8") as f:
		json.dump(history, f, indent=2)

	print("Saved model and history to", args.output_dir)


if __name__ == "__main__":
	main()
