From 2646a8ebc298e5821b84a14b1bc6487ae5bc7221 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Sun, 16 Apr 2023 22:37:39 +0200
Subject: [PATCH] fixes
 https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/112

---
 LICENSE                             | 2 +-
 src/maggotuba/models/modules.py     | 7 +++++--
 src/maggotuba/models/train_model.py | 9 +++++++++
 3 files changed, 15 insertions(+), 3 deletions(-)

diff --git a/LICENSE b/LICENSE
index e047d11..1473c0c 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
 MIT License
 
-Copyright (c) 2022 François Laurent, Institut Pasteur
+Copyright (c) 2022-2023 François Laurent, Institut Pasteur
 
 Permission is hereby granted, free of charge, to any person obtaining a copy
 of this software and associated documentation files (the "Software"), to deal
diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py
index a598e01..9ea3bba 100644
--- a/src/maggotuba/models/modules.py
+++ b/src/maggotuba/models/modules.py
@@ -106,7 +106,9 @@ class MaggotModule(nn.Module):
             self.config[entry] = self.path_for_config(Path(self.config[entry]))
 
     def path_for_config(self, path):
-        if self.root_dir and path.is_absolute():
+        if path.name.endswith('.pt'):
+            path = path.name
+        elif self.root_dir and path.is_absolute():
             # Path.is_relative_to available from Python >= 3.9 only;
             # we want the present code to run on Python >= 3.8
             relativepath = path.relative_to(self.root_dir)
@@ -175,6 +177,7 @@ class MaggotEncoder(MaggotModule):
             try:
                 encoder.load_state_dict(torch.load(path))
             except Exception as e:
+                raise
                 _reason = e
                 config['load_state'] = False # for `was_pretrained` to properly work
         else:
@@ -182,7 +185,7 @@ class MaggotEncoder(MaggotModule):
         # if state file not found or config option "load_state" is False,
         # (re-)initialize the model's weights
         if _reason:
-            logging.debug(f"initializing the encoder ({_reason})")
+            logging.info(f"initializing the encoder ({_reason})")
             _init, _bias = config.get('init', None), config.get('bias', None)
             for child in encoder.children():
                 if isinstance(child,
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index aedc215..17d0ec0 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -12,6 +12,15 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
     larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive
     assert len(larva_dataset_file) == 1
 
+    # argument `rng_seed` predates `seed`
+    try:
+        seed = kwargs.pop('seed')
+    except KeyError:
+        pass
+    else:
+        if rng_seed is None:
+            rng_seed = seed
+
     # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader
     # add can initialize a Labels object
     # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets
-- 
GitLab