Commit 061f4092 authored by Yoann Dufresne's avatar Yoann Dufresne

optimization for divergence computation

parent 498df34a
......@@ -2,6 +2,7 @@ data/
real_data/
art/
**/__pycache__/
*.pyc
.pytest_cache/
.idea/
.snakemake/
......
......@@ -62,7 +62,7 @@ class D2Graph(nx.Graph):
return self.subgraph(list(self.nodes()))
def construct_from_barcodes(self, neighbor_threshold=0.25, verbose=True, clique_mode=None, threads=1):
def construct_from_barcodes(self, neighbor_threshold=0.25, min_size_clique=4, verbose=False, clique_mode=None, threads=1):
# Compute all the d-graphs
if verbose:
print("Computing the unit d-graphs..")
......@@ -70,8 +70,8 @@ class D2Graph(nx.Graph):
if clique_mode == "louvain":
dg_factory = LouvainDGFactory(self.barcode_graph)
else:
dg_factory = CliqueDGFactory(self.barcode_graph, debug=self.debug, debug_path=self.debug_path)
self.d_graphs_per_node = dg_factory.generate_all_dgraphs(threads=threads, verbose=True)
dg_factory = CliqueDGFactory(self.barcode_graph, min_size_clique=min_size_clique, debug=self.debug, debug_path=self.debug_path)
self.d_graphs_per_node = dg_factory.generate_all_dgraphs(threads=threads, verbose=verbose)
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} computed d-graphs")
......
......@@ -17,7 +17,7 @@ def process_node(factory, node):
sys.stdout.flush()
# udg generation
neighbors = list(factory.graph.neighbors(node))
neighbors = [x for x in factory.graph[node]]
subgraph = nx.Graph(factory.graph.subgraph(neighbors))
dgs = factory.generate_by_node(node, subgraph)
......@@ -56,18 +56,25 @@ class AbstractDGFactory:
if verbose:
print("Start parallel work")
results = None
with Pool(processes=threads) as pool:
results = pool.starmap(process_node, zip(
[factory]*nb_nodes,
self.graph.nodes()
))
# Fill the index by node
for node, dgs in results:
key = frozenset({node})
for dg in dgs:
index.add_value(key, dg)
if threads > 1:
results = None
with Pool(processes=threads) as pool:
results = pool.starmap(process_node, zip(
[factory]*nb_nodes,
self.graph.nodes()
))
# Fill the index by node
for node, dgs in results:
key = frozenset({node})
for dg in dgs:
index.add_value(key, dg)
else:
for node in self.graph.nodes():
key = frozenset({node})
_, dgs = process_node(factory, node)
for dg in dgs:
index.add_value(key, dg)
return index
......@@ -75,5 +82,3 @@ class AbstractDGFactory:
@abstractmethod
def generate_by_node(self, node, subgraph):
pass
import networkx as nx
from collections import Counter
from deconvolution.dgraph.AbstractDGFactory import AbstractDGFactory
from deconvolution.dgraph.d_graph import Dgraph
......@@ -24,60 +25,86 @@ class CliqueDGFactory(AbstractDGFactory):
def generate_by_node(self, central_node, subgraph):
node_d_graphs = set()
node_neighbors = {node: [x for x in subgraph[node]] for node in subgraph.nodes}
# Clique computation
cliques = []
clique_neighbors_multiset = []
clique_neighbors_set = []
clq_per_node = {node: [] for node in subgraph.nodes}
idx = 0
for clique in nx.find_cliques(subgraph):
if len(clique) >= self.min_size:
cliques.append(clique)
# index cliques per node
clq_per_node = {}
for idx, clq in enumerate(cliques):
for node in clq:
if node not in clq_per_node:
clq_per_node[node] = []
clq_per_node[node].append(idx)
# clq_per_node for nei can be empty because of minimum clique size
for node in subgraph.nodes:
if node not in clq_per_node:
clq_per_node[node] = []
# Create the clique set
clique_set = frozenset(clique)
cliques.append(clique_set)
# Index neighbors of the clique to speedup the divergence computation
neighbors = []
for node in clique:
# index clique per node. Useful ?
clq_per_node[node].append(idx)
# Prepare a neighbor multiset to speedup the divergence computation
neighbors.extend(node_neighbors[node])
ms = Counter(neighbors)
clique_neighbors_multiset.append(ms)
clique_neighbors_set.append(frozenset(ms))
idx += 1
# def clique_divergence(c1, c2):
# # Observed link
# nb_links = 0
# for node in c1:
# neighbors = clique_neighbors[node]
#
# # Awaited links
# d_approx = max(len(c1), len(c2))
# awaited = d_approx * (d_approx - 1) / 2
#
# return abs(awaited - nb_links)
def clique_divergence(c1_idx, c1, c2):
observed_link = len(c1 & c2) # Intersections of the nodes are glued links
neighbor_intersection = clique_neighbors_set[c1_idx] & c2
neighbors_multiset = clique_neighbors_multiset[c1_idx]
observed_link += sum(neighbors_multiset[x] for x in neighbor_intersection) # Sum the links between the cliques
# Awaited links
d_approx = max(len(c1), len(c2))
awaited = d_approx * (d_approx - 1) / 2
return abs(awaited - observed_link)
# Pair cliques
clq_pairs = set()
for idx, clq in enumerate(cliques):
for node in clq:
neighbors = list(subgraph.neighbors(node))
# Looks into the neighbors for clique pairing
for nei in neighbors:
nei_clqs = clq_per_node[nei]
# Pair useful cliques
for nei_clq in nei_clqs:
if nei_clq > idx:
clq_pairs.add((idx, nei_clq))
def enumerate_clique_pair():
for clq_idx, clq in enumerate(cliques):
for node in clq:
# Looks into the neighbors for clique pairing
for nei in subgraph[node]:
nei_clqs = clq_per_node[nei]
# Pair useful cliques
for nei_clq in nei_clqs:
if nei_clq > clq_idx:
div_clq = clique_divergence(clq_idx, clq, cliques[nei_clq])
yield clq_idx, nei_clq, div_clq
# Create the clique graph for max weight
clq_G = nx.Graph()
# Create nodes
for idx in range(len(cliques)):
clq_G.add_node(idx)
# Create edges
max_div = 0
for idx1, idx2 in clq_pairs:
for idx1, idx2, div in enumerate_clique_pair():
# Get cliques
clq1 = cliques[idx1]
clq2 = cliques[idx2]
# Create candidate udg
d_graph = Dgraph(central_node)
d_graph.put_halves(clq1, clq2, subgraph)
# Add divergence to the clique graph
div = d_graph.get_link_divergence()
if div > max_div:
max_div = div
clq_G.add_edge(idx1, idx2, weight=div)
# Normalize the divergence
for idx1, idx2 in clq_pairs:
for idx1, idx2 in clq_G.edges():
clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight']
if self.debug and len(clq_G.nodes) > 0:
......@@ -93,7 +120,7 @@ class CliqueDGFactory(AbstractDGFactory):
clq2 = cliques[idx2]
# Create candidate udg
d_graph = Dgraph(central_node)
d_graph.put_halves(clq1, clq2, subgraph)
d_graph.put_halves(list(clq1), list(clq2), subgraph)
node_d_graphs.add(d_graph)
return node_d_graphs
import argparse
import time
import networkx as nx
from collections import Counter
......@@ -79,7 +80,7 @@ def analyse_clique_graph(barcode_graph):
def analyse_d_graphs(barcode_graph):
# Generate udgs
factory = CliqueDGFactory(barcode_graph, 1)
udg_per_node = factory.generate_all_dgraphs()
udg_per_node = factory.generate_all_dgraphs(threads=1)
# Remove duplicate udgs
udgs = {}
for udg_node_lst in udg_per_node.values():
......@@ -99,14 +100,18 @@ def analyse_d_graphs(barcode_graph):
def main():
args = parse_arguments()
g = nx.read_gexf(args.barcode_graph)
prev_time = time.time()
continuous, total = analyse_clique_graph(g)
print("cliques")
print("cliques", time.time() - prev_time)
print(continuous, "/", total)
prev_time = time.time()
continuous, total = analyse_d_graphs(g)
print("udgs")
print("udgs", time.time() - prev_time)
print(continuous, "/", total)
if __name__ == "__main__":
# import cProfile
# cProfile.run('main()')
main()
......@@ -4,7 +4,7 @@ from distutils.core import setup
setup(
name='10X-deconvolve',
version='0.1dev',
packages=['deconvolution.d2graph', 'deconvolution.dgraph', 'deconvolution.main', 'experiments'],
packages=['deconvolution.d2graph', 'deconvolution.dgraph', 'deconvolution.main', 'experiments', 'tests'],
license='AGPL V3',
long_description=open('README.md').read(),
)
This diff is collapsed.
......@@ -8,84 +8,65 @@ from deconvolution.dgraph import graph_manipulator as gm
class TestD2Graph(unittest.TestCase):
# def test_construction(self):
# d2 = D2Graph(complete_graph)
# d2.construct_from_barcodes(index_size=4, verbose=False)
# def test_linear_d2_construction(self):
# for d in range(2, 10):
# size = 2 * d + 3
#
# G = gm.generate_d_graph_chain(size, d)
# d2 = D2Graph(G)
# print("before", d)
# d2.construct_from_barcodes(neighbor_threshold=0, min_size_clique=d, verbose=False)
# print("after", d)
#
# # for dg in d2.all_d_graphs:
# # print(dg.score, dg.get_link_divergence(), dg)
# # print()
#
# # Test the number of d-graphs
# awaited_d_num = size - 2 * d
# self.assertEqual(awaited_d_num, len(d2.all_d_graphs))
#
# # Evaluate the number of candidate unit d_graphs generated
# for node, candidates in d2.d_graphs_per_node.items():
# if node == "C" or node == "B2":
# self.assertEqual(1, len(candidates))
# else:
# self.assertEqual(0, len(candidates))
# # Test connectivity
# # Center node names
# c1 = d
# c2 = d + 1
# c3 = d + 2
# # Connectivity matrix
# awaited_distances = {
# c1: {c2: 2, c3: 4},
# c2: {c1: 2, c3: 2},
# c3: {c1: 4, c2: 2}
# }
#
# # Evaluate the dgraph
# self.assertEqual(13, len(d2.index))
# for x, y, data in d2.edges(data=True):
# dg1 = d2.node_by_idx[x]
# dg2 = d2.node_by_idx[y]
#
# overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
# for dmer, dg_lst in d2.index.items():
# if dmer == overlap_key:
# values = list(d2.index[dmer])
# self.assertEqual(2, len(d2.index[dmer]))
# self.assertNotEqual(values[0], values[1])
# else:
# self.assertEqual(1, len(d2.index[dmer]))
def test_linear_d2_construction(self):
for d in range(1, 10):
size = 2 * d + 3
index_k = 2 * d - 1
G = gm.generate_d_graph_chain(size, d)
d2 = D2Graph(G)
d2.construct_from_barcodes(index_size=index_k, verbose=False)
# for dg in d2.all_d_graphs:
# print(dg.score, dg.get_link_divergence(), dg)
# print()
# Test the number of d-graphs
awaited_d_num = size - 2 * d
self.assertEqual(awaited_d_num, len(d2.all_d_graphs))
# Test dgraph
awaited_index_size = comb(2 * d + 1, index_k) + (size - (2 * d + 1)) * comb(2 * d, index_k - 1)
if len(d2.index) != awaited_index_size:
dmers = [list(x) for x in d2.index]
dmers = [str(x) for x in dmers if len(x) != len(frozenset(x))]
self.assertEqual(awaited_index_size, len(d2.index))
# Test connectivity
# Center node names
c1 = d
c2 = d + 1
c3 = d + 2
# Connectivity matrix
awaited_distances = {
c1: {c2: 2, c3: 4},
c2: {c1: 2, c3: 2},
c3: {c1: 4, c2: 2}
}
for x, y, data in d2.edges(data=True):
dg1 = d2.node_by_idx[x]
dg2 = d2.node_by_idx[y]
awaited_dist = awaited_distances[dg1.center][dg2.center]
self.assertEqual(data["distance"], awaited_dist)
# awaited_dist = awaited_distances[dg1.center][dg2.center]
# self.assertEqual(data["distance"], awaited_dist)
def test_no_variability(self):
barcode_graph = nx.read_gexf("test_data/bar_1000_5_2.gexf")
d2 = D2Graph(barcode_graph)
d2.construct_from_barcodes()
udgs = d2.all_d_graphs
for _ in range(5):
d2 = D2Graph(barcode_graph)
d2.construct_from_barcodes()
self.assertEqual(len(udgs), len(d2.all_d_graphs))
def test_reloading(self):
# Parameters
d = 3
size = 2 * d + 3
index_k = 2 * d - 1
# Create a d2 graph
G = gm.generate_d_graph_chain(size, d)
d2 = D2Graph(G)
d2.construct_from_barcodes(index_size=index_k, verbose=False)
d2.construct_from_barcodes(verbose=False)
# Save and reload the d2 in a temporary file
with tempfile.NamedTemporaryFile() as fp:
......
import unittest
import networkx as nx
from d_graph_data import unit_d_graph
from deconvolution.dgraph.d_graph import Dgraph
from deconvolution.dgraph import graph_manipulator as gm
from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
class TestDGraph(unittest.TestCase):
......@@ -60,9 +62,23 @@ class TestDGraph(unittest.TestCase):
self.assertEqual([['A0'], ['A1'], ['A2'], ['C'], ['B2'], ['B1'], ['B0']], lst)
# def test_list_dgraphs(self):
def test_generation_no_variability_DGCliqueFactory(self):
barcode_graph = nx.read_gexf("test_data/bar_1000_5_2.gexf")
factory = CliqueDGFactory(barcode_graph, 1)
udg_per_node = factory.generate_all_dgraphs()
all_udgs = set()
for udgs in udg_per_node.values():
all_udgs.update(udgs)
size = len(all_udgs)
for _ in range(5):
factory = CliqueDGFactory(barcode_graph, 1)
udg_per_node = factory.generate_all_dgraphs()
all_udgs = set()
for udgs in udg_per_node.values():
all_udgs.update(udgs)
self.assertEqual(len(all_udgs), size)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment