Skip to content
Snippets Groups Projects
Commit a9746f0f authored by François  LAURENT's avatar François LAURENT
Browse files

multi-layer classifier

parent c55ad3c7
No related branches found
No related tags found
No related merge requests found
......@@ -7,21 +7,43 @@ import torch.nn as nn
from behavior_model.models.neural_nets import Encoder, device
import behavior_model.data.utils as data_utils
class DeepLinear(nn.Module):
def __init__(self, n_input, n_output, n_layers=1):
super().__init__()
if n_layers is None: n_layers = 1
self.layers = []
layers = []
for _ in range(n_layers - 1):
layer = nn.Linear(n_input, n_input)
self.layers.append(layer)
layers.append(layer)
layers.append(nn.ReLU())
layer = nn.Linear(n_input, n_output)
self.layers.append(layer)
layers.append(layer)
self.classifier = nn.Sequential(*layers)
def _init_layers(self):
for layer in self.layers:
nn.init.xavier_uniform_(layer.weight)
nn.init.zeros_(layer.bias)
def forward(self, x):
return self.classifier.forward(x)
class SupervisedMaggot(nn.Module):
def __init__(self, n_latent_features, n_behaviors, enc_config, enc_path,
clf_path=None):
clf_path=None, n_layers=1):
super().__init__()
# Pretrained or trained MaggotUBA encoder
self.encoder = encoder = Encoder(**enc_config)
encoder.load_state_dict(torch.load(enc_path))
# Classifier stacked atop the encoder
self.clf = nn.Linear(n_latent_features, n_behaviors)
self.clf = DeepLinear(n_latent_features, n_behaviors, n_layers)
if clf_path:
self.clf.load_state_dict(torch.load(clf_path))
else:
nn.init.xavier_uniform_(self.clf.weight)
nn.init.zeros_(self.clf.bias)
self.clf._init_layers()
def forward(self, x):
#x = torch.flip(x, (2,))
......@@ -44,15 +66,15 @@ class DenseLayer:
config=None,
autoencoder_config=None,
n_behaviors=None,
n_layers=1,
average_body_length=None,
device=device):
# MaggotUBA autoencoder config
self._config = autoencoder_config
self._clf_config = config
self.prepend_log_dir = True
self._model = None
if n_behaviors is not None:
self.n_behaviors = n_behaviors
self._n_behaviors = n_behaviors
self._n_layers = n_layers
self.average_body_length = average_body_length
self.device = device
......@@ -66,8 +88,10 @@ class DenseLayer:
if self._config is None:
self._config = self.clf_config.get("autoencoder_config", None)
if isinstance(self._config, (str, pathlib.Path)):
with open(self._config, "r") as f:
path = self._config
with open(path, "r") as f:
self._config = json.load(f)
self._config["config"] = str(path)
return self._config
@config.setter
......@@ -117,12 +141,23 @@ class DenseLayer:
@property
def n_behaviors(self):
return self.clf_config.get("n_behaviors", None)
return self.clf_config.get("n_behaviors", self._n_behaviors)
@n_behaviors.setter
def n_behaviors(self, n):
self.clf_config["n_behaviors"] = n
@property
def n_layers(self):
try:
return self.clf_config["clf_depth"] + 1
except KeyError:
return self._n_layers
@n_behaviors.setter
def n_layers(self, n):
self.clf_config["clf_depth"] = 0 if n is None else n - 1
def window(self, data):
winlen = self.config["len_traj"]
N = data.shape[0]+1
......@@ -190,12 +225,15 @@ class DenseLayer:
if train:
return self.model(x)
else:
if not isinstance(x, torch.Tensor):
if isinstance(x, torch.Tensor):
if x.dtype is not torch.float32:
x = x.to(torch.float32)
else:
x = torch.from_numpy(x.astype(np.float32))
y = self.model(x.to(self.device))
return y.cpu().numpy()
def train(self, dataset):
def prepare_dataset(self, dataset):
try:
dataset.batch_size
except AttributeError:
......@@ -214,6 +252,9 @@ class DenseLayer:
if not (0 <= midpoint - before and midpoint + after <= dataset.window_length):
raise ValueError(f"the dataset can provide segments of up to {dataset.window_length} time points")
dataset._mask = slice(midpoint - before, midpoint + after)
def train(self, dataset):
self.prepare_dataset(dataset)
#
enc_path = "best_validated_encoder.pt"
if self.prepend_log_dir:
......@@ -227,6 +268,7 @@ class DenseLayer:
n_behaviors=self.n_behaviors,
enc_config=self.config,
enc_path=enc_path,
n_layers=self.n_layers,
)
model.train() # this only sets the model in training mode (enables gradients)
model.to(self.device)
......@@ -256,38 +298,53 @@ class DenseLayer:
#
return self
def draw(self, dataset):
data, expected = dataset.getsample()
def draw(self, dataset, subset="train"):
data, expected = dataset.getobs(subset)
if isinstance(data, list):
data = torch.stack(data)
data = data.to(torch.float32).to(self.device)
if isinstance(expected, list):
expected = torch.stack(expected)
expected = expected.to(torch.long).to(self.device)
if subset.startswith("train"):
expected = expected.to(torch.long).to(self.device)
return data, expected
@torch.no_grad()
def predict(self, all_spines):
data = self.preprocess(all_spines)
if data is None:
return
def predict(self, data, subset=None):
self.model = model = SupervisedMaggot(
n_latent_features=self.config["dim_latent"],
n_behaviors=self.n_behaviors,
enc_config=self.config,
enc_path=self.enc_path,
clf_path=self.clf_path,
n_layers=self.n_layers,
)
model.eval()
model.to(self.device)
output = self.forward(data)
label_ids = np.argmax(output, axis=1)
try:
self.labels
except AttributeError:
self.labels = self.clf_config["behavior_labels"]
labels = [self.labels[label] for label in label_ids]
return labels
if subset is None:
data = self.preprocess(data)
if data is None:
return
output = self.forward(data)
label_ids = np.argmax(output, axis=1)
try:
self.labels
except AttributeError:
self.labels = self.clf_config["behavior_labels"]
labels = [self.labels[label] for label in label_ids]
return labels
else:
dataset = data
self.prepare_dataset(dataset)
predicted, expected = [], []
for data, exp in dataset.getsample(subset, "all"):
output = self.forward(data)
pred = np.argmax(output, axis=1)
exp = exp.numpy()
assert pred.size == exp.size
predicted.append(pred)
expected.append(exp)
return np.concatenate(predicted), np.concatenate(expected)
def save(self, config_path="clf_config.json", config_only=False):
if self.prepend_log_dir:
......@@ -302,8 +359,8 @@ class DenseLayer:
clf_path=self.clf_path,
n_behaviors=self.n_behaviors,
behavior_labels=self.labels,
clf_depth=self.n_layers - 1,
# additional information (not reused):
clf_depth=0,
bias=True,
init="xavier",
loss="cross-entropy",
......@@ -311,3 +368,5 @@ class DenseLayer:
target=["present"],
), f, indent=2)
def new_generator():
return torch.Generator(device).manual_seed(42)
......@@ -3,11 +3,11 @@ from taggingbackends.data.chore import load_spine
import taggingbackends.data.fimtrack as fimtrack
from taggingbackends.data.labels import Labels
from taggingbackends.features.skeleton import get_5point_spines
from maggotuba.models.denselayer import DenseLayer
from maggotuba.models.denselayer import DenseLayer, new_generator
import numpy as np
import json
def predict_model(backend):
def predict_model(backend, **kwargs):
"""
This function generates predicted labels for all the input data.
......@@ -27,10 +27,32 @@ def predict_model(backend):
# initialize output labels
input_files, labels = backend.prepare_labels(input_files)
assert 0 < len(input_files)
# load the model
model_files = backend.list_model_files()
config_file = [file for file in model_files if file.name.endswith("config.json")]
if 1 < len(config_file):
config_file = [file for file in config_file if file.name.endswith("clf_config.json")]
model = DenseLayer(config_file[-1])
#
labels.labelspec = model.clf_config["behavior_labels"]
#
if len(input_files) == 1:
file = input_files[0]
if file.name.startswith("larva_dataset_") and file.name.endswith(".hdf5"):
ret = predict_larva_dataset(backend, model, file, labels, **kwargs)
return labels if ret is None else ret
#
ret = predict_individual_data_files(backend, model, input_files, labels)
return labels if ret is None else ret
def predict_individual_data_files(backend, model, input_files, labels):
_break = False # for now, a single file can be labelled at a time
for file in input_files:
# load the input data (or features)
if file.name.endswith(".spine"):
if _break:
print(f"ignoring file: {file.name}")
continue
elif file.name.endswith(".spine"):
spine = load_spine(file)
run = spine["date_time"].iloc[0]
larvae = spine["larva_id"].values
......@@ -53,6 +75,7 @@ def predict_model(backend):
t, data = fimtrack.read_spines(file, fps=labels.camera_framerate)
run = "NA"
else:
print(f"ignoring file: {file.name}")
continue
# downsample the skeleton
if isinstance(data, dict):
......@@ -60,12 +83,6 @@ def predict_model(backend):
data[larva] = get_5point_spines(data[larva])
else:
data = get_5point_spines(data)
# load the model
model_files = backend.list_model_files()
config_file = [file for file in model_files if file.name.endswith("config.json")]
if 1 < len(config_file):
config_file = [file for file in config_file if file.name.endswith("clf_config.json")]
model = DenseLayer(config_file[-1])
# assign labels
if isinstance(data, dict):
ref_length = np.median(np.concatenate([
......@@ -91,13 +108,14 @@ def predict_model(backend):
else:
labels[run, larva] = dict(zip(t[mask], predictions))
# save the predicted labels to file
# labels.labelspec = {
# "names": ["run", "bend", "stop", "hunch", "back", "roll"],
# "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff",
# "#00ffff", "#ffff00"]
# }
labels.labelspec = model.clf_config["behavior_labels"]
labels.dump(backend.processed_data_dir() / "predicted.labels")
#
_break = True
def predict_larva_dataset(backend, model, file, labels, subset="validation"):
from taggingbackends.data.dataset import LarvaDataset
dataset = LarvaDataset(file, new_generator())
return model.predict(dataset, subset)
from taggingbackends.main import main
......
from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset
from maggotuba.models.denselayer import DenseLayer, device
from maggotuba.models.denselayer import DenseLayer, new_generator
import numpy as np
import json
import torch
import os
import glob
def train_model(backend, pretrained_model_instance="default"):
def train_model(backend, layers=1, pretrained_model_instance="default"):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
assert len(larva_dataset_file) == 1
dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42))
dataset = LarvaDataset(larva_dataset_file[0], new_generator())
nlabels = len(dataset.labels)
assert 0 < nlabels
# copy the pretrained model into the model instance directory
......@@ -40,7 +40,8 @@ def train_model(backend, pretrained_model_instance="default"):
with open(str(dst), "wb") as o:
o.write(i.read())
# load the pretrained model
model = DenseLayer(autoencoder_config=config_file, n_behaviors=nlabels)
model = DenseLayer(autoencoder_config=config_file, n_behaviors=nlabels,
n_layers=layers)
# fine-tune and save the model
model.train(dataset)
model.save()
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment