Commit 3474f04e authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

transform d2g to heritate from networkx Graph instead of beeing composed of

parent 367747d2
......@@ -5,29 +5,29 @@ from bidict import bidict
from d_graph import Dgraph, compute_all_max_d_graphs, filter_dominated, list_domination_filter
class D2Graph(object):
class D2Graph(nx.Graph):
"""D2Graph (read it (d-graph)²)"""
def __init__(self, graph):
def __init__(self, barcode_graph):
super(D2Graph, self).__init__()
self.graph = graph
self.barcode_graph = barcode_graph
# Number the edges from original graph
self.edge_idxs = {}
self.barcode_edge_idxs = {}
self.nb_uniq_edge = 0
for idx, edge in enumerate(self.graph.edges()):
for idx, edge in enumerate(self.barcode_graph.edges()):
if edge == (edge[1], edge[0]):
self.nb_uniq_edge += 1
if edge in self.edge_idxs:
if edge in self.barcode_edge_idxs:
print("Edge already present")
self.edge_idxs[edge] = idx
self.edge_idxs[(edge[1], edge[0])] = idx
self.barcode_edge_idxs[edge] = idx
self.barcode_edge_idxs[(edge[1], edge[0])] = idx
def construct_from_barcodes(self, index_size=3, verbose=True, debug=False):
# Compute all the d-graphs
if verbose:
print("Compute the unit d-graphs")
self.d_graphs_per_node = compute_all_max_d_graphs(self.graph, debug=debug)
self.d_graphs_per_node = compute_all_max_d_graphs(self.barcode_graph, debug=debug)
self.d_graphs_per_node = filter_dominated(self.d_graphs_per_node)
self.all_d_graphs = []
for d_graphs in self.d_graphs_per_node.values():
......@@ -51,17 +51,17 @@ class D2Graph(object):
self.distances = self.compute_distances()
# Create the graph
self.nx_graph, self.nodes = self.to_nx_graph()
self.bidict_nodes = self.create_graph()
def save(self, filename):
with open(filename, "w") as fp:
# First line nb_nodes nb_cov_var
fp.write(f"{len(self.all_d_graphs)} {int((len(self.edge_idxs)+self.nb_uniq_edge)/2)}\n")
fp.write(f"{len(self.all_d_graphs)} {int((len(self.barcode_edge_idxs)+self.nb_uniq_edge)/2)}\n")
# Write the edges per d_graph
for d_graph in self.all_d_graphs:
fp.write(f"{d_graph.idx} {' '.join([str(self.edge_idxs[e]) for e in d_graph.edges])}\n")
fp.write(f"{d_graph.idx} {' '.join([str(self.barcode_edge_idxs[e]) for e in d_graph.edges])}\n")
# Write the distances
for d_graph in self.all_d_graphs:
......@@ -71,24 +71,28 @@ class D2Graph(object):
def load(self, filename):
# Reload the graph
self.nx_graph = nx.read_gexf(filename)
G = nx.read_gexf(filename)
for node in G.nodes():
self.add_node(node)
for edge in G.edges():
self.add_edge(*edge)
# Extract d-graphs from nx graph
self.all_d_graphs = []
self.node_by_idx = {}
self.nodes = {}
for idx, node in enumerate(self.nx_graph.nodes()):
dg = Dgraph.load(node, self.graph)
self.nodes[node] = dg
self.bidict_nodes = {}
for idx, node in enumerate(self.nodes()):
dg = Dgraph.load(node, self.barcode_graph)
self.bidict_nodes[node] = dg
self.all_d_graphs.append(dg)
if dg.idx == -1:
dg.idx = idx
self.node_by_idx[dg.idx] = dg
self.nodes = bidict(self.nodes)
self.bidict_nodes = bidict(self.bidict_nodes)
# Extract edges and re-compute distances
self.distances = {}
for edge in self.nx_graph.edges():
for edge in self.edges():
# Get the dg pair
idx1 = int(edge[0].split(' ')[0])
idx2 = int(edge[1].split(' ')[0])
......@@ -175,10 +179,9 @@ class D2Graph(object):
return index
def to_nx_graph(self):
def create_graph(self):
# next_idx = 0
nodes = {}
G = nx.Graph()
for dmer in self.index:
for d_idx, dg in enumerate(self.index[dmer]):
......@@ -188,13 +191,13 @@ class D2Graph(object):
# next_idx += 1
# Add the node
G.add_node(nodes[dg])
self.add_node(nodes[dg])
# Add the edges
for prev_node in self.index[dmer][:d_idx]:
G.add_edge(nodes[dg], nodes[prev_node])
self.add_edge(nodes[dg], nodes[prev_node])
return G, bidict(nodes)
return bidict(nodes)
def filter_dominated_in_index(self):
......
......@@ -20,7 +20,8 @@ class Path(list):
self.nodes = [x for x in self.nodes[::-1]]
def get_score(self):
def get_score(self, d2g):
return 0
......@@ -43,7 +44,3 @@ class Unitig(Path):
......@@ -56,8 +56,6 @@ class TestD2Graph(unittest.TestCase):
self.assertEquals(awaited_index_size, len(d2.index))
d2_nx = d2.nx_graph
# Test connectivity
# Center node names
c1 = d
......@@ -95,16 +93,16 @@ class TestD2Graph(unittest.TestCase):
# Save and reload the d2 in a temporary file
with tempfile.NamedTemporaryFile() as fp:
# Save
nx.write_gexf(d2.nx_graph, fp.name)
nx.write_gexf(d2, fp.name)
# Reload
d2_reloaded = D2Graph(G)
d2_reloaded.load(fp.name)
# Test the nx graph
self.assertNotEquals(d2_reloaded.nx_graph, None)
self.assertEquals(len(d2_reloaded.nx_graph.nodes()), len(d2.nx_graph.nodes()))
self.assertEquals(len(d2_reloaded.nx_graph.edges()), len(d2.nx_graph.edges()))
self.assertNotEquals(d2_reloaded, None)
self.assertEquals(len(d2_reloaded.nodes()), len(d2.nodes()))
self.assertEquals(len(d2_reloaded.edges()), len(d2.edges()))
# Test all_d_graphs
self.assertEquals(len(d2_reloaded.all_d_graphs), len(d2.all_d_graphs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment