Skip to content
Snippets Groups Projects
Commit 53daae46 authored by François  LAURENT's avatar François LAURENT
Browse files

additional comments

parent 5c9c32f7
Branches
Tags
No related merge requests found
...@@ -22,9 +22,11 @@ def predict_model(backend, **kwargs): ...@@ -22,9 +22,11 @@ def predict_model(backend, **kwargs):
if not input_files: if not input_files:
input_files = backend.list_input_files(group_by_directories=True) input_files = backend.list_input_files(group_by_directories=True)
assert 0 < len(input_files) assert 0 < len(input_files)
# initialize output labels # initialize output labels
input_files_and_labels = backend.prepare_labels(input_files) input_files_and_labels = backend.prepare_labels(input_files)
assert 0 < len(input_files_and_labels) assert 0 < len(input_files_and_labels)
# load the model # load the model
model_files = backend.list_model_files() model_files = backend.list_model_files()
config_files = [file config_files = [file
...@@ -46,14 +48,14 @@ def predict_model(backend, **kwargs): ...@@ -46,14 +48,14 @@ def predict_model(backend, **kwargs):
model = MultiscaleMaggotTrainer(config_file) model = MultiscaleMaggotTrainer(config_file)
else: else:
model = MaggotBagging(config_files) model = MaggotBagging(config_files)
#
# call the `predict` logic on the input data files
if len(input_files) == 1: if len(input_files) == 1:
input_files = next(iter(input_files.values())) input_files = next(iter(input_files.values()))
if len(input_files) == 1: if len(input_files) == 1:
file = input_files[0] file = input_files[0]
if file.name.startswith("larva_dataset_") and file.name.endswith(".hdf5"): if file.name.startswith("larva_dataset_") and file.name.endswith(".hdf5"):
return predict_larva_dataset(backend, model, file, **kwargs) return predict_larva_dataset(backend, model, file, **kwargs)
#
predict_individual_data_files(backend, model, input_files_and_labels) predict_individual_data_files(backend, model, input_files_and_labels)
def predict_individual_data_files(backend, model, input_files_and_labels): def predict_individual_data_files(backend, model, input_files_and_labels):
......
...@@ -7,16 +7,26 @@ import glob ...@@ -7,16 +7,26 @@ import glob
def train_model(backend, layers=1, pretrained_model_instance="default", def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs): subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ # list training data files;
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive # we actually expect a single larva_dataset file that make_dataset generated
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) # or moved into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # this one is recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive
assert len(larva_dataset_file) == 1 assert len(larva_dataset_file) == 1
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets
# instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader
# add can initialize a Labels object
# note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed), dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs) subsets=subsets, **kwargs)
# initialize a Labels object
labels = dataset.labels labels = dataset.labels
assert 0 < len(labels) assert 0 < len(labels)
# the labels may be bytes objects; convert to str
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
# copy and load the pretrained model into the model instance directory # copy and load the pretrained model into the model instance directory
if isinstance(pretrained_model_instance, str): if isinstance(pretrained_model_instance, str):
config_file = import_pretrained_model(backend, pretrained_model_instance) config_file = import_pretrained_model(backend, pretrained_model_instance)
...@@ -25,16 +35,26 @@ def train_model(backend, layers=1, pretrained_model_instance="default", ...@@ -25,16 +35,26 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
pretrained_model_instances = pretrained_model_instance pretrained_model_instances = pretrained_model_instance
config_files = import_pretrained_models(backend, pretrained_model_instances) config_files = import_pretrained_models(backend, pretrained_model_instances)
model = make_trainer(config_files, labels, layers, iterations) model = make_trainer(config_files, labels, layers, iterations)
# fine-tune the model
# fine-tune the pretrained model on the loaded dataset
model.train(dataset) model.train(dataset)
# add post-prediction rule ABC -> AAC
# add post-prediction rule ABC -> AAC;
# see https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/62
model.clf_config['post_filters'] = ['ABC->AAC'] model.clf_config['post_filters'] = ['ABC->AAC']
# save the model # save the model
print(f"saving model \"{backend.model_instance}\"") print(f"saving model \"{backend.model_instance}\"")
model.save() model.save()
# TODO: merge the below two functions # TODO: merge the below two functions
"""
The files of the pretrained model are located in the `pretrained_models`
directory. Importing a pretrained model consists in creating a directory in
the `models` directory, named by the instance, and copying the model files.
The train step will make more files in the model instance directory.
"""
def import_pretrained_model(backend, pretrained_model_instance): def import_pretrained_model(backend, pretrained_model_instance):
pretrained_autoencoder_dir = backend.project_dir / "pretrained_models" / pretrained_model_instance pretrained_autoencoder_dir = backend.project_dir / "pretrained_models" / pretrained_model_instance
config_file = None config_file = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment