import argparse
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import AdamW
import wandb
from pathlib import Path

from src.data.preprocessing import DataPreprocessor
from src.data.dataset import create_data_loaders
from src.models.excelformer import ExcelFormer
from src.training.trainer import Trainer
from src.utils.config import Config, get_default_config

def parse_args():
    parser = argparse.ArgumentParser(description='Train ExcelFormer model for ESG prediction')
    parser.add_argument('--config', type=str, help='Path to config YAML file')
    parser.add_argument('--data', type=str, required=True, help='Path to dataset CSV file')
    parser.add_argument('--output', type=str, help='Path to save model and results')
    return parser.parse_args()

def main():
    # Parse arguments
    args = parse_args()
    
    # Load configuration
    if args.config:
        config = Config.from_yaml(args.config)
    else:
        config = get_default_config()
    
    # Create output directory
    if args.output:
        output_dir = Path(args.output)
        output_dir.mkdir(parents=True, exist_ok=True)
        config.training.model_save_path = str(output_dir / 'best_model.pt')
    else:
        output_dir = Path('outputs')
        output_dir.mkdir(parents=True, exist_ok=True)
        config.training.model_save_path = str(output_dir / 'best_model.pt')
    
    # Save configuration
    config.to_yaml(str(output_dir / 'config.yaml'))
    
    # Initialize wandb
    if config.training.use_wandb:
        wandb.init(
            project='esg-prediction',
            config={
                'model': config.model.__dict__,
                'training': config.training.__dict__,
                'data': config.data.__dict__
            }
        )
    
    # Load data
    print('Loading data...')
    df = pd.read_csv(args.data)
    
    # Initialize preprocessor
    preprocessor = DataPreprocessor(
        n_quantiles=config.data.n_quantiles,
        random_state=config.training.random_state
    )
    
    # Fit preprocessor
    print('Preprocessing data...')
    preprocessor.fit(df)
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(
        df,
        preprocessor,
        target_column=config.data.target_column,
        train_ratio=config.training.train_ratio,
        val_ratio=config.training.val_ratio,
        batch_size=config.training.batch_size,
        random_state=config.training.random_state
    )
    
    # Get feature dimensions
    n_numerical, n_categorical = preprocessor.get_feature_dims()
    
    # Initialize model
    print('Initializing model...')
    model = ExcelFormer(
        n_numerical_features=n_numerical,
        n_categorical_features=n_categorical,
        embedding_dim=config.model.embedding_dim,
        n_blocks=config.model.n_blocks,
        n_heads=config.model.n_heads,
        hidden_dim=config.model.hidden_dim,
        dropout=config.model.dropout,
        gamma=config.model.gamma
    )
    
    # Initialize optimizer and loss
    optimizer = AdamW(
        model.parameters(),
        lr=config.training.learning_rate,
        weight_decay=config.training.weight_decay
    )
    criterion = nn.MSELoss()
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        use_wandb=config.training.use_wandb
    )
    
    # Train model
    print('Training model...')
    train_metrics, val_metrics = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        n_epochs=config.training.n_epochs,
        early_stopping_patience=config.training.early_stopping_patience,
        model_save_path=config.training.model_save_path
    )
    
    # Evaluate on test set
    print('Evaluating model...')
    test_metrics = trainer.evaluate(test_loader, prefix='test')
    
    # Save test metrics
    if args.output:
        pd.DataFrame([test_metrics]).to_csv(
            output_dir / 'test_metrics.csv',
            index=False
        )
    
    # Compute feature importance
    print('Computing feature importance...')
    numerical_importance, categorical_importance = model.compute_feature_importance(
        next(iter(test_loader))['numerical'],
        next(iter(test_loader))['categorical']
    )
    
    # Get feature names
    feature_names = preprocessor.get_feature_names()
    
    # Save feature importance
    if args.output:
        importance_df = pd.DataFrame({
            'feature': feature_names['numerical'] + feature_names['categorical'],
            'importance': torch.cat([
                numerical_importance,
                categorical_importance
            ]).cpu().numpy()
        })
        importance_df.to_csv(
            output_dir / 'feature_importance.csv',
            index=False
        )
    
    print('Done!')

if __name__ == '__main__':
    main() 