Commit 66b8ee72 authored by Yoann Dufresne's avatar Yoann Dufresne
Browse files

evaluation script ok

parent bb20e13b
......@@ -30,27 +30,35 @@ def load_graph(filename):
exit()
def mols_from_node(node_name):
return [int(idx) for idx in node_name.split(":")[1].split(".")[0].split("_")]
""" Compute appearance frequencies from node names.
All the node names must be under the format :
{idx}:{mol1_id}_{mol2_id}_...{molx_id}.other_things_here
@param graph The networkx graph representinf the deconvolved graph
@param only_wong If True, don't print correct nodes
@param file_pointer Where to print the output. If set to stdout, then pretty print. If set to None, don't print anything.
@return A tuple containing two dictionaries. The first one with theoritical frequences of each node, the second one with observed frequencies.
@return A tuple containing two dictionaries. The first one with theoritical frequencies of each node, the second one with observed frequencies.
"""
def parse_graph_frequencies(graph, only_wrong=False, file_pointer=sys.stdout):
def parse_graph_frequencies(graph):
# Compute origin nodes formated as `{idx}:{mol1_id}_{mol2_id}_...`
observed_frequences = {}
observed_frequencies = {}
origin_node_names = []
node_per_barcode = {}
for node in graph.nodes():
first_dot = node.find(".")
origin_name = node[:first_dot]
origin_name = node.split(".")[0]
if not origin_name in node_per_barcode:
node_per_barcode[origin_name] = []
node_per_barcode[origin_name].append(node)
# Count frequency
if not origin_name in observed_frequences:
observed_frequences[origin_name] = 0
if not origin_name in observed_frequencies:
observed_frequencies[origin_name] = 0
origin_node_names.append(origin_name)
observed_frequences[origin_name] += 1
observed_frequencies[origin_name] += 1
# Compute wanted frequencies
theoritical_frequencies = {}
......@@ -61,28 +69,63 @@ def parse_graph_frequencies(graph, only_wrong=False, file_pointer=sys.stdout):
# The node should be splited into the number of molecules inside itself
theoritical_frequencies[node_name] = len(mol_ids)
# Print results
if file_pointer != None:
print("--- Frequency analysis ---", file=file_pointer)
for key in theoritical_frequencies:
obs, the = observed_frequences[key], theoritical_frequencies[key]
result = f"{key}: {obs}/{the}"
return theoritical_frequencies, observed_frequencies, node_per_barcode
if file_pointer == sys.stdout:
result = colored(result, 'green' if obs==the else 'red')
if only_wrong and obs==the:
continue
""" This function aims to look for direct molecule neighbors.
If a node has more than 2 direct neighbors, it's not rightly splitted
"""
def parse_graph_path(graph):
neighborhood = {}
for node in graph.nodes():
molecules = mols_from_node(node)
neighbors = list(graph.neighbors(node))
print(result, file=file_pointer)
neighborhood[node] = []
for mol in molecules:
for nei in neighbors:
nei_mols = mols_from_node(nei)
if mol-1 in nei_mols:
neighborhood[node].append(nei)
if mol+1 in nei_mols:
neighborhood[node].append(nei)
return theoritical_frequencies, observed_frequences
return neighborhood
def print_summary(frequencies, file_pointer=sys.stdout):
def print_summary(frequencies, neighborhood, light_print=False, file_pointer=sys.stdout):
if file_pointer == None:
return
print("--- Nodes analysis ---", file=file_pointer)
theoritical_frequencies, observed_frequencies, node_per_barcode = frequencies
for key in theoritical_frequencies:
obs, the = observed_frequencies[key], theoritical_frequencies[key]
result = f"{key}: {obs}/{the}"
if file_pointer == sys.stdout:
result = colored(result, 'green' if obs==the else 'red')
# Compute neighborhood correctness
neighborhood_ok = True
for node in node_per_barcode[key]:
if len(neighborhood[node]) != 2:
neighborhood_ok = False
if light_print and (obs==the or not neighborhood_ok):
continue
print(result, file=file_pointer)
for node in node_per_barcode[key]:
text = f"\t{node}\t{' '.join(neighborhood[node])}"
if file_pointer == sys.stdout:
text = colored(text, 'green' if len(neighborhood[node]) == 2 else 'yellow')
print(text, file=file_pointer)
print("--- Global summary ---", file=file_pointer)
# --- Frequency usage ---
......@@ -108,9 +151,10 @@ def print_summary(frequencies, file_pointer=sys.stdout):
def main():
args = parse_args()
graph = load_graph(args.filename)
frequencies = parse_graph_frequencies(graph, only_wrong=args.light_print)
frequencies = parse_graph_frequencies(graph)
neighborhood = parse_graph_path(graph)
print_summary(frequencies)
print_summary(frequencies, neighborhood, light_print=args.light_print)
if __name__ == "__main__":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment