diff --git a/src/LarvaDatasets.jl b/src/LarvaDatasets.jl
index 8188429153f1f25c3387aee6a23d671981116271..eef36fe5d7a8856ac59e57bad724e044a0d70d7e 100644
--- a/src/LarvaDatasets.jl
+++ b/src/LarvaDatasets.jl
@@ -604,6 +604,7 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
         includeall="edited",
         past_future_extensions=true,
         seed=nothing,
+        sample_size=nothing,
         kwargs...)
     repo = if input_data isa String
         isnothing(file_filter) ? Repository(input_data) : Repository(input_data, file_filter)
@@ -642,6 +643,11 @@ function new_write_larva_dataset_hdf5(output_dir, input_data;
         @error "Most likely cause: no time segments could be isolated"
         rethrow()
     end
+    if !isnothing(sample_size)
+        # TODO: move the below function into PlanarLarvae.Dataloaders
+        isnothing(seed) || @warn "Random number generation is not explicitly seeded with sample_size defined"
+        samplesize!(index, sample_size)
+    end
     total_sample_size = length(index)
     classcounts, _ = Dataloaders.groupby(index.sampler.selectors, index.targetcounts)
     #
@@ -846,4 +852,67 @@ end
 
 runid(file) = splitpath(file.source)[end-1]
 
+function samplesize!(index, sample_size)
+    total_sample_size = length(index)
+    sample_size < total_sample_size || return index
+
+    ratio = sample_size / total_sample_size
+
+    # apply `ratio` to the total counts first
+    maxcounts = Dataloaders.total(index.maxcounts)
+    targetcounts = Dataloaders.total(index.targetcounts)
+    for (label, count) in pairs(targetcounts)
+        targetcounts[label] = round(Int, count * ratio)
+    end
+    totalcount = sum(values(targetcounts))
+    if totalcount < sample_size
+        for label in shuffle(keys(targetcounts))
+            if targetcounts[label] < maxcounts[label]
+                targetcounts[label] += 1
+                totalcount += 1
+                totalcount == sample_size && break
+            end
+        end
+    elseif sample_size < totalcount
+        for label in shuffle(keys(targetcounts))
+            if 0 < targetcounts[label]
+                targetcounts[label] -= 1
+                totalcount -= 1
+                totalcount == sample_size && break
+            end
+        end
+    end
+
+    # apply `ratio` at the per-file level
+    targetcountsperfile = copy(index.targetcounts)
+    for counts in values(targetcountsperfile)
+        for (label, count) in pairs(counts)
+            counts[label] = round(Int, count * ratio)
+        end
+    end
+    targetcounts′= Dataloaders.total(targetcountsperfile)
+    for (label, targetcount) in pairs(targetcounts)
+        count = targetcounts′[label]
+        while count < targetcount
+            file = rand(keys(targetcountsperfile))
+            counts = targetcountsperfile[file]
+            if counts[label] < index.maxcounts[file][label]
+                counts[label] += 1
+                count += 1
+            end
+        end
+        while targetcount < count
+            file = rand(keys(targetcountsperfile))
+            counts = targetcountsperfile[file]
+            if 0 < counts[label]
+                counts[label] -= 1
+                count -= 1
+            end
+        end
+    end
+
+    index.targetcounts = targetcountsperfile
+    return index
+end
+
 end