import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class GatedLinearUnit(nn.Module):
    """Gated Linear Unit (GLU) layer."""
    
    def __init__(self, input_dim: int, output_dim: int, dropout: float = 0.1):
        """Initialize GLU layer.
        
        Args:
            input_dim: Input dimension
            output_dim: Output dimension
            dropout: Dropout rate
        """
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim * 2)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        x = self.fc(x)
        x = self.dropout(x)
        x, gate = x.chunk(2, dim=-1)
        return x * torch.tanh(gate)

class SemiPermeableAttention(nn.Module):
    """Semi-Permeable Attention (SPA) layer."""
    
    def __init__(
        self,
        input_dim: int,
        n_heads: int = 32,
        dropout: float = 0.1,
        gamma: float = 1e-4
    ):
        """Initialize SPA layer.
        
        Args:
            input_dim: Input dimension
            n_heads: Number of attention heads
            dropout: Dropout rate
            gamma: Initialization scale for attention weights
        """
        super().__init__()
        self.input_dim = input_dim
        self.n_heads = n_heads
        self.head_dim = input_dim // n_heads
        
        # Initialize attention weights with small values
        self.q_proj = nn.Linear(input_dim, input_dim)
        self.k_proj = nn.Linear(input_dim, input_dim)
        self.v_proj = nn.Linear(input_dim, input_dim)
        self.out_proj = nn.Linear(input_dim, input_dim)
        
        # Initialize with small weights
        nn.init.normal_(self.q_proj.weight, std=gamma)
        nn.init.normal_(self.k_proj.weight, std=gamma)
        nn.init.normal_(self.v_proj.weight, std=gamma)
        nn.init.normal_(self.out_proj.weight, std=gamma)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
            mask: Optional attention mask of shape (batch_size, seq_len, seq_len)
            
        Returns:
            Output tensor of shape (batch_size, seq_len, input_dim)
        """
        batch_size, seq_len, _ = x.shape
        
        # Project queries, keys and values
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
        
        # Transpose for attention computation
        q = q.transpose(1, 2)  # (batch_size, n_heads, seq_len, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Apply softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Compute output
        out = torch.matmul(attn_weights, v)  # (batch_size, n_heads, seq_len, head_dim)
        out = out.transpose(1, 2).contiguous()  # (batch_size, seq_len, n_heads, head_dim)
        out = out.view(batch_size, seq_len, self.input_dim)
        
        return self.out_proj(out)

class GLUFeedforward(nn.Module):
    """GLU-based feedforward network."""
    
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        dropout: float = 0.1
    ):
        """Initialize GLU feedforward network.
        
        Args:
            input_dim: Input dimension
            hidden_dim: Hidden dimension
            dropout: Dropout rate
        """
        super().__init__()
        self.glu1 = GatedLinearUnit(input_dim, hidden_dim, dropout)
        self.glu2 = GatedLinearUnit(hidden_dim, input_dim, dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, seq_len, input_dim)
        """
        return self.glu2(self.glu1(x))

class ExcelFormerBlock(nn.Module):
    """A single block of the ExcelFormer model."""
    
    def __init__(
        self,
        input_dim: int,
        n_heads: int = 32,
        hidden_dim: int = 256,
        dropout: float = 0.1,
        gamma: float = 1e-4
    ):
        """Initialize ExcelFormer block.
        
        Args:
            input_dim: Input dimension
            n_heads: Number of attention heads
            hidden_dim: Hidden dimension for feedforward network
            dropout: Dropout rate
            gamma: Initialization scale for attention weights
        """
        super().__init__()
        self.attention = SemiPermeableAttention(
            input_dim,
            n_heads=n_heads,
            dropout=dropout,
            gamma=gamma
        )
        self.feedforward = GLUFeedforward(
            input_dim,
            hidden_dim,
            dropout=dropout
        )
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
            mask: Optional attention mask
            
        Returns:
            Output tensor of shape (batch_size, seq_len, input_dim)
        """
        # Attention block
        residual = x
        x = self.norm1(x)
        x = self.attention(x, mask)
        x = self.dropout(x)
        x = residual + x
        
        # Feedforward block
        residual = x
        x = self.norm2(x)
        x = self.feedforward(x)
        x = self.dropout(x)
        x = residual + x
        
        return x 