Commit 1ff78387 authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

Cliques vs udgs experiments$

parent a304859c
from progressbar import ProgressBar
pbar = ProgressBar()
OUTDIR="snake_exec" if "outdir" not in config else config["outdir"]
N=[10000] if "n" not in config else config["n"] # Number of molecule to simulate OUTDIR="snake_experiments" if "outdir" not in config else config["outdir"]
D=[5] if "d" not in config else config["d"] # Average coverage of each molecule trials=1
M=[2] if "m" not in config else config["m"] # Average number of molecule per barcode N=[1000] if "n" not in config else config["n"] # Number of molecule to simulate
M_DEV=[0] if "m_dev" not in config else config["m_dev"] # Std deviation for merging number D=[10] if "d" not in config else config["d"] # Average coverage of each molecule
M=[2, 3] if "m" not in config else config["m"] # Average number of molecule per barcode
# M_DEV=[0, 0.5, 1] if "m_dev" not in config else config["m_dev"] # Std deviation for merging number
rule all: rule all:
input: input:
expand(f"{OUTDIR}/simu_bar_n{{n}}_d{{d}}_m{{m}}-dev{{md}}.gexf", n=N, m=M, d=D, md=M_DEV) f"{OUTDIR}/results.tsv"
rule generate_tsv:
input:
expand(f"{OUTDIR}/simu_{{exp}}_bar_n{{n}}_d{{d}}_m{{m}}_results.txt", exp=list(range(trials)), n=N, m=M, d=D) #, md=M_DEV)
output:
f"{OUTDIR}/results.tsv"
run:
with open(str(output), "w") as out:
out.write("nb_mols\tfusion\trun\ttp_cliques\tcliques\ttp_udgs\tudgs\n")
for file in input:
# Values extraction
tp_clqs=0
clqs=0
tp_udgs=0
udgs=0
with open(file) as inp:
lines = inp.readlines()
tp_clqs, clqs = [int(x) for x in lines[1].strip().split(" / ")]
tp_udgs, udgs = [int(x) for x in lines[3].strip().split(" / ")]
# Get the important values from the filename
file = file.split("/")[-1]
names = file.split("_")
idx, nb_mols, fusion = int(names[1]), int(names[3][1:]), int(names[5][1:].split(".")[0])
out.write(f"{nb_mols}\t{fusion}\t{idx}\t{tp_clqs}\t{clqs}\t{tp_udgs}\t{udgs}\n")
rule mesure_quality:
input:
f"{OUTDIR}/simu_{{exp}}_bar_n{{n}}_d{{d}}_m{{m}}.gexf"
output:
f"{OUTDIR}/simu_{{exp}}_bar_n{{n}}_d{{d}}_m{{m}}_results.txt"
shell:
"python3 experiments/clique_graph_eval.py {input} > {output}"
rule generate_barcodes: rule generate_barcodes:
input: input:
"{path}/simu_mol_{params}.gexf" "{path}/simu_mol_{params}.gexf"
output: output:
"{path}/simu_bar_{params}_m{m}-dev{md}.gexf" "{path}/simu_{idx}_bar_{params}_m{m}.gexf"
shell: shell:
"python3 deconvolution/main/generate_fake_barcode_graph.py --merging_depth {wildcards.m} --deviation {wildcards.md} --input_graph {input} --output {output}" "python3 deconvolution/main/generate_fake_barcode_graph.py --merging_depth {wildcards.m} --input_graph {input} --output {output}"
rule generate_molecules: rule generate_molecules:
output: output:
......
...@@ -29,6 +29,7 @@ class CliqueDGFactory(AbstractDGFactory): ...@@ -29,6 +29,7 @@ class CliqueDGFactory(AbstractDGFactory):
# Clique computation # Clique computation
cliques = [] cliques = []
clique_names = []
clique_neighbors_multiset = [] clique_neighbors_multiset = []
clique_neighbors_set = [] clique_neighbors_set = []
clq_per_node = {node: [] for node in subgraph.nodes} clq_per_node = {node: [] for node in subgraph.nodes}
...@@ -51,17 +52,11 @@ class CliqueDGFactory(AbstractDGFactory): ...@@ -51,17 +52,11 @@ class CliqueDGFactory(AbstractDGFactory):
clique_neighbors_set.append(frozenset(ms)) clique_neighbors_set.append(frozenset(ms))
idx += 1 idx += 1
# def clique_divergence(c1, c2): if self.debug is not None:
# # Observed link for clique in cliques:
# nb_links = 0 names = [str(n) for n in clique]
# for node in c1: names.sort()
# neighbors = clique_neighbors[node] clique_names.append(f"[{','.join(names)}]")
#
# # Awaited links
# d_approx = max(len(c1), len(c2))
# awaited = d_approx * (d_approx - 1) / 2
#
# return abs(awaited - nb_links)
def clique_divergence(c1_idx, c1, c2): def clique_divergence(c1_idx, c1, c2):
observed_link = len(c1 & c2) # Intersections of the nodes are glued links observed_link = len(c1 & c2) # Intersections of the nodes are glued links
...@@ -108,13 +103,9 @@ class CliqueDGFactory(AbstractDGFactory): ...@@ -108,13 +103,9 @@ class CliqueDGFactory(AbstractDGFactory):
for idx1, idx2 in clq_G.edges(): for idx1, idx2 in clq_G.edges():
clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight'] clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight']
if self.debug and len(clq_G.nodes) > 0:
import os
nx.write_gexf(clq_G, f"{self.mwm_dir}/{central_node.replace('/', '-')}.gexf")
# d-graph computation regarding max weight matching # d-graph computation regarding max weight matching
mwm = nx.algorithms.max_weight_matching(clq_G) mwm = nx.algorithms.max_weight_matching(clq_G)
mwm_results = []
for idx1, idx2 in mwm: for idx1, idx2 in mwm:
# Get cliques # Get cliques
clq1 = cliques[idx1] clq1 = cliques[idx1]
...@@ -124,4 +115,15 @@ class CliqueDGFactory(AbstractDGFactory): ...@@ -124,4 +115,15 @@ class CliqueDGFactory(AbstractDGFactory):
d_graph.put_halves(list(clq1), list(clq2), subgraph) d_graph.put_halves(list(clq1), list(clq2), subgraph)
node_d_graphs.add(d_graph) node_d_graphs.add(d_graph)
if self.debug is not None:
mwm_results.append(" <-> ".join(sorted([clique_names[idx1], clique_names[idx2]])))
if self.debug and len(clq_G.nodes) > 0:
name_mapping = {idx:clique_names[idx] for idx in clq_G.nodes}
clq_G = nx.relabel_nodes(clq_G, name_mapping)
nx.write_gexf(clq_G, f"{self.mwm_dir}/{central_node.replace('/', '-')}.gexf")
with open(f"{self.mwm_dir}/{central_node.replace('/', '-')}_matching.txt", "w") as matching:
for result in mwm_results:
matching.write(f"{result}\n")
return node_d_graphs return node_d_graphs
...@@ -10,8 +10,8 @@ class Dgraph(object): ...@@ -10,8 +10,8 @@ class Dgraph(object):
self.idx = -1 self.idx = -1
self.center = center self.center = center
self.score = 0 self.score = 0
self.halves = [None,None] self.halves = [[], []]
self.connexity = [None,None] self.connexity = [[], []]
self.nodes = [self.center] self.nodes = [self.center]
self.node_set = set(self.nodes) self.node_set = set(self.nodes)
self.edges = [] self.edges = []
...@@ -20,7 +20,6 @@ class Dgraph(object): ...@@ -20,7 +20,6 @@ class Dgraph(object):
self.marked = False self.marked = False
""" Static method to load a dgraph from a text """ Static method to load a dgraph from a text
@param text the saved d-graph @param text the saved d-graph
@param barcode_graph Barcode graph from which the d-graph is extracted @param barcode_graph Barcode graph from which the d-graph is extracted
...@@ -46,7 +45,6 @@ class Dgraph(object): ...@@ -46,7 +45,6 @@ class Dgraph(object):
return dg return dg
""" Compute the d-graph quality (score) according to the connectivity between the two halves. """ Compute the d-graph quality (score) according to the connectivity between the two halves.
@param h1 First half of the d-graph @param h1 First half of the d-graph
@param h2 Second half of the d-graph @param h2 Second half of the d-graph
...@@ -92,22 +90,18 @@ class Dgraph(object): ...@@ -92,22 +90,18 @@ class Dgraph(object):
self.halves[0].sort(reverse=True, key=lambda v: connex[0][v]) self.halves[0].sort(reverse=True, key=lambda v: connex[0][v])
self.halves[1].sort(reverse=True, key=lambda v: connex[1][v]) self.halves[1].sort(reverse=True, key=lambda v: connex[1][v])
def get_link_divergence(self): def get_link_divergence(self):
return int(abs(self.score - self.get_optimal_score())) return int(abs(self.score - self.get_optimal_score()))
def get_optimal_score(self): def get_optimal_score(self):
max_len = max(len(self.halves[0]), len(self.halves[1])) max_len = max(len(self.halves[0]), len(self.halves[1]))
return int(max_len * (max_len - 1) / 2) return int(max_len * (max_len - 1) / 2)
def to_sorted_list(self): def to_sorted_list(self):
if self.sorted_list is None: if self.sorted_list is None:
self.sorted_list = sorted(self.nodes) self.sorted_list = sorted(self.nodes)
return self.sorted_list return self.sorted_list
def to_ordered_lists(self): def to_ordered_lists(self):
if self.ordered_list is None: if self.ordered_list is None:
hands = [[],[]] hands = [[],[]]
...@@ -124,10 +118,24 @@ class Dgraph(object): ...@@ -124,10 +118,24 @@ class Dgraph(object):
self.ordered_list = hands[0][::-1] + [[self.center]] + hands[1] self.ordered_list = hands[0][::-1] + [[self.center]] + hands[1]
return self.ordered_list return self.ordered_list
def to_node_set(self): def to_node_set(self):
return frozenset(self.to_sorted_list()) return frozenset(self.to_sorted_list())
def to_uniq_triplet(self):
""" Return the triplet (center, left_clique, right_clique) where the cliques are sorted by name.
The left clique have a lexicography order smaller than the right one.
"""
left = sorted(self.halves[0])
left_repr = ",".join(str(node) for node in left)
right = sorted(self.halves[1])
right_repr = ",".join(str(node) for node in right)
if left_repr > right_repr:
save = left
left = right
right = save
return self.center, left, right
def distance_to(self, dgraph): def distance_to(self, dgraph):
nodes_1 = self.to_sorted_list() nodes_1 = self.to_sorted_list()
...@@ -149,7 +157,6 @@ class Dgraph(object): ...@@ -149,7 +157,6 @@ class Dgraph(object):
return dist return dist
""" Verify if dg1 is dominated by dg2. The domination is determined by two points: All the nodes """ Verify if dg1 is dominated by dg2. The domination is determined by two points: All the nodes
of dg1 are part of dg2 and the divergeance of dg1 is greater than dg2. of dg1 are part of dg2 and the divergeance of dg1 is greater than dg2.
@param dg1 (resp dg2) A d_graph object. @param dg1 (resp dg2) A d_graph object.
...@@ -172,7 +179,6 @@ class Dgraph(object): ...@@ -172,7 +179,6 @@ class Dgraph(object):
return False return False
def __eq__(self, other): def __eq__(self, other):
if other is None: if other is None:
return False return False
...@@ -183,7 +189,7 @@ class Dgraph(object): ...@@ -183,7 +189,7 @@ class Dgraph(object):
if self.node_set != other.node_set: if self.node_set != other.node_set:
return False return False
return self.to_ordered_lists() == other.to_ordered_lists() return self.to_uniq_triplet() == other.to_uniq_triplet()
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
...@@ -193,36 +199,16 @@ class Dgraph(object): ...@@ -193,36 +199,16 @@ class Dgraph(object):
other_tuple = (other.get_link_divergence(), other.get_optimal_score()) other_tuple = (other.get_link_divergence(), other.get_optimal_score())
return my_tuple < other_tuple return my_tuple < other_tuple
def __hash__(self): def __hash__(self):
nodelist = self.to_sorted_list() return str(self).__hash__()
nodelist = [str(x) for x in nodelist]
return ",".join(nodelist).__hash__()
def __ordered_hash__(self):
lst = self.to_ordered_lists()
fwd_uniq_lst = [sorted(l) for l in lst]
fwd_str = ",".join([f"[{'-'.join(l)}]" for l in fwd_uniq_lst])
fwd_hash = fwd_str.__hash__()
rev_uniq_lst = [sorted(l) for l in lst[::-1]]
rev_str = ",".join([f"[{'-'.join(l)}]" for l in rev_uniq_lst])
rev_hash = rev_str.__hash__()
return int(min(fwd_hash, rev_hash)) def __full_repr__(self):
def __repr__(self):
# print(self.halves) # print(self.halves)
representation = str(self.center) + " " representation = str(self.center) + " "
representation += "[" + ", ".join([f"{node} {self.connexity[0][node]}" for node in self.halves[0]]) + "]" representation += "[" + ", ".join([f"{node} {self.connexity[0][node]}" for node in self.halves[0]]) + "]"
representation += "[" + ", ".join([f"{node} {self.connexity[1][node]}" for node in self.halves[1]]) + "]" representation += "[" + ", ".join([f"{node} {self.connexity[1][node]}" for node in self.halves[1]]) + "]"
return representation return representation
def _to_str_nodes(self): def __repr__(self):
str_nodes = [str(x) for x in self.nodes] c, left, right = self.to_uniq_triplet()
str_nodes.sort() return f"[{c}][{','.join(str(x) for x in left)}][{','.join(str(x) for x in right)}]"
return str(str_nodes)
...@@ -10,8 +10,9 @@ from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory ...@@ -10,8 +10,9 @@ from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="Tests on graph barcode") parser = argparse.ArgumentParser(description="Tests on graph barcode")
parser.add_argument('barcode_graph', help='The barcode graph file. Must be a gexf formatted file.') parser.add_argument('barcode_graph', help='The barcode graph file. Must be a gexf formatted file.')
parser.add_argument('--threads', '-t', type=int, help="Number of threads to use (Set 1 for profiling)") parser.add_argument('--threads', '-t', type=int, default=1, help="Number of threads to use (Set 1 for profiling)")
parser.add_argument('--verbose', '-v', action='store_true', help="Set the verbose flag") parser.add_argument('--verbose', '-v', action='store_true', help="Set the verbose flag")
parser.add_argument('--debug', '-d', action="store_true", help="Debug flag. Write the clique graph and the matching if true")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -102,18 +103,64 @@ def analyse_d_graphs(barcode_graph, threads=8, verbose=False): ...@@ -102,18 +103,64 @@ def analyse_d_graphs(barcode_graph, threads=8, verbose=False):
def main(): def main():
args = parse_arguments() args = parse_arguments()
g = nx.read_gexf(args.barcode_graph) g = nx.read_gexf(args.barcode_graph)
# prev_time = time.time() continuous, total = analyse_clique_graph(g)
# continuous, total = analyse_clique_graph(g) print("cliques")
# print("cliques", time.time() - prev_time) print(continuous, "/", total)
# print(continuous, "/", total)
prev_time = time.time()
continuous, total = analyse_d_graphs(g, threads=args.threads, verbose=args.verbose) continuous, total = analyse_d_graphs(g, threads=args.threads, verbose=args.verbose)
print("udgs", time.time() - prev_time) print("udgs")
print(continuous, "/", total) print(continuous, "/", total)
def main2():
args = parse_arguments()
g = nx.read_gexf(args.barcode_graph)
# Generate udgs
debug_path = f"/tmp/debug_{time.time()}"
if args.debug:
import os
os.mkdir(debug_path)
factory = CliqueDGFactory(g, 1, debug=args.debug, debug_path=debug_path)
udg_per_node = factory.generate_all_dgraphs(threads=args.threads, verbose=args.verbose)
# Remove duplicate udgs
udgs = set()
for udg_node_lst in udg_per_node.values():
for udg in udg_node_lst:
udgs.add(udg)
continuous = 0
for udg in udgs:
if is_continuous(iterable_to_barcode_multiset(udg.to_sorted_list())):
continuous += 1
# # Save the udgs
# with open("saved_udgs.txt", "w") as f:
# for udg in udgs:
# print(udg, file=f)
# Reload the previous udgs
prev_udgs = set()
udgs = set([str(x) for x in udgs])
udg_lst = []
with open("saved_udgs.txt") as f:
for line in f:
prev_udgs.add(line.strip())
udg_lst.append(line.strip())
prev_disapeared = prev_udgs - udgs
unknown_udgs = udgs - prev_udgs
first = None
for udg in udg_lst:
if udg in prev_disapeared:
first = udg
break
print("First missing", first)
print(len(prev_disapeared), prev_disapeared)
print()
print(len(unknown_udgs), unknown_udgs)
print(continuous, len(udgs))
if __name__ == "__main__": if __name__ == "__main__":
# import cProfile
# cProfile.run('main()')
main() main()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment