import argparse
import torch
import logging
from model import TabularTransformer
from training_utils import prepare_data, validate, DataProcessor
import pandas as pd
import os

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description='Loan Repayment Prediction')
    parser.add_argument('--mode', type=str, required=True, choices=['train', 'predict'],
                      help='Mode to run the model: train or predict')
    parser.add_argument('--data_dir', type=str, default='./data',
                      help='Directory containing the data files')
    parser.add_argument('--model_path', type=str, default='./checkpoints/best_model.pth',
                      help='Path to save/load the model')
    parser.add_argument('--batch_size', type=int, default=256,
                      help='Batch size for training and inference')
    parser.add_argument('--embedding_dim', type=int, default=64,
                      help='Dimension of embeddings')
    parser.add_argument('--num_heads', type=int, default=2,
                      help='Number of attention heads')
    parser.add_argument('--num_layers', type=int, default=1,
                      help='Number of transformer layers')
    parser.add_argument('--dropout', type=float, default=0.1,
                      help='Dropout rate')
    parser.add_argument('--learning_rate', type=float, default=1e-4,
                      help='Learning rate')
    parser.add_argument('--num_epochs', type=int, default=50,
                      help='Number of training epochs')
    parser.add_argument('--input_file', type=str,
                      help='Path to input file for prediction (required in predict mode)')
    parser.add_argument('--output_file', type=str,
                      help='Path to save predictions (required in predict mode)')
    
    return parser.parse_args()

def train(args):
    from training_utils import train_model
    import torch.nn as nn
    import torch.optim as optim
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Using device: {device}')
    
    # Prepare data
    train_loader, val_loader, test_loader, categorical_cardinalities = prepare_data(
        args.data_dir, args.batch_size
    )
    
    # Initialize model
    model = TabularTransformer(
        num_categorical_features=len(categorical_cardinalities),
        num_numerical_features=len(DataProcessor().numerical_columns),
        categorical_cardinalities=categorical_cardinalities,
        embedding_dim=args.embedding_dim,
        num_heads=args.num_heads,
        num_layers=args.num_layers,
        dropout=args.dropout,
        num_classes=1
    ).to(device)
    
    # Initialize optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    criterion = nn.BCEWithLogitsLoss()
    
    # Create checkpoint directory
    os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
    
    # Train model
    train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=args.num_epochs,
        device=device,
        save_path=args.model_path
    )
    
    # Evaluate on test set
    checkpoint = torch.load(args.model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    test_loss, test_acc = validate(model, test_loader, criterion, device)
    logger.info(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

def predict(args):
    if not args.input_file or not args.output_file:
        raise ValueError("input_file and output_file are required in predict mode")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f'Using device: {device}')
    
    # Load data processor and model
    processor = DataProcessor()
    
    # Load and process input data
    input_df = pd.read_csv(args.input_file)
    cat_features, num_features = processor.transform(input_df)
    
    # Load model
    checkpoint = torch.load(args.model_path, map_location=device)
    model = TabularTransformer(
        num_categorical_features=len(processor.categorical_columns),
        num_numerical_features=len(processor.numerical_columns),
        categorical_cardinalities=[len(processor.categorical_encoders[col].classes_)
                                 for col in processor.categorical_columns],
        embedding_dim=args.embedding_dim,
        num_heads=args.num_heads,
        num_layers=args.num_layers,
        dropout=args.dropout,
        num_classes=1
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Make predictions
    predictions = []
    with torch.no_grad():
        for i in range(0, len(input_df), args.batch_size):
            batch_cat = torch.LongTensor(cat_features[i:i+args.batch_size]).to(device)
            batch_num = torch.FloatTensor(num_features[i:i+args.batch_size]).to(device)
            
            outputs = model(batch_cat, batch_num)
            probs = torch.sigmoid(outputs).cpu().numpy()
            predictions.extend(probs)
    
    # Save predictions
    input_df['prediction'] = predictions
    input_df.to_csv(args.output_file, index=False)
    logger.info(f'Predictions saved to {args.output_file}')

def main():
    args = parse_args()
    
    if args.mode == 'train':
        train(args)
    else:  # predict mode
        predict(args)

if __name__ == '__main__':
    main()
