import networkx as nx
from node_importance import calculate_node_importance
from graph_embedding import create_node_embeddings
from label_propagation import LabelPropagation
from evaluate import evaluate_communities
import numpy as np
import os


def load_graph_from_txt(file_path):
    """Load graph data from file"""
    G = nx.Graph()
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip() and not line.startswith('#'):
                n1, n2 = map(int, line.strip().split())
                G.add_edge(n1, n2)
    return G


def save_communities(communities, output_file="communities.txt"):
    """Save the community detection results to a file"""
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(f"Detected Number of Communities: {len(communities)}\n\n")
        for i, community in enumerate(communities, 1):
            sorted_nodes = sorted(list(community))
            f.write(f"community{i}: {sorted_nodes}\n")


def find_overlapping_nodes(communities):
    """Find overlapping nodes that belong to multiple communities"""
    node_communities = {}
    for i, community in enumerate(communities, 1):
        for node in community:
            if node not in node_communities:
                node_communities[node] = []
            node_communities[node].append(i)

    return {node: comm_ids for node, comm_ids in node_communities.items()
            if len(comm_ids) > 1}


def GELPA_OCD(G, params):
    """
    Main algorithm function
    Parameters:
        G: NetworkX graph object  
        params: Algorithm parameters dictionary
    Returns:
        communities: List of detected overlapping communities
    """

    # 1. Calculate and rank node importance
    node_importance = calculate_node_importance(G)
    sorted_nodes = sorted(node_importance.items(), key=lambda x: x[1], reverse=True)
    update_sequence = [node for node, _ in sorted_nodes]

    # 2. Create node embeddings
    embeddings = create_node_embeddings(
        G,
        dimensions=params['embedding_dim'],
        walk_length=params['walk_length'],
        num_walks=params['num_walks'],
        p=params['p'],
        q=params['q'],
        workers=params['workers']
    )

    # 3. Construct similarity matrix
    similarity_matrix = {}
    nodes = list(G.nodes())
    for i in nodes:
        similarity_matrix[i] = {}
        for j in nodes:
            if i != j:
                sim = np.dot(embeddings[i], embeddings[j]) / (
                        np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j]))
                similarity_matrix[i][j] = sim
            else:
                similarity_matrix[i][j] = 1.0

    # 4. Label propagation
    lp = LabelPropagation(G, similarity_matrix, threshold=params['threshold'])
    communities = lp.run(update_sequence, params['max_iter'])

    return communities


def main():
    # Algorithm parameter settings
    params = {
        'embedding_dim': 128,
        'walk_length': 50, 
        'num_walks': 100, 
        'p': 1.0, 
        'q': 0.5, 
        'workers': 4,  

        # Label propagation parameters
        'max_iter': 50,  
        'threshold': 0.3,  

        # Input and output settings
        'input_file': "03Real_Datasets/football.txt",
        'output_file': "03Real_Datasets_Output/1_karate.txt"
    }

    print(f"Loading data: {params['input_file']}")
    G = load_graph_from_txt(params['input_file'])
    print(f"Network Information: Number of nodes={G.number_of_nodes()}, Number of edges={G.number_of_edges()}")

    communities = GELPA-OCD(G, params)

    # Calculate evaluation metrics
    metrics = evaluate_communities(G, communities)

    # Save results
    save_communities(communities, params['output_file'])

    # Output results
    print("\nDetection results:")
    print(f"Detected Number of Communities: {len(communities)}")
    print(f"EQ Value: {metrics['EQ']:.4f}")
    print(f"Average Community Size: {metrics['avg_community_size']:.2f}")

    overlapping_nodes = find_overlapping_nodes(communities)
    print(f"\nNumber of Overlapping Nodes: {len(overlapping_nodes)}")

    with open("03Real_Datasets_Output/1_karate.txt", 'a+', encoding='utf-8') as f:
        f.write(f"\nEQ Value: {metrics['EQ']:.4f}")
        f.write(f"\nAverage Community Size：{metrics['avg_community_size']:.2f}")
        f.write(f"\nNumber of Overlapping Nodes: {len(overlapping_nodes)}")


if __name__ == "__main__":
    main()