-
François LAURENT authoredFrançois LAURENT authored
embed_model.py 8.76 KiB
#from taggingbackends.data.labels import Labels
from taggingbackends.features.skeleton import get_5point_spines
from maggotuba.features.preprocess import Preprocessor
from maggotuba.models.modules import MaggotEncoder
from maggotuba.models.trainers import new_generator
from behavior_model.models.neural_nets import device
from taggingbackends.explorer import check_permissions
from collections import defaultdict
import numpy as np
import logging
import torch
import h5py
def embed_model(backend, **kwargs):
"""
This function projects input data into MaggotUBA latent space (or embedding).
It supports single *larva_dataset* hdf5 files, or (possibly multiple) track
data files.
Input files are expected in `data/interim` or `data/raw`.
The predicted labels are saved in `data/processed`, in `predicted.label`
files, following the same directory structure as in `data/interim` or
`data/raw`.
"""
if kwargs.pop('debug', False):
logging.root.setLevel(logging.DEBUG)
# we pick files in `data/interim` if any, otherwise in `data/raw`
input_files = backend.list_interim_files(group_by_directories=True)
if not input_files:
input_files = backend.list_input_files(group_by_directories=True)
assert 0 < len(input_files)
# load the model
model_files = backend.list_model_files()
config_files = [file
for file in model_files
if file.name.endswith('config.json')]
if len(config_files) == 0:
raise RuntimeError(f"no config files found for tagger: {backend.model_instance}")
config_files = [file
for file in config_files
if file.name != 'clf_config.json']
if len(config_files) == 0:
raise RuntimeError(f"no encoder config files found")
elif len(config_files) == 1:
config_file = config_files[0]
model = MaggotEncoder(config_file)
# call the `predict` logic on the input data files
embed_individual_data_files(backend, model, input_files)
def embed_individual_data_files(backend, encoder, input_files):
from taggingbackends.data.trxmat import TrxMat
from taggingbackends.data.chore import load_spine
import taggingbackends.data.fimtrack as fimtrack
encoder.eval()
encoder.to(device)
features = defaultdict(dict)
npoints = 0
for input_files in input_files.values():
done = False
for file in input_files:
# load the input data (or features)
if done:
logging.info(f"ignoring file: {file.name}")
continue
elif file.name.endswith(".outline"):
# skip to spine file
logging.info(f"ignoring file: {file.name}")
continue
elif file.name.endswith(".spine"):
spine = load_spine(file)
run = spine["date_time"].iloc[0]
larvae = spine["larva_id"].values
t = spine["time"].values
data = spine.iloc[:,3:].values
elif file.name.endswith(".mat"):
trx = TrxMat(file)
t = trx["t"]
data = trx["spine"]
run, data = next(iter(data.items()))
if run == "spine":
run, data = next(iter(data.items()))
t = t[run]
elif file.name.endswith(".csv"):
if labels.camera_framerate:
logging.info(f"camera frame rate: {labels.camera_framerate}fps")
else:
logging.info("assuming 30-fps camera frame rate")
labels.camera_framerate = 30
t, data = fimtrack.read_spines(file, fps=labels.camera_framerate)
run = "NA"
else:
# label files not processed; only their data dependencies are
logging.info(f"ignoring file: {file.name}")
continue
# downsample the skeleton
if isinstance(data, dict):
for larva in data:
data[larva] = get_5point_spines(data[larva])
else:
data = get_5point_spines(data)
# assign labels and apply post-prediction filters
preprocessor = Preprocessor(encoder)
if isinstance(data, dict):
ref_length = np.median(np.concatenate([
preprocessor.body_length(spines) for spines in data.values()
]))
preprocessor.average_body_length = ref_length
logging.info(f"average body length: {ref_length}")
for larva, spines in data.items():
latentfeatures = _embed(preprocessor, encoder, t[larva], spines)
if latentfeatures is None:
logging.info(f"failure to window track: {larva}")
else:
features[run][larva] = (t[larva], latentfeatures)
npoints += len(latentfeatures)
else:
ref_length = np.median(preprocessor.body_length(data))
preprocessor.average_body_length = ref_length
logging.info(f"average body length: {ref_length}")
for larva in np.unique(larvae):
mask = larvae == larva
latentfeatures = _embed(preprocessor, encoder, t[mask], data[mask])
if latentfeatures is None:
logging.info(f"failure to window track: {larva}")
else:
features[run][larva] = (t[mask], latentfeatures)
npoints += len(latentfeatures)
done = True
# format the latent features and related info as matrices
run_id = []
run_id_repeats = []
sample_run = next(iter(features.values()))
sample_track_id, (sample_times, sample_embedding) = next(iter(sample_run.items()))
nfeatures = sample_embedding.shape[1]
track_id = np.zeros(npoints, dtype=type(sample_track_id))
t = np.zeros(npoints, dtype=sample_times.dtype)
embedding = np.zeros((npoints, nfeatures), dtype=sample_embedding.dtype)
i = 0
for run, tracks in features.items():
run_id.append(run)
repeats = 0
for track, (timesteps, ftr) in tracks.items():
j = len(timesteps)
track_id[i:i+j] = track
t[i:i+j] = timesteps
embedding[i:i+j] = ftr
i += j
repeats += j
run_id_repeats.append(repeats)
run_id = list(repeat(run_id, run_id_repeats))
# save the vectorized data to file
embeddings = get_output_filepath(backend, file)
with h5py.File(embeddings, 'w') as f:
# f['n_runs'] = len(features)
# for i, (run, tracks) in enumerate(features.items()):
# g = f.create_group(f'run_{i}')
# g['run_id'] = run
# g['n_tracks'] = len(tracks)
# for j, (track, (t, latent)) in enumerate(tracks.items()):
# h = g.create_group(f'track_{j}')
# h['track_id'] = track
# h['n_steps'] = len(t)
# for k, (t, latent) in enumerate(zip(t, latent)):
# l = h.create_group(f'step_{k}')
# l['time'] = t
# l['embedding'] = latent
f['run_id'] = run_id
f['track_id'] = track_id
f['time'] = t
f['embedding'] = embedding
check_permissions(embeddings)
@torch.no_grad()
def _embed(preprocessor, encoder, t, data):
rawfeatures = preprocessor(t, data)
if rawfeatures is not None:
rawfeatures = torch.from_numpy(rawfeatures.astype(np.float32))
latentfeatures = encoder(rawfeatures).detach().numpy()
assert len(t) == len(latentfeatures)
return latentfeatures
def get_output_filepath(backend, file):
#if file.is_relative_to(backend.interim_data_dir()): # Py>=3.9
if str(file).startswith(str(backend.interim_data_dir())):
subdir = file.parent.relative_to(backend.interim_data_dir())
else:
#assert file.is_relative_to(backend.raw_data_dir())
assert str(file).startswith(str(backend.raw_data_dir()))
subdir = file.parent.relative_to(backend.raw_data_dir())
parentdir = backend.processed_data_dir() / subdir
parentdir.mkdir(parents=True, exist_ok=True)
target = parentdir / "embeddings.h5"
if target.is_file():
logging.info(f"ouput file already exists: {target}")
i = 0
while True:
i += 1
target = parentdir / f"embeddings-{i}.h5"
if not target.is_file(): break
return target
def repeat(items, n):
for item, n in zip(items, n):
for _ in range(n):
yield item
from taggingbackends.main import main
if __name__ == "__main__":
main(embed_model)