Skip to content
Snippets Groups Projects
Commit d12bb23c authored by François  LAURENT's avatar François LAURENT
Browse files

pretrained_model_instance arg

parent 18710e82
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import os import os
import glob import glob
def train_model(backend): def train_model(backend, pretrained_model_instance="default"):
# make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/
#larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster)
...@@ -16,7 +16,7 @@ def train_model(backend): ...@@ -16,7 +16,7 @@ def train_model(backend):
nlabels = len(dataset.labels) nlabels = len(dataset.labels)
assert 0 < nlabels assert 0 < nlabels
# copy the pretrained model into the model instance directory # copy the pretrained model into the model instance directory
pretrained_autoencoder_dir = backend.model_dir() / ".." pretrained_autoencoder_dir = backend.model_dir() / "pretrained_models" / pretrained_model_instance
config_file = None config_file = None
for file in pretrained_autoencoder_dir.iterdir(): for file in pretrained_autoencoder_dir.iterdir():
if not file.is_file(): if not file.is_file():
...@@ -28,9 +28,9 @@ def train_model(backend): ...@@ -28,9 +28,9 @@ def train_model(backend):
dir = backend.model_dir().relative_to(backend.project_dir) dir = backend.model_dir().relative_to(backend.project_dir)
config["log_dir"] = str(dir) config["log_dir"] = str(dir)
# optional updates? # optional updates?
config["project_dir"] = config["exp_folder"] = str(dir) #config["project_dir"] = config["exp_folder"] = str(dir)
config["exp_name"] = backend.model_instance #config["exp_name"] = backend.model_instance
config["config"] = str(dir / os.path.basename(config["config"])) #config["config"] = str(dir / os.path.basename(config["config"]))
with open(str(dst), "w") as f: with open(str(dst), "w") as f:
json.dump(config, f, indent=2) json.dump(config, f, indent=2)
assert config_file is None assert config_file is None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment