import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from typing import Tuple, Dict, Optional
from .preprocessing import DataPreprocessor

class ESGDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        preprocessor: DataPreprocessor,
        target_column: str = 'ESG_Overall',
        is_training: bool = True
    ):
        """Initialize the ESG dataset.
        
        Args:
            df: Input DataFrame
            preprocessor: Fitted DataPreprocessor instance
            target_column: Name of the target column to predict
            is_training: Whether this is a training dataset
        """
        self.df = df
        self.preprocessor = preprocessor
        self.target_column = target_column
        self.is_training = is_training
        
        # Transform features
        self.numerical_features, self.categorical_features = preprocessor.transform(df)
        
        # Get targets
        self.targets = df[target_column].values
        
        # Convert to tensors
        self.numerical_features = torch.FloatTensor(self.numerical_features)
        self.categorical_features = torch.FloatTensor(self.categorical_features)
        self.targets = torch.FloatTensor(self.targets)
        
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Get a single data point.
        
        Args:
            idx: Index of the data point
            
        Returns:
            Dictionary containing features and target
        """
        return {
            'numerical': self.numerical_features[idx],
            'categorical': self.categorical_features[idx],
            'target': self.targets[idx]
        }

def create_data_loaders(
    df: pd.DataFrame,
    preprocessor: DataPreprocessor,
    target_column: str = 'ESG_Overall',
    train_ratio: float = 0.8,
    val_ratio: float = 0.1,
    batch_size: int = 32,
    random_state: int = 42
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create train, validation and test data loaders.
    
    Args:
        df: Input DataFrame
        preprocessor: Fitted DataPreprocessor instance
        target_column: Name of the target column to predict
        train_ratio: Ratio of training data
        val_ratio: Ratio of validation data
        batch_size: Batch size for data loaders
        random_state: Random seed for reproducibility
        
    Returns:
        Tuple of (train_loader, val_loader, test_loader)
    """
    # Set random seed
    np.random.seed(random_state)
    
    # Shuffle indices
    n_samples = len(df)
    indices = np.random.permutation(n_samples)
    
    # Split indices
    train_size = int(n_samples * train_ratio)
    val_size = int(n_samples * val_ratio)
    
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]
    
    # Create datasets
    train_dataset = ESGDataset(
        df.iloc[train_indices],
        preprocessor,
        target_column,
        is_training=True
    )
    
    val_dataset = ESGDataset(
        df.iloc[val_indices],
        preprocessor,
        target_column,
        is_training=False
    )
    
    test_dataset = ESGDataset(
        df.iloc[test_indices],
        preprocessor,
        target_column,
        is_training=False
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    
    return train_loader, val_loader, test_loader 