Skip to content
Snippets Groups Projects
Commit 018c318c authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'dev' into 'main'

Set of commits to be tagged v0.17

See merge request !10
parents dc04b9af 0a79cb60
No related branches found
No related tags found
1 merge request!10Set of commits to be tagged v0.17
[tool.poetry]
name = "MaggotUBA-adapter"
version = "0.16.4"
version = "0.17"
description = "Interface between MaggotUBA and the Nyx tagging UI"
authors = ["François Laurent"]
license = "MIT"
......@@ -14,7 +14,7 @@ maggotuba-core = {git = "https://gitlab.pasteur.fr/nyx/MaggotUBA-core", tag = "v
torch = "^1.11.0"
numpy = "^1.19.3"
protobuf = "3.9.2"
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.15.3"}
taggingbackends = {git = "https://gitlab.pasteur.fr/nyx/TaggingBackends", tag = "v0.16"}
[build-system]
requires = ["poetry-core>=1.0.0"]
......
......@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions
import logging
import json
import re
import os.path
"""
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
......@@ -139,7 +140,9 @@ class MaggotTrainer:
y = self.model(x.to(self.device))
return y.cpu().numpy()
def prepare_dataset(self, dataset):
def prepare_dataset(self, dataset, training=False):
if training:
self.model.clf.config['training_dataset'] = str(dataset.path)
try:
dataset.batch_size
except AttributeError:
......@@ -203,7 +206,7 @@ class MaggotTrainer:
optimizer.step()
def train(self, dataset):
self.prepare_dataset(dataset)
self.prepare_dataset(dataset, training=True)
criterion = self.init_model_for_training(dataset)
# pre-train the classifier with static encoder weights
if self._pretrain_classifier():
......@@ -248,10 +251,26 @@ class MaggotTrainer:
assert pred.size == exp.size
predicted.append(pred)
expected.append(exp)
return np.concatenate(predicted), np.concatenate(expected)
predicted = np.concatenate(predicted)
expected = np.concatenate(expected)
predicted = [self.labels[label] for label in predicted]
expected = [dataset.labels[label] for label in expected]
return predicted, expected
def save(self):
def save(self, copy_dataset=True):
self.model.save()
if copy_dataset:
# copy the compiled training dataset into the model directory
try:
dataset = self.model.clf.config['training_dataset']
except KeyError:
pass
else:
copy = self.model.clf.path / os.path.basename(dataset)
with open(copy, 'wb') as g:
with open(dataset, 'rb') as f:
g.write(f.read())
check_permissions(copy)
@property
def root_dir(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment