import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class FeatureEncoder(nn.Module):
    def __init__(self, num_categorical_features, num_numerical_features, embedding_dim, categorical_cardinalities):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # Embedding layers for categorical features
        self.categorical_embeddings = nn.ModuleList([
            nn.Embedding(cardinality, embedding_dim) 
            for cardinality in categorical_cardinalities
        ])
        
        # Layer for numerical features
        self.numerical_projection = nn.Linear(num_numerical_features, embedding_dim)
        self.layer_norm = nn.LayerNorm(embedding_dim)
        
        # CLS token embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
    def forward(self, categorical_features, numerical_features):
        # Process categorical features
        cat_embeddings = []
        for i, embedding_layer in enumerate(self.categorical_embeddings):
            cat_emb = embedding_layer(categorical_features[:, i])
            cat_embeddings.append(cat_emb)
        
        # Process numerical features
        num_emb = self.numerical_projection(numerical_features)
        num_emb = self.layer_norm(num_emb)
        
        # Combine all embeddings
        batch_size = categorical_features.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        
        # Concatenate all embeddings
        embeddings = torch.cat([
            torch.stack(cat_embeddings, dim=1),  # [batch_size, num_cat, embedding_dim]
            num_emb.unsqueeze(1),                # [batch_size, 1, embedding_dim]
            cls_tokens                           # [batch_size, 1, embedding_dim]
        ], dim=1)
        
        return embeddings

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        # Linear projections and reshape
        q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        out = torch.matmul(attn, v)
        
        # Reshape and apply output projection
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(out)

class GatedTransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Gating mechanism
        self.gate = nn.Linear(d_model, 1)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * d_model, d_model)
        )
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Multi-head attention
        attn_output = self.attention(x)
        
        # Gating mechanism
        gate = torch.sigmoid(self.gate(attn_output))
        gated_attn = gate * attn_output
        
        # Residual connection and normalization
        x = self.norm1(x + self.dropout(gated_attn))
        
        # Feed-forward network
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x

class TabularTransformer(nn.Module):
    def __init__(self, 
                 num_categorical_features,
                 num_numerical_features,
                 categorical_cardinalities,
                 embedding_dim=64,
                 num_heads=4,
                 num_layers=3,
                 dropout=0.1,
                 num_classes=1):
        super().__init__()
        
        self.feature_encoder = FeatureEncoder(
            num_categorical_features=num_categorical_features,
            num_numerical_features=num_numerical_features,
            embedding_dim=embedding_dim,
            categorical_cardinalities=categorical_cardinalities
        )
        
        # Stack of transformer blocks
        self.transformer_blocks = nn.ModuleList([
            GatedTransformerBlock(
                d_model=embedding_dim,
                num_heads=num_heads,
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        
        # Prediction head
        self.prediction_head = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim // 2, num_classes)
        )
        
    def forward(self, categorical_features, numerical_features):
        # Feature encoding
        x = self.feature_encoder(categorical_features, numerical_features)
        
        # Transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Use CLS token for prediction
        cls_token = x[:, -1, :]  # Get the CLS token embedding
        
        # Prediction
        output = self.prediction_head(cls_token)
        
        return output
