Commit e11080d5 authored by Fabrice  ALLAIN's avatar Fabrice ALLAIN
Browse files

PCA projection done in ensemble analysis part

parent 84f96cf6
......@@ -28,6 +28,65 @@ from aria.StructureEnsemble import StructureEnsemble, StructureEnsembleSettings
LOG = logging.getLogger(__name__)
def colscatter(X, axe, colors, ndim=2, axtitle="", xlabel="x",
ylabel="y", zlabel="z", legend_prefix="", others=False):
"""
Scatter plot with palette colors
Parameters
----------
X :
axe :
colors :
ndim :
(Default value = 2)
axtitle :
(Default value = "")
xlabel :
(Default value = "x")
ylabel :
(Default value = "y")
zlabel :
(Default value = "z")
legend_prefix :
(Default value = "")
others :
(Default value = False)
Returns
-------
"""
palette = sns.color_palette("hls", len(set(colors)))
dims = [X[:, i] for i in range(0, ndim)]
axe.scatter(*dims, c=colors,
cmap=ListedColormap(palette))
axe.invert_xaxis()
if others:
axe.legend([Line2D([0], [0], linestyle="none", c=palette[colidx],
marker="o")
for colidx in set(colors)],
[legend_prefix + str(colidx) if colidx != max(
colors) else "Others" for colidx in set(colors)],
numpoints=1)
else:
axe.legend([Line2D([0], [0], linestyle="none", c=palette[colidx],
marker="o")
for colidx in set(colors)],
[legend_prefix + str(colidx) for colidx in
set(colors)], numpoints=1)
axe.set_title(axtitle)
axe.set_xlabel(xlabel)
axe.set_ylabel(ylabel)
if ndim == 3:
axe.set_zlabel(zlabel)
class EnsembleAnalysis(object):
"""ARIA extended ensemble analysis"""
......@@ -37,6 +96,7 @@ class EnsembleAnalysis(object):
@staticmethod
def _get_ensemble_paths(iteration_path):
"""
Parameters
----------
......@@ -61,7 +121,7 @@ class EnsembleAnalysis(object):
else:
# no clustering, pdb list correspond to all generated structures
LOG.info("No cluster found in this iteration, compute analysis for"
"iteration ensemble")
" iteration ensemble")
list_of_pdb.append(
[foo for foo in glob(os.path.join(iteration_path, "*.pdb"))
if not os.path.basename(foo).startswith('fitted_')])
......@@ -74,6 +134,7 @@ class EnsembleAnalysis(object):
def violation_analysis(project, iteration_id, restraints, ensemble, out_file,
dists_ref=None, headerflag=True):
"""
Parameters
----------
......@@ -228,14 +289,16 @@ class EnsembleAnalysis(object):
LOG.info("Wrote %s file", out_file)
@staticmethod
def pca_projection(iter_dir, ensemble, molecule, atmask="CA"):
def pca_projection(
ensemble, molecule, infos, atmask="CA",
title="3D PCA projection on backbone coordinates", outfile=None):
"""
PCA projection of ensemble coordinates
Parameters
----------
iter_dir :
title
infos
ensemble :
molecule :
......@@ -248,6 +311,7 @@ class EnsembleAnalysis(object):
"""
sns.set_style('ticks')
pca = PCA(n_components=3)
mask = [a.getId() for c in molecule.get_chains() for r in c.getResidues()
for a in r.getAtoms() if a.getName() == atmask]
......@@ -257,9 +321,9 @@ class EnsembleAnalysis(object):
si.getSettings()['number_of_best_structures'] = 'all'
si._fit(mask)
# Get matrix of coordinates
fitcoords = si.getFittedCoordinates()
fitcoords = np.take(fitcoords, mask, axis=1)
ns, na, xyz = fitcoords.shape
# Change the shape of coords matrix in order to use pca, kmeans, ...
fitcoords.shape = ns, na * xyz
......@@ -269,50 +333,16 @@ class EnsembleAnalysis(object):
fitcoords_reduced = pca.fit_transform(fitcoords)
#
fig = plt.figure(figsize=(30, 20))
ax = fig.add_subplot(221, projection='3d', elev=-150, azim=110)
def colscatter(X, ax, colors, ndim=2, title="", xlabel="x", ylabel="y",
zlabel="z", legend_prefix="", others=False):
"""
Scatter plot with palette colors
"""
palette = sns.color_palette("hls", len(set(colors)))
dims = [X[:, i] for i in range(0, ndim)]
ax.scatter(*dims, c=colors,
cmap=ListedColormap(palette))
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if ndim == 3:
ax.set_zlabel(zlabel)
ax.invert_xaxis()
if others:
ax.legend([Line2D([0], [0], linestyle="none",
c=palette[colidx],
marker="o") for colidx in
set(colors)],
[legend_prefix + str(colidx) if colidx != max(
colors) else "Others" for colidx in set(colors)],
numpoints=1)
else:
ax.legend([Line2D([0], [0], linestyle="none",
c=palette[colidx],
marker="o") for colidx in
set(colors)],
[legend_prefix + str(colidx) for colidx in
set(colors)], numpoints=1)
colscatter(fitcoords_reduced, ax,
aria_km2[prot][data][it]['clust_infos'], ndim=3,
title="3D PCA projection on %s C$\\alpha$ coordinates \n%s contacts at iteration %s" % (
prot, data, str(it)),
xlabel="1st eigenvector", ylabel="2nd eigenvector",
zlabel="3rd eigenvector", legend_prefix="Clust ")
# return fitcoords, infos
fig = plt.figure()
ax = Axes3D(fig)
colscatter(
fitcoords_reduced, ax, infos, ndim=3, axtitle=title,
xlabel="Principal component 1", ylabel="Principal component 2",
zlabel="Principal component 3", legend_prefix="Clust ")
if outfile:
plt.savefig(outfile)
def run(self):
"""Execute Ensemble analysis"""
......@@ -373,8 +403,6 @@ class EnsembleAnalysis(object):
# Get list of pdb related to structure ensemble(s)
ens_paths = self._get_ensemble_paths(iteration_path)
# clustlists = [open(listpath).readlines() for listpath in
# glob(os.path.join(iteration_path, '*_clust*.list'))]
with Capturing() as output:
# We define here as many structure ensemble object as number of
......@@ -424,16 +452,25 @@ class EnsembleAnalysis(object):
out_file, dists_ref=dists_ref)
infos = [inf for inf in ensemble.getInfo()]
print(len(infos))
print(len(ens_paths[0]))
# [info[1].update({'clust': idx}) for info in infos for idx, clustlist in
# enumerate(clustlists) if filter(re.compile(info[0]).match, clustlist)]
[info[1].update({'clust': idx}) for info in infos for idx, clustlist in
enumerate(ens_paths) if filter(re.compile(info[0]).match, clustlist)]
# Compute pca projection of cluster labels on all the
# generated structures. Add later an option to visualize extra
# information like RMSD, quality, ... with a csv file
self.pca_projection()
print([info[1].get('clust') for info in infos])
print(len([info[1].get('clust') for info in infos]))
self.pca_projection(
ensemble, molecule, [info[1].get('clust', 0) for info in infos],
atmask=self.settings.analysis.config["atmask"],
title="3D PCA projection on %s backbone coordinates \n%s contacts"
" at iteration %s" % (protein_id, list_name,
str(iteration_id)),
outfile=os.path.join(out_path, "%s_%s_it%s_clusts.3dpca.png" %
(protein_id, list_name, str(iteration_id)))
)
# plt.show()
# with open(os.path.join(iter_dir,
# "analysis/pyfit/accuracydssp/RMSD.dat")) as rmsdfile:
# accdssp = {key: float(value) for key, value in
......
......@@ -24,7 +24,8 @@ LOG = logging.getLogger(__name__)
def addtup(tup, inc=1):
"""Increment all values by 1 in a tuple
"""
Increment all values by 1 in a tuple
Parameters
----------
......@@ -42,7 +43,8 @@ def addtup(tup, inc=1):
def titleprint(outfile, progname='', desc=''):
"""Init log file
"""
Init log file
Parameters
----------
......@@ -68,7 +70,8 @@ def titleprint(outfile, progname='', desc=''):
def get_filename(path):
"""Search filename in the given path
"""
Search filename in the given path
Parameters
----------
......@@ -85,6 +88,7 @@ def get_filename(path):
def reg_load(regex, filepath, sort=None):
"""
Parameters
----------
......@@ -116,7 +120,8 @@ def reg_load(regex, filepath, sort=None):
def sort_2dict(unsort_dict, key, reverse=True):
"""Sort 2d dict by key
"""
Sort 2d dict by key
Parameters
----------
......@@ -143,7 +148,8 @@ def sort_2dict(unsort_dict, key, reverse=True):
def cart_dist(vectx, vecty):
"""Evaluate cartesian distance beetween 2 points x, vecty
"""
Evaluate cartesian distance beetween 2 points x, vecty
Parameters
----------
......@@ -161,7 +167,8 @@ def cart_dist(vectx, vecty):
def format_str(string):
"""Convert str in bool, float, int or str
"""
Convert str in bool, float, int or str
Parameters
----------
......@@ -210,6 +217,7 @@ def format_str(string):
def format_dict(indict):
"""
Parameters
----------
......@@ -229,6 +237,7 @@ def format_dict(indict):
def ppdict(indict, indent=2):
"""
Parameters
----------
......@@ -246,7 +255,8 @@ def ppdict(indict, indent=2):
def tickmin(pandata, ntick=None, shift=0, offset=5):
"""Minimise number of ticks labels for matplotlib or seaborn plot
"""
Minimise number of ticks labels for matplotlib or seaborn plot
Parameters
----------
......@@ -288,7 +298,8 @@ def tickmin(pandata, ntick=None, shift=0, offset=5):
def tickrot(axes, figure, rotype='horizontal', xaxis=True, yaxis=True):
"""Matplot rotation of ticks labels
"""
Matplot rotation of ticks labels
Parameters
----------
......@@ -317,12 +328,11 @@ def tickrot(axes, figure, rotype='horizontal', xaxis=True, yaxis=True):
# TODO: Add another level when we use verbose options instead of displaying debug messages
class CustomLogging(object):
""" """
# default_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
# "conf/logging.json")
"""
custom logging
Customized python logging config
"""
# default_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
# "conf/logging.json")
default_file = "conf/logging.json"
def __init__(self, level=logging.INFO, desc=__doc__):
......@@ -342,6 +352,7 @@ class CustomLogging(object):
def update_msg(self, desc):
"""
Parameters
----------
......@@ -368,7 +379,8 @@ class CustomLogging(object):
return config
def set_outdir(self, outdir):
"""Create log directory and change log files location
"""
Create log directory and change log files location
Parameters
----------
......@@ -400,7 +412,12 @@ class CustomLogging(object):
logging.config.dictConfig(self.config)
def welcome(self):
""":return:"""
"""
Returns
-------
"""
desc = '''
================================================================================
......@@ -421,7 +438,9 @@ class Capturing(list):
def __enter__(self):
"""
:return:
Returns
-------
"""
# Stock default stdout and redirect current stdout to this class
self._stdout = sys.stdout
......@@ -434,7 +453,14 @@ class Capturing(list):
def __exit__(self, *args):
"""
:return:
Parameters
----------
args
Returns
-------
"""
self.extend("\n".join(self._stringio.getvalue().splitlines()))
self._stringio.truncate(0)
......@@ -472,8 +498,8 @@ class CommandProtocol(object):
@abstractmethod
def run(self):
"""main method to launch protocol
:return:
"""
Main method to launch protocol
Parameters
----------
......@@ -489,12 +515,13 @@ class CommandProtocol(object):
class NotDisordered(Select):
"""Define an atom as disordered or not in pdb selection"""
def accept_atom(self, atom):
"""Accept or not the atom if it does not correspond to an alternative
"""
Accept or not the atom if it does not correspond to an alternative
location
Parameters
----------
atom :
atom:
Returns
......
......@@ -24,6 +24,7 @@ LOG = logging.getLogger(__name__)
def check_file(prospective_file):
"""
Parameters
----------
......@@ -92,6 +93,7 @@ class AriaEcCommands(object):
def _update_logger(self, log):
"""
Parameters
----------
......@@ -133,7 +135,8 @@ class AriaEcCommands(object):
return parser
def _create_subparsers(self, parser):
"""Generate subcommands
"""
Generate subcommands
Parameters
----------
......@@ -152,7 +155,8 @@ class AriaEcCommands(object):
parser.add_parser(command, parents=[subcommand])
def _setup_argparser(self, desc=None):
"""setup opt & args
"""
setup opt & args
Parameters
----------
......@@ -198,7 +202,8 @@ class AriaEcCommands(object):
return parser
def _bbconv_argparser(self, desc=None):
"""bbconv opt & args
"""
bbconv opt & args
Parameters
----------
......@@ -228,7 +233,8 @@ class AriaEcCommands(object):
return parser
def _maplot_argparser(self, desc=None):
"""maplot opt & args
"""
maplot opt & args
Parameters
----------
......@@ -274,6 +280,7 @@ class AriaEcCommands(object):
@staticmethod
def _pdbqual_argparser(desc=None):
"""
Parameters
----------
......@@ -296,6 +303,7 @@ class AriaEcCommands(object):
@staticmethod
def _analysis_argparser(desc=None):
"""
Parameters
----------
......@@ -324,6 +332,7 @@ class AriaEcCommands(object):
@staticmethod
def _tbl2xml_argparser(desc=None):
"""
Parameters
----------
......@@ -350,6 +359,7 @@ class AriaEcCommands(object):
@staticmethod
def _pdbdist_argparser(desc=None):
"""
Parameters
----------
......
......@@ -272,6 +272,7 @@ remove_pdbs: False
pair_list: min
[analysis]
atmask: CA
violation_treshold: 0.5
nbest_structures: 15
sort_criterion: total_energy
......@@ -92,7 +92,8 @@ class AriaEcBbConverter(object):
@staticmethod
def compute_diversityvalue(msa, seqlen):
"""Compute bbcontacts diversity value
"""
Compute bbcontacts diversity value
Parameters
----------
......@@ -103,16 +104,16 @@ class AriaEcBbConverter(object):
Returns
-------
diversity_score: `float`
diversity_score : `float`
Notes
-----
Diversity score correspond to the square root of the multiple
sequence alignment length (:math:`m`) over the length of the protein
sequence (:math:`l`)
.. math:: Divscore = \frac{\sqrt{m}}{l}
.. math:: Divscore = \frac{\sqrt{m}}{l}
"""
msa_reg = re.compile(r"^>[A-Za-z0-9]+_[A-Za-z0-9]+")
msalen = 0
......@@ -141,7 +142,8 @@ class AriaXMLConverter(Converter, object):
@property
def molecule(self):
"""aria.Molecule.Molecule object or None. If a structure has been loaded,
"""
aria.Molecule.Molecule object or None. If a structure has been loaded,
use it to update the molecule
Parameters
......@@ -162,7 +164,8 @@ class AriaXMLConverter(Converter, object):
@staticmethod
def upd_mol(molecule, structure):
"""Update molecule object according to pdb structure
"""
Update molecule object according to pdb structure
Parameters
----------
......@@ -183,7 +186,8 @@ class AriaXMLConverter(Converter, object):
pass
def read_seq(self, seqpath):
"""Load aria Molecule object from seq file
"""
Load aria Molecule object from seq file
Parameters
----------
......@@ -240,6 +244,7 @@ class AriaXMLConverter(Converter, object):
def read_pdb(self, pdbpath):
"""
Parameters
----------
......@@ -269,7 +274,8 @@ class AriaXMLConverter(Converter, object):
@staticmethod
def deff(atm_dists, dpow=6):
"""Compute aria effective distance from input distances
"""
Compute aria effective distance from input distances
Parameters
----------
......@@ -290,7 +296,8 @@ class AriaXMLConverter(Converter, object):
# TODO: Use mako
@staticmethod
def write_dist_xml(dist_restraints, outfile):
"""Write aria distance restraint xml file
"""
Write aria distance restraint xml file
Parameters
----------
......@@ -361,6 +368,7 @@ reliable="{reliable}" list_name="{list_name}">
@staticmethod
def _write_helix_hb_tbl(sec_struct, outfile, dminus, dplus):
"""
Parameters
----------
......@@ -393,7 +401,8 @@ assign (resid {res1} and name o) (resid {res2} and name hn) 1.8 {dminus} {dplus
@staticmethod
def _write_hbmap_tbl(hbmap, outfile, dminus, dplus, n_hb=None,
hb_type="main", topo=None):
"""Build hbdond distance restraints from a res-res contact map. Tbl
"""
Build hbdond distance restraints from a res-res contact map. Tbl
restraints generated use pseudoatoms since we assume we don't know
which donnor/acceptor are involved.
......@@ -411,13 +420,14 @@ assign (resid {res1} and name o) (resid {res2} and name hn) 1.8 {dminus} {dplus
dminus :
n_hb :
(Default value = None)
(Default value = None)
topo :
(Default value = None)
(Default value = None)
Returns
-------
"""
# TODO: test with several deviation since these restraints should
# contain noise !!
......@@ -476,6 +486,7 @@ assign (resid {res1} and name o) (resid {res2} and name hn) 1.8 {dminus} {dplus
def write_hb_tbl(self, protein, outfile, hbmap=None, dminus=0.0,
dplus=0.5, n_hb=None, lr_type='main'):
"""
Parameters
----------
......@@ -490,13 +501,14 @@ assign (resid {res1} and name o) (resid {res2} and name hn) 1.8 {dminus} {dplus
outfile :
dminus :
(Default value = 0.0)
(Default value = 0.0)
n_hb :
(Default value = None)
(Default value = None)