Skip to content
Snippets Groups Projects
Commit fcefbbac authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'hdf5labels' into 'dev'

Text labels

See merge request !9
parents 5d46423e 5a30d4da
No related branches found
No related tags found
3 merge requests!11Set of commits to be tagged v0.18,!10Set of commits to be tagged v0.17,!9Text labels
...@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v ...@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v
torch = "^1.11.0" torch = "^1.11.0"
numpy = "^1.19.3" numpy = "^1.19.3"
protobuf = "3.9.2" protobuf = "3.9.2"
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.15.3"} taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.16"}
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
......
...@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions ...@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions
import logging import logging
import json import json
import re import re
import os.path
""" """
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
...@@ -139,7 +140,9 @@ class MaggotTrainer: ...@@ -139,7 +140,9 @@ class MaggotTrainer:
y = self.model(x.to(self.device)) y = self.model(x.to(self.device))
return y.cpu().numpy() return y.cpu().numpy()
def prepare_dataset(self, dataset): def prepare_dataset(self, dataset, training=False):
if training:
self.model.clf.config['training_dataset'] = str(dataset.path)
try: try:
dataset.batch_size dataset.batch_size
except AttributeError: except AttributeError:
...@@ -203,7 +206,7 @@ class MaggotTrainer: ...@@ -203,7 +206,7 @@ class MaggotTrainer:
optimizer.step() optimizer.step()
def train(self, dataset): def train(self, dataset):
self.prepare_dataset(dataset) self.prepare_dataset(dataset, training=True)
criterion = self.init_model_for_training(dataset) criterion = self.init_model_for_training(dataset)
# pre-train the classifier with static encoder weights # pre-train the classifier with static encoder weights
if self._pretrain_classifier(): if self._pretrain_classifier():
...@@ -248,10 +251,26 @@ class MaggotTrainer: ...@@ -248,10 +251,26 @@ class MaggotTrainer:
assert pred.size == exp.size assert pred.size == exp.size
predicted.append(pred) predicted.append(pred)
expected.append(exp) expected.append(exp)
return np.concatenate(predicted), np.concatenate(expected) predicted = np.concatenate(predicted)
expected = np.concatenate(expected)
predicted = [self.labels[label] for label in predicted]
expected = [dataset.labels[label] for label in expected]
return predicted, expected
def save(self): def save(self, copy_dataset=True):
self.model.save() self.model.save()
if copy_dataset:
# copy the compiled training dataset into the model directory
try:
dataset = self.model.clf.config['training_dataset']
except KeyError:
pass
else:
copy = self.model.clf.path / os.path.basename(dataset)
with open(copy, 'wb') as g:
with open(dataset, 'rb') as f:
g.write(f.read())
check_permissions(copy)
@property @property
def root_dir(self): def root_dir(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment