Commit 59ef6b96 authored by Yoann Dufresne's avatar Yoann Dufresne

bugfix: fake d2 generation due to variable collisions

parent 02f87907
......@@ -3,7 +3,7 @@ import itertools
from bidict import bidict
import sys
from deconvolution.dgraph.FixedDGIndex import FixedDGIndex
# from deconvolution.dgraph.FixedDGIndex import FixedDGIndex
from deconvolution.dgraph.VariableDGIndex import VariableDGIndex
from deconvolution.dgraph.d_graph import Dgraph
from deconvolution.dgraph.CliqueDGFactory import CliqueDGFactory
......@@ -59,7 +59,7 @@ class D2Graph(nx.Graph):
return self.subgraph(list(self.nodes()))
def construct_from_barcodes(self, index_size=3, verbose=True, debug=False, clique_mode=None, threads=1):
def construct_from_barcodes(self, index_size=3, verbose=True, clique_mode=None, threads=1):
# Compute all the d-graphs
if verbose:
print("Computing the unit d-graphs..")
......@@ -68,7 +68,7 @@ class D2Graph(nx.Graph):
dg_factory = LouvainDGFactory(self.barcode_graph)
else:
dg_factory = CliqueDGFactory(self.barcode_graph)
self.d_graphs_per_node = dg_factory.generate_all_dgraphs(debug=True)
self.d_graphs_per_node = dg_factory.generate_all_dgraphs(threads=threads, verbose=True)
if verbose:
counts = sum(len(x) for x in self.d_graphs_per_node.values())
print(f"\t {counts} computed d-graphs")
......@@ -85,7 +85,7 @@ class D2Graph(nx.Graph):
print("Compute the dmer dgraph")
print("\tIndexing")
# self.index = FixedDGIndex(size=index_size)
self.index = VariableDGIndex(size=2)
self.index = VariableDGIndex(size=index_size)
for idx, dg in enumerate(self.all_d_graphs):
if verbose:
print(f"\r\t{idx+1}/{len(self.all_d_graphs)}", end='')
......
import networkx as nx
from abc import abstractmethod
from multiprocessing import Pool
from deconvolution.dgraph.FixedDGIndex import FixedDGIndex
locked = False
nb_over = 0
def process_node(factory, node):
neighbors = list(factory.graph.neighbors(node))
subgraph = nx.Graph(factory.graph.subgraph(neighbors))
dgs = factory.generate_by_node(node, subgraph)
global nb_over, locked
nb_over += 1
if factory.verbose:
if not locked:
locked = True
print(f"\r{nb_over}/{factory.nb_nodes} node analysis", end='')
locked = False
return node, dgs
class AbstractDGFactory:
def __init__(self, graph):
self.graph = graph
self.nb_nodes = len(self.graph.nodes())
self.verbose = False
def generate_all_dgraphs(self, debug=False):
def generate_all_dgraphs(self, verbose=False, threads=8):
index = FixedDGIndex(size=1)
factory = self
global nb_over
nb_over = 0
nb_nodes = len(self.graph.nodes())
for idx, node in enumerate(self.graph.nodes()):
if debug: print(f"\r{idx+1}/{nb_nodes} node analysis", end='')
neighbors = list(self.graph.neighbors(node))
subgraph = nx.Graph(self.graph.subgraph(neighbors))
self.verbose = verbose
results = None
with Pool(processes=threads) as pool:
results = pool.starmap(process_node, zip(
[factory]*nb_nodes,
self.graph.nodes(),
))
# Fill the index by node
# Fill the index by node
for node, dgs in results:
key = frozenset({node})
for dg in self.generate_by_node(node, subgraph):
for dg in dgs:
index.add_value(key, dg)
index.filter_by_entry()
......@@ -32,84 +58,3 @@ class AbstractDGFactory:
pass
# def compute_all_max_d_graphs(graph, debug=False, clique_mode=None, threads=1):
# d_graphs = FixedDGIndex(size=1)
#
# nds = list(graph.nodes())
# for idx, node in enumerate(nds):
# # print(idx+1, '/', len(nds))
# #if "MI" not in str(node): continue # for debugging; only look at deconvolved nodes
# #print(f"\r{idx+1}/{len(graph.nodes())}")
# neighbors = list(graph.neighbors(node))
# neighbors_graph = nx.Graph(graph.subgraph(neighbors))
#
# node_d_graphs = set()
#
# mode_str = " "
# if clique_mode is None:
# # Find all the cliques (equivalent to compute all the candidate half d-graph)
# cliques = []
# for clique in nx.find_cliques(neighbors_graph):
# if len(clique) > 3:
# cliques.append(clique)
# mode_str += "(max-cliques)"
# elif clique_mode == "louvain":
# louvain = community.best_partition(neighbors_graph) # louvain
# # high resolution seems to work better
# communities = [[c for c,i in louvain.items() if i == clique_id] for clique_id in set(louvain.values())]
# mode_str += "(louvain)"
# cliques = []
# for comm in communities:
# # further decompose! into necessarily 2 communities
# community_as_graph = nx.Graph(graph.subgraph(comm))
# if len(community_as_graph.nodes()) <= 2:
# cliques += [community_as_graph.nodes()]
# else:
# cliques += map(list,nx.community.asyn_fluidc(community_as_graph,2))
#
# elif clique_mode == "testing":
# # k-clique communities
# #from networkx.algorithms.community import k_clique_communities
# #cliques = k_clique_communities(neighbors_graph, 3) # related to the d-graph d parameter
# from cdlib import algorithms
# cliques_dict = algorithms.node_perception(neighbors_graph, threshold=0.75, overlap_threshold=0.75) #typical output: Sizes of found cliques (testing): Counter({6: 4, 5: 3, 4: 2, 2: 1})
# #cliques_dict = algorithms.gdmp2(neighbors_graph, min_threshold=0.9) #typical output: sizes of found cliques (testing): Counter({3: 2, 5: 1})
# #cliques_dict = algorithms.angel(neighbors_graph, threshold=0.90) # very sensitive parameters: 0.84 and 0.88 don't work at all but 0.86 does sort of
# from collections import defaultdict
# cliques_dict2 = defaultdict(list)
# for (node, values) in cliques_dict.to_node_community_map().items():
# for value in values:
# cliques_dict2[value] += [node]
# cliques = list(cliques_dict2.values())
# mode_str += "(testing)"
#
# if debug: print("node", node, "has", len(cliques), "cliques in neighborhood (of size", len(neighbors), ")")
#
# cliques_debugging = True
# if cliques_debugging:
#
# from collections import Counter
# len_cliques = Counter(map(len,cliques))
#
# # Pair halves to create d-graphes
# for idx, clq1 in enumerate(cliques):
# for clq2_idx in range(idx+1, len(cliques)):
# clq2 = cliques[clq2_idx]
#
# # Check for d-graph candidates
# d_graph = Dgraph(node)
# d_graph.put_halves(clq1, clq2, neighbors_graph)
#
# factor = 0.5
# #if clique_mode == "testing": factor = 1 # still allows louvain's big communities
# #print("link div:",d_graph.get_link_divergence(),"opt:",d_graph.get_optimal_score(), "good d graph?",d_graph.get_link_divergence() <= d_graph.get_optimal_score() *factor)
# if d_graph.get_link_divergence() <= d_graph.get_optimal_score() * factor:
# node_d_graphs.add(d_graph)
#
# # Fill the index by node
# key = frozenset({node})
# for dg in node_d_graphs:
# d_graphs.add_value(key, dg)
#
# d_graphs.filter_by_entry()
# return d_graphs
......@@ -11,7 +11,7 @@ class CliqueDGFactory(AbstractDGFactory):
self.dg_max_divergence_factor = dg_max_divergence_factor
def generate_by_node(self, node, subgraph):
def generate_by_node(self, central_node, subgraph):
node_d_graphs = set()
# Clique computation
......@@ -20,18 +20,59 @@ class CliqueDGFactory(AbstractDGFactory):
if len(clique) >= self.min_size:
cliques.append(clique)
# TODO: Index cliques to pair them faster
# index cliques per node
clq_per_node = {}
for idx, clq in enumerate(cliques):
for node in clq:
if node not in clq_per_node:
clq_per_node[node] = []
clq_per_node[node].append(idx)
# d-graph computation
for idx, clq1 in enumerate(cliques):
for clq2_idx in range(idx+1, len(cliques)):
clq2 = cliques[clq2_idx]
# Pair cliques
clq_pairs = set()
for idx, clq in enumerate(cliques):
for node in clq:
neighbors = list(subgraph.neighbors(node))
# Looks into the neighbors for clique pairing
for nei in neighbors:
nei_clqs = clq_per_node[nei]
# Pair useful cliques
for nei_clq in nei_clqs:
if nei_clq > idx:
clq_pairs.add((idx, nei_clq))
# Check for d-graph candidates
d_graph = Dgraph(node)
d_graph.put_halves(clq1, clq2, subgraph)
# Create the clique graph for max weight
clq_G = nx.Graph()
# Create nodes
for idx in range(len(cliques)):
clq_G.add_node(idx)
# Create edges
max_div = 0
for idx1, idx2 in clq_pairs:
# Get cliques
clq1 = cliques[idx1]
clq2 = cliques[idx2]
# Create candidate udg
d_graph = Dgraph(central_node)
d_graph.put_halves(clq1, clq2, subgraph)
# Add divergence to the clique graph
div = d_graph.get_link_divergence()
if div > max_div:
max_div = div
clq_G.add_edge(idx1, idx2, weight=div)
# Normalize the divergence
for idx1, idx2 in clq_pairs:
clq_G.edges[idx1, idx2]['weight'] = max_div - clq_G.edges[idx1, idx2]['weight']
if d_graph.get_link_divergence() <= d_graph.get_optimal_score() * self.dg_max_divergence_factor:
node_d_graphs.add(d_graph)
# d-graph computation regarding max weight matching
mwm = nx.algorithms.max_weight_matching(clq_G)
for idx1, idx2 in mwm:
# Get cliques
clq1 = cliques[idx1]
clq2 = cliques[idx2]
# Create candidate udg
d_graph = Dgraph(central_node)
d_graph.put_halves(clq1, clq2, subgraph)
node_d_graphs.add(d_graph)
return node_d_graphs
......@@ -11,7 +11,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('filename', type=str,
help='The file to evalute')
parser.add_argument('--type', '-t', choices=["d2", "path", "d2-2annotate"], default="path", required=True,
parser.add_argument('--type', '-t', choices=["d2", "path", "d2-2annotate", "dgraphs"], default="path", required=True,
help="Define the data type to evaluate. Must be 'd2' or 'path' or 'd2-2annotate' (Rayan's hack).")
parser.add_argument('--light-print', '-l', action='store_true',
help='Print only wrong nodes and paths')
......@@ -50,15 +50,39 @@ def mols_from_node(node_name):
return [int(idx) for idx in node_name.split(":")[1].split(".")[0].split("_")]
""" Compute appearance frequencies from node names.
All the node names must be under the format :
{idx}:{mol1_id}_{mol2_id}_...{molx_id}.other_things_here
@param graph The networkx graph representinf the deconvolved graph
@param only_wong If True, don't print correct nodes
@param file_pointer Where to print the output. If set to stdout, then pretty print. If set to None, don't print anything.
@return A tuple containing two dictionaries. The first one with theoritical frequencies of each node, the second one with observed frequencies.
"""
def parse_udg_qualities(graph):
""" Compute the quality for the best udgs present in the graph.
All the node names must be under the format :
{idx}:{mol1_id}_{mol2_id}_...{molx_id}.other_things_here
:param graph: The networkx graph representinf the deconvolved graph
:return: A tuple containing two dictionaries. The first one with theoritical frequencies of each node, the second one with observed frequencies.
"""
dg_per_node = {}
for node, data in graph.nodes(data=True):
str_udg = data["udg"]
central, h1, h2 = str_to_udg_lists(str_udg)
if central not in dg_per_node:
dg_per_node[central] = []
dg_per_node[central].append(data["udg"])
for node in dg_per_node:
print(node, dg_per_node[node])
print(len(dg_per_node))
return dg_per_node
def parse_path_graph_frequencies(graph, barcode_graph):
""" Compute appearance frequencies from node names.
All the node names must be under the format :
{idx}:{mol1_id}_{mol2_id}_...{molx_id}.other_things_here
:param graph: The networkx graph representing the deconvolved graph
:param only_wong: If True, don't print correct nodes
:param file_pointer: Where to print the output. If set to stdout, then pretty print. If set to None, don't print anything.
:return: A tuple containing two dictionaries. The first one with theoretical frequencies of each node, the second one with observed frequencies.
"""
# Compute origin nodes formatted as `{idx}:{mol1_id}_{mol2_id}_...`
observed_frequencies = {}
real_frequencies = {}
......@@ -147,12 +171,19 @@ def print_path_summary(frequencies, light_print=False, file_pointer=sys.stdout):
print(f"Under/Over splitting: {under_split} - {over_split}")
def print_dgraphs_summary(frequencies, light_print=False):
pass
# ------------- D2 Graph -------------
def str_to_udg_lists(s):
udg = s.replace("]", "").replace(' [', '[')
return udg.split('[')
def parse_dg_name(gr, name):
udg = nx.get_node_attributes(gr, 'udg')[name]
udg = udg.replace("]", "").replace(' [', '[')
central, h1, h2 = udg.split('[')
central, h1, h2 = str_to_udg_lists(udg)
idx = name
score = nx.get_node_attributes(gr, 'score')[name]
......@@ -451,6 +482,9 @@ def main():
frequencies = parse_path_graph_frequencies(graph, barcode_graph)
print_path_summary(frequencies, light_print=args.light_print)
elif args.type == "dgraphs":
udg_per_node = parse_udg_qualities(graph)
# print(udg_per_node)
elif args.type == "d2":
components = list(nx.connected_components(graph))
components.sort(key=lambda x: -len(x))
......
......@@ -11,8 +11,8 @@ def parse_arguments():
parser = argparse.ArgumentParser(description='Transform a 10X barcode graph into a d2 graph. The program dig for the d-graphs and then merge them into a d2-graph.')
parser.add_argument('barcode_graph', help='The barcode graph file. Must be a gefx formated file.')
parser.add_argument('--output_prefix', '-o', default="d2_graph", help="Output file prefix.")
parser.add_argument('--threads', '-t', default=8, type=int, help='Number of thread to use for dgraph computation')
parser.add_argument('--debug', '-d', action='store_true', help="Debug")
parser.add_argument('--threads', '-t', default=1, type=int, help='Number of thread to use for dgraph computation')
# parser.add_argument('--debug', '-d', action='store_true', help="Debug")
parser.add_argument('--maxclq', '-c', action='store_true', help="Enable max clique community detection (default behaviour)")
parser.add_argument('--louvain', '-l', action='store_true', help="Enable Louvain community detection instead of all max-cliques")
parser.add_argument('--comtest', '-k', action='store_true', help="Enable [placeholder] community detection algorithm instead of max-cliques")
......@@ -24,13 +24,13 @@ def main():
# Parsing the input file
args = parse_arguments()
debug = args.debug
# debug = args.debug
filename = args.barcode_graph
def dprint(s):
from datetime import datetime
t = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if debug: print(t,"[debug]",s)
# if debug: print(t,"[debug]",s)
dprint("loading barcode graph")
if filename.endswith('.gexf'):
......@@ -54,11 +54,11 @@ def main():
d2g = d2.D2Graph(G)
dprint("D2 graph object created")
dprint("constructing d2 graph from barcode graph")
index_size = 8 #if clique_mode is None else 3
d2g.construct_from_barcodes(index_size=index_size, debug=debug, clique_mode=clique_mode, threads=args.threads)
index_size = 4 #if clique_mode is None else 3
d2g.construct_from_barcodes(index_size=index_size, clique_mode=clique_mode, threads=args.threads)
dprint("[debug] d2 graph constructed")
d2g.save(f"{args.output_prefix}.tsv")
# d2g.save(f"{args.output_prefix}.tsv")
nx.write_gexf(d2g, f"{args.output_prefix}.gexf")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment