Skip to content
Snippets Groups Projects
Commit adf8233a authored by François  LAURENT's avatar François LAURENT
Browse files

fixes #2 and implements proposol 1 in larvatagger.jl#62

parent 9eb347f9
No related branches found
No related tags found
No related merge requests found
......@@ -7,12 +7,13 @@ def predict_model(backend, **kwargs):
"""
This function generates predicted labels for all the input data.
The input files can be read from any directory.
All generated/modified files should be written to `data/interim` or
`data/processed`.
The predicted labels are expected in `data/processed`.
It currently supports single files at a time only, either *larva_dataset*
hdf5 files or track data files of any time.
The `predict_model.py` script is required.
Input files are expected in `data/interim` or `data/raw`.
The predicted labels are saved in `data/processed`, as file
`predicted.label`.
"""
# we pick files in `data/interim` if any, otherwise in `data/raw`
input_files = backend.list_interim_files()
......@@ -92,6 +93,8 @@ def predict_individual_data_files(backend, model, input_files, labels):
data[larva] = get_5point_spines(data[larva])
else:
data = get_5point_spines(data)
#
post_filters = model.clf_config.get('post_filters', None)
# assign labels
if isinstance(data, dict):
ref_length = np.median(np.concatenate([
......@@ -102,9 +105,10 @@ def predict_individual_data_files(backend, model, input_files, labels):
for larva, spines in data.items():
predictions = model.predict((t[larva], spines))
if predictions is None:
print(f"failure at windowing track: {larva}")
print(f"failure to window track: {larva}")
else:
labels[run, larva] = dict(zip(t[larva], predictions))
predictions = apply_filters(predictions, post_filters)
labels[run, larva] = dict(_zip(t[larva], predictions))
else:
ref_length = np.median(model.body_length(data))
model.average_body_length = ref_length
......@@ -113,9 +117,10 @@ def predict_individual_data_files(backend, model, input_files, labels):
mask = larvae == larva
predictions = model.predict((t[mask], data[mask]))
if predictions is None:
print(f"failure at windowing track: {larva}")
print(f"failure to window track: {larva}")
else:
labels[run, larva] = dict(zip(t[mask], predictions))
predictions = apply_filters(predictions, post_filters)
labels[run, larva] = dict(_zip(t[mask], predictions))
# save the predicted labels to file
labels.dump(backend.processed_data_dir() / "predicted.label")
#
......@@ -127,6 +132,24 @@ def predict_larva_dataset(backend, model, file, labels, subset="validation"):
dataset = LarvaDataset(file, new_generator())
return model.predict(dataset, subset)
def _zip(xs, ys):
# prevent issues similar to #2
assert len(xs) == len(ys)
return zip(xs, ys)
def apply_filters(labels, post_filters):
if post_filters:
for post_filter in post_filters:
if post_filter == 'ABA->AAA':
# modify sequentially
for k in range(1, len(labels)-1):
label = labels[k-1]
if labels[k-1] != label and label == labels[k+1]:
labels[k] = label
else:
raise ValueError(f"filter not supported: {post_filter}")
return labels
from taggingbackends.main import main
......
......@@ -62,22 +62,25 @@ class MaggotTrainer:
win = interpolate(t, data, m, winlen, **interpolation_args)
if win is not None:
assert win.shape[0] == winlen
yield win
yield t[m], win
else:
for m in range(0, N-winlen):
n = m + winlen
yield data[m:n]
yield t[(m + n) // 2], data[m:n]
def pad(self, data):
winlen = self.config["len_traj"]
def pad(self, target_t, defined_t, data):
if data.shape[0] == 1:
return data
else:
head = searchsortedfirst(target_t, defined_t[0])
tail = len(target_t) - (searchsortedlast(target_t, defined_t[-1]) + 1)
ind = np.r_[
np.zeros(winlen // 2, dtype=int),
np.zeros(head, dtype=int),
np.arange(data.shape[0]),
(data.shape[1]-1) * np.ones(winlen // 2 - 1, dtype=int),
(data.shape[1]-1) * np.ones(tail, dtype=int),
]
if len(ind) != len(target_t):
raise RuntimeError('missing time steps')
return data[ind]
def body_length(self, data):
......@@ -98,12 +101,20 @@ class MaggotTrainer:
w = np.einsum("ij,jkl", rot, np.reshape(w.T, (2, 5, -1), order='F'))
return w
"""
Preprocess a single track.
This includes running a sliding window, resampling the track in each window,
normalizing the spines, etc.
"""
def preprocess(self, t, data):
defined_t = []
ws = []
for w in self.window(t, data):
for t_, w in self.window(t, data):
defined_t.append(t_)
ws.append(self.normalize(w))
if ws:
ret = self.pad(np.stack(ws))
ret = self.pad(t, defined_t, np.stack(ws))
if self.swap_head_tail:
ret = ret[:,:,::-1,:]
return ret
......@@ -268,3 +279,15 @@ def make_trainer(config_file, *args, **kwargs):
model = MaggotTrainer(config_file, *args, **kwargs)
return model
# Julia functions
def searchsortedfirst(xs, x):
for i, x_ in enumerate(xs):
if x <= x_:
return i
def searchsortedlast(xs, x):
for i in range(len(xs))[::-1]:
x_ = xs[i]
if x_ <= x:
return i
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment