import numpy as np
from scipy.special import expit
from sklearn.linear_model import Ridge

class I_ICA_Plus:
    def __init__(
        self,
        L_max=200,
        T_max=1000,
        tol=0.01,
        σ=0.01,
        verbose=True,
        reg_init=1e-3,
        reg_decay=0.99
    ):
        self.L_max = L_max
        self.T_max = T_max
        self.tol = tol
        self.σ = σ
        self.verbose = verbose

        self.reg_init = reg_init
        self.reg_decay = reg_decay

        self.L = 0
        self.W = []
        self.b = []
        self.Beta = None
        self.e = None

    def geometric_constraint(self, g, e):
        numerator = sum(np.dot(e[:, j], g[:, 0]) ** 2 for j in range(e.shape[1]))
        denominator = (np.linalg.norm(e) ** 2) * (np.linalg.norm(g) ** 2) + 1e-10
        return numerator / denominator

    def node_pool_search(self, X, e):
        best_g = None
        best_cos2 = -np.inf
        d = X.shape[1]
        n_samples = X.shape[0]
        res_norm = np.linalg.norm(e)

        # Adaptive scaling
        λ_min = 0.05 * res_norm
        λ_max = 1.0 * res_norm

        for _ in range(self.T_max):
            λ = np.random.uniform(λ_min, λ_max)
            w = λ * (2 * np.random.rand(d) - 1)
            b_val = λ * (2 * np.random.rand(1) - 1)
            g = expit(X @ w + b_val).reshape(n_samples, 1)

            cos2 = self.geometric_constraint(g, e)
            γ_L = 1 - (1 + self.σ * self.L) / (1 + self.L + 1)
            threshold = γ_L * np.linalg.norm(e) ** 2

            if cos2 >= threshold and cos2 > best_cos2:
                best_cos2 = cos2
                best_g = (w, b_val, g.copy())

        self.σ = self.σ + np.random.uniform(1 - self.σ, 1)
        return best_g

    def fit(self, X, y):
        n_samples = X.shape[0]
        y = np.atleast_2d(y)
        if y.shape[0] != n_samples:
            y = y.T

        self.e = y.copy()
        H = np.zeros((n_samples, 0))

        self.W = []
        self.b = []
        self.L = 0

        reg_value = self.reg_init

        if self.verbose:
            print(f"✅ [I-ICA] Start training: L_max={self.L_max}, T_max={self.T_max}, tol={self.tol}")

        while self.L < self.L_max and np.linalg.norm(self.e) > self.tol:
            result = self.node_pool_search(X, self.e)

            if result is None:
                if self.verbose:
                    print(f"⚠️ Node pool search failed. Adding random node.")
                d = X.shape[1]
                w = np.random.randn(d)
                b_val = np.random.randn(1)
                g_new = expit(X @ w + b_val).reshape(n_samples, 1)
            else:
                w, b_val, g_new = result

            self.W.append(w)
            self.b.append(b_val)
            self.L += 1

            g_new = g_new.reshape(n_samples, 1)
            H = np.hstack((H, g_new)) if H.size else g_new

            # Adaptive regularization
            reg_model = Ridge(alpha=reg_value, fit_intercept=False)
            reg_model.fit(H, y)
            self.Beta = reg_model.coef_.T

            self.e = y - H @ self.Beta

            # Decay regularization
            reg_value *= self.reg_decay

            if self.verbose:
                print(f"✅ [Node {self.L:3d}] Residual norm: {np.linalg.norm(self.e):.6f}, reg: {reg_value:.6e}")

        if self.verbose:
            print(f"🎯 Training complete. Hidden nodes: {self.L}, Final residual norm: {np.linalg.norm(self.e):.6f}")

    def predict(self, X):
        if self.Beta is None or self.L == 0:
            raise ValueError("❌ Cannot predict: model not trained.")
        n_samples = X.shape[0]
        H_test = np.zeros((n_samples, self.L))
        for i in range(self.L):
            H_test[:, i] = expit(X @ self.W[i] + self.b[i])
        return H_test @ self.Beta

    def rmse(self, y_true, y_pred):
        return np.sqrt(np.mean((y_true - y_pred) ** 2))
