Select Git revision
train_model.py
train_model.py 2.21 KiB
from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset
from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility
import glob
def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs):
# list training data files;
# we actually expect a single larva_dataset file that make_dataset generated
# or moved into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # this one is recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive
assert len(larva_dataset_file) == 1
# argument `rng_seed` predates `seed`
try:
seed = kwargs.pop('seed')
except KeyError:
pass
else:
if rng_seed is None:
rng_seed = seed
# instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader
# add can initialize a Labels object
# note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets
dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
subsets=subsets, **kwargs)
# initialize a Labels object
labels = dataset.labels
assert 0 < len(labels)
# the labels may be bytes objects; convert to str
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
# could be moved into `make_trainer`, but we need it to access the generator
enforce_reproducibility(dataset.generator)
# copy and load the pretrained model into the model instance directory
model = make_trainer(backend, pretrained_model_instance, labels, layers, iterations)
# fine-tune the pretrained model on the loaded dataset
model.train(dataset)
# add post-prediction rule ABC -> AAC;
# see https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/62
model.clf_config['post_filters'] = ['ABC->AAC']
# save the model
print(f"saving model \"{backend.model_instance}\"")
model.save()
from taggingbackends.main import main
if __name__ == "__main__":
main(train_model)