diff --git a/model/predict.py b/model/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0b8389169883b8dd101c97e0fbc611200cd07b --- /dev/null +++ b/model/predict.py @@ -0,0 +1,61 @@ +import numpy as np +import gffpandas.gffpandas as gffpd +from typing import Tuple +import re + +with open("on_target/model/reg_coef.pkl", "br") as handle: + coef = np.load(handle, allow_pickle=True) + +bases = ["A", "T", "G", "C"] + + +def encode(seq): + '''One-hot encoding of a sequence (only non-ambiguous bases (ATGC) accepted)''' + return np.array([[int(b == p) for b in seq] for p in bases]) + + +# Quartiles: q1 > 0.4 > q2 > -0.08 > q3 > -0.59 > q4 +def predict(X): + return [np.sum(x * coef) for x in X] + + +def rev_comp(seq): + comp = str.maketrans("ATGC", "TACG") + return seq.translate(comp)[::-1] + + +def find_targets(seq): + repam = "[ATGC]GG" + L = len(seq) + seq_revcomp = rev_comp(seq) + alltargets = [ + dict( + [ + ("target", m.group(1)), + ("guide", m.group(1)[:20]), + ("start", L - m.start() - 20), + ("stop", L - m.start()), + ("pam", L - m.start() - 22), + ("ori", "-"), + ] + ) + for m in re.finditer("(?=([ATGC]{6}" + repam + "[ATGC]{16}))", seq_revcomp) + ] + return alltargets + + +def on_target_predict(seq): + seq = seq.upper() # make uppercase + seq = re.sub(r"\s", "", seq) # removes white space + alltargets = find_targets(seq) + if alltargets: + X = np.array( + [encode(tar["target"][:7] + tar["target"][9:]) for tar in alltargets] #encore and remove GG of PAM + ) + X = X.reshape(X.shape[0], -1) + preds = predict(X) + for i, target in enumerate(alltargets): + target.update({"pred": preds[i]}) + return alltargets + else: + return [] diff --git a/model/reg_coef.pkl b/model/reg_coef.pkl new file mode 100644 index 0000000000000000000000000000000000000000..5509c33e49403cd4672187d245c848ae0aa5cc3b Binary files /dev/null and b/model/reg_coef.pkl differ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/predict_test.py b/tests/predict_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e26d77f8da2fc1ab3a89c96d2493cec972419da0 --- /dev/null +++ b/tests/predict_test.py @@ -0,0 +1,30 @@ +import pytest +from on_target.model.predict import on_target_predict + + +def test_on_target_predict_empty(): + # Empty sequence + predicted_target = on_target_predict("") + assert len(predicted_target) == 0, "the list is non empty" + + +def test_on_target_predict_size_guide(): + size_guide = 20 + predicted_targets = on_target_predict( + "TGCCTGTTTACGCGCCGATTGTTGCGAGATTTGGACGGACGTTGACGGGGTCTATACCTGCGACCCGCGTCAGGTGCCCGATGCGAGGTTGTTGAAGTCGATGTCCTACCAGGAAGCGATGGAGCTTTCCTACTTCGGCG" + ) + guides = (predicted_target["guide"] for predicted_target in predicted_targets) + for guide in guides: + assert len(guide) == size_guide, ( + "the guide do not have a length of " + size_guide + ) + pams = ( + (predicted_target["pam"], predicted_target["start"]) + for predicted_target in predicted_targets + ) + for pam in pams: + (pam_val, start_val) = pam + assert ( + start_val - pam_val == 2 + ), "the difference between start and pam position is different than 2" +