Commit 9a61b76f authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

opti: divide the d2 computation time by several order of magnitude

parent a7343c1f
...@@ -14,7 +14,9 @@ class Dgraph(object): ...@@ -14,7 +14,9 @@ class Dgraph(object):
self.halves = [None,None] self.halves = [None,None]
self.connexity = [None,None] self.connexity = [None,None]
self.nodes = [self.center] self.nodes = [self.center]
self.node_set = set(self.center)
self.edges = [] self.edges = []
self.ordered_list = None
""" Static method to load a dgraph from a text """ Static method to load a dgraph from a text
...@@ -51,7 +53,12 @@ class Dgraph(object): ...@@ -51,7 +53,12 @@ class Dgraph(object):
def put_halves(self, h1, h2, graph): def put_halves(self, h1, h2, graph):
self.score = 0 self.score = 0
self.halves[0] = h1 self.halves[0] = h1
for node in h1:
self.node_set.add(node)
self.halves[1] = h2 self.halves[1] = h2
for node in h2:
self.node_set.add(node)
self.nodes = sorted([self.center] + self.halves[0] + self.halves[1]) self.nodes = sorted([self.center] + self.halves[0] + self.halves[1])
self.connexity[0] = {key: 0 for key in self.halves[0]} self.connexity[0] = {key: 0 for key in self.halves[0]}
self.connexity[1] = {key: 0 for key in self.halves[1]} self.connexity[1] = {key: 0 for key in self.halves[1]}
...@@ -96,18 +103,20 @@ class Dgraph(object): ...@@ -96,18 +103,20 @@ class Dgraph(object):
def to_ordered_lists(self): def to_ordered_lists(self):
hands = [[],[]] if self.ordered_list is None:
for idx in range(2): hands = [[],[]]
prev_connectivity = -1 for idx in range(2):
for node in self.halves[idx]: prev_connectivity = -1
# group nodes by similar connectivity for node in self.halves[idx]:
value = self.connexity[idx][node] # group nodes by similar connectivity
if value != prev_connectivity: value = self.connexity[idx][node]
hands[idx].append([]) if value != prev_connectivity:
prev_connectivity = value hands[idx].append([])
hands[idx][-1].append(node) prev_connectivity = value
hands[idx][-1].append(node)
return hands[0][::-1] + [[self.center]] + hands[1]
self.ordered_list = hands[0][::-1] + [[self.center]] + hands[1]
return self.ordered_list
def to_node_multiset(self): def to_node_multiset(self):
...@@ -119,7 +128,7 @@ class Dgraph(object): ...@@ -119,7 +128,7 @@ class Dgraph(object):
dist = 0 dist = 0
idx1, idx2 = 0, 0 idx1, idx2 = 0, 0
while(idx1 != len(self.nodes) and idx2 != len(other_nodes)): while idx1 != len(self.nodes) and idx2 != len(other_nodes):
if self.nodes[idx1] == other_nodes[idx2]: if self.nodes[idx1] == other_nodes[idx2]:
idx1 += 1 idx1 += 1
idx2 += 1 idx2 += 1
...@@ -158,7 +167,13 @@ class Dgraph(object): ...@@ -158,7 +167,13 @@ class Dgraph(object):
def __eq__(self, other): def __eq__(self, other):
if other == None: if other is None:
return False
if self.idx == other.idx:
return True
if self.node_set != other.node_set:
return False return False
return self.to_ordered_lists() == other.to_ordered_lists() return self.to_ordered_lists() == other.to_ordered_lists()
......
...@@ -9,8 +9,9 @@ import graph_manipulator as gm ...@@ -9,8 +9,9 @@ import graph_manipulator as gm
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='Transform a 10X molecule graph into a 10X barcode graph.') parser = argparse.ArgumentParser(description='Transform a 10X molecule graph into a 10X barcode graph.')
parser.add_argument('--merging_depth', '-m', type=int, required=True, help='Number of nodes to merge together') parser.add_argument('--merging_depth', '-m', type=int, required=True, help='Average number of nodes to merge together.')
parser.add_argument('--input_graph', '-i', required=True, help='A 10X molecule graph gexf formated.') parser.add_argument('--deviation', '-d', type=float, default=0.0, help='Standard deviation for the number of node to merge.')
parser.add_argument('--input_graph', '-i', required=True, help='A 10X molecule graph gexf formatted.')
parser.add_argument('--output', '-o', help="Output filename") parser.add_argument('--output', '-o', help="Output filename")
parser.add_argument('--random_seed', '-s', type=int, help="If you want to fix the random seed for reproducibility") parser.add_argument('--random_seed', '-s', type=int, help="If you want to fix the random seed for reproducibility")
...@@ -18,21 +19,24 @@ def parse_arguments(): ...@@ -18,21 +19,24 @@ def parse_arguments():
return args return args
""" Take a molecule d-graph chain and merge the nodes uniformly to obtain a barcode graph. def fusion_graph(G, merging_depth, std_dev=0):
@param G A molecule graph """ Take a molecule d-graph chain and merge the nodes to obtain a barcode graph.
@param merging_depth The number of nodes to merge from the original graph to obtain one node of the barcode graph :param G A molecule graph
@return The merged barcode graph :param merging_depth The average number of nodes to merge from the original graph to obtain one node of the barcode graph
""" :param std_dev the standard deviation to apply (0 = uniform merging)
def fusion_graph(G, merging_depth): :return The merged barcode graph
"""
nodes = list(G.nodes()) nodes = list(G.nodes())
random.shuffle(nodes) random.shuffle(nodes)
label = 0 label = 0
bijective_labels = {} bijective_labels = {}
idx=0
for idx in range(0, len(nodes), merging_depth): while idx < len(nodes):
merging_size = max(1, min(round(random.gauss(merging_depth, std_dev)), len(nodes) - idx))
# Extract values to merge # Extract values to merge
sublist = nodes[idx : idx+merging_depth] sublist = nodes[idx: idx+merging_size]
# Merge nodes # Merge nodes
merged = sublist[0] merged = sublist[0]
for sub_idx in range(1, len(sublist)): for sub_idx in range(1, len(sublist)):
...@@ -41,6 +45,7 @@ def fusion_graph(G, merging_depth): ...@@ -41,6 +45,7 @@ def fusion_graph(G, merging_depth):
# Label the node # Label the node
bijective_labels[merged] = f"{label}:{merged}" bijective_labels[merged] = f"{label}:{merged}"
label += 1 label += 1
idx += merging_size
# Relabel all the nodes # Relabel all the nodes
G = nx.relabel_nodes(G, bijective_labels) G = nx.relabel_nodes(G, bijective_labels)
...@@ -48,8 +53,8 @@ def fusion_graph(G, merging_depth): ...@@ -48,8 +53,8 @@ def fusion_graph(G, merging_depth):
return G return G
def save_graph(G, outfile): def save_graph(graph, filename):
nx.write_gexf(G, outfile) nx.write_gexf(graph, filename)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -59,7 +64,7 @@ if __name__ == "__main__": ...@@ -59,7 +64,7 @@ if __name__ == "__main__":
random.seed(args.random_seed) random.seed(args.random_seed)
G = nx.read_gexf(args.input_graph) G = nx.read_gexf(args.input_graph)
G = fusion_graph(G, args.merging_depth) G = fusion_graph(G, args.merging_depth, args.deviation)
outfile = f"simulated_barcodes_{args.merging_depth}.gexf" outfile = f"simulated_barcodes_{args.merging_depth}.gexf"
if args.output: if args.output:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment