diff --git a/Manifest.toml b/Manifest.toml index d1dbb68891c08febfe490dc5efaa434e0895aee2..84c6a4d2fb0bf89d829568da8b8273fbc65f14d6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.0" +julia_version = "1.9.2" manifest_format = "2.0" project_hash = "2c20afabe03d014276e9478d0fdccbc2cdd634c1" @@ -64,7 +64,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.2+0" +version = "1.0.5+0" [[deps.Conda]] deps = ["Downloads", "JSON", "VersionParsing"] @@ -308,15 +308,15 @@ version = "2.5.10" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.0" +version = "1.9.2" [[deps.PlanarLarvae]] deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] -git-tree-sha1 = "025970ba7e5b0b9455b2487c1ad2150a17edba0c" +git-tree-sha1 = "25dede7c9e34786f3c9a576fc2da3c3448c12d80" repo-rev = "main" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" uuid = "c2615984-ef14-4d40-b148-916c85b43307" -version = "0.13.0" +version = "0.14.0" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -483,7 +483,7 @@ version = "1.2.13+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.7.0+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/Project.toml b/Project.toml index eb48701da429f612664dddfce60f75064568c1eb..4731bc26eb3551067eaacf1e1b690ae5c67f448a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaggingBackends" uuid = "e551f703-3b82-4335-b341-d497b48d519b" authors = ["François Laurent", "Institut Pasteur"] -version = "0.15.2" +version = "0.15.3" [deps] Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/pyproject.toml b/pyproject.toml index 4b3f3ee2ce0113a910b8f04533315cf3f546be93..3a7f66ef70140ee3e975ff4f210efee5f54b1815 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "TaggingBackends" -version = "0.15.2" +version = "0.15.3" description = "Backbone for LarvaTagger.jl tagging backends" authors = ["François Laurent"] diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl index 8188429153f1f25c3387aee6a23d671981116271..1e90013cc052a01c7711fed0a314f5e4520bc92f 100644 --- a/src/LarvaDatasets.jl +++ b/src/LarvaDatasets.jl @@ -26,7 +26,7 @@ using Statistics using Memoization using OrderedCollections -export write_larva_dataset_hdf5, first_stimulus, labelcounts +export write_larva_dataset_hdf5, first_stimulus, labelcounts, check_larva_dataset_hdf5 """ labelcounts(files) @@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; includeall="edited", past_future_extensions=true, seed=nothing, + sample_size=nothing, kwargs...) repo = if input_data isa String isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter) @@ -642,6 +643,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; @error "Most likely cause: no time segments could be isolated" rethrow() end + isnothing(sample_size) || samplesize!(index, sample_size) total_sample_size = length(index) classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts) # @@ -703,6 +705,9 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; else # ensure labels are ordered as provided in input; # see https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/24 + if labels isa AbstractDict + labels = string.(keys(labels)) + end h5["labels"] = labels h5["label_counts"] = [get(classcounts, Symbol(label), 0) for label in labels] end @@ -846,4 +851,30 @@ end runid(file) = splitpath(file.source)[end-1] +""" + check_larva_dataset_hdf5(path) + +Read the total label counts and return example time points. +""" +function check_larva_dataset_hdf5(path; print=true) + h5open(path, "r") do h5 + labels = read(h5, "labels") + labelcounts = read(h5, "label_counts") + labelcounts = Dict(Symbol(label) => count for (label, count) in zip(labels, labelcounts)) + print && @info "Labels:" pairs(labelcounts)... + examples = Dict{Symbol, NamedTuple{(:path, :larva_number, :reference_time), Tuple{String, Int, Float64}}}() + g = h5["samples"] + for sampleid in 1:read(attributes(g), "n_samples") + h = g["sample_$sampleid"] + label = Symbol(read(attributes(h), "behavior")) + if label ∉ keys(examples) + examples[label] = (path=read(attributes(h), "path"), larva_number=read(attributes(h), "larva_number"), reference_time=read(attributes(h), "reference_time")) + @info "$(label) example" examples[label]... + length(examples) == length(labels) && break + end + end + return labelcounts, examples + end +end + end diff --git a/src/taggingbackends/data/dataset.py b/src/taggingbackends/data/dataset.py index 74c39aca8936d9a7db4be84b7f7992905a975130..0700b0ac2795c3b941b220f1a0ff5568bbc5990c 100644 --- a/src/taggingbackends/data/dataset.py +++ b/src/taggingbackends/data/dataset.py @@ -202,7 +202,11 @@ class LarvaDataset: elif subset.startswith("test"): dataset = self.test_set if nbatches == "all": - nbatches = len(dataset) + if isinstance(dataset, itertools.cycle): + logging.warning("drawing unlimited number of batches from circular dataset") + nbatches = np.inf + else: + nbatches = len(dataset) try: while 0 < nbatches: nbatches -= 1