Commit 0ff34ba2 authored by Yoann Dufresne's avatar Yoann Dufresne

parallelism upgrade for d2 construction

parent 0f0814d4
......@@ -46,8 +46,9 @@ rule d2_generation:
barcode_graph=f"{WORKDIR}/{{file}}.gexf"
output:
d2_file=f"{WORKDIR}/{{file}}_d2_raw_{{method}}.gexf"
threads: workflow.cores
run:
shell(f"python3 deconvolution/main/to_d2_graph.py {{input.barcode_graph}} --{{wildcards.method}} -o {WORKDIR}/{{wildcards.file}}_d2_raw_{{wildcards.method}}")
shell(f"python3 deconvolution/main/to_d2_graph.py {{input.barcode_graph}} --{{wildcards.method}} -t {{threads}} -o {WORKDIR}/{{wildcards.file}}_d2_raw_{{wildcards.method}}")
rule setup_workdir:
......
import networkx as nx
from abc import abstractmethod
from multiprocessing import Pool
from multiprocessing import Pool, Value
from deconvolution.dgraph.FixedDGIndex import FixedDGIndex
counter = None
locked = False
nb_over = 0
def process_node(factory, node):
neighbors = list(factory.graph.neighbors(node))
subgraph = nx.Graph(factory.graph.subgraph(neighbors))
dgs = factory.generate_by_node(node, subgraph)
global nb_over, locked
nb_over += 1
global counter
counter.value += 1
if factory.verbose:
if not locked:
locked = True
print(f"\r{nb_over}/{factory.nb_nodes} node analysis", end='')
locked = False
print(f"\r{counter.value}/{factory.nb_nodes} node analysis", end='')
return node, dgs
......@@ -27,20 +23,22 @@ class AbstractDGFactory:
self.graph = graph
self.nb_nodes = len(self.graph.nodes())
self.verbose = False
global counter
counter = Value('i', 0)
def generate_all_dgraphs(self, verbose=False, threads=8):
index = FixedDGIndex(size=1)
factory = self
global nb_over
nb_over = 0
nb_nodes = len(self.graph.nodes())
self.verbose = verbose
global counter
counter = Value('i', 0)
results = None
with Pool(processes=threads) as pool:
results = pool.starmap(process_node, zip(
[factory]*nb_nodes,
self.graph.nodes(),
self.graph.nodes()
))
# Fill the index by node
......@@ -49,7 +47,7 @@ class AbstractDGFactory:
for dg in dgs:
index.add_value(key, dg)
index.filter_by_entry()
# index.filter_by_entry()
return index
......
......@@ -41,16 +41,13 @@ class AbstractDGIndex(dict):
pass
def _filter_entry(self, key_set):
@staticmethod
def filter_entry(dgs):
""" For one entry in the index, filter out dominated dgraphs
:param key_set: The entry to filter
:param dgs: d-graph set
:return: filtered set
"""
# Verify presence in the index
if key_set not in self:
raise KeyError("The set is not present in the index")
# n² filtering
dgs = self[key_set]
to_remove = set()
for dg1 in dgs:
......@@ -59,13 +56,12 @@ class AbstractDGIndex(dict):
to_remove.add(dg1)
break
self[key_set] = dgs - to_remove
return to_remove
return dgs - to_remove
def filter_by_entry(self):
for key_set in self:
removed = self._filter_entry(key_set)
self[key_set] = AbstractDGIndex.filter_entry(self[key_set])
# TODO: remove globaly ?
......
......@@ -2,6 +2,8 @@ import networkx as nx
from deconvolution.dgraph.AbstractDGFactory import AbstractDGFactory
from deconvolution.dgraph.d_graph import Dgraph
from deconvolution.dgraph import AbstractDGIndex
class CliqueDGFactory(AbstractDGFactory):
......@@ -75,4 +77,6 @@ class CliqueDGFactory(AbstractDGFactory):
d_graph.put_halves(clq1, clq2, subgraph)
node_d_graphs.add(d_graph)
# Filter dominated
node_d_graphs = AbstractDGIndex.filter_entry(node_d_graphs)
return node_d_graphs
......@@ -90,7 +90,8 @@ def parse_path_graph_frequencies(graph, barcode_graph):
node_per_barcode = {}
for node, data in graph.nodes(data=True):
origin_name = data["center"]
parsed = parse_dg_name(graph, node)
origin_name = parsed[0][1]
if origin_name not in node_per_barcode:
node_per_barcode[origin_name] = []
......
......@@ -11,7 +11,7 @@ def parse_arguments():
parser = argparse.ArgumentParser(description='Transform a 10X barcode graph into a d2 graph. The program dig for the d-graphs and then merge them into a d2-graph.')
parser.add_argument('barcode_graph', help='The barcode graph file. Must be a gefx formated file.')
parser.add_argument('--output_prefix', '-o', default="d2_graph", help="Output file prefix.")
parser.add_argument('--threads', '-t', default=1, type=int, help='Number of thread to use for dgraph computation')
parser.add_argument('--threads', '-t', default=8, type=int, help='Number of thread to use for dgraph computation')
# parser.add_argument('--debug', '-d', action='store_true', help="Debug")
parser.add_argument('--maxclq', '-c', action='store_true', help="Enable max clique community detection (default behaviour)")
parser.add_argument('--louvain', '-l', action='store_true', help="Enable Louvain community detection instead of all max-cliques")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment