From d12bb23c5e1a9b8a6bcca44ccc89fd42ab4bc1e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Mon, 1 Aug 2022 09:48:08 +0200 Subject: [PATCH] pretrained_model_instance arg --- .../default}/autoencoder_config.json | 0 .../default}/best_validated_encoder.pt | Bin src/maggotuba/models/train_model.py | 10 +++++----- 3 files changed, 5 insertions(+), 5 deletions(-) rename {models => pretrained_models/default}/autoencoder_config.json (100%) rename {models => pretrained_models/default}/best_validated_encoder.pt (100%) diff --git a/models/autoencoder_config.json b/pretrained_models/default/autoencoder_config.json similarity index 100% rename from models/autoencoder_config.json rename to pretrained_models/default/autoencoder_config.json diff --git a/models/best_validated_encoder.pt b/pretrained_models/default/best_validated_encoder.pt similarity index 100% rename from models/best_validated_encoder.pt rename to pretrained_models/default/best_validated_encoder.pt diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 0401cd7..5527b97 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -7,7 +7,7 @@ import torch import os import glob -def train_model(backend): +def train_model(backend, pretrained_model_instance="default"): # make_dataset generated or moved the larva_dataset file into data/interim/{instance}/ #larva_dataset_file = backend.list_interim_files("larva_dataset_*.hdf5") # recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # not recursive (faster) @@ -16,7 +16,7 @@ def train_model(backend): nlabels = len(dataset.labels) assert 0 < nlabels # copy the pretrained model into the model instance directory - pretrained_autoencoder_dir = backend.model_dir() / ".." + pretrained_autoencoder_dir = backend.model_dir() / "pretrained_models" / pretrained_model_instance config_file = None for file in pretrained_autoencoder_dir.iterdir(): if not file.is_file(): @@ -28,9 +28,9 @@ def train_model(backend): dir = backend.model_dir().relative_to(backend.project_dir) config["log_dir"] = str(dir) # optional updates? - config["project_dir"] = config["exp_folder"] = str(dir) - config["exp_name"] = backend.model_instance - config["config"] = str(dir / os.path.basename(config["config"])) + #config["project_dir"] = config["exp_folder"] = str(dir) + #config["exp_name"] = backend.model_instance + #config["config"] = str(dir / os.path.basename(config["config"])) with open(str(dst), "w") as f: json.dump(config, f, indent=2) assert config_file is None -- GitLab