diff --git a/src/maggotuba/models/denselayer.py b/src/maggotuba/models/denselayer.py index 9ae12b8913ce0eb35cbf31f94545a618b2770196..8e89c50810a9fb061da829c24e3d8716db547899 100644 --- a/src/maggotuba/models/denselayer.py +++ b/src/maggotuba/models/denselayer.py @@ -7,21 +7,43 @@ import torch.nn as nn from behavior_model.models.neural_nets import Encoder, device import behavior_model.data.utils as data_utils +class DeepLinear(nn.Module): + def __init__(self, n_input, n_output, n_layers=1): + super().__init__() + if n_layers is None: n_layers = 1 + self.layers = [] + layers = [] + for _ in range(n_layers - 1): + layer = nn.Linear(n_input, n_input) + self.layers.append(layer) + layers.append(layer) + layers.append(nn.ReLU()) + layer = nn.Linear(n_input, n_output) + self.layers.append(layer) + layers.append(layer) + self.classifier = nn.Sequential(*layers) + + def _init_layers(self): + for layer in self.layers: + nn.init.xavier_uniform_(layer.weight) + nn.init.zeros_(layer.bias) + + def forward(self, x): + return self.classifier.forward(x) class SupervisedMaggot(nn.Module): def __init__(self, n_latent_features, n_behaviors, enc_config, enc_path, - clf_path=None): + clf_path=None, n_layers=1): super().__init__() # Pretrained or trained MaggotUBA encoder self.encoder = encoder = Encoder(**enc_config) encoder.load_state_dict(torch.load(enc_path)) # Classifier stacked atop the encoder - self.clf = nn.Linear(n_latent_features, n_behaviors) + self.clf = DeepLinear(n_latent_features, n_behaviors, n_layers) if clf_path: self.clf.load_state_dict(torch.load(clf_path)) else: - nn.init.xavier_uniform_(self.clf.weight) - nn.init.zeros_(self.clf.bias) + self.clf._init_layers() def forward(self, x): #x = torch.flip(x, (2,)) @@ -44,15 +66,15 @@ class DenseLayer: config=None, autoencoder_config=None, n_behaviors=None, + n_layers=1, average_body_length=None, device=device): # MaggotUBA autoencoder config self._config = autoencoder_config self._clf_config = config self.prepend_log_dir = True - self._model = None - if n_behaviors is not None: - self.n_behaviors = n_behaviors + self._n_behaviors = n_behaviors + self._n_layers = n_layers self.average_body_length = average_body_length self.device = device @@ -66,8 +88,10 @@ class DenseLayer: if self._config is None: self._config = self.clf_config.get("autoencoder_config", None) if isinstance(self._config, (str, pathlib.Path)): - with open(self._config, "r") as f: + path = self._config + with open(path, "r") as f: self._config = json.load(f) + self._config["config"] = str(path) return self._config @config.setter @@ -117,12 +141,23 @@ class DenseLayer: @property def n_behaviors(self): - return self.clf_config.get("n_behaviors", None) + return self.clf_config.get("n_behaviors", self._n_behaviors) @n_behaviors.setter def n_behaviors(self, n): self.clf_config["n_behaviors"] = n + @property + def n_layers(self): + try: + return self.clf_config["clf_depth"] + 1 + except KeyError: + return self._n_layers + + @n_behaviors.setter + def n_layers(self, n): + self.clf_config["clf_depth"] = 0 if n is None else n - 1 + def window(self, data): winlen = self.config["len_traj"] N = data.shape[0]+1 @@ -190,12 +225,15 @@ class DenseLayer: if train: return self.model(x) else: - if not isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor): + if x.dtype is not torch.float32: + x = x.to(torch.float32) + else: x = torch.from_numpy(x.astype(np.float32)) y = self.model(x.to(self.device)) return y.cpu().numpy() - def train(self, dataset): + def prepare_dataset(self, dataset): try: dataset.batch_size except AttributeError: @@ -214,6 +252,9 @@ class DenseLayer: if not (0 <= midpoint - before and midpoint + after <= dataset.window_length): raise ValueError(f"the dataset can provide segments of up to {dataset.window_length} time points") dataset._mask = slice(midpoint - before, midpoint + after) + + def train(self, dataset): + self.prepare_dataset(dataset) # enc_path = "best_validated_encoder.pt" if self.prepend_log_dir: @@ -227,6 +268,7 @@ class DenseLayer: n_behaviors=self.n_behaviors, enc_config=self.config, enc_path=enc_path, + n_layers=self.n_layers, ) model.train() # this only sets the model in training mode (enables gradients) model.to(self.device) @@ -256,38 +298,53 @@ class DenseLayer: # return self - def draw(self, dataset): - data, expected = dataset.getsample() + def draw(self, dataset, subset="train"): + data, expected = dataset.getobs(subset) if isinstance(data, list): data = torch.stack(data) data = data.to(torch.float32).to(self.device) if isinstance(expected, list): expected = torch.stack(expected) - expected = expected.to(torch.long).to(self.device) + if subset.startswith("train"): + expected = expected.to(torch.long).to(self.device) return data, expected @torch.no_grad() - def predict(self, all_spines): - data = self.preprocess(all_spines) - if data is None: - return + def predict(self, data, subset=None): self.model = model = SupervisedMaggot( n_latent_features=self.config["dim_latent"], n_behaviors=self.n_behaviors, enc_config=self.config, enc_path=self.enc_path, clf_path=self.clf_path, + n_layers=self.n_layers, ) model.eval() model.to(self.device) - output = self.forward(data) - label_ids = np.argmax(output, axis=1) - try: - self.labels - except AttributeError: - self.labels = self.clf_config["behavior_labels"] - labels = [self.labels[label] for label in label_ids] - return labels + if subset is None: + data = self.preprocess(data) + if data is None: + return + output = self.forward(data) + label_ids = np.argmax(output, axis=1) + try: + self.labels + except AttributeError: + self.labels = self.clf_config["behavior_labels"] + labels = [self.labels[label] for label in label_ids] + return labels + else: + dataset = data + self.prepare_dataset(dataset) + predicted, expected = [], [] + for data, exp in dataset.getsample(subset, "all"): + output = self.forward(data) + pred = np.argmax(output, axis=1) + exp = exp.numpy() + assert pred.size == exp.size + predicted.append(pred) + expected.append(exp) + return np.concatenate(predicted), np.concatenate(expected) def save(self, config_path="clf_config.json", config_only=False): if self.prepend_log_dir: @@ -302,8 +359,8 @@ class DenseLayer: clf_path=self.clf_path, n_behaviors=self.n_behaviors, behavior_labels=self.labels, + clf_depth=self.n_layers - 1, # additional information (not reused): - clf_depth=0, bias=True, init="xavier", loss="cross-entropy", @@ -311,3 +368,5 @@ class DenseLayer: target=["present"], ), f, indent=2) +def new_generator(): + return torch.Generator(device).manual_seed(42) diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 9ed347fa2ef2fdd50920d52c708b4f4826f16da9..1580c11b89e6a37fe8dc3195d5655247a2777ff3 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -3,11 +3,11 @@ from taggingbackends.data.chore import load_spine import taggingbackends.data.fimtrack as fimtrack from taggingbackends.data.labels import Labels from taggingbackends.features.skeleton import get_5point_spines -from maggotuba.models.denselayer import DenseLayer +from maggotuba.models.denselayer import DenseLayer, new_generator import numpy as np import json -def predict_model(backend): +def predict_model(backend, **kwargs): """ This function generates predicted labels for all the input data. @@ -27,10 +27,32 @@ def predict_model(backend): # initialize output labels input_files, labels = backend.prepare_labels(input_files) assert 0 < len(input_files) + # load the model + model_files = backend.list_model_files() + config_file = [file for file in model_files if file.name.endswith("config.json")] + if 1 < len(config_file): + config_file = [file for file in config_file if file.name.endswith("clf_config.json")] + model = DenseLayer(config_file[-1]) # + labels.labelspec = model.clf_config["behavior_labels"] + # + if len(input_files) == 1: + file = input_files[0] + if file.name.startswith("larva_dataset_") and file.name.endswith(".hdf5"): + ret = predict_larva_dataset(backend, model, file, labels, **kwargs) + return labels if ret is None else ret + # + ret = predict_individual_data_files(backend, model, input_files, labels) + return labels if ret is None else ret + +def predict_individual_data_files(backend, model, input_files, labels): + _break = False # for now, a single file can be labelled at a time for file in input_files: # load the input data (or features) - if file.name.endswith(".spine"): + if _break: + print(f"ignoring file: {file.name}") + continue + elif file.name.endswith(".spine"): spine = load_spine(file) run = spine["date_time"].iloc[0] larvae = spine["larva_id"].values @@ -53,6 +75,7 @@ def predict_model(backend): t, data = fimtrack.read_spines(file, fps=labels.camera_framerate) run = "NA" else: + print(f"ignoring file: {file.name}") continue # downsample the skeleton if isinstance(data, dict): @@ -60,12 +83,6 @@ def predict_model(backend): data[larva] = get_5point_spines(data[larva]) else: data = get_5point_spines(data) - # load the model - model_files = backend.list_model_files() - config_file = [file for file in model_files if file.name.endswith("config.json")] - if 1 < len(config_file): - config_file = [file for file in config_file if file.name.endswith("clf_config.json")] - model = DenseLayer(config_file[-1]) # assign labels if isinstance(data, dict): ref_length = np.median(np.concatenate([ @@ -91,13 +108,14 @@ def predict_model(backend): else: labels[run, larva] = dict(zip(t[mask], predictions)) # save the predicted labels to file - # labels.labelspec = { - # "names": ["run", "bend", "stop", "hunch", "back", "roll"], - # "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff", - # "#00ffff", "#ffff00"] - # } - labels.labelspec = model.clf_config["behavior_labels"] labels.dump(backend.processed_data_dir() / "predicted.labels") + # + _break = True + +def predict_larva_dataset(backend, model, file, labels, subset="validation"): + from taggingbackends.data.dataset import LarvaDataset + dataset = LarvaDataset(file, new_generator()) + return model.predict(dataset, subset) from taggingbackends.main import main diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index c121ddcacc4bc3f697bf8d3b2a9c82873ac3c0ba..a19f121173f0e23c6058daeb811fd0e01b4ee427 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -1,18 +1,18 @@ from taggingbackends.data.labels import Labels from taggingbackends.data.dataset import LarvaDataset -from maggotuba.models.denselayer import DenseLayer, device +from maggotuba.models.denselayer import DenseLayer, new_generator import numpy as np import json import torch import os import glob -def train_model(backend, pretrained_model_instance="default"): +def train_model(backend, layers=1, pretrained_model_instance="default"): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) assert len(larva_dataset_file) == 1 - dataset = LarvaDataset(larva_dataset_file[0], torch.Generator(device).manual_seed(42)) + dataset = LarvaDataset(larva_dataset_file[0], new_generator()) nlabels = len(dataset.labels) assert 0 < nlabels # copy the pretrained model into the model instance directory @@ -40,7 +40,8 @@ def train_model(backend, pretrained_model_instance="default"): with open(str(dst), "wb") as o: o.write(i.read()) # load the pretrained model - model = DenseLayer(autoencoder_config=config_file, n_behaviors=nlabels) + model = DenseLayer(autoencoder_config=config_file, n_behaviors=nlabels, + n_layers=layers) # fine-tune and save the model model.train(dataset) model.save()