Skip to content
Snippets Groups Projects
Commit 1df5541d authored by François  LAURENT's avatar François LAURENT
Browse files

confusion matrices

parent 65965546
No related branches found
No related tags found
2 merge requests!9Set of commits to be tagged 0.16,!7Save larva_dataset file along with model files
Pipeline #117147 passed
...@@ -312,11 +312,11 @@ version = "1.9.2" ...@@ -312,11 +312,11 @@ version = "1.9.2"
[[deps.PlanarLarvae]] [[deps.PlanarLarvae]]
deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"] deps = ["DelimitedFiles", "HDF5", "JSON3", "LinearAlgebra", "MAT", "Meshes", "OrderedCollections", "Random", "SHA", "StaticArrays", "Statistics", "StatsBase", "StructTypes"]
git-tree-sha1 = "a6ced965b03efe596835f093d0ffbf1e8d991d50" git-tree-sha1 = "3672831638c71be6e111ffde3376580d0d2bfaf5"
repo-rev = "v0.14a2" repo-rev = "3bffeb68ab85f82c048bcb5aea3d19048e211ef9"
repo-url = "https://gitlab.pasteur.fr/nyx/planarlarvae.jl" repo-url = "https://gitlab.pasteur.fr/nyx/PlanarLarvae.jl"
uuid = "c2615984-ef14-4d40-b148-916c85b43307" uuid = "c2615984-ef14-4d40-b148-916c85b43307"
version = "0.14.0-a" version = "0.14.0"
[[deps.PrecompileTools]] [[deps.PrecompileTools]]
deps = ["Preferences"] deps = ["Preferences"]
......
import os
from glob import glob
import numpy as np
from sklearn.metrics import confusion_matrix
from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset
from taggingbackends.explorer import BackendExplorer
"""
Generic function for true labels possibly in the shape of tag arrays.
"""
def index(labels, tags):
if isinstance(tags, str):
if tags == 'edited':
# probably a (manual) tagging mistake
return -1
else:
return labels.index(tags)
else:
for i, label in enumerate(labels):
if label in tags:
return i
print("Incompatible labels")
print(" expected labels:")
print(labels)
print(" labels at a time step:")
print(tags)
return -1
uid, gid = os.getenv('HOST_UID', None), os.getenv('HOST_GID', None)
if uid is not None:
uid, gid = int(uid), int(gid)
labels = None
cm = None
fn_true = 'groundtruth.label'
fn_pred = 'predicted.label'
for assay, _, files in os.walk(f'/data'):
if any([fn == fn_true for fn in files]) and any([fn == fn_pred for fn in files]):
expected = Labels(os.path.join(assay, fn_true))
predicted = Labels(os.path.join(assay, fn_pred))
if labels is None:
labels = predicted.labelspec
else:
assert labels == predicted.labelspec
cm_ = None
for larva in expected:
# note: not all the larvae in `expected` may be in `predicted`
y_pred = np.array([labels.index(label) for label in predicted[larva].values()])
y_true= np.array([index(labels, tags) for tags in expected[larva].values()])
ok = 0 <= y_true
cm__ = confusion_matrix(y_true[ok], y_pred[ok], labels=range(len(labels)))
cm_ = cm__ if cm_ is None else cm_ + cm__
cm = cm_ if cm is None else cm + cm_
assert cm_ is not None
path = os.path.join(assay, 'confusion.csv')
with open(path, 'w') as f:
f.write(",".join(labels))
for row in cm_:
f.write("\n")
f.write(",".join([str(count) for count in row]))
if uid is not None:
os.chown(path, uid, gid)
assert cm is not None
print('labels:')
print(labels)
print('confusion matrix:')
print(cm)
precision = np.diag(cm) / cm.sum(axis=0)
recall = np.diag(cm) / cm.sum(axis=1)
assert np.all(0 < precision)
assert np.all(0 < recall)
f1score = 2 * precision * recall / (precision + recall)
print('f1-scores per class:')
print(f1score)
print('f1-score:')
print(np.mean(f1score))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment