Skip to content
Snippets Groups Projects
Commit a2a3c5a5 authored by François  LAURENT's avatar François LAURENT
Browse files
parent 59492c0c
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
name = "MaggotUBA-adapter" name = "MaggotUBA-adapter"
version = "0.1.0" version = "0.1.0"
description = "Interface between MaggotUBA and the Nyx tagging UI" description = "Interface between MaggotUBA and the Nyx tagging UI"
authors = ["François Laurent <francois.laurent@posteo.net>"] authors = ["François Laurent"]
license = "MIT" license = "MIT"
packages = [ packages = [
{ include = "maggotuba", from = "src" }, { include = "maggotuba", from = "src" },
...@@ -10,7 +10,7 @@ packages = [ ...@@ -10,7 +10,7 @@ packages = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8,<3.11" python = "^3.8,<3.11"
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "main"} taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", rev = "dev"}
structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="dev-branch"} structured-temporal-convolution = {git = "git@gitlab.pasteur.fr:les-larves/structured-temporal-convolution.git", branch="dev-branch"}
torch = "^1.11.0" torch = "^1.11.0"
numpy = "^1.19.3" numpy = "^1.19.3"
......
...@@ -60,7 +60,7 @@ def predict_model(backend): ...@@ -60,7 +60,7 @@ def predict_model(backend):
config_file = [file for file in model_files if file.name.endswith("config.json")] config_file = [file for file in model_files if file.name.endswith("config.json")]
model = RandomForest(config_file[-1]).load() model = RandomForest(config_file[-1]).load()
# assign labels # assign labels
labels = Labels() labels = Labels(tracking=input_files)
if isinstance(data, dict): if isinstance(data, dict):
ref_length = np.mean(np.concatenate([ ref_length = np.mean(np.concatenate([
model.body_length(spines) for spines in data.values() model.body_length(spines) for spines in data.values()
...@@ -84,12 +84,14 @@ def predict_model(backend): ...@@ -84,12 +84,14 @@ def predict_model(backend):
labels[run, larva] = dict(zip(t[mask], predictions)) labels[run, larva] = dict(zip(t[mask], predictions))
# save the predicted labels to file # save the predicted labels to file
if metadata: if metadata:
labels[run]['metadata'] = metadata labels.metadata = metadata
else: else:
labels[run]['metadata'] = {'filename': file.name} labels.metadata = {'filename': file.name}
labels.metadata['labels'] = ["run", "bend", "stop", "hunch", "back", "roll"] labels.labelspec = {
labels.metadata['label_colors'] = ["#000000", "#ff0000", "#00ff00", "names": ["run", "bend", "stop", "hunch", "back", "roll"],
"#0000ff", "#00ffff", "#ffff00"] "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff",
"#00ffff", "#ffff00"]
}
labels.dump(backend.processed_data_dir() / "predicted.labels") labels.dump(backend.processed_data_dir() / "predicted.labels")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment