import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
import os
from tqdm import tqdm
import logging
from model import TabularTransformer

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class LoanDataset(Dataset):
    def __init__(self, categorical_features, numerical_features, targets):
        self.categorical_features = torch.LongTensor(categorical_features)
        self.numerical_features = torch.FloatTensor(numerical_features)
        self.targets = torch.FloatTensor(targets)
        
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        return {
            'categorical': self.categorical_features[idx],
            'numerical': self.numerical_features[idx],
            'target': self.targets[idx]
        }

class DataProcessor:
    def __init__(self):
        self.categorical_encoders = {}
        self.numerical_scaler = StandardScaler()
        self.categorical_columns = [
            'term', 'grade', 'sub_grade', 'emp_length', 'home_ownership',
            'verification_status', 'purpose', 'initial_list_status'
        ]
        self.numerical_columns = [
            'loan_amnt', 'int_rate', 'annual_inc', 'dti', 'open_acc',
            'pub_rec', 'revol_bal', 'revol_util', 'total_acc',
            'mort_acc', 'pub_rec_bankruptcies'
        ]
        
    def fit_transform(self, df):
        # Process categorical features
        categorical_features = []
        for col in self.categorical_columns:
            if col not in self.categorical_encoders:
                self.categorical_encoders[col] = LabelEncoder()
                self.categorical_encoders[col].fit(df[col].fillna('missing'))
            encoded = self.categorical_encoders[col].transform(df[col].fillna('missing'))
            categorical_features.append(encoded)
        
        # Process numerical features
        numerical_features = self.numerical_scaler.fit_transform(
            df[self.numerical_columns].fillna(0)
        )
        
        return np.column_stack(categorical_features), numerical_features
    
    def transform(self, df):
        # Process categorical features
        categorical_features = []
        for col in self.categorical_columns:
            encoded = self.categorical_encoders[col].transform(df[col].fillna('missing'))
            categorical_features.append(encoded)
        
        # Process numerical features
        numerical_features = self.numerical_scaler.transform(
            df[self.numerical_columns].fillna(0)
        )
        
        return np.column_stack(categorical_features), numerical_features

def load_data(data_dir):
    train_df = pd.read_csv(os.path.join(data_dir, 'train.csv'))
    val_df = pd.read_csv(os.path.join(data_dir, 'val.csv'))
    test_df = pd.read_csv(os.path.join(data_dir, 'test.csv'))
    
    return train_df, val_df, test_df

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(train_loader, desc='Training')
    for batch in progress_bar:
        categorical = batch['categorical'].to(device)
        numerical = batch['numerical'].to(device)
        targets = batch['target'].to(device)
        
        optimizer.zero_grad()
        outputs = model(categorical, numerical)
        loss = criterion(outputs.squeeze(), targets)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predicted = (outputs.squeeze() > 0.5).float()
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
        
        progress_bar.set_postfix({
            'loss': total_loss / (progress_bar.n + 1),
            'acc': 100. * correct / total
        })
    
    return total_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            categorical = batch['categorical'].to(device)
            numerical = batch['numerical'].to(device)
            targets = batch['target'].to(device)
            
            outputs = model(categorical, numerical)
            loss = criterion(outputs.squeeze(), targets)
            
            total_loss += loss.item()
            predicted = (outputs.squeeze() > 0.5).float()
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    return total_loss / len(val_loader), 100. * correct / total

def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    num_epochs,
    device,
    save_path
):
    best_val_acc = 0
    
    for epoch in range(num_epochs):
        logger.info(f'Epoch {epoch+1}/{num_epochs}')
        
        # Training
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validation
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        logger.info(
            f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
            f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%'
        )
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, save_path)
            logger.info(f'Saved best model with validation accuracy: {val_acc:.2f}%')

def prepare_data(data_dir, batch_size=32):
    # Load data
    train_df, val_df, test_df = load_data(data_dir)
    
    # Initialize data processor
    processor = DataProcessor()
    
    # Process training data
    train_cat, train_num = processor.fit_transform(train_df)
    train_targets = train_df['loan_status'].map({'Fully Paid': 1, 'Charged Off': 0}).values
    
    # Process validation data
    val_cat, val_num = processor.transform(val_df)
    val_targets = val_df['loan_status'].map({'Fully Paid': 1, 'Charged Off': 0}).values
    
    # Process test data
    test_cat, test_num = processor.transform(test_df)
    test_targets = test_df['loan_status'].map({'Fully Paid': 1, 'Charged Off': 0}).values
    
    # Create datasets
    train_dataset = LoanDataset(train_cat, train_num, train_targets)
    val_dataset = LoanDataset(val_cat, val_num, val_targets)
    test_dataset = LoanDataset(test_cat, test_num, test_targets)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=4
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
    )
    
    # Get categorical cardinalities
    categorical_cardinalities = [
        len(processor.categorical_encoders[col].classes_)
        for col in processor.categorical_columns
    ]
    
    return train_loader, val_loader, test_loader, categorical_cardinalities
