Skip to content
Snippets Groups Projects
Commit b7fa1f91 authored by François  LAURENT's avatar François LAURENT
Browse files

introspection window_length in LarvaDataset

parent e415c473
Branches
No related tags found
No related merge requests found
Pipeline #85932 passed
......@@ -45,6 +45,12 @@ It returns `true` for labels to be accounted for, and `false` for those to be sk
If `unload=true`, data are unloaded from memory as soon as the corresponding file have been
read.
!!! note
this function typically is the main bottleneck in the process of generating
*larva_dataset hdf5* files.
"""
function labelcounts(files, headlength=nothing, taillength=nothing;
selection_rule=nothing, unload=false)
......@@ -165,12 +171,14 @@ function balancedcounts(counts, targetcount=nothing, majorityweight=2)
end
function write_larva_dataset_hdf5(path, counts, files, refs, nsteps_before, nsteps_after)
# this method mutates argument `refs`
refs′= Tuple{Int, Int, Int, eltype(keys(counts))}[]
for (label, count) in pairs(counts)
for (i, j, k) in shuffle(refs[label])[1:count]
push!(refs′, (i, j, k, label))
end
end
empty!(refs) # free memory
refs′= sort(refs′)
sampleid = 0
ch = Channel(; spawn=true) do ch
......
......@@ -113,6 +113,7 @@ end
function LarvaDatasets.labelcounts(
files::Vector{LarvaH5}, nsteps_before=nothing, nsteps_after=nothing;
unused...)
# TODO: multithread the IO and rework the index to decrease the memory consumption.
if isnothing(nsteps_before)
nsteps_before = isnothing(nsteps_after) ? 0 : nsteps_after + 1
end
......
......@@ -14,6 +14,7 @@ class LarvaDataset:
self._encode = True
self._sample_size = None
self._mask = slice(0, None)
self._window_length = None
"""
*h5py.File*: *larva_dataset hdf5* file handler.
"""
......@@ -123,9 +124,24 @@ class LarvaDataset:
self.split()
return self._validation_set
"""
Draw an item or sample.
Draw an observation.
"""
def getsample(self, validation_set=False):
def getobservation(self, validation_set=False):
dataset = self.validation_set if validation_set else self.training_set
return next(dataset)
"""
Alias for `getobservation`.
"""
def getsample(self, validation_set=False):
return self.getobservation(validation_set)
"""
*int*: number of time points in a segment.
"""
@property
def window_length(self):
if self._window_length is None:
allrecords = self.full_set["samples"]
anyrecord = allrecords[next(iter(allrecords.keys()))]
self._window_length = anyrecord.shape[0]
return self._window_length
......@@ -69,7 +69,7 @@ def main(fun=None):
sample_size = sys.argv[k]
elif sys.argv[k] == "--window-length":
k = k + 1
sample_size = sys.argv[k]
window_length = sys.argv[k]
elif sys.argv[k] == "--trxmat-only":
trxmat_only = True
elif sys.argv[k] == "--reuse-h5files":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment