Skip to content
Snippets Groups Projects
Commit 6d0108fe authored by François  LAURENT's avatar François LAURENT Committed by François LAURENT
Browse files

Text labels

parent 1900dc03
No related branches found
No related tags found
2 merge requests!11Set of commits to be tagged v0.18,!10Set of commits to be tagged v0.17
......@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v
torch = "^1.11.0"
numpy = "^1.19.3"
protobuf = "3.9.2"
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.15.3"}
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.16"}
[build-system]
requires = ["poetry-core>=1.0.0"]
......
......@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions
import logging
import json
import re
import os.path
"""
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
......@@ -139,7 +140,9 @@ class MaggotTrainer:
y = self.model(x.to(self.device))
return y.cpu().numpy()
def prepare_dataset(self, dataset):
def prepare_dataset(self, dataset, training=False):
if training:
self.model.clf.config['training_dataset'] = str(dataset.path)
try:
dataset.batch_size
except AttributeError:
......@@ -203,7 +206,7 @@ class MaggotTrainer:
optimizer.step()
def train(self, dataset):
self.prepare_dataset(dataset)
self.prepare_dataset(dataset, training=True)
criterion = self.init_model_for_training(dataset)
# pre-train the classifier with static encoder weights
if self._pretrain_classifier():
......@@ -248,10 +251,26 @@ class MaggotTrainer:
assert pred.size == exp.size
predicted.append(pred)
expected.append(exp)
return np.concatenate(predicted), np.concatenate(expected)
predicted = np.concatenate(predicted)
expected = np.concatenate(expected)
predicted = [self.labels[label] for label in predicted]
expected = [dataset.labels[label] for label in expected]
return predicted, expected
def save(self):
def save(self, copy_dataset=True):
self.model.save()
if copy_dataset:
# copy the compiled training dataset into the model directory
try:
dataset = self.model.clf.config['training_dataset']
except KeyError:
pass
else:
copy = self.model.clf.path / os.path.basename(dataset)
with open(copy, 'wb') as g:
with open(dataset, 'rb') as f:
g.write(f.read())
check_permissions(copy)
@property
def root_dir(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment