Skip to content
Snippets Groups Projects
CliqueDGFactory.py 4.96 KiB
import networkx as nx
from collections import Counter

from deconvolution.dgraph.AbstractDGFactory import AbstractDGFactory
from deconvolution.dgraph.d_graph import Dgraph
from deconvolution.dgraph import AbstractDGIndex


class CliqueDGFactory(AbstractDGFactory):

    def __init__(self, graph, min_size_clique=4, debug=False, debug_path="."):
        super(CliqueDGFactory, self).__init__(graph, debug=debug)
        self.min_size = min_size_clique

        if debug:
            self.debug_path = debug_path

            # Create mwm debug dir
            import os
            self.mwm_dir = f"{self.debug_path}/mwm"
            if os.path.isdir(self.mwm_dir):
                os.rmdir(self.mwm_dir)
            os.mkdir(self.mwm_dir)


    def generate_by_node(self, central_node, subgraph):
        node_d_graphs = set()
        node_neighbors = {node: [x for x in subgraph[node]] for node in subgraph.nodes}

        # Clique computation
        cliques = []
        clique_names = []
        clique_neighbors_multiset = []
        clique_neighbors_set = []
        clq_per_node = {node: [] for node in subgraph.nodes}
        idx = 0
        for clique in nx.find_cliques(subgraph):
            if len(clique) >= self.min_size:
                # Create the clique set
                clique_set = frozenset(clique)
                cliques.append(clique_set)

                # Index neighbors of the clique to speedup the divergence computation
                neighbors = []
                for node in clique:
                    # index clique per node. Useful ?
                    clq_per_node[node].append(idx)
                    # Prepare a neighbor multiset to speedup the divergence computation
                    neighbors.extend(node_neighbors[node])
                ms = Counter(neighbors)
                clique_neighbors_multiset.append(ms)
                clique_neighbors_set.append(frozenset(ms))
                idx += 1

        if self.debug is not None:
            for clique in cliques:
                names = [str(n) for n in clique]
                names.sort()
                clique_names.append(f"[{','.join(names)}]")

        def clique_divergence(c1_idx, c1, c2):
            observed_link = len(c1 & c2)  # Intersections of the nodes are glued links
            neighbor_intersection = clique_neighbors_set[c1_idx] & c2
            neighbors_multiset = clique_neighbors_multiset[c1_idx]
            for x in neighbor_intersection:
                observed_link += neighbors_multiset[x]  # Sum the links between the cliques

            # Awaited links
            d_approx = max(len(c1), len(c2))
            awaited = d_approx * (d_approx - 1) / 2

            return abs(awaited - observed_link)

        # Pair cliques
        def enumerate_clique_pair():
            for clq_idx, clq in enumerate(cliques):
                for node in clq:
                    # Looks into the neighbors for clique pairing
                    for nei in subgraph[node]:
                        nei_clqs = clq_per_node[nei]
                        # Pair useful cliques
                        for nei_clq in nei_clqs:
                            if nei_clq > clq_idx:
                                div_clq = clique_divergence(clq_idx, clq, cliques[nei_clq])
                                yield clq_idx, nei_clq, div_clq

        # Create the clique graph for max weight
        clq_G = nx.Graph()
        # Create nodes
        for idx in range(len(cliques)):
            clq_G.add_node(idx)

        # Create edges
        max_div = 0
        for idx1, idx2, div in enumerate_clique_pair():
            # Get cliques
            clq1 = cliques[idx1]
            clq2 = cliques[idx2]
            if div > max_div:
                max_div = div
            clq_G.add_edge(idx1, idx2, weight=div)
        # Normalize the divergence
        for idx1, idx2 in clq_G.edges():
            clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight']

        # d-graph computation regarding max weight matching
        mwm = nx.algorithms.max_weight_matching(clq_G)
        mwm_results = []
        for idx1, idx2 in mwm:
            # Get cliques
            clq1 = cliques[idx1]
            clq2 = cliques[idx2]
            # Create candidate udg
            d_graph = Dgraph(central_node)
            d_graph.put_halves(list(clq1), list(clq2), subgraph)
            node_d_graphs.add(d_graph)

            if self.debug is not None:
                mwm_results.append(" <-> ".join(sorted([clique_names[idx1], clique_names[idx2]])))

        if self.debug and len(clq_G.nodes) > 0:
            name_mapping = {idx:clique_names[idx] for idx in clq_G.nodes}
            clq_G = nx.relabel_nodes(clq_G, name_mapping)
            nx.write_gexf(clq_G, f"{self.mwm_dir}/{central_node.replace('/', '-')}.gexf")
            with open(f"{self.mwm_dir}/{central_node.replace('/', '-')}_matching.txt", "w") as matching:
                for result in mwm_results:
                    matching.write(f"{result}\n")

        return node_d_graphs