Commit 9a61b76f authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

opti: divide the d2 computation time by several order of magnitude

parent a7343c1f
......@@ -14,7 +14,9 @@ class Dgraph(object):
self.halves = [None,None]
self.connexity = [None,None]
self.nodes = [self.center]
self.node_set = set(self.center)
self.edges = []
self.ordered_list = None
""" Static method to load a dgraph from a text
......@@ -51,7 +53,12 @@ class Dgraph(object):
def put_halves(self, h1, h2, graph):
self.score = 0
self.halves[0] = h1
for node in h1:
self.node_set.add(node)
self.halves[1] = h2
for node in h2:
self.node_set.add(node)
self.nodes = sorted([self.center] + self.halves[0] + self.halves[1])
self.connexity[0] = {key: 0 for key in self.halves[0]}
self.connexity[1] = {key: 0 for key in self.halves[1]}
......@@ -96,6 +103,7 @@ class Dgraph(object):
def to_ordered_lists(self):
if self.ordered_list is None:
hands = [[],[]]
for idx in range(2):
prev_connectivity = -1
......@@ -107,7 +115,8 @@ class Dgraph(object):
prev_connectivity = value
hands[idx][-1].append(node)
return hands[0][::-1] + [[self.center]] + hands[1]
self.ordered_list = hands[0][::-1] + [[self.center]] + hands[1]
return self.ordered_list
def to_node_multiset(self):
......@@ -119,7 +128,7 @@ class Dgraph(object):
dist = 0
idx1, idx2 = 0, 0
while(idx1 != len(self.nodes) and idx2 != len(other_nodes)):
while idx1 != len(self.nodes) and idx2 != len(other_nodes):
if self.nodes[idx1] == other_nodes[idx2]:
idx1 += 1
idx2 += 1
......@@ -158,7 +167,13 @@ class Dgraph(object):
def __eq__(self, other):
if other == None:
if other is None:
return False
if self.idx == other.idx:
return True
if self.node_set != other.node_set:
return False
return self.to_ordered_lists() == other.to_ordered_lists()
......
......@@ -9,8 +9,9 @@ import graph_manipulator as gm
def parse_arguments():
parser = argparse.ArgumentParser(description='Transform a 10X molecule graph into a 10X barcode graph.')
parser.add_argument('--merging_depth', '-m', type=int, required=True, help='Number of nodes to merge together')
parser.add_argument('--input_graph', '-i', required=True, help='A 10X molecule graph gexf formated.')
parser.add_argument('--merging_depth', '-m', type=int, required=True, help='Average number of nodes to merge together.')
parser.add_argument('--deviation', '-d', type=float, default=0.0, help='Standard deviation for the number of node to merge.')
parser.add_argument('--input_graph', '-i', required=True, help='A 10X molecule graph gexf formatted.')
parser.add_argument('--output', '-o', help="Output filename")
parser.add_argument('--random_seed', '-s', type=int, help="If you want to fix the random seed for reproducibility")
......@@ -18,21 +19,24 @@ def parse_arguments():
return args
""" Take a molecule d-graph chain and merge the nodes uniformly to obtain a barcode graph.
@param G A molecule graph
@param merging_depth The number of nodes to merge from the original graph to obtain one node of the barcode graph
@return The merged barcode graph
"""
def fusion_graph(G, merging_depth):
def fusion_graph(G, merging_depth, std_dev=0):
""" Take a molecule d-graph chain and merge the nodes to obtain a barcode graph.
:param G A molecule graph
:param merging_depth The average number of nodes to merge from the original graph to obtain one node of the barcode graph
:param std_dev the standard deviation to apply (0 = uniform merging)
:return The merged barcode graph
"""
nodes = list(G.nodes())
random.shuffle(nodes)
label = 0
bijective_labels = {}
idx=0
for idx in range(0, len(nodes), merging_depth):
while idx < len(nodes):
merging_size = max(1, min(round(random.gauss(merging_depth, std_dev)), len(nodes) - idx))
# Extract values to merge
sublist = nodes[idx : idx+merging_depth]
sublist = nodes[idx: idx+merging_size]
# Merge nodes
merged = sublist[0]
for sub_idx in range(1, len(sublist)):
......@@ -41,6 +45,7 @@ def fusion_graph(G, merging_depth):
# Label the node
bijective_labels[merged] = f"{label}:{merged}"
label += 1
idx += merging_size
# Relabel all the nodes
G = nx.relabel_nodes(G, bijective_labels)
......@@ -48,8 +53,8 @@ def fusion_graph(G, merging_depth):
return G
def save_graph(G, outfile):
nx.write_gexf(G, outfile)
def save_graph(graph, filename):
nx.write_gexf(graph, filename)
if __name__ == "__main__":
......@@ -59,7 +64,7 @@ if __name__ == "__main__":
random.seed(args.random_seed)
G = nx.read_gexf(args.input_graph)
G = fusion_graph(G, args.merging_depth)
G = fusion_graph(G, args.merging_depth, args.deviation)
outfile = f"simulated_barcodes_{args.merging_depth}.gexf"
if args.output:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment