diff --git a/pyproject.toml b/pyproject.toml index 150798187a736fbf4804ea16808db5119c4da730..8ebdc6fb9dfde3863678fa53e1fe4c48f874db3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "MaggotUBA-adapter" version = "0.1.0" description = "Interface between MaggotUBA and the Nyx tagging UI" -authors = ["François Laurent <francois.laurent@posteo.net>"] +authors = ["François Laurent"] license = "MIT" packages = [ { include = "maggotuba", from = "src" }, @@ -10,7 +10,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.8,<3.11" -taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "main"} +taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "dev"} structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="dev-branch"} torch = "^1.11.0" numpy = "^1.19.3" diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 3c636da41b6cd4e8d37ebebd1eba5ac0ed056685..a6f0742ebb6d215c9a9ba99096171d10e6d0937e 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -60,7 +60,7 @@ def predict_model(backend): config_file = [file for file in model_files if file.name.endswith("config.json")] model = RandomForest(config_file[-1]).load() # assign labels - labels = Labels() + labels = Labels(tracking=input_files) if isinstance(data, dict): ref_length = np.mean(np.concatenate([ model.body_length(spines) for spines in data.values() @@ -84,12 +84,14 @@ def predict_model(backend): labels[run, larva] = dict(zip(t[mask], predictions)) # save the predicted labels to file if metadata: - labels[run]['metadata'] = metadata + labels.metadata = metadata else: - labels[run]['metadata'] = {'filename': file.name} - labels.metadata['labels'] = ["run", "bend", "stop", "hunch", "back", "roll"] - labels.metadata['label_colors'] = ["#000000", "#ff0000", "#00ff00", - "#0000ff", "#00ffff", "#ffff00"] + labels.metadata = {'filename': file.name} + labels.labelspec = { + "names": ["run", "bend", "stop", "hunch", "back", "roll"], + "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff", + "#00ffff", "#ffff00"] + } labels.dump(backend.processed_data_dir() / "predicted.labels")