Skip to content
Snippets Groups Projects
Commit 44bfff72 authored by François  LAURENT's avatar François LAURENT
Browse files

automatic file selection delegated to taggingbackends

parent 5d744c14
Branches
Tags
1 merge request!12Set of commits to be tagged v0.19
......@@ -5,6 +5,10 @@ import numpy as np
import logging
import os.path
tracking_data_file_extensions = ('.spine', '.outline', '.csv', '.mat', '.hdf5')
def predict_model(backend, **kwargs):
"""
This function generates predicted labels for all the input data.
......@@ -25,17 +29,16 @@ def predict_model(backend, **kwargs):
input_files = backend.list_interim_files(group_by_directories=True)
if not input_files:
input_files = backend.list_input_files(group_by_directories=True)
input_files = supported_input_files(input_files)
assert 0 < len(input_files)
# initialize output labels
input_files_and_labels = backend.prepare_labels(input_files)
input_files_and_labels = backend.prepare_labels(input_files, single_input=True,
allowed_file_extensions=tracking_data_file_extensions)
assert 0 < len(input_files_and_labels)
# load the model
model_files = backend.list_model_files()
config_files = [file
for file in model_files
config_files = [file for file in model_files
if file.name.endswith('config.json')]
if len(config_files) == 0:
raise RuntimeError(f"no config files found for tagger: {backend.model_instance}")
......@@ -200,38 +203,6 @@ def apply_filters(labels, post_filters):
return labels
def supported_input_files(files):
if isinstance(files, dict):
files_ = {}
for d, fs in files.items():
fs = supported_input_files(fs)
if fs:
files_[d] = fs
files = files_
else:
files = [f for f in files if supported_input_file(f)]
return files
def supported_input_file(file):
if isinstance(file, str):
file = os.path.basename(file)
else:
file = file.name
if file.startswith('trx') and file.endswith('.mat'):
return True
elif file.endswith('.spine') or file.endswith('.outline'):
return True
elif file.endswith('.json') or file.endswith('.label') or \
file.endswith('.labels') or file.endswith('.nyxlabel'):
return True
elif file.endswith('.csv'):
return True
elif file.endswith('.hdf5'):
return True
else:
return False
from taggingbackends.main import main
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment