From 2646a8ebc298e5821b84a14b1bc6487ae5bc7221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Sun, 16 Apr 2023 22:37:39 +0200 Subject: [PATCH] fixes https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/112 --- LICENSE | 2 +- src/maggotuba/models/modules.py | 7 +++++-- src/maggotuba/models/train_model.py | 9 +++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/LICENSE b/LICENSE index e047d11..1473c0c 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 François Laurent, Institut Pasteur +Copyright (c) 2022-2023 François Laurent, Institut Pasteur Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index a598e01..9ea3bba 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -106,7 +106,9 @@ class MaggotModule(nn.Module): self.config[entry] = self.path_for_config(Path(self.config[entry])) def path_for_config(self, path): - if self.root_dir and path.is_absolute(): + if path.name.endswith('.pt'): + path = path.name + elif self.root_dir and path.is_absolute(): # Path.is_relative_to available from Python >= 3.9 only; # we want the present code to run on Python >= 3.8 relativepath = path.relative_to(self.root_dir) @@ -175,6 +177,7 @@ class MaggotEncoder(MaggotModule): try: encoder.load_state_dict(torch.load(path)) except Exception as e: + raise _reason = e config['load_state'] = False # for `was_pretrained` to properly work else: @@ -182,7 +185,7 @@ class MaggotEncoder(MaggotModule): # if state file not found or config option "load_state" is False, # (re-)initialize the model's weights if _reason: - logging.debug(f"initializing the encoder ({_reason})") + logging.info(f"initializing the encoder ({_reason})") _init, _bias = config.get('init', None), config.get('bias', None) for child in encoder.children(): if isinstance(child, diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index aedc215..17d0ec0 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -12,6 +12,15 @@ def train_model(backend, layers=1, pretrained_model_instance="default", larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive assert len(larva_dataset_file) == 1 + # argument `rng_seed` predates `seed` + try: + seed = kwargs.pop('seed') + except KeyError: + pass + else: + if rng_seed is None: + rng_seed = seed + # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader # add can initialize a Labels object # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets -- GitLab