Skip to content
Snippets Groups Projects
d2_graph.py 8.85 KiB
import networkx as nx
import itertools
from bidict import bidict
import sys

# from deconvolution.dgraph.FixedDGIndex import FixedDGIndex
from deconvolution.dgraph.VariableDGIndex import VariableDGIndex
from deconvolution.dgraph.d_graph import Dgraph
from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
from deconvolution.dgraph.LouvainDGFactory import LouvainDGFactory


class D2Graph(nx.Graph):
    """D2Graph (read it (d-graph)²)"""
    def __init__(self, barcode_graph, debug=False, debug_path='.'):
        super(D2Graph, self).__init__()
        self.all_d_graphs = []
        self.d_graphs_per_node = {}
        self.node_by_idx = {}
        self.barcode_graph = barcode_graph
        self.index = None

        self.variables_per_lcp = {}

        # Number the edges from original graph
        self.barcode_edge_idxs = {}
        self.nb_uniq_edge = 0
        for idx, edge in enumerate(self.barcode_graph.edges()):
            if edge == (edge[1], edge[0]):
                self.nb_uniq_edge += 1
            if edge in self.barcode_edge_idxs:
                print("Edge already present")
            self.barcode_edge_idxs[edge] = idx
            self.barcode_edge_idxs[(edge[1], edge[0])] = idx

        self.debug = debug
        self.debug_path = debug_path


    """ Redefine subgraph to avoid errors type instantiation errors.
    """
    def subgraph(self, nodes):
        nodes = frozenset(nodes)

        G = D2Graph(self.barcode_graph)
        G.barcode_edge_idxs = self.barcode_edge_idxs

        # Add sub-nodes
        for node in nodes:
            G.add_node(node)
            G.nodes[node].update(self.nodes[node])

        # Add edges
        for node1, node2, data in self.edges(data=True):
            if node1 in nodes and node2 in nodes:
                G.add_edge(node1, node2, distance=data["distance"])

        # Node by idx
        G.node_by_idx = self.node_by_idx

        return G

    def clone(self):
        return self.subgraph(list(self.nodes()))


    def construct_from_barcodes(self, neighbor_threshold=0.25, min_size_clique=4, verbose=False, clique_mode=None, threads=1):
        # Compute all the d-graphs
        if verbose:
            print("Computing the unit d-graphs..")
        dg_factory = None
        if clique_mode == "louvain":
            dg_factory = LouvainDGFactory(self.barcode_graph)
        else:
            dg_factory = CliqueDGFactory(self.barcode_graph, min_size_clique=min_size_clique, debug=self.debug, debug_path=self.debug_path)
        self.d_graphs_per_node = dg_factory.generate_all_dgraphs(threads=threads, verbose=verbose)
        if verbose:
            counts = sum(len(x) for x in self.d_graphs_per_node.values())
            print(f"\t {counts} computed d-graphs")
        for d_graphs in self.d_graphs_per_node.values():
            self.all_d_graphs.extend(d_graphs)

        # Number the d_graphs
        for idx, d_graph in enumerate(self.all_d_graphs):
            d_graph.idx = idx
            self.node_by_idx[idx] = d_graph

        if verbose:
            print("Compute the graph")
        # Create the graph
        self.bidict_nodes = self.create_graph_from_node_neighborhoods(neighbor_threshold)


    def get_lcp(self, obj):
        if type(obj) == str:
            obj = int(obj)
        if type(obj) == int:
            obj = self.node_by_idx[obj]
        return obj

    def get_covering_variables(self, obj):
        lcp = self.get_lcp(obj)
        if lcp not in self.variables_per_lcp:
            variables = []
            for e in lcp.edges:
                variables.append(self.barcode_edge_idxs[e])
            self.variables_per_lcp[lcp] = variables

        return self.variables_per_lcp[lcp]


    def save(self, filename):
        with open(filename, "w") as fp:
            # First line nb_nodes nb_cov_var
            fp.write(f"{len(self.all_d_graphs)} {int((len(self.barcode_edge_idxs)+self.nb_uniq_edge)/2)}\n")

            # Write the edges per d_graph
            for d_graph in self.all_d_graphs:
                fp.write(f"{d_graph.idx} {' '.join([str(self.barcode_edge_idxs[e]) for e in d_graph.edges])}\n")

            # Write the distances
            for x, y, data in self.edges(data=True):
                fp.write(f"{x} {y} {data['distance']}\n")


    def load(self, filename):
        # Reload the graph
        G = nx.read_gexf(filename)
        for node, attrs in G.nodes(data=True):
            self.add_node(node)
            self.nodes[node].update(attrs)
        for edge in G.edges(data=True):
            self.add_edge(edge[0], edge[1], distance=edge[2]["distance"])

        # Extract d-graphs from nx graph
        # self.node_by_name = {}
        self.bidict_nodes = {}
        for idx, node in enumerate(self.nodes(data=True)):
            node, data = node
            dg = Dgraph.load(data["udg"], self.barcode_graph)
            self.bidict_nodes[node] = dg
            self.all_d_graphs.append(dg)
            if dg.idx == -1:
                dg.idx = int(node)
            self.node_by_idx[dg.idx] = dg
            # self.node_by_name[node] = dg
        self.bidict_nodes = bidict(self.bidict_nodes)


    def create_index_from_tuples(self, tuple_size=3, verbose=True):
        index = {}

        if verbose:
            print("\tIndex d-graphs")
        for lst_idx, dg in enumerate(self.all_d_graphs):
            if verbose:
                sys.stdout.write(f"\r\t{lst_idx+1}/{len(self.all_d_graphs)}")
                sys.stdout.flush()

            nodelist = dg.to_sorted_list()
            if len(nodelist) < tuple_size:
                continue

            # Generate all tuplesize-mers
            for dmer in itertools.combinations(nodelist, tuple_size):
                if dmer not in index:
                    index[dmer] = set()
                index[dmer].add(dg)

        if verbose:
            print()

        return index


    def create_graph_from_node_neighborhoods(self, neighborhood_threshold=0.25):
        nodes = {}

        # Create the nodes of d2g from udgs
        for dg in self.all_d_graphs:
            nodes[dg] = dg.idx
            self.add_node(nodes[dg])
            # Add covering barcode edges
            barcode_edges = " ".join([str(self.barcode_edge_idxs[x]) for x in dg.edges])
            self.nodes[nodes[dg]]["barcode_edges"] = barcode_edges
            self.nodes[nodes[dg]]["score"] = f"{dg.score}/{dg.get_optimal_score()}"
            self.nodes[nodes[dg]]["udg"] = str(dg)
            self.nodes[nodes[dg]]["central_node_barcode"] = str(dg).split(']')[0]+']'

        # Create the edges from neighbor edges
        for dg in self.all_d_graphs:
            for node in dg.to_node_set():
                if node == dg.center:
                    continue
                entry = frozenset({node})
                if entry in self.d_graphs_per_node:
                    colliding_dgs = self.d_graphs_per_node[entry]
                    for colliding_dg in colliding_dgs:
                        distance = dg.distance_to(colliding_dg)
                        distance_ratio = distance / (len(dg.nodes) + len(colliding_dg.nodes))
                        if distance_ratio <= neighborhood_threshold:
                            self.add_edge(nodes[dg], nodes[colliding_dg], distance=distance)

        # Filter out singletons
        graph_nodes = list(nodes)
        for n in graph_nodes:
            if len(list(self.neighbors(nodes[n]))) == 0:
                self.remove_node(nodes[n])
                del nodes[n]

        return bidict(nodes)

    def create_graph_from_index(self):
        nodes = {}

        for dmer in self.index:
            dgs = list(set(self.index[dmer]))
            for d_idx, dg in enumerate(dgs):
                # Create a node name
                if dg not in nodes:
                    nodes[dg] = dg.idx

                    # Add the node
                    self.add_node(nodes[dg])
                    # Add covering barcode edges
                    barcode_edges = " ".join([str(self.barcode_edge_idxs[x]) for x in dg.edges])
                    self.nodes[nodes[dg]]["barcode_edges"] = barcode_edges
                    self.nodes[nodes[dg]]["score"] = f"{dg.score}/{dg.get_optimal_score()}"
                    self.nodes[nodes[dg]]["udg"] = str(dg)


                # Add the edges
                for prev_idx in range(d_idx):
                    prev_dg = dgs[prev_idx]

                    # Add on small distances
                    d = dg.distance_to(prev_dg)
                    if d <= min(len(dg.node_set)/2, len(prev_dg.node_set)/2):
                        self.add_edge(nodes[dg], nodes[prev_dg], distance=d)

        return bidict(nodes)


    def compute_distances(self):
        for x, y, data in self.edges(data=True):
            dg1 = self.node_by_idx[x]
            dg2 = self.node_by_idx[y]
            if dg1 == dg2:
                continue

            # Distance computing and adding in the dist dicts
            d = dg1.distance_to(dg2)
            data["distance"] = d