Skip to content
Snippets Groups Projects
Select Git revision
  • 02f879076fe668a57a9e43dc592acad10d88f937
  • master default protected
  • dev
  • score_test
4 results

d2_graph.py

Blame
  • d2_graph.py 7.37 KiB
    import networkx as nx
    import itertools
    from bidict import bidict
    import sys
    
    from deconvolution.dgraph.FixedDGIndex import FixedDGIndex
    from deconvolution.dgraph.VariableDGIndex import VariableDGIndex
    from deconvolution.dgraph.d_graph import Dgraph
    from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
    from deconvolution.dgraph.LouvainDGFactory import LouvainDGFactory
    
    
    class D2Graph(nx.Graph):
        """D2Graph (read it (d-graph)²)"""
        def __init__(self, barcode_graph):
            super(D2Graph, self).__init__()
            self.all_d_graphs = []
            self.d_graphs_per_node = {}
            self.node_by_idx = {}
            self.barcode_graph = barcode_graph
            self.index = None
    
            # Number the edges from original graph
            self.barcode_edge_idxs = {}
            self.nb_uniq_edge = 0
            for idx, edge in enumerate(self.barcode_graph.edges()):
                if edge == (edge[1], edge[0]):
                    self.nb_uniq_edge += 1
                if edge in self.barcode_edge_idxs:
                    print("Edge already present")
                self.barcode_edge_idxs[edge] = idx
                self.barcode_edge_idxs[(edge[1], edge[0])] = idx
    
    
        """ Redefine subgraph to avoid errors type instantiation errors.
        """
        def subgraph(self, nodes):
            nodes = frozenset(nodes)
    
            G = D2Graph(self.barcode_graph)
            G.barcode_edge_idxs = self.barcode_edge_idxs
    
            # Add sub-nodes
            for node in nodes:
                G.add_node(node)
                G.nodes[node].update(self.nodes[node])
    
            # Add edges
            for node1, node2, data in self.edges(data=True):
                if node1 in nodes and node2 in nodes:
                    G.add_edge(node1, node2, distance=data["distance"])
    
            # Node by idx
            G.node_by_idx = self.node_by_idx
    
            return G
    
        def clone(self):
            return self.subgraph(list(self.nodes()))
    
    
        def construct_from_barcodes(self, index_size=3, verbose=True, debug=False, clique_mode=None, threads=1):
            # Compute all the d-graphs
            if verbose:
                print("Computing the unit d-graphs..")
            dg_factory = None
            if clique_mode == "louvain":
                dg_factory = LouvainDGFactory(self.barcode_graph)
            else:
                dg_factory = CliqueDGFactory(self.barcode_graph)
            self.d_graphs_per_node = dg_factory.generate_all_dgraphs(debug=True)
            if verbose:
                counts = sum(len(x) for x in self.d_graphs_per_node.values())
                print(f"\t {counts} computed d-graphs")
            for d_graphs in self.d_graphs_per_node.values():
                self.all_d_graphs.extend(d_graphs)
    
            # Number the d_graphs
            for idx, d_graph in enumerate(self.all_d_graphs):
                d_graph.idx = idx
                self.node_by_idx[idx] = d_graph
    
            # Index all the d-graphs
            if verbose:
                print("Compute the dmer dgraph")
                print("\tIndexing")
            # self.index = FixedDGIndex(size=index_size)
            self.index = VariableDGIndex(size=2)
            for idx, dg in enumerate(self.all_d_graphs):
                if verbose:
                    print(f"\r\t{idx+1}/{len(self.all_d_graphs)}", end='')
                self.index.add_dgraph(dg)
                # self.var_index.add_dgraph(dg)
            if verbose:
                print()
                print("\tFilter index")
            self.index.filter_by_entry()
            # self.index = self.create_index_from_tuples(index_size, verbose=verbose)
            # self.filter_dominated_in_index(tuple_size=index_size, verbose=verbose)
            # Compute node distances for pair of dgraphs that share at least 1 dmer.
            if verbose:
                print("Compute the graph")
            # Create the graph
            self.bidict_nodes = self.create_graph()
    
    
        def get_covering_variables(self, udg):
            variables = []
            for e in udg.edges:
                variables.append(self.barcode_edge_idxs[e])
    
            return frozenset(variables)
    
    
        def save(self, filename):
            with open(filename, "w") as fp:
                # First line nb_nodes nb_cov_var
                fp.write(f"{len(self.all_d_graphs)} {int((len(self.barcode_edge_idxs)+self.nb_uniq_edge)/2)}\n")
    
                # Write the edges per d_graph
                for d_graph in self.all_d_graphs:
                    fp.write(f"{d_graph.idx} {' '.join([str(self.barcode_edge_idxs[e]) for e in d_graph.edges])}\n")
    
                # Write the distances
                for x, y, data in self.edges(data=True):
                    fp.write(f"{x} {y} {data['distance']}\n")
    
    
        def load(self, filename):
            # Reload the graph
            G = nx.read_gexf(filename)
            for node, attrs in G.nodes(data=True):
                self.add_node(node)
                self.nodes[node].update(attrs)
            for edge in G.edges(data=True):
                self.add_edge(edge[0], edge[1], distance=edge[2]["distance"])
    
            # Extract d-graphs from nx graph
            # self.node_by_name = {}
            self.bidict_nodes = {}
            for idx, node in enumerate(self.nodes(data=True)):
                node, data = node
                dg = Dgraph.load(data["udg"], self.barcode_graph)
                self.bidict_nodes[node] = dg
                self.all_d_graphs.append(dg)
                if dg.idx == -1:
                    dg.idx = int(node)
                self.node_by_idx[dg.idx] = dg
                # self.node_by_name[node] = dg
            self.bidict_nodes = bidict(self.bidict_nodes)
    
    
        def create_index_from_tuples(self, tuple_size=3, verbose=True):
            index = {}
    
            if verbose:
                print("\tIndex d-graphs")
            for lst_idx, dg in enumerate(self.all_d_graphs):
                if verbose:
                    sys.stdout.write(f"\r\t{lst_idx+1}/{len(self.all_d_graphs)}")
                    sys.stdout.flush()
    
                nodelist = dg.to_sorted_list()
                if len(nodelist) < tuple_size:
                    continue
    
                # Generate all tuplesize-mers
                for dmer in itertools.combinations(nodelist, tuple_size):
                    if dmer not in index:
                        index[dmer] = set()
                    index[dmer].add(dg)
    
            if verbose:
                print()
    
            return index
    
    
        def create_graph(self):
            nodes = {}
    
            for dmer in self.index:
                dgs = list(set(self.index[dmer]))
                for d_idx, dg in enumerate(dgs):
                    # Create a node name
                    if dg not in nodes:
                        nodes[dg] = dg.idx
    
                        # Add the node
                        self.add_node(nodes[dg])
                        # Add covering barcode edges
                        barcode_edges = " ".join([str(self.barcode_edge_idxs[x]) for x in dg.edges])
                        self.nodes[nodes[dg]]["barcode_edges"] = barcode_edges
                        self.nodes[nodes[dg]]["score"] = f"{dg.score}/{dg.get_optimal_score()}"
                        self.nodes[nodes[dg]]["udg"] = str(dg)
    
    
                    # Add the edges
                    for prev_idx in range(d_idx):
                        prev_dg = dgs[prev_idx]
    
                        # Add on small distances
                        d = dg.distance_to(prev_dg)
                        if d <= 5:
                            self.add_edge(nodes[dg], nodes[prev_dg], distance=d)
    
            return bidict(nodes)
    
    
        def compute_distances(self):
            for x, y, data in self.edges(data=True):
                dg1 = self.node_by_idx[x]
                dg2 = self.node_by_idx[y]
                if dg1 == dg2:
                    continue
    
                # Distance computing and adding in the dist dicts
                d = dg1.distance_to(dg2)
                data["distance"] = d