Skip to content
Snippets Groups Projects
Commit 475f4319 authored by David  BIKARD's avatar David BIKARD
Browse files

Initial commit

parent e973a38d
No related branches found
No related tags found
No related merge requests found
import numpy as np
import gffpandas.gffpandas as gffpd
from typing import Tuple
import re
with open("on_target/model/reg_coef.pkl", "br") as handle:
coef = np.load(handle, allow_pickle=True)
bases = ["A", "T", "G", "C"]
def encode(seq):
'''One-hot encoding of a sequence (only non-ambiguous bases (ATGC) accepted)'''
return np.array([[int(b == p) for b in seq] for p in bases])
# Quartiles: q1 > 0.4 > q2 > -0.08 > q3 > -0.59 > q4
def predict(X):
return [np.sum(x * coef) for x in X]
def rev_comp(seq):
comp = str.maketrans("ATGC", "TACG")
return seq.translate(comp)[::-1]
def find_targets(seq):
repam = "[ATGC]GG"
L = len(seq)
seq_revcomp = rev_comp(seq)
alltargets = [
dict(
[
("target", m.group(1)),
("guide", m.group(1)[:20]),
("start", L - m.start() - 20),
("stop", L - m.start()),
("pam", L - m.start() - 22),
("ori", "-"),
]
)
for m in re.finditer("(?=([ATGC]{6}" + repam + "[ATGC]{16}))", seq_revcomp)
]
return alltargets
def on_target_predict(seq):
seq = seq.upper() # make uppercase
seq = re.sub(r"\s", "", seq) # removes white space
alltargets = find_targets(seq)
if alltargets:
X = np.array(
[encode(tar["target"][:7] + tar["target"][9:]) for tar in alltargets] #encore and remove GG of PAM
)
X = X.reshape(X.shape[0], -1)
preds = predict(X)
for i, target in enumerate(alltargets):
target.update({"pred": preds[i]})
return alltargets
else:
return []
File added
import pytest
from on_target.model.predict import on_target_predict
def test_on_target_predict_empty():
# Empty sequence
predicted_target = on_target_predict("")
assert len(predicted_target) == 0, "the list is non empty"
def test_on_target_predict_size_guide():
size_guide = 20
predicted_targets = on_target_predict(
"TGCCTGTTTACGCGCCGATTGTTGCGAGATTTGGACGGACGTTGACGGGGTCTATACCTGCGACCCGCGTCAGGTGCCCGATGCGAGGTTGTTGAAGTCGATGTCCTACCAGGAAGCGATGGAGCTTTCCTACTTCGGCG"
)
guides = (predicted_target["guide"] for predicted_target in predicted_targets)
for guide in guides:
assert len(guide) == size_guide, (
"the guide do not have a length of " + size_guide
)
pams = (
(predicted_target["pam"], predicted_target["start"])
for predicted_target in predicted_targets
)
for pam in pams:
(pam_val, start_val) = pam
assert (
start_val - pam_val == 2
), "the difference between start and pam position is different than 2"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment