From 38d01ab856cbef13a7d0f2d09201263737c05341 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Wed, 3 Jan 2024 16:16:28 +0100
Subject: [PATCH] simpler format for embeddings.h5
---
README.md | 30 ++++++++-----
src/maggotuba/models/embed_model.py | 67 +++++++++++++++++++++++------
2 files changed, 72 insertions(+), 25 deletions(-)
diff --git a/README.md b/README.md
index 6590557..93e9121 100644
--- a/README.md
+++ b/README.md
@@ -174,17 +174,25 @@ Similarly to the other commands, input data files are expected in the data/raw d
The above command produces an `embeddings.h5` file in the data/processed/20230311 directory.
-The `embeddings.h5` file is an HDF5 file structured as follows:
+The `embeddings.h5` file is an HDF5 file containing several arrays that all feature as many elements or rows as
+embedded/projected data points. This file is structured as follows:
```
-├── n_runs <- integer; number of files or assays or runs.
-├── run_0 <- dataset named run_<i> with <i> ranging from 0 to n_runs - 1.
-... ├── run_id <- string; run id, typically date and time in the yyyymmdd_HHMMSS format.
- ├── n_tracks <- integer; number of tracks or larvae.
- ├── track_0 <- dataset named track_<j> with <j> ranging from 0 to n_tracks - 1.
- ... ├── track_id <- integer; track id as referred to as in the original tracking data file.
- ├── n_steps <- integer; number of projected time steps or segments.
- ├── step_0 <- dataset named step_<k> with <k> ranging from 0 to n_steps - 1.
- ... ├── time <- float; timestamp of the step or segment center.
- └── embedding <- float array.
+├── run_id <- 1D array, typically of strings; id of the tracking data file or assay or run.
+├── track_id <- 1D array of integers; id of the track or larva.
+├── time <- 1D array of floats; timestamp of the time step or time segment center.
+└── embedding <- 2D array of floats; coordinates in the latent space.
```
This format is not compatible with the `clustering.cache` file used by [MaggotUBA's ToMATo UI](https://github.com/DecBayComp/Detecting_subtle_behavioural_changes/blob/ee73f0dd294a991322a0eec8f6ce69488c7a1f9a/maggotuba/src/maggotuba/cli/cli_model_clustering.py#L129-L164).
+
+Track ids are not unique across runs. Similarly, times do not share a common time origin across runs.
+
+To visualize the embeddings, the `embedding` matrix can be loaded and transformed with methods like UMAP:
+```
+import h5py
+import umap
+
+with h5py.File('embeddings.h5', 'r') as f:
+ embedding = f['embedding'][...]
+
+embedding2d = umap.UMAP().fit_transform(embedding)
+```
diff --git a/src/maggotuba/models/embed_model.py b/src/maggotuba/models/embed_model.py
index 58eaa49..bc814b7 100644
--- a/src/maggotuba/models/embed_model.py
+++ b/src/maggotuba/models/embed_model.py
@@ -4,6 +4,7 @@ from maggotuba.features.preprocess import Preprocessor
from maggotuba.models.modules import MaggotEncoder
from maggotuba.models.trainers import new_generator
from behavior_model.models.neural_nets import device
+from taggingbackends.explorer import check_permissions
from collections import defaultdict
import numpy as np
import logging
@@ -61,6 +62,7 @@ def embed_individual_data_files(backend, encoder, input_files):
encoder.to(device)
features = defaultdict(dict)
+ npoints = 0
for input_files in input_files.values():
done = False
@@ -122,6 +124,7 @@ def embed_individual_data_files(backend, encoder, input_files):
logging.info(f"failure to window track: {larva}")
else:
features[run][larva] = (t[larva], latentfeatures)
+ npoints += len(latentfeatures)
else:
ref_length = np.median(preprocessor.body_length(data))
preprocessor.average_body_length = ref_length
@@ -133,24 +136,54 @@ def embed_individual_data_files(backend, encoder, input_files):
logging.info(f"failure to window track: {larva}")
else:
features[run][larva] = (t[mask], latentfeatures)
+ npoints += len(latentfeatures)
done = True
+ # format the latent features and related info as matrices
+ run_id = []
+ run_id_repeats = []
+ sample_run = next(iter(features.values()))
+ sample_track_id, (sample_times, sample_embedding) = next(iter(sample_run.items()))
+ nfeatures = sample_embedding.shape[1]
+ track_id = np.zeros(npoints, dtype=type(sample_track_id))
+ t = np.zeros(npoints, dtype=sample_times.dtype)
+ embedding = np.zeros((npoints, nfeatures), dtype=sample_embedding.dtype)
+ i = 0
+ for run, tracks in features.items():
+ run_id.append(run)
+ repeats = 0
+ for track, (timesteps, ftr) in tracks.items():
+ j = len(timesteps)
+ track_id[i:i+j] = track
+ t[i:i+j] = timesteps
+ embedding[i:i+j] = ftr
+ i += j
+ repeats += j
+ run_id_repeats.append(repeats)
+ run_id = list(repeat(run_id, run_id_repeats))
+
# save the vectorized data to file
- with h5py.File(get_output_filepath(backend, file), 'w') as f:
- f['n_runs'] = len(features)
- for i, (run, tracks) in enumerate(features.items()):
- g = f.create_group(f'run_{i}')
- g['run_id'] = run
- g['n_tracks'] = len(tracks)
- for j, (track, (t, latent)) in enumerate(tracks.items()):
- h = g.create_group(f'track_{j}')
- h['track_id'] = track
- h['n_steps'] = len(t)
- for k, (t, latent) in enumerate(zip(t, latent)):
- l = h.create_group(f'step_{k}')
- l['time'] = t
- l['embedding'] = latent
+ embeddings = get_output_filepath(backend, file)
+ with h5py.File(embeddings, 'w') as f:
+ # f['n_runs'] = len(features)
+ # for i, (run, tracks) in enumerate(features.items()):
+ # g = f.create_group(f'run_{i}')
+ # g['run_id'] = run
+ # g['n_tracks'] = len(tracks)
+ # for j, (track, (t, latent)) in enumerate(tracks.items()):
+ # h = g.create_group(f'track_{j}')
+ # h['track_id'] = track
+ # h['n_steps'] = len(t)
+ # for k, (t, latent) in enumerate(zip(t, latent)):
+ # l = h.create_group(f'step_{k}')
+ # l['time'] = t
+ # l['embedding'] = latent
+ f['run_id'] = run_id
+ f['track_id'] = track_id
+ f['time'] = t
+ f['embedding'] = embedding
+ check_permissions(embeddings)
@torch.no_grad()
def _embed(preprocessor, encoder, t, data):
@@ -182,6 +215,12 @@ def get_output_filepath(backend, file):
return target
+def repeat(items, n):
+ for item, n in zip(items, n):
+ for _ in range(n):
+ yield item
+
+
from taggingbackends.main import main
if __name__ == "__main__":
--
GitLab