diff --git a/pyproject.toml b/pyproject.toml
index 4132fd5c6691d30e016a25c72df17c0c1771472b..01e9c3d3c407fff7624add08df7e33820c8bc4fd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "MaggotUBA-adapter"
-version = "0.16.4"
+version = "0.17"
description = "Interface between MaggotUBA and the Nyx tagging UI"
authors = ["François Laurent"]
license = "MIT"
@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v
torch = "^1.11.0"
numpy = "^1.19.3"
protobuf = "3.9.2"
-taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.15.3"}
+taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.16"}
[build-system]
requires = ["poetry-core>=1.0.0"]
diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py
index 9143828147944b6cc6d26daa1d0e74407a094feb..2603a8490dc3ec052ae15962592a7b85ff5bcab8 100644
--- a/src/maggotuba/models/trainers.py
+++ b/src/maggotuba/models/trainers.py
@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions
import logging
import json
import re
+import os.path
"""
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
@@ -139,7 +140,9 @@ class MaggotTrainer:
y = self.model(x.to(self.device))
return y.cpu().numpy()
- def prepare_dataset(self, dataset):
+ def prepare_dataset(self, dataset, training=False):
+ if training:
+ self.model.clf.config['training_dataset'] = str(dataset.path)
try:
dataset.batch_size
except AttributeError:
@@ -203,7 +206,7 @@ class MaggotTrainer:
optimizer.step()
def train(self, dataset):
- self.prepare_dataset(dataset)
+ self.prepare_dataset(dataset, training=True)
criterion = self.init_model_for_training(dataset)
# pre-train the classifier with static encoder weights
if self._pretrain_classifier():
@@ -248,10 +251,26 @@ class MaggotTrainer:
assert pred.size == exp.size
predicted.append(pred)
expected.append(exp)
- return np.concatenate(predicted), np.concatenate(expected)
+ predicted = np.concatenate(predicted)
+ expected = np.concatenate(expected)
+ predicted = [self.labels[label] for label in predicted]
+ expected = [dataset.labels[label] for label in expected]
+ return predicted, expected
- def save(self):
+ def save(self, copy_dataset=True):
self.model.save()
+ if copy_dataset:
+ # copy the compiled training dataset into the model directory
+ try:
+ dataset = self.model.clf.config['training_dataset']
+ except KeyError:
+ pass
+ else:
+ copy = self.model.clf.path / os.path.basename(dataset)
+ with open(copy, 'wb') as g:
+ with open(dataset, 'rb') as f:
+ g.write(f.read())
+ check_permissions(copy)
@property
def root_dir(self):