Skip to content
Snippets Groups Projects
embed_model.py 8.76 KiB
#from taggingbackends.data.labels import Labels
from taggingbackends.features.skeleton import get_5point_spines
from maggotuba.features.preprocess import Preprocessor
from maggotuba.models.modules import MaggotEncoder
from maggotuba.models.trainers import new_generator
from behavior_model.models.neural_nets import device
from taggingbackends.explorer import check_permissions
from collections import defaultdict
import numpy as np
import logging
import torch
import h5py


def embed_model(backend, **kwargs):
    """
    This function projects input data into MaggotUBA latent space (or embedding).

    It supports single *larva_dataset* hdf5 files, or (possibly multiple) track
    data files.

    Input files are expected in `data/interim` or `data/raw`.

    The predicted labels are saved in `data/processed`, in `predicted.label`
    files, following the same directory structure as in `data/interim` or
    `data/raw`.
    """
    if kwargs.pop('debug', False):
        logging.root.setLevel(logging.DEBUG)

    # we pick files in `data/interim` if any, otherwise in `data/raw`
    input_files = backend.list_interim_files(group_by_directories=True)
    if not input_files:
        input_files = backend.list_input_files(group_by_directories=True)
    assert 0 < len(input_files)

    # load the model
    model_files = backend.list_model_files()
    config_files = [file
                    for file in model_files
                    if file.name.endswith('config.json')]
    if len(config_files) == 0:
        raise RuntimeError(f"no config files found for tagger: {backend.model_instance}")
    config_files = [file
                    for file in config_files
                    if file.name != 'clf_config.json']
    if len(config_files) == 0:
        raise RuntimeError(f"no encoder config files found")
    elif len(config_files) == 1:
        config_file = config_files[0]
        model = MaggotEncoder(config_file)

    # call the `predict` logic on the input data files
    embed_individual_data_files(backend, model, input_files)

def embed_individual_data_files(backend, encoder, input_files):
    from taggingbackends.data.trxmat import TrxMat
    from taggingbackends.data.chore import load_spine
    import taggingbackends.data.fimtrack as fimtrack

    encoder.eval()
    encoder.to(device)

    features = defaultdict(dict)
    npoints = 0

    for input_files in input_files.values():
        done = False
        for file in input_files:

            # load the input data (or features)
            if done:
                logging.info(f"ignoring file: {file.name}")
                continue
            elif file.name.endswith(".outline"):
                # skip to spine file
                logging.info(f"ignoring file: {file.name}")
                continue
            elif file.name.endswith(".spine"):
                spine = load_spine(file)
                run = spine["date_time"].iloc[0]
                larvae = spine["larva_id"].values
                t = spine["time"].values
                data = spine.iloc[:,3:].values
            elif file.name.endswith(".mat"):
                trx = TrxMat(file)
                t = trx["t"]
                data = trx["spine"]
                run, data = next(iter(data.items()))
                if run == "spine":
                    run, data = next(iter(data.items()))
                t = t[run]
            elif file.name.endswith(".csv"):
                if labels.camera_framerate:
                    logging.info(f"camera frame rate: {labels.camera_framerate}fps")
                else:
                    logging.info("assuming 30-fps camera frame rate")
                    labels.camera_framerate = 30
                t, data = fimtrack.read_spines(file, fps=labels.camera_framerate)
                run = "NA"
            else:
                # label files not processed; only their data dependencies are
                logging.info(f"ignoring file: {file.name}")
                continue

            # downsample the skeleton
            if isinstance(data, dict):
                for larva in data:
                    data[larva] = get_5point_spines(data[larva])
            else:
                data = get_5point_spines(data)

            # assign labels and apply post-prediction filters
            preprocessor = Preprocessor(encoder)
            if isinstance(data, dict):
                ref_length = np.median(np.concatenate([
                    preprocessor.body_length(spines) for spines in data.values()
                    ]))
                preprocessor.average_body_length = ref_length
                logging.info(f"average body length: {ref_length}")
                for larva, spines in data.items():
                    latentfeatures = _embed(preprocessor, encoder, t[larva], spines)
                    if latentfeatures is None:
                        logging.info(f"failure to window track: {larva}")
                    else:
                        features[run][larva] = (t[larva], latentfeatures)
                        npoints += len(latentfeatures)
            else:
                ref_length = np.median(preprocessor.body_length(data))
                preprocessor.average_body_length = ref_length
                logging.info(f"average body length: {ref_length}")
                for larva in np.unique(larvae):
                    mask = larvae == larva
                    latentfeatures = _embed(preprocessor, encoder, t[mask], data[mask])
                    if latentfeatures is None:
                        logging.info(f"failure to window track: {larva}")
                    else:
                        features[run][larva] = (t[mask], latentfeatures)
                        npoints += len(latentfeatures)

            done = True

    # format the latent features and related info as matrices
    run_id = []
    run_id_repeats = []
    sample_run = next(iter(features.values()))
    sample_track_id, (sample_times, sample_embedding) = next(iter(sample_run.items()))
    nfeatures = sample_embedding.shape[1]
    track_id = np.zeros(npoints, dtype=type(sample_track_id))
    t = np.zeros(npoints, dtype=sample_times.dtype)
    embedding = np.zeros((npoints, nfeatures), dtype=sample_embedding.dtype)
    i = 0
    for run, tracks in features.items():
        run_id.append(run)
        repeats = 0
        for track, (timesteps, ftr) in tracks.items():
            j = len(timesteps)
            track_id[i:i+j] = track
            t[i:i+j] = timesteps
            embedding[i:i+j] = ftr
            i += j
            repeats += j
        run_id_repeats.append(repeats)
    run_id = list(repeat(run_id, run_id_repeats))

    # save the vectorized data to file
    embeddings = get_output_filepath(backend, file)
    with h5py.File(embeddings, 'w') as f:
        # f['n_runs'] = len(features)
        # for i, (run, tracks) in enumerate(features.items()):
        #     g = f.create_group(f'run_{i}')
        #     g['run_id'] = run
        #     g['n_tracks'] = len(tracks)
        #     for j, (track, (t, latent)) in enumerate(tracks.items()):
        #         h = g.create_group(f'track_{j}')
        #         h['track_id'] = track
        #         h['n_steps'] = len(t)
        #         for k, (t, latent) in enumerate(zip(t, latent)):
        #             l = h.create_group(f'step_{k}')
        #             l['time'] = t
        #             l['embedding'] = latent
        f['run_id'] = run_id
        f['track_id'] = track_id
        f['time'] = t
        f['embedding'] = embedding
    check_permissions(embeddings)

@torch.no_grad()
def _embed(preprocessor, encoder, t, data):
    rawfeatures = preprocessor(t, data)
    if rawfeatures is not None:
        rawfeatures = torch.from_numpy(rawfeatures.astype(np.float32))
        latentfeatures = encoder(rawfeatures).detach().numpy()
        assert len(t) == len(latentfeatures)
        return latentfeatures

def get_output_filepath(backend, file):
    #if file.is_relative_to(backend.interim_data_dir()): # Py>=3.9
    if str(file).startswith(str(backend.interim_data_dir())):
        subdir = file.parent.relative_to(backend.interim_data_dir())
    else:
        #assert file.is_relative_to(backend.raw_data_dir())
        assert str(file).startswith(str(backend.raw_data_dir()))
        subdir = file.parent.relative_to(backend.raw_data_dir())
    parentdir = backend.processed_data_dir() / subdir
    parentdir.mkdir(parents=True, exist_ok=True)
    target = parentdir / "embeddings.h5"
    if target.is_file():
        logging.info(f"ouput file already exists: {target}")
        i = 0
        while True:
            i += 1
            target = parentdir / f"embeddings-{i}.h5"
            if not target.is_file(): break
    return target


def repeat(items, n):
    for item, n in zip(items, n):
        for _ in range(n):
            yield item


from taggingbackends.main import main

if __name__ == "__main__":
    main(embed_model)