Skip to content
Snippets Groups Projects
Select Git revision
  • 1df5541d2627741252fd7d14626a7b55ff473ccc
  • main default protected
  • dev
  • make_dataset
  • v0.19
  • v0.18.2
  • v0.18.1
  • v0.18
  • v0.17
  • v0.16
  • v0.15.3
  • v0.15.2
  • v0.15.1
  • v0.15
  • v0.14.1
  • v0.14
  • v0.13.1
  • v0.13
  • v0.12.4
  • v0.12.3
  • v0.12.2
  • v0.12.1
  • v0.12
  • v0.11.1
24 results

confusion.py

Blame
  • confusion.py 2.56 KiB
    import os
    from glob import glob
    import numpy as np
    from sklearn.metrics import confusion_matrix
    from taggingbackends.data.labels import Labels
    from taggingbackends.data.dataset import LarvaDataset
    from taggingbackends.explorer import BackendExplorer
    
    """
    Generic function for true labels possibly in the shape of tag arrays.
    """
    def index(labels, tags):
        if isinstance(tags, str):
            if tags == 'edited':
                # probably a (manual) tagging mistake
                return -1
            else:
                return labels.index(tags)
        else:
            for i, label in enumerate(labels):
                if label in tags:
                    return i
        print("Incompatible labels")
        print("  expected labels:")
        print(labels)
        print("  labels at a time step:")
        print(tags)
        return -1
    
    uid, gid = os.getenv('HOST_UID', None), os.getenv('HOST_GID', None)
    if uid is not None:
        uid, gid = int(uid), int(gid)
    
    labels = None
    cm = None
    
    fn_true = 'groundtruth.label'
    fn_pred = 'predicted.label'
    
    for assay, _, files in os.walk(f'/data'):
        if any([fn == fn_true for fn in files]) and any([fn == fn_pred for fn in files]):
            expected = Labels(os.path.join(assay, fn_true))
            predicted = Labels(os.path.join(assay, fn_pred))
    
            if labels is None:
                labels = predicted.labelspec
            else:
                assert labels == predicted.labelspec
    
            cm_ = None
            for larva in expected:
                # note: not all the larvae in `expected` may be in `predicted`
                y_pred = np.array([labels.index(label) for label in predicted[larva].values()])
                y_true= np.array([index(labels, tags) for tags in expected[larva].values()])
                ok = 0 <= y_true
                cm__ = confusion_matrix(y_true[ok], y_pred[ok], labels=range(len(labels)))
                cm_ = cm__ if cm_ is None else cm_ + cm__
                cm = cm_ if cm is None else cm + cm_
    
            assert cm_ is not None
    
            path = os.path.join(assay, 'confusion.csv')
            with open(path, 'w') as f:
                f.write(",".join(labels))
                for row in cm_:
                    f.write("\n")
                    f.write(",".join([str(count) for count in row]))
    
            if uid is not None:
                os.chown(path, uid, gid)
    
    assert cm is not None
    print('labels:')
    print(labels)
    print('confusion matrix:')
    print(cm)
    
    precision = np.diag(cm) / cm.sum(axis=0)
    recall = np.diag(cm) / cm.sum(axis=1)
    assert np.all(0 < precision)
    assert np.all(0 < recall)
    f1score = 2 * precision * recall / (precision + recall)
    print('f1-scores per class:')
    print(f1score)
    print('f1-score:')
    print(np.mean(f1score))