import numpy as np
from collections import defaultdict
import os


def load_ground_truth(file_path):
    """
    Load the true community partitioning of LFR network (supports overlapping communities)
    File format: each line is "nodeID communityID1 communityID2 ..."
    Returns: community list, each community is a set of nodes
    """
    # Use a dictionary to store the nodes in each community
    communities_dict = defaultdict(set)

    # Read and parse the file
    with open(file_path, 'r') as f:
        for line in f:
            # Remove whitespace and split
            parts = line.strip().split()
            if not parts:  # Skip empty lines
                continue

            try:
                # First number is the node ID
                node = int(parts[0])
                # All subsequent numbers are community IDs
                community_ids = list(map(int, parts[1:]))

                # Add the node to each community it belongs to
                for comm_id in community_ids:
                    communities_dict[comm_id].add(node)
            except (ValueError, IndexError) as e:
                print(f"Warning: Skipping malformatted line: {line.strip()}")
                continue

    # Convert to list form for return
    communities = list(communities_dict.values())

    # Basic validation
    if not communities:
        raise ValueError("No valid community partitioning found in the file")

    print(f"Successfully loaded community partitioning: {len(communities)} communities")
    # Count overlapping nodes
    all_nodes = set()
    overlapping_nodes = set()
    for comm in communities:
        for node in comm:
            if node in all_nodes:
                overlapping_nodes.add(node)
            all_nodes.add(node)
    print(f"Total nodes: {len(all_nodes)}, Overlapping nodes: {len(overlapping_nodes)}")

    return communities, all_nodes, overlapping_nodes


def load_detected_communities(file_path):
    """
    Load detected communities from algorithm output file
    """
    communities = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()

            # Skip first two lines (title and empty line)
            start_line = 2
            for line in lines[start_line:]:
                if line.strip() and "community" in line:
                    # Format example: "Community1: [1, 2, 3, 4, 5]"
                    parts = line.split(':')
                    if len(parts) > 1:
                        community_str = parts[1].strip()
                        # Remove brackets and split
                        nodes_str = community_str.strip('[]')
                        if nodes_str:
                            nodes = set(map(int, nodes_str.split(', ')))
                            communities.append(nodes)

        all_nodes = set()
        overlapping_nodes = set()
        for comm in communities:
            for node in comm:
                if node in all_nodes:
                    overlapping_nodes.add(node)
                all_nodes.add(node)

        print(f"Detected communities: {len(communities)}")
        print(f"Detected nodes: {len(all_nodes)}")
        print(f"Detected overlapping nodes: {len(overlapping_nodes)}")
    except Exception as e:
        print(f"Error reading detected communities: {e}")
        print("Trying generic loading function...")
        communities, all_nodes, overlapping_nodes = load_ground_truth(file_path)

    return communities, all_nodes, overlapping_nodes


def get_node_to_communities_map(communities):
    """
    Build a mapping from nodes to communities
    Returns: dictionary, keys are node IDs, values are sets of community IDs the node belongs to
    """
    node_to_communities = defaultdict(set)
    for i, comm in enumerate(communities):
        for node in comm:
            node_to_communities[node].add(i)
    return node_to_communities


def calculate_omega_index(communities1, communities2):
    """
    Calculate the Omega index, evaluating the similarity between two overlapping community partitions
    Omega = (Agree_same + Agree_different) / Total_pairs

    Parameters:
        communities1: First community partition, list form, each element is a set of nodes
        communities2: Second community partition, list form, each element is a set of nodes
    Returns:
        omega: Omega index, range [0, 1], 1 means complete agreement, 0 means complete disagreement
    """
    # Get all nodes
    all_nodes = set()
    for comm in communities1 + communities2:
        all_nodes.update(comm)

    # Build node to community mappings
    node_to_comms1 = get_node_to_communities_map(communities1)
    node_to_comms2 = get_node_to_communities_map(communities2)

    # Calculate node pair relationships
    agree_same = 0  # Number of node pairs in the same community in both partitions
    agree_diff = 0  # Number of node pairs in different communities in both partitions
    total_pairs = 0  # Total number of node pairs

    nodes_list = list(all_nodes)
    for i in range(len(nodes_list)):
        for j in range(i + 1, len(nodes_list)):
            node1, node2 = nodes_list[i], nodes_list[j]

            # Whether in the same community in partition 1
            same_comm1 = len(node_to_comms1[node1].intersection(node_to_comms1[node2])) > 0

            # Whether in the same community in partition 2
            same_comm2 = len(node_to_comms2[node1].intersection(node_to_comms2[node2])) > 0

            # Count agreement
            if same_comm1 == same_comm2:
                if same_comm1:
                    agree_same += 1
                else:
                    agree_diff += 1

            total_pairs += 1

    # Calculate Omega index
    omega = (agree_same + agree_diff) / total_pairs if total_pairs > 0 else 0

    return omega


def calculate_overlapping_f1(ground_truth, detected_communities):
    """
    Calculate F1 score for overlapping node identification

    Parameters:
        ground_truth: True community partition, in list form
        detected_communities: Algorithm detected community partition, in list form
    Returns:
        f1: F1 score
        precision: Precision
        recall: Recall
    """
    # Identify true overlapping nodes and detected overlapping nodes
    true_overlapping = set()
    all_true_nodes = set()
    node_community_count = defaultdict(int)

    for comm in ground_truth:
        for node in comm:
            node_community_count[node] += 1
            all_true_nodes.add(node)

    for node, count in node_community_count.items():
        if count > 1:
            true_overlapping.add(node)

    detected_overlapping = set()
    all_detected_nodes = set()
    node_community_count = defaultdict(int)

    for comm in detected_communities:
        for node in comm:
            node_community_count[node] += 1
            all_detected_nodes.add(node)

    for node, count in node_community_count.items():
        if count > 1:
            detected_overlapping.add(node)

    # Ensure node ID consistency
    common_nodes = all_true_nodes.intersection(all_detected_nodes)

    # Create binary labels (1 for overlapping nodes, 0 for non-overlapping nodes)
    y_true = []
    y_pred = []

    # Set labels
    for node in common_nodes:
        y_true.append(1 if node in true_overlapping else 0)
        y_pred.append(1 if node in detected_overlapping else 0)

    # Calculate precision, recall and F1 score
    true_positives = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 1)
    false_positives = sum(1 for t, p in zip(y_true, y_pred) if t == 0 and p == 1)
    false_negatives = sum(1 for t, p in zip(y_true, y_pred) if t == 1 and p == 0)

    if true_positives + false_positives == 0:
        precision = 0.0
    else:
        precision = true_positives / (true_positives + false_positives)

    if true_positives + false_negatives == 0:
        recall = 0.0
    else:
        recall = true_positives / (true_positives + false_negatives)

    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)

    return f1, precision, recall


def evaluate_metrics(ground_truth_file, detected_communities_file, output_file=None):
    """
    Calculate advanced metrics: Omega index and overlapping node F1 score

    Parameters:
        ground_truth_file: Path to true community partition file
        detected_communities_file: Path to detected community partition file
        output_file: Path to output results file, if None then only print without saving
    Returns:
        metrics: Dictionary containing evaluation metrics
    """
    # Load true community partition
    ground_truth, _, true_overlapping_nodes = load_ground_truth(ground_truth_file)

    # Load detected community partition
    detected_communities, _, detected_overlapping_nodes = load_detected_communities(detected_communities_file)

    # Calculate overlapping node F1 score
    overlap_f1, overlap_precision, overlap_recall = calculate_overlapping_f1(
        ground_truth, detected_communities
    )

    # Collect metrics
    metrics = {
        "overlap_f1": overlap_f1,
        "overlap_precision": overlap_precision,
        "overlap_recall": overlap_recall
    }

    # Print results
    # print("\nAdvanced evaluation metrics:")
    # print(f"Omega index: {metrics['omega']:.4f}")
    print(f"Overlapping node F1 score: {metrics['overlap_f1']:.4f}")
    print(f"Overlapping node precision: {metrics['overlap_precision']:.4f}")
    print(f"Overlapping node recall: {metrics['overlap_recall']:.4f}")

    # Print overlapping node count information for debugging
    print(f"\nDebug information:")
    print(f"True overlapping node count: {len(true_overlapping_nodes)}")
    print(f"Detected overlapping node count: {len(detected_overlapping_nodes)}")

    # Save results
    if output_file:
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("Advanced evaluation metrics:\n")
            f.write(f"Overlapping node F1 score: {metrics['overlap_f1']:.4f}\n")
            f.write(f"Overlapping node precision: {metrics['overlap_precision']:.4f}\n")
            f.write(f"Overlapping node recall: {metrics['overlap_recall']:.4f}\n")
            f.write(f"\nDebug information:\n")
            f.write(f"True overlapping node count: {len(true_overlapping_nodes)}\n")
            f.write(f"Detected overlapping node count: {len(detected_overlapping_nodes)}\n")

        print(f"Results saved to: {output_file}")

    return metrics


if __name__ == "__main__":
    """
    Evaluate a single LFR network, just modify the following file paths
    """
    # Modify file paths here
    # True community partition file
    ground_truth_file = "04LFR_networks/LFR4/LFR4-3/community.dat"

    # Detected community partition file
    detected_file = "04LFR_results/LFR4/LFR4-3/network_communities.txt"

    # Evaluation results output file
    output_file = "04LFR_results/LFR4/LFR4-3/advanced_metrics.txt"

    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # Check if files exist
    if not os.path.exists(ground_truth_file):
        print(f"Error: True community file not found: {ground_truth_file}")
    elif not os.path.exists(detected_file):
        print(f"Error: Detected community file not found: {detected_file}")
    else:
        # Calculate metrics
        print(f"True community file: {ground_truth_file}")
        print(f"Detected community file: {detected_file}")
        print(f"Output results file: {output_file}")

        metrics = evaluate_metrics(ground_truth_file, detected_file, output_file)

        # Display important metrics
        print(f"\n===== Evaluation Results =====")
        print(f"Overlapping node F1 score: {metrics['overlap_f1']:.4f}")

        print("\nEvaluation complete! Full results have been saved to the output file.")