Skip to content
Snippets Groups Projects
Commit 5a30d4da authored by François  LAURENT's avatar François LAURENT
Browse files

Text labels

parent 5d46423e
No related branches found
No related tags found
3 merge requests!11Set of commits to be tagged v0.18,!10Set of commits to be tagged v0.17,!9Text labels
......@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v
torch = "^1.11.0"
numpy = "^1.19.3"
protobuf = "3.9.2"
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.15.3"}
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.16"}
[build-system]
requires = ["poetry-core>=1.0.0"]
......
......@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions
import logging
import json
import re
import os.path
"""
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
......@@ -139,7 +140,9 @@ class MaggotTrainer:
y = self.model(x.to(self.device))
return y.cpu().numpy()
def prepare_dataset(self, dataset):
def prepare_dataset(self, dataset, training=False):
if training:
self.model.clf.config['training_dataset'] = str(dataset.path)
try:
dataset.batch_size
except AttributeError:
......@@ -203,7 +206,7 @@ class MaggotTrainer:
optimizer.step()
def train(self, dataset):
self.prepare_dataset(dataset)
self.prepare_dataset(dataset, training=True)
criterion = self.init_model_for_training(dataset)
# pre-train the classifier with static encoder weights
if self._pretrain_classifier():
......@@ -248,10 +251,26 @@ class MaggotTrainer:
assert pred.size == exp.size
predicted.append(pred)
expected.append(exp)
return np.concatenate(predicted), np.concatenate(expected)
predicted = np.concatenate(predicted)
expected = np.concatenate(expected)
predicted = [self.labels[label] for label in predicted]
expected = [dataset.labels[label] for label in expected]
return predicted, expected
def save(self):
def save(self, copy_dataset=True):
self.model.save()
if copy_dataset:
# copy the compiled training dataset into the model directory
try:
dataset = self.model.clf.config['training_dataset']
except KeyError:
pass
else:
copy = self.model.clf.path / os.path.basename(dataset)
with open(copy, 'wb') as g:
with open(dataset, 'rb') as f:
g.write(f.read())
check_permissions(copy)
@property
def root_dir(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment