diff --git a/LICENSE b/LICENSE index e047d11a0921d8543224f4dce4482b27c6b8057f..1473c0cede580821495a5e623fb05b8985616a6d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 François Laurent, Institut Pasteur +Copyright (c) 2022-2023 François Laurent, Institut Pasteur Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/maggotuba/models/modules.py b/src/maggotuba/models/modules.py index a598e01eff69e95ee964142fb14d9a2f921e1b6b..9ea3bbaaff29d6c18d3b41a1e6dbbb16fae6575e 100644 --- a/src/maggotuba/models/modules.py +++ b/src/maggotuba/models/modules.py @@ -106,7 +106,9 @@ class MaggotModule(nn.Module): self.config[entry] = self.path_for_config(Path(self.config[entry])) def path_for_config(self, path): - if self.root_dir and path.is_absolute(): + if path.name.endswith('.pt'): + path = path.name + elif self.root_dir and path.is_absolute(): # Path.is_relative_to available from Python >= 3.9 only; # we want the present code to run on Python >= 3.8 relativepath = path.relative_to(self.root_dir) @@ -175,6 +177,7 @@ class MaggotEncoder(MaggotModule): try: encoder.load_state_dict(torch.load(path)) except Exception as e: + raise _reason = e config['load_state'] = False # for `was_pretrained` to properly work else: @@ -182,7 +185,7 @@ class MaggotEncoder(MaggotModule): # if state file not found or config option "load_state" is False, # (re-)initialize the model's weights if _reason: - logging.debug(f"initializing the encoder ({_reason})") + logging.info(f"initializing the encoder ({_reason})") _init, _bias = config.get('init', None), config.get('bias', None) for child in encoder.children(): if isinstance(child, diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index aedc21506816c90214ec0c7737c083461d32d366..17d0ec098f3b992b09f04cc080531d2f69147bdc 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -12,6 +12,15 @@ def train_model(backend, layers=1, pretrained_model_instance="default", larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive assert len(larva_dataset_file) == 1 + # argument `rng_seed` predates `seed` + try: + seed = kwargs.pop('seed') + except KeyError: + pass + else: + if rng_seed is None: + rng_seed = seed + # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader # add can initialize a Labels object # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets