diff --git a/deconvolution/d2_graph.py b/deconvolution/d2_graph.py index d89b4a34e13176ec88cc310701d87d1dbcbc0106..26eb5a4bc0945c3e95032d6a1daab95d67ff4ba8 100644 --- a/deconvolution/d2_graph.py +++ b/deconvolution/d2_graph.py @@ -6,7 +6,7 @@ from d_graph import compute_all_max_d_graphs class D2Graph(object): """D2Graph (read it (d-graph)²)""" - def __init__(self, graph): + def __init__(self, graph, index_size=10): super(D2Graph, self).__init__() self.graph = graph @@ -14,19 +14,23 @@ class D2Graph(object): self.d_graphs = compute_all_max_d_graphs(self.graph) # Index all the d-graphes - self.index = self.create_index() + self.index = self.create_index_from_tuples(index_size) - def create_index(self): + + def create_index_from_tuples(self, tuple_size=3): index = {} perfect = 0 for node in self.d_graphs: for dg in self.d_graphs[node]: - nodeset = dg.to_node_set() - # Generate all dmers without one node - for el in nodeset: - dmer = nodeset.difference(frozenset([el])) + nodelist = dg.to_list() + nodelist.sort() + if len(nodelist) < tuple_size: + continue + + # Generate all tuplesize-mers + for dmer in itertools.combinations(nodelist, tuple_size): if not dmer in index: index[dmer] = [dg] else: diff --git a/deconvolution/d_graph.py b/deconvolution/d_graph.py index 16d7be1d0c4610ec6bba7312692a59fb9d934e77..85e241c4511a047e4f9da86ab380bb10172c7a57 100644 --- a/deconvolution/d_graph.py +++ b/deconvolution/d_graph.py @@ -51,6 +51,10 @@ class Dgraph(object): return max_len * (max_len - 1) / 2 + def to_list(self): + return self.halves[0]+ [self.center] + self.halves[1] + + def to_ordered_lists(self): hands = [[],[]] for idx in range(2): @@ -66,14 +70,12 @@ class Dgraph(object): return hands[0][::-1] + [[self.center]] + hands[1] - def to_node_set(self): - return frozenset(self.halves[0] + self.halves[1] + [self.center]) + def to_node_multiset(self): + return frozenset(self.to_list()) def __eq__(self, other): - my_tuple = (self.get_link_divergence(), self.get_optimal_score()) - other_tuple = (other.get_link_divergence(), other.get_optimal_score()) - return (my_tuple == other_tuple) + return self.to_ordered_lists() == other.to_ordered_lists() def __ne__(self, other): return not (self == other) @@ -85,7 +87,7 @@ class Dgraph(object): def __hash__(self): - nodelist = list(self.to_node_set()) + nodelist = list(self.to_list()) nodelist.sort() return ",".join(nodelist).__hash__() diff --git a/tests/d2_graph_test.py b/tests/d2_graph_test.py index 8e8b1eff74c03d098c07c2a3f60a8e7debb319a9..f38f3b3eb10787cd37004a78ec07f4044e60edd5 100644 --- a/tests/d2_graph_test.py +++ b/tests/d2_graph_test.py @@ -8,7 +8,7 @@ from tests.d_graph_data import unit_d_graph, unit_overlapp_d_graph, complete_gra class TestD2Graph(unittest.TestCase): def test_construction(self): - d2 = D2Graph(complete_graph) + d2 = D2Graph(complete_graph, 6) # Evaluate the number of candidate unit d_graphs generated for node, candidates in d2.d_graphs.items(): @@ -17,28 +17,20 @@ class TestD2Graph(unittest.TestCase): else: self.assertEquals(0, len(candidates)) - # Evaluate the hashes - self.assertEquals(3, len(d2.index)) + # Evaluate the index + self.assertEquals(13, len(d2.index)) - udg = Dgraph(unit_d_graph[0]) - udg.put_halves(unit_d_graph[1], unit_d_graph[2], unit_d_graph[3]) - uodg = Dgraph(unit_overlapp_d_graph[0]) - uodg.put_halves(unit_overlapp_d_graph[1], unit_overlapp_d_graph[2], unit_overlapp_d_graph[3]) - - key = frozenset({'A2', 'A1', 'B1', 'C', 'B0', 'B2'}) - self.assertEquals(2, len(d2.index[key])) - self.assertTrue(udg in d2.index[key]) - self.assertTrue(uodg in d2.index[key]) - key = frozenset({'A0', 'A2', 'A1', 'B1', 'C', 'B2'}) - self.assertEquals(1, len(d2.index[key])) - self.assertEquals(udg, d2.index[key][0]) - key = frozenset({'A2', 'B-1', 'B1', 'C', 'B2', 'B0'}) - self.assertEquals(1, len(d2.index[key])) - self.assertEquals(uodg, d2.index[key][0]) + overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C') + for dmer, dg_lst in d2.index.items(): + if dmer == overlap_key: + self.assertEquals(2, len(d2.index[dmer])) + self.assertNotEquals(d2.index[dmer][0], d2.index[dmer][1]) + else: + self.assertEquals(1, len(d2.index[dmer])) def test_to_nx_graph(self): - d2 = D2Graph(complete_graph) + d2 = D2Graph(complete_graph, 6) d2G, node_names = d2.to_nx_graph() nodes = list(d2G.nodes()) self.assertEquals(2, len(nodes))