import torch
import torch.nn as nn
from typing import Dict, Optional, Tuple
from .layers import GatedLinearUnit, ExcelFormerBlock

class ExcelFormer(nn.Module):
    """ExcelFormer model for tabular data."""
    
    def __init__(
        self,
        n_numerical_features: int,
        n_categorical_features: int,
        embedding_dim: int = 64,
        n_blocks: int = 4,
        n_heads: int = 32,
        hidden_dim: int = 256,
        dropout: float = 0.1,
        gamma: float = 1e-4
    ):
        """Initialize ExcelFormer model.
        
        Args:
            n_numerical_features: Number of numerical features
            n_categorical_features: Number of categorical features
            embedding_dim: Dimension of feature embeddings
            n_blocks: Number of ExcelFormer blocks
            n_heads: Number of attention heads
            hidden_dim: Hidden dimension for feedforward networks
            dropout: Dropout rate
            gamma: Initialization scale for attention weights
        """
        super().__init__()
        
        # Feature embeddings
        self.numerical_embedding = GatedLinearUnit(
            n_numerical_features,
            embedding_dim,
            dropout
        )
        
        self.categorical_embedding = GatedLinearUnit(
            n_categorical_features,
            embedding_dim,
            dropout
        )
        
        # ExcelFormer blocks
        self.blocks = nn.ModuleList([
            ExcelFormerBlock(
                embedding_dim * 2,  # Combined numerical and categorical embeddings
                n_heads=n_heads,
                hidden_dim=hidden_dim,
                dropout=dropout,
                gamma=gamma
            )
            for _ in range(n_blocks)
        ])
        
        # Prediction head
        self.prediction_head = nn.Sequential(
            nn.Linear(embedding_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
    def forward(
        self,
        numerical_features: torch.Tensor,
        categorical_features: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass.
        
        Args:
            numerical_features: Numerical features tensor of shape (batch_size, n_numerical_features)
            categorical_features: Categorical features tensor of shape (batch_size, n_categorical_features)
            mask: Optional attention mask
            
        Returns:
            Predictions tensor of shape (batch_size, 1)
        """
        # Embed features
        numerical_emb = self.numerical_embedding(numerical_features)
        categorical_emb = self.categorical_embedding(categorical_features)
        
        # Combine embeddings
        x = torch.cat([numerical_emb, categorical_emb], dim=-1)
        
        # Add sequence dimension for transformer blocks
        x = x.unsqueeze(1)  # (batch_size, 1, embedding_dim * 2)
        
        # Process through ExcelFormer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        # Remove sequence dimension
        x = x.squeeze(1)  # (batch_size, embedding_dim * 2)
        
        # Make predictions
        predictions = self.prediction_head(x)
        
        return predictions
    
    def compute_feature_importance(
        self,
        numerical_features: torch.Tensor,
        categorical_features: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute feature importance scores.
        
        Args:
            numerical_features: Numerical features tensor
            categorical_features: Categorical features tensor
            
        Returns:
            Tuple of (numerical_importance, categorical_importance) tensors
        """
        # Get embeddings
        numerical_emb = self.numerical_embedding(numerical_features)
        categorical_emb = self.categorical_embedding(categorical_features)
        
        # Compute importance as gradient magnitude
        numerical_emb.requires_grad_(True)
        categorical_emb.requires_grad_(True)
        
        # Forward pass
        x = torch.cat([numerical_emb, categorical_emb], dim=-1)
        x = x.unsqueeze(1)
        
        for block in self.blocks:
            x = block(x)
        
        x = x.squeeze(1)
        predictions = self.prediction_head(x)
        
        # Compute gradients
        predictions.backward(torch.ones_like(predictions))
        
        # Get importance scores
        numerical_importance = torch.abs(numerical_emb.grad).mean(dim=0)
        categorical_importance = torch.abs(categorical_emb.grad).mean(dim=0)
        
        return numerical_importance, categorical_importance 