diff --git a/pyproject.toml b/pyproject.toml index 59cc1eddb90b910e9013ad0bea169ce2f4762a98..a9a5786cf72ab925e8aaf210491cf6972e873723 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "MaggotUBA-adapter" -version = "0.19" +version = "0.20" description = "Interface between MaggotUBA and the Nyx tagging UI" authors = ["François Laurent"] license = "MIT" diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index cc0d50ba38f240f674b3140481e67209908b1778..7157dfdcab5320f1078f485731e2e33810d5fae1 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -4,6 +4,7 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, Ma import numpy as np import logging import os.path +import re tracking_data_file_extensions = ('.spine', '.outline', '.csv', '.mat', '.hdf5') @@ -133,8 +134,10 @@ def predict_individual_data_files(backend, model, input_files_and_labels): if predictions is None: logging.info(f"failure to window track: {larva}") else: - predictions = apply_filters(predictions, post_filters) - labels[run, larva] = dict(_zip(t[larva], predictions)) + t_ = t[larva] + predictions = apply_filters(t_, predictions, + post_filters, labels.decoding_label_list, larva) + labels[run, larva] = dict(_zip(t_, predictions)) else: ref_length = np.median(model.body_length(data)) model.average_body_length = ref_length @@ -145,8 +148,10 @@ def predict_individual_data_files(backend, model, input_files_and_labels): if predictions is None: logging.info(f"failure to window track: {larva}") else: - predictions = apply_filters(predictions, post_filters) - labels[run, larva] = dict(_zip(t[mask], predictions)) + t_ = t[mask] + predictions = apply_filters(t_, predictions, + post_filters, labels.decoding_label_list, larva) + labels[run, larva] = dict(_zip(t_, predictions)) # save the predicted labels to file labels.dump(get_output_filepath(backend, file)) @@ -184,23 +189,52 @@ def get_output_filepath(backend, file): if not target.is_file(): break return target -def apply_filters(labels, post_filters): +def apply_filters(t, actions, post_filters, labels, larva_id): + # `labels` can be used to map labels to class indices + # `larva_id` is for logging only; no filters should rely on this datum if post_filters: for post_filter in post_filters: if post_filter == 'ABA->AAA': # modify sequentially - for k in range(1, len(labels)-1): - label = labels[k-1] - if labels[k] != label and label == labels[k+1]: - labels[k] = label + for k in range(1, len(actions)-1): + action = actions[k-1] + if actions[k] != action and action == actions[k+1]: + actions[k] = action elif post_filter == 'ABC->AAC': # modify sequentially - for k in range(1, len(labels)-1): - if labels[k-1] != labels[k] and labels[k] != labels[k+1]: - labels[k] = labels[k-1] + for k in range(1, len(actions)-1): + if actions[k-1] != actions[k] and actions[k] != actions[k+1]: + actions[k] = actions[k-1] + elif ' with duration' in post_filter and '->' in post_filter: + condition, replacement = post_filter.split('->') + replacement = replacement.strip() + assert replacement in labels + # the predicted actions are (decoded) labels, not class indices + #replacement = labels.index(replacement) + label, constraint = condition.split(' with duration') + label = label.strip() + assert label in labels + assert re.match(r'^ *(<|<=|>|>=|==) *[0-9.]+ ?s? *$', constraint) is not None + constraint = constraint.rstrip('s ') + constraint = f'duration{constraint}' + k = 0 + while k < len(actions): + if actions[k] == label: + i = k + k += 1 + while k < len(actions) and actions[k] == label: + k += 1 + j = k - 1 + duration = t[j] - t[i] + if eval(constraint): + logging.info(f'larva {larva_id}: replacing labels at {k-i} steps ({round(duration, 3)}s): {actions[i]}->{replacement}') + for j in range(i, k): + actions[j] = replacement + else: + k += 1 else: raise ValueError(f"filter not supported: {post_filter}") - return labels + return actions from taggingbackends.main import main