import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Dict, Optional, Tuple, List
import numpy as np
from tqdm import tqdm
import wandb
from ..models.excelformer import ExcelFormer
from .metrics import compute_metrics

class Trainer:
    """Trainer class for ExcelFormer model."""
    
    def __init__(
        self,
        model: ExcelFormer,
        optimizer: torch.optim.Optimizer,
        criterion: nn.Module,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
        use_wandb: bool = True
    ):
        """Initialize trainer.
        
        Args:
            model: ExcelFormer model
            optimizer: Optimizer
            criterion: Loss function
            device: Device to train on
            use_wandb: Whether to use Weights & Biases for logging
        """
        self.model = model.to(device)
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.use_wandb = use_wandb
        
    def train_epoch(
        self,
        train_loader: DataLoader,
        epoch: int
    ) -> Dict[str, float]:
        """Train for one epoch.
        
        Args:
            train_loader: Training data loader
            epoch: Current epoch number
            
        Returns:
            Dictionary of training metrics
        """
        self.model.train()
        total_loss = 0
        predictions = []
        targets = []
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        for batch in pbar:
            # Move batch to device
            numerical = batch['numerical'].to(self.device)
            categorical = batch['categorical'].to(self.device)
            target = batch['target'].to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            output = self.model(numerical, categorical)
            loss = self.criterion(output.squeeze(), target)
            
            # Backward pass
            loss.backward()
            self.optimizer.step()
            
            # Update metrics
            total_loss += loss.item()
            predictions.extend(output.squeeze().detach().cpu().numpy())
            targets.extend(target.cpu().numpy())
            
            # Update progress bar
            pbar.set_postfix({'loss': loss.item()})
        
        # Compute metrics
        metrics = compute_metrics(
            np.array(predictions),
            np.array(targets)
        )
        metrics['loss'] = total_loss / len(train_loader)
        
        # Log metrics
        if self.use_wandb:
            wandb.log({
                f'train/{k}': v
                for k, v in metrics.items()
            }, step=epoch)
        
        return metrics
    
    @torch.no_grad()
    def evaluate(
        self,
        data_loader: DataLoader,
        epoch: Optional[int] = None,
        prefix: str = 'val'
    ) -> Dict[str, float]:
        """Evaluate the model.
        
        Args:
            data_loader: Data loader for evaluation
            epoch: Current epoch number (for logging)
            prefix: Prefix for metric names
            
        Returns:
            Dictionary of evaluation metrics
        """
        self.model.eval()
        total_loss = 0
        predictions = []
        targets = []
        
        for batch in data_loader:
            # Move batch to device
            numerical = batch['numerical'].to(self.device)
            categorical = batch['categorical'].to(self.device)
            target = batch['target'].to(self.device)
            
            # Forward pass
            output = self.model(numerical, categorical)
            loss = self.criterion(output.squeeze(), target)
            
            # Update metrics
            total_loss += loss.item()
            predictions.extend(output.squeeze().cpu().numpy())
            targets.extend(target.cpu().numpy())
        
        # Compute metrics
        metrics = compute_metrics(
            np.array(predictions),
            np.array(targets)
        )
        metrics['loss'] = total_loss / len(data_loader)
        
        # Log metrics
        if self.use_wandb and epoch is not None:
            wandb.log({
                f'{prefix}/{k}': v
                for k, v in metrics.items()
            }, step=epoch)
        
        return metrics
    
    def train(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        n_epochs: int,
        early_stopping_patience: int = 5,
        model_save_path: Optional[str] = None
    ) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
        """Train the model.
        
        Args:
            train_loader: Training data loader
            val_loader: Validation data loader
            n_epochs: Number of epochs to train for
            early_stopping_patience: Number of epochs to wait for improvement
            model_save_path: Path to save best model
            
        Returns:
            Tuple of (train_metrics, val_metrics) lists
        """
        train_metrics = []
        val_metrics = []
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(n_epochs):
            # Train epoch
            train_epoch_metrics = self.train_epoch(train_loader, epoch)
            train_metrics.append(train_epoch_metrics)
            
            # Evaluate
            val_epoch_metrics = self.evaluate(val_loader, epoch)
            val_metrics.append(val_epoch_metrics)
            
            # Early stopping
            if val_epoch_metrics['loss'] < best_val_loss:
                best_val_loss = val_epoch_metrics['loss']
                patience_counter = 0
                
                # Save best model
                if model_save_path:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': self.model.state_dict(),
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_loss': best_val_loss,
                    }, model_save_path)
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    print(f'Early stopping at epoch {epoch}')
                    break
        
        return train_metrics, val_metrics
    
    def load_checkpoint(self, checkpoint_path: str) -> None:
        """Load model checkpoint.
        
        Args:
            checkpoint_path: Path to checkpoint file
        """
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
    def predict(
        self,
        numerical_features: torch.Tensor,
        categorical_features: torch.Tensor
    ) -> np.ndarray:
        """Make predictions.
        
        Args:
            numerical_features: Numerical features tensor
            categorical_features: Categorical features tensor
            
        Returns:
            Predictions as numpy array
        """
        self.model.eval()
        with torch.no_grad():
            numerical = numerical_features.to(self.device)
            categorical = categorical_features.to(self.device)
            predictions = self.model(numerical, categorical)
            return predictions.squeeze().cpu().numpy() 