Skip to content
Snippets Groups Projects
Commit ede676e6 authored by François  LAURENT's avatar François LAURENT
Browse files

Merge branch 'dev' into 'main'

sample_size argument and check_larva_dataset_hdf5 function

See merge request !6
parents 8a761bba 5d034e98
No related branches found
No related tags found
1 merge request!6sample_size argument and check_larva_dataset_hdf5 function
Pipeline #111760 passed
# This file is machine-generated - editing it directly is not advised # This file is machine-generated - editing it directly is not advised
julia_version = "1.9.0" julia_version = "1.9.2"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "2c20afabe03d014276e9478d0fdccbc2cdd634c1" project_hash = "2c20afabe03d014276e9478d0fdccbc2cdd634c1"
...@@ -64,7 +64,7 @@ weakdeps = ["Dates", "LinearAlgebra"] ...@@ -64,7 +64,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]] [[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.2+0" version = "1.0.5+0"
[[deps.Conda]] [[deps.Conda]]
deps = ["Downloads", "JSON", "VersionParsing"] deps = ["Downloads", "JSON", "VersionParsing"]
...@@ -308,15 +308,15 @@ version = "2.5.10" ...@@ -308,15 +308,15 @@ version = "2.5.10"
[[deps.Pkg]] [[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.0" version = "1.9.2"
[[deps.PlanarLarvae]] [[deps.PlanarLarvae]]
deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"]
git-tree-sha1 = "025970ba7e5b0b9455b2487c1ad2150a17edba0c" git-tree-sha1 = "25dede7c9e34786f3c9a576fc2da3c3448c12d80"
repo-rev = "main" repo-rev = "main"
repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl"
uuid = "c2615984-ef14-4d40-b148-916c85b43307" uuid = "c2615984-ef14-4d40-b148-916c85b43307"
version = "0.13.0" version = "0.14.0"
[[deps.PrecompileTools]] [[deps.PrecompileTools]]
deps = ["Preferences"] deps = ["Preferences"]
...@@ -483,7 +483,7 @@ version = "1.2.13+0" ...@@ -483,7 +483,7 @@ version = "1.2.13+0"
[[deps.libblastrampoline_jll]] [[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.7.0+0" version = "5.8.0+0"
[[deps.nghttp2_jll]] [[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"] deps = ["Artifacts", "Libdl"]
......
name = "TaggingBackends" name = "TaggingBackends"
uuid = "e551f703-3b82-4335-b341-d497b48d519b" uuid = "e551f703-3b82-4335-b341-d497b48d519b"
authors = ["François Laurent", "Institut Pasteur"] authors = ["François Laurent", "Institut Pasteur"]
version = "0.15.2" version = "0.15.3"
[deps] [deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
......
[tool.poetry] [tool.poetry]
name = "TaggingBackends" name = "TaggingBackends"
version = "0.15.2" version = "0.15.3"
description = "Backbone for LarvaTagger.jl tagging backends" description = "Backbone for LarvaTagger.jl tagging backends"
authors = ["François Laurent"] authors = ["François Laurent"]
......
...@@ -26,7 +26,7 @@ using Statistics ...@@ -26,7 +26,7 @@ using Statistics
using Memoization using Memoization
using OrderedCollections using OrderedCollections
export write_larva_dataset_hdf5, first_stimulus, labelcounts export write_larva_dataset_hdf5, first_stimulus, labelcounts, check_larva_dataset_hdf5
""" """
labelcounts(files) labelcounts(files)
...@@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; ...@@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
includeall="edited", includeall="edited",
past_future_extensions=true, past_future_extensions=true,
seed=nothing, seed=nothing,
sample_size=nothing,
kwargs...) kwargs...)
repo = if input_data isa String repo = if input_data isa String
isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter) isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter)
...@@ -642,6 +643,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; ...@@ -642,6 +643,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
@error "Most likely cause: no time segments could be isolated" @error "Most likely cause: no time segments could be isolated"
rethrow() rethrow()
end end
isnothing(sample_size) || samplesize!(index, sample_size)
total_sample_size = length(index) total_sample_size = length(index)
classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts) classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts)
# #
...@@ -703,6 +705,9 @@ function new_write_larva_dataset_hdf5(output_dir, input_data; ...@@ -703,6 +705,9 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
else else
# ensure labels are ordered as provided in input; # ensure labels are ordered as provided in input;
# see https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/24 # see https://gitlab.pasteur.fr/nyx/TaggingBackends/-/issues/24
if labels isa AbstractDict
labels = string.(keys(labels))
end
h5["labels"] = labels h5["labels"] = labels
h5["label_counts"] = [get(classcounts, Symbol(label), 0) for label in labels] h5["label_counts"] = [get(classcounts, Symbol(label), 0) for label in labels]
end end
...@@ -846,4 +851,30 @@ end ...@@ -846,4 +851,30 @@ end
runid(file) = splitpath(file.source)[end-1] runid(file) = splitpath(file.source)[end-1]
"""
check_larva_dataset_hdf5(path)
Read the total label counts and return example time points.
"""
function check_larva_dataset_hdf5(path; print=true)
h5open(path, "r") do h5
labels = read(h5, "labels")
labelcounts = read(h5, "label_counts")
labelcounts = Dict(Symbol(label) => count for (label, count) in zip(labels, labelcounts))
print && @info "Labels:" pairs(labelcounts)...
examples = Dict{Symbol, NamedTuple{(:path, :larva_number, :reference_time), Tuple{String, Int, Float64}}}()
g = h5["samples"]
for sampleid in 1:read(attributes(g), "n_samples")
h = g["sample_$sampleid"]
label = Symbol(read(attributes(h), "behavior"))
if label keys(examples)
examples[label] = (path=read(attributes(h), "path"), larva_number=read(attributes(h), "larva_number"), reference_time=read(attributes(h), "reference_time"))
@info "$(label) example" examples[label]...
length(examples) == length(labels) && break
end
end
return labelcounts, examples
end
end
end end
...@@ -202,7 +202,11 @@ class LarvaDataset: ...@@ -202,7 +202,11 @@ class LarvaDataset:
elif subset.startswith("test"): elif subset.startswith("test"):
dataset = self.test_set dataset = self.test_set
if nbatches == "all": if nbatches == "all":
nbatches = len(dataset) if isinstance(dataset, itertools.cycle):
logging.warning("drawing unlimited number of batches from circular dataset")
nbatches = np.inf
else:
nbatches = len(dataset)
try: try:
while 0 < nbatches: while 0 < nbatches:
nbatches -= 1 nbatches -= 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment