// Dr. Feras Al-Obeidat
// March 2025

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;

import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

import weka.filters.Filter;
import weka.filters.supervised.attribute.AttributeSelection; // Filter version
import weka.filters.supervised.attribute.NominalToBinary;
import weka.attributeSelection.CfsSubsetEval;
import weka.attributeSelection.InfoGainAttributeEval;
import weka.attributeSelection.Ranker;
import weka.attributeSelection.BestFirst;
import weka.filters.unsupervised.attribute.Standardize;
import weka.filters.unsupervised.attribute.NumericToNominal;
import weka.filters.Filter;
import weka.filters.supervised.instance.Resample;
import weka.filters.Filter;

//import weka.filters.unsupervised.attribute.NominalToBinary;

public class Feras_PROANT {
	// Constants
	static final int D = 4;
	static final int folds = 10;
	static int population = 100;
	static int generation = 200;
	// Dataset-related
	static List<String> classes;
	static int numOfClasses;
	static int numOfAttributes;
	static int numOfPrototypes = 4;
	static boolean useAttributeSelection = true;
	// Results
	static List<Double> testingResults = new ArrayList<>();
	static double[][] TrainReal, TestReal;

	public static void main(String[] args) throws Exception {
		// Load and configure data
		Instances data = loadAndPrepareData("diabetes_prediction_dataset.arff");
		System.out.println(
				" Cleaned data: " + data.numInstances() + " instances, " + data.numAttributes() + " attributes");

		// 1. Remove duplicates
		weka.filters.unsupervised.instance.RemoveDuplicates duplicateRemover = new weka.filters.unsupervised.instance.RemoveDuplicates();
		duplicateRemover.setInputFormat(data);
		data = Filter.useFilter(data, duplicateRemover);

		// 2. Remove instances with missing values
		weka.filters.unsupervised.instance.RemoveWithValues removeMissing = new weka.filters.unsupervised.instance.RemoveWithValues();
		removeMissing.setInputFormat(data);
		data = Filter.useFilter(data, removeMissing);

		// 3. Remove outliers using IQR method (custom method you define)
		data = removeOutliersUsingIQR(data); // Make sure this method is implemented

		// data = removeOutliersUsingIQR(data);

		System.out.println(
				" Cleaned data: " + data.numInstances() + " instances, " + data.numAttributes() + " attributes");

		NumericToNominal convert = new NumericToNominal();
		convert.setAttributeIndices("last"); // convert the last attribute (diabetes)
		convert.setInputFormat(data);
		data = Filter.useFilter(data, convert);
		if (data.classIndex() == -1)
			data.setClassIndex(data.numAttributes() - 1); // set 'diabetes' as class
		List<String> classes = getClassLabels(data);
		System.out.println("✅ Classes: " + classes);

//		Check Class Balance
		// Assuming 'data' is your Instances object and class is already set
		if (data.classIndex() == -1) {
			data.setClassIndex(data.numAttributes() - 1); // Set last attribute as class (if not set)
		}

		int numClasses = data.numClasses();
		int[] classCounts = new int[numClasses];

		for (int i = 0; i < data.numInstances(); i++) {
			int classIndex = (int) data.instance(i).classValue();
			classCounts[classIndex]++;
		}

		// Print results
		System.out.println("🔍 Class Distribution:");
		for (int i = 0; i < numClasses; i++) {
			System.out.println("Class " + data.classAttribute().value(i) + ": " + classCounts[i] + " instances");
		}

		// 4. Resample: take 30% of cleaned data
		weka.filters.supervised.instance.Resample resample = new weka.filters.supervised.instance.Resample();
		resample.setNoReplacement(true); // Sample without replacement
		resample.setSampleSizePercent(35); // 35% of original data
	//	resample.setBiasToUniformClass(1.0); // Force balanced class distribution
		resample.setInputFormat(data);
		Instances resampledData = Filter.useFilter(data, resample);
		data = resampledData;

		System.out.println("Resampled Data Size: " + data.numInstances());

		// applu one hotencoding
		weka.filters.unsupervised.attribute.NominalToBinary encoder = new weka.filters.unsupervised.attribute.NominalToBinary();
		// Specify nominal attribute indices (e.g., 1 for gender, 5 for smoking_history)
		// NominalToBinary encoder = new NominalToBinary();
		encoder.setAttributeIndices("1,5"); // or whatever the correct indexes are in your dataset
		encoder.setInputFormat(data);
		// Apply encoding only on those
		data = Filter.useFilter(data, encoder);

		// Set class index (again, just to be safe)
		if (data.classIndex() == -1)
			data.setClassIndex(data.numAttributes() - 1); // Assuming the class is still the last

		// ✅ Re-extract class labels AFTER encoding
		classes = getClassLabels(data); // This should initialize the 'classes' list properly
		numOfClasses = data.numClasses();
		numOfAttributes = data.numAttributes();

		System.out.println(" Number of Attributes: " + numOfAttributes);
		System.out.println("  Number of Classes: " + numOfClasses);
		System.out.println(" Number of Prototypes: " + numOfPrototypes);

		System.out.println("🔍 First 5 Instances with Attribute Names:\n");
		int numToPrint = Math.min(5, data.numInstances());
		/*
		 * for (int i = 0; i < numToPrint; i++) { Instance instance = data.instance(i);
		 * System.out.println("Instance " + (i + 1) + ":");
		 * 
		 * for (int j = 0; j < data.numAttributes(); j++) { String attrName =
		 * data.attribute(j).name(); String value = data.attribute(j).isNominal() ?
		 * instance.stringValue(j) : String.valueOf(instance.value(j));
		 * System.out.println("  " + attrName + ": " + value); }
		 * 
		 * System.out.println(); // empty line between instances }
		 */
		// Extract class labels
		classes = getClassLabels(data);
		// Cross-validation

		// System.exit(0);
		performOptimization(data);
	}

	private static Instances loadAndPrepareData(String path) throws Exception {
		Instances data = DataSource.read(path);
		data.setClassIndex(data.numAttributes() - 1);
		return data;
	}

	public static List<String> getClassLabels(Instances data) {
		List<String> labels = new ArrayList<>();
		if (!data.classAttribute().isNominal()) {
			throw new IllegalArgumentException("Class attribute must be nominal.");
		}

		for (int i = 0; i < data.classAttribute().numValues(); i++) {
			labels.add(data.classAttribute().value(i));
		}

		return labels;
	}

	private static void performOptimization(Instances data) throws Exception {
		Random rand = new Random(4);
		Instances randData = new Instances(data);
		randData.randomize(rand);

		// only for attribute selection

		// Step 1: Set up the attribute selection filter
		// Make sure the class index is set
		if (randData.classIndex() == -1)
			randData.setClassIndex(randData.numAttributes() - 1);

		if (useAttributeSelection) {
			// Apply attribute selection

			AttributeSelection filter = new AttributeSelection();
			CfsSubsetEval evaluator = new CfsSubsetEval();
			BestFirst search = new BestFirst();
			filter.setEvaluator(evaluator);
			filter.setSearch(search);
			filter.setInputFormat(randData); // <-- error usually caused by missing class index
			Instances reducedData = Filter.useFilter(randData, filter);
			randData = reducedData;

			System.out.println("Selected attributes: " + randData.numAttributes());
			System.out.println("Selected attributes:");
			for (int i = 0; i < randData.numAttributes(); i++) {
				System.out.println("- " + randData.attribute(i).name());
			}

			numOfAttributes = randData.numAttributes();
			// Extract class labels
			classes = getClassLabels(randData);
			// only for attribute selection
		}
		// System.exit(D);

		numOfAttributes = randData.numAttributes();
		// Extract class labels
		classes = getClassLabels(randData);

		PROAnt proAnt = new PROAnt();
		proAnt.setParameters(generation, population);
		proAnt.sendNumOfClassesAndAttributes(numOfClasses, numOfAttributes);

		// Split 75% for training, 25% for testing
		int trainSize = (int) Math.round(randData.numInstances() * 0.75);
		int testSize = randData.numInstances() - trainSize;

		// Generate training and testing sets
		Instances train = new Instances(randData, 0, trainSize);
		Instances test = new Instances(randData, trainSize, testSize);

		// Convert to 2D arrays
		if (classes == null || classes.isEmpty()) {
			throw new IllegalStateException("Class labels have not been initialized. Call getClassLabels() first.");
		}

		TrainReal = to2DArray(train);
		TestReal = to2DArray(test);

		// Threshold calculation and ACO execution
		// Compute thresholds

		double[][][][] thresholds = computeThresholds(TrainReal);
		// Weights: [class][prototype][attribute]

		double[][][] weights = new double[numOfClasses][numOfPrototypes][numOfAttributes];
		Random rand1 = new Random();
		for (int cl = 0; cl < numOfClasses; cl++) {
			for (int p = 0; p < numOfPrototypes; p++) {
				for (int att = 0; att < numOfAttributes; att++) {
					weights[cl][p][att] = rand1.nextDouble(); // value in [0, 1]
//		            System.out.printf("Weight[class=%d][prototype=%d][attribute=%d] = %.4f%n",
//		                    cl, p, att, weights[cl][p][att]);
				}
			}
		}
		proAnt.sendTrainingTestingAndThresholds(TrainReal, thresholds, weights, TestReal);

		double result = proAnt.run();
		System.out.println("  ** the results: " + result);
		// summarizeResults(sum, resultTestBefore, resultTrainBefore, resultTrainAfter);
	}

	private static double[][] to2DArray(Instances data) {
		if (classes == null || classes.isEmpty()) {
			throw new IllegalStateException("Class labels have not been initialized. Call getClassLabels() first.");
		}

		double[][] result = new double[data.numInstances()][numOfAttributes];
		for (int i = 0; i < data.numInstances(); i++) {
			Instance inst = data.instance(i);
			for (int j = 0; j < numOfAttributes - 1; j++) {
				result[i][j] = inst.value(j);
			}
			result[i][numOfAttributes - 1] = classes.indexOf(inst.stringValue(inst.classAttribute()));
		}
		return result;
	}

	private static double[][][][] computeThresholds(double[][] trainData) {
		return ThresholdUtil.compute(trainData, numOfClasses, numOfPrototypes, numOfAttributes, D);
	}

	public static Instances removeOutliersUsingIQR(Instances data) throws Exception {
		Instances filteredData = new Instances(data);

		// List of attributes you want to apply IQR to
		List<String> numericAttributes = Arrays.asList("age", "bmi", "HbA1c_level", "blood_glucose_level");

		for (String attrName : numericAttributes) {
			int attrIndex = filteredData.attribute(attrName).index();
			List<Double> values = new ArrayList<>();

			// Collect values
			for (int i = 0; i < filteredData.numInstances(); i++) {
				double val = filteredData.instance(i).value(attrIndex);
				values.add(val);
			}

			// Sort and calculate IQR
			Collections.sort(values);
			double q1 = percentile(values, 25);
			double q3 = percentile(values, 75);
			double iqr = q3 - q1;
			double lowerBound = q1 - 1.5 * iqr;
			double upperBound = q3 + 1.5 * iqr;

			// Remove instances outside [lowerBound, upperBound]
			Instances temp = new Instances(filteredData);
			filteredData.clear();
			for (int i = 0; i < temp.numInstances(); i++) {
				double val = temp.instance(i).value(attrIndex);
				if (val >= lowerBound && val <= upperBound) {
					filteredData.add(temp.instance(i));
				}
			}
		}

		return filteredData;
	}

	// Utility method to compute percentile
	public static double percentile(List<Double> sortedValues, double percentile) {
		int index = (int) Math.ceil(percentile / 100.0 * sortedValues.size()) - 1;
		return sortedValues.get(Math.max(0, Math.min(index, sortedValues.size() - 1)));
	}

}
