Skip to content
Snippets Groups Projects
Commit 1be895e4 authored by Mélanie  HENNART's avatar Mélanie HENNART
Browse files

Add new script

parent feed36a5
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
###############################################################################
# #
# Copyright (C) 2021 Melanie HENNART #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <https://www.gnu.org/licenses/>. #
# #
# #
# Contact: #
# #
# Melanie HENNART, PhD Student #
# melanie.hennart@pasteur.fr #
# Biodiversity and Epidemiology of Bacterial Pathogens #
# Institut Pasteur #
# 25-28, Rue du Docteur Roux #
# 75015 Paris Cedex 15 #
# France #
# #
###############################################################################
import pandas as pd
import networkx as nx
import argparse
def mapping_cluster_ST (data, seuil ) :
df = pd.crosstab(data['ST'], data[seuil])
df = df.drop('NA')
B = nx.Graph()
B.add_nodes_from(df.index, bipartite=1)
B.add_nodes_from(df.columns, bipartite=0)
for i in df.index :
for j in df.columns :
if df[j][i] > 0 :
B.add_edge(i,j, weight=df[j][i])
dico = {}
ST_attribute = []
while len(B.edges) > 0 :
edges = B.edges(data=True)
max_weight_edges = max([edge[2]['weight'] for edge in edges])
edges_subgraph = [(u,v) for (u,v,d) in B.edges(data=True) if d['weight'] == max_weight_edges]
C = B.edge_subgraph(edges_subgraph).copy()
for connected_component in nx.connected_components(C) :
S = C.subgraph(connected_component).copy()
top_nodes = {n for n, d in S.nodes(data=True) if d["bipartite"] == 0}
bottom_nodes = set(S) - top_nodes
if len(top_nodes) == 1 :
u = list(top_nodes)[0]
v = min(bottom_nodes)
B.remove_nodes_from((u,v))
dico[u]= [ v , 'Inheritance from ST' ]
ST_attribute.append(v)
else :
no_strains_nodes = [(sum(d['weight'] for (u,v,d) in C.edges(top_node, data=True)), top_node) for top_node in top_nodes]
no_strains_nodes.sort()
for (no_strains_node, node) in no_strains_nodes :
u = node
degree = [(S.degree[x], x) for x in S.adj[u] if x not in ST_attribute ]
if len(degree) > 0 :
v = min(degree)[1]
B.remove_nodes_from((u,v))
dico[u]= [ v , 'Inheritance from ST' ]
ST_attribute.append(v)
top_nodes = {n for n, d in B.nodes(data=True) if d["bipartite"] == 0}
lambda_ = 10000
for u in top_nodes :
dico[u]= [lambda_ , "Arbitrary"]
lambda_ += 1
mapping = data[seuil].apply(lambda x : dico[x][0])
attribute = data[seuil].apply(lambda x : dico[x][1])
return mapping, attribute, dico
#========= MAIN PROGRAM =====================================================#
#=== Parameters
parser = argparse.ArgumentParser()
parser.add_argument("-i", dest="input_file", type=str, required=True, help="input tab-delimited file (mandatory)")
parser.add_argument("-c", dest="colunms", type=str, required=True, help="name of the clustering column(s) (mandatory)")
parser.add_argument("-o", dest="output_file", type=str, required=True, help="basename for output files (mandatory)")
args = parser.parse_args()
input_file = args.input_file
colunms = args.colunms
output_file = args.output_file
#=== Algorithms
data = pd.read_csv(input_file, sep='\t', index_col=0, dtype=str)
data = data.fillna('NA')
colunms = colunms.split(',')
for colunm in colunms :
if colunm in data.columns :
mapping, attribute, dico = mapping_cluster_ST (data, colunm)
data['Mapping_'+colunm] = mapping
data['Attribution_'+colunm] = attribute
GroupData=pd.DataFrame.from_dict(dico, orient='index', columns=['Mapping', 'Attribution'])
GroupData.to_csv(output_file+'.'+colunm+'.txt', sep="\t")
else :
print ('Error: '+ colunm + ' is not a column in the input file.')
data.to_csv(output_file + '.out', sep="\t")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment