class LabelPropagation:
    def __init__(self, G, similarity_matrix, threshold=None):
        """
        Args:
            G: NetworkX Graph object
            similarity_matrix: Node similarity matrix
            threshold: Label filtering threshold. If None, use 1/number of labels.
        """
        self.G = G
        self.similarity_matrix = similarity_matrix
        self.labels = {}  # Node label set
        self.threshold = threshold

    def initialize_labels(self):
        """Initialize the label of each node as its own ID"""
        for node in self.G.nodes():
            self.labels[node] = {node: 1.0}

    def update_node_labels(self, node):
        """
        Update the label of a single node.
        Calculate the new label set using the weighted community belonging coefficient formula.
        """
        neighbor_labels = {}
        denominator = 0  # Denominator sum

        # Collect labels from all neighboring nodes and calculate the denominator.
        for neighbor in self.G.neighbors(node):
            weight = self.similarity_matrix[node][neighbor]
            for label, label_weight in self.labels[neighbor].items():
                denominator += label_weight * weight

        # If the denominator is 0, keep the original label unchanged
        if denominator == 0:
            return

        # Calculate the belonging coefficient for each label
        for neighbor in self.G.neighbors(node):
            weight = self.similarity_matrix[node][neighbor]
            for label, label_weight in self.labels[neighbor].items():
                if label not in neighbor_labels:
                    neighbor_labels[label] = 0
                neighbor_labels[label] += (label_weight * weight) / denominator

        # Determine the filtering threshold
        if self.threshold is None:
            threshold = 1.0 / len(neighbor_labels) if neighbor_labels else 0
        else:
            threshold = self.threshold 

        # # Filter out invalid labels
        filtered_labels = {k: v for k, v in neighbor_labels.items() if v > threshold}

        # If there are no tags after filtering, keep the largest tag.
        if not filtered_labels and neighbor_labels:
            max_label = max(neighbor_labels.items(), key=lambda x: x[1])
            filtered_labels = {max_label[0]: max_label[1]}

        # Normalize tag weights
        if filtered_labels: 
            total_weight = sum(filtered_labels.values())
            filtered_labels = {k: v / total_weight for k, v in filtered_labels.items()}

        # Update node labels
        self.labels[node] = filtered_labels

    def get_dominant_label(self, node):
        """
        Get the dominant label of the node
        Returns:
            tuple: (label, weight)
        """
        if not self.labels[node]:
            return None, 0.0
        return max(self.labels[node].items(), key=lambda x: x[1])

    def run(self, update_sequence, max_iter):
        """
        Run label propagation algorithm
        Args:
            update_sequence: List of node update order
            max_iter: Maximum number of iterations
        Returns:
            list: List of detected overlapping communities
        """
        self.initialize_labels()
        iter_count = 0

        for _ in range(max_iter):
            old_labels = {k: v.copy() for k, v in self.labels.items()}

            # Update labels according to node importance order
            for node in update_sequence:
                self.update_node_labels(node)

            # Check for convergence
            if all(old_labels[node] == self.labels[node] for node in self.G.nodes()):
                break

            iter_count += 1

        print(f"Label propagation converged after {iter_count} iterations")
        return self._extract_communities()

    def _extract_communities(self):
        """
        Extract overlapping communities from label distribution
        Returns:
            list: List of communities, where each community is a set of nodes
        """
        communities = {}

        for node, labels in self.labels.items():
            for label, weight in labels.items():
                if label not in communities:
                    communities[label] = set()
                communities[label].add(node)

        return list(communities.values())


