Skip to content
Snippets Groups Projects
Select Git revision
  • 17fa2603cbf859bca47a244652b4d81fdd8a1fad
  • main default protected
  • torch2
  • torch1
  • dev protected
  • 20230311_new_default
  • 20230311
  • design protected
  • 20230129
  • 20230111
  • 20221005 protected
  • 20220418 protected
  • v0.20
  • v0.19
  • v0.18
  • v0.17
  • v0.16.4
  • v0.16.3
  • v0.16.2
  • v0.16.1
  • v0.16
  • v0.15
  • v0.14
  • v0.13
  • v0.12.4
  • v0.12.3
  • v0.12.2
  • v0.12.1
  • v0.12
  • v0.11
  • v0.10
  • v0.9.1
32 results

train_model.py

Blame
  • train_model.py 2.21 KiB
    from taggingbackends.data.labels import Labels
    from taggingbackends.data.dataset import LarvaDataset
    from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility
    import glob
    
    def train_model(backend, layers=1, pretrained_model_instance="default",
                    subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs):
        # list training data files;
        # we actually expect a single larva_dataset file that make_dataset generated
        # or moved into data/interim/{instance}/
        #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # this one is recursive
        larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive
        assert len(larva_dataset_file) == 1
    
        # argument `rng_seed` predates `seed`
        try:
            seed = kwargs.pop('seed')
        except KeyError:
            pass
        else:
            if rng_seed is None:
                rng_seed = seed
    
        # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader
        # add can initialize a Labels object
        # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets
        dataset = LarvaDataset(larva_dataset_file[0], new_generator(rng_seed),
                               subsets=subsets, **kwargs)
    
        # initialize a Labels object
        labels = dataset.labels
        assert 0 < len(labels)
    
        # the labels may be bytes objects; convert to str
        labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
    
        # could be moved into `make_trainer`, but we need it to access the generator
        enforce_reproducibility(dataset.generator)
    
        # copy and load the pretrained model into the model instance directory
        model = make_trainer(backend, pretrained_model_instance, labels, layers, iterations)
    
        # fine-tune the pretrained model on the loaded dataset
        model.train(dataset)
    
        # add post-prediction rule ABC -> AAC;
        # see https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/62
        model.clf_config['post_filters'] = ['ABC->AAC']
    
        # save the model
        print(f"saving model \"{backend.model_instance}\"")
        model.save()
    
    
    from taggingbackends.main import main
    
    if __name__ == "__main__":
        main(train_model)