From 1ff783871fe4d9e38b039702e9b6073504d9e665 Mon Sep 17 00:00:00 2001
From: Yoann Dufresne <yoann.dufresne0@gmail.com>
Date: Tue, 5 May 2020 12:43:19 +0200
Subject: [PATCH] Cliques vs udgs experiments$

---
 Snakefile_clique_experiments            | 56 ++++++++++++++++++---
 deconvolution/dgraph/CliqueDGFactory.py | 34 +++++++------
 deconvolution/dgraph/d_graph.py         | 60 +++++++++--------------
 experiments/clique_graph_eval.py        | 65 +++++++++++++++++++++----
 4 files changed, 145 insertions(+), 70 deletions(-)

diff --git a/Snakefile_clique_experiments b/Snakefile_clique_experiments
index 2c72059..86fde6d 100644
--- a/Snakefile_clique_experiments
+++ b/Snakefile_clique_experiments
@@ -1,22 +1,62 @@
+from progressbar import ProgressBar
+pbar = ProgressBar()
 
-OUTDIR="snake_exec" if "outdir" not in config else config["outdir"]
-N=[10000] if "n" not in config else config["n"] # Number of molecule to simulate
-D=[5] if "d" not in config else config["d"] # Average coverage of each molecule
-M=[2] if "m" not in config else config["m"] # Average number of molecule per barcode
-M_DEV=[0] if "m_dev" not in config else config["m_dev"] # Std deviation for merging number
+
+OUTDIR="snake_experiments" if "outdir" not in config else config["outdir"]
+trials=1
+N=[1000] if "n" not in config else config["n"] # Number of molecule to simulate
+D=[10] if "d" not in config else config["d"] # Average coverage of each molecule
+M=[2, 3] if "m" not in config else config["m"] # Average number of molecule per barcode
+# M_DEV=[0, 0.5, 1] if "m_dev" not in config else config["m_dev"] # Std deviation for merging number
 
 
 rule all:
     input:
-        expand(f"{OUTDIR}/simu_bar_n{{n}}_d{{d}}_m{{m}}-dev{{md}}.gexf", n=N, m=M, d=D, md=M_DEV)
+        f"{OUTDIR}/results.tsv"
+
+rule generate_tsv:
+    input:
+        expand(f"{OUTDIR}/simu_{{exp}}_bar_n{{n}}_d{{d}}_m{{m}}_results.txt", exp=list(range(trials)), n=N, m=M, d=D)  #, md=M_DEV)
+    output:
+        f"{OUTDIR}/results.tsv"
+    run:
+        with open(str(output), "w") as out:
+            out.write("nb_mols\tfusion\trun\ttp_cliques\tcliques\ttp_udgs\tudgs\n")
+            for file in input:
+                # Values extraction
+                tp_clqs=0
+                clqs=0
+                tp_udgs=0
+                udgs=0
+                with open(file) as inp:
+                    lines = inp.readlines()
+                    tp_clqs, clqs = [int(x) for x in lines[1].strip().split(" / ")]
+                    tp_udgs, udgs = [int(x) for x in lines[3].strip().split(" / ")]
+
+                # Get the important values from the filename
+                file = file.split("/")[-1]
+                names = file.split("_")
+                idx, nb_mols, fusion = int(names[1]), int(names[3][1:]), int(names[5][1:].split(".")[0])
+
+                out.write(f"{nb_mols}\t{fusion}\t{idx}\t{tp_clqs}\t{clqs}\t{tp_udgs}\t{udgs}\n")
+
+
+rule mesure_quality:
+    input:
+        f"{OUTDIR}/simu_{{exp}}_bar_n{{n}}_d{{d}}_m{{m}}.gexf"
+    output:
+        f"{OUTDIR}/simu_{{exp}}_bar_n{{n}}_d{{d}}_m{{m}}_results.txt"
+    shell:
+        "python3 experiments/clique_graph_eval.py {input} > {output}"
+
 
 rule generate_barcodes:
     input:
         "{path}/simu_mol_{params}.gexf"
     output:
-        "{path}/simu_bar_{params}_m{m}-dev{md}.gexf"
+        "{path}/simu_{idx}_bar_{params}_m{m}.gexf"
     shell:
-        "python3 deconvolution/main/generate_fake_barcode_graph.py --merging_depth {wildcards.m} --deviation {wildcards.md} --input_graph {input} --output {output}"
+        "python3 deconvolution/main/generate_fake_barcode_graph.py --merging_depth {wildcards.m} --input_graph {input} --output {output}"
 
 rule generate_molecules:
     output:
diff --git a/deconvolution/dgraph/CliqueDGFactory.py b/deconvolution/dgraph/CliqueDGFactory.py
index 9fd0564..a90135a 100644
--- a/deconvolution/dgraph/CliqueDGFactory.py
+++ b/deconvolution/dgraph/CliqueDGFactory.py
@@ -29,6 +29,7 @@ class CliqueDGFactory(AbstractDGFactory):
 
         # Clique computation
         cliques = []
+        clique_names = []
         clique_neighbors_multiset = []
         clique_neighbors_set = []
         clq_per_node = {node: [] for node in subgraph.nodes}
@@ -51,17 +52,11 @@ class CliqueDGFactory(AbstractDGFactory):
                 clique_neighbors_set.append(frozenset(ms))
                 idx += 1
 
-        # def clique_divergence(c1, c2):
-        #     # Observed link
-        #     nb_links = 0
-        #     for node in c1:
-        #         neighbors = clique_neighbors[node]
-        #
-        #     # Awaited links
-        #     d_approx = max(len(c1), len(c2))
-        #     awaited = d_approx * (d_approx - 1) / 2
-        #
-        #     return abs(awaited - nb_links)
+        if self.debug is not None:
+            for clique in cliques:
+                names = [str(n) for n in clique]
+                names.sort()
+                clique_names.append(f"[{','.join(names)}]")
 
         def clique_divergence(c1_idx, c1, c2):
             observed_link = len(c1 & c2)  # Intersections of the nodes are glued links
@@ -108,13 +103,9 @@ class CliqueDGFactory(AbstractDGFactory):
         for idx1, idx2 in clq_G.edges():
             clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight']
 
-        if self.debug and len(clq_G.nodes) > 0:
-            import os
-
-            nx.write_gexf(clq_G, f"{self.mwm_dir}/{central_node.replace('/', '-')}.gexf")
-
         # d-graph computation regarding max weight matching
         mwm = nx.algorithms.max_weight_matching(clq_G)
+        mwm_results = []
         for idx1, idx2 in mwm:
             # Get cliques
             clq1 = cliques[idx1]
@@ -124,4 +115,15 @@ class CliqueDGFactory(AbstractDGFactory):
             d_graph.put_halves(list(clq1), list(clq2), subgraph)
             node_d_graphs.add(d_graph)
 
+            if self.debug is not None:
+                mwm_results.append(" <-> ".join(sorted([clique_names[idx1], clique_names[idx2]])))
+
+        if self.debug and len(clq_G.nodes) > 0:
+            name_mapping = {idx:clique_names[idx] for idx in clq_G.nodes}
+            clq_G = nx.relabel_nodes(clq_G, name_mapping)
+            nx.write_gexf(clq_G, f"{self.mwm_dir}/{central_node.replace('/', '-')}.gexf")
+            with open(f"{self.mwm_dir}/{central_node.replace('/', '-')}_matching.txt", "w") as matching:
+                for result in mwm_results:
+                    matching.write(f"{result}\n")
+
         return node_d_graphs
diff --git a/deconvolution/dgraph/d_graph.py b/deconvolution/dgraph/d_graph.py
index 17ce191..865ac63 100644
--- a/deconvolution/dgraph/d_graph.py
+++ b/deconvolution/dgraph/d_graph.py
@@ -10,8 +10,8 @@ class Dgraph(object):
         self.idx = -1
         self.center = center
         self.score = 0
-        self.halves = [None,None]
-        self.connexity = [None,None]
+        self.halves = [[], []]
+        self.connexity = [[], []]
         self.nodes = [self.center]
         self.node_set = set(self.nodes)
         self.edges = []
@@ -20,7 +20,6 @@ class Dgraph(object):
 
         self.marked = False
 
-
     """ Static method to load a dgraph from a text
         @param text the saved d-graph
         @param barcode_graph Barcode graph from which the d-graph is extracted
@@ -46,7 +45,6 @@ class Dgraph(object):
 
         return dg
 
-
     """ Compute the d-graph quality (score) according to the connectivity between the two halves.
         @param h1 First half of the d-graph
         @param h2 Second half of the d-graph
@@ -92,22 +90,18 @@ class Dgraph(object):
         self.halves[0].sort(reverse=True, key=lambda v: connex[0][v])
         self.halves[1].sort(reverse=True, key=lambda v: connex[1][v])
 
-
     def get_link_divergence(self):
         return int(abs(self.score - self.get_optimal_score()))
 
-
     def get_optimal_score(self):
         max_len = max(len(self.halves[0]), len(self.halves[1]))
         return int(max_len * (max_len - 1) / 2)
 
-
     def to_sorted_list(self):
         if self.sorted_list is None:
             self.sorted_list = sorted(self.nodes)
         return self.sorted_list
 
-
     def to_ordered_lists(self):
         if self.ordered_list is None:
             hands = [[],[]]
@@ -124,10 +118,24 @@ class Dgraph(object):
             self.ordered_list = hands[0][::-1] + [[self.center]] + hands[1]
         return self.ordered_list
 
-
     def to_node_set(self):
         return frozenset(self.to_sorted_list())
 
+    def to_uniq_triplet(self):
+        """ Return the triplet (center, left_clique, right_clique) where the cliques are sorted by name.
+            The left clique have a lexicography order smaller than the right one.
+        """
+        left = sorted(self.halves[0])
+        left_repr = ",".join(str(node) for node in left)
+        right = sorted(self.halves[1])
+        right_repr = ",".join(str(node) for node in right)
+
+        if left_repr > right_repr:
+            save = left
+            left = right
+            right = save
+
+        return self.center, left, right
 
     def distance_to(self, dgraph):
         nodes_1 = self.to_sorted_list()
@@ -149,7 +157,6 @@ class Dgraph(object):
 
         return dist
 
-
     """ Verify if dg1 is dominated by dg2. The domination is determined by two points: All the nodes
     of dg1 are part of dg2 and the divergeance of dg1 is greater than dg2.
     @param dg1 (resp dg2) A d_graph object.
@@ -172,7 +179,6 @@ class Dgraph(object):
 
         return False
 
-
     def __eq__(self, other):
         if other is None:
             return False
@@ -183,7 +189,7 @@ class Dgraph(object):
         if self.node_set != other.node_set:
             return False
 
-        return self.to_ordered_lists() == other.to_ordered_lists()
+        return self.to_uniq_triplet() == other.to_uniq_triplet()
 
     def __ne__(self, other):
         return not (self == other)
@@ -193,36 +199,16 @@ class Dgraph(object):
         other_tuple = (other.get_link_divergence(), other.get_optimal_score())
         return my_tuple < other_tuple
 
-
     def __hash__(self):
-        nodelist = self.to_sorted_list()
-        nodelist = [str(x) for x in nodelist]
-        return ",".join(nodelist).__hash__()
-
-
-    def __ordered_hash__(self):
-        lst = self.to_ordered_lists()
-
-        fwd_uniq_lst = [sorted(l) for l in lst]
-        fwd_str = ",".join([f"[{'-'.join(l)}]" for l in fwd_uniq_lst])
-        fwd_hash = fwd_str.__hash__()
-
-        rev_uniq_lst = [sorted(l) for l in lst[::-1]]
-        rev_str = ",".join([f"[{'-'.join(l)}]" for l in rev_uniq_lst])
-        rev_hash = rev_str.__hash__()
+        return str(self).__hash__()
 
-        return int(min(fwd_hash, rev_hash))
-
-
-    def __repr__(self):
+    def __full_repr__(self):
         # print(self.halves)
         representation = str(self.center) + " "
         representation += "[" + ", ".join([f"{node} {self.connexity[0][node]}" for node in self.halves[0]]) + "]"
         representation += "[" + ", ".join([f"{node} {self.connexity[1][node]}" for node in self.halves[1]]) + "]"
         return representation
 
-    def _to_str_nodes(self):
-        str_nodes = [str(x) for x in self.nodes]
-        str_nodes.sort()
-        return str(str_nodes)
-
+    def __repr__(self):
+        c, left, right = self.to_uniq_triplet()
+        return f"[{c}][{','.join(str(x) for x in left)}][{','.join(str(x) for x in right)}]"
diff --git a/experiments/clique_graph_eval.py b/experiments/clique_graph_eval.py
index 3b85627..3789a55 100644
--- a/experiments/clique_graph_eval.py
+++ b/experiments/clique_graph_eval.py
@@ -10,8 +10,9 @@ from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
 def parse_arguments():
     parser = argparse.ArgumentParser(description="Tests on graph barcode")
     parser.add_argument('barcode_graph', help='The barcode graph file. Must be a gexf formatted file.')
-    parser.add_argument('--threads', '-t', type=int, help="Number of threads to use (Set 1 for profiling)")
+    parser.add_argument('--threads', '-t', type=int, default=1, help="Number of threads to use (Set 1 for profiling)")
     parser.add_argument('--verbose', '-v', action='store_true', help="Set the verbose flag")
+    parser.add_argument('--debug', '-d', action="store_true", help="Debug flag. Write the clique graph and the matching if true")
 
     args = parser.parse_args()
     return args
@@ -102,18 +103,64 @@ def analyse_d_graphs(barcode_graph, threads=8, verbose=False):
 def main():
     args = parse_arguments()
     g = nx.read_gexf(args.barcode_graph)
-    # prev_time = time.time()
-    # continuous, total = analyse_clique_graph(g)
-    # print("cliques", time.time() - prev_time)
-    # print(continuous, "/", total)
-    prev_time = time.time()
+    continuous, total = analyse_clique_graph(g)
+    print("cliques")
+    print(continuous, "/", total)
     continuous, total = analyse_d_graphs(g, threads=args.threads, verbose=args.verbose)
-    print("udgs", time.time() - prev_time)
+    print("udgs")
     print(continuous, "/", total)
 
 
+def main2():
+    args = parse_arguments()
+    g = nx.read_gexf(args.barcode_graph)
+    # Generate udgs
+    debug_path = f"/tmp/debug_{time.time()}"
+    if args.debug:
+        import os
+        os.mkdir(debug_path)
+    factory = CliqueDGFactory(g, 1, debug=args.debug, debug_path=debug_path)
+    udg_per_node = factory.generate_all_dgraphs(threads=args.threads, verbose=args.verbose)
+    # Remove duplicate udgs
+    udgs = set()
+    for udg_node_lst in udg_per_node.values():
+        for udg in udg_node_lst:
+            udgs.add(udg)
+
+    continuous = 0
+    for udg in udgs:
+        if is_continuous(iterable_to_barcode_multiset(udg.to_sorted_list())):
+            continuous += 1
+
+    # # Save the udgs
+    # with open("saved_udgs.txt", "w") as f:
+    #     for udg in udgs:
+    #         print(udg, file=f)
+    # Reload the previous udgs
+    prev_udgs = set()
+    udgs = set([str(x) for x in udgs])
+    udg_lst = []
+    with open("saved_udgs.txt") as f:
+        for line in f:
+            prev_udgs.add(line.strip())
+            udg_lst.append(line.strip())
+
+    prev_disapeared = prev_udgs - udgs
+    unknown_udgs = udgs - prev_udgs
+
+    first = None
+    for udg in udg_lst:
+        if udg in prev_disapeared:
+            first = udg
+            break
+
+    print("First missing", first)
+    print(len(prev_disapeared), prev_disapeared)
+    print()
+    print(len(unknown_udgs), unknown_udgs)
+
+    print(continuous, len(udgs))
+
 
 if __name__ == "__main__":
-    # import cProfile
-    # cProfile.run('main()')
     main()
-- 
GitLab