#from taggingbackends.data.labels import Labels from taggingbackends.features.skeleton import get_5point_spines from maggotuba.features.preprocess import Preprocessor from maggotuba.models.modules import MaggotEncoder from maggotuba.models.trainers import new_generator from behavior_model.models.neural_nets import device from taggingbackends.explorer import check_permissions from collections import defaultdict import numpy as np import logging import torch import h5py def embed_model(backend, **kwargs): """ This function projects input data into MaggotUBA latent space (or embedding). It supports single *larva_dataset* hdf5 files, or (possibly multiple) track data files. Input files are expected in `data/interim` or `data/raw`. The predicted labels are saved in `data/processed`, in `predicted.label` files, following the same directory structure as in `data/interim` or `data/raw`. """ if kwargs.pop('debug', False): logging.root.setLevel(logging.DEBUG) # we pick files in `data/interim` if any, otherwise in `data/raw` input_files = backend.list_interim_files(group_by_directories=True) if not input_files: input_files = backend.list_input_files(group_by_directories=True) assert 0 < len(input_files) # load the model model_files = backend.list_model_files() config_files = [file for file in model_files if file.name.endswith('config.json')] if len(config_files) == 0: raise RuntimeError(f"no config files found for tagger: {backend.model_instance}") config_files = [file for file in config_files if file.name != 'clf_config.json'] if len(config_files) == 0: raise RuntimeError(f"no encoder config files found") elif len(config_files) == 1: config_file = config_files[0] model = MaggotEncoder(config_file) # call the `predict` logic on the input data files embed_individual_data_files(backend, model, input_files) def embed_individual_data_files(backend, encoder, input_files): from taggingbackends.data.trxmat import TrxMat from taggingbackends.data.chore import load_spine import taggingbackends.data.fimtrack as fimtrack encoder.eval() encoder.to(device) features = defaultdict(dict) npoints = 0 for input_files in input_files.values(): done = False for file in input_files: # load the input data (or features) if done: logging.info(f"ignoring file: {file.name}") continue elif file.name.endswith(".outline"): # skip to spine file logging.info(f"ignoring file: {file.name}") continue elif file.name.endswith(".spine"): spine = load_spine(file) run = spine["date_time"].iloc[0] larvae = spine["larva_id"].values t = spine["time"].values data = spine.iloc[:,3:].values elif file.name.endswith(".mat"): trx = TrxMat(file) t = trx["t"] data = trx["spine"] run, data = next(iter(data.items())) if run == "spine": run, data = next(iter(data.items())) t = t[run] elif file.name.endswith(".csv"): if labels.camera_framerate: logging.info(f"camera frame rate: {labels.camera_framerate}fps") else: logging.info("assuming 30-fps camera frame rate") labels.camera_framerate = 30 t, data = fimtrack.read_spines(file, fps=labels.camera_framerate) run = "NA" else: # label files not processed; only their data dependencies are logging.info(f"ignoring file: {file.name}") continue # downsample the skeleton if isinstance(data, dict): for larva in data: data[larva] = get_5point_spines(data[larva]) else: data = get_5point_spines(data) # assign labels and apply post-prediction filters preprocessor = Preprocessor(encoder) if isinstance(data, dict): ref_length = np.median(np.concatenate([ preprocessor.body_length(spines) for spines in data.values() ])) preprocessor.average_body_length = ref_length logging.info(f"average body length: {ref_length}") for larva, spines in data.items(): latentfeatures = _embed(preprocessor, encoder, t[larva], spines) if latentfeatures is None: logging.info(f"failure to window track: {larva}") else: features[run][larva] = (t[larva], latentfeatures) npoints += len(latentfeatures) else: ref_length = np.median(preprocessor.body_length(data)) preprocessor.average_body_length = ref_length logging.info(f"average body length: {ref_length}") for larva in np.unique(larvae): mask = larvae == larva latentfeatures = _embed(preprocessor, encoder, t[mask], data[mask]) if latentfeatures is None: logging.info(f"failure to window track: {larva}") else: features[run][larva] = (t[mask], latentfeatures) npoints += len(latentfeatures) done = True # format the latent features and related info as matrices run_id = [] run_id_repeats = [] sample_run = next(iter(features.values())) sample_track_id, (sample_times, sample_embedding) = next(iter(sample_run.items())) nfeatures = sample_embedding.shape[1] track_id = np.zeros(npoints, dtype=type(sample_track_id)) t = np.zeros(npoints, dtype=sample_times.dtype) embedding = np.zeros((npoints, nfeatures), dtype=sample_embedding.dtype) i = 0 for run, tracks in features.items(): run_id.append(run) repeats = 0 for track, (timesteps, ftr) in tracks.items(): j = len(timesteps) track_id[i:i+j] = track t[i:i+j] = timesteps embedding[i:i+j] = ftr i += j repeats += j run_id_repeats.append(repeats) run_id = list(repeat(run_id, run_id_repeats)) # save the vectorized data to file embeddings = get_output_filepath(backend, file) with h5py.File(embeddings, 'w') as f: # f['n_runs'] = len(features) # for i, (run, tracks) in enumerate(features.items()): # g = f.create_group(f'run_{i}') # g['run_id'] = run # g['n_tracks'] = len(tracks) # for j, (track, (t, latent)) in enumerate(tracks.items()): # h = g.create_group(f'track_{j}') # h['track_id'] = track # h['n_steps'] = len(t) # for k, (t, latent) in enumerate(zip(t, latent)): # l = h.create_group(f'step_{k}') # l['time'] = t # l['embedding'] = latent f['run_id'] = run_id f['track_id'] = track_id f['time'] = t f['embedding'] = embedding check_permissions(embeddings) @torch.no_grad() def _embed(preprocessor, encoder, t, data): rawfeatures = preprocessor(t, data) if rawfeatures is not None: rawfeatures = torch.from_numpy(rawfeatures.astype(np.float32)) latentfeatures = encoder(rawfeatures).detach().numpy() assert len(t) == len(latentfeatures) return latentfeatures def get_output_filepath(backend, file): #if file.is_relative_to(backend.interim_data_dir()): # Py>=3.9 if str(file).startswith(str(backend.interim_data_dir())): subdir = file.parent.relative_to(backend.interim_data_dir()) else: #assert file.is_relative_to(backend.raw_data_dir()) assert str(file).startswith(str(backend.raw_data_dir())) subdir = file.parent.relative_to(backend.raw_data_dir()) parentdir = backend.processed_data_dir() / subdir parentdir.mkdir(parents=True, exist_ok=True) target = parentdir / "embeddings.h5" if target.is_file(): logging.info(f"ouput file already exists: {target}") i = 0 while True: i += 1 target = parentdir / f"embeddings-{i}.h5" if not target.is_file(): break return target def repeat(items, n): for item, n in zip(items, n): for _ in range(n): yield item from taggingbackends.main import main if __name__ == "__main__": main(embed_model)