From 475f4319b011d5324a4907d54ee010ca8e82ca8f Mon Sep 17 00:00:00 2001 From: David BIKARD <david.bikard@pasteur.fr> Date: Thu, 7 Nov 2019 11:03:42 +0100 Subject: [PATCH] Initial commit --- model/predict.py | 61 ++++++++++++++++++++++++++++++++++++++++++ model/reg_coef.pkl | Bin 0 -> 1254 bytes tests/__init__.py | 0 tests/predict_test.py | 30 +++++++++++++++++++++ 4 files changed, 91 insertions(+) create mode 100644 model/predict.py create mode 100644 model/reg_coef.pkl create mode 100644 tests/__init__.py create mode 100644 tests/predict_test.py diff --git a/model/predict.py b/model/predict.py new file mode 100644 index 0000000..9d0b838 --- /dev/null +++ b/model/predict.py @@ -0,0 +1,61 @@ +import numpy as np +import gffpandas.gffpandas as gffpd +from typing import Tuple +import re + +with open("on_target/model/reg_coef.pkl", "br") as handle: + coef = np.load(handle, allow_pickle=True) + +bases = ["A", "T", "G", "C"] + + +def encode(seq): + '''One-hot encoding of a sequence (only non-ambiguous bases (ATGC) accepted)''' + return np.array([[int(b == p) for b in seq] for p in bases]) + + +# Quartiles: q1 > 0.4 > q2 > -0.08 > q3 > -0.59 > q4 +def predict(X): + return [np.sum(x * coef) for x in X] + + +def rev_comp(seq): + comp = str.maketrans("ATGC", "TACG") + return seq.translate(comp)[::-1] + + +def find_targets(seq): + repam = "[ATGC]GG" + L = len(seq) + seq_revcomp = rev_comp(seq) + alltargets = [ + dict( + [ + ("target", m.group(1)), + ("guide", m.group(1)[:20]), + ("start", L - m.start() - 20), + ("stop", L - m.start()), + ("pam", L - m.start() - 22), + ("ori", "-"), + ] + ) + for m in re.finditer("(?=([ATGC]{6}" + repam + "[ATGC]{16}))", seq_revcomp) + ] + return alltargets + + +def on_target_predict(seq): + seq = seq.upper() # make uppercase + seq = re.sub(r"\s", "", seq) # removes white space + alltargets = find_targets(seq) + if alltargets: + X = np.array( + [encode(tar["target"][:7] + tar["target"][9:]) for tar in alltargets] #encore and remove GG of PAM + ) + X = X.reshape(X.shape[0], -1) + preds = predict(X) + for i, target in enumerate(alltargets): + target.update({"pred": preds[i]}) + return alltargets + else: + return [] diff --git a/model/reg_coef.pkl b/model/reg_coef.pkl new file mode 100644 index 0000000000000000000000000000000000000000..5509c33e49403cd4672187d245c848ae0aa5cc3b GIT binary patch literal 1254 zcmYjR3r|!>7`@A@bQJ`7#8T3hNaT@9B-mgTGuT=oE<uXaP$jjnP?i>Uxyz!ZrMd{q z`~7kkC?EorC6t0-DU00XoFCSH6n5L2+?l!aedo;i&NsR4Fsr?_sktr3>TuTOG_^Lm z8mgVn>NfL3XPwnyZ*e(WtuC{!9cP*CwZ~E2WYIi&nDwF6QCnwiG1u9T7G^!X(xhox zjUG`MC39o7tHGYDM|SH`<@!54dUUz|Zl=X#sq*ON6T(_oTeIL!Rfb9acUSb&7R_Sn z(PPT>*i6eYTY-M2w6yg0-}3xBSTvU&SJP(;uS|{*(D4KJMNBy`!@(vPiUAB-3UU)L zq%zpi;^5jx>`{1A!g0lcO8l<sf1>vm*Yi0F1Fz~af@iWYI4xuEW6rgQaS;L82oz$| zg$;~iSH4!;KYYwFjLK`^5c6EepXr!Aisdt4{vN;DZo;_B@zXcs=)VRdBoAF!<ytX^ z@Dc%K!i)=V;cLW*O0Rk#%#%D6alvxVHeZjAs>uHzJ{X~!L)`g>o*MS5a8BDee&s&h zA!INwICGwTvGARH^q&|F&pCZThN&ic6?ZuN<W36vu_rQoPI}pwD#dRpC;q^G1s9^= zQx22Vc#s+{+H0|cZu-*hi;cOrl9MlU0oyQA&qxE<!95%(ye;@D9jjQuq~b32N-u<E z3o@wbi!ZTVpPr4?(2n9dSCTkv!=}oN;b{6LZem&V^ows3KVTn+>=nzqqvNp8mr3{w zhP_OSy{?Xq4)UJNaEBH!;`!og20iQ)vggvp4eqM<IU!sNx$u<~7f7iF2e|UdCG3`A zUvZn>i}a*$pDQrZ!ZD1+NA#jQbPn7csV#=lWUCUIDV^i3(v?vxD=yxx%*3XKjZkI3 zaOnd~i7#VE{)U7#z#nhE!FjnYCwto0Wr>`VW2W#bgvP<W40dWR>?q7*`6@km7~wNj z5J^8g2xQ@n;#k5BjunaCK^R%@ar&G%Hz(d3So#cpS4Av4U|gp6fs{RQTV#e*WIdJc zcZG3^GYVU2dAUzdQ99<0q|e0H4vwglMiV_hbDD!el!i6*^0z0^N$@N5bHD|kXj{Sy zg~g(_&(U>FG7S0)>A#098b`fKklQ_**vh36iH!Zou~-~_%1IdCe(c07SEWJXZ}E8y z6k@c9BQP9t2N_`JdAt@#PZjpKhQW4h2X(;hkoXoB#E_mm7!gwCAv`hM%Z8!H+vKlv Nwp>qe>4`Nt{{TAQFn$04 literal 0 HcmV?d00001 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/predict_test.py b/tests/predict_test.py new file mode 100644 index 0000000..e26d77f --- /dev/null +++ b/tests/predict_test.py @@ -0,0 +1,30 @@ +import pytest +from on_target.model.predict import on_target_predict + + +def test_on_target_predict_empty(): + # Empty sequence + predicted_target = on_target_predict("") + assert len(predicted_target) == 0, "the list is non empty" + + +def test_on_target_predict_size_guide(): + size_guide = 20 + predicted_targets = on_target_predict( + "TGCCTGTTTACGCGCCGATTGTTGCGAGATTTGGACGGACGTTGACGGGGTCTATACCTGCGACCCGCGTCAGGTGCCCGATGCGAGGTTGTTGAAGTCGATGTCCTACCAGGAAGCGATGGAGCTTTCCTACTTCGGCG" + ) + guides = (predicted_target["guide"] for predicted_target in predicted_targets) + for guide in guides: + assert len(guide) == size_guide, ( + "the guide do not have a length of " + size_guide + ) + pams = ( + (predicted_target["pam"], predicted_target["start"]) + for predicted_target in predicted_targets + ) + for pam in pams: + (pam_val, start_val) = pam + assert ( + start_val - pam_val == 2 + ), "the difference between start and pam position is different than 2" + -- GitLab