// Cleaned and restored version of PROAnt.java
// Author: Feras Al-Obeidat
// Enhanced ACO Optimization with Weights and Thresholds

import java.io.IOException;
import java.util.*;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.data.category.DefaultCategoryDataset;
import javax.swing.*;

class PROAnt {
	double resultOfClassificationTest, resultOfClassificationTrain, resultOfClassificationTestAfter,
			resultOfClassificationTrainAfter,resultOfClassificationTestAdaptive;
	static int numOfPrototpes_ = 4;
	static Random rnd = new Random();
	static int numOfClasses, numOfAttributes;

	static double[][][][] Thresholds;
	static double[][][] Weights;
	static double[][][] bestWeightParams;
	double[][] TestRealDE;
	static double[][] TrainRealDE;

	private static final int NUM_ANTS = 40;
	private static final int NUM_ITERATIONS = 60;
	private static final double BASE_PHEROMONE_INCREMENT = 0.2;
	private static boolean isVisualization = false;
	private static double bestFitnessSoFar = Double.NEGATIVE_INFINITY;

	public void sendNumOfClassesAndAttributes(int nClasses, int nAttributes) {
		numOfClasses = nClasses;
		numOfAttributes = nAttributes;
	}

	public void setParameters(int genNumber, int population) {
		// You may implement this if needed for external control
	}

	public void sendTrainingTestingAndThresholds(double[][] train, double[][][][] thresh, double[][][] weights,
			double[][] test) {
		TrainRealDE = train;
		Thresholds = thresh;
		setWeights(weights);
		TestRealDE = test;
	}

	public double run() throws IOException {
		return start();
	}


	public double start() {
		System.out.println("****** Before optimization ");
		Individual popTest = new Individual(Thresholds, getWeights(), TestRealDE);
		resultOfClassificationTest = popTest.fitness;
		System.out.println("Fitness on testing = " + resultOfClassificationTest);

		Individual popTrain = new Individual(Thresholds, getWeights(), TrainRealDE);
		resultOfClassificationTrain = popTrain.fitness;
		System.out.println("Fitness on training = " + resultOfClassificationTrain);

		System.out.println("numOfClasses: " + numOfClasses);
		int D = 4;
		double[][][][] L = new double[numOfClasses][numOfPrototpes_][numOfAttributes - 1][D];
		double[][][][] H = new double[numOfClasses][numOfPrototpes_][numOfAttributes - 1][D];

		for (int Cl = 0; Cl < numOfClasses; Cl++) {
			for (int p = 0; p < numOfPrototpes_; p++) {
				for (int At = 0; At < numOfAttributes - 1; At++) {
					L[Cl][p][At][0] = 0;
					H[Cl][p][At][0] = Thresholds[Cl][p][At][0];
					L[Cl][p][At][1] = Math.max(0, Thresholds[Cl][p][At][1] - Thresholds[Cl][p][At][0]);
					H[Cl][p][At][1] = Thresholds[Cl][p][At][1] + Thresholds[Cl][p][At][0];
					L[Cl][p][At][2] = Math.max(0, H[Cl][p][At][1]);
					H[Cl][p][At][2] = Thresholds[Cl][p][At][2] + Thresholds[Cl][p][At][3];
					L[Cl][p][At][3] = 0;
					H[Cl][p][At][3] = Thresholds[Cl][p][At][3];
				}
			}
		}

		double[][][][] bestParams = optimizeParameters(TrainRealDE, L, H);

		if (bestParams == null || bestWeightParams == null) {
			System.out.println("\u26A0\uFE0F No optimized parameters found. Using original thresholds and weights.");
			bestParams = Thresholds;
			bestWeightParams = Weights;
		}

		double bestFitness = evaluateFitness(bestParams, bestWeightParams, TrainRealDE);
		resultOfClassificationTrainAfter = bestFitness;
		resultOfClassificationTestAfter = evaluateFitness(bestParams, bestWeightParams, TestRealDE);
	
		System.out.printf("\n\t* Training *= %.4f\t* Testing *= %.4f", resultOfClassificationTrainAfter,
				resultOfClassificationTestAfter);
		
		//  Adaptive weight tuning at test time
		double[][][] adjustedTestWeights = adjustWeightsAtTestTime(bestParams, bestWeightParams, TestRealDE);
		double resultOfClassificationTestAdaptive = evaluateFitness(bestParams, adjustedTestWeights, TestRealDE);

		// Use the better of the two
		resultOfClassificationTestAfter = Math.max(
		    evaluateFitness(bestParams, bestWeightParams, TestRealDE),
		    resultOfClassificationTestAdaptive
		);

		System.out.printf("\n\t* Training *= %.4f\t* Testing *= %.4f", resultOfClassificationTrainAfter, resultOfClassificationTestAfter);
		return resultOfClassificationTestAfter;
	}

	private static double evaluateFitness(double[][][][] thresholds, double[][][] weights, double[][] data) {
		return new Individual(thresholds, weights, data).fitness;
	}

	public static double[][][] getWeights() {
		return Weights;
	}

	public static void setWeights(double[][][] weights) {
		Weights = weights;
	}

	private static double[][][][] optimizeParameters(double[][] data, double[][][][] L, double[][][][] H) {
		int c = Thresholds.length, p = Thresholds[0].length, a = Thresholds[0][0].length + 1,
				D = Thresholds[0][0][0].length;
		double[][][][] pheromones = initializePheromones(c, p, a, D);
		double[][][] weightPheromones = new double[c][p][a - 1];
		for (int i = 0; i < c; i++)
			for (int j = 0; j < p; j++)
				Arrays.fill(weightPheromones[i][j], 1.0);

		double[][][][] bestParams = null;
		double[][][] bestWeights = null;
		double bestFitness = Double.NEGATIVE_INFINITY;

		final int EARLY_STOP_LIMIT = 15;
		final double MIN_IMPROVEMENT = 0.0005;
		int stagnantCounter = 0;
		double lastBestFitness = Double.NEGATIVE_INFINITY;

		for (int iter = 0; iter < NUM_ITERATIONS; iter++) {
			Params[] antParams = new Params[NUM_ANTS];
			double[] antFitnesses = new double[NUM_ANTS];

			for (int a_ = 0; a_ < NUM_ANTS; a_++) {
				antParams[a_] = generateRandomParams(c, p, a, D, L, H, pheromones, weightPheromones);
				antFitnesses[a_] = evaluateFitness(antParams[a_].thresholds, antParams[a_].weights, data);
			}

			int bestIndex = 0;
			for (int a_ = 1; a_ < NUM_ANTS; a_++) {
				if (antFitnesses[a_] > antFitnesses[bestIndex])
					bestIndex = a_;
			}

			double[][][][] bestAnt = antParams[bestIndex].thresholds;
			double[][][] bestAntWeights = antParams[bestIndex].weights;
			// Step 1: Local search on thresholds
			double[][][][] improvedThresholds = localSearch(bestAnt, bestAntWeights, data, L, H, iter);
			double improvedFitness = evaluateFitness(improvedThresholds, bestAntWeights, data);

			// Step 2: Local search on weights (optional, greedy)
			double[][][] improvedWeights = localSearchWeights(bestAntWeights, improvedThresholds, data);
			double improvedWeightFitness = evaluateFitness(improvedThresholds, improvedWeights, data);

			// Accept weights only if fitness improves
			if (improvedWeightFitness > improvedFitness) {
				bestAntWeights = improvedWeights;
				improvedFitness = improvedWeightFitness;
			}

			if (improvedFitness > bestFitness) {
				bestFitness = improvedFitness;
				bestParams = improvedThresholds;
				bestWeights = bestAntWeights;
				stagnantCounter = 0;
			} else {
				stagnantCounter++;
			}

			if (Math.abs(bestFitness - lastBestFitness) < MIN_IMPROVEMENT)
				stagnantCounter++;
			else
				stagnantCounter = 0;
			lastBestFitness = bestFitness;

			if (stagnantCounter >= EARLY_STOP_LIMIT) {
				System.out.println("⏹️ Adaptive early stopping at iteration " + iter);
				break;
			}

			// Top-k deposit
			int topK = Math.max(1, NUM_ANTS / 10);
			Integer[] indices = new Integer[NUM_ANTS];
			for (int i = 0; i < NUM_ANTS; i++)
				indices[i] = i;
			Arrays.sort(indices, (a_, b_) -> Double.compare(antFitnesses[b_], antFitnesses[a_]));

			for (int rank = 0; rank < topK; rank++) {
				// Always update threshold pheromones
				depositPheromones(pheromones, antParams[indices[rank]].thresholds, antFitnesses[indices[rank]]);
			}

			//  Only deposit weight pheromones for the accepted best weights
			depositWeightPheromones(weightPheromones, bestAntWeights, improvedFitness);

			evaporatePheromones(pheromones, iter);
			evaporateWeightPheromones(weightPheromones, iter);
		}

		if (bestParams == null || bestWeights == null) {
			System.out.println("Optimization failed — using fallback values from first ant.");
			bestParams = generateRandomParams(c, p, a, D, L, H, pheromones, weightPheromones).thresholds;
			bestWeights = generateRandomParams(c, p, a, D, L, H, pheromones, weightPheromones).weights;
		} else {
			System.out.println(" Optimization completed with fitness: " + bestFitness);
		}

		bestWeightParams = bestWeights;
		return bestParams;
	}

	private static void depositPheromones(double[][][][] pheromones, double[][][][] thresholds, double fitness) {
		double increment = calculatePheromoneIncrement(fitness);
		for (int i = 0; i < pheromones.length; i++)
			for (int j = 0; j < pheromones[0].length; j++)
				for (int k = 0; k < pheromones[0][0].length; k++)
					for (int l = 0; l < pheromones[0][0][0].length; l++)
						pheromones[i][j][k][l] += increment;
	}

	private static void depositWeightPheromones(double[][][] weightPheromones, double[][][] weights, double fitness) {
		double increment = calculatePheromoneIncrement(fitness);
		for (int i = 0; i < weightPheromones.length; i++)
			for (int j = 0; j < weightPheromones[0].length; j++)
				for (int k = 0; k < weightPheromones[0][0].length; k++)
					weightPheromones[i][j][k] += increment * weights[i][j][k];
	}

	private static double calculatePheromoneIncrement(double fitness) {
		if (fitness > bestFitnessSoFar)
			bestFitnessSoFar = fitness;
		return BASE_PHEROMONE_INCREMENT * (fitness - bestFitnessSoFar) / Math.max(bestFitnessSoFar, 1e-6);
	}

	private static double[][][][] initializePheromones(int c, int p, int a, int D) {
		double[][][][] pheromones = new double[c][p][a - 1][D]; // exclude class attribute
		for (int i = 0; i < c; i++) {
			for (int j = 0; j < p; j++) {
				for (int k = 0; k < a - 1; k++) {
					Arrays.fill(pheromones[i][j][k], 1.0); // initialize all thresholds to 1.0
				}
			}
		}
		return pheromones;
	}

	private static Params generateRandomParams(int c, int p, int a, int D, double[][][][] L, double[][][][] H,
			double[][][][] pheromones, double[][][] weightPheromones) {
		double[][][][] thresholds = new double[c][p][a - 1][D];
		double[][][] weights = new double[c][p][a - 1];

		for (int cl = 0; cl < c; cl++) {
			for (int pr = 0; pr < p; pr++) {
				for (int at = 0; at < a - 1; at++) {
					for (int th = 0; th < D; th++) {
						double lower = L[cl][pr][at][th];
						double upper = H[cl][pr][at][th];
						thresholds[cl][pr][at][th] = lower + Math.random() * (upper - lower);
					}
					weights[cl][pr][at] = Math.min(1.0, (weightPheromones[cl][pr][at] + 1.0) * Math.random());
				}
			}
		}
		// Normalize weights so they sum to 1 per prototype
		for (int cl = 0; cl < c; cl++) {
			for (int pr = 0; pr < p; pr++) {
				double sum = 0.0;
				for (int at = 0; at < a - 1; at++) {
					sum += weights[cl][pr][at];
				}
				if (sum > 0) {
					for (int at = 0; at < a - 1; at++) {
						weights[cl][pr][at] /= sum;
					}
				}
			}
		}

		return new Params(thresholds, weights);
	}

	private static double[][][][] localSearch(double[][][][] currentThresholds, double[][][] currentWeights,
			double[][] data, double[][][][] L, double[][][][] H, int iter) {

		double[][][][] bestParams = deepCopy(currentThresholds);
		double bestFitness = evaluateFitness(bestParams, currentWeights, data);

		// Dynamic step size that shrinks over iterations
		double epsilon = 0.02;
		for (int cl = 0; cl < bestParams.length; cl++) {
			for (int p = 0; p < bestParams[0].length; p++) {
				for (int att = 0; att < bestParams[0][0].length; att++) {
					for (int d = 0; d < bestParams[0][0][0].length; d++) {
						double original = bestParams[cl][p][att][d];
						double step = epsilon * (H[cl][p][att][d] - L[cl][p][att][d]);

						double[] candidates = { Math.min(original + step, H[cl][p][att][d]),
								Math.max(original - step, L[cl][p][att][d]) };

						for (double candidate : candidates) {
							bestParams[cl][p][att][d] = candidate;
							double newFitness = evaluateFitness(bestParams, currentWeights, data);

							if (newFitness > bestFitness) {
								bestFitness = newFitness;
								original = candidate;
							} else {
								bestParams[cl][p][att][d] = original; // revert
							}
						}
					}
				}
			}
		}

		return bestParams;
	}

	private static void evaporatePheromones(double[][][][] pheromones, int iteration) {
		double rate = 0.1 + 0.2 * (1.0 - ((double) iteration / NUM_ITERATIONS)); // dynamic evaporation rate
		for (int i = 0; i < pheromones.length; i++)
			for (int j = 0; j < pheromones[0].length; j++)
				for (int k = 0; k < pheromones[0][0].length; k++)
					for (int l = 0; l < pheromones[0][0][0].length; l++)
						pheromones[i][j][k][l] *= (1.0 - rate); // decay pheromone
	}

	private static void evaporateWeightPheromones(double[][][] weightPheromones, int iteration) {
		double rate = 0.1 + 0.2 * (1.0 - ((double) iteration / NUM_ITERATIONS)); // same logic as thresholds
		for (int i = 0; i < weightPheromones.length; i++)
			for (int j = 0; j < weightPheromones[0].length; j++)
				for (int k = 0; k < weightPheromones[0][0].length; k++)
					weightPheromones[i][j][k] *= (1.0 - rate);
	}

	private static double[][][][] deepCopy(double[][][][] original) {
		int a = original.length;
		int b = original[0].length;
		int c = original[0][0].length;
		int d = original[0][0][0].length;

		double[][][][] copy = new double[a][b][c][d];
		for (int i = 0; i < a; i++)
			for (int j = 0; j < b; j++)
				for (int k = 0; k < c; k++)
					System.arraycopy(original[i][j][k], 0, copy[i][j][k], 0, d);

		return copy;
	}

	private static double[][][] localSearchWeights(double[][][] weights, double[][][][] thresholds, double[][] data) {
		double[][][] bestWeights = deepCopyWeights(weights);
		double bestFitness = evaluateFitness(thresholds, bestWeights, data);

		double epsilon = 0.05; // you can reduce this to 0.02 if tuning gets too aggressive

		for (int cl = 0; cl < bestWeights.length; cl++) {
			for (int p = 0; p < bestWeights[0].length; p++) {
				for (int at = 0; at < bestWeights[0][0].length; at++) {
					double original = bestWeights[cl][p][at];

					// Generate slightly higher and lower values
					double[] candidates = { Math.min(1.0, original + epsilon), Math.max(0.0, original - epsilon) };

					for (double candidate : candidates) {
						bestWeights[cl][p][at] = candidate;

						// 🔄 Normalize weights for this prototype
						normalizeWeights(bestWeights[cl][p]);

						double newFitness = evaluateFitness(thresholds, bestWeights, data);
						if (newFitness > bestFitness) {
							bestFitness = newFitness;
							original = candidate;
						} else {
							bestWeights[cl][p][at] = original;
							normalizeWeights(bestWeights[cl][p]);
						}
					}
				}
			}
		}

		return bestWeights;
	}

	private static double[][][] deepCopyWeights(double[][][] original) {
		int a = original.length, b = original[0].length, c = original[0][0].length;
		double[][][] copy = new double[a][b][c];
		for (int i = 0; i < a; i++)
			for (int j = 0; j < b; j++)
				System.arraycopy(original[i][j], 0, copy[i][j], 0, c);
		return copy;
	}

	private static void normalizeWeights(double[] weights) {
		double sum = 0.0;
		for (double w : weights)
			sum += w;

		if (sum > 0) {
			for (int i = 0; i < weights.length; i++) {
				weights[i] /= sum;
			}
		}
	}
	
	private static double[][][] adjustWeightsAtTestTime(double[][][][] thresholds, double[][][] originalWeights, double[][] testData) {
	    double[][][] bestWeights = deepCopyWeights(originalWeights);
	    double bestFitness = evaluateFitness(thresholds, bestWeights, testData);

	    double epsilon = 0.05; // Step size for tweaking weights

	    for (int cl = 0; cl < bestWeights.length; cl++) {
	        for (int pr = 0; pr < bestWeights[0].length; pr++) {
	            for (int at = 0; at < bestWeights[0][0].length; at++) {
	                double original = bestWeights[cl][pr][at];
	                
	                // Try increasing and decreasing slightly
	                double[] candidates = {
	                    Math.min(1.0, original + epsilon),
	                    Math.max(0.0, original - epsilon)
	                };

	                for (double candidate : candidates) {
	                    bestWeights[cl][pr][at] = candidate;

	                    // Normalize weights after tweak
	                    normalizePrototypeWeights(bestWeights[cl][pr]);

	                    double newFitness = evaluateFitness(thresholds, bestWeights, testData);

	                    if (newFitness > bestFitness) {
	                        bestFitness = newFitness;
	                        original = candidate;
	                    } else {
	                        bestWeights[cl][pr][at] = original;
	                    }
	                }
	            }
	        }
	    }

	    return bestWeights;
	}
	private static void normalizePrototypeWeights(double[] weights) {
	    double sum = Arrays.stream(weights).sum();
	    if (sum > 0) {
	        for (int i = 0; i < weights.length; i++) {
	            weights[i] /= sum;
	        }
	    }
	}

	


}
