Skip to content
Snippets Groups Projects
Commit 3b054745 authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'dev' into main

parents d98fc5a0 d48b1fe4
No related branches found
No related tags found
No related merge requests found
[tool.poetry]
name = "MaggotUBA-adapter"
version = "0.9"
version = "0.10"
description = "Interface between MaggotUBA and the Nyx tagging UI"
authors = ["François Laurent"]
license = "MIT"
......
import glob
import pathlib
def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
def make_dataset(backend, labels_expected=False, trxmat_only=False, balancing_strategy='maggotuba', **kwargs):
if labels_expected:
larva_dataset_file = glob.glob(str(backend.raw_data_dir() / "larva_dataset_*.hdf5"))
if larva_dataset_file:
......@@ -17,7 +17,9 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
if False:#trxmat_only:
out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
else:
out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
out = backend.generate_dataset(backend.raw_data_dir(),
balance=isinstance(balancing_strategy, str) and balancing_strategy.lower() == 'maggotuba',
**kwargs)
print(f"larva_dataset file generated: {out}")
......
......@@ -6,7 +6,7 @@ import json
import glob
def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, **kwargs):
subsets=(1, 0, 0), rng_seed=None, balancing_strategy='maggotuba', **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
......@@ -14,6 +14,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs)
dataset.weight_classes = isinstance(balancing_strategy, str) and (balancing_strategy.lower() == 'auto')
labels = dataset.labels
assert 0 < len(labels)
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
......
......@@ -5,6 +5,7 @@ from behavior_model.models.neural_nets import device
#import behavior_model.data.utils as data_utils
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
from taggingbackends.features.skeleton import interpolate
import logging
"""
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
......@@ -94,7 +95,11 @@ class MaggotTrainer:
w = w - np.tile(wc, (1, 5))
# rotate
v = np.mean(w[:,8:10] - w[:,0:2], axis=0)
v = v / np.sqrt(np.dot(v, v))
vnorm = np.sqrt(np.dot(v, v))
if vnorm == 0:
logging.warning('null distance between head and tail')
else:
v = v / vnorm
c, s = v / self.average_body_length # scale using the rotation matrix
rot = np.array([[ c, s],
[-s, c]]) # clockwise rotation
......@@ -154,10 +159,13 @@ class MaggotTrainer:
def train(self, dataset):
self.prepare_dataset(dataset)
kwargs = {}
if dataset.class_weights is not None:
kwargs['weight'] = torch.from_numpy(dataset.class_weights.astype(np.float32)).to(self.device)
model = self.model
model.train() # this only sets the model in training mode (enables gradients)
model.to(self.device)
criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(**kwargs)
nsteps = self.config['optim_iter']
grad_clip = self.config['grad_clip']
# pre-train the classifier with static encoder weights
......@@ -190,7 +198,7 @@ class MaggotTrainer:
return self
def draw(self, dataset, subset="train"):
data, expected = dataset.getobs(subset)
data, expected = dataset.getbatch(subset)
if isinstance(data, list):
data = torch.stack(data)
data = data.to(torch.float32).to(self.device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment