From f9e9aff9340dc853fc93bc90c42f2c64637c44f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20LAURENT?= <francois.laurent@pasteur.fr> Date: Wed, 12 Oct 2022 20:30:44 +0200 Subject: [PATCH] interpolation support at commandline and predict_model levels --- src/LarvaDatasets.jl | 4 +++ src/taggingbackends/explorer.py | 6 ++-- src/taggingbackends/features/skeleton.py | 43 ++++++++++++++++++++++++ src/taggingbackends/main.py | 16 +++++++-- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index d733130..50b0647 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -250,6 +250,7 @@ end function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; fixmwt=false, frameinterval=nothing, ) + fixmwt && @warn "`fixmwt=true` is no longer supported" # this method mutates argument `refs` refsā²= Tuple{Int, Int, Int, eltype(keys(counts))}[] for (label, count) in pairs(counts) @@ -278,6 +279,9 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste # extension h5["labels"] = collect(keys(counts)) #h5["files"] = [f.source for f in files] + if !isnothing(frameinterval) + attributes(h5["samples"])["frame_interval"] = frameinterval + end end end diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index f6ecdaf..07cd659 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -415,7 +415,8 @@ run `poetry add {pkg}` from directory: \n return input_files, labels def generate_dataset(self, input_files, - labels=None, window_length=20, sample_size=None): + labels=None, window_length=20, sample_size=None, + frame_interval=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ """ @@ -424,7 +425,8 @@ run `poetry add {pkg}` from directory: \n input_files if isinstance(input_files, list) else str(input_files), window_length, labels=labels, - sample_size=sample_size) + sample_size=sample_size, + frameinterval=frame_interval) def compile_trxmat_database(self, input_dir, labels=None, window_length=20, sample_size=None, reuse_h5files=False): diff --git a/src/taggingbackends/features/skeleton.py b/src/taggingbackends/features/skeleton.py index d52f458..959094a 100644 --- a/src/taggingbackends/features/skeleton.py +++ b/src/taggingbackends/features/skeleton.py @@ -19,3 +19,46 @@ def get_5point_spines(spine): return spine else: raise NotImplementedError(spine.shape) + +def interpolate(times, spines, anchor, window_length, + spine_interpolation='linear', frame_interval=0.1, **kwargs): + """ + Interpolate spine series around anchor time `times[anchor]`, with about + `window_length // 2` time steps before and after, evenly spaced by + `frame_interval`. + + Only linear interpolation is supported for now. + """ + # m = anchor + # n = m + window_length + # if n <= spines.shape[0]: + # return spines[m:n,:] + # else: + # return + assert spine_interpolation == 'linear' + tstart, anchor, tstop = times[0], times[anchor], times[-1] + istart = np.trunc((tstart - anchor) / frame_interval).astype(int) + istop = np.trunc((tstop - anchor) / frame_interval).astype(int) + nframes_before = window_length // 2 + nframes_after = window_length - 1 - nframes_before + istart = max(-nframes_before, istart) + istop = min(nframes_after, istop) + if istop - istart + 1 < window_length: + return + grid = range(istart, istop+1) + series = [] + for i in grid: + t = round((anchor + i * frame_interval) * 1e4) * 1e-4 + inext = np.flatnonzero(t <= times)[0] + tnext, xnext = times[inext], spines[inext] + if tnext == t: + x = xnext + else: + assert 0 < inext + tprev, xprev = times[inext-1], spines[inext-1] + x = interp(xprev, xnext, (t - tprev) / (tnext - tprev)) + series.append(x) + return np.stack(series, axis=0) + +def interp(x0, x1, alpha): + return (1 - alpha) * x0 + alpha * x1 diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index cabbc61..3046369 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -8,7 +8,7 @@ def help(_print=False): Usage: tagging-backend [train|predict] --model-instance <name> tagging-backend train ... --labels <comma-separated-list> tagging-backend train ... --sample-size <N> --window-length <T> - tagging-backend train ... --trxmat-only --reuse-h5files + tagging-backend train ... --frame-interval <I> tagging-backend train ... --pretrained-model-instance <name> tagging-backend predict ... --skip-make-dataset @@ -21,6 +21,10 @@ spines. If option `--sample-size` is passed, <N> time segments are sampled from the raw database. The total length 3*<T> of time segments is 60 per default (20 *past* points, 20 *present* points and 20 *future* points). +If frame interval <I> is specified (in seconds), spine series are resampled and +interpolated around each time segment anchor (center). + +**Deprecated**: Option `--trxmat-only` is suitable for large databases made of trx.mat files only. Intermediate HDF5 files are generated prior to counting the various behavior labels and sampling time segments in the database. These intermediate @@ -48,7 +52,7 @@ def main(fun=None): train_or_predict = sys.argv[1] project_dir = model_instance = None input_files, labels = [], [] - sample_size = window_length = None + sample_size = window_length = frame_interval = None trxmat_only = reuse_h5files = False skip_make_dataset = skip_build_features = False pretrained_model_instance = None @@ -72,6 +76,9 @@ def main(fun=None): elif sys.argv[k] == "--window-length": k = k + 1 window_length = sys.argv[k] + elif sys.argv[k] == "--frame-interval": + k = k + 1 + frame_interval = sys.argv[k] elif sys.argv[k] == "--trxmat-only": trxmat_only = True elif sys.argv[k] == "--reuse-h5files": @@ -98,6 +105,8 @@ def main(fun=None): make_dataset_kwargs["sample_size"] = sample_size if window_length: make_dataset_kwargs["window_length"] = window_length + if frame_interval: + make_dataset_kwargs["frame_interval"] = frame_interval if trxmat_only: make_dataset_kwargs["trxmat_only"] = True if reuse_h5files: @@ -125,6 +134,9 @@ def main(fun=None): elif key in ("sample_size", "window_length"): if isinstance(val, str): val = int(val) + elif key in ("frame_interval",): + if isinstance(val, str): + val = float(val) elif key == "labels": if isinstance(val, str): val = val.split(',') -- GitLab