Commit 6802438f authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

bugfix: unpredictable filter counts

parent d0b9fb62
import networkx as nx import networkx as nx
import itertools import itertools
from bidict import bidict from bidict import bidict
import sys
from d_graph import Dgraph, compute_all_max_d_graphs, filter_dominated, list_domination_filter from d_graph import Dgraph, compute_all_max_d_graphs, filter_dominated, list_domination_filter
...@@ -13,7 +14,8 @@ class D2Graph(nx.Graph): ...@@ -13,7 +14,8 @@ class D2Graph(nx.Graph):
self.d_graphs_per_node = {} self.d_graphs_per_node = {}
self.node_by_idx = {} self.node_by_idx = {}
self.barcode_graph = barcode_graph self.barcode_graph = barcode_graph
self.index = None
# Number the edges from original graph # Number the edges from original graph
self.barcode_edge_idxs = {} self.barcode_edge_idxs = {}
self.nb_uniq_edge = 0 self.nb_uniq_edge = 0
...@@ -58,7 +60,13 @@ class D2Graph(nx.Graph): ...@@ -58,7 +60,13 @@ class D2Graph(nx.Graph):
if verbose: if verbose:
print("Compute the unit d-graphs") print("Compute the unit d-graphs")
self.d_graphs_per_node = compute_all_max_d_graphs(self.barcode_graph, debug=debug) self.d_graphs_per_node = compute_all_max_d_graphs(self.barcode_graph, debug=debug)
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} computed d-graphs")
self.d_graphs_per_node = filter_dominated(self.d_graphs_per_node) self.d_graphs_per_node = filter_dominated(self.d_graphs_per_node)
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} remaining d-graphs after first filter")
for d_graphs in self.d_graphs_per_node.values(): for d_graphs in self.d_graphs_per_node.values():
self.all_d_graphs.extend(d_graphs) self.all_d_graphs.extend(d_graphs)
...@@ -67,13 +75,12 @@ class D2Graph(nx.Graph): ...@@ -67,13 +75,12 @@ class D2Graph(nx.Graph):
for idx, d_graph in enumerate(self.all_d_graphs): for idx, d_graph in enumerate(self.all_d_graphs):
d_graph.idx = idx d_graph.idx = idx
self.node_by_idx[idx] = d_graph self.node_by_idx[idx] = d_graph
# self.node_by_name[str(d_graph)] = d_graph
# Index all the d-graphs
# Index all the d-graphes
if verbose: if verbose:
print("Compute the dmer index") print("Compute the dmer index")
self.index = self.create_index_from_tuples(index_size) self.index = self.create_index_from_tuples(index_size, verbose=verbose)
self.filter_dominated_in_index() self.filter_dominated_in_index(tuple_size=index_size, verbose=verbose)
# Compute node distances for pair of dgraphs that share at least 1 dmer. # Compute node distances for pair of dgraphs that share at least 1 dmer.
if verbose: if verbose:
print("Compute the graph") print("Compute the graph")
...@@ -128,22 +135,28 @@ class D2Graph(nx.Graph): ...@@ -128,22 +135,28 @@ class D2Graph(nx.Graph):
self.bidict_nodes = bidict(self.bidict_nodes) self.bidict_nodes = bidict(self.bidict_nodes)
def create_index_from_tuples(self, tuple_size=3): def create_index_from_tuples(self, tuple_size=3, verbose=True):
index = {} index = {}
perfect = 0 if verbose:
for dg in self.all_d_graphs: print("\tIndex d-graphs")
nodelist = dg.to_list() for lst_idx, dg in enumerate(self.all_d_graphs):
nodelist.sort() if verbose:
sys.stdout.write(f"\r\t{lst_idx+1}/{len(self.all_d_graphs)}")
sys.stdout.flush()
nodelist = dg.to_sorted_list()
if len(nodelist) < tuple_size: if len(nodelist) < tuple_size:
continue continue
# Generate all tuplesize-mers # Generate all tuplesize-mers
for dmer in itertools.combinations(nodelist, tuple_size): for dmer in itertools.combinations(nodelist, tuple_size):
if not dmer in index: if dmer not in index:
index[dmer] = [dg] index[dmer] = set()
else: index[dmer].add(dg)
index[dmer].append(dg)
if verbose:
print()
return index return index
...@@ -158,49 +171,18 @@ class D2Graph(nx.Graph): ...@@ -158,49 +171,18 @@ class D2Graph(nx.Graph):
# Distance computing and adding in the dist dicts # Distance computing and adding in the dist dicts
d = dg1.distance_to(dg2) d = dg1.distance_to(dg2)
data["distance"] = d data["distance"] = d
def create_index_ordered(self):
index = {}
perfect = 0
for node in self.d_graphs_per_node:
for dg in self.d_graphs_per_node[node]:
lst = dg.to_ordered_lists()
# Generate all dmers without the first node
# pull all the values
concat = [el for l in lst[1:] for el in l]
# generate dmers
for idx in range(len(lst[0])):
dmer = frozenset(concat + lst[0][:idx] + lst[0][idx+1:])
if not dmer in index:
index[dmer] = [dg]
else:
index[dmer].append(dg)
# Generate all dmers without the last node
# pull all the values
concat = [el for l in lst[:-1] for el in l]
# generate dmers
for idx in range(len(lst[-1])):
dmer = frozenset(concat + lst[-1][:idx] + lst[-1][idx+1:])
if not dmer in index:
index[dmer] = [dg]
else:
index[dmer].append(dg)
return index
def create_graph(self): def create_graph(self):
nodes = {} nodes = {}
for dmer in self.index: for dmer in self.index:
for d_idx, dg in enumerate(self.index[dmer]): dgs = list(self.index[dmer])
for d_idx, dg in enumerate(dgs):
# Create a node name # Create a node name
if not dg in nodes: if dg not in nodes:
nodes[dg] = dg.idx nodes[dg] = dg.idx
# Add the node # Add the node
self.add_node(nodes[dg]) self.add_node(nodes[dg])
# Add covering barcode edges # Add covering barcode edges
...@@ -211,48 +193,73 @@ class D2Graph(nx.Graph): ...@@ -211,48 +193,73 @@ class D2Graph(nx.Graph):
# Add the edges # Add the edges
for prev_node in self.index[dmer][:d_idx]: for prev_node in dgs[:d_idx]:
if prev_node != dg: if prev_node != dg:
self.add_edge(nodes[dg], nodes[prev_node]) self.add_edge(nodes[dg], nodes[prev_node])
return bidict(nodes) return bidict(nodes)
def filter_dominated_in_index(self): def filter_dominated_in_index(self, tuple_size=3, verbose=True):
to_remove = [] to_remove = set()
if verbose:
print("\tFilter dominated in index")
# Find dominated # Find dominated
for dmer, dg_list in self.index.items(): for dmer_idx, item in enumerate(self.index.items()):
dmer, dg_list = item
if verbose:
sys.stdout.write(f"\r\t{dmer_idx+1}/{len(self.index)}")
sys.stdout.flush()
undominated = list_domination_filter(dg_list) undominated = list_domination_filter(dg_list)
# if len(undominated) > 1:
# print(dmer)
# print("\n".join([str(x) for x in undominated]))
# print()
# Register dominated # Register dominated
if len(dg_list) != len(undominated): if len(dg_list) != len(undominated):
for dg in dg_list: for dg in dg_list:
if not dg in undominated: if dg not in undominated:
to_remove.append(dg) to_remove.add(dg)
self.index[dmer] = undominated self.index[dmer] = undominated
to_remove = frozenset(to_remove) if verbose:
# Remove dominated in global list print()
for r_dg in to_remove: print("\tDmer removal")
# # Remove dominated in global list
# for r_idx, r_dg in enumerate(to_remove):
#
# self.all_d_graphs.remove(r_dg)
# self.d_graphs_per_node[r_dg.center].remove(r_dg)
#
# Remove from index
# for idx, dmer in enumerate(itertools.combinations(r_dg.to_sorted_list(), tuple_size)):
# if dmer in self.index[dmer]:
# self.index[dmer].remove(r_dg)
# if len(self.index[dmer]) == 0:
# del self.index[dmer]
removable_dmers = set()
for r_idx, r_dg in enumerate(to_remove):
if verbose:
sys.stdout.write(f"\r\t{r_idx+1}/{len(to_remove)}")
sys.stdout.flush()
self.all_d_graphs.remove(r_dg) self.all_d_graphs.remove(r_dg)
self.d_graphs_per_node[r_dg.center].remove(r_dg) self.d_graphs_per_node[r_dg.center].remove(r_dg)
# Remove dominated in index # Remove dominated in index
removable_dmers = [] for dmer in itertools.combinations(r_dg.to_sorted_list(), tuple_size):
for dmer in self.index:
for r_dg in to_remove:
if r_dg in self.index[dmer]: if r_dg in self.index[dmer]:
self.index[dmer] = list(filter(lambda x: x!=r_dg, self.index[dmer])) self.index[dmer] = list(filter(lambda x: x!=r_dg, self.index[dmer]))
if len(self.index[dmer]) == 0: if len(self.index[dmer]) == 0:
removable_dmers.append(dmer) removable_dmers.add(dmer)
# Remove empty dmers # Remove empty dmers
for dmer in removable_dmers: for dmer in removable_dmers:
del self.index[dmer] del self.index[dmer]
if verbose:
print()
...@@ -14,9 +14,10 @@ class Dgraph(object): ...@@ -14,9 +14,10 @@ class Dgraph(object):
self.halves = [None,None] self.halves = [None,None]
self.connexity = [None,None] self.connexity = [None,None]
self.nodes = [self.center] self.nodes = [self.center]
self.node_set = set(self.center) self.node_set = set(self.nodes)
self.edges = [] self.edges = []
self.ordered_list = None self.ordered_list = None
self.sorted_list = None
""" Static method to load a dgraph from a text """ Static method to load a dgraph from a text
...@@ -98,8 +99,11 @@ class Dgraph(object): ...@@ -98,8 +99,11 @@ class Dgraph(object):
return int(max_len * (max_len - 1) / 2) return int(max_len * (max_len - 1) / 2)
def to_list(self): def to_sorted_list(self):
return self.halves[0]+ [self.center] + self.halves[1] if self.sorted_list is None:
self.sorted_list = self.halves[0]+ [self.center] + self.halves[1]
self.sorted_list.sort()
return self.sorted_list
def to_ordered_lists(self): def to_ordered_lists(self):
...@@ -119,8 +123,8 @@ class Dgraph(object): ...@@ -119,8 +123,8 @@ class Dgraph(object):
return self.ordered_list return self.ordered_list
def to_node_multiset(self): def to_node_set(self):
return frozenset(self.to_list()) return frozenset(self.to_sorted_list())
def distance_to(self, dgraph): def distance_to(self, dgraph):
...@@ -149,8 +153,8 @@ class Dgraph(object): ...@@ -149,8 +153,8 @@ class Dgraph(object):
@return True if dg1 is dominated by dg2. @return True if dg1 is dominated by dg2.
""" """
def is_dominated(self, dg): def is_dominated(self, dg):
dg1_nodes = frozenset(self.to_list()) dg1_nodes = self.to_node_set()
dg2_nodes = frozenset(dg.to_list()) dg2_nodes = dg.to_node_set()
# domination first condition: inclusion of all the nodes # domination first condition: inclusion of all the nodes
if not dg1_nodes.issubset(dg2_nodes): if not dg1_nodes.issubset(dg2_nodes):
...@@ -188,9 +192,8 @@ class Dgraph(object): ...@@ -188,9 +192,8 @@ class Dgraph(object):
def __hash__(self): def __hash__(self):
nodelist = list(self.to_list()) nodelist = self.to_sorted_list()
nodelist = [str(x) for x in nodelist] nodelist = [str(x) for x in nodelist]
nodelist.sort()
return ",".join(nodelist).__hash__() return ",".join(nodelist).__hash__()
...@@ -227,7 +230,7 @@ def compute_all_max_d_graphs(graph, debug=False): ...@@ -227,7 +230,7 @@ def compute_all_max_d_graphs(graph, debug=False):
neighbors = list(graph.neighbors(node)) neighbors = list(graph.neighbors(node))
neighbors_graph = nx.Graph(graph.subgraph(neighbors)) neighbors_graph = nx.Graph(graph.subgraph(neighbors))
node_d_graphs = [] node_d_graphs = set()
# Find all the cliques (equivalent to compute all the candidate half d-graph) # Find all the cliques (equivalent to compute all the candidate half d-graph)
cliques = list(nx.find_cliques(neighbors_graph)) cliques = list(nx.find_cliques(neighbors_graph))
...@@ -243,7 +246,7 @@ def compute_all_max_d_graphs(graph, debug=False): ...@@ -243,7 +246,7 @@ def compute_all_max_d_graphs(graph, debug=False):
if d_graph.get_link_divergence() > d_graph.get_optimal_score() / 2: if d_graph.get_link_divergence() > d_graph.get_optimal_score() / 2:
continue continue
node_d_graphs.append(d_graph) node_d_graphs.add(d_graph)
# Cut the the distribution queue # Cut the the distribution queue
...@@ -262,20 +265,24 @@ def compute_all_max_d_graphs(graph, debug=False): ...@@ -262,20 +265,24 @@ def compute_all_max_d_graphs(graph, debug=False):
""" """
def add_new_dg_regarding_domination(dg, undominated_dgs_list): def add_new_dg_regarding_domination(dg, undominated_dgs_list):
to_remove = [] to_remove = []
dominated = False
# Search for domination relations # Search for domination relations
for u_dg in undominated_dgs_list: for u_dg in undominated_dgs_list:
if len(to_remove) == 0 and dg.is_dominated(u_dg): if not dominated and dg.is_dominated(u_dg):
return undominated_dgs_list dominated = True
elif u_dg.is_dominated(dg): if u_dg.is_dominated(dg):
to_remove.append(u_dg) to_remove.append(u_dg)
# Remove dominated values # Remove dominated values
size = len(undominated_dgs_list)
for dg2 in to_remove: for dg2 in to_remove:
undominated_dgs_list.remove(dg2) undominated_dgs_list.remove(dg2)
#print(size, len(to_remove), len(undominated_dgs_list))
# Add the new dg # Add the new dg
undominated_dgs_list.append(dg) if not dominated:
undominated_dgs_list.append(dg)
return undominated_dgs_list return undominated_dgs_list
...@@ -289,9 +296,9 @@ def filter_dominated(d_graphs, overall=False, in_place=True): ...@@ -289,9 +296,9 @@ def filter_dominated(d_graphs, overall=False, in_place=True):
for dgs in d_graphs.values(): for dgs in d_graphs.values():
all_d_graphs.extend(dgs) all_d_graphs.extend(dgs)
print(len(all_d_graphs)) # print(len(all_d_graphs))
all_d_graphs = list_domination_filter(all_d_graphs) all_d_graphs = list_domination_filter(all_d_graphs)
print(len(all_d_graphs)) # print(len(all_d_graphs))
return d_graphs return d_graphs
...@@ -310,16 +317,16 @@ def local_domination_filter(d_graphs, in_place=True): ...@@ -310,16 +317,16 @@ def local_domination_filter(d_graphs, in_place=True):
# Filter node by node # Filter node by node
for node, d_graph_list in d_graphs.items(): for node, d_graph_list in d_graphs.items():
# Add the non filtered d-graph to the output # Add the non filtered d-graph to the output
filtered[node] = list_domination_filter(d_graph_list) filtered[node] = brutal_list_domination_filter(d_graph_list)
return filtered return filtered
""" Filter the input d-graphs list. In the list of d-graph centered on a node n, if a d-graph is """ Filter the input d-graphs list. In the list of d-graph centered on a node n, if a d-graph is
completly included in another and have a highest distance score to the optimal, then it is completely included in another and have a highest distance score to the optimal, then it is
filtered out. filtered out.
@param d_graphs All the d-graphs to filter. @param d_graphs All the d-graphs to filter.
@return The filtered dictionnary of d-graph per node. @return The filtered dictionary of d-graph per node.
""" """
def list_domination_filter(d_graphs): def list_domination_filter(d_graphs):
filtered = [] filtered = []
...@@ -328,4 +335,15 @@ def list_domination_filter(d_graphs): ...@@ -328,4 +335,15 @@ def list_domination_filter(d_graphs):
for dg in d_graphs: for dg in d_graphs:
add_new_dg_regarding_domination(dg, filtered) add_new_dg_regarding_domination(dg, filtered)
return filtered return set(filtered)
def brutal_list_domination_filter(d_graphs):
undominated = set(d_graphs)
for dg1 in d_graphs:
for dg2 in d_graphs:
if dg1.is_dominated(dg2):
undominated.remove(dg1)
break
return undominated
...@@ -8,17 +8,7 @@ def generate_d_graph_chain(size, d): ...@@ -8,17 +8,7 @@ def generate_d_graph_chain(size, d):
:param d The number of connection on the left and on the right for any node :param d The number of connection on the left and on the right for any node
:return The d-graph chain :return The d-graph chain
""" """
G = nx.Graph() return generate_approx_d_graph_chain(size, d, d)
for idx in range(size):
# Create the node
G.add_node(idx)
# Link the node to d previous nodes
for prev in range(max(0, idx-d), idx):
G.add_edge(prev, idx)
return G
def generate_approx_d_graph_chain(size, d_max, d_avg, size_reduction=0, rnd_seed=-1): def generate_approx_d_graph_chain(size, d_max, d_avg, size_reduction=0, rnd_seed=-1):
......
...@@ -60,7 +60,7 @@ class Solution(Path): ...@@ -60,7 +60,7 @@ class Solution(Path):
""" Only respect counts for now """ Only respect counts for now
""" """
def to_barcode_path(self): def to_barcode_path(self):
barcode_per_position = [set(udg.to_list()) for udg in self] barcode_per_position = [set(udg.to_sorted_list()) for udg in self]
compressed_barcodes = [] compressed_barcodes = []
for idx, barcodes in enumerate(barcode_per_position): for idx, barcodes in enumerate(barcode_per_position):
......
...@@ -28,8 +28,9 @@ class TestD2Graph(unittest.TestCase): ...@@ -28,8 +28,9 @@ class TestD2Graph(unittest.TestCase):
overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C') overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
for dmer, dg_lst in d2.index.items(): for dmer, dg_lst in d2.index.items():
if dmer == overlap_key: if dmer == overlap_key:
values = list(d2.index[dmer])
self.assertEqual(2, len(d2.index[dmer])) self.assertEqual(2, len(d2.index[dmer]))
self.assertNotEqual(d2.index[dmer][0], d2.index[dmer][1]) self.assertNotEqual(values[0], values[1])
else: else:
self.assertEqual(1, len(d2.index[dmer])) self.assertEqual(1, len(d2.index[dmer]))
...@@ -73,16 +74,6 @@ class TestD2Graph(unittest.TestCase): ...@@ -73,16 +74,6 @@ class TestD2Graph(unittest.TestCase):
awaited_dist = awaited_distances[dg1.center][dg2.center] awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEqual(data["distance"], awaited_dist) self.assertEqual(data["distance"], awaited_dist)
# # distance tests
# for idx1, neighbors in d2.distances.items():
# dg1 = d2.node_by_idx[idx1]
# for idx2, dist in neighbors.items():
# dg2 = d2.node_by_idx[idx2]
# awaited_dist = awaited_distances[dg1.center][dg2.center]
# self.assertEqual(dist, awaited_dist)
def test_reloading(self): def test_reloading(self):
# Parameters # Parameters
d = 3 d = 3
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment