Skip to content
Snippets Groups Projects
Commit 2646a8eb authored by François  LAURENT's avatar François LAURENT
Browse files
parent 599f3bb4
No related branches found
No related tags found
No related merge requests found
MIT License MIT License
Copyright (c) 2022 François Laurent, Institut Pasteur Copyright (c) 2022-2023 François Laurent, Institut Pasteur
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
...@@ -106,7 +106,9 @@ class MaggotModule(nn.Module): ...@@ -106,7 +106,9 @@ class MaggotModule(nn.Module):
self.config[entry] = self.path_for_config(Path(self.config[entry])) self.config[entry] = self.path_for_config(Path(self.config[entry]))
def path_for_config(self, path): def path_for_config(self, path):
if self.root_dir and path.is_absolute(): if path.name.endswith('.pt'):
path = path.name
elif self.root_dir and path.is_absolute():
# Path.is_relative_to available from Python >= 3.9 only; # Path.is_relative_to available from Python >= 3.9 only;
# we want the present code to run on Python >= 3.8 # we want the present code to run on Python >= 3.8
relativepath = path.relative_to(self.root_dir) relativepath = path.relative_to(self.root_dir)
...@@ -175,6 +177,7 @@ class MaggotEncoder(MaggotModule): ...@@ -175,6 +177,7 @@ class MaggotEncoder(MaggotModule):
try: try:
encoder.load_state_dict(torch.load(path)) encoder.load_state_dict(torch.load(path))
except Exception as e: except Exception as e:
raise
_reason = e _reason = e
config['load_state'] = False # for `was_pretrained` to properly work config['load_state'] = False # for `was_pretrained` to properly work
else: else:
...@@ -182,7 +185,7 @@ class MaggotEncoder(MaggotModule): ...@@ -182,7 +185,7 @@ class MaggotEncoder(MaggotModule):
# if state file not found or config option "load_state" is False, # if state file not found or config option "load_state" is False,
# (re-)initialize the model's weights # (re-)initialize the model's weights
if _reason: if _reason:
logging.debug(f"initializing the encoder ({_reason})") logging.info(f"initializing the encoder ({_reason})")
_init, _bias = config.get('init', None), config.get('bias', None) _init, _bias = config.get('init', None), config.get('bias', None)
for child in encoder.children(): for child in encoder.children():
if isinstance(child, if isinstance(child,
......
...@@ -12,6 +12,15 @@ def train_model(backend, layers=1, pretrained_model_instance="default", ...@@ -12,6 +12,15 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive larva_dataset_file = glob.glob(str(backend.interim_data_dir() / "larva_dataset_*.hdf5")) # this other one is not recursive
assert len(larva_dataset_file) == 1 assert len(larva_dataset_file) == 1
# argument `rng_seed` predates `seed`
try:
seed = kwargs.pop('seed')
except KeyError:
pass
else:
if rng_seed is None:
rng_seed = seed
# instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader # instanciate a LarvaDataset object, that is similar to a PyTorch DataLoader
# add can initialize a Labels object # add can initialize a Labels object
# note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets # note: subsets=(1, 0, 0) => all data are training data; no validation or test subsets
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment