diff --git a/Manifest.toml b/Manifest.toml index 329dd1231d64a5a71ca6a3c1a0fd4280c951ce4a..4a73b4fbe0141384268c4a7706e3ac384ef55a61 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.2" +julia_version = "1.8.3" manifest_format = "2.0" project_hash = "2c20afabe03d014276e9478d0fdccbc2cdd634c1" @@ -309,12 +309,12 @@ uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" version = "1.8.0" [[deps.PlanarLarvae]] -deps = ["DelimitedFiles", "HDF5", "JSON3", "MAT", "Meshes", "OrderedCollections", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] -git-tree-sha1 = "607572b4d9404105e64e5b9b0f2f047bd307eb17" +deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] +git-tree-sha1 = "33b53b6c16da1bd0982ce3b95096807c43b977dc" repo-rev = "main" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" uuid = "c2615984-ef14-4d40-b148-916c85b43307" -version = "0.6.0" +version = "0.8.0" [[deps.Preferences]] deps = ["TOML"] diff --git a/Project.toml b/Project.toml index 3180012fc2c59edde30b63ae215adbf6cca3a2c3..a75dfba0f839e4b7833a667211df8ef22c653b2f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaggingBackends" uuid = "e551f703-3b82-4335-b341-d497b48d519b" authors = ["François Laurent", "Institut Pasteur"] -version = "0.9" +version = "0.10.0" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/pyproject.toml b/pyproject.toml index cb7e01260aeb61f0c67ab38d5f9704ac8a81c8f7..e4b1161777c9e95a16a142e48462a3a767c44482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "TaggingBackends" -version = "0.8" +version = "0.10" description = "Backbone for LarvaTagger.jl tagging backends" authors = ["François Laurent"] diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 328675fa321eb82f51ae73e3a217f58f7ad4895d..22d590cc09a4509a7117207aea3d194e9b3da433 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -17,6 +17,7 @@ structure of *trx.mat* files, an alternative implementation is provided by modul """ using PlanarLarvae, PlanarLarvae.Formats, PlanarLarvae.Features, PlanarLarvae.MWT +using PlanarLarvae.Datasets: coerce using Random using HDF5 using Dates @@ -264,14 +265,55 @@ function thresholdedcounts(counts; majorityweight=20) end function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after; - fixmwt=false, frameinterval=nothing, + fixmwt=false, frameinterval=nothing, includeall=nothing, ) fixmwt && @warn "`fixmwt=true` is no longer supported" # this method mutates argument `refs` - refs′= Tuple{Int, Int, Int, eltype(keys(counts))}[] - for (label, count) in pairs(counts) - for (i, j, k) in shuffle(refs[label])[1:count] - push!(refs′, (i, j, k, label)) + T = eltype(keys(counts)) + refs′= Tuple{Int, Int, Int, T}[] + if !isnothing(includeall) + includeall = coerce(T, includeall) + if haskey(counts, includeall) + count = counts[includeall] + T′= Vector{Tuple{Int, Int, Int, T}} + specialrefs = Dict{T, T′}() + for (i, j, k) in refs[includeall] + for l in keys(refs) + if l != includeall + m = findfirst(==((i, j, k)), refs[l]) + if !isnothing(m) + push!(get!(specialrefs, l, T′()), (i, j, k, l)) + deleteat!(refs[l], m) + end + end + end + end + if !isempty(specialrefs) + @info "Explicit inclusions based on label \"$(includeall)\":" [Symbol(label) => length(refs″) for (label, refs″) in pairs(specialrefs)]... + for (label, count) in pairs(counts) + label == includeall && continue + if label in keys(specialrefs) + refs″= specialrefs[label] + if count < length(refs″) + refs″ = shuffle(refs″)[1:count] + end + refs′= vcat(refs′, refs″) + count = count - length(refs″) + end + if 0 < count + for (i, j, k) in shuffle(refs[label])[1:count] + push!(refs′, (i, j, k, label)) + end + end + end + end + end + end + if isempty(refs′) + for (label, count) in pairs(counts) + for (i, j, k) in shuffle(refs[label])[1:count] + push!(refs′, (i, j, k, label)) + end end end empty!(refs) # free memory @@ -295,6 +337,7 @@ function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nste attributes(g)["n_samples"] = sampleid # extension h5["labels"] = collect(keys(counts)) + h5["label_counts"] = collect(values(counts)) #h5["files"] = [f.source for f in files] if !isnothing(frameinterval) attributes(g)["frame_interval"] = frameinterval @@ -473,6 +516,12 @@ Note that, if `input_data` lists all files, `labelledfiles` is not called and ar Similarly, if `labelpointers` is defined, `labelcounts` is not called and argument `timestep_filter` is not used. +*New in version 0.10*: `includeall` specifies a secondary label for systematic inclusion in +the dataset. The time segments with this secondary label are accounted for under the +associated primary label, prior to applying the balancing rule. If the label specified by +`includeall` is found in `labels`, it is considered as primary and `includeall` is ignored. +Generally speaking, `labels` should not include any secondary label. + Known issue: ASCII-compatible string attributes are ASCII encoded and deserialized as `bytes` by the *h5py* Python library. """ @@ -488,7 +537,8 @@ function write_larva_dataset_hdf5(output_dir::String, shallow=false, balance=true, fixmwt=false, - frameinterval=nothing) + frameinterval=nothing, + includeall="edited") files = if input_data isa String repository = input_data labelledfiles(repository, chunks; selection_rule=file_filter, shallow=shallow) @@ -514,10 +564,23 @@ function write_larva_dataset_hdf5(output_dir::String, refs = labelpointers counts = Dict{String, Int}(label=>length(pointers) for (label, pointers) in pairs(labelpointers)) end - if !isnothing(labels) - labels′ = p -> string(p[1]) in labels + if isnothing(labels) + labels = collect(keys(counts)) + if !isnothing(includeall) && includeall ∈ labels + labels = [label for label in labels if label != includeall] + end + else + if !isnothing(includeall) && includeall ∈ labels + includeall = nothing + end + labels′= if isnothing(includeall) + p -> string(p[1]) in labels + else + p -> string(p[1]) in labels || string(p[1]) == includeall + end filter!(labels′, counts) filter!(labels′, refs) + isempty(counts) && throw("None of specified labels were found") end if balance sample_sizes, total_sample_size = balancedcounts(counts, sample_size) @@ -525,12 +588,13 @@ function write_larva_dataset_hdf5(output_dir::String, isnothing(sample_size) || @error "Argument sample_size not supported for the specified balancing strategy" sample_sizes, total_sample_size = thresholdedcounts(counts) end - @info "Sample sizes (observed, selected):" [Symbol(label) => (count, get(sample_sizes, label, 0)) for (label, count) in pairs(counts)]... + @info "Sample sizes (observed, selected):" [Symbol(label) => (get(counts, label, 0), get(sample_sizes, label, 0)) for label in labels]... date = Dates.format(Dates.now(), "yyyy_mm_dd") output_file = joinpath(output_dir, "larva_dataset_$(date)_$(window_length)_$(window_length)_$(total_sample_size).hdf5") write_larva_dataset_hdf5(output_file, sample_sizes, files, refs, nsteps_before, nsteps_after; - fixmwt=fixmwt, frameinterval=frameinterval) + fixmwt=fixmwt, frameinterval=frameinterval, + includeall=includeall) h5open(output_file, "cw") do h5 attributes(h5["samples"])["len_traj"] = window_length diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 3e9d2d32ca6520b0c9e848cf20bbce58bf1cd37a..9ce597029f74b3433edf874862d23b2b18caf4ec 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -230,8 +230,11 @@ class LarvaDataset: @property def class_weights(self): if not isinstance(self._class_weights, np.ndarray) and self._class_weights in (None, True): - _, class_counts = np.unique(self.training_labels, return_counts=True) - class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) + try: + class_counts = np.asarray(self.full_set["label_counts"]) + except KeyError: + _, class_counts = np.unique(self.training_labels, return_counts=True) + class_counts = np.array([class_counts[i] for i in range(len(self.labels))]) self._class_weights = 1 - class_counts / np.sum(class_counts) return None if self._class_weights is False else self._class_weights diff --git a/src/taggingbackends/data/labels.py b/src/taggingbackends/data/labels.py index f9402f502e3a8ee746d1bd54bd07cd57fea41206..e2b93ad59c98dc9e2e86584be438592b5bcf8bb7 100644 --- a/src/taggingbackends/data/labels.py +++ b/src/taggingbackends/data/labels.py @@ -60,9 +60,13 @@ def retagged_trxmat_confusion_matrix(label_file): ``` The `labelspec` attribute is assumed to be a list, which is valid for *.label* -generated by automatic tagging. -Manual tagging will store label names as `Labels.labelspec['names']`, because -label colors are also stored in the `labelspec` attribute. +files generated by automatic tagging. +In contrast, manual tagging stores label names as `Labels.labelspec['names']`, +because label colors are also stored in the `labelspec` attribute. + +With taggingbackends==0.9, a related attribute was introduced: +`secondarylabelspec`. To get a unique array of indexable labels, both primary +and secondary labels in a same array, use `full_label_list` instead. """ class Labels: @@ -77,6 +81,7 @@ class Labels: # self.labels, self.metadata = labels, metadata self.labelspec, self.units = labelspec, units + self.secondarylabelspec = None self._tracking = tracking self._input_labels = None # @@ -258,6 +263,7 @@ class Labels: new_self = json.load(file, cls=LabelDecoder) self.labels, self.metadata = new_self.labels, new_self.metadata self.labelspec, self.units = new_self.labelspec, new_self.units + self.secondarylabelspec = new_self.secondarylabelspec self.tracking = new_self.tracking return self @@ -275,6 +281,8 @@ class Labels: data["units"] = self.units if self.labelspec: data["labels"] = self.labelspec + if self.secondarylabelspec: + data["secondarylabels"] = self.secondarylabelspec if self.tracking: if isinstance(self.tracking, dict): data["dependencies"] = self.tracking @@ -297,6 +305,7 @@ class Labels: run = self.metadata.pop("id") self.units = data.get("units", {}) self.labelspec = data.get("labels", {}) + self.secondarylabelspec = data.get("secondarylabels", []) self._tracking = data.get("dependencies", []) if isinstance(self._tracking, dict): self._tracking = [self._tracking] @@ -309,6 +318,19 @@ class Labels: for timestamp, label in zip(track["t"], track["labels"])} return self + """ + List of str: all different labels including primary and secondary labels. + """ + @property + def full_label_list(self): + if isinstance(self.labelspec, dict): + labelset = self.labelspec['names'] + else: + labelset = self.labelspec + if self.secondarylabelspec: + labelset = labelset + self.secondarylabelspec + return labelset + """ Encode the text labels as indices (`int` or `list` of `int`). @@ -323,10 +345,7 @@ class Labels: elif isinstance(label, dict): encoded = {t: self.encode(l) for t, l in label.items()} else: - if isinstance(self.labelspec, dict): - labelset = self.labelspec['names'] - else: - labelset = self.labelspec + labelset = self.full_label_list if isinstance(label, str): encoded = labelset.index(label) + 1 elif isinstance(label, int): @@ -349,10 +368,7 @@ class Labels: elif isinstance(label, dict): decoded = {t: self.decode(l) for t, l in label.items()} else: - if isinstance(self.labelspec, dict): - labelset = self.labelspec['names'] - else: - labelset = self.labelspec + labelset = self.full_label_list if isinstance(label, int): decoded = labelset[label-1] elif isinstance(label, str): diff --git a/src/taggingbackends/explorer.py b/src/taggingbackends/explorer.py index 13bc5815381f358c03494e172d6d440d46b9c61a..3d59007aa63d1ac740a1b702e074040ff5f06959 100644 --- a/src/taggingbackends/explorer.py +++ b/src/taggingbackends/explorer.py @@ -474,7 +474,7 @@ run `poetry add {pkg}` from directory: \n def generate_dataset(self, input_files, labels=None, window_length=20, sample_size=None, balance=True, - frame_interval=None): + include_all=None, frame_interval=None): """ Generate a *larva_dataset hdf5* file in data/interim/{instance}/ """ @@ -485,6 +485,7 @@ run `poetry add {pkg}` from directory: \n labels=labels, sample_size=sample_size, balance=balance, + includeall=include_all, frameinterval=frame_interval) def compile_trxmat_database(self, input_dir, diff --git a/src/taggingbackends/main.py b/src/taggingbackends/main.py index a6c6d9c19f3f25267b24b1cde85b8e83f035a8ec..76761fdf39cb7697dca78051aa00d6633e80db67 100644 --- a/src/taggingbackends/main.py +++ b/src/taggingbackends/main.py @@ -10,7 +10,9 @@ Usage: tagging-backend [train|predict] --model-instance <name> tagging-backend train ... --sample-size <N> --balancing-strategy <strategy> tagging-backend train ... --frame-interval <I> --window-length <T> tagging-backend train ... --pretrained-model-instance <name> - tagging-backend predict ... --skip-make-dataset --sandbox <token> + tagging-backend train ... --include-all <secondary-label> + tagging-backend train ... --skip-make-dataset --skip-build-features + tagging-backend predict ... --make-dataset --build-features --sandbox <token> `tagging-backend` typically is run using `poetry run`. A name must be provided to identify the trained model and its location within @@ -36,6 +38,10 @@ the `make_dataset` module is loaded and this may take quite some time due to dependencies (e.g. Julia FFI). The `--skip-make-dataset` option makes `train` truly skip this step; the corresponding module is not loaded. +Since version 0.10, `predict` makes `--skip-make-dataset` and +`--skip-build-features` the default behavior. As a counterpart, it admits +arguments `--make-dataset` and `--build-features`. + `--sandbox <token>` makes `tagging-backend` use a token instead of <name> as directory name in data/raw, data/interim and data/processed. This is intended to prevent conflicts on running `predict` in parallel on @@ -59,10 +65,11 @@ def main(fun=None): input_files, labels = [], [] sample_size = window_length = frame_interval = None trxmat_only = reuse_h5files = False - skip_make_dataset = skip_build_features = False + make_dataset = build_features = None pretrained_model_instance = None sandbox = False balancing_strategy = 'auto' + include_all = None unknown_args = {} k = 2 while k < len(sys.argv): @@ -92,9 +99,13 @@ def main(fun=None): elif sys.argv[k] == "--reuse-h5files": reuse_h5files = True elif sys.argv[k] == "--skip-make-dataset": - skip_make_dataset = True + make_dataset = False elif sys.argv[k] == "--skip-build-features": - skip_build_features = True + build_features = False + elif sys.argv[k] == '--make-dataset': + make_dataset = True + elif sys.argv[k] == '--build-features': + build_features = True elif sys.argv[k] == "--pretrained-model-instance": k = k + 1 pretrained_model_instance = sys.argv[k] @@ -104,6 +115,9 @@ def main(fun=None): elif sys.argv[k] == "--balancing-strategy": k = k + 1 balancing_strategy = sys.argv[k] + elif sys.argv[k] == "--include-all": + k = k + 1 + include_all = sys.argv[k] else: unknown_args[sys.argv[k].lstrip('-').replace('-', '_')] = sys.argv[k+1] k = k + 1 @@ -116,7 +130,11 @@ def main(fun=None): if input_files: for file in input_files: backend.move_to_raw(file) - if not skip_make_dataset: + if make_dataset is None and train_or_predict == 'train': + make_dataset = True + if build_features is None and train_or_predict == 'train': + build_features = True + if make_dataset: make_dataset_kwargs = dict(labels_expected=train_or_predict == "train", balancing_strategy=balancing_strategy) if labels: @@ -133,8 +151,10 @@ def main(fun=None): make_dataset_kwargs["reuse_h5files"] = True elif reuse_h5files: logging.info("option --reuse-h5files is ignored in the absence of --trxmat-only") + if include_all: + make_dataset_kwargs["include_all"] = include_all backend._run_script(backend.make_dataset, **make_dataset_kwargs) - if not skip_build_features: + if build_features: backend._run_script(backend.build_features) if train_or_predict == "predict": backend._run_script(backend.predict_model, trailing=unknown_args)