From 3bb5da6daf24ecc1168ca74f401156310c403033 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Mon, 17 Apr 2023 20:56:50 +0200 Subject: [PATCH] implements https://gitlab.pasteur.fr/nyx/larvatagger.jl/-/issues/110 --- src/maggotuba/models/modules.py | 1 - src/maggotuba/models/train_model.py | 5 ++++- src/maggotuba/models/trainers.py | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index 9ea3bba..19912ad 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -177,7 +177,6 @@ class MaggotEncoder(MaggotModule): try: encoder.load_state_dict(torch.load(path)) except Exception as e: - raise _reason = e config['load_state'] = False # for `was_pretrained` to properly work else: diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index 17d0ec0..c966538 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -1,6 +1,6 @@ from taggingbackends.data.labels import Labels from taggingbackends.data.dataset import LarvaDataset -from maggotuba.models.trainers import make_trainer, new_generator +from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility import glob def train_model(backend, layers=1, pretrained_model_instance="default", @@ -34,6 +34,9 @@ def train_model(backend, layers=1, pretrained_model_instance="default", # the labels may be bytes objects; convert to str labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] + # could be moved into `make_trainer`, but we need it to access the generator + enforce_reproducibility(dataset.generator) + # copy and load the pretrained model into the model instance directory model = make_trainer(backend, pretrained_model_instance, labels, layers, iterations) diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py index 00f3040..c0fae21 100644 --- a/src/maggotuba/models/trainers.py +++ b/src/maggotuba/models/trainers.py @@ -2,7 +2,6 @@ import numpy as np import torch import torch.nn as nn from behavior_model.models.neural_nets import device -#import behavior_model.data.utils as data_utils from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag from taggingbackends.features.skeleton import interpolate from taggingbackends.explorer import BackendExplorer, check_permissions @@ -252,6 +251,20 @@ def new_generator(seed=None): if seed is None: seed = 0b11010111001001101001110 return generator.manual_seed(seed) +def enforce_reproducibility(generator=None): + import random + if generator is None: + seed = 0b11010111001001101001110 + else: + seed = generator.initial_seed() + # see https://pytorch.org/docs/1.13/notes/randomness.html + torch.use_deterministic_algorithms(True) + # torch.backends.cudnn.deterministic = True + torch.manual_seed(seed) + seed = seed % 2**32 + np.random.seed(seed) + random.seed(seed) + class MultiscaleMaggotTrainer(MaggotTrainer): def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_iterations=None, @@ -314,6 +327,7 @@ def make_trainer(first_arg, *args, **kwargs): else: config_file = first_arg + #enforce_reproducibility() # the type criterion does not fail in the case of unimplemented bagging, # as config files are listed in a pretrained_models subdirectory. -- GitLab