Skip to content
Snippets Groups Projects
Commit 0bffd4ec authored by François  LAURENT's avatar François LAURENT
Browse files

interpolation support finalized + swap_head_tail config option

parent 3d3e5b90
No related branches found
No related tags found
No related merge requests found
{ {
"project_dir": "models", "project_dir": "subset2_interp10_len20",
"seed": 100, "seed": 100,
"exp_name": "20220517", "exp_name": "latent-100",
"data_dir": "structured-temporal-convolution/20220425/larva_dataset_2022_04_25_20_20_100000.hdf5", "data_dir": "subset2_interp10_len20/larva_dataset_2022_10_05_20_20_100000.hdf5",
"raw_data_dir": "structured-temporal-convolution/larva_dataset/t5_t15_point_dynamics", "raw_data_dir": "larva_dataset/t5-t15-subset2",
"log_dir": "models", "log_dir": "subset2_interp10_len20/training_log/latent-100",
"exp_folder": "models", "exp_folder": "subset2_interp10_len20/training_log/latent-100",
"config": "models/autoencoder_config.json", "config": "subset2_interp10_len20/config-100.json",
"num_workers": 4, "num_workers": 4,
"n_features": 10, "n_features": 10,
"len_traj": 20, "len_traj": 20,
"len_pred": 20, "len_pred": 20,
"dim_latent": 10, "dim_latent": 100,
"activation": "relu", "activation": "relu",
"enc_filters": [ "enc_filters": [
128, 128,
...@@ -101,5 +101,8 @@ ...@@ -101,5 +101,8 @@
"past", "past",
"present", "present",
"future" "future"
] ],
"spine_interpolation": "linear",
"frame_interval": 0.1,
"swap_head_tail": false
} }
No preview for this file type
...@@ -14,7 +14,7 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs): ...@@ -14,7 +14,7 @@ def make_dataset(backend, labels_expected=False, trxmat_only=False, **kwargs):
else: else:
print("generating a larva_dataset file...") print("generating a larva_dataset file...")
# generate a larva_dataset_*.hdf5 file in data/interim/{instance}/ # generate a larva_dataset_*.hdf5 file in data/interim/{instance}/
if trxmat_only: if False:#trxmat_only:
out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs) out = backend.compile_trxmat_database(backend.raw_data_dir(), **kwargs)
else: else:
out = backend.generate_dataset(backend.raw_data_dir(), **kwargs) out = backend.generate_dataset(backend.raw_data_dir(), **kwargs)
......
...@@ -99,7 +99,7 @@ def predict_individual_data_files(backend, model, input_files, labels): ...@@ -99,7 +99,7 @@ def predict_individual_data_files(backend, model, input_files, labels):
model.average_body_length = ref_length model.average_body_length = ref_length
print(f"average body length: {ref_length}") print(f"average body length: {ref_length}")
for larva, spines in data.items(): for larva, spines in data.items():
predictions = model.predict(spines) predictions = model.predict((t[larva], spines))
if predictions is None: if predictions is None:
print(f"failure at windowing track: {larva}") print(f"failure at windowing track: {larva}")
else: else:
...@@ -110,7 +110,7 @@ def predict_individual_data_files(backend, model, input_files, labels): ...@@ -110,7 +110,7 @@ def predict_individual_data_files(backend, model, input_files, labels):
print(f"average body length: {ref_length}") print(f"average body length: {ref_length}")
for larva in np.unique(larvae): for larva in np.unique(larvae):
mask = larvae == larva mask = larvae == larva
predictions = model.predict(data[mask]) predictions = model.predict((t[mask], data[mask]))
if predictions is None: if predictions is None:
print(f"failure at windowing track: {larva}") print(f"failure at windowing track: {larva}")
else: else:
......
...@@ -4,6 +4,7 @@ import torch.nn as nn ...@@ -4,6 +4,7 @@ import torch.nn as nn
from behavior_model.models.neural_nets import device from behavior_model.models.neural_nets import device
#import behavior_model.data.utils as data_utils #import behavior_model.data.utils as data_utils
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot
from taggingbackends.features.skeleton import interpolate
""" """
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer
...@@ -40,10 +41,29 @@ class MaggotTrainer: ...@@ -40,10 +41,29 @@ class MaggotTrainer:
def labels(self, labels): def labels(self, labels):
self.model.clf.behavior_labels = labels self.model.clf.behavior_labels = labels
@property
def swap_head_tail(self):
return self.config.get('swap_head_tail', True)
@swap_head_tail.setter
def swap_head_tail(self, b):
self.config['swap_head_tail'] = b
### TODO: move parts of the below code in a features module ### TODO: move parts of the below code in a features module
def window(self, data): # all the code in this section is called by `predict` only
def window(self, t, data):
interpolation_args = {k: self.config[k]
for k in ('spine_interpolation', 'frame_interval')
if k in self.config}
winlen = self.config["len_traj"] winlen = self.config["len_traj"]
N = data.shape[0]+1 N = data.shape[0]+1
if interpolation_args:
for m in range(0, N-1):
win = interpolate(t, data, m, winlen, **interpolation_args)
if win is not None:
assert win.shape[0] == winlen
yield win
else:
for m in range(0, N-winlen): for m in range(0, N-winlen):
n = m + winlen n = m + winlen
yield data[m:n] yield data[m:n]
...@@ -78,12 +98,15 @@ class MaggotTrainer: ...@@ -78,12 +98,15 @@ class MaggotTrainer:
w = np.einsum("ij,jkl", rot, np.reshape(w.T, (2, 5, -1), order='F')) w = np.einsum("ij,jkl", rot, np.reshape(w.T, (2, 5, -1), order='F'))
return w return w
def preprocess(self, data): def preprocess(self, t, data):
ws = [] ws = []
for w in self.window(data): for w in self.window(t, data):
ws.append(self.normalize(w)) ws.append(self.normalize(w))
if ws: if ws:
return self.pad(np.stack(ws))[:,:,::-1,:] # swap head and tail ret = self.pad(np.stack(ws))
if self.swap_head_tail:
ret = ret[:,:,::-1,:]
return ret
### ###
def forward(self, x, train=False): def forward(self, x, train=False):
...@@ -167,7 +190,8 @@ class MaggotTrainer: ...@@ -167,7 +190,8 @@ class MaggotTrainer:
model.eval() model.eval()
model.to(self.device) model.to(self.device)
if subset is None: if subset is None:
data = self.preprocess(data) # data is a (times, spines) couple
data = self.preprocess(*data)
if data is None: if data is None:
return return
output = self.forward(data) output = self.forward(data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment