Skip to content
Snippets Groups Projects
Commit 3bb5da6d authored by François  LAURENT's avatar François LAURENT
Browse files

implements larvatagger.jl#110

parent 2646a8eb
No related branches found
No related tags found
No related merge requests found
...@@ -177,7 +177,6 @@ class MaggotEncoder(MaggotModule): ...@@ -177,7 +177,6 @@ class MaggotEncoder(MaggotModule):
try: try:
encoder.load_state_dict(torch.load(path)) encoder.load_state_dict(torch.load(path))
except Exception as e: except Exception as e:
raise
_reason = e _reason = e
config['load_state'] = False # for `was_pretrained` to properly work config['load_state'] = False # for `was_pretrained` to properly work
else: else:
......
from taggingbackends.data.labels import Labels from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset from taggingbackends.data.dataset import LarvaDataset
from maggotuba.models.trainers import make_trainer, new_generator from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility
import glob import glob
def train_model(backend, layers=1, pretrained_model_instance="default", def train_model(backend, layers=1, pretrained_model_instance="default",
...@@ -34,6 +34,9 @@ def train_model(backend, layers=1, pretrained_model_instance="default", ...@@ -34,6 +34,9 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
# the labels may be bytes objects; convert to str # the labels may be bytes objects; convert to str
labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels] labels = labels if isinstance(labels[0], str) else [s.decode() for s in labels]
# could be moved into `make_trainer`, but we need it to access the generator
enforce_reproducibility(dataset.generator)
# copy and load the pretrained model into the model instance directory # copy and load the pretrained model into the model instance directory
model = make_trainer(backend, pretrained_model_instance, labels, layers, iterations) model = make_trainer(backend, pretrained_model_instance, labels, layers, iterations)
......
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from behavior_model.models.neural_nets import device from behavior_model.models.neural_nets import device
#import behavior_model.data.utils as data_utils
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
from taggingbackends.features.skeleton import interpolate from taggingbackends.features.skeleton import interpolate
from taggingbackends.explorer import BackendExplorer, check_permissions from taggingbackends.explorer import BackendExplorer, check_permissions
...@@ -252,6 +251,20 @@ def new_generator(seed=None): ...@@ -252,6 +251,20 @@ def new_generator(seed=None):
if seed is None: seed = 0b11010111001001101001110 if seed is None: seed = 0b11010111001001101001110
return generator.manual_seed(seed) return generator.manual_seed(seed)
def enforce_reproducibility(generator=None):
import random
if generator is None:
seed = 0b11010111001001101001110
else:
seed = generator.initial_seed()
# see https://pytorch.org/docs/1.13/notes/randomness.html
torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
seed = seed % 2**32
np.random.seed(seed)
random.seed(seed)
class MultiscaleMaggotTrainer(MaggotTrainer): class MultiscaleMaggotTrainer(MaggotTrainer):
def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_iterations=None, def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_iterations=None,
...@@ -314,6 +327,7 @@ def make_trainer(first_arg, *args, **kwargs): ...@@ -314,6 +327,7 @@ def make_trainer(first_arg, *args, **kwargs):
else: else:
config_file = first_arg config_file = first_arg
#enforce_reproducibility()
# the type criterion does not fail in the case of unimplemented bagging, # the type criterion does not fail in the case of unimplemented bagging,
# as config files are listed in a pretrained_models subdirectory. # as config files are listed in a pretrained_models subdirectory.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment