Skip to content
Snippets Groups Projects

Save larva_dataset file along with model files

Closed François LAURENT requested to merge save_dataset into main
6 files
+ 147
13
Compare changes
  • Side-by-side
  • Inline
Files
6
+ 86
0
 
import os
 
from glob import glob
 
import numpy as np
 
from sklearn.metrics import confusion_matrix
 
from taggingbackends.data.labels import Labels
 
from taggingbackends.data.dataset import LarvaDataset
 
from taggingbackends.explorer import BackendExplorer
 
 
"""
 
Generic function for true labels possibly in the shape of tag arrays.
 
"""
 
def index(labels, tags):
 
if isinstance(tags, str):
 
if tags == 'edited':
 
# probably a (manual) tagging mistake
 
return -1
 
else:
 
return labels.index(tags)
 
else:
 
for i, label in enumerate(labels):
 
if label in tags:
 
return i
 
print("Incompatible labels")
 
print(" expected labels:")
 
print(labels)
 
print(" labels at a time step:")
 
print(tags)
 
return -1
 
 
uid, gid = os.getenv('HOST_UID', None), os.getenv('HOST_GID', None)
 
if uid is not None:
 
uid, gid = int(uid), int(gid)
 
 
labels = None
 
cm = None
 
 
fn_true = 'groundtruth.label'
 
fn_pred = 'predicted.label'
 
 
for assay, _, files in os.walk(f'/data'):
 
if any([fn == fn_true for fn in files]) and any([fn == fn_pred for fn in files]):
 
expected = Labels(os.path.join(assay, fn_true))
 
predicted = Labels(os.path.join(assay, fn_pred))
 
 
if labels is None:
 
labels = predicted.labelspec
 
else:
 
assert labels == predicted.labelspec
 
 
cm_ = None
 
for larva in expected:
 
# note: not all the larvae in `expected` may be in `predicted`
 
y_pred = np.array([labels.index(label) for label in predicted[larva].values()])
 
y_true= np.array([index(labels, tags) for tags in expected[larva].values()])
 
ok = 0 <= y_true
 
cm__ = confusion_matrix(y_true[ok], y_pred[ok], labels=range(len(labels)))
 
cm_ = cm__ if cm_ is None else cm_ + cm__
 
cm = cm_ if cm is None else cm + cm_
 
 
assert cm_ is not None
 
 
path = os.path.join(assay, 'confusion.csv')
 
with open(path, 'w') as f:
 
f.write(",".join(labels))
 
for row in cm_:
 
f.write("\n")
 
f.write(",".join([str(count) for count in row]))
 
 
if uid is not None:
 
os.chown(path, uid, gid)
 
 
assert cm is not None
 
print('labels:')
 
print(labels)
 
print('confusion matrix:')
 
print(cm)
 
 
precision = np.diag(cm) / cm.sum(axis=0)
 
recall = np.diag(cm) / cm.sum(axis=1)
 
assert np.all(0 < precision)
 
assert np.all(0 < recall)
 
f1score = 2 * precision * recall / (precision + recall)
 
print('f1-scores per class:')
 
print(f1score)
 
print('f1-score:')
 
print(np.mean(f1score))
Loading