import numpy as np
import pandas as pd
from sklearn.preprocessing import QuantileTransformer
from catboost import CatBoostEncoder
from typing import Tuple, List, Dict, Optional

class DataPreprocessor:
    def __init__(self, n_quantiles: int = 1000, random_state: int = 42):
        """Initialize the data preprocessor.
        
        Args:
            n_quantiles: Number of quantiles for numerical feature normalization
            random_state: Random seed for reproducibility
        """
        self.n_quantiles = n_quantiles
        self.random_state = random_state
        self.numerical_transformer = QuantileTransformer(
            n_quantiles=n_quantiles, 
            output_distribution='normal',
            random_state=random_state
        )
        self.categorical_encoder = CatBoostEncoder(random_state=random_state)
        self.feature_names: Dict[str, List[str]] = {
            'numerical': [],
            'categorical': []
        }
        self.target_columns = ['ESG_Overall', 'ESG_Environmental', 'ESG_Social', 'ESG_Governance']
        
    def identify_features(self, df: pd.DataFrame) -> None:
        """Identify numerical and categorical features from the dataset.
        
        Args:
            df: Input DataFrame
        """
        # Exclude target columns and non-feature columns
        exclude_cols = self.target_columns + ['CompanyID', 'CompanyName', 'Year']
        
        for col in df.columns:
            if col in exclude_cols:
                continue
                
            if df[col].dtype in ['int64', 'float64']:
                self.feature_names['numerical'].append(col)
            else:
                self.feature_names['categorical'].append(col)
    
    def fit(self, df: pd.DataFrame) -> None:
        """Fit the preprocessors on the training data.
        
        Args:
            df: Training DataFrame
        """
        self.identify_features(df)
        
        # Fit numerical transformer
        if self.feature_names['numerical']:
            self.numerical_transformer.fit(df[self.feature_names['numerical']])
        
        # Fit categorical encoder
        if self.feature_names['categorical']:
            self.categorical_encoder.fit(
                df[self.feature_names['categorical']],
                df[self.target_columns[0]]  # Use overall ESG score as target
            )
    
    def transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Transform the input data.
        
        Args:
            df: Input DataFrame
            
        Returns:
            Tuple of (numerical_features, categorical_features) as numpy arrays
        """
        numerical_features = np.array([])
        categorical_features = np.array([])
        
        if self.feature_names['numerical']:
            numerical_features = self.numerical_transformer.transform(
                df[self.feature_names['numerical']]
            )
            
        if self.feature_names['categorical']:
            categorical_features = self.categorical_encoder.transform(
                df[self.feature_names['categorical']]
            )
            
        return numerical_features, categorical_features
    
    def fit_transform(self, df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
        """Fit and transform the data in one step.
        
        Args:
            df: Input DataFrame
            
        Returns:
            Tuple of (numerical_features, categorical_features) as numpy arrays
        """
        self.fit(df)
        return self.transform(df)
    
    def get_feature_names(self) -> Dict[str, List[str]]:
        """Get the names of numerical and categorical features.
        
        Returns:
            Dictionary containing lists of feature names
        """
        return self.feature_names.copy()
    
    def get_feature_dims(self) -> Tuple[int, int]:
        """Get the dimensions of numerical and categorical features.
        
        Returns:
            Tuple of (n_numerical_features, n_categorical_features)
        """
        return (
            len(self.feature_names['numerical']),
            len(self.feature_names['categorical'])
        ) 