Skip to content
Snippets Groups Projects

Set of commits to be tagged v0.20

Merged François LAURENT requested to merge dev into main
2 files
+ 48
14
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -4,6 +4,7 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, Ma
@@ -4,6 +4,7 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, Ma
import numpy as np
import numpy as np
import logging
import logging
import os.path
import os.path
 
import re
tracking_data_file_extensions = ('.spine', '.outline', '.csv', '.mat', '.hdf5')
tracking_data_file_extensions = ('.spine', '.outline', '.csv', '.mat', '.hdf5')
@@ -133,8 +134,10 @@ def predict_individual_data_files(backend, model, input_files_and_labels):
@@ -133,8 +134,10 @@ def predict_individual_data_files(backend, model, input_files_and_labels):
if predictions is None:
if predictions is None:
logging.info(f"failure to window track: {larva}")
logging.info(f"failure to window track: {larva}")
else:
else:
predictions = apply_filters(predictions, post_filters)
t_ = t[larva]
labels[run, larva] = dict(_zip(t[larva], predictions))
predictions = apply_filters(t_, predictions,
 
post_filters, labels.decoding_label_list, larva)
 
labels[run, larva] = dict(_zip(t_, predictions))
else:
else:
ref_length = np.median(model.body_length(data))
ref_length = np.median(model.body_length(data))
model.average_body_length = ref_length
model.average_body_length = ref_length
@@ -145,8 +148,10 @@ def predict_individual_data_files(backend, model, input_files_and_labels):
@@ -145,8 +148,10 @@ def predict_individual_data_files(backend, model, input_files_and_labels):
if predictions is None:
if predictions is None:
logging.info(f"failure to window track: {larva}")
logging.info(f"failure to window track: {larva}")
else:
else:
predictions = apply_filters(predictions, post_filters)
t_ = t[mask]
labels[run, larva] = dict(_zip(t[mask], predictions))
predictions = apply_filters(t_, predictions,
 
post_filters, labels.decoding_label_list, larva)
 
labels[run, larva] = dict(_zip(t_, predictions))
# save the predicted labels to file
# save the predicted labels to file
labels.dump(get_output_filepath(backend, file))
labels.dump(get_output_filepath(backend, file))
@@ -184,23 +189,52 @@ def get_output_filepath(backend, file):
@@ -184,23 +189,52 @@ def get_output_filepath(backend, file):
if not target.is_file(): break
if not target.is_file(): break
return target
return target
def apply_filters(labels, post_filters):
def apply_filters(t, actions, post_filters, labels, larva_id):
 
# `labels` can be used to map labels to class indices
 
# `larva_id` is for logging only; no filters should rely on this datum
if post_filters:
if post_filters:
for post_filter in post_filters:
for post_filter in post_filters:
if post_filter == 'ABA->AAA':
if post_filter == 'ABA->AAA':
# modify sequentially
# modify sequentially
for k in range(1, len(labels)-1):
for k in range(1, len(actions)-1):
label = labels[k-1]
action = actions[k-1]
if labels[k] != label and label == labels[k+1]:
if actions[k] != action and action == actions[k+1]:
labels[k] = label
actions[k] = action
elif post_filter == 'ABC->AAC':
elif post_filter == 'ABC->AAC':
# modify sequentially
# modify sequentially
for k in range(1, len(labels)-1):
for k in range(1, len(actions)-1):
if labels[k-1] != labels[k] and labels[k] != labels[k+1]:
if actions[k-1] != actions[k] and actions[k] != actions[k+1]:
labels[k] = labels[k-1]
actions[k] = actions[k-1]
 
elif ' with duration' in post_filter and '->' in post_filter:
 
condition, replacement = post_filter.split('->')
 
replacement = replacement.strip()
 
assert replacement in labels
 
# the predicted actions are (decoded) labels, not class indices
 
#replacement = labels.index(replacement)
 
label, constraint = condition.split(' with duration')
 
label = label.strip()
 
assert label in labels
 
assert re.match(r'^ *(<|<=|>|>=|==) *[0-9.]+ ?s? *$', constraint) is not None
 
constraint = constraint.rstrip('s ')
 
constraint = f'duration{constraint}'
 
k = 0
 
while k < len(actions):
 
if actions[k] == label:
 
i = k
 
k += 1
 
while k < len(actions) and actions[k] == label:
 
k += 1
 
j = k - 1
 
duration = t[j] - t[i]
 
if eval(constraint):
 
logging.info(f'larva {larva_id}: replacing labels at {k-i} steps ({round(duration, 3)}s): {actions[i]}->{replacement}')
 
for j in range(i, k):
 
actions[j] = replacement
 
else:
 
k += 1
else:
else:
raise ValueError(f"filter not supported: {post_filter}")
raise ValueError(f"filter not supported: {post_filter}")
return labels
return actions
from taggingbackends.main import main
from taggingbackends.main import main
Loading