Skip to content
Snippets Groups Projects
Commit 53daae46 authored by François  LAURENT's avatar François LAURENT
Browse files

additional comments

parent 5c9c32f7
No related branches found
No related tags found
No related merge requests found
......@@ -22,9 +22,11 @@ def predict_model(backend, **kwargs):
if not input_files:
input_files = backend.list_input_files(group_by_directories=True)
assert 0 < len(input_files)
# initialize output labels
input_files_and_labels = backend.prepare_labels(input_files)
assert 0 < len(input_files_and_labels)
# load the model
model_files = backend.list_model_files()
config_files = [file
......@@ -46,14 +48,14 @@ def predict_model(backend, **kwargs):
model = MultiscaleMaggotTrainer(config_file)
else:
model = MaggotBagging(config_files)
#
# call the `predict` logic on the input data files
if len(input_files) == 1:
input_files = next(iter(input_files.values()))
if len(input_files) == 1:
file = input_files[0]
if file.name.startswith("larva_dataset_") and file.name.endswith(".hdf5"):
return predict_larva_dataset(backend, model, file, **kwargs)
#
predict_individual_data_files(backend, model, input_files_and_labels)
def predict_individual_data_files(backend, model, input_files_and_labels):
......
......@@ -7,16 +7,26 @@ import glob
def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
# list training data files;
# we actually expect a single larva_dataset file that make_dataset generated
# or moved into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # this one is recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive
assert len(larva_dataset_file) == 1
# subsets=(1, 0, 0) => all data are training data; no validation or test subsets
# instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader
# add can initialize a Labels object
# note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs)
# initialize a Labels object
labels = dataset.labels
assert 0 < len(labels)
# the labels may be bytes objects; convert to str
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
# copy and load the pretrained model into the model instance directory
if isinstance(pretrained_model_instance, str):
config_file = import_pretrained_model(backend, pretrained_model_instance)
......@@ -25,16 +35,26 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
pretrained_model_instances = pretrained_model_instance
config_files = import_pretrained_models(backend, pretrained_model_instances)
model = make_trainer(config_files, labels, layers, iterations)
# fine-tune the model
# fine-tune the pretrained model on the loaded dataset
model.train(dataset)
# add post-prediction rule ABC -> AAC
# add post-prediction rule ABC -> AAC;
# see https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/62
model.clf_config['post_filters'] = ['ABC->AAC']
# save the model
print(f"saving model \"{backend.model_instance}\"")
model.save()
# TODO: merge the below two functions
"""
The files of the pretrained model are located in the `pretrained_models`
directory. Importing a pretrained model consists in creating a directory in
the `models` directory, named by the instance, and copying the model files.
The train step will make more files in the model instance directory.
"""
def import_pretrained_model(backend, pretrained_model_instance):
pretrained_autoencoder_dir = backend.project_dir / "pretrained_models" / pretrained_model_instance
config_file = None
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment