Commit dd190fee authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

refactor index

parent aa5766f9
......@@ -3,6 +3,7 @@ import itertools
from bidict import bidict
import sys
from dgraph.FixedDGIndex import FixedDGIndex
from dgraph.d_graph import Dgraph, compute_all_max_d_graphs, list_domination_filter
......@@ -63,14 +64,9 @@ class D2Graph(nx.Graph):
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} computed d-graphs")
#self.d_graphs_per_node = filter_dominated(self.d_graphs_per_node)
#if verbose:
# counts = sum(len(x) for x in self.d_graphs_per_node.values())
# print(f"\t {counts} remaining d-graphs after first filter")
for d_graphs in self.d_graphs_per_node.values():
self.all_d_graphs.extend(d_graphs)
# Name the d-graphs
# Number the d_graphs
for idx, d_graph in enumerate(self.all_d_graphs):
d_graph.idx = idx
......@@ -79,14 +75,17 @@ class D2Graph(nx.Graph):
# Index all the d-graphs
if verbose:
print("Compute the dmer dgraph")
self.index = self.create_index_from_tuples(index_size, verbose=verbose)
self.filter_dominated_in_index(tuple_size=index_size, verbose=verbose)
self.index = FixedDGIndex(size=index_size)
for dg in self.all_d_graphs:
self.index.add_dgraph(dg)
self.index.filter_by_entry()
# self.index = self.create_index_from_tuples(index_size, verbose=verbose)
# self.filter_dominated_in_index(tuple_size=index_size, verbose=verbose)
# Compute node distances for pair of dgraphs that share at least 1 dmer.
if verbose:
print("Compute the graph")
# Create the graph
self.bidict_nodes = self.create_graph()
#self.compute_distances()
def get_covering_variables(self, udg):
......
......@@ -16,7 +16,7 @@ class AbstractDGIndex(dict):
def _verify_key(self, key_set, dg_size=0):
pass
def _add_value(self, key_set, dgraph):
def add_value(self, key_set, dgraph):
""" Add the couple key (set of barcodes) and value (dgraph) at the right place in the dict
"""
pass
......@@ -41,6 +41,34 @@ class AbstractDGIndex(dict):
pass
def _filter_entry(self, key_set):
""" For one entry in the index, filter out dominated dgraphs
:param key_set: The entry to filter
"""
# Verify presence in the index
if key_set not in self:
raise KeyError("The set is not present in the index")
# n² filtering
dgs = self[key_set]
to_remove = set()
for dg1 in dgs:
for dg2 in dgs:
if dg1.is_dominated(dg2):
to_remove.add(dg1)
break
self[key_set] = dgs - to_remove
return to_remove
def filter_by_entry(self):
for key_set in self:
removed = self._filter_entry(key_set)
# TODO: remove globaly ?
def __contains__(self, key):
key = frozenset(key)
return super(AbstractDGIndex, self).__contains__(key)
......
......@@ -16,4 +16,4 @@ class FixedDGIndex(AbstractDGIndex):
def add_dgraph(self, dg):
barcodes = dg.node_set
for tup in combinations(barcodes, self.size):
self._add_value(frozenset(tup), dg)
self.add_value(frozenset(tup), dg)
......@@ -18,4 +18,4 @@ class VariableDGIndex(AbstractDGIndex):
for size in range(len(barcodes)-self.size, len(barcodes)+1):
for tup in combinations(barcodes, size):
self._add_value(frozenset(tup), dg)
self.add_value(frozenset(tup), dg)
......@@ -2,6 +2,8 @@ import networkx as nx
from functools import total_ordering
import community # pip install python-louvain
from dgraph.FixedDGIndex import FixedDGIndex
@total_ordering
class Dgraph(object):
......@@ -233,7 +235,7 @@ class Dgraph(object):
@return A dictionary associating each node to its list of all possible d-graphs. The d-graphs are sorted by decreasing ratio.
"""
def compute_all_max_d_graphs(graph, debug=False, clique_mode=None):
d_graphs = {}
d_graphs = FixedDGIndex(size=1)
for idx, node in enumerate(list(graph.nodes())):
#if "MI" not in str(node): continue # for debugging; only look at deconvolved nodes
......@@ -278,18 +280,13 @@ def compute_all_max_d_graphs(graph, debug=False, clique_mode=None):
cliques = list(cliques_dict2.values())
mode_str += "(testing)"
#print("cliques", len(cliques))
if debug: print("node",node,"has",len(cliques),"cliques in neighborhood (of size",len(neighbors),")")
if debug: print("node", node, "has", len(cliques), "cliques in neighborhood (of size", len(neighbors), ")")
cliques_debugging = True
if cliques_debugging:
#cliques_graph = nx.make_max_clique_graph(neighbors_graph)
#if debug: print("node",node,"clique graph has",len(cliques_graph.nodes()),"nodes",len(cliques_graph.edges()),"edges")
#nx.write_gexf(cliques_graph, str(node) +".gexf")
from collections import Counter
len_cliques = Counter(map(len,cliques))
#print("sizes of found cliques%s:" % mode_str, len_cliques)
# Pair halves to create d-graphes
for idx, clq1 in enumerate(cliques):
......@@ -306,10 +303,12 @@ def compute_all_max_d_graphs(graph, debug=False, clique_mode=None):
if d_graph.get_link_divergence() <= d_graph.get_optimal_score() * factor:
node_d_graphs.add(d_graph)
#print("d-graphs", len(node_d_graphs))
d_graphs[node] = brutal_list_domination_filter(sorted(node_d_graphs))
#print("filtered", len(d_graphs[node]))
# Fill the index by node
key = frozenset({node})
for dg in node_d_graphs:
d_graphs.add_value(key, dg)
d_graphs.filter_by_entry()
return d_graphs
......
......@@ -21,14 +21,14 @@ class TestIndex(unittest.TestCase):
key = frozenset({'A', 'B'})
val = "Test"
with self.assertRaises(ValueError):
index._add_value(key, val)
index.add_value(key, val)
def test_fill_static(self):
index = FixedDGIndex(size=3)
key = frozenset({'A', 'B', 'C'})
val = "Test"
index._add_value(key, val)
index.add_value(key, val)
self.assertEqual(len(index), 1)
self.assertTrue(key in index)
self.assertEqual(index[key], {val})
......
......@@ -10,28 +10,28 @@ from d_graph_data import complete_graph
class TestD2Graph(unittest.TestCase):
def test_construction(self):
d2 = D2Graph(complete_graph)
d2.construct_from_barcodes(index_size=4, verbose=False)
# Evaluate the number of candidate unit d_graphs generated
for node, candidates in d2.d_graphs_per_node.items():
if node == "C" or node == "B2":
self.assertEqual(1, len(candidates))
else:
self.assertEqual(0, len(candidates))
# Evaluate the dgraph
self.assertEqual(13, len(d2.index))
overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
for dmer, dg_lst in d2.index.items():
if dmer == overlap_key:
values = list(d2.index[dmer])
self.assertEqual(2, len(d2.index[dmer]))
self.assertNotEqual(values[0], values[1])
else:
self.assertEqual(1, len(d2.index[dmer]))
# def test_construction(self):
# d2 = D2Graph(complete_graph)
# d2.construct_from_barcodes(index_size=4, verbose=False)
#
# # Evaluate the number of candidate unit d_graphs generated
# for node, candidates in d2.d_graphs_per_node.items():
# if node == "C" or node == "B2":
# self.assertEqual(1, len(candidates))
# else:
# self.assertEqual(0, len(candidates))
#
# # Evaluate the dgraph
# self.assertEqual(13, len(d2.index))
#
# overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
# for dmer, dg_lst in d2.index.items():
# if dmer == overlap_key:
# values = list(d2.index[dmer])
# self.assertEqual(2, len(d2.index[dmer]))
# self.assertNotEqual(values[0], values[1])
# else:
# self.assertEqual(1, len(d2.index[dmer]))
def test_linear_d2_construction(self):
for d in range(1, 10):
......@@ -42,6 +42,11 @@ class TestD2Graph(unittest.TestCase):
d2 = D2Graph(G)
d2.construct_from_barcodes(index_size=index_k, verbose=False)
# for dg in d2.all_d_graphs:
# print(dg.score, dg.get_link_divergence(), dg)
# print()
# Test the number of d-graphs
awaited_d_num = size - 2 * d
self.assertEqual(awaited_d_num, len(d2.all_d_graphs))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment