Skip to content
Snippets Groups Projects
Commit 4a660112 authored by François  LAURENT's avatar François LAURENT
Browse files

implements bagging (#larvatagger.jl#63, prediction only)

parent becb7f82
No related branches found
No related tags found
No related merge requests found
...@@ -380,6 +380,9 @@ class SupervisedMaggot(nn.Module): ...@@ -380,6 +380,9 @@ class SupervisedMaggot(nn.Module):
def forward(self, x): def forward(self, x):
return self.clf(self.encoder(x)) return self.clf(self.encoder(x))
def mask_forward(self, x):
return self.clf(self.encoder.mask_forward(x))
def save(self): def save(self):
enc, clf = self.encoder, self.clf enc, clf = self.encoder, self.clf
enc.save() enc.save()
...@@ -423,3 +426,45 @@ class MultiscaleSupervisedMaggot(nn.Module): ...@@ -423,3 +426,45 @@ class MultiscaleSupervisedMaggot(nn.Module):
self.clf.model # force parameter loading or initialization self.clf.model # force parameter loading or initialization
return super().parameters(self) return super().parameters(self)
"""
Bagging for `SupervisedMaggot`.
Taggers in the bag are individually trained.
For now the bag itself cannot be trained and is used for prediction only.
Bags of taggers are stored so that the models directory only contains
subdirectories, each subdirectory specifying an individual tagger.
"""
class MaggotBag(nn.Module):
def __init__(self, paths, behaviors=[], n_layers=1, cls=SupervisedMaggot):
super().__init__()
self.maggots = [cls(path, behaviors, n_layers) for path in paths]
self._lead_maggot = None
def forward(self, x):
#return torch.cat([encoder.mask_forward(x) for encoder in self.encoders], dim=1)
return self.vote([maggot.mask_forward(x) for maggot in self.maggots])
def vote(self, y):
vote, _ = torch.mode(torch.stack(y, dim=len(y[0].shape)))
return vote
@property
def encoder(self):
return self.maggots[self.lead_maggot].encoder
@property
def clf(self):
return self.maggots[self.lead_maggot].clf
@property
def lead_maggot(self):
if self._lead_maggot is None:
len_traj = 0
for i, maggot in enumerate(self.maggots):
len_traj_ = maggot.encoder.config['len_traj']
if len_traj < len_traj_:
len_traj = len_traj_
self._lead_maggot = i
return self._lead_maggot
from taggingbackends.data.labels import Labels from taggingbackends.data.labels import Labels
from taggingbackends.features.skeleton import get_5point_spines from taggingbackends.features.skeleton import get_5point_spines
from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, MaggotBagging, new_generator
import numpy as np import numpy as np
import logging import logging
...@@ -27,20 +27,25 @@ def predict_model(backend, **kwargs): ...@@ -27,20 +27,25 @@ def predict_model(backend, **kwargs):
assert 0 < len(input_files_and_labels) assert 0 < len(input_files_and_labels)
# load the model # load the model
model_files = backend.list_model_files() model_files = backend.list_model_files()
config_file = [file for file in model_files if file.name.endswith("config.json")] config_files = [file
n_config_files = len(config_file) for file in model_files
if n_config_files == 0: if file.name.endswith('config.json')]
raise RuntimeError(f"no such tagger found: {backend.model_instance}") if len(config_files) == 0:
config_file = [file raise RuntimeError(f"no config files found for tagger: {backend.model_instance}")
for file in config_file single_encoder_classifier = len(config_files) == 2
if file.name.endswith("clf_config.json") config_files = [file
and file.parent == backend.model_dir()] for file in config_files
assert len(config_file) == 1 if file.name == 'clf_config.json']
config_file = config_file[-1] if len(config_files) == 0:
if 2 < n_config_files: raise RuntimeError(f"no classifier config files found; is {backend.model_instance} tagger trained?")
model = MultiscaleMaggotTrainer(config_file) elif len(config_files) == 1:
config_file = config_files[0]
if single_encoder_classifier:
model = MaggotTrainer(config_file)
else:
model = MultiscaleMaggotTrainer(config_file)
else: else:
model = MaggotTrainer(config_file) model = MaggotBagging(config_files)
# #
if len(input_files) == 1: if len(input_files) == 1:
input_files = next(iter(input_files.values())) input_files = next(iter(input_files.values()))
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from behavior_model.models.neural_nets import device from behavior_model.models.neural_nets import device
#import behavior_model.data.utils as data_utils #import behavior_model.data.utils as data_utils
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
from taggingbackends.features.skeleton import interpolate from taggingbackends.features.skeleton import interpolate
""" """
...@@ -265,6 +265,14 @@ class MultiscaleMaggotTrainer(MaggotTrainer): ...@@ -265,6 +265,14 @@ class MultiscaleMaggotTrainer(MaggotTrainer):
return self._default_encoder_config return self._default_encoder_config
class MaggotBagging(MaggotTrainer):
def __init__(self, cfgfilepaths, behaviors=[], n_layers=1,
average_body_length=1.0, device=device):
self.model = MaggotBag(cfgfilepaths, behaviors, n_layers)
self.average_body_length = average_body_length # usually set later
self.device = device
""" """
Pick the adequate trainer following a rapid inspection of the config file(s). Pick the adequate trainer following a rapid inspection of the config file(s).
...@@ -272,6 +280,8 @@ For now, config files are actually not inspected. However, using this function ...@@ -272,6 +280,8 @@ For now, config files are actually not inspected. However, using this function
is highly recommended as more models are introduced with future releases. is highly recommended as more models are introduced with future releases.
""" """
def make_trainer(config_file, *args, **kwargs): def make_trainer(config_file, *args, **kwargs):
# the type criterion does not fail in the case of unimplemented bagging,
# as config files are listed in a pretrained_models subdirectory.
if isinstance(config_file, list): # multiple encoders if isinstance(config_file, list): # multiple encoders
config_files = config_file config_files = config_file
model = MultiscaleMaggotTrainer(config_files, *args, **kwargs) model = MultiscaleMaggotTrainer(config_files, *args, **kwargs)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment