Commit c1d36427 authored by Rayan  CHIKHI's avatar Rayan CHIKHI
Browse files

merge

parents 9e8cd8ae 1c653eb8
......@@ -25,6 +25,24 @@ class PartialOrder:
# This score must be updated when the order is modified
self.score = 0
self.debug_stack = []
def copy(self):
copy = PartialOrder()
for ms in self.barcode_order:
copy.barcode_order.append(ms.copy())
for udg in self.udg_order:
copy.udg_order.append(udg)
for udg_set in self.udg_per_set:
copy.udg_per_set.append(udg_set.copy())
copy.len_barcodes = self.len_barcodes
copy.len_sets = self.len_sets
copy.len_udgs = self.len_udgs
copy.score = self.score
return copy
def _get_right_overlaps(self, udg):
"""" Get the overlap of an udg with the right part of the multiset partial order.
:param udg: The udg to overlap
......@@ -36,25 +54,24 @@ class PartialOrder:
# Will look for full overlaps from right to left
while current_set_idx >= 0:
ms = self.barcode_order[current_set_idx]
if len(ms - remaining_barcodes) == 0 and ms != remaining_barcodes:
current_set_idx -= 1
remaining_barcodes -= ms
remaining_barcodes = remaining_barcodes - ms
elif len(ms & remaining_barcodes) == 0:
return current_set_idx+1, Counter(), remaining_barcodes
else:
# leftmost multiset , leftmost non overlapping, non overlapping barcodes
return current_set_idx, ms - remaining_barcodes, remaining_barcodes - ms
return -1, Counter(), remaining_barcodes
return 0, Counter(), remaining_barcodes
def add_right(self, udg):
save = self.copy()
self.udg_order.append(udg)
self.len_udgs += 1
# Empty case
if len(self) == 0:
self.barcode_order.append(Counter(udg.nodes))
self.udg_per_set.append({udg})
return
scores = [0, 0, 0]
# Step 1 - Determine overlapping multisets from right to left
leftmost_idx, left_non_overlap, new_multiset = self._get_right_overlaps(udg)
......@@ -70,20 +87,64 @@ class PartialOrder:
# Copy the previous overlapping udg set for the new multiset
self.udg_per_set.insert(leftmost_idx, self.udg_per_set[leftmost_idx-1].copy())
self.len_sets += 1
self.score += len(self.udg_per_set[leftmost_idx])
scores[0] += len(self.udg_per_set[leftmost_idx])
# Step 3 - Add the udg as covering the right multisets
for idx in range(max(0, leftmost_idx), self.len_sets):
self.udg_per_set[idx].add(udg)
self.score += 1
scores[1] += 1
# Step 3 - Add a new multiset on the right for the remaining barcodes
# Step 4 - Add a new multiset on the right for the remaining barcodes
if len(new_multiset) > 0:
self.barcode_order.append(new_multiset)
self.udg_per_set.append(set())
self.len_sets += 1
self.len_barcodes += sum(new_multiset.values())
self.len_sets += 1
self.score += 1
scores[2] += 1
self.udg_per_set.append({udg})
# Step 4 - Add the udg as covering the right multisets
for idx in range(max(0, leftmost_idx), self.len_sets):
self.udg_per_set[idx].add(udg)
self.debug_stack.append((udg, scores))
# TODO: Step 5 - Modify score
def add_right2(self, udg):
left_idx, leftmost_overlap, rightmost_overlap = self._get_right_overlaps2(udg)
def remove_right(self):
save = self.copy()
# Step 1 - Remove the udg
last_udg = self.udg_order.pop()
scores = [0]*3
last_debug, last_scores = self.debug_stack.pop()
self.len_udgs -= 1
# Step 2 - Remove the last multiset if only cover by last_udg
if len(self.udg_per_set[-1]) == 1:
self.udg_per_set.pop()
ms = self.barcode_order.pop()
self.len_barcodes -= sum(ms.values())
self.len_sets -= 1
scores[2] = -1
self.score -= 1
# Step 3 - Remove last_udg from coverings from right to left
rightmost_covered_idx = len(self.barcode_order) - 1
while rightmost_covered_idx >= 0 and last_udg in self.udg_per_set[rightmost_covered_idx]:
self.udg_per_set[rightmost_covered_idx].remove(last_udg)
self.score -= 1
scores[1] -= 1
rightmost_covered_idx -= 1
# Step 4 - Merge the two left sets of interest if they are identical
left_interest = rightmost_covered_idx
if 0 <= left_interest < len(self.udg_per_set) - 1:
# Check set similarity
if self.udg_per_set[left_interest] == self.udg_per_set[left_interest+1]:
sets = self.udg_per_set.pop(left_interest)
self.score -= len(sets)
scores[0] -= len(sets)
ms = self.barcode_order.pop(left_interest)
self.barcode_order[left_interest] = self.barcode_order[left_interest] + ms
self.len_sets -= 1
return last_udg
def get_add_score(self, udg):
score = 0
......@@ -100,27 +161,109 @@ class PartialOrder:
# covering number points for the new udg
score += self.len_sets - leftmost_idx
# Negative points for redundant elements
# shift one left
remaining_size = sum(remaining_right.values()) - sum(left_non_overlap.values())
leftmost_idx -= 1
return score
def reverse_order(self):
self.barcode_order = self.barcode_order[::-1]
self.udg_order = self.udg_order[::-1]
self.udg_per_set = self.udg_per_set[::-1]
# Search for non overlapped common barcodes
while remaining_size > 0 and leftmost_idx >= 0:
ms = self.barcode_order[leftmost_idx]
def __len__(self):
return self.len_barcodes
# Compute intersection
common = ms & remaining_right
candidate_negative = sum(common.values())
score -= min(candidate_negative, remaining_size)
# Update structures
remaining_right -= common
remaining_size -= sum(ms.values())
leftmost_idx -= 1
_predicted_score = 0
_saved_neighbors = {}
def _next_node(d2g, partial_order, node, used):
node = str(node)
# create neighborhood on the first call
if node not in _saved_neighbors:
_saved_neighbors[node] = {str(x) for x in d2g[node] if not used[str(x)]}
return score
neighbors = _saved_neighbors[node]
# Return None is no usable neighbor
if len(neighbors) == 0:
return None
max_score = 0
max_neighbor_name = None
for neighbor_name in neighbors:
neighbor_udg = d2g.node_by_idx[int(neighbor_name)]
neighbor_score = partial_order.get_add_score(neighbor_udg)
if neighbor_score > max_score:
max_score = neighbor_score
max_neighbor_name = neighbor_name
if max_neighbor_name is None:
return None
else:
neighbors.discard(max_neighbor_name)
global _predicted_score
_predicted_score = max_score
return max_neighbor_name
def greedy_partial_order(d2g, node):
used_nodes = {str(n): False for n in d2g.nodes()}
used_nodes[str(node)] = True
current_node = node
current_udg = d2g.node_by_idx[int(node)]
po = PartialOrder()
po.add_right(current_udg)
forward = True
reverse = True
while forward or reverse:
next_node_name = _next_node(d2g, po, str(current_node), used_nodes)
if next_node_name is not None:
next_udg = d2g.node_by_idx[int(next_node_name)]
po.add_right(next_udg)
used_nodes[next_node_name] = True
current_node = next_node_name
else:
if forward:
forward = False
po.reverse_order()
current_node = str(po.udg_order[-1].idx)
else:
reverse = False
return po
def bb_partial_order(d2g, node):
used_nodes = {str(n): False for n in d2g.nodes()}
used_nodes[str(node)] = True
current_node_name = str(node)
current_udg = d2g.node_by_idx[int(node)]
po = PartialOrder()
po.add_right(current_udg)
can_continue = True
while can_continue:
next_node_name = _next_node(d2g, po, current_node_name, used_nodes)
# We found a new deeper solution
if next_node_name is not None:
next_udg = d2g.node_by_idx[int(next_node_name)]
_score = po.score
global _predicted_score
po.add_right(next_udg)
_score = po.score - _score
used_nodes[next_node_name] = True
current_node_name = next_node_name
# All the possible solutions have been explored
elif len(po) == 0:
can_continue = False
# We are in a dead end, must go back one step
else:
yield po.copy()
back_udg = po.remove_right()
used_nodes[current_node_name] = False
del _saved_neighbors[current_node_name]
current_node_name = str(back_udg.idx)
def __len__(self):
return self.len_barcodes
#!/usr/bin/env python3
import networkx as nx
import argparse
import sys
import random
from deconvolution.d2graph import d2_graph as d2
from barcodes.partialorder import greedy_partial_order, bb_partial_order
def parse_arguments():
parser = argparse.ArgumentParser(description='Greedy construction of a path through the d2 graph.')
parser.add_argument('barcode_graph', help='The barcode graph file. Must be a gefx formatted file.')
parser.add_argument('d2_graph', help='d2 graph to reduce. Must be a gexf formatted file.')
parser.add_argument('--out_prefix', '-o', default="", help="Output file prefix.")
args = parser.parse_args()
if args.out_prefix == "":
args.out_prefix = '.'.join(args.d2_graph.split('.')[:-1])
return args
def main():
# Parsing the arguments and validate them
args = parse_arguments()
barcode_file = args.barcode_graph
d2_file = args.d2_graph
if (not barcode_file.endswith('.gexf')) or (not d2_file.endswith(".gexf")):
print("Inputs file must be gexf formatted", file=sys.stderr)
exit(1)
# Loading
G = nx.read_gexf(barcode_file)
d2g = d2.D2Graph(G)
d2g.load(d2_file)
# Take the principal component
largest_component_nodes = max(nx.connected_components(d2g), key=len)
largest_component = d2g.subgraph(largest_component_nodes)
all_nodes = list(largest_component.nodes())
rnd_node = all_nodes[random.randint(0, len(all_nodes)-1)]
# po = greedy_partial_order(largest_component, rnd_node)
for po in bb_partial_order(largest_component, rnd_node):
print("barcodes", len(po), "sets", po.len_sets, "udgs", po.len_udgs, "score", po.score)
if __name__ == "__main__":
main()
......@@ -2,10 +2,10 @@
import sys
sys.setrecursionlimit(10000)
import argparse
from termcolor import colored
import networkx as nx
sys.setrecursionlimit(10000)
def parse_args():
......@@ -16,6 +16,7 @@ def parse_args():
help="Define the data type to evaluate. Must be 'd2' or 'path' or 'd2-2annotate' (Rayan's hack).")
parser.add_argument('--light-print', '-l', action='store_true',
help='Print only wrong nodes and paths')
parser.add_argument('--max_gap', '-g', type=int, default=0, help="Allow to jump over max_gap nodes during the increasing path search")
parser.add_argument('--barcode_graph', '-b', help="Path to the barcode graph corresponding to the d2_graph to analyse.")
parser.add_argument('--optimization_file', '-o',
help="If the main file is a d2, a file formatted for optimization can be set. This file will be used to compute the coverage of the longest path on the barcode graph.")
......@@ -55,8 +56,8 @@ def parse_udg_qualities(graph):
""" Compute the quality for the best udgs present in the graph.
All the node names must be under the format :
{idx}:{mol1_id}_{mol2_id}_...{molx_id}.other_things_here
:param graph: The networkx graph representinf the deconvolved graph
:return: A tuple containing two dictionaries. The first one with theoritical frequencies of each node, the second one with observed frequencies.
:param graph: The networkx graph representing the deconvolved graph
:return: A tuple containing two dictionaries. The first one with theoretical frequencies of each node, the second one with observed frequencies.
"""
dg_per_node = {}
......@@ -80,8 +81,7 @@ def parse_path_graph_frequencies(graph, barcode_graph):
All the node names must be under the format :
{idx}:{mol1_id}_{mol2_id}_...{molx_id}.other_things_here
:param graph: The networkx graph representing the deconvolved graph
:param only_wong: If True, don't print correct nodes
:param file_pointer: Where to print the output. If set to stdout, then pretty print. If set to None, don't print anything.
:param barcode_graph: The barcode graph
:return: A tuple containing two dictionaries. The first one with theoretical frequencies of each node, the second one with observed frequencies.
"""
# Compute origin nodes formatted as `{idx}:{mol1_id}_{mol2_id}_...`
......@@ -110,10 +110,10 @@ def parse_path_graph_frequencies(graph, barcode_graph):
return real_frequencies, observed_frequencies, node_per_barcode
""" This function aims to look for direct molecule neighbors.
If a node has more than 2 direct neighbors, it's not rightly splitted
"""
def parse_graph_path(graph):
""" This function aims to look for direct molecule neighbors.
If a node has more than 2 direct neighbors, it's not rightly split
"""
neighborhood = {}
for node in graph.nodes():
......@@ -250,10 +250,25 @@ def print_d2_summary(connected_components, longest_path, coverage_vars=(0, 0), l
print(f"Number of usable coverage variables: {len(coverage_vars[1])}")
print(f"Coverage: {len(coverage_vars[0])}/{len(coverage_vars[1])}")
print(f"Missing coverage variables:\n{coverage_vars[1]-coverage_vars[0]}")
if not light_print:
print(f"Missing coverage variables:\n{coverage_vars[1]-coverage_vars[0]}")
def _get_distant_neighbors(graph, node, dist):
neighbors = set()
to_compute = [node]
for _ in range(dist):
next_compute = []
for node in to_compute:
for neighbor in graph[node]:
if neighbor not in neighbors:
neighbors.add(neighbor)
next_compute.append(neighbor)
to_compute = next_compute
return neighbors
def compute_next_nodes(d2_component):
def compute_next_nodes(d2_component, max_jumps=0):
# First parse dg names
dg_names = {}
for node in d2_component.nodes():
......@@ -272,7 +287,8 @@ def compute_next_nodes(d2_component):
for mol_idx in molecule_idxs:
nexts = []
for neighbor in d2_component[node]:
# for neighbor in d2_component[node]:
for neighbor in _get_distant_neighbors(d2_component, node, max_jumps+1):
# nei_head: central node of the neighbor of 'node'
nei_head, _, _ = dg_names[neighbor]
nei_mols = mols_from_node(nei_head[1])
......@@ -292,15 +308,15 @@ def compute_next_nodes(d2_component):
return next_nodes
def compute_longest_increasing_paths(d2_component):
next_nodes = compute_next_nodes(d2_component)
def compute_longest_increasing_paths(d2_component, max_gap=0):
next_nodes = compute_next_nodes(d2_component, max_jumps=max_gap)
# Compute the longest path for each node
longest_paths = {}
for idx, start_node in enumerate(next_nodes):
# print(f"{idx}/{len(next_nodes)}")
for mol_idx in next_nodes[start_node]:
recursive_longest_path(start_node, mol_idx , next_nodes, longest_paths)
recursive_longest_path(start_node, mol_idx, next_nodes, longest_paths)
test_node = '5339'
for mol in longest_paths[test_node]:
......@@ -323,7 +339,7 @@ def compute_longest_increasing_paths(d2_component):
def backtrack_longest_path(node, molecule, longest_paths, path=[]):
if node == None:
if node is None:
return path
path.append((molecule, node))
......@@ -360,6 +376,40 @@ def recursive_longest_path(current_node, current_molecule, next_nodes, longest_p
return longest_paths[current_node][current_molecule]
# def longest_common_subsequence(barcode_true_path, barcoded_graph):
# """ Assume that the two graphs have an attribute barcode for each node and a unique node name"""
# path_nodes = []
# path_nodes_barcodes = []
# for node, data in barcode_true_path.nodes(data=True):
# path_nodes.append(node)
# path_nodes_barcodes.append(data["barcode"])
# path_nodes_to_idx = {n: idx for idx, n in enumerate(path_nodes)}
#
# graph_nodes = []
# graph_nodes_barcodes = []
# for node, data in barcoded_graph.nodes(data=True):
# graph_nodes.append(node)
# graph_nodes_barcodes.append(data["barcode"])
# graph_nodes_to_idx = {n: idx for idx, n in enumerate(graph_nodes)}
#
# dynamic_array = [[0 for _ in range(len(graph_nodes)+1)] for _ in range(len(path_nodes)+1)]
# for row in range(1, len(path_nodes)+1):
# path_node = path_nodes[row-1]
# path_barcode = path_nodes_barcodes[row-1]
#
# for column in range(1, len(graph_nodes)):
# graph_node = graph_nodes[column-1]
# graph_barcode = graph_nodes_barcodes[column-1]
#
# prev_scores = [dynamic_array[row-1][column]]
# for neighbor_node in barcoded_graph[graph_node]:
# neighbor_idx = graph_nodes_to_idx[neighbor_node]
# prev_scores.append(dynamic_array[row-1][neighbor_idx])
#
# match_point = 1 if path_barcode == graph_barcode else 0
# dynamic_array[row][column] = max(prev_scores) + match_point
def compute_covered_variables(graph, path):
path_nodes = set()
for mol, node_name in path:
......@@ -375,7 +425,7 @@ def compute_covered_variables(graph, path):
return used_vars, total_vars
# returns True iff there exist x in mol1 such that there exists y in mol2 and |x-y| <= some_value
# returns True iff there exist x in mol1 such that there exists y in mol2 and |x-y| <= some_value
def nearby_udg_molecules(mols1, mols2):
for x in mols1:
for y in mols2:
......@@ -390,19 +440,19 @@ def verify_graph_edges(d2_component):
head, c1, c2 = parse_dg_name(d2_component,node)
# Construct the molecule(s) that this udg really 'reflects'
# i.e. the udg has a central node and two cliques
# that central node is the result of merging of several molecules
# ideally, only one of those molecules is connected to the molecules of the cliques
# (there could be more than one though; in that case the udg is 'ambiguous')
# udg_molecules aims to reflect the molecule(s) underlying this udg
# i.e. the udg has a central node and two cliques
# that central node is the result of merging of several molecules
# ideally, only one of those molecules is connected to the molecules of the cliques
# (there could be more than one though; in that case the udg is 'ambiguous')
# udg_molecules aims to reflect the molecule(s) underlying this udg
udg_molecules = set()
# Get the current molecule idxs of central node
molecule_idxs = mols_from_node(head[1])
#print("mol idxs", molecule_idxs)
# print("mol idxs", molecule_idxs)
# Examine molecule idx's of cliques to see which are close to the central node
# rationale: c1/c2 contain nearby molecule id's
# rationale: c1/c2 contain nearby molecule id's
for mol_idx in molecule_idxs:
nexts = []
for c in [c1,c2]:
......@@ -417,9 +467,9 @@ def verify_graph_edges(d2_component):
nexts.sort(key=lambda x: x)
quality = sum([1.0/x if mol_idx+x in nexts else 0 for x in range(1,6)]) / sum([1.0/x for x in range(1,6)])
if quality > 0.6: eyeballed but still arbitrary
if quality > 0.6: # eyeballed but still arbitrary
udg_molecules.add(mol_idx)
#print("mol",mol_idx,molecule_idxs,"quality",quality,"nexts",nexts)
# print("mol",mol_idx,molecule_idxs,"quality",quality,"nexts",nexts)
udg_molecules_dict[head[0]]=udg_molecules
......@@ -437,7 +487,7 @@ def verify_graph_edges(d2_component):
else:
color = 'red'
data['color'] = color
#print("edge",node_udg_molecules,neighbor_udg_molecules,color)
# print("edge",node_udg_molecules,neighbor_udg_molecules,color)
# also, annotate nodes by their putative molecule found
for n, data in d2_component.nodes(data=True):
......@@ -458,7 +508,7 @@ def verify_graph_edges(d2_component):
if "_" in data['udg_molecule'] or data['udg_molecule'] == '':
if "_" in data['udg_molecule']:
m1, m2 = list(map(int,data['udg_molecule'].split("_")))
if abs(m2-m1) < 30: continue don't remove that kind of nodes
if abs(m2-m1) < 30: continue # don't remove that kind of nodes
nodes_to_remove += [n]
d2_component.remove_nodes_from(nodes_to_remove)
print("removed",len(nodes_to_remove),"bad nodes")
......@@ -505,7 +555,7 @@ def main():
components.sort(key=lambda x: -len(x))
component = graph.subgraph(components[0])
longest_path = compute_longest_increasing_paths(component)
longest_path = compute_longest_increasing_paths(component, max_gap=args.max_gap)
vars = compute_covered_variables(graph, longest_path)
print_d2_summary(components, longest_path, coverage_vars=vars, light_print=args.light_print)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment