Commit d7e6c8cf authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

reload d2 nx graph and start reconstruct dg graphs + tests

parent 4e175f2c
......@@ -2,7 +2,7 @@ import networkx as nx
import itertools
from bidict import bidict
from d_graph import compute_all_max_d_graphs, filter_dominated, list_domination_filter
from d_graph import Dgraph, compute_all_max_d_graphs, filter_dominated, list_domination_filter
class D2Graph(object):
......@@ -69,8 +69,14 @@ class D2Graph(object):
fp.write(f"{d_graph.idx} {neighbor_idx} {dist}\n")
def load_from_gexf(self, filename):
pass
def load(self, filename):
# Reload the graph
self.nx_graph = nx.read_gexf(filename)
# Extract d-graphs from nx graph
self.all_d_graphs = []
for node in self.nx_graph.nodes():
self.all_d_graphs.append(Dgraph.load(node, self.graph))
def create_index_from_tuples(self, tuple_size=3):
......
......@@ -15,6 +15,25 @@ class Dgraph(object):
self.connexity = [None,None]
self.nodes = [self.center]
self.edges = []
""" Static method to load a dgraph from a text
@param text the saved d-graph
@param barcode_graph Barcode graph from which the d-graph is extracted
@return a new d-graph object corresponding to the test
"""
def load(text, barcode_graph):
# basic split
text = text.replace(']', '')
head, h1, h2 = text.split('[')
# Head parsing
head = head.split(' ')
dg = Dgraph(head[-3])
if len(head) == 4:
dg.idx = int(head[0])
return dg
""" Compute the d-graph quality (score) according to the connectivity between the two halves.
......
import unittest
import tempfile
import networkx as nx
from scipy.special import comb
from d2_graph import D2Graph
......@@ -79,5 +81,38 @@ class TestD2Graph(unittest.TestCase):
self.assertEquals(dist, awaited_dist)
def test_reloading(self):
# Parameters
d = 3
size = 2 * d + 3
index_k = 2 * d - 1
# Create a d2 graph
G = gm.generate_d_graph_chain(size, d)
d2 = D2Graph(G)
d2.construct_from_barcodes(index_size=index_k, verbose=False)
# Save and reload the d2 in a temporary file
with tempfile.NamedTemporaryFile() as fp:
# Save
nx.write_gexf(d2.nx_graph, fp.name)
# Reload
d2_reloaded = D2Graph(G)
d2_reloaded.load(fp.name)
# Test the nx graph
self.assertNotEquals(d2_reloaded.nx_graph, None)
self.assertEquals(len(d2_reloaded.nx_graph.nodes()), len(d2.nx_graph.nodes()))
self.assertEquals(len(d2_reloaded.nx_graph.edges()), len(d2.nx_graph.edges()))
# Test all_d_graphs
self.assertEquals(len(d2_reloaded.all_d_graphs), len(d2.all_d_graphs))
# Verify dg idxs
reloaded_idxs = [dg.idx for dg in d2_reloaded.all_d_graphs]
for dg in d2.all_d_graphs:
self.assertTrue(dg.idx in reloaded_idxs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment