Commit 1cf6e97f authored by fabrice's avatar fabrice

Bug fixe: wrong atom list when pair_list set to all (add terminal atoms)

parent 8d077988
......@@ -15,12 +15,15 @@ import json
import re
import pkg_resources as pkgr
import aria.legacy.AminoAcid as AminoAcid
import aria.ConversionTable as ConversionTable
import aria.conversion
from .base import (Capturing, get_filename)
from .base import get_filename, Capturing
from .protein import Protein
from .reader import ProtFileListReader
from .protmap import (ResAtmMap, ResMap)
from aria.Molecule import Molecule
from aria.tools import string_to_segid
from aria.AriaXML import AriaXMLPickler
from aria.conversion import Converter, SequenceList, MoleculeSettings
logger = logging.getLogger(__name__)
......@@ -90,10 +93,52 @@ class AriaEcBbConverter(object):
return math.sqrt(n) / float(l)
class AriaXMLConverter(object):
def __init__(self):
# TODO: stock here pickle file path for cns donor/ acceptor dict
pass
class AriaXMLConverter(Converter, object):
def __init__(self, settings):
Converter.__init__(self)
self._mol_set = MoleculeSettings()
self._pickler = AriaXMLPickler()
self.outprefix = ""
self.settings = settings
self.molecule = None
def load_molecule(self, seqpath):
self._mol_set['format'] = 'seq'
self._mol_set['input'] = seqpath
self._mol_set['output'] = os.path.join(self.settings.infra["xml"],
self.outprefix + ".xml")
self._mol_set['type'] = 'PROTEIN'
self._mol_set['segid'] = ' '
self._mol_set['first_residue_number'] = 1
self._mol_set['naming_convention'] = ''
self._mol_set['name'] = self.outprefix
segids = self._mol_set['segid']
segids = segids.split('/')
segids = [string_to_segid(segid) for segid in segids]
# Recup molecule
chain_types = {}
for s in segids:
chain_types[s] = self._mol_set['type']
sequence = SequenceList(chain_types, self._mol_set['first_residue_number'])
with Capturing() as output:
sequence.parse(self._mol_set['input'], self._mol_set['format'],
self._mol_set['naming_convention'])
logger.info("\n" + "".join(output))
factory = self.create_factory()
factory.reset()
factory.unfreeze()
chains = sequence.create_chains(factory)
self.molecule = Molecule(self._mol_set['name'])
[self.molecule.add_chain(chains[seg]) for seg in segids]
@staticmethod
def deff(distance_list, dpow=6):
......@@ -350,11 +395,35 @@ assign (resid {res1} and name n) (resid {res1} and name ca) (resid {res1} and na
class AriaEcXMLConverter(AriaXMLConverter):
def __init__(self, settings):
def __init__(self, *args, **kwargs):
self.restraint_list = []
self.settings = settings
self.protname = ""
super(AriaEcXMLConverter, self).__init__()
super(AriaEcXMLConverter, self).__init__(*args, **kwargs)
def atm_product(self, idx1, res1, idx2, res2, prod_type="min"):
def resname(res):
return AminoAcid.AminoAcid(res)[0]
def min_atms(aa1, aa2, atms):
# Function to minimize atom pair list between aa1 & aa2
return [
atmpair for atmpair in atms if atmpair in (
('CA', 'CA'),
('CB', 'CB'),
self.settings.scsc_min[resname(aa1)][resname(aa2)])]
atms1 = self.molecule.get_chains()[0].getResidues()[idx1].atoms.keys()
atms2 = self.molecule.get_chains()[0].getResidues()[idx2].atoms.keys()
if prod_type == "min":
return min_atms(res1, res2, list(itertools.product(atms1, atms2)))
elif prod_type == "heavy":
atms1 = filter(ResAtmMap.heavy_reg.match, atms1)
atms2 = filter(ResAtmMap.heavy_reg.match, atms2)
return list(itertools.product(atms1, atms2))
elif prod_type == "all":
return list(itertools.product(atms1, atms2))
else:
logger.error("Wrong pair_list option. Pair_list set to min")
return min_atms(res1, res2, list(itertools.product(atms1, atms2)))
def targetdistmap(self, distype, sequence, distfile=None, groupby=None):
# TODO: valeur par defaut de distfile au fichier contenant les infos
......@@ -451,79 +520,14 @@ class AriaEcXMLConverter(AriaXMLConverter):
pair_flag = self.settings.setup.config["pair_list"]
target_dist = self.settings.setup.config["restraint_distance"]
conv_table = ConversionTable.ConversionTable().table['AMINO_ACID'][
'iupac']
# TODO: atm_pair method
if pair_flag == "min":
def min_atms(aa1, aa2, atms):
# Function to minimize atom pair list between aa1 & aa2
return [
atmpair for atmpair in atms if atmpair in (
('CA', 'CA'),
('CB', 'CB'),
self.settings.scsc_min[aa1][aa2])]
aa_atm = dict((AminoAcid.AminoAcid(aa)[0],
filter(contactmap.heavy_reg.match, atms.keys()))
for aa, atms in conv_table.items())
aa_atm_pair = \
{
(aa1, aa2): min_atms(aa1, aa2, list(itertools.product(atms1,
atms2)))
for aa1, atms1 in aa_atm.items()
for aa2, atms2 in aa_atm.items()
}
elif pair_flag == "heavy":
# Dict giving all atms foreach aa
aa_heav_atm = dict((AminoAcid.AminoAcid(aa)[0],
filter(contactmap.heavy_reg.match, atms.keys()))
for aa, atms in conv_table.items())
# Dict giving atoms product for each res-res pair
aa_atm_pair = {(aa1, aa2): list(itertools.product(atms1, atms2))
for aa1, atms1 in aa_heav_atm.items()
for aa2, atms2 in aa_heav_atm.items()
}
elif pair_flag == "all":
# Dict giving all atms foreach aa
aa_atm = dict((AminoAcid.AminoAcid(aa)[0],
filter(contactmap.all_reg.match, atms.keys()))
for aa, atms in conv_table.items())
# Dict giving atoms product for each res-res pair
aa_atm_pair = {(aa1, aa2): list(itertools.product(atms1, atms2))
for aa1, atms1 in aa_atm.items()
for aa2, atms2 in aa_atm.items()
}
else:
logger.error("Wrong pair_list option. Pair_list set to min")
def min_atms(aa1, aa2, atms):
# Function to minimize atom pair list between aa1 & aa2
return [
atmpair for atmpair in atms if atmpair in (
('CA', 'CA'),
('CB', 'CB'),
self.settings.scsc_min[aa1][aa2])]
aa_atm = dict((AminoAcid.AminoAcid(aa)[0],
filter(contactmap.heavy_reg.match, atms.keys()))
for aa, atms in conv_table.items())
aa_atm_pair = \
{
(aa1, aa2): min_atms(aa1, aa2, list(itertools.product(atms1,
atms2)))
for aa1, atms1 in aa_atm.items()
for aa2, atms2 in aa_atm.items()
}
def min_ind(ind):
return ind if ind >= 0 else 0
def max_ind(ind, max_idx):
return ind if ind <= max_idx else max_idx
def resname(residx):
return AminoAcid.AminoAcid(contactmap.index.values[residx][-3:])[0]
def resname_3l(residx):
return AminoAcid.AminoAcid(contactmap.index.values[residx][-3:])[1]
max_seqidx = len(contactmap.sequence)
restraint_dict = collections.OrderedDict()
......@@ -531,18 +535,19 @@ class AriaEcXMLConverter(AriaXMLConverter):
contrib_id = 0
for contactidx, contact in enumerate(pair_list):
# /!\ humanidx in contact must start at 1 !n_factor
# Add neighbors if neigh_flag
logger.debug("Contact %s" % str(contact))
# Add neighbors if neigh_flag
resx_idx = range(min_ind(contact[0] - 1),
max_ind(contact[0] + 2, max_seqidx)) if \
neigh_flag else [contact[0]]
resy_idx = range(min_ind(contact[1] - 1),
max_ind(contact[1] + 2, max_seqidx)) if \
neigh_flag else [contact[1]]
contactweight = weight_list[contactidx]
contactweight = weight_list[contactidx]
dist_list = []
if adr_flag:
rest_id += 1
# TODO: Autre dist_list if target != ResMap
......@@ -576,22 +581,27 @@ class AriaEcXMLConverter(AriaXMLConverter):
target_dist = self.settings.setup.confn_factorig[
"restraint_distance"]
for resx in resx_idx:
for resy in resy_idx:
resxn = targetdist.index.levels[0][resx]
resyn = targetdist.index.levels[0][resy]
for atm_pair in aa_atm_pair[(resname(resx),
resname(resy))]:
for idx_x in resx_idx:
for idx_y in resy_idx:
mapidx_x = targetdist.index.levels[0][idx_x]
mapidx_y = targetdist.index.levels[0][idx_y]
res_x = resname_3l(idx_x)
res_y = resname_3l(idx_y)
atm_pairs = self.atm_product(idx_x, res_x, idx_y, res_y,
pair_flag)
for atm_pair in atm_pairs:
if adr_flag:
contrib_id += 1
else:
if len(targetdist.index.levels) == 2:
target_dist = targetdist.loc[resxn,
atm_pair[0]][resyn,
target_dist = targetdist.loc[mapidx_x,
atm_pair[0]][mapidx_y,
atm_pair[1]]
else:
target_dist = "%.2f" % targetdist.iat[resx,
resy]
target_dist = "%.2f" % targetdist.iat[idx_x,
idx_y]
if target_dist is None:
# In case missing distance values
......@@ -600,7 +610,7 @@ class AriaEcXMLConverter(AriaXMLConverter):
logger.warning(
"Target distance is missing for restraint "
"%s-%s (%s). Using default distance (%s)"
% (resx + 1, resy + 1, atm_pair, target_dist))
% (idx_x + 1, idx_y + 1, atm_pair, target_dist))
rest_id += 1
contrib_id = 1
......@@ -636,12 +646,12 @@ class AriaEcXMLConverter(AriaXMLConverter):
"weight": 1.0
},
"spin_pair": {
resx + 1: atm_pair[0],
resy + 1: atm_pair[1]
idx_x + 1: atm_pair[0],
idx_y + 1: atm_pair[1]
}
}
xml_file = self.settings.infra["xml"] + "/" + "_".join((
self.protname, listname)) + ".xml"
self.outprefix, listname)) + ".xml"
self.write_dist_xml(restraint_dict, xml_file)
return xml_file, pair_list
......@@ -667,11 +677,11 @@ class AriaEcXMLConverter(AriaXMLConverter):
:param hbmap: Extra hbond map (eg: metapsicov hbonds)
:return:
"""
dihed_file = os.path.join(self.settings.infra["tbl"], self.protname +
dihed_file = os.path.join(self.settings.infra["tbl"], self.outprefix +
"_dihed.tbl")
hb_file = os.path.join(self.settings.infra["tbl"], self.protname +
hb_file = os.path.join(self.settings.infra["tbl"], self.outprefix +
"_hbond.tbl")
ssdist_file = os.path.join(self.settings.infra["tbl"], self.protname +
ssdist_file = os.path.join(self.settings.infra["tbl"], self.outprefix +
"_ssdist.tbl")
self.write_dihedral_tbl(protein.sec_struct.ss_matrix, dihed_file)
self.write_hb_tbl(protein, hb_file,
......@@ -684,39 +694,21 @@ class AriaEcXMLConverter(AriaXMLConverter):
ssdist_file)
return {'hbond': hb_file, 'dihed': dihed_file, 'ssdist': ssdist_file}
def write_xmlseq(self, seqpath):
xml_file = os.path.join(self.settings.infra["xml"], self.protname +
".xml")
m = aria.conversion.MoleculeSettings()
m['format'] = 'seq'
m['input'] = seqpath
m['output'] = xml_file
m['type'] = 'PROTEIN'
m['segid'] = ' '
m['first_residue_number'] = 1
m['naming_convention'] = ''
m['name'] = self.protname
c = aria.conversion.ConverterSettings()
c.reset()
c['molecule'] = m
c['project_name'] = self.protname
converter = aria.conversion.Converter()
converter.setSettings(c)
# TODO: generate xml in order to use convert method ??
# converter.convert()
with Capturing() as output:
converter._convert_sequence()
def write_xmlseq(self):
logger.info("\n" + "".join(output))
try:
self._pickler.dump(self.molecule, self._mol_set[
'output'])
except Exception, msg:
logger.error("Error writing xml seq file : %s" % msg)
return xml_file
return self._mol_set['output']
def write_project(self, aria_template, seqfile, dist_files, tbl_files,
desc=""):
def write_ariaproject(self, aria_template, seqfile, dist_files, tbl_files,
desc=""):
if aria_template:
template = os.path.abspath(aria_template)
......@@ -755,21 +747,21 @@ class AriaEcXMLConverter(AriaXMLConverter):
for direct in (work_dir, temp_root):
if not os.path.exists(os.path.join(direct, self.protname)):
os.makedirs(os.path.join(direct, self.protname))
if not os.path.exists(os.path.join(direct, self.outprefix)):
os.makedirs(os.path.join(direct, self.outprefix))
if not os.path.exists(os.path.join(direct, self.protname, desc)):
os.makedirs(os.path.join(direct, self.protname, desc))
if not os.path.exists(os.path.join(direct, self.outprefix, desc)):
os.makedirs(os.path.join(direct, self.outprefix, desc))
work_dir = os.path.join(work_dir, self.protname, desc)
temp_root = os.path.join(temp_root, self.protname, desc)
work_dir = os.path.join(work_dir, self.outprefix, desc)
temp_root = os.path.join(temp_root, self.outprefix, desc)
aria_project_dict['working_directory'] = work_dir
aria_project_dict['temp_root'] = temp_root
project = {'project_name': "_".join((self.protname, desc)),
project = {'project_name': "_".join((self.outprefix, desc)),
'date': datetime.date.today().isoformat(),
'file_root': "_".join((self.protname, desc))}
'file_root': "_".join((self.outprefix, desc))}
aria_project_dict.update(project)
data_molecule = {'molecule_file': seqfile}
......
No preview for this file type
......@@ -80,7 +80,7 @@ class AriaEcContactMap(object):
# Use only position filter
# self.filter(fo.mapdict, fo.filetype, fo.contactlist,
# self.protein, clashlist=fo.clashlist,
# protname=self.protname,
# outprefix=self.outprefix,
# outdir=self.settings.outdir, mapfilters="pos")
self.allresmap[(fo.filename, fo.filetype)] = fo.mapdict
......
No preview for this file type
No preview for this file type
......@@ -54,7 +54,7 @@ class AriaEcSetup:
# -------------------------------------------------------------------- #
self.outprefix = get_filename(self.settings.setup.args.get("seq",
None))
self.converter.protname = self.outprefix
self.converter.outprefix = self.outprefix
# ------------------------- Load sequence ---------------------------- #
self.protein.set_aa_sequence(self.settings.setup.args.get("seq", None))
# -------------- Load secondary structure prediction ----------------- #
......@@ -69,7 +69,7 @@ class AriaEcSetup:
# ---------------------------- Processing ---------------------------- #
# -------------------------------------------------------------------- #
# TODO: write submatrix in a file
# TODO: change read method in reader to __call__ ?
# TODO: change read method in reader to __call__
# -------------------------- contact maps ---------------------------- #
self.reader.read(self.settings.setup.args.get("infiles"),
filetypelist=self.settings.setup.args.get("contact_types"),
......@@ -142,7 +142,8 @@ class AriaEcSetup:
# ----------------------------- SEQ file ----------------------------- #
self.protein.write_seq(os.path.join(self.settings.infra.get("others", ''),
self.outprefix + ".seq"))
# Load aria molecule object from generated seq file
self.converter.load_molecule(self.protein.seqfile_path)
# --------------------------- TBL restraints ------------------------- #
# Setting contact number limit for hbmap
n_hb = int(len(self.protein.aa_sequence.sequence) *
......@@ -159,16 +160,16 @@ class AriaEcSetup:
self.allresmap, self.targetmap)
# --------------------------- XML SEQ file --------------------------- #
seq_file = self.converter.write_xmlseq(self.protein.seqfile_path)
xmlseq_file = self.converter.write_xmlseq()
# ---------------------- ARIA XML project file ----------------------- #
aria_template = self.settings.main.config["ariaproject_template"] if \
self.settings.main.config["ariaproject_template"] and \
os.path.exists(self.settings.main.config["ariaproject_template"])\
else None
self.converter.write_project(aria_template,
seq_file, dist_files, tbl_files,
desc="_".join(sorted(self.allresmap.keys())))
self.converter.write_ariaproject(aria_template,
xmlseq_file, dist_files, tbl_files,
desc="_".join(sorted(self.allresmap.keys())))
# ------------------------------ others ------------------------------ #
self.write_optional_files()
......
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment