analysis.py 16.1 KB
Newer Older
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
1
2
3
4
5
6
# coding=utf-8
"""
Created on 4/7/17

@author: fallain
"""
7
8
import re
import os
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
9
import logging
10
11
12
import numpy as np

from glob import glob
13
from collections import OrderedDict
14
from Bio.PDB import PDBParser, PDBIO
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
15
from aria.AriaXML import AriaXMLPickler
16
17
from aria.SuperImposer import SuperImposer
from .converter import AriaEcXMLConverter
18
19
20
21
from .base import NotDisordered, Capturing
from aria.DataContainer import DATA_SEQUENCE
from aria.StructureEnsemble import StructureEnsemble, StructureEnsembleSettings

Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
22
23
24
25

LOG = logging.getLogger(__name__)


26
class EnsembleAnalysis(object):
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
27
    """ARIA extended ensemble analysis"""
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
28
29
30
31

    def __init__(self, settings):
        self.settings = settings

32
33
    @staticmethod
    def _get_pdblist(iteration_path):
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
34
35
36
37
38
39
40
41
42
43
44
45
        """

        Parameters
        ----------
        iteration_path :
            

        Returns
        -------

        
        """
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        # Get the list of all the generated structures in iteration path if no
        # clustering. Otherwise, get the list of the lowest energy cluster
        if os.path.exists(os.path.join(iteration_path, "report.clustering")):
            # Get the list of pdb for each cluster
            LOG.info("Clusters found in this iteration, compute analysis for"
                     "each generated cluster ensemble")
            list_of_pdb = []
            for clustlist in glob(os.path.join(iteration_path, '*clust*.list')):
                with open(clustlist) as cluster:
                    list_of_pdb.append(cluster.read().splitlines())
        else:
            # no clustering, pdb list correspond to all generated structures
            LOG.info("No cluster found in this iteration, compute analysis for"
                     "iteration ensemble")
60
61
62
63
            list_of_pdb = [
                [foo for foo in glob(os.path.join(iteration_path, "*.pdb"))
                 if not os.path.basename(foo).startswith('fitted_')]]

64
65
        LOG.debug("Lists of structures:\n%s", "\n".join(
            [str(_) for _ in list_of_pdb]))
66
        return list_of_pdb
67

68
    @staticmethod
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
69
    def violation_analysis(project, iteration_id, restraints, ensemble, out_file,
70
71
                           dists_ref=None, headerflag=True):
        """
72

73
74
        Parameters
        ----------
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        project :
            
        iteration_id :
            
        restraints :
            
        ensemble :
            
        out_file :
            
        dists_ref :
            (Default value = None)
        headerflag :
            (Default value = True)
89

90
91
        Returns
        -------
92

Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
93
        
94
95
96
97
98
99
        """
        protein_id = project.getSettings()['name']
        nbest = project.getProtocol().getIterationSettings(iteration_id)[
            "number_of_best_structures"]
        cutoff = project.getProtocol().getIterationSettings(iteration_id)[
            "violation_analyser_settings"]["violation_tolerance"]
100
101
        with open(out_file, 'w') as out:

102
103
104
105
106
107
108
109
110
111
            for restraintlist in restraints:

                for rest in restraintlist:

                    output = []
                    dd = []
                    ddref = []

                    for contrib in rest.getContributions():

112
                        dist = None
113
                        try:
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
114
115
                            dist = [ensemble.getDistances(*sp.getAtoms())
                                    for sp in contrib]
116
117
118
119
                            # Liste des distances pour la contribution c pour
                            # chaque structure de l'ensemble (une ou plusieur
                            # distance(s) par structure de l'ensemble si contribution
                            # ambigue ou non)
120
121
                            dist = np.power(
                                np.sum(np.power(dist, - 6.), axis=0), -1. / 6)
122
123
124
125
126
                            # Liste des distances pour chaque structure de l'ensemble
                            # (liste associant une distance par structure de l'ensemble).
                            # Si contrib non ambig (associee a une seule paire de spins),
                            # cette liste devrait etre identique a la precedente
                            dd.append(dist)
127

128
129
130
131
                        except Exception as msg:
                            LOG.warning("%s: %s" % (Exception, msg))
                            pass

132
                        dref = None
133
134
135
136
137
138
                        if dists_ref:
                            try:
                                dref = [dists_ref(*sp.getAtoms()) for sp in contrib]
                                dref = np.power(np.sum(np.power(dref, -6.), axis=0),
                                                -1. / 6)
                                ddref.append(dref)
139

140
141
142
143
                            except Exception as msg:
                                LOG.warning("%s: %s" % (Exception, msg))
                                pass

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                        tmp = OrderedDict()
                        tmp['protein'] = protein_id
                        tmp['data'] = rest.getReferencePeak().getSpectrum().getName()
                        tmp['iteration'] = iteration_id
                        tmp['ens_size'] = nbest
                        tmp['rest_no'] = rest.getId()
                        tmp['contrib_no'] = contrib.getId()
                        # Assuming there is only one spin per contribution
                        tmp['resid_1'] = contrib[0][0].getResidue().getNumber()
                        tmp['resid_2'] = contrib[0][1].getResidue().getNumber()
                        tmp['res_1'] = contrib[0][0].getResidue().getName()[:3]
                        tmp['res_2'] = contrib[0][1].getResidue().getName()[:3]
                        tmp['atm_1'] = contrib[0][0].getName()
                        tmp['atm_2'] = contrib[0][1].getName()
                        tmp['viol_cutoff'] = cutoff
                        tmp['d_target'] = rest.getDistance()
                        tmp['lower_bound'] = rest.getLowerBound() - cutoff
                        tmp['upper_bound'] = rest.getUpperBound() + cutoff
                        tmp['rest_weight'] = rest.getWeight()
                        tmp['dc_min'] = np.min(dist) if dist is not None else None
                        tmp['dc_avg'] = np.mean(dist) if dist is not None else None  # Moyenne des distances associes a la contrib c dans l'ensemble
                        tmp['dc_med'] = np.median(dist) if dist is not None else None
                        tmp['dc_ref'] = np.mean(dref) if dref is not None else None

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
                        output.append(tmp)

                    # Liste des distances (effective si plusieurs spin par
                    # contribution) associe aux contraintes
                    dd = np.array(dd)
                    ddref = np.array(ddref)  # Idem pour natif

                    dd_eff = np.power(np.sum(np.power(dd, -6), axis=0), -1. / 6)
                    ddref_eff = np.power(np.sum(np.power(ddref, -6), axis=0), -1. / 6)

                    for x in range(len(output)):
                        # Moyenne des distances effectives dans l'ensemble de
                        # structure. Peut etre biaisee si on a des structures qui
                        # ont une distance anormalement elevee. Dans ce cas, on ne
                        # peut evaluer si la contrainte a ete correctement supprimee
                        #  dans l'ensemble etudie.
184
185
186
187
188
                        output[x]['Deff_min'] = np.min(dd_eff)
                        output[x]['Deff_avg'] = np.mean(dd_eff)
                        output[x]['Deff_med'] = np.median(dd_eff)
                        output[x]['Deff_sdev'] = np.std(dd_eff)
                        output[x]['Deff_ref'] = np.mean(ddref_eff)
189
190
191
192
193
194
195
196
197
198
199
200
                        output[x]['pc_viol'] = float(
                            np.sum(np.greater(dd_eff, rest.getUpperBound() + cutoff))) / nbest
                        # Normalement d_ref ne contient qu'un seul elt puisqu'il y a
                        #  qu'une seule structure ( a verifier )
                        # distance effective native
                        # Distance effective minimum dans l'ensemble
                        output[x]['viol'] = True if \
                            output[x]['pc_viol'] >= 0.5 else False
                        # Contrainte consideree comme valide si la distance effective
                        # dans la structure de reference est en dessous du seuil de
                        # violations.
                        output[x]['valid'] = True if \
201
                            output[x]['upper_bound'] >= output[x]['dc_ref'] >= output[x]['lower_bound'] else False
202
203
204
205
                        output[x]['contact_5'] = True if \
                            output[x]['dc_ref'] <= 5.0 else False
                        output[x]['contact_8'] = True if \
                            output[x]['dc_ref'] <= 8.0 else False
206

207
208
209
210
211
212
213
214
215
                        if output[x]['contact_8'] and output[x]['viol']:
                            output[x]['group'] = 'VP viol'
                        elif output[x]['contact_8'] and not output[x]['viol']:
                            output[x]['group'] = 'VP'
                        elif not output[x]['contact_8'] and output[x]['viol']:
                            output[x]['group'] = 'FP viol'
                        else:
                            output[x]['group'] = 'FP'

216
217
218
219
                        if headerflag:
                            out.write(",".join(output[0].keys()))
                            headerflag = False

220
                        out.write("\n" + ",".join(["%s" % output[x][k]
221
                                                   for k in output[x].keys()]))
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
222

223
224
225
226
227
228
229
230
        LOG.info("Wrote %s file", out_file)

    @staticmethod
    def pca_projection(iter_dir, ensemble, molecule, atmask="CA"):
        """

        Parameters
        ----------
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
231
232
233
234
235
236
237
238
        iter_dir :
            
        ensemble :
            
        molecule :
            
        atmask :
            (Default value = "CA")
239
240
241
242

        Returns
        -------

Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
243
        
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        """
        mask = [a.getId() for c in molecule.get_chains() for r in c.getResidues()
                for a in r.getAtoms() if a.getName() == atmask]

        # Align all structures on ca backbone
        si = SuperImposer(ensemble, molecule)
        si.getSettings()['number_of_best_structures'] = 'all'
        si._fit(mask)

        fitcoords = si.getFittedCoordinates()
        fitcoords = np.take(fitcoords, mask, axis=1)

        ns, na, xyz = fitcoords.shape
        # Change the shape of coords matrix in order to use pca, kmeans, ...
        fitcoords.shape = ns, na * xyz

        ensemble.getSettings()['number_of_best_structures'] = 15

        # return fitcoords, infos

    def run(self):
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
265
        """Execute Violation analysis"""
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        # Args
        project_path = self.settings.analysis.args["project"]
        # restraints_path = self.settings.analysis.args["restraints"]
        native_path = self.settings.analysis.args.get("ref")
        out_path = self.settings.analysis.args["output_directory"]
        list_name = self.settings.analysis.args["listname"]

        # Aria objects
        pickler = AriaXMLPickler()
        project = pickler.load(project_path)
        molecule_path = project.getData(DATA_SEQUENCE)[0].getLocation()[0]
        molecule = pickler.load(molecule_path)

        # If we are at the first iteration, we select the related ensemble and
        # restraints, otherwise we take the ensemble from the previous iteration
        iteration_path = self.settings.analysis.args["iteration"]
        iteration_id = int(re.search(
            '[0-9]+', os.path.basename(iteration_path)).group(0))

        if iteration_id != 0:
            iteration_path = os.path.join(
                os.path.dirname(iteration_path), "it%d" % (iteration_id - 1))
            LOG.info("Ensemble analysis will be done on restraints and "
                     "ensemble from it%d with violation criteria of it%d",
                     (iteration_id - 1), iteration_id)
            if not os.path.exists(iteration_path):
                LOG.error("Can not found previous iteration (%s)",
                          iteration_path)

        restraints = glob(os.path.join(iteration_path, '*restraints.xml'))
        if not restraints:
            # Load tbl restraints and convert them into aria xml format
            restraints = glob(os.path.join(iteration_path, '*.tbl'))
            restraints = AriaEcXMLConverter.tbl2xml(
                iteration_path, molecule_path, restraints, list_name)
        restraints = [pickler.load(restraint).restraint
                      for restraint in restraints]

        # Setup
        protein_id = project.getSettings()['name']
        nbest = project.getProtocol().getIterationSettings(iteration_id)[
            "number_of_best_structures"]
        sort_crit = project.getProtocol().getIterationSettings(iteration_id)[
            "sort_criterion"]

        se_settings = StructureEnsembleSettings()
        se_settings['sort_criterion'] = sort_crit
        se_settings['number_of_best_structures'] = nbest

        # Get list of pdb related to structure ensemble(s)
        list_of_pdb = self._get_pdblist(iteration_path)

        # Read structure ensembles
        LOG.info("Reading structure ensemble(s)")
        with Capturing() as output:
            ensembles = [
                StructureEnsemble(se_settings) for _ in list_of_pdb]
            [ensemble.read(list_of_pdb[i], molecule, format='cns')
             for i, ensemble in enumerate(ensembles)]
            [ensemble.sort()
             for i, ensemble in enumerate(ensembles)]
            if output:
                LOG.info(output)

        # Get the lowest energy ensemble
        LOG.info("Sorting structure ensemble(s) with %s criteria", sort_crit)
        energies = np.array([
            np.mean([d['total_energy'] for d in ens.getInfo()[:, 1]][:nbest])
            if len(ens) >= nbest else None for ens in ensembles], dtype=np.float)
        ensemble = ensembles[np.argmin(energies)]

        # Get native structure
        # # Issue with several pdb files which have the same residue_number in
        #  atm and hetatm sections ... We remove them with bio pdb
        dists_ref = None
        if native_path:
            LOG.info("Reading native structure")
            logging.captureWarnings(True)
            native = PDBParser().get_structure(protein_id, native_path)
            native_path = os.path.join(out_path, protein_id + "_ordered.native.pdb")
            io = PDBIO()
            io.set_structure(native)
            io.save(native_path, select=NotDisordered())

            native = StructureEnsemble(se_settings)
            native.display_warnings = 0
            native.read([native_path], molecule, format='cns')

            dists_ref = native.getDistances

Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
356
357
358
        # We get here the distance of 'number_of_best_structures' in the
        #  ensemble
        # ens_getdists = ensemble.getDistances
359
360
361

        LOG.info("Violation analysis")
        out_file = os.path.join(out_path, 'violations.csv')
Fabrice  ALLAIN's avatar
Fabrice ALLAIN committed
362
        self.violation_analysis(project, iteration_id, restraints, ensemble,
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
                                out_file, dists_ref=dists_ref)

        infos = [inf for inf in ensemble.getInfo()]
        print(len(infos))
        print(len(list_of_pdb[0]))

        clustlists = [open(listpath).readlines() for listpath in
                      glob(os.path.join(iteration_path, '*_clust*.list'))]

        # with open(os.path.join(iter_dir,
        #                        "analysis/pyfit/accuracydssp/RMSD.dat")) as rmsdfile:
        #     accdssp = {key: float(value) for key, value in
        #                [line.split() for line in rmsdfile if
        #                 re.search("pdb", line)]}

        # Add clust label key for each pdb struct
        # [info[1].update({'clust': idx}) for info in infos for idx, clustlist in
        #  enumerate(clustlists) if filter(re.compile(info[0]).match, clustlist)]
        # [info[1].update({'ensemble': False})
        #  for info in infos]
        # [info[1].update({'ensemble': True})
        #  for info in infos
        #  for clustlist in [clustlist[0:15] for clustlist in clustlists] if
        #  filter(re.compile(info[0]).match, clustlist)]
        # [info[1].update(
        #     {'accdssp': accdssp.get(os.path.basename(info[0]), None)})
        #  for info in infos]