Commit db2ebce9 authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

fix bugs induced by the d2g heritage simplification in all the mains

parent ba61e7d2
...@@ -7,47 +7,49 @@ from d2_path import Path, Unitig ...@@ -7,47 +7,49 @@ from d2_path import Path, Unitig
""" Remove unnecessary transitions """ Remove unnecessary transitions
""" """
def transitive_reduction(d2g): def transitive_reduction(d2g):
nxg = d2g.nx_graph edges = list(d2g.edges(data=True))
edges = list(nxg.edges())
# Remove self edges # Remove self edges
for edge in edges: for edge in edges:
if edge[0] == edge[1]: if edge[0] == edge[1]:
nxg.remove_edge(*edge) d2g.remove_edge(*edge)
edges = list(d2g.edges(data=True))
nb_removed = 0 nb_removed = 0
for edge in edges: while len(edges) > 0:
dg1_name, dg2_name, data = edges.pop(0)
# Extract dgs # Extract dgs
dg1 = d2g.nodes[edge[0]] dg1 = d2g.node_by_name[dg1_name]
dg2 = d2g.nodes[edge[1]] dg2 = d2g.node_by_name[dg2_name]
# Extract common neighbors # Extract common neighbors
nei1 = frozenset(nxg.neighbors(d2g.nodes.inverse[dg1])) nei1 = frozenset(d2g.neighbors(dg1_name))
nei2 = frozenset(nxg.neighbors(d2g.nodes.inverse[dg2])) nei2 = frozenset(d2g.neighbors(dg2_name))
common = nei1.intersection(nei2) common = nei1.intersection(nei2)
# Look for all the common neighbors, if edge must be remove or not # Look for all the common neighbors, if edge must be remove or not
current_dist = d2g.distances[dg1.idx][dg2.idx] current_dist = d2g[dg1_name][dg2_name]["distance"]
for node in common: for node in common:
com_dg = d2g.nodes[node] com_dg = d2g.node_by_name[node]
extern_dist = d2g.distances[dg1.idx][com_dg.idx] + d2g.distances[com_dg.idx][dg2.idx] extern_dist = d2g[dg1_name][node]["distance"] + d2g[node][dg2_name]["distance"]
# If better path, remove the edge # If better path, remove the edge
if extern_dist <= current_dist: if extern_dist <= current_dist:
# Remove from graph # Remove from graph
nxg.remove_edge(*edge) d2g.remove_edge(dg1_name, dg2_name)
# Remove in distances # Remove from edge list
del d2g.distances[dg1.idx][dg2.idx] edge = (dg1_name, dg2_name, data)
if len(d2g.distances[dg1.idx]) == 0: if edge in edges:
del d2g.distances[dg1.idx] edges.remove(edge)
del d2g.distances[dg2.idx][dg1.idx]
if len(d2g.distances[dg2.idx]) == 0: # Mark as removed
del d2g.distances[dg2.idx]
nb_removed += 1 nb_removed += 1
break break
print(f"{nb_removed} edge removed") print(f"{nb_removed} edge removed")
print(f"{len(nxg.edges())} remaining") print(f"{len(d2g.edges())} remaining")
""" For each node of the d2 graph, construct a node in the reduced graph. """ For each node of the d2 graph, construct a node in the reduced graph.
......
...@@ -44,9 +44,11 @@ class D2Graph(nx.Graph): ...@@ -44,9 +44,11 @@ class D2Graph(nx.Graph):
# Name the d-graphs # Name the d-graphs
# Number the d_graphs # Number the d_graphs
self.node_by_idx = {} self.node_by_idx = {}
self.node_by_name = {}
for idx, d_graph in enumerate(self.all_d_graphs): for idx, d_graph in enumerate(self.all_d_graphs):
d_graph.idx = idx d_graph.idx = idx
self.node_by_idx[idx] = d_graph self.node_by_idx[idx] = d_graph
self.node_by_name[str(d_graph)] = d_graph
# Index all the d-graphes # Index all the d-graphes
if verbose: if verbose:
...@@ -55,11 +57,10 @@ class D2Graph(nx.Graph): ...@@ -55,11 +57,10 @@ class D2Graph(nx.Graph):
self.filter_dominated_in_index() self.filter_dominated_in_index()
# Compute node distances for pair of dgraphs that share at least 1 dmer. # Compute node distances for pair of dgraphs that share at least 1 dmer.
if verbose: if verbose:
print("Compute a subset of distances") print("Compute the graph")
self.distances = self.compute_distances()
# Create the graph # Create the graph
self.bidict_nodes = self.create_graph() self.bidict_nodes = self.create_graph()
self.compute_distances()
def save(self, filename): def save(self, filename):
...@@ -72,22 +73,24 @@ class D2Graph(nx.Graph): ...@@ -72,22 +73,24 @@ class D2Graph(nx.Graph):
fp.write(f"{d_graph.idx} {' '.join([str(self.barcode_edge_idxs[e]) for e in d_graph.edges])}\n") fp.write(f"{d_graph.idx} {' '.join([str(self.barcode_edge_idxs[e]) for e in d_graph.edges])}\n")
# Write the distances # Write the distances
for d_graph in self.all_d_graphs: for x, y, data in self.edges(data=True):
for neighbor_idx, dist in self.distances[d_graph.idx].items(): dg1 = self.node_by_name[x]
fp.write(f"{d_graph.idx} {neighbor_idx} {dist}\n") dg2 = self.node_by_name[y]
fp.write(f"{dg1.idx} {dg2.idx} {data['distance']}\n")
def load(self, filename): def load(self, filename):
# Reload the graph # Reload the graph
G = nx.read_gexf(filename) G = nx.read_gexf(filename)
for node in G.nodes(): for node, attrs in G.nodes(data=True):
self.add_node(node) self.add_node(node, attr_dict=attrs)
for edge in G.edges(): for edge in G.edges(data=True):
self.add_edge(*edge) self.add_edge(edge[0], edge[1], distance=edge[2]["distance"])
# Extract d-graphs from nx graph # Extract d-graphs from nx graph
self.all_d_graphs = [] self.all_d_graphs = []
self.node_by_idx = {} self.node_by_idx = {}
self.node_by_name = {}
self.bidict_nodes = {} self.bidict_nodes = {}
for idx, node in enumerate(self.nodes()): for idx, node in enumerate(self.nodes()):
dg = Dgraph.load(node, self.barcode_graph) dg = Dgraph.load(node, self.barcode_graph)
...@@ -96,26 +99,9 @@ class D2Graph(nx.Graph): ...@@ -96,26 +99,9 @@ class D2Graph(nx.Graph):
if dg.idx == -1: if dg.idx == -1:
dg.idx = idx dg.idx = idx
self.node_by_idx[dg.idx] = dg self.node_by_idx[dg.idx] = dg
self.node_by_name[node] = dg
self.bidict_nodes = bidict(self.bidict_nodes) self.bidict_nodes = bidict(self.bidict_nodes)
# Extract edges and re-compute distances
self.distances = {}
for edge in self.edges():
# Get the dg pair
idx1 = int(edge[0].split(' ')[0])
idx2 = int(edge[1].split(' ')[0])
dg1 = self.node_by_idx[idx1]
dg2 = self.node_by_idx[idx2]
# Compute and save the distance
dist = dg1.distance_to(dg2)
if not idx1 in self.distances:
self.distances[idx1] = {}
self.distances[idx1][idx2] = dist
if not idx2 in self.distances:
self.distances[idx2] = {}
self.distances[idx2][idx1] = dist
def create_index_from_tuples(self, tuple_size=3): def create_index_from_tuples(self, tuple_size=3):
index = {} index = {}
...@@ -138,22 +124,16 @@ class D2Graph(nx.Graph): ...@@ -138,22 +124,16 @@ class D2Graph(nx.Graph):
def compute_distances(self): def compute_distances(self):
distances = {dg.idx:{} for dg in self.all_d_graphs} for x, y, data in self.edges(data=True):
dg1 = self.node_by_name[x]
for dmer, dgraphs in self.index.items(): dg2 = self.node_by_name[y]
for idx1, dg1 in enumerate(dgraphs): if dg1 == dg2:
for idx2 in range(idx1+1, len(dgraphs)): continue
dg2 = dgraphs[idx2]
if dg1 == dg2:
continue
# Distance computing and adding in the dist dicts
d = dg1.distance_to(dg2)
distances[dg1.idx][dg2.idx] = d
distances[dg2.idx][dg1.idx] = d
return distances
# Distance computing and adding in the dist dicts
d = dg1.distance_to(dg2)
data["distance"] = d
def create_index_ordered(self): def create_index_ordered(self):
index = {} index = {}
...@@ -188,7 +168,6 @@ class D2Graph(nx.Graph): ...@@ -188,7 +168,6 @@ class D2Graph(nx.Graph):
def create_graph(self): def create_graph(self):
# next_idx = 0
nodes = {} nodes = {}
for dmer in self.index: for dmer in self.index:
...@@ -196,14 +175,14 @@ class D2Graph(nx.Graph): ...@@ -196,14 +175,14 @@ class D2Graph(nx.Graph):
# Create a node name # Create a node name
if not dg in nodes: if not dg in nodes:
nodes[dg] = str(dg) nodes[dg] = str(dg)
# next_idx += 1
# Add the node # Add the node
self.add_node(nodes[dg]) self.add_node(nodes[dg])
# Add the edges # Add the edges
for prev_node in self.index[dmer][:d_idx]: for prev_node in self.index[dmer][:d_idx]:
self.add_edge(nodes[dg], nodes[prev_node]) if prev_node != dg:
self.add_edge(nodes[dg], nodes[prev_node])
return bidict(nodes) return bidict(nodes)
......
import networkx as nx
G = nx.path_graph(3)
print(G.edges(data=True))
edge_data = G[0][1]
edge_data["test"] = 2
# nx.set_edge_attributes(G, 0, "test")
# edge = list(G.edges(data=True))[0]
print(G.edges(data=True))
...@@ -68,15 +68,22 @@ class TestD2Graph(unittest.TestCase): ...@@ -68,15 +68,22 @@ class TestD2Graph(unittest.TestCase):
c3:{c1:4, c2:2} c3:{c1:4, c2:2}
} }
# distance tests for x, y, data in d2.edges(data=True):
for idx1, neighbors in d2.distances.items(): dg1 = d2.node_by_idx[int(x.split(" ")[0])]
dg1 = d2.node_by_idx[idx1] dg2 = d2.node_by_idx[int(y.split(" ")[0])]
for idx2, dist in neighbors.items(): awaited_dist = awaited_distances[dg1.center][dg2.center]
dg2 = d2.node_by_idx[idx2] self.assertEquals(data["distance"], awaited_dist)
# # distance tests
# for idx1, neighbors in d2.distances.items():
# dg1 = d2.node_by_idx[idx1]
# for idx2, dist in neighbors.items():
# dg2 = d2.node_by_idx[idx2]
awaited_dist = awaited_distances[dg1.center][dg2.center] # awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEquals(dist, awaited_dist) # self.assertEquals(dist, awaited_dist)
def test_reloading(self): def test_reloading(self):
...@@ -104,6 +111,8 @@ class TestD2Graph(unittest.TestCase): ...@@ -104,6 +111,8 @@ class TestD2Graph(unittest.TestCase):
self.assertEquals(len(d2_reloaded.nodes()), len(d2.nodes())) self.assertEquals(len(d2_reloaded.nodes()), len(d2.nodes()))
self.assertEquals(len(d2_reloaded.edges()), len(d2.edges())) self.assertEquals(len(d2_reloaded.edges()), len(d2.edges()))
# TODO: Verify distances
# Test all_d_graphs # Test all_d_graphs
self.assertEquals(len(d2_reloaded.all_d_graphs), len(d2.all_d_graphs)) self.assertEquals(len(d2_reloaded.all_d_graphs), len(d2.all_d_graphs))
# Verify dg idxs # Verify dg idxs
...@@ -111,15 +120,6 @@ class TestD2Graph(unittest.TestCase): ...@@ -111,15 +120,6 @@ class TestD2Graph(unittest.TestCase):
for dg in d2.all_d_graphs: for dg in d2.all_d_graphs:
self.assertTrue(dg.idx in reloaded_idxs) self.assertTrue(dg.idx in reloaded_idxs)
# Verify distances
self.assertEquals(len(d2.distances), len(d2_reloaded.distances))
for idx1 in d2.distances:
self.assertTrue(idx1 in d2_reloaded.distances)
self.assertEquals(len(d2.distances[idx1]), len(d2_reloaded.distances[idx1]))
for idx2 in d2.distances[idx1]:
self.assertTrue(idx2 in d2_reloaded.distances[idx1])
self.assertEquals(d2.distances[idx1][idx2], d2_reloaded.distances[idx1][idx2])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment