diff --git a/models/autoencoder_config.json b/pretrained_models/default/autoencoder_config.json similarity index 100% rename from models/autoencoder_config.json rename to pretrained_models/default/autoencoder_config.json diff --git a/models/best_validated_encoder.pt b/pretrained_models/default/best_validated_encoder.pt similarity index 100% rename from models/best_validated_encoder.pt rename to pretrained_models/default/best_validated_encoder.pt diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 0401cd76cba4377ddff3fd0856fe30646ebc5b28..5527b97429f2933f1afcdd178838bed557a12199 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -7,7 +7,7 @@ import torch import os import glob -def train_model(backend): +def train_model(backend, pretrained_model_instance="default"): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) @@ -16,7 +16,7 @@ def train_model(backend): nlabels = len(dataset.labels) assert 0 < nlabels # copy the pretrained model into the model instance directory - pretrained_autoencoder_dir = backend.model_dir() / ".." + pretrained_autoencoder_dir = backend.model_dir() / "pretrained_models" / pretrained_model_instance config_file = None for file in pretrained_autoencoder_dir.iterdir(): if not file.is_file(): @@ -28,9 +28,9 @@ def train_model(backend): dir = backend.model_dir().relative_to(backend.project_dir) config["log_dir"] = str(dir) # optional updates? - config["project_dir"] = config["exp_folder"] = str(dir) - config["exp_name"] = backend.model_instance - config["config"] = str(dir / os.path.basename(config["config"])) + #config["project_dir"] = config["exp_folder"] = str(dir) + #config["exp_name"] = backend.model_instance + #config["config"] = str(dir / os.path.basename(config["config"])) with open(str(dst), "w") as f: json.dump(config, f, indent=2) assert config_file is None