diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index d733130288f36d5b96d39c0cfbf2192e1fb0f429..50b0647da82d5b6d25ae04ea1dcadfa963387b97 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -250,6 +250,7 @@ end function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; fixmwt=false, frameinterval=nothing, ) + fixmwt && @warn "`fixmwt=true` is no longer supported" # this method mutates argument `refs` refsā²= Tuple{Int, Int, Int, eltype(keys(counts))}[] for (label, count) in pairs(counts) @@ -278,6 +279,9 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste # extension h5["labels"] = collect(keys(counts)) #h5["files"] = [f.source for f in files] + if !isnothing(frameinterval) + attributes(h5["samples"])["frame_interval"] = frameinterval + end end end diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index f6ecdaf64ad6fbacfd20b7b5890671e0d17f400b..07cd6592c494c28dd42444b68b98e22d108d5d67 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -415,7 +415,8 @@ run `poetry add {pkg}` from directory: \n return input_files, labels def generate_dataset(self, input_files, - labels=None, window_length=20, sample_size=None): + labels=None, window_length=20, sample_size=None, + frame_interval=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ """ @@ -424,7 +425,8 @@ run `poetry add {pkg}` from directory: \n input_files if isinstance(input_files, list) else str(input_files), window_length, labels=labels, - sample_size=sample_size) + sample_size=sample_size, + frameinterval=frame_interval) def compile_trxmat_database(self, input_dir, labels=None, window_length=20, sample_size=None, reuse_h5files=False): diff --git a/src/taggingbackends/features/skeleton.py b/src/taggingbackends/features/skeleton.py index d52f458910605b3d919fa3128149f6c331bb41f2..959094ae95f9089d0edbc0d97597e8de6aaee599 100644 --- a/src/taggingbackends/features/skeleton.py +++ b/src/taggingbackends/features/skeleton.py @@ -19,3 +19,46 @@ def get_5point_spines(spine): return spine else: raise NotImplementedError(spine.shape) + +def interpolate(times, spines, anchor, window_length, + spine_interpolation='linear', frame_interval=0.1, **kwargs): + """ + Interpolate spine series around anchor time `times[anchor]`, with about + `window_length // 2` time steps before and after, evenly spaced by + `frame_interval`. + + Only linear interpolation is supported for now. + """ + # m = anchor + # n = m + window_length + # if n <= spines.shape[0]: + # return spines[m:n,:] + # else: + # return + assert spine_interpolation == 'linear' + tstart, anchor, tstop = times[0], times[anchor], times[-1] + istart = np.trunc((tstart - anchor) / frame_interval).astype(int) + istop = np.trunc((tstop - anchor) / frame_interval).astype(int) + nframes_before = window_length // 2 + nframes_after = window_length - 1 - nframes_before + istart = max(-nframes_before, istart) + istop = min(nframes_after, istop) + if istop - istart + 1 < window_length: + return + grid = range(istart, istop+1) + series = [] + for i in grid: + t = round((anchor + i * frame_interval) * 1e4) * 1e-4 + inext = np.flatnonzero(t <= times)[0] + tnext, xnext = times[inext], spines[inext] + if tnext == t: + x = xnext + else: + assert 0 < inext + tprev, xprev = times[inext-1], spines[inext-1] + x = interp(xprev, xnext, (t - tprev) / (tnext - tprev)) + series.append(x) + return np.stack(series, axis=0) + +def interp(x0, x1, alpha): + return (1 - alpha) * x0 + alpha * x1 diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index cabbc61bc7d811f0f3ac8dc74bb16678a069501d..304636906148520a3faa3a55ce17c8943b310f1a 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -8,7 +8,7 @@ def help(_print=False): Usage: tagging-backend [train|predict] --model-instance <name> tagging-backend train ... --labels <comma-separated-list> tagging-backend train ... --sample-size <N> --window-length <T> - tagging-backend train ... --trxmat-only --reuse-h5files + tagging-backend train ... --frame-interval <I> tagging-backend train ... --pretrained-model-instance <name> tagging-backend predict ... --skip-make-dataset @@ -21,6 +21,10 @@ spines. If option `--sample-size` is passed, <N> time segments are sampled from the raw database. The total length 3*<T> of time segments is 60 per default (20 *past* points, 20 *present* points and 20 *future* points). +If frame interval <I> is specified (in seconds), spine series are resampled and +interpolated around each time segment anchor (center). + +**Deprecated**: Option `--trxmat-only` is suitable for large databases made of trx.mat files only. Intermediate HDF5 files are generated prior to counting the various behavior labels and sampling time segments in the database. These intermediate @@ -48,7 +52,7 @@ def main(fun=None): train_or_predict = sys.argv[1] project_dir = model_instance = None input_files, labels = [], [] - sample_size = window_length = None + sample_size = window_length = frame_interval = None trxmat_only = reuse_h5files = False skip_make_dataset = skip_build_features = False pretrained_model_instance = None @@ -72,6 +76,9 @@ def main(fun=None): elif sys.argv[k] == "--window-length": k = k + 1 window_length = sys.argv[k] + elif sys.argv[k] == "--frame-interval": + k = k + 1 + frame_interval = sys.argv[k] elif sys.argv[k] == "--trxmat-only": trxmat_only = True elif sys.argv[k] == "--reuse-h5files": @@ -98,6 +105,8 @@ def main(fun=None): make_dataset_kwargs["sample_size"] = sample_size if window_length: make_dataset_kwargs["window_length"] = window_length + if frame_interval: + make_dataset_kwargs["frame_interval"] = frame_interval if trxmat_only: make_dataset_kwargs["trxmat_only"] = True if reuse_h5files: @@ -125,6 +134,9 @@ def main(fun=None): elif key in ("sample_size", "window_length"): if isinstance(val, str): val = int(val) + elif key in ("frame_interval",): + if isinstance(val, str): + val = float(val) elif key == "labels": if isinstance(val, str): val = val.split(',')