From f9e9aff9340dc853fc93bc90c42f2c64637c44f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20LAURENT?= <francois.laurent@pasteur.fr>
Date: Wed, 12 Oct 2022 20:30:44 +0200
Subject: [PATCH] interpolation support at commandline and predict_model levels
---
src/LarvaDatasets.jl | 4 +++
src/taggingbackends/explorer.py | 6 ++--
src/taggingbackends/features/skeleton.py | 43 ++++++++++++++++++++++++
src/taggingbackends/main.py | 16 +++++++--
4 files changed, 65 insertions(+), 4 deletions(-)
diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index d733130..50b0647 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -250,6 +250,7 @@ end
function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after;
fixmwt=false, frameinterval=nothing,
)
+ fixmwt && @warn "`fixmwt=true` is no longer supported"
# this method mutates argument `refs`
refsā²= Tuple{Int, Int, Int, eltype(keys(counts))}[]
for (label, count) in pairs(counts)
@@ -278,6 +279,9 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste
# extension
h5["labels"] = collect(keys(counts))
#h5["files"] = [f.source for f in files]
+ if !isnothing(frameinterval)
+ attributes(h5["samples"])["frame_interval"] = frameinterval
+ end
end
end
diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py
index f6ecdaf..07cd659 100644
--- a/src/taggingbackends/explorer.py
+++ b/src/taggingbackends/explorer.py
@@ -415,7 +415,8 @@ run `poetry add {pkg}` from directory: \n
return input_files, labels
def generate_dataset(self, input_files,
- labels=None, window_length=20, sample_size=None):
+ labels=None, window_length=20, sample_size=None,
+ frame_interval=None):
"""
Generate a *larva_dataset hdf5* file in data/interim/{instance}/
"""
@@ -424,7 +425,8 @@ run `poetry add {pkg}` from directory: \n
input_files if isinstance(input_files, list) else str(input_files),
window_length,
labels=labels,
- sample_size=sample_size)
+ sample_size=sample_size,
+ frameinterval=frame_interval)
def compile_trxmat_database(self, input_dir,
labels=None, window_length=20, sample_size=None, reuse_h5files=False):
diff --git a/src/taggingbackends/features/skeleton.py b/src/taggingbackends/features/skeleton.py
index d52f458..959094a 100644
--- a/src/taggingbackends/features/skeleton.py
+++ b/src/taggingbackends/features/skeleton.py
@@ -19,3 +19,46 @@ def get_5point_spines(spine):
return spine
else:
raise NotImplementedError(spine.shape)
+
+def interpolate(times, spines, anchor, window_length,
+ spine_interpolation='linear', frame_interval=0.1, **kwargs):
+ """
+ Interpolate spine series around anchor time `times[anchor]`, with about
+ `window_length // 2` time steps before and after, evenly spaced by
+ `frame_interval`.
+
+ Only linear interpolation is supported for now.
+ """
+ # m = anchor
+ # n = m + window_length
+ # if n <= spines.shape[0]:
+ # return spines[m:n,:]
+ # else:
+ # return
+ assert spine_interpolation == 'linear'
+ tstart, anchor, tstop = times[0], times[anchor], times[-1]
+ istart = np.trunc((tstart - anchor) / frame_interval).astype(int)
+ istop = np.trunc((tstop - anchor) / frame_interval).astype(int)
+ nframes_before = window_length // 2
+ nframes_after = window_length - 1 - nframes_before
+ istart = max(-nframes_before, istart)
+ istop = min(nframes_after, istop)
+ if istop - istart + 1 < window_length:
+ return
+ grid = range(istart, istop+1)
+ series = []
+ for i in grid:
+ t = round((anchor + i * frame_interval) * 1e4) * 1e-4
+ inext = np.flatnonzero(t <= times)[0]
+ tnext, xnext = times[inext], spines[inext]
+ if tnext == t:
+ x = xnext
+ else:
+ assert 0 < inext
+ tprev, xprev = times[inext-1], spines[inext-1]
+ x = interp(xprev, xnext, (t - tprev) / (tnext - tprev))
+ series.append(x)
+ return np.stack(series, axis=0)
+
+def interp(x0, x1, alpha):
+ return (1 - alpha) * x0 + alpha * x1
diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py
index cabbc61..3046369 100644
--- a/src/taggingbackends/main.py
+++ b/src/taggingbackends/main.py
@@ -8,7 +8,7 @@ def help(_print=False):
Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --labels <comma-separated-list>
tagging-backend train ... --sample-size <N> --window-length <T>
- tagging-backend train ... --trxmat-only --reuse-h5files
+ tagging-backend train ... --frame-interval <I>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend predict ... --skip-make-dataset
@@ -21,6 +21,10 @@ spines. If option `--sample-size` is passed, <N> time segments are sampled from
the raw database. The total length 3*<T> of time segments is 60 per default
(20 *past* points, 20 *present* points and 20 *future* points).
+If frame interval <I> is specified (in seconds), spine series are resampled and
+interpolated around each time segment anchor (center).
+
+**Deprecated**:
Option `--trxmat-only` is suitable for large databases made of trx.mat files
only. Intermediate HDF5 files are generated prior to counting the various
behavior labels and sampling time segments in the database. These intermediate
@@ -48,7 +52,7 @@ def main(fun=None):
train_or_predict = sys.argv[1]
project_dir = model_instance = None
input_files, labels = [], []
- sample_size = window_length = None
+ sample_size = window_length = frame_interval = None
trxmat_only = reuse_h5files = False
skip_make_dataset = skip_build_features = False
pretrained_model_instance = None
@@ -72,6 +76,9 @@ def main(fun=None):
elif sys.argv[k] == "--window-length":
k = k + 1
window_length = sys.argv[k]
+ elif sys.argv[k] == "--frame-interval":
+ k = k + 1
+ frame_interval = sys.argv[k]
elif sys.argv[k] == "--trxmat-only":
trxmat_only = True
elif sys.argv[k] == "--reuse-h5files":
@@ -98,6 +105,8 @@ def main(fun=None):
make_dataset_kwargs["sample_size"] = sample_size
if window_length:
make_dataset_kwargs["window_length"] = window_length
+ if frame_interval:
+ make_dataset_kwargs["frame_interval"] = frame_interval
if trxmat_only:
make_dataset_kwargs["trxmat_only"] = True
if reuse_h5files:
@@ -125,6 +134,9 @@ def main(fun=None):
elif key in ("sample_size", "window_length"):
if isinstance(val, str):
val = int(val)
+ elif key in ("frame_interval",):
+ if isinstance(val, str):
+ val = float(val)
elif key == "labels":
if isinstance(val, str):
val = val.split(',')
--
GitLab