Commit c72944dc authored by Fabrice  ALLAIN's avatar Fabrice ALLAIN
Browse files

First steps for pdbstat parallelization

parent b109bc46
......@@ -19,7 +19,7 @@ from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D
from aria.SuperImposer import SuperImposer
from .converter import AriaEcXMLConverter
from .base import NotDisordered, Capturing
from .common import NotDisordered, Capturing
from matplotlib.colors import ListedColormap
from aria.DataContainer import DATA_SEQUENCE
from aria.StructureEnsemble import StructureEnsemble, StructureEnsembleSettings
......@@ -286,7 +286,6 @@ class EnsembleAnalysis(object):
LOG.info("Wrote %s file", out_file)
<<<<<<< HEAD
def extract_ariaensemble_matrix(self):
"""
......@@ -295,10 +294,9 @@ class EnsembleAnalysis(object):
"""
pass
=======
# TODO: As described in Guillaume paper, clustering on aligned CA
# coordinates highly depends on the efficiency of the alignment method
>>>>>>> ce3a45d90d4a68a55ef93fc2713b59ddac0448d9
@staticmethod
def pca_projection(
......
......@@ -9,7 +9,7 @@ import logging
import argparse as argp
from . import __doc__
from .base import format_dict, CustomLogging
from .common import format_dict, CustomLogging
from .settings import AriaEcSettings
from .maplot import AriaEcContactMap
from .converter import AriaEcBbConverter, AriaEcXMLConverter
......@@ -417,6 +417,10 @@ class AriaEcCommands(object):
parser.add_argument(
"pdbdists", action=ReadableFile,
help="PDB distance file in csv format")
parser.add_argument(
"-j", dest="njobs", default=1, metavar="N_JOBS",
help="Number of cpus used to run mixture in parallel"
)
return parser
def create_settings(self):
......
# coding=utf-8
"""
Basic tools aria_ec
Basic tools ariaec
"""
from __future__ import absolute_import, division, print_function
from abc import ABCMeta, abstractmethod
from Bio.PDB import Select
import six
import logging
import logging.config
import os
......@@ -15,10 +13,13 @@ import re
import ast
import sys
import shutil
import six
import numpy as np
import pkg_resources as pkgr
import matplotlib.artist as art
from Bio.PDB import Select
LOG = logging.getLogger(__name__)
......@@ -37,7 +38,6 @@ def addtup(tup, inc=1):
Returns
-------
"""
return tuple((val + inc for val in tup))
......@@ -58,7 +58,6 @@ def titleprint(outfile, progname='', desc=''):
Returns
-------
"""
out = '''
================================================================================
......@@ -81,14 +80,12 @@ def get_filename(path):
Returns
-------
"""
return "_".join(path.split("/")[-1].split(".")[:-1])
def reg_load(regex, filepath, sort=None):
"""
Parameters
----------
......@@ -97,12 +94,10 @@ def reg_load(regex, filepath, sort=None):
sort :
return: (Default value = None)
filepath :
Returns
-------
"""
lines_dict = {}
......@@ -130,12 +125,10 @@ def sort_2dict(unsort_dict, key, reverse=True):
unsort_dict :
return: sorted dict
key :
Returns
-------
"""
sorted_index = sorted(unsort_dict, key=lambda x: float(unsort_dict[x][key]),
reverse=reverse)
......@@ -161,7 +154,6 @@ def cart_dist(vectx, vecty):
Returns
-------
"""
return np.sqrt(sum(np.power(vectx - vecty, 2)))
......@@ -178,46 +170,42 @@ def format_str(string):
Returns
-------
"""
if re.search(r"^\s*(true)\s*$", string, re.I):
return True
elif re.search(r"^\s*(false)\s*$", string, re.I):
if re.search(r"^\s*(false)\s*$", string, re.I):
return False
elif re.search(r"^\s*\d+\s*$", string):
if re.search(r"^\s*\d+\s*$", string):
return int(string)
elif re.search(r"^[\s\d-]+\.\d+\s*$", string):
if re.search(r"^[\s\d-]+\.\d+\s*$", string):
return float(string)
elif re.search(r'^".+"$', string):
if re.search(r'^".+"$', string):
# remove " characters
return string[1:-1]
elif "," in string:
if "," in string:
return string.split(',')
elif "+" in string:
if "+" in string:
return string.split('+')
elif "/" in string and os.path.exists(string):
if "/" in string and os.path.exists(string):
return os.path.abspath(string)
elif re.search(r"[/\w]+", string):
if re.search(r"[/\w]+", string):
return string
else:
if string:
try:
ev_str = ast.literal_eval(string)
except ValueError:
LOG.error("Don't understand given string %s. Please check "
"format.", string)
return None
except SyntaxError:
LOG.error("Given string %s is not a valid expression", string)
return None
return ev_str
else:
if string:
try:
ev_str = ast.literal_eval(string)
except ValueError:
LOG.error("Don't understand given string %s. Please check "
"format.", string)
return None
except SyntaxError:
LOG.error("Given string %s is not a valid expression", string)
return None
return ev_str
return None
def format_dict(indict):
"""
Parameters
----------
......@@ -227,7 +215,6 @@ def format_dict(indict):
Returns
-------
"""
for key in indict:
if isinstance(indict[key], str):
......@@ -237,7 +224,6 @@ def format_dict(indict):
def ppdict(indict, indent=2):
"""
Parameters
----------
......@@ -249,7 +235,6 @@ def ppdict(indict, indent=2):
Returns
-------
"""
return json.dumps(indict, indent=indent)
......@@ -272,7 +257,6 @@ def tickmin(pandata, ntick=None, shift=0, offset=5):
Returns
-------
"""
yticks = [_ + shift for _ in range(0, len(pandata.index))]
xticks = [_ + shift for _ in range(0, len(pandata.columns))]
......@@ -310,14 +294,13 @@ def tickrot(axes, figure, rotype='horizontal', xaxis=True, yaxis=True):
yaxis :
return: (Default value = True)
figure :
xaxis :
(Default value = True)
Returns
-------
"""
if yaxis:
art.setp(axes.get_yticklabels(), rotation=rotype)
......@@ -352,7 +335,6 @@ class CustomLogging(object):
def update_msg(self, desc):
"""
Parameters
----------
......@@ -362,7 +344,6 @@ class CustomLogging(object):
Returns
-------
"""
if isinstance(self.msg, list):
self.msg += desc
......@@ -390,7 +371,6 @@ class CustomLogging(object):
Returns
-------
"""
outdir = os.path.join(outdir,
"log") if "log" not in outdir else outdir
......@@ -507,7 +487,6 @@ class CommandProtocol(object):
Returns
-------
"""
raise NotImplementedError
......@@ -522,12 +501,10 @@ class NotDisordered(Select):
Parameters
----------
atom:
Returns
-------
"""
return not atom.is_disordered() or atom.get_altloc() == 'A'
......
......@@ -16,7 +16,7 @@ import json
import re
import textwrap
import aria.legacy.AminoAcid as AminoAcid
from .base import get_filename
from .common import get_filename
from .protein import Protein
from .reader import MapFileListReader, TblDistFile
from .protmap import (ResAtmMap, ResMap)
......
......@@ -7,7 +7,7 @@ from __future__ import absolute_import, division, print_function
import sys
import json
import logging
from .base import get_filename
from .common import get_filename
from .reader import MapFileListReader
from .protmap import MapFilter
from .protein import Protein
......
......@@ -17,7 +17,7 @@ from future.utils import iteritems
from collections import defaultdict, OrderedDict
# from .base import ppdict
from .base import Capturing
from .common import Capturing
from .reader import CulledPdbFile
from .protmap import ResAtmMap
......
......@@ -10,7 +10,7 @@ import shutil
import logging
from aria.legacy.QualityChecks import QualityChecks
from .base import CommandProtocol
from .common import CommandProtocol
LOG = logging.getLogger(__name__)
......
......@@ -4,10 +4,14 @@
"""
import os
import time
import pickle
import logging
import itertools
import numpy as np
import pandas as pd
import progressbar
import multiprocessing
import sklearn.mixture as mixture
from .protmap import SsAaAtmMap
......@@ -17,6 +21,87 @@ from aria.legacy.AminoAcid import AminoAcid
LOG = logging.getLogger(__name__)
# TODO: use pathos (http://stackoverflow.com/a/21345308)
# Code below adapated from an answer of klaus se on stackoverflow
# (http://stackoverflow.com/a/16071616)
def worker(f, task_queue, done_queue):
"""
Worker process. Query the task_queue if there is an item available. Each
item is a 2-tuple (jobid, data) or (None, None) used to submit a job in
done_queue
Parameters
----------
f: function object
task_queue: multiprocessing.Queue object
Queue of task to submit
done_queue: multiprocessing.Queue object
Queue where all the task are running
Returns
-------
None
"""
while True:
# Remove and return an item from the queue. Block if necessary until an
# item is available.
i, x = task_queue.get()
if i is None:
break
done_queue.put((i, f(x)))
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
"""
Map in parallel a function on elements of X using multiprocessing module
Parameters
----------
f: function object
Function to launch with multiprocessing. Accept only one arguments which
is an element of X
X: list
Input data where each element of X is submitted with f function
nprocs: int
Number of parallel processs
Returns
-------
"""
# Create queues
task_queue = multiprocessing.Queue(1)
done_queue = multiprocessing.Queue()
bar = progressbar.ProgressBar()
# Initialize list of processes (1 for each cpu according to nprocs)
proc = [multiprocessing.Process(target=worker, args=(f, task_queue, done_queue))
for _ in range(nprocs)]
# Start worker processes
for p in proc:
p.daemon = True
p.start()
# Sent as many task as the number of elements in X by putting a list with
# (jobid, data)
print("test1")
sent = [task_queue.put((i, x)) for i, x in enumerate(X)]
# (None, None) as the last item indicates to workers that it has reached the
# end of the sequence of items for each process.
print("test2")
[task_queue.put((None, None)) for _ in range(nprocs)]
# Get all the results
print("test3", len(sent))
res = [done_queue.get() for _ in bar(range(len(sent)))]
# Wait all the process to finish
print("test4")
[p.join() for p in proc]
return [x for i, x in sorted(res)]
class PDBStat(object):
"""
Analyze distance distribution (should be extracted with pdbdist cmd)
......@@ -172,6 +257,12 @@ class PDBStat(object):
return df, best_gmm, best_logmm
@staticmethod
def test(name):
time.sleep(5)
print("done")
# print(name)
def run(self):
"""
Command line method
......@@ -181,12 +272,14 @@ class PDBStat(object):
"""
minsize = int(self.settings.pdbstat.config.get("sample_minsize", 20))
ncpus = int(self.settings.pdbstat.args.get("njobs", 20))
LOG.info("Reading csv matrix")
# Read pdb distance file
dists = pd.read_csv(self.settings.pdbstat.args["pdbdists"],
low_memory=False)
# dists = pd.read_csv(self.settings.pdbstat.args["pdbdists"],
# low_memory=False)
# LOG.debug(dists.head())
# Initialize target distance maps
inter_lowerbounds, inter_targetdists, inter_upperbounds = [
SsAaAtmMap() for _ in range(3)]
......@@ -203,72 +296,84 @@ class PDBStat(object):
for foo in ("lowerbounds", "targetdists", "upperbounds")]
# Iterative product through all groups
# Since the target map is a square matrix, we compute only half
# of the dataframe => Need a set of unique index corresponding to
# half of the dataframe
groups = inter_targetdists.index.get_values().tolist(), \
inter_targetdists.columns.get_values().tolist()
groups = set(itertools.product(*groups))
groups = [tuple(group) for group in map(sorted, groups)]
LOG.info("Computing gaussian mixtures")
for ss1, ssgroup1 in inter_targetdists.groupby(level='SecStruct'):
for res1, resgroup1 in ssgroup1.groupby(level='AminoAcid'):
for atm1, atmgroup1 in resgroup1.groupby(level='Atom'):
for ss2, ssgroup2 in inter_targetdists.groupby(level='SecStruct'):
for res2, resgroup2 in ssgroup2.groupby(
level='AminoAcid'):
for atm2, atmgroup2 in resgroup2.groupby(
level='Atom'):
ss_type = self.ss_type(ss1, ss2, intra=False)
subdists = self.subdist(
dists, ss_type, res1, res2, atm1,
atm2).as_matrix(columns=["dist"])
if subdists.size > minsize:
# Get mixture models related to log distribution
logmm = self.gmm_analysis(subdists)[2]
mixt_means = np.exp(logmm.means_)
inter_lowerbounds[(ss1, res1, atm1)][
(ss2, res2, atm2)] = np.min(subdists)
inter_lowerbounds[(ss2, res2, atm2)][
(ss1, res1, atm1)] = np.min(subdists)
inter_targetdists[(ss1, res1, atm1)][
(ss2, res2, atm2)] = np.min(mixt_means)
inter_targetdists[(ss2, res2, atm2)][
(ss1, res1, atm1)] = np.min(mixt_means)
inter_upperbounds[(ss1, res1, atm1)][
(ss2, res2, atm2)] = np.max(subdists)
inter_upperbounds[(ss2, res2, atm2)][
(ss1, res1, atm1)] = np.max(subdists)
if ss1 == ss2:
# Also save intra ss distance
ss_type = self.ss_type(ss1, ss2)
subdists = self.subdist(
dists, ss_type, res1, res2, atm1,
atm2).as_matrix(columns=["dist"])
if subdists.size > minsize:
# IIF we have more than 20 samples
# Get mixture models related to log distribution
logmm = self.gmm_analysis(subdists)[2]
mixt_means = np.exp(logmm.means_)
intra_lowerbounds[(ss1, res1, atm1)][
(ss2, res2, atm2)] = np.min(subdists)
intra_lowerbounds[(ss2, res2, atm2)][
(ss1, res1, atm1)] = np.min(subdists)
intra_targetdists[(ss1, res1, atm1)][
(ss2, res2, atm2)] = np.min(mixt_means)
intra_targetdists[(ss2, res2, atm2)][
(ss1, res1, atm1)] = np.min(mixt_means)
intra_upperbounds[(ss1, res1, atm1)][
(ss2, res2, atm2)] = np.max(subdists)
intra_upperbounds[(ss2, res2, atm2)][
(ss1, res1, atm1)] = np.max(subdists)
LOG.info("Serializing target distance maps")
pickle.dump(inter_lowerbounds.to_dict(),
open(interfiles[0], "wb"))
pickle.dump(inter_targetdists.to_dict(),
open(interfiles[1], "wb"))
pickle.dump(inter_upperbounds.to_dict(),
open(interfiles[2], "wb"))
pickle.dump(intra_lowerbounds.to_dict(),
open(intrafiles[0], "wb"))
pickle.dump(intra_targetdists.to_dict(),
open(intrafiles[1], "wb"))
pickle.dump(intra_upperbounds.to_dict(),
open(intrafiles[2], "wb"))
parmap(self.test, groups[:10])
# self.pool.map(test, groups)
# for ss1, ssgroup1 in inter_targetdists.groupby(level='SecStruct'):
# for res1, resgroup1 in ssgroup1.groupby(level='AminoAcid'):
# for atm1, atmgroup1 in resgroup1.groupby(level='Atom'):
# for ss2, ssgroup2 in inter_targetdists.groupby(level='SecStruct'):
# for res2, resgroup2 in ssgroup2.groupby(
# level='AminoAcid'):
# for atm2, atmgroup2 in resgroup2.groupby(
# level='Atom'):
# ss_type = self.ss_type(ss1, ss2, intra=False)
# subdists = self.subdist(
# dists, ss_type, res1, res2, atm1,
# atm2).as_matrix(columns=["dist"])
# if subdists.size > minsize:
# # Get mixture models related to log distribution
# logmm = self.gmm_analysis(subdists)[2]
# mixt_means = np.exp(logmm.means_)
#
# inter_lowerbounds[(ss1, res1, atm1)][
# (ss2, res2, atm2)] = np.min(subdists)
# inter_lowerbounds[(ss2, res2, atm2)][
# (ss1, res1, atm1)] = np.min(subdists)
# inter_targetdists[(ss1, res1, atm1)][
# (ss2, res2, atm2)] = np.min(mixt_means)
# inter_targetdists[(ss2, res2, atm2)][
# (ss1, res1, atm1)] = np.min(mixt_means)
# inter_upperbounds[(ss1, res1, atm1)][
# (ss2, res2, atm2)] = np.max(subdists)
# inter_upperbounds[(ss2, res2, atm2)][
# (ss1, res1, atm1)] = np.max(subdists)
#
# if ss1 == ss2:
# # Also save intra ss distance
# ss_type = self.ss_type(ss1, ss2)
# subdists = self.subdist(
# dists, ss_type, res1, res2, atm1,
# atm2).as_matrix(columns=["dist"])
# if subdists.size > minsize:
# # IIF we have more than 20 samples
# # Get mixture models related to log distribution
# logmm = self.gmm_analysis(subdists)[2]
# mixt_means = np.exp(logmm.means_)
#
# intra_lowerbounds[(ss1, res1, atm1)][
# (ss2, res2, atm2)] = np.min(subdists)
# intra_lowerbounds[(ss2, res2, atm2)][
# (ss1, res1, atm1)] = np.min(subdists)
# intra_targetdists[(ss1, res1, atm1)][
# (ss2, res2, atm2)] = np.min(mixt_means)
# intra_targetdists[(ss2, res2, atm2)][
# (ss1, res1, atm1)] = np.min(mixt_means)
# intra_upperbounds[(ss1, res1, atm1)][
# (ss2, res2, atm2)] = np.max(subdists)
# intra_upperbounds[(ss2, res2, atm2)][
# (ss1, res1, atm1)] = np.max(subdists)
#
# LOG.info("Serializing target distance maps")
# pickle.dump(inter_lowerbounds.to_dict(),
# open(interfiles[0], "wb"))
# pickle.dump(inter_targetdists.to_dict(),
# open(interfiles[1], "wb"))
# pickle.dump(inter_upperbounds.to_dict(),
# open(interfiles[2], "wb"))
# pickle.dump(intra_lowerbounds.to_dict(),
# open(intrafiles[0], "wb"))
# pickle.dump(intra_targetdists.to_dict(),
# open(intrafiles[1], "wb"))
# pickle.dump(intra_upperbounds.to_dict(),
# open(intrafiles[2], "wb"))
......@@ -14,7 +14,7 @@ import pkg_resources as pkgr
import aria.legacy.SequenceList as SequenceList
import aria.legacy.AminoAcid as AmnAcd
from six import iteritems, text_type
from .base import (reg_load, ppdict)
from .common import (reg_load, ppdict)
# import skbio.Protein as skprot
# TODO: interface skbio ??
......
......@@ -29,7 +29,7 @@ import aria.legacy.AminoAcid as AminoAcid
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from .base import (tickmin, tickrot, titleprint, addtup)
from .common import (tickmin, tickrot, titleprint, addtup)
from .ndconv import net_deconv
......
......@@ -11,7 +11,7 @@ import os.path