Skip to content
Snippets Groups Projects
Commit 0fe18ae2 authored by François  LAURENT's avatar François LAURENT
Browse files

pretrained_model_instance arg to train_model + copy arg to move_to_raw

parent b7fa1f91
Branches
Tags
No related merge requests found
Pipeline #86070 passed
...@@ -267,18 +267,22 @@ run `poetry add {pkg}` from directory: \n ...@@ -267,18 +267,22 @@ run `poetry add {pkg}` from directory: \n
model_instance, model_instance,
create_if_missing) create_if_missing)
def move_to_raw(self, source): def move_to_raw(self, source, copy=True):
source = pathlib.Path(source) source = pathlib.Path(source)
output_dir = self.raw_data_dir() output_dir = self.raw_data_dir()
destination = output_dir / source.name destination = output_dir / source.name
destination.parent.mkdir(parents=True, exist_ok=True) destination.parent.mkdir(parents=True, exist_ok=True)
if not destination.exists() or not destination.samefile(source):
if copy:
# Q: can we simply rename the file in the case source and dest are # Q: can we simply rename the file in the case source and dest are
# located on different partitions? # located on different partitions?
if not destination.exists() or not destination.samefile(source):
with destination.open('wb') as f: with destination.open('wb') as f:
with source.open('rb') as g: with source.open('rb') as g:
f.write(g.read()) f.write(g.read())
source.unlink() source.unlink()
else:
destination.unlink(missing_ok=True)
destination.symlink_to(source)
def move_to_interim(self, source, destination=None, copy=False): def move_to_interim(self, source, destination=None, copy=False):
""" """
...@@ -307,6 +311,7 @@ run `poetry add {pkg}` from directory: \n ...@@ -307,6 +311,7 @@ run `poetry add {pkg}` from directory: \n
with source.open('rb') as g: with source.open('rb') as g:
f.write(g.read()) f.write(g.read())
else: else:
destination.unlink(missing_ok=True)
destination.symlink_to(source) destination.symlink_to(source)
def move_to_processed(self, source, destination=None, copy=False): def move_to_processed(self, source, destination=None, copy=False):
......
...@@ -9,6 +9,7 @@ Usage: tagging-backend [train|predict] --model-instance <name> ...@@ -9,6 +9,7 @@ Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --labels <comma-separated-list> tagging-backend train ... --labels <comma-separated-list>
tagging-backend train ... --sample-size <N> --window-length <T> tagging-backend train ... --sample-size <N> --window-length <T>
tagging-backend train ... --trxmat-only --reuse-h5files tagging-backend train ... --trxmat-only --reuse-h5files
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend predict ... --skip-make-dataset tagging-backend predict ... --skip-make-dataset
`tagging-backend` typically is run using `poetry run`. `tagging-backend` typically is run using `poetry run`.
...@@ -50,6 +51,7 @@ def main(fun=None): ...@@ -50,6 +51,7 @@ def main(fun=None):
sample_size = window_length = None sample_size = window_length = None
trxmat_only = reuse_h5files = False trxmat_only = reuse_h5files = False
skip_make_dataset = skip_build_features = False skip_make_dataset = skip_build_features = False
pretrained_model_instance = None
k = 2 k = 2
while k < len(sys.argv): while k < len(sys.argv):
if sys.argv[k] == "--project-dir": if sys.argv[k] == "--project-dir":
...@@ -78,6 +80,9 @@ def main(fun=None): ...@@ -78,6 +80,9 @@ def main(fun=None):
skip_make_dataset = True skip_make_dataset = True
elif sys.argv[k] == "--skip-build-features": elif sys.argv[k] == "--skip-build-features":
skip_build_features = True skip_build_features = True
elif sys.argv[k] == "--pretrained-model-instance":
k = k + 1
pretrained_model_instance = sys.argv[k]
else: else:
logging.warning(f"unsupported argument '{sys.argv[k]}'") logging.warning(f"unsupported argument '{sys.argv[k]}'")
k = k + 1 k = k + 1
...@@ -102,9 +107,13 @@ def main(fun=None): ...@@ -102,9 +107,13 @@ def main(fun=None):
backend._run_script(backend.make_dataset, **make_dataset_kwargs) backend._run_script(backend.make_dataset, **make_dataset_kwargs)
if not skip_build_features: if not skip_build_features:
backend._run_script(backend.build_features) backend._run_script(backend.build_features)
backend._run_script(backend.predict_model \ if train_or_predict == "predict":
if train_or_predict == "predict" \ backend._run_script(backend.predict_model)
else backend.train_model) else:
train_kwargs = {}
if pretrained_model_instance:
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
backend._run_script(backend.train_model, **train_kwargs)
else: else:
# called by make_dataset, build_features, train_model and predict_model # called by make_dataset, build_features, train_model and predict_model
backend = BackendExplorerDecoder().decode(sys.argv[1]) backend = BackendExplorerDecoder().decode(sys.argv[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment