Skip to content
Snippets Groups Projects
Commit 4a660112 authored by François  LAURENT's avatar François LAURENT
Browse files

implements bagging (#larvatagger.jl#63, prediction only)

parent becb7f82
No related branches found
No related tags found
No related merge requests found
......@@ -380,6 +380,9 @@ class SupervisedMaggot(nn.Module):
def forward(self, x):
return self.clf(self.encoder(x))
def mask_forward(self, x):
return self.clf(self.encoder.mask_forward(x))
def save(self):
enc, clf = self.encoder, self.clf
enc.save()
......@@ -423,3 +426,45 @@ class MultiscaleSupervisedMaggot(nn.Module):
self.clf.model # force parameter loading or initialization
return super().parameters(self)
"""
Bagging for `SupervisedMaggot`.
Taggers in the bag are individually trained.
For now the bag itself cannot be trained and is used for prediction only.
Bags of taggers are stored so that the models directory only contains
subdirectories, each subdirectory specifying an individual tagger.
"""
class MaggotBag(nn.Module):
def __init__(self, paths, behaviors=[], n_layers=1, cls=SupervisedMaggot):
super().__init__()
self.maggots = [cls(path, behaviors, n_layers) for path in paths]
self._lead_maggot = None
def forward(self, x):
#return torch.cat([encoder.mask_forward(x) for encoder in self.encoders], dim=1)
return self.vote([maggot.mask_forward(x) for maggot in self.maggots])
def vote(self, y):
vote, _ = torch.mode(torch.stack(y, dim=len(y[0].shape)))
return vote
@property
def encoder(self):
return self.maggots[self.lead_maggot].encoder
@property
def clf(self):
return self.maggots[self.lead_maggot].clf
@property
def lead_maggot(self):
if self._lead_maggot is None:
len_traj = 0
for i, maggot in enumerate(self.maggots):
len_traj_ = maggot.encoder.config['len_traj']
if len_traj < len_traj_:
len_traj = len_traj_
self._lead_maggot = i
return self._lead_maggot
from taggingbackends.data.labels import Labels
from taggingbackends.features.skeleton import get_5point_spines
from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, new_generator
from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, MaggotBagging, new_generator
import numpy as np
import logging
......@@ -27,20 +27,25 @@ def predict_model(backend, **kwargs):
assert 0 < len(input_files_and_labels)
# load the model
model_files = backend.list_model_files()
config_file = [file for file in model_files if file.name.endswith("config.json")]
n_config_files = len(config_file)
if n_config_files == 0:
raise RuntimeError(f"no such tagger found: {backend.model_instance}")
config_file = [file
for file in config_file
if file.name.endswith("clf_config.json")
and file.parent == backend.model_dir()]
assert len(config_file) == 1
config_file = config_file[-1]
if 2 < n_config_files:
model = MultiscaleMaggotTrainer(config_file)
config_files = [file
for file in model_files
if file.name.endswith('config.json')]
if len(config_files) == 0:
raise RuntimeError(f"no config files found for tagger: {backend.model_instance}")
single_encoder_classifier = len(config_files) == 2
config_files = [file
for file in config_files
if file.name == 'clf_config.json']
if len(config_files) == 0:
raise RuntimeError(f"no classifier config files found; is {backend.model_instance} tagger trained?")
elif len(config_files) == 1:
config_file = config_files[0]
if single_encoder_classifier:
model = MaggotTrainer(config_file)
else:
model = MultiscaleMaggotTrainer(config_file)
else:
model = MaggotTrainer(config_file)
model = MaggotBagging(config_files)
#
if len(input_files) == 1:
input_files = next(iter(input_files.values()))
......
......@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from behavior_model.models.neural_nets import device
#import behavior_model.data.utils as data_utils
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
from taggingbackends.features.skeleton import interpolate
"""
......@@ -265,6 +265,14 @@ class MultiscaleMaggotTrainer(MaggotTrainer):
return self._default_encoder_config
class MaggotBagging(MaggotTrainer):
def __init__(self, cfgfilepaths, behaviors=[], n_layers=1,
average_body_length=1.0, device=device):
self.model = MaggotBag(cfgfilepaths, behaviors, n_layers)
self.average_body_length = average_body_length # usually set later
self.device = device
"""
Pick the adequate trainer following a rapid inspection of the config file(s).
......@@ -272,6 +280,8 @@ For now, config files are actually not inspected. However, using this function
is highly recommended as more models are introduced with future releases.
"""
def make_trainer(config_file, *args, **kwargs):
# the type criterion does not fail in the case of unimplemented bagging,
# as config files are listed in a pretrained_models subdirectory.
if isinstance(config_file, list): # multiple encoders
config_files = config_file
model = MultiscaleMaggotTrainer(config_files, *args, **kwargs)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment