Commit 6802438f authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

bugfix: unpredictable filter counts

parent d0b9fb62
import networkx as nx
import itertools
from bidict import bidict
import sys
from d_graph import Dgraph, compute_all_max_d_graphs, filter_dominated, list_domination_filter
......@@ -13,6 +14,7 @@ class D2Graph(nx.Graph):
self.d_graphs_per_node = {}
self.node_by_idx = {}
self.barcode_graph = barcode_graph
self.index = None
# Number the edges from original graph
self.barcode_edge_idxs = {}
......@@ -58,7 +60,13 @@ class D2Graph(nx.Graph):
if verbose:
print("Compute the unit d-graphs")
self.d_graphs_per_node = compute_all_max_d_graphs(self.barcode_graph, debug=debug)
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} computed d-graphs")
self.d_graphs_per_node = filter_dominated(self.d_graphs_per_node)
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} remaining d-graphs after first filter")
for d_graphs in self.d_graphs_per_node.values():
self.all_d_graphs.extend(d_graphs)
......@@ -67,13 +75,12 @@ class D2Graph(nx.Graph):
for idx, d_graph in enumerate(self.all_d_graphs):
d_graph.idx = idx
self.node_by_idx[idx] = d_graph
# self.node_by_name[str(d_graph)] = d_graph
# Index all the d-graphes
# Index all the d-graphs
if verbose:
print("Compute the dmer index")
self.index = self.create_index_from_tuples(index_size)
self.filter_dominated_in_index()
self.index = self.create_index_from_tuples(index_size, verbose=verbose)
self.filter_dominated_in_index(tuple_size=index_size, verbose=verbose)
# Compute node distances for pair of dgraphs that share at least 1 dmer.
if verbose:
print("Compute the graph")
......@@ -128,22 +135,28 @@ class D2Graph(nx.Graph):
self.bidict_nodes = bidict(self.bidict_nodes)
def create_index_from_tuples(self, tuple_size=3):
def create_index_from_tuples(self, tuple_size=3, verbose=True):
index = {}
perfect = 0
for dg in self.all_d_graphs:
nodelist = dg.to_list()
nodelist.sort()
if verbose:
print("\tIndex d-graphs")
for lst_idx, dg in enumerate(self.all_d_graphs):
if verbose:
sys.stdout.write(f"\r\t{lst_idx+1}/{len(self.all_d_graphs)}")
sys.stdout.flush()
nodelist = dg.to_sorted_list()
if len(nodelist) < tuple_size:
continue
# Generate all tuplesize-mers
for dmer in itertools.combinations(nodelist, tuple_size):
if not dmer in index:
index[dmer] = [dg]
else:
index[dmer].append(dg)
if dmer not in index:
index[dmer] = set()
index[dmer].add(dg)
if verbose:
print()
return index
......@@ -160,45 +173,14 @@ class D2Graph(nx.Graph):
data["distance"] = d
def create_index_ordered(self):
index = {}
perfect = 0
for node in self.d_graphs_per_node:
for dg in self.d_graphs_per_node[node]:
lst = dg.to_ordered_lists()
# Generate all dmers without the first node
# pull all the values
concat = [el for l in lst[1:] for el in l]
# generate dmers
for idx in range(len(lst[0])):
dmer = frozenset(concat + lst[0][:idx] + lst[0][idx+1:])
if not dmer in index:
index[dmer] = [dg]
else:
index[dmer].append(dg)
# Generate all dmers without the last node
# pull all the values
concat = [el for l in lst[:-1] for el in l]
# generate dmers
for idx in range(len(lst[-1])):
dmer = frozenset(concat + lst[-1][:idx] + lst[-1][idx+1:])
if not dmer in index:
index[dmer] = [dg]
else:
index[dmer].append(dg)
return index
def create_graph(self):
nodes = {}
for dmer in self.index:
for d_idx, dg in enumerate(self.index[dmer]):
dgs = list(self.index[dmer])
for d_idx, dg in enumerate(dgs):
# Create a node name
if not dg in nodes:
if dg not in nodes:
nodes[dg] = dg.idx
# Add the node
......@@ -211,48 +193,73 @@ class D2Graph(nx.Graph):
# Add the edges
for prev_node in self.index[dmer][:d_idx]:
for prev_node in dgs[:d_idx]:
if prev_node != dg:
self.add_edge(nodes[dg], nodes[prev_node])
return bidict(nodes)
def filter_dominated_in_index(self):
to_remove = []
def filter_dominated_in_index(self, tuple_size=3, verbose=True):
to_remove = set()
if verbose:
print("\tFilter dominated in index")
# Find dominated
for dmer, dg_list in self.index.items():
for dmer_idx, item in enumerate(self.index.items()):
dmer, dg_list = item
if verbose:
sys.stdout.write(f"\r\t{dmer_idx+1}/{len(self.index)}")
sys.stdout.flush()
undominated = list_domination_filter(dg_list)
# if len(undominated) > 1:
# print(dmer)
# print("\n".join([str(x) for x in undominated]))
# print()
# Register dominated
if len(dg_list) != len(undominated):
for dg in dg_list:
if not dg in undominated:
to_remove.append(dg)
if dg not in undominated:
to_remove.add(dg)
self.index[dmer] = undominated
to_remove = frozenset(to_remove)
# Remove dominated in global list
for r_dg in to_remove:
if verbose:
print()
print("\tDmer removal")
# # Remove dominated in global list
# for r_idx, r_dg in enumerate(to_remove):
#
# self.all_d_graphs.remove(r_dg)
# self.d_graphs_per_node[r_dg.center].remove(r_dg)
#
# Remove from index
# for idx, dmer in enumerate(itertools.combinations(r_dg.to_sorted_list(), tuple_size)):
# if dmer in self.index[dmer]:
# self.index[dmer].remove(r_dg)
# if len(self.index[dmer]) == 0:
# del self.index[dmer]
removable_dmers = set()
for r_idx, r_dg in enumerate(to_remove):
if verbose:
sys.stdout.write(f"\r\t{r_idx+1}/{len(to_remove)}")
sys.stdout.flush()
self.all_d_graphs.remove(r_dg)
self.d_graphs_per_node[r_dg.center].remove(r_dg)
# Remove dominated in index
removable_dmers = []
for dmer in self.index:
for r_dg in to_remove:
for dmer in itertools.combinations(r_dg.to_sorted_list(), tuple_size):
if r_dg in self.index[dmer]:
self.index[dmer] = list(filter(lambda x: x!=r_dg, self.index[dmer]))
if len(self.index[dmer]) == 0:
removable_dmers.append(dmer)
removable_dmers.add(dmer)
# Remove empty dmers
for dmer in removable_dmers:
del self.index[dmer]
if verbose:
print()
......@@ -14,9 +14,10 @@ class Dgraph(object):
self.halves = [None,None]
self.connexity = [None,None]
self.nodes = [self.center]
self.node_set = set(self.center)
self.node_set = set(self.nodes)
self.edges = []
self.ordered_list = None
self.sorted_list = None
""" Static method to load a dgraph from a text
......@@ -98,8 +99,11 @@ class Dgraph(object):
return int(max_len * (max_len - 1) / 2)
def to_list(self):
return self.halves[0]+ [self.center] + self.halves[1]
def to_sorted_list(self):
if self.sorted_list is None:
self.sorted_list = self.halves[0]+ [self.center] + self.halves[1]
self.sorted_list.sort()
return self.sorted_list
def to_ordered_lists(self):
......@@ -119,8 +123,8 @@ class Dgraph(object):
return self.ordered_list
def to_node_multiset(self):
return frozenset(self.to_list())
def to_node_set(self):
return frozenset(self.to_sorted_list())
def distance_to(self, dgraph):
......@@ -149,8 +153,8 @@ class Dgraph(object):
@return True if dg1 is dominated by dg2.
"""
def is_dominated(self, dg):
dg1_nodes = frozenset(self.to_list())
dg2_nodes = frozenset(dg.to_list())
dg1_nodes = self.to_node_set()
dg2_nodes = dg.to_node_set()
# domination first condition: inclusion of all the nodes
if not dg1_nodes.issubset(dg2_nodes):
......@@ -188,9 +192,8 @@ class Dgraph(object):
def __hash__(self):
nodelist = list(self.to_list())
nodelist = self.to_sorted_list()
nodelist = [str(x) for x in nodelist]
nodelist.sort()
return ",".join(nodelist).__hash__()
......@@ -227,7 +230,7 @@ def compute_all_max_d_graphs(graph, debug=False):
neighbors = list(graph.neighbors(node))
neighbors_graph = nx.Graph(graph.subgraph(neighbors))
node_d_graphs = []
node_d_graphs = set()
# Find all the cliques (equivalent to compute all the candidate half d-graph)
cliques = list(nx.find_cliques(neighbors_graph))
......@@ -243,7 +246,7 @@ def compute_all_max_d_graphs(graph, debug=False):
if d_graph.get_link_divergence() > d_graph.get_optimal_score() / 2:
continue
node_d_graphs.append(d_graph)
node_d_graphs.add(d_graph)
# Cut the the distribution queue
......@@ -262,19 +265,23 @@ def compute_all_max_d_graphs(graph, debug=False):
"""
def add_new_dg_regarding_domination(dg, undominated_dgs_list):
to_remove = []
dominated = False
# Search for domination relations
for u_dg in undominated_dgs_list:
if len(to_remove) == 0 and dg.is_dominated(u_dg):
return undominated_dgs_list
elif u_dg.is_dominated(dg):
if not dominated and dg.is_dominated(u_dg):
dominated = True
if u_dg.is_dominated(dg):
to_remove.append(u_dg)
# Remove dominated values
size = len(undominated_dgs_list)
for dg2 in to_remove:
undominated_dgs_list.remove(dg2)
#print(size, len(to_remove), len(undominated_dgs_list))
# Add the new dg
if not dominated:
undominated_dgs_list.append(dg)
return undominated_dgs_list
......@@ -289,9 +296,9 @@ def filter_dominated(d_graphs, overall=False, in_place=True):
for dgs in d_graphs.values():
all_d_graphs.extend(dgs)
print(len(all_d_graphs))
# print(len(all_d_graphs))
all_d_graphs = list_domination_filter(all_d_graphs)
print(len(all_d_graphs))
# print(len(all_d_graphs))
return d_graphs
......@@ -310,16 +317,16 @@ def local_domination_filter(d_graphs, in_place=True):
# Filter node by node
for node, d_graph_list in d_graphs.items():
# Add the non filtered d-graph to the output
filtered[node] = list_domination_filter(d_graph_list)
filtered[node] = brutal_list_domination_filter(d_graph_list)
return filtered
""" Filter the input d-graphs list. In the list of d-graph centered on a node n, if a d-graph is
completly included in another and have a highest distance score to the optimal, then it is
completely included in another and have a highest distance score to the optimal, then it is
filtered out.
@param d_graphs All the d-graphs to filter.
@return The filtered dictionnary of d-graph per node.
@return The filtered dictionary of d-graph per node.
"""
def list_domination_filter(d_graphs):
filtered = []
......@@ -328,4 +335,15 @@ def list_domination_filter(d_graphs):
for dg in d_graphs:
add_new_dg_regarding_domination(dg, filtered)
return filtered
return set(filtered)
def brutal_list_domination_filter(d_graphs):
undominated = set(d_graphs)
for dg1 in d_graphs:
for dg2 in d_graphs:
if dg1.is_dominated(dg2):
undominated.remove(dg1)
break
return undominated
......@@ -8,17 +8,7 @@ def generate_d_graph_chain(size, d):
:param d The number of connection on the left and on the right for any node
:return The d-graph chain
"""
G = nx.Graph()
for idx in range(size):
# Create the node
G.add_node(idx)
# Link the node to d previous nodes
for prev in range(max(0, idx-d), idx):
G.add_edge(prev, idx)
return G
return generate_approx_d_graph_chain(size, d, d)
def generate_approx_d_graph_chain(size, d_max, d_avg, size_reduction=0, rnd_seed=-1):
......
......@@ -60,7 +60,7 @@ class Solution(Path):
""" Only respect counts for now
"""
def to_barcode_path(self):
barcode_per_position = [set(udg.to_list()) for udg in self]
barcode_per_position = [set(udg.to_sorted_list()) for udg in self]
compressed_barcodes = []
for idx, barcodes in enumerate(barcode_per_position):
......
......@@ -28,8 +28,9 @@ class TestD2Graph(unittest.TestCase):
overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
for dmer, dg_lst in d2.index.items():
if dmer == overlap_key:
values = list(d2.index[dmer])
self.assertEqual(2, len(d2.index[dmer]))
self.assertNotEqual(d2.index[dmer][0], d2.index[dmer][1])
self.assertNotEqual(values[0], values[1])
else:
self.assertEqual(1, len(d2.index[dmer]))
......@@ -73,16 +74,6 @@ class TestD2Graph(unittest.TestCase):
awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEqual(data["distance"], awaited_dist)
# # distance tests
# for idx1, neighbors in d2.distances.items():
# dg1 = d2.node_by_idx[idx1]
# for idx2, dist in neighbors.items():
# dg2 = d2.node_by_idx[idx2]
# awaited_dist = awaited_distances[dg1.center][dg2.center]
# self.assertEqual(dist, awaited_dist)
def test_reloading(self):
# Parameters
d = 3
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment