Commit 061f4092 authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

optimization for divergence computation

parent 498df34a
...@@ -2,6 +2,7 @@ data/ ...@@ -2,6 +2,7 @@ data/
real_data/ real_data/
art/ art/
**/__pycache__/ **/__pycache__/
*.pyc
.pytest_cache/ .pytest_cache/
.idea/ .idea/
.snakemake/ .snakemake/
......
...@@ -62,7 +62,7 @@ class D2Graph(nx.Graph): ...@@ -62,7 +62,7 @@ class D2Graph(nx.Graph):
return self.subgraph(list(self.nodes())) return self.subgraph(list(self.nodes()))
def construct_from_barcodes(self, neighbor_threshold=0.25, verbose=True, clique_mode=None, threads=1): def construct_from_barcodes(self, neighbor_threshold=0.25, min_size_clique=4, verbose=False, clique_mode=None, threads=1):
# Compute all the d-graphs # Compute all the d-graphs
if verbose: if verbose:
print("Computing the unit d-graphs..") print("Computing the unit d-graphs..")
...@@ -70,8 +70,8 @@ class D2Graph(nx.Graph): ...@@ -70,8 +70,8 @@ class D2Graph(nx.Graph):
if clique_mode == "louvain": if clique_mode == "louvain":
dg_factory = LouvainDGFactory(self.barcode_graph) dg_factory = LouvainDGFactory(self.barcode_graph)
else: else:
dg_factory = CliqueDGFactory(self.barcode_graph, debug=self.debug, debug_path=self.debug_path) dg_factory = CliqueDGFactory(self.barcode_graph, min_size_clique=min_size_clique, debug=self.debug, debug_path=self.debug_path)
self.d_graphs_per_node = dg_factory.generate_all_dgraphs(threads=threads, verbose=True) self.d_graphs_per_node = dg_factory.generate_all_dgraphs(threads=threads, verbose=verbose)
if verbose: if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values()) counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} computed d-graphs") print(f"\t {counts} computed d-graphs")
......
...@@ -17,7 +17,7 @@ def process_node(factory, node): ...@@ -17,7 +17,7 @@ def process_node(factory, node):
sys.stdout.flush() sys.stdout.flush()
# udg generation # udg generation
neighbors = list(factory.graph.neighbors(node)) neighbors = [x for x in factory.graph[node]]
subgraph = nx.Graph(factory.graph.subgraph(neighbors)) subgraph = nx.Graph(factory.graph.subgraph(neighbors))
dgs = factory.generate_by_node(node, subgraph) dgs = factory.generate_by_node(node, subgraph)
...@@ -56,18 +56,25 @@ class AbstractDGFactory: ...@@ -56,18 +56,25 @@ class AbstractDGFactory:
if verbose: if verbose:
print("Start parallel work") print("Start parallel work")
results = None if threads > 1:
with Pool(processes=threads) as pool: results = None
results = pool.starmap(process_node, zip( with Pool(processes=threads) as pool:
[factory]*nb_nodes, results = pool.starmap(process_node, zip(
self.graph.nodes() [factory]*nb_nodes,
)) self.graph.nodes()
))
# Fill the index by node
for node, dgs in results: # Fill the index by node
key = frozenset({node}) for node, dgs in results:
for dg in dgs: key = frozenset({node})
index.add_value(key, dg) for dg in dgs:
index.add_value(key, dg)
else:
for node in self.graph.nodes():
key = frozenset({node})
_, dgs = process_node(factory, node)
for dg in dgs:
index.add_value(key, dg)
return index return index
...@@ -75,5 +82,3 @@ class AbstractDGFactory: ...@@ -75,5 +82,3 @@ class AbstractDGFactory:
@abstractmethod @abstractmethod
def generate_by_node(self, node, subgraph): def generate_by_node(self, node, subgraph):
pass pass
import networkx as nx import networkx as nx
from collections import Counter
from deconvolution.dgraph.AbstractDGFactory import AbstractDGFactory from deconvolution.dgraph.AbstractDGFactory import AbstractDGFactory
from deconvolution.dgraph.d_graph import Dgraph from deconvolution.dgraph.d_graph import Dgraph
...@@ -24,60 +25,86 @@ class CliqueDGFactory(AbstractDGFactory): ...@@ -24,60 +25,86 @@ class CliqueDGFactory(AbstractDGFactory):
def generate_by_node(self, central_node, subgraph): def generate_by_node(self, central_node, subgraph):
node_d_graphs = set() node_d_graphs = set()
node_neighbors = {node: [x for x in subgraph[node]] for node in subgraph.nodes}
# Clique computation # Clique computation
cliques = [] cliques = []
clique_neighbors_multiset = []
clique_neighbors_set = []
clq_per_node = {node: [] for node in subgraph.nodes}
idx = 0
for clique in nx.find_cliques(subgraph): for clique in nx.find_cliques(subgraph):
if len(clique) >= self.min_size: if len(clique) >= self.min_size:
cliques.append(clique) # Create the clique set
clique_set = frozenset(clique)
# index cliques per node cliques.append(clique_set)
clq_per_node = {}
for idx, clq in enumerate(cliques): # Index neighbors of the clique to speedup the divergence computation
for node in clq: neighbors = []
if node not in clq_per_node: for node in clique:
clq_per_node[node] = [] # index clique per node. Useful ?
clq_per_node[node].append(idx) clq_per_node[node].append(idx)
# Prepare a neighbor multiset to speedup the divergence computation
# clq_per_node for nei can be empty because of minimum clique size neighbors.extend(node_neighbors[node])
for node in subgraph.nodes: ms = Counter(neighbors)
if node not in clq_per_node: clique_neighbors_multiset.append(ms)
clq_per_node[node] = [] clique_neighbors_set.append(frozenset(ms))
idx += 1
# def clique_divergence(c1, c2):
# # Observed link
# nb_links = 0
# for node in c1:
# neighbors = clique_neighbors[node]
#
# # Awaited links
# d_approx = max(len(c1), len(c2))
# awaited = d_approx * (d_approx - 1) / 2
#
# return abs(awaited - nb_links)
def clique_divergence(c1_idx, c1, c2):
observed_link = len(c1 & c2) # Intersections of the nodes are glued links
neighbor_intersection = clique_neighbors_set[c1_idx] & c2
neighbors_multiset = clique_neighbors_multiset[c1_idx]
observed_link += sum(neighbors_multiset[x] for x in neighbor_intersection) # Sum the links between the cliques
# Awaited links
d_approx = max(len(c1), len(c2))
awaited = d_approx * (d_approx - 1) / 2
return abs(awaited - observed_link)
# Pair cliques # Pair cliques
clq_pairs = set() def enumerate_clique_pair():
for idx, clq in enumerate(cliques): for clq_idx, clq in enumerate(cliques):
for node in clq: for node in clq:
neighbors = list(subgraph.neighbors(node)) # Looks into the neighbors for clique pairing
# Looks into the neighbors for clique pairing for nei in subgraph[node]:
for nei in neighbors: nei_clqs = clq_per_node[nei]
nei_clqs = clq_per_node[nei] # Pair useful cliques
# Pair useful cliques for nei_clq in nei_clqs:
for nei_clq in nei_clqs: if nei_clq > clq_idx:
if nei_clq > idx: div_clq = clique_divergence(clq_idx, clq, cliques[nei_clq])
clq_pairs.add((idx, nei_clq)) yield clq_idx, nei_clq, div_clq
# Create the clique graph for max weight # Create the clique graph for max weight
clq_G = nx.Graph() clq_G = nx.Graph()
# Create nodes # Create nodes
for idx in range(len(cliques)): for idx in range(len(cliques)):
clq_G.add_node(idx) clq_G.add_node(idx)
# Create edges # Create edges
max_div = 0 max_div = 0
for idx1, idx2 in clq_pairs: for idx1, idx2, div in enumerate_clique_pair():
# Get cliques # Get cliques
clq1 = cliques[idx1] clq1 = cliques[idx1]
clq2 = cliques[idx2] clq2 = cliques[idx2]
# Create candidate udg
d_graph = Dgraph(central_node)
d_graph.put_halves(clq1, clq2, subgraph)
# Add divergence to the clique graph
div = d_graph.get_link_divergence()
if div > max_div: if div > max_div:
max_div = div max_div = div
clq_G.add_edge(idx1, idx2, weight=div) clq_G.add_edge(idx1, idx2, weight=div)
# Normalize the divergence # Normalize the divergence
for idx1, idx2 in clq_pairs: for idx1, idx2 in clq_G.edges():
clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight'] clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight']
if self.debug and len(clq_G.nodes) > 0: if self.debug and len(clq_G.nodes) > 0:
...@@ -93,7 +120,7 @@ class CliqueDGFactory(AbstractDGFactory): ...@@ -93,7 +120,7 @@ class CliqueDGFactory(AbstractDGFactory):
clq2 = cliques[idx2] clq2 = cliques[idx2]
# Create candidate udg # Create candidate udg
d_graph = Dgraph(central_node) d_graph = Dgraph(central_node)
d_graph.put_halves(clq1, clq2, subgraph) d_graph.put_halves(list(clq1), list(clq2), subgraph)
node_d_graphs.add(d_graph) node_d_graphs.add(d_graph)
return node_d_graphs return node_d_graphs
import argparse import argparse
import time
import networkx as nx import networkx as nx
from collections import Counter from collections import Counter
...@@ -79,7 +80,7 @@ def analyse_clique_graph(barcode_graph): ...@@ -79,7 +80,7 @@ def analyse_clique_graph(barcode_graph):
def analyse_d_graphs(barcode_graph): def analyse_d_graphs(barcode_graph):
# Generate udgs # Generate udgs
factory = CliqueDGFactory(barcode_graph, 1) factory = CliqueDGFactory(barcode_graph, 1)
udg_per_node = factory.generate_all_dgraphs() udg_per_node = factory.generate_all_dgraphs(threads=1)
# Remove duplicate udgs # Remove duplicate udgs
udgs = {} udgs = {}
for udg_node_lst in udg_per_node.values(): for udg_node_lst in udg_per_node.values():
...@@ -99,14 +100,18 @@ def analyse_d_graphs(barcode_graph): ...@@ -99,14 +100,18 @@ def analyse_d_graphs(barcode_graph):
def main(): def main():
args = parse_arguments() args = parse_arguments()
g = nx.read_gexf(args.barcode_graph) g = nx.read_gexf(args.barcode_graph)
prev_time = time.time()
continuous, total = analyse_clique_graph(g) continuous, total = analyse_clique_graph(g)
print("cliques") print("cliques", time.time() - prev_time)
print(continuous, "/", total) print(continuous, "/", total)
prev_time = time.time()
continuous, total = analyse_d_graphs(g) continuous, total = analyse_d_graphs(g)
print("udgs") print("udgs", time.time() - prev_time)
print(continuous, "/", total) print(continuous, "/", total)
if __name__ == "__main__": if __name__ == "__main__":
# import cProfile
# cProfile.run('main()')
main() main()
...@@ -4,7 +4,7 @@ from distutils.core import setup ...@@ -4,7 +4,7 @@ from distutils.core import setup
setup( setup(
name='10X-deconvolve', name='10X-deconvolve',
version='0.1dev', version='0.1dev',
packages=['deconvolution.d2graph', 'deconvolution.dgraph', 'deconvolution.main', 'experiments'], packages=['deconvolution.d2graph', 'deconvolution.dgraph', 'deconvolution.main', 'experiments', 'tests'],
license='AGPL V3', license='AGPL V3',
long_description=open('README.md').read(), long_description=open('README.md').read(),
) )
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -8,84 +8,65 @@ from deconvolution.dgraph import graph_manipulator as gm ...@@ -8,84 +8,65 @@ from deconvolution.dgraph import graph_manipulator as gm
class TestD2Graph(unittest.TestCase): class TestD2Graph(unittest.TestCase):
# def test_construction(self):
# d2 = D2Graph(complete_graph) # def test_linear_d2_construction(self):
# d2.construct_from_barcodes(index_size=4, verbose=False) # for d in range(2, 10):
# size = 2 * d + 3
#
# G = gm.generate_d_graph_chain(size, d)
# d2 = D2Graph(G)
# print("before", d)
# d2.construct_from_barcodes(neighbor_threshold=0, min_size_clique=d, verbose=False)
# print("after", d)
#
# # for dg in d2.all_d_graphs:
# # print(dg.score, dg.get_link_divergence(), dg)
# # print()
#
# # Test the number of d-graphs
# awaited_d_num = size - 2 * d
# self.assertEqual(awaited_d_num, len(d2.all_d_graphs))
# #
# # Evaluate the number of candidate unit d_graphs generated # # Test connectivity
# for node, candidates in d2.d_graphs_per_node.items(): # # Center node names
# if node == "C" or node == "B2": # c1 = d
# self.assertEqual(1, len(candidates)) # c2 = d + 1
# else: # c3 = d + 2
# self.assertEqual(0, len(candidates)) # # Connectivity matrix
# awaited_distances = {
# c1: {c2: 2, c3: 4},
# c2: {c1: 2, c3: 2},
# c3: {c1: 4, c2: 2}
# }
# #
# # Evaluate the dgraph # for x, y, data in d2.edges(data=True):
# self.assertEqual(13, len(d2.index)) # dg1 = d2.node_by_idx[x]
# dg2 = d2.node_by_idx[y]
# #
# overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C') # awaited_dist = awaited_distances[dg1.center][dg2.center]
# for dmer, dg_lst in d2.index.items(): # self.assertEqual(data["distance"], awaited_dist)
# if dmer == overlap_key:
# values = list(d2.index[dmer]) def test_no_variability(self):
# self.assertEqual(2, len(d2.index[dmer])) barcode_graph = nx.read_gexf("test_data/bar_1000_5_2.gexf")
# self.assertNotEqual(values[0], values[1]) d2 = D2Graph(barcode_graph)
# else: d2.construct_from_barcodes()
# self.assertEqual(1, len(d2.index[dmer])) udgs = d2.all_d_graphs
def test_linear_d2_construction(self): for _ in range(5):
for d in range(1, 10): d2 = D2Graph(barcode_graph)
size = 2 * d + 3 d2.construct_from_barcodes()
index_k = 2 * d - 1 self.assertEqual(len(udgs), len(d2.all_d_graphs))
G = gm.generate_d_graph_chain(size, d)
d2 = D2Graph(G)
d2.construct_from_barcodes(index_size=index_k, verbose=False)
# for dg in d2.all_d_graphs:
# print(dg.score, dg.get_link_divergence(), dg)
# print()
# Test the number of d-graphs
awaited_d_num = size - 2 * d
self.assertEqual(awaited_d_num, len(d2.all_d_graphs))
# Test dgraph
awaited_index_size = comb(2 * d + 1, index_k) + (size - (2 * d + 1)) * comb(2 * d, index_k - 1)
if len(d2.index) != awaited_index_size:
dmers = [list(x) for x in d2.index]
dmers = [str(x) for x in dmers if len(x) != len(frozenset(x))]
self.assertEqual(awaited_index_size, len(d2.index))
# Test connectivity
# Center node names
c1 = d
c2 = d + 1
c3 = d + 2
# Connectivity matrix
awaited_distances = {
c1: {c2: 2, c3: 4},
c2: {c1: 2, c3: 2},
c3: {c1: 4, c2: 2}
}
for x, y, data in d2.edges(data=True):
dg1 = d2.node_by_idx[x]
dg2 = d2.node_by_idx[y]
awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEqual(data["distance"], awaited_dist)
def test_reloading(self): def test_reloading(self):
# Parameters # Parameters
d = 3 d = 3
size = 2 * d + 3 size = 2 * d + 3
index_k = 2 * d - 1
# Create a d2 graph # Create a d2 graph
G = gm.generate_d_graph_chain(size, d) G = gm.generate_d_graph_chain(size, d)
d2 = D2Graph(G) d2 = D2Graph(G)
d2.construct_from_barcodes(index_size=index_k, verbose=False) d2.construct_from_barcodes(verbose=False)
# Save and reload the d2 in a temporary file # Save and reload the d2 in a temporary file
with tempfile.NamedTemporaryFile() as fp: with tempfile.NamedTemporaryFile() as fp:
......
import unittest import unittest
import networkx as nx
from d_graph_data import unit_d_graph from d_graph_data import unit_d_graph
from deconvolution.dgraph.d_graph import Dgraph from deconvolution.dgraph.d_graph import Dgraph
from deconvolution.dgraph import graph_manipulator as gm from deconvolution.dgraph import graph_manipulator as gm
from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
class TestDGraph(unittest.TestCase): class TestDGraph(unittest.TestCase):
...@@ -60,9 +62,23 @@ class TestDGraph(unittest.TestCase): ...@@ -60,9 +62,23 @@ class TestDGraph(unittest.TestCase):
self.assertEqual([['A0'], ['A1'], ['A2'], ['C'], ['B2'], ['B1'], ['B0']], lst) self.assertEqual([['A0'], ['A1'], ['A2'], ['C'], ['B2'], ['B1'], ['B0']], lst)
def test_generation_no_variability_DGCliqueFactory(self):
# def test_list_dgraphs(self): barcode_graph = nx.read_gexf("test_data/bar_1000_5_2.gexf")
factory = CliqueDGFactory(barcode_graph, 1)
udg_per_node = factory.generate_all_dgraphs()
all_udgs = set()
for udgs in udg_per_node.values():
all_udgs.update(udgs)
size = len(all_udgs)
for _ in range(5):
factory = CliqueDGFactory(barcode_graph, 1)
udg_per_node = factory.generate_all_dgraphs()
all_udgs = set()
for udgs in udg_per_node.values():
all_udgs.update(udgs)
self.assertEqual(len(all_udgs), size)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment