Skip to content
Snippets Groups Projects
Select Git revision
  • 628a7fb4622d7d75c9d9afd0f625c30a2ce6c2ed
  • master default protected
2 results

setup.py

Blame
  • embed_model.py 8.76 KiB
    #from taggingbackends.data.labels import Labels
    from taggingbackends.features.skeleton import get_5point_spines
    from maggotuba.features.preprocess import Preprocessor
    from maggotuba.models.modules import MaggotEncoder
    from maggotuba.models.trainers import new_generator
    from behavior_model.models.neural_nets import device
    from taggingbackends.explorer import check_permissions
    from collections import defaultdict
    import numpy as np
    import logging
    import torch
    import h5py
    
    
    def embed_model(backend, **kwargs):
        """
        This function projects input data into MaggotUBA latent space (or embedding).
    
        It supports single *larva_dataset* hdf5 files, or (possibly multiple) track
        data files.
    
        Input files are expected in `data/interim` or `data/raw`.
    
        The predicted labels are saved in `data/processed`, in `predicted.label`
        files, following the same directory structure as in `data/interim` or
        `data/raw`.
        """
        if kwargs.pop('debug', False):
            logging.root.setLevel(logging.DEBUG)
    
        # we pick files in `data/interim` if any, otherwise in `data/raw`
        input_files = backend.list_interim_files(group_by_directories=True)
        if not input_files:
            input_files = backend.list_input_files(group_by_directories=True)
        assert 0 < len(input_files)
    
        # load the model
        model_files = backend.list_model_files()
        config_files = [file
                        for file in model_files
                        if file.name.endswith('config.json')]
        if len(config_files) == 0:
            raise RuntimeError(f"no config files found for tagger: {backend.model_instance}")
        config_files = [file
                        for file in config_files
                        if file.name != 'clf_config.json']
        if len(config_files) == 0:
            raise RuntimeError(f"no encoder config files found")
        elif len(config_files) == 1:
            config_file = config_files[0]
            model = MaggotEncoder(config_file)
    
        # call the `predict` logic on the input data files
        embed_individual_data_files(backend, model, input_files)
    
    def embed_individual_data_files(backend, encoder, input_files):
        from taggingbackends.data.trxmat import TrxMat
        from taggingbackends.data.chore import load_spine
        import taggingbackends.data.fimtrack as fimtrack
    
        encoder.eval()
        encoder.to(device)
    
        features = defaultdict(dict)
        npoints = 0
    
        for input_files in input_files.values():
            done = False
            for file in input_files:
    
                # load the input data (or features)
                if done:
                    logging.info(f"ignoring file: {file.name}")
                    continue
                elif file.name.endswith(".outline"):
                    # skip to spine file
                    logging.info(f"ignoring file: {file.name}")
                    continue
                elif file.name.endswith(".spine"):
                    spine = load_spine(file)
                    run = spine["date_time"].iloc[0]
                    larvae = spine["larva_id"].values
                    t = spine["time"].values
                    data = spine.iloc[:,3:].values
                elif file.name.endswith(".mat"):
                    trx = TrxMat(file)
                    t = trx["t"]
                    data = trx["spine"]
                    run, data = next(iter(data.items()))
                    if run == "spine":
                        run, data = next(iter(data.items()))
                    t = t[run]
                elif file.name.endswith(".csv"):
                    if labels.camera_framerate:
                        logging.info(f"camera frame rate: {labels.camera_framerate}fps")
                    else:
                        logging.info("assuming 30-fps camera frame rate")
                        labels.camera_framerate = 30
                    t, data = fimtrack.read_spines(file, fps=labels.camera_framerate)
                    run = "NA"
                else:
                    # label files not processed; only their data dependencies are
                    logging.info(f"ignoring file: {file.name}")
                    continue
    
                # downsample the skeleton
                if isinstance(data, dict):
                    for larva in data:
                        data[larva] = get_5point_spines(data[larva])
                else:
                    data = get_5point_spines(data)
    
                # assign labels and apply post-prediction filters
                preprocessor = Preprocessor(encoder)
                if isinstance(data, dict):
                    ref_length = np.median(np.concatenate([
                        preprocessor.body_length(spines) for spines in data.values()
                        ]))
                    preprocessor.average_body_length = ref_length
                    logging.info(f"average body length: {ref_length}")
                    for larva, spines in data.items():
                        latentfeatures = _embed(preprocessor, encoder, t[larva], spines)
                        if latentfeatures is None:
                            logging.info(f"failure to window track: {larva}")
                        else:
                            features[run][larva] = (t[larva], latentfeatures)
                            npoints += len(latentfeatures)
                else:
                    ref_length = np.median(preprocessor.body_length(data))
                    preprocessor.average_body_length = ref_length
                    logging.info(f"average body length: {ref_length}")
                    for larva in np.unique(larvae):
                        mask = larvae == larva
                        latentfeatures = _embed(preprocessor, encoder, t[mask], data[mask])
                        if latentfeatures is None:
                            logging.info(f"failure to window track: {larva}")
                        else:
                            features[run][larva] = (t[mask], latentfeatures)
                            npoints += len(latentfeatures)
    
                done = True
    
        # format the latent features and related info as matrices
        run_id = []
        run_id_repeats = []
        sample_run = next(iter(features.values()))
        sample_track_id, (sample_times, sample_embedding) = next(iter(sample_run.items()))
        nfeatures = sample_embedding.shape[1]
        track_id = np.zeros(npoints, dtype=type(sample_track_id))
        t = np.zeros(npoints, dtype=sample_times.dtype)
        embedding = np.zeros((npoints, nfeatures), dtype=sample_embedding.dtype)
        i = 0
        for run, tracks in features.items():
            run_id.append(run)
            repeats = 0
            for track, (timesteps, ftr) in tracks.items():
                j = len(timesteps)
                track_id[i:i+j] = track
                t[i:i+j] = timesteps
                embedding[i:i+j] = ftr
                i += j
                repeats += j
            run_id_repeats.append(repeats)
        run_id = list(repeat(run_id, run_id_repeats))
    
        # save the vectorized data to file
        embeddings = get_output_filepath(backend, file)
        with h5py.File(embeddings, 'w') as f:
            # f['n_runs'] = len(features)
            # for i, (run, tracks) in enumerate(features.items()):
            #     g = f.create_group(f'run_{i}')
            #     g['run_id'] = run
            #     g['n_tracks'] = len(tracks)
            #     for j, (track, (t, latent)) in enumerate(tracks.items()):
            #         h = g.create_group(f'track_{j}')
            #         h['track_id'] = track
            #         h['n_steps'] = len(t)
            #         for k, (t, latent) in enumerate(zip(t, latent)):
            #             l = h.create_group(f'step_{k}')
            #             l['time'] = t
            #             l['embedding'] = latent
            f['run_id'] = run_id
            f['track_id'] = track_id
            f['time'] = t
            f['embedding'] = embedding
        check_permissions(embeddings)
    
    @torch.no_grad()
    def _embed(preprocessor, encoder, t, data):
        rawfeatures = preprocessor(t, data)
        if rawfeatures is not None:
            rawfeatures = torch.from_numpy(rawfeatures.astype(np.float32))
            latentfeatures = encoder(rawfeatures).detach().numpy()
            assert len(t) == len(latentfeatures)
            return latentfeatures
    
    def get_output_filepath(backend, file):
        #if file.is_relative_to(backend.interim_data_dir()): # Py>=3.9
        if str(file).startswith(str(backend.interim_data_dir())):
            subdir = file.parent.relative_to(backend.interim_data_dir())
        else:
            #assert file.is_relative_to(backend.raw_data_dir())
            assert str(file).startswith(str(backend.raw_data_dir()))
            subdir = file.parent.relative_to(backend.raw_data_dir())
        parentdir = backend.processed_data_dir() / subdir
        parentdir.mkdir(parents=True, exist_ok=True)
        target = parentdir / "embeddings.h5"
        if target.is_file():
            logging.info(f"ouput file already exists: {target}")
            i = 0
            while True:
                i += 1
                target = parentdir / f"embeddings-{i}.h5"
                if not target.is_file(): break
        return target
    
    
    def repeat(items, n):
        for item, n in zip(items, n):
            for _ in range(n):
                yield item
    
    
    from taggingbackends.main import main
    
    if __name__ == "__main__":
        main(embed_model)