import networkx as nx
import numpy as np
from collections import defaultdict
import os
import math
from main import ocd_gempa, load_graph_from_txt, save_communities, evaluate_communities
def load_ground_truth(file_path):
    """
    Load the Ground Truth Community Partition of the LFR Network
    Returns: Community List
    """
    # Use a Dictionary to Store the Nodes Contained in Each Community
    communities_dict = defaultdict(set)

    # Read and Parse the File
    with open(file_path, 'r') as f:
        for line in f:
            # Remove whitespace and split
            parts = line.strip().split()
            if not parts:  # Skip empty lines 
                continue

            try:
                # The first number is the node ID
                node = int(parts[0])
                # All the numbers after that are community ID
                community_ids = list(map(int, parts[1:]))

                # Add the node to each community it belongs to
                for comm_id in community_ids:
                    communities_dict[comm_id].add(node)
            except (ValueError, IndexError) as e:
                print(f"Warning: Skip malformed lines: {line.strip()}")
                continue

    communities = list(communities_dict.values())

    if not communities:
        raise ValueError("Failed to read valid community partitions from the file")

    print(f"Successfully loaded community partitions: {len(communities)} communities in total")
    # Count overlapping nodes
    all_nodes = set()
    overlapping_nodes = set()
    for comm in communities:
        for node in comm:
            if node in all_nodes:
                overlapping_nodes.add(node)
            all_nodes.add(node)
    print(f"Total number of nodes: {len(all_nodes)}, Number of overlapping nodes: {len(overlapping_nodes)}")

    return communities


def calculate_entropy(communities, n_nodes):
    """
    Calculate the entropy H(X) of community partition X
    """
    entropy = 0.0
    eps = 1e-10 

    for comm in communities:
        p = len(comm) / n_nodes
        if p > 0:
            entropy -= p * np.log2(p + eps)

    return entropy


def calculate_conditional_entropy(communities_x, communities_y, n_nodes):
    """
    Calculate the conditional entropy H(X|Y)
    """
    eps = 1e-10  
    conditional_entropy = 0.0

    # For each community in Y
    for comm_y in communities_y:
        p_y = len(comm_y) / n_nodes  # P(y)
        if p_y <= 0:
            continue

        # For each community in X
        for comm_x in communities_x:
            intersection = len(comm_x.intersection(comm_y))
            if intersection > 0:
                p_xy = intersection / n_nodes 
                conditional_entropy -= p_xy * np.log2((p_xy + eps) / (p_y + eps))

    return conditional_entropy


def calculate_nmi_for_overlapping_communities(communities1, communities2):
    """
    Calculate the Normalized Mutual Information (NMI) of overlapping communities
    """
    # Get all nodes
    all_nodes = set()
    for comm in communities1 + communities2:
        all_nodes.update(comm)
    n_nodes = len(all_nodes)

    # Calculate H(X) and H(Y)
    h_x = calculate_entropy(communities1, n_nodes)
    h_y = calculate_entropy(communities2, n_nodes)

    # Caluate H(X|Y) and H(Y|X)
    h_x_given_y = calculate_conditional_entropy(communities1, communities2, n_nodes)
    h_y_given_x = calculate_conditional_entropy(communities2, communities1, n_nodes)

    # Caluate NMI
    if max(h_x, h_y) == 0:
        return 0.0

    nmi = (h_x - h_x_given_y + h_y - h_y_given_x) / (2 * max(h_x, h_y))

    return max(0.0, min(1.0, nmi))


def evaluate_lfr_network():
    """Evaluate the community detection results on the LFR benchmark networks"""
    # Algorithm parameter settings
    params = {
        # node2vec parameters
        'embedding_dim': 128,
        'walk_length': 80,
        'num_walks': 100,
        'p': 1.0,
        'q': 0.5,
        'workers': 4,

        # Label propagation parameters
        'max_iter': 200,
        'threshold': 0.3,

        # Input and output settings
        'network_file': "04LFR_networks/D/LFR19/network.dat",  # LFR network files
        'community_file': "04LFR_networks/LFR19/community.dat",  # Ground truth community partition files
        'output_dir': "04LFR_results/D/LFR19"  # Output directory for results
    }

    os.makedirs(params['output_dir'], exist_ok=True)

    print(f"Loading network data: {params['network_file']}")
    G = load_graph_from_txt(params['network_file'])
    print(f"Network information: number of nodes={G.number_of_nodes()}, number of edges={G.number_of_edges()}")

    # Loading ground truth community partitions
    print(f"Loading real community divisions: {params['community_file']}")
    ground_truth = load_ground_truth(params['community_file'])

    print("\nStarting to run community detection algorithm...")
    detected_communities = ocd_gempa(G, params)

    # Calculating evaluation metrics
    metrics = evaluate_communities(G, detected_communities)

    # Calculating NMI
    nmi = calculate_nmi_for_overlapping_communities(ground_truth, detected_communities)

    # Calculating overlapping node information
    true_overlapping_nodes = set()
    for i, comm1 in enumerate(ground_truth):
        for j, comm2 in enumerate(ground_truth):
            if i < j:
                true_overlapping_nodes.update(comm1.intersection(comm2))

    detected_overlapping_nodes = set()
    for i, comm1 in enumerate(detected_communities):
        for j, comm2 in enumerate(detected_communities):
            if i < j:
                detected_overlapping_nodes.update(comm1.intersection(comm2))

    # Saving results
    output_file = os.path.join(params['output_dir'], 'network_communities.txt')
    save_communities(detected_communities, output_file)

    # Outputting results
    print("\nDetection results:")
    print(f"Detected Number of Communities: {len(detected_communities)}")
    print(f"True Number of Communities: {len(ground_truth)}")
    print(f"NMI Value: {nmi:.4f}")
    print(f"Extended Modularity (EQ): {metrics['EQ']:.4f}")

    # Saving evaluation metrics
    metrics_file = os.path.join(params['output_dir'], 'network_metrics.txt')
    with open(metrics_file, 'w', encoding='utf-8') as f:
        f.write(f"Detected Number of Communities: {len(detected_communities)}\n")
        f.write(f"True Number of Communities: {len(ground_truth)}\n")
        f.write(f"NMI Value: {nmi:.4f}\n")
        f.write(f"Extended Modularity (EQ): {metrics['EQ']:.4f}\n")

if __name__ == "__main__":
    evaluate_lfr_network()
