From 3bb5da6daf24ecc1168ca74f401156310c403033 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net>
Date: Mon, 17 Apr 2023 20:56:50 +0200
Subject: [PATCH] implements
 https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/110

---
 src/maggotuba/models/modules.py     |  1 -
 src/maggotuba/models/train_model.py |  5 ++++-
 src/maggotuba/models/trainers.py    | 16 +++++++++++++++-
 3 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py
index 9ea3bba..19912ad 100644
--- a/src/maggotuba/models/modules.py
+++ b/src/maggotuba/models/modules.py
@@ -177,7 +177,6 @@ class MaggotEncoder(MaggotModule):
             try:
                 encoder.load_state_dict(torch.load(path))
             except Exception as e:
-                raise
                 _reason = e
                 config['load_state'] = False # for `was_pretrained` to properly work
         else:
diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py
index 17d0ec0..c966538 100644
--- a/src/maggotuba/models/train_model.py
+++ b/src/maggotuba/models/train_model.py
@@ -1,6 +1,6 @@
 from taggingbackends.data.labels import Labels
 from taggingbackends.data.dataset import LarvaDataset
-from maggotuba.models.trainers import make_trainer, new_generator
+from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility
 import glob
 
 def train_model(backend, layers=1, pretrained_model_instance="default",
@@ -34,6 +34,9 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
     # the labels may be bytes objects; convert to str
     labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
 
+    # could be moved into `make_trainer`, but we need it to access the generator
+    enforce_reproducibility(dataset.generator)
+
     # copy and load the pretrained model into the model instance directory
     model = make_trainer(backend, pretrained_model_instance, labels, layers, iterations)
 
diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py
index 00f3040..c0fae21 100644
--- a/src/maggotuba/models/trainers.py
+++ b/src/maggotuba/models/trainers.py
@@ -2,7 +2,6 @@ import numpy as np
 import torch
 import torch.nn as nn
 from behavior_model.models.neural_nets import device
-#import behavior_model.data.utils as data_utils
 from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
 from taggingbackends.features.skeleton import interpolate
 from taggingbackends.explorer import BackendExplorer, check_permissions
@@ -252,6 +251,20 @@ def new_generator(seed=None):
     if seed is None: seed = 0b11010111001001101001110
     return generator.manual_seed(seed)
 
+def enforce_reproducibility(generator=None):
+    import random
+    if generator is None:
+        seed = 0b11010111001001101001110
+    else:
+        seed = generator.initial_seed()
+    # see https://pytorch.org/docs/1.13/notes/randomness.html
+    torch.use_deterministic_algorithms(True)
+    # torch.backends.cudnn.deterministic = True
+    torch.manual_seed(seed)
+    seed = seed % 2**32
+    np.random.seed(seed)
+    random.seed(seed)
+
 
 class MultiscaleMaggotTrainer(MaggotTrainer):
     def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_iterations=None,
@@ -314,6 +327,7 @@ def make_trainer(first_arg, *args, **kwargs):
 
     else:
         config_file = first_arg
+        #enforce_reproducibility()
 
         # the type criterion does not fail in the case of unimplemented bagging,
         # as config files are listed in a pretrained_models subdirectory.
-- 
GitLab