Skip to content
Snippets Groups Projects

Set of commits to be tagged v0.17

Merged François LAURENT requested to merge dev into main
2 files
+ 25
6
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -8,6 +8,7 @@ from taggingbackends.explorer import BackendExplorer, check_permissions
import logging
import json
import re
import os.path
"""
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
@@ -139,7 +140,9 @@ class MaggotTrainer:
y = self.model(x.to(self.device))
return y.cpu().numpy()
def prepare_dataset(self, dataset):
def prepare_dataset(self, dataset, training=False):
if training:
self.model.clf.config['training_dataset'] = str(dataset.path)
try:
dataset.batch_size
except AttributeError:
@@ -203,7 +206,7 @@ class MaggotTrainer:
optimizer.step()
def train(self, dataset):
self.prepare_dataset(dataset)
self.prepare_dataset(dataset, training=True)
criterion = self.init_model_for_training(dataset)
# pre-train the classifier with static encoder weights
if self._pretrain_classifier():
@@ -248,10 +251,26 @@ class MaggotTrainer:
assert pred.size == exp.size
predicted.append(pred)
expected.append(exp)
return np.concatenate(predicted), np.concatenate(expected)
predicted = np.concatenate(predicted)
expected = np.concatenate(expected)
predicted = [self.labels[label] for label in predicted]
expected = [dataset.labels[label] for label in expected]
return predicted, expected
def save(self):
def save(self, copy_dataset=True):
self.model.save()
if copy_dataset:
# copy the compiled training dataset into the model directory
try:
dataset = self.model.clf.config['training_dataset']
except KeyError:
pass
else:
copy = self.model.clf.path / os.path.basename(dataset)
with open(copy, 'wb') as g:
with open(dataset, 'rb') as f:
g.write(f.read())
check_permissions(copy)
@property
def root_dir(self):