Skip to content
Snippets Groups Projects
Commit ddf30bc1 authored by François  LAURENT's avatar François LAURENT
Browse files

implements larvatagger.jl#110

parent 014b6021
No related branches found
No related tags found
No related merge requests found
Pipeline #102145 passed
MIT License
Copyright (c) 2022 François Laurent, Institut Pasteur
Copyright (c) 2022-2023 François Laurent, Institut Pasteur
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
......
......@@ -266,9 +266,11 @@ function thresholdedcounts(counts; majorityweight=20)
end
function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after;
fixmwt=false, frameinterval=nothing, includeall=nothing,
fixmwt=false, frameinterval=nothing, includeall=nothing, seed=nothing,
)
fixmwt && @warn "`fixmwt=true` is no longer supported"
@debug "Seeding the random number generator" seed
isnothing(seed) || Random.seed!(seed)
# this method mutates argument `refs`
T = eltype(keys(counts))
refs′= Tuple{Int, Int, Int, T}[]
......@@ -480,7 +482,8 @@ end
"""
write_larva_dataset_hdf5(output_directory, input_files, window_length=20)
write_larva_dataset_hdf5(...; labels=nothing, labelpointers=nothing)
write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, chunks=false, shallow=false)
write_larva_dataset_hdf5(...; sample_size=nothing, balance=true, seed=nothing)
write_larva_dataset_hdf5(...; chunks=false, shallow=false)
write_larva_dataset_hdf5(...; file_filter, timestep_filter)
Sample series of 5-point spines from data files and save them in a hdf5 file,
......@@ -542,7 +545,8 @@ function write_larva_dataset_hdf5(output_dir::String,
balance=true,
fixmwt=false,
frameinterval=nothing,
includeall="edited")
includeall="edited",
seed=nothing)
files = if input_data isa String
repository = input_data
labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow)
......@@ -601,7 +605,7 @@ function write_larva_dataset_hdf5(output_dir::String,
write_larva_dataset_hdf5(output_file,
sample_sizes, files, refs, nsteps_before, nsteps_after;
fixmwt=fixmwt, frameinterval=frameinterval,
includeall=includeall)
includeall=includeall, seed=seed)
h5open(output_file, "cw") do h5
attributes(h5["samples"])["len_traj"] = window_length
......
......@@ -474,7 +474,7 @@ run `poetry add {pkg}` from directory: \n
def generate_dataset(self, input_files,
labels=None, window_length=20, sample_size=None, balance=True,
include_all=None, frame_interval=None):
include_all=None, frame_interval=None, seed=None):
"""
Generate a *larva_dataset hdf5* file in data/interim/{instance}/
"""
......@@ -486,7 +486,8 @@ run `poetry add {pkg}` from directory: \n
sample_size=sample_size,
balance=balance,
includeall=include_all,
frameinterval=frame_interval)
frameinterval=frame_interval,
seed=seed)
def compile_trxmat_database(self, input_dir,
labels=None, window_length=20, sample_size=None, reuse_h5files=False):
......
......@@ -12,6 +12,7 @@ Usage: tagging-backend [train|predict] --model-instance <name>
tagging-backend train ... --pretrained-model-instance <name>
tagging-backend train ... --include-all <secondary-label>
tagging-backend train ... --skip-make-dataset --skip-build-features
tagging-backend train ... --seed <seed>
tagging-backend predict ... --make-dataset --build-features
tagging-backend predict ... --sandbox <token>
tagging-backend --help
......@@ -99,6 +100,7 @@ def main(fun=None):
balancing_strategy = 'auto'
include_all = None
class_weights = None
seed = None
unknown_args = {}
k = 2
while k < len(sys.argv):
......@@ -150,6 +152,9 @@ def main(fun=None):
elif sys.argv[k] == "--include-all":
k = k + 1
include_all = sys.argv[k]
elif sys.argv[k] == "--seed":
k = k + 1
seed = sys.argv[k]
else:
unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1]
k = k + 1
......@@ -187,6 +192,8 @@ def main(fun=None):
logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only")
if include_all:
make_dataset_kwargs["include_all"] = include_all
if seed is not None:
make_dataset_kwargs["seed"] = seed
backend._run_script(backend.make_dataset, **make_dataset_kwargs)
if build_features:
backend._run_script(backend.build_features)
......@@ -198,6 +205,8 @@ def main(fun=None):
train_kwargs["pretrained_model_instance"] = pretrained_model_instance
if class_weights:
train_kwargs['class_weights'] = class_weights
if seed is not None:
train_kwargs['seed'] = seed
backend._run_script(backend.train_model, trailing=unknown_args, **train_kwargs)
else:
# called by make_dataset, build_features, train_model and predict_model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment