diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 78e4fd91b0d932f161b1bdde1d24ff8bbde8319d..580676a87034cc5150461f2050434789b43cc62b 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -7,12 +7,13 @@ def predict_model(backend, **kwargs): """ This function generates predicted labels for all the input data. - The input files can be read from any directory. - All generated/modified files should be written to `data/interim` or - `data/processed`. - The predicted labels are expected in `data/processed`. + It currently supports single files at a time only, either *larva_dataset* + hdf5 files or track data files of any time. - The `predict_model.py` script is required. + Input files are expected in `data/interim` or `data/raw`. + + The predicted labels are saved in `data/processed`, as file + `predicted.label`. """ # we pick files in `data/interim` if any, otherwise in `data/raw` input_files = backend.list_interim_files() @@ -92,6 +93,8 @@ def predict_individual_data_files(backend, model, input_files, labels): data[larva] = get_5point_spines(data[larva]) else: data = get_5point_spines(data) + # + post_filters = model.clf_config.get('post_filters', None) # assign labels if isinstance(data, dict): ref_length = np.median(np.concatenate([ @@ -102,9 +105,10 @@ def predict_individual_data_files(backend, model, input_files, labels): for larva, spines in data.items(): predictions = model.predict((t[larva], spines)) if predictions is None: - print(f"failure at windowing track: {larva}") + print(f"failure to window track: {larva}") else: - labels[run, larva] = dict(zip(t[larva], predictions)) + predictions = apply_filters(predictions, post_filters) + labels[run, larva] = dict(_zip(t[larva], predictions)) else: ref_length = np.median(model.body_length(data)) model.average_body_length = ref_length @@ -113,9 +117,10 @@ def predict_individual_data_files(backend, model, input_files, labels): mask = larvae == larva predictions = model.predict((t[mask], data[mask])) if predictions is None: - print(f"failure at windowing track: {larva}") + print(f"failure to window track: {larva}") else: - labels[run, larva] = dict(zip(t[mask], predictions)) + predictions = apply_filters(predictions, post_filters) + labels[run, larva] = dict(_zip(t[mask], predictions)) # save the predicted labels to file labels.dump(backend.processed_data_dir() / "predicted.label") # @@ -127,6 +132,24 @@ def predict_larva_dataset(backend, model, file, labels, subset="validation"): dataset = LarvaDataset(file, new_generator()) return model.predict(dataset, subset) +def _zip(xs, ys): + # prevent issues similar to #2 + assert len(xs) == len(ys) + return zip(xs, ys) + +def apply_filters(labels, post_filters): + if post_filters: + for post_filter in post_filters: + if post_filter == 'ABA->AAA': + # modify sequentially + for k in range(1, len(labels)-1): + label = labels[k-1] + if labels[k-1] != label and label == labels[k+1]: + labels[k] = label + else: + raise ValueError(f"filter not supported: {post_filter}") + return labels + from taggingbackends.main import main diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py index 6bdc4b852a9e0441c2114652efe3e86bac8d5f33..025b95dbbb83af436d92fd9d53f6b450d92c1ab6 100644 --- a/src/maggotuba/models/trainers.py +++ b/src/maggotuba/models/trainers.py @@ -62,22 +62,25 @@ class MaggotTrainer: win = interpolate(t, data, m, winlen, **interpolation_args) if win is not None: assert win.shape[0] == winlen - yield win + yield t[m], win else: for m in range(0, N-winlen): n = m + winlen - yield data[m:n] + yield t[(m + n) // 2], data[m:n] - def pad(self, data): - winlen = self.config["len_traj"] + def pad(self, target_t, defined_t, data): if data.shape[0] == 1: return data else: + head = searchsortedfirst(target_t, defined_t[0]) + tail = len(target_t) - (searchsortedlast(target_t, defined_t[-1]) + 1) ind = np.r_[ - np.zeros(winlen // 2, dtype=int), + np.zeros(head, dtype=int), np.arange(data.shape[0]), - (data.shape[1]-1) * np.ones(winlen // 2 - 1, dtype=int), + (data.shape[1]-1) * np.ones(tail, dtype=int), ] + if len(ind) != len(target_t): + raise RuntimeError('missing time steps') return data[ind] def body_length(self, data): @@ -98,12 +101,20 @@ class MaggotTrainer: w = np.einsum("ij,jkl", rot, np.reshape(w.T, (2, 5, -1), order='F')) return w + """ + Preprocess a single track. + + This includes running a sliding window, resampling the track in each window, + normalizing the spines, etc. + """ def preprocess(self, t, data): + defined_t = [] ws = [] - for w in self.window(t, data): + for t_, w in self.window(t, data): + defined_t.append(t_) ws.append(self.normalize(w)) if ws: - ret = self.pad(np.stack(ws)) + ret = self.pad(t, defined_t, np.stack(ws)) if self.swap_head_tail: ret = ret[:,:,::-1,:] return ret @@ -268,3 +279,15 @@ def make_trainer(config_file, *args, **kwargs): model = MaggotTrainer(config_file, *args, **kwargs) return model + +# Julia functions +def searchsortedfirst(xs, x): + for i, x_ in enumerate(xs): + if x <= x_: + return i + +def searchsortedlast(xs, x): + for i in range(len(xs))[::-1]: + x_ = xs[i] + if x_ <= x: + return i