import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class NativeSparseAttention(nn.Module):
    """
    Native Sparse Attention (NSA) module for efficient feature interaction modeling.
    Implements sliding-window local attention, block-wise compression, and block selection.
    """
    def __init__(self, embed_dim, num_heads=8, window_size=8, block_size=4, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.window_size = window_size
        self.block_size = block_size
        self.scale = self.head_dim ** -0.5
        
        # Linear projections for Q, K, V
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
        # Block compression and selection
        self.block_compression = nn.Linear(block_size * embed_dim, embed_dim)
        self.block_selection = nn.Linear(embed_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, D = x.shape
        
        # Project Q, K, V
        q = self.q_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Sliding-window local attention
        attn_output = self._sliding_window_attention(q, k, v)
        
        # Block-wise compression
        compressed = self._block_compression(attn_output)
        
        # Block selection
        selected = self._block_selection(compressed)
        
        # Reshape and project output
        output = selected.transpose(1, 2).contiguous().view(B, N, D)
        output = self.out_proj(output)
        
        return output
    
    def _sliding_window_attention(self, q, k, v):
        B, H, N, D = q.shape
        output = torch.zeros_like(q)
        
        for i in range(0, N, self.window_size):
            end_idx = min(i + self.window_size, N)
            window_size = end_idx - i
            
            # Extract window
            q_window = q[:, :, i:end_idx, :]
            k_window = k[:, :, i:end_idx, :]
            v_window = v[:, :, i:end_idx, :]
            
            # Compute attention scores
            scores = torch.matmul(q_window, k_window.transpose(-2, -1)) * self.scale
            attn_weights = F.softmax(scores, dim=-1)
            attn_weights = self.dropout(attn_weights)
            
            # Apply attention
            window_output = torch.matmul(attn_weights, v_window)
            output[:, :, i:end_idx, :] = window_output
        
        return output
    
    def _block_compression(self, x):
        B, H, N, D = x.shape
        num_blocks = N // self.block_size
        
        if num_blocks * self.block_size < N:
            # Handle remaining features
            remaining = N - num_blocks * self.block_size
            x = F.pad(x, (0, 0, 0, self.block_size - remaining))
            num_blocks += 1
        
        # Reshape to blocks
        x_blocks = x.view(B, H, num_blocks, self.block_size * D)
        
        # Compress blocks
        compressed = self.block_compression(x_blocks)
        
        return compressed
    
    def _block_selection(self, x):
        B, H, N, D = x.shape
        
        # Compute selection scores
        selection_scores = self.block_selection(x).squeeze(-1)  # [B, H, N]
        
        # Select top blocks (keep top 50% for efficiency)
        num_selected = max(1, N // 2)
        _, top_indices = torch.topk(selection_scores, num_selected, dim=-1)
        
        # Gather selected blocks
        batch_indices = torch.arange(B).view(B, 1, 1).expand(-1, H, num_selected)
        head_indices = torch.arange(H).view(1, H, 1).expand(B, -1, num_selected)
        
        selected = x[batch_indices, head_indices, top_indices]
        
        return selected


class TabMixer(nn.Module):
    """
    TabMixer module for feature mixing via channel-wise and token-wise MLPs.
    Inspired by MLP-Mixer architecture for tabular data.
    """
    def __init__(self, embed_dim, mlp_ratio=4, dropout=0.1):
        super().__init__()
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        
        # Channel-wise MLP (processes each feature independently)
        self.channel_mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Token-wise MLP (processes across features)
        self.token_mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        # Channel-wise mixing
        x = x + self.channel_mlp(self.norm1(x))
        
        # Token-wise mixing (transpose to apply MLP across features)
        x = x + self.token_mlp(self.norm2(x.transpose(-2, -1))).transpose(-2, -1)
        
        return x


class TabNSA(nn.Module):
    """
    TabNSA: Hybrid deep learning framework combining Native Sparse Attention and TabMixer.
    Designed for tabular data with efficient feature interaction modeling.
    """
    def __init__(self, 
                 num_features,
                 embed_dim=128,
                 num_heads=8,
                 num_layers=3,
                 window_size=8,
                 block_size=4,
                 mlp_ratio=4,
                 num_classes=3,
                 dropout=0.1):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.num_features = num_features
        
        # Feature embedding layers
        self.feature_embedding = nn.Linear(1, embed_dim)  # Shared embedding for all features
        
        # NSA and TabMixer layers
        self.nsa_layers = nn.ModuleList([
            NativeSparseAttention(embed_dim, num_heads, window_size, block_size, dropout)
            for _ in range(num_layers)
        ])
        
        self.tabmixer_layers = nn.ModuleList([
            TabMixer(embed_dim, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        
        # Layer normalization
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(embed_dim) for _ in range(num_layers)
        ])
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes)
        )
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)
    
    def forward(self, x):
        # x shape: [B, N] where N is number of features
        
        # Feature embedding: project each feature to D-dimensional space
        # Reshape to [B, N, 1] and project to [B, N, D]
        x = x.unsqueeze(-1)  # [B, N, 1]
        x = self.feature_embedding(x)  # [B, N, D]
        
        # Apply NSA and TabMixer layers
        for i in range(len(self.nsa_layers)):
            # NSA processing
            nsa_out = self.nsa_layers[i](x)
            
            # TabMixer processing
            mixer_out = self.tabmixer_layers[i](x)
            
            # Fusion: element-wise summation
            x = nsa_out + mixer_out
            
            # Layer normalization
            x = self.layer_norms[i](x)
        
        # Global pooling: mean across feature dimension
        x = torch.mean(x, dim=1)  # [B, D]
        
        # Classification head
        logits = self.classifier(x)
        
        return logits


def create_tabnsa_model(num_features, num_classes=3, **kwargs):
    """
    Factory function to create TabNSA model with default hyperparameters.
    
    Args:
        num_features: Number of input features
        num_classes: Number of output classes
        **kwargs: Additional hyperparameters
    
    Returns:
        TabNSA model instance
    """
    default_config = {
        'embed_dim': 128,
        'num_heads': 8,
        'num_layers': 3,
        'window_size': 8,
        'block_size': 4,
        'mlp_ratio': 4,
        'dropout': 0.1
    }
    
    # Update with provided kwargs
    default_config.update(kwargs)
    
    return TabNSA(
        num_features=num_features,
        num_classes=num_classes,
        **default_config
    )
