diff --git a/LICENSE b/LICENSE index e047d11a0921d8543224f4dce4482b27c6b8057f..1473c0cede580821495a5e623fb05b8985616a6d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 François Laurent, Institut Pasteur +Copyright (c) 2022-2023 François Laurent, Institut Pasteur Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 7580a90cc75270e79e07143b50d64bc4444f981a..b531b73254fe1606624399e3a95467e1e656a4f0 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -266,9 +266,11 @@ function thresholdedcounts(counts; majorityweight=20) end function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; - fixmwt=false, frameinterval=nothing, includeall=nothing, + fixmwt=false, frameinterval=nothing, includeall=nothing, seed=nothing, ) fixmwt && @warn "`fixmwt=true` is no longer supported" + @debug "Seeding the random number generator" seed + isnothing(seed) || Random.seed!(seed) # this method mutates argument `refs` T = eltype(keys(counts)) refs′= Tuple{Int, Int, Int, T}[] @@ -480,7 +482,8 @@ end """ write_larva_dataset_hdf5(output_directory, input_files, window_length=20) write_larva_dataset_hdf5(...; labels=nothing, labelpointers=nothing) - write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, chunks=false, shallow=false) + write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, seed=nothing) + write_larva_dataset_hdf5(...; chunks=false, shallow=false) write_larva_dataset_hdf5(...; file_filter, timestep_filter) Sample series of 5-point spines from data files and save them in a hdf5 file, @@ -542,7 +545,8 @@ function write_larva_dataset_hdf5(output_dir::String, balance=true, fixmwt=false, frameinterval=nothing, - includeall="edited") + includeall="edited", + seed=nothing) files = if input_data isa String repository = input_data labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow) @@ -601,7 +605,7 @@ function write_larva_dataset_hdf5(output_dir::String, write_larva_dataset_hdf5(output_file, sample_sizes, files, refs, nsteps_before, nsteps_after; fixmwt=fixmwt, frameinterval=frameinterval, - includeall=includeall) + includeall=includeall, seed=seed) h5open(output_file, "cw") do h5 attributes(h5["samples"])["len_traj"] = window_length diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index 3d59007aa63d1ac740a1b702e074040ff5f06959..7c30b065a563f825e2cb07a4fdd4b3d958787dc8 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -474,7 +474,7 @@ run `poetry add {pkg}` from directory: \n def generate_dataset(self, input_files, labels=None, window_length=20, sample_size=None, balance=True, - include_all=None, frame_interval=None): + include_all=None, frame_interval=None, seed=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ """ @@ -486,7 +486,8 @@ run `poetry add {pkg}` from directory: \n sample_size=sample_size, balance=balance, includeall=include_all, - frameinterval=frame_interval) + frameinterval=frame_interval, + seed=seed) def compile_trxmat_database(self, input_dir, labels=None, window_length=20, sample_size=None, reuse_h5files=False): diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index f530b149389f3eaba15cca571a46603acb349eed..e59ce4d929a82b22d170f983ae87d99c9f181d17 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -12,6 +12,7 @@ Usage: tagging-backend [train|predict] --model-instance <name> tagging-backend train ... --pretrained-model-instance <name> tagging-backend train ... --include-all <secondary-label> tagging-backend train ... --skip-make-dataset --skip-build-features + tagging-backend train ... --seed <seed> tagging-backend predict ... --make-dataset --build-features tagging-backend predict ... --sandbox <token> tagging-backend --help @@ -99,6 +100,7 @@ def main(fun=None): balancing_strategy = 'auto' include_all = None class_weights = None + seed = None unknown_args = {} k = 2 while k < len(sys.argv): @@ -150,6 +152,9 @@ def main(fun=None): elif sys.argv[k] == "--include-all": k = k + 1 include_all = sys.argv[k] + elif sys.argv[k] == "--seed": + k = k + 1 + seed = sys.argv[k] else: unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1] k = k + 1 @@ -187,6 +192,8 @@ def main(fun=None): logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only") if include_all: make_dataset_kwargs["include_all"] = include_all + if seed is not None: + make_dataset_kwargs["seed"] = seed backend._run_script(backend.make_dataset, **make_dataset_kwargs) if build_features: backend._run_script(backend.build_features) @@ -198,6 +205,8 @@ def main(fun=None): train_kwargs["pretrained_model_instance"] = pretrained_model_instance if class_weights: train_kwargs['class_weights'] = class_weights + if seed is not None: + train_kwargs['seed'] = seed backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs) else: # called by make_dataset, build_features, train_model and predict_model