Commit db2ebce9 authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

fix bugs induced by the d2g heritage simplification in all the mains

parent ba61e7d2
......@@ -7,47 +7,49 @@ from d2_path import Path, Unitig
""" Remove unnecessary transitions
"""
def transitive_reduction(d2g):
nxg = d2g.nx_graph
edges = list(nxg.edges())
edges = list(d2g.edges(data=True))
# Remove self edges
for edge in edges:
if edge[0] == edge[1]:
nxg.remove_edge(*edge)
d2g.remove_edge(*edge)
edges = list(d2g.edges(data=True))
nb_removed = 0
for edge in edges:
while len(edges) > 0:
dg1_name, dg2_name, data = edges.pop(0)
# Extract dgs
dg1 = d2g.nodes[edge[0]]
dg2 = d2g.nodes[edge[1]]
dg1 = d2g.node_by_name[dg1_name]
dg2 = d2g.node_by_name[dg2_name]
# Extract common neighbors
nei1 = frozenset(nxg.neighbors(d2g.nodes.inverse[dg1]))
nei2 = frozenset(nxg.neighbors(d2g.nodes.inverse[dg2]))
nei1 = frozenset(d2g.neighbors(dg1_name))
nei2 = frozenset(d2g.neighbors(dg2_name))
common = nei1.intersection(nei2)
# Look for all the common neighbors, if edge must be remove or not
current_dist = d2g.distances[dg1.idx][dg2.idx]
current_dist = d2g[dg1_name][dg2_name]["distance"]
for node in common:
com_dg = d2g.nodes[node]
extern_dist = d2g.distances[dg1.idx][com_dg.idx] + d2g.distances[com_dg.idx][dg2.idx]
com_dg = d2g.node_by_name[node]
extern_dist = d2g[dg1_name][node]["distance"] + d2g[node][dg2_name]["distance"]
# If better path, remove the edge
if extern_dist <= current_dist:
# Remove from graph
nxg.remove_edge(*edge)
# Remove in distances
del d2g.distances[dg1.idx][dg2.idx]
if len(d2g.distances[dg1.idx]) == 0:
del d2g.distances[dg1.idx]
del d2g.distances[dg2.idx][dg1.idx]
if len(d2g.distances[dg2.idx]) == 0:
del d2g.distances[dg2.idx]
d2g.remove_edge(dg1_name, dg2_name)
# Remove from edge list
edge = (dg1_name, dg2_name, data)
if edge in edges:
edges.remove(edge)
# Mark as removed
nb_removed += 1
break
print(f"{nb_removed} edge removed")
print(f"{len(nxg.edges())} remaining")
print(f"{len(d2g.edges())} remaining")
""" For each node of the d2 graph, construct a node in the reduced graph.
......
......@@ -44,9 +44,11 @@ class D2Graph(nx.Graph):
# Name the d-graphs
# Number the d_graphs
self.node_by_idx = {}
self.node_by_name = {}
for idx, d_graph in enumerate(self.all_d_graphs):
d_graph.idx = idx
self.node_by_idx[idx] = d_graph
self.node_by_name[str(d_graph)] = d_graph
# Index all the d-graphes
if verbose:
......@@ -55,11 +57,10 @@ class D2Graph(nx.Graph):
self.filter_dominated_in_index()
# Compute node distances for pair of dgraphs that share at least 1 dmer.
if verbose:
print("Compute a subset of distances")
self.distances = self.compute_distances()
print("Compute the graph")
# Create the graph
self.bidict_nodes = self.create_graph()
self.compute_distances()
def save(self, filename):
......@@ -72,22 +73,24 @@ class D2Graph(nx.Graph):
fp.write(f"{d_graph.idx} {' '.join([str(self.barcode_edge_idxs[e]) for e in d_graph.edges])}\n")
# Write the distances
for d_graph in self.all_d_graphs:
for neighbor_idx, dist in self.distances[d_graph.idx].items():
fp.write(f"{d_graph.idx} {neighbor_idx} {dist}\n")
for x, y, data in self.edges(data=True):
dg1 = self.node_by_name[x]
dg2 = self.node_by_name[y]
fp.write(f"{dg1.idx} {dg2.idx} {data['distance']}\n")
def load(self, filename):
# Reload the graph
G = nx.read_gexf(filename)
for node in G.nodes():
self.add_node(node)
for edge in G.edges():
self.add_edge(*edge)
for node, attrs in G.nodes(data=True):
self.add_node(node, attr_dict=attrs)
for edge in G.edges(data=True):
self.add_edge(edge[0], edge[1], distance=edge[2]["distance"])
# Extract d-graphs from nx graph
self.all_d_graphs = []
self.node_by_idx = {}
self.node_by_name = {}
self.bidict_nodes = {}
for idx, node in enumerate(self.nodes()):
dg = Dgraph.load(node, self.barcode_graph)
......@@ -96,26 +99,9 @@ class D2Graph(nx.Graph):
if dg.idx == -1:
dg.idx = idx
self.node_by_idx[dg.idx] = dg
self.node_by_name[node] = dg
self.bidict_nodes = bidict(self.bidict_nodes)
# Extract edges and re-compute distances
self.distances = {}
for edge in self.edges():
# Get the dg pair
idx1 = int(edge[0].split(' ')[0])
idx2 = int(edge[1].split(' ')[0])
dg1 = self.node_by_idx[idx1]
dg2 = self.node_by_idx[idx2]
# Compute and save the distance
dist = dg1.distance_to(dg2)
if not idx1 in self.distances:
self.distances[idx1] = {}
self.distances[idx1][idx2] = dist
if not idx2 in self.distances:
self.distances[idx2] = {}
self.distances[idx2][idx1] = dist
def create_index_from_tuples(self, tuple_size=3):
index = {}
......@@ -138,22 +124,16 @@ class D2Graph(nx.Graph):
def compute_distances(self):
distances = {dg.idx:{} for dg in self.all_d_graphs}
for dmer, dgraphs in self.index.items():
for idx1, dg1 in enumerate(dgraphs):
for idx2 in range(idx1+1, len(dgraphs)):
dg2 = dgraphs[idx2]
if dg1 == dg2:
continue
# Distance computing and adding in the dist dicts
d = dg1.distance_to(dg2)
distances[dg1.idx][dg2.idx] = d
distances[dg2.idx][dg1.idx] = d
return distances
for x, y, data in self.edges(data=True):
dg1 = self.node_by_name[x]
dg2 = self.node_by_name[y]
if dg1 == dg2:
continue
# Distance computing and adding in the dist dicts
d = dg1.distance_to(dg2)
data["distance"] = d
def create_index_ordered(self):
index = {}
......@@ -188,7 +168,6 @@ class D2Graph(nx.Graph):
def create_graph(self):
# next_idx = 0
nodes = {}
for dmer in self.index:
......@@ -196,14 +175,14 @@ class D2Graph(nx.Graph):
# Create a node name
if not dg in nodes:
nodes[dg] = str(dg)
# next_idx += 1
# Add the node
self.add_node(nodes[dg])
# Add the edges
for prev_node in self.index[dmer][:d_idx]:
self.add_edge(nodes[dg], nodes[prev_node])
if prev_node != dg:
self.add_edge(nodes[dg], nodes[prev_node])
return bidict(nodes)
......
import networkx as nx
G = nx.path_graph(3)
print(G.edges(data=True))
edge_data = G[0][1]
edge_data["test"] = 2
# nx.set_edge_attributes(G, 0, "test")
# edge = list(G.edges(data=True))[0]
print(G.edges(data=True))
......@@ -68,15 +68,22 @@ class TestD2Graph(unittest.TestCase):
c3:{c1:4, c2:2}
}
# distance tests
for idx1, neighbors in d2.distances.items():
dg1 = d2.node_by_idx[idx1]
for x, y, data in d2.edges(data=True):
dg1 = d2.node_by_idx[int(x.split(" ")[0])]
dg2 = d2.node_by_idx[int(y.split(" ")[0])]
for idx2, dist in neighbors.items():
dg2 = d2.node_by_idx[idx2]
awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEquals(data["distance"], awaited_dist)
# # distance tests
# for idx1, neighbors in d2.distances.items():
# dg1 = d2.node_by_idx[idx1]
# for idx2, dist in neighbors.items():
# dg2 = d2.node_by_idx[idx2]
awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEquals(dist, awaited_dist)
# awaited_dist = awaited_distances[dg1.center][dg2.center]
# self.assertEquals(dist, awaited_dist)
def test_reloading(self):
......@@ -104,6 +111,8 @@ class TestD2Graph(unittest.TestCase):
self.assertEquals(len(d2_reloaded.nodes()), len(d2.nodes()))
self.assertEquals(len(d2_reloaded.edges()), len(d2.edges()))
# TODO: Verify distances
# Test all_d_graphs
self.assertEquals(len(d2_reloaded.all_d_graphs), len(d2.all_d_graphs))
# Verify dg idxs
......@@ -111,15 +120,6 @@ class TestD2Graph(unittest.TestCase):
for dg in d2.all_d_graphs:
self.assertTrue(dg.idx in reloaded_idxs)
# Verify distances
self.assertEquals(len(d2.distances), len(d2_reloaded.distances))
for idx1 in d2.distances:
self.assertTrue(idx1 in d2_reloaded.distances)
self.assertEquals(len(d2.distances[idx1]), len(d2_reloaded.distances[idx1]))
for idx2 in d2.distances[idx1]:
self.assertTrue(idx2 in d2_reloaded.distances[idx1])
self.assertEquals(d2.distances[idx1][idx2], d2_reloaded.distances[idx1][idx2])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment