From 2fe3bd45d0f090e4e7fbfcff76b3eb162117d86d Mon Sep 17 00:00:00 2001
From: Kenzo-Hugo Hillion <kenzo-hugo.hillion1@pasteur.fr>
Date: Mon, 2 Dec 2019 16:17:51 +0100
Subject: [PATCH] add route to retrieve counts of taxonomical annotations per
 level

---
 .../api/catalog/qparams_validators/gene.py    |  9 ++++
 backend/metagenedb/api/catalog/views/gene.py  | 46 ++++++++++++++++++-
 .../metagenedb/api/catalog/views/test_gene.py | 46 ++++++++++++++++++-
 .../metagenedb/apps/catalog/factory/gene.py   |  4 +-
 .../common/utils/mocks/metagenedb.py          |  9 ++++
 5 files changed, 110 insertions(+), 4 deletions(-)

diff --git a/backend/metagenedb/api/catalog/qparams_validators/gene.py b/backend/metagenedb/api/catalog/qparams_validators/gene.py
index 160adf7..9da2277 100644
--- a/backend/metagenedb/api/catalog/qparams_validators/gene.py
+++ b/backend/metagenedb/api/catalog/qparams_validators/gene.py
@@ -1,7 +1,12 @@
 from marshmallow import Schema, fields
+from marshmallow.validate import OneOf
 
 from metagenedb.common.django_default.qparams_validators import PaginatedQueryParams
 
+TAXA_CHOICES = [
+    'superkingdom', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'
+]
+
 
 class GeneLengthQueryParams(Schema):
     window_size = fields.Integer()
@@ -11,3 +16,7 @@ class GeneLengthQueryParams(Schema):
 class GeneQueryParams(PaginatedQueryParams):
     no_taxonomy = fields.Boolean()
     no_functions = fields.Boolean()
+
+
+class TaxCountQueryParams(Schema):
+    level = fields.String(validate=OneOf(choices=TAXA_CHOICES))
diff --git a/backend/metagenedb/api/catalog/views/gene.py b/backend/metagenedb/api/catalog/views/gene.py
index dcf485e..65f9b9b 100644
--- a/backend/metagenedb/api/catalog/views/gene.py
+++ b/backend/metagenedb/api/catalog/views/gene.py
@@ -1,3 +1,5 @@
+from collections import defaultdict
+
 from django.db.models import Max
 from drf_yasg import openapi
 from drf_yasg.utils import swagger_auto_schema
@@ -8,7 +10,7 @@ from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_422_UNPROCESSABLE_EN
 
 from metagenedb.apps.catalog.models import Gene
 from metagenedb.api.catalog.filters import GeneFilter
-from metagenedb.api.catalog.qparams_validators.gene import GeneLengthQueryParams, GeneQueryParams
+from metagenedb.api.catalog.qparams_validators.gene import GeneLengthQueryParams, GeneQueryParams, TaxCountQueryParams
 from metagenedb.apps.catalog.serializers import GeneSerializer
 
 from .bulk_viewset import BulkViewSet
@@ -55,6 +57,8 @@ class GeneViewSet(BulkViewSet):
     DEFAULT_WINDOW_SIZE = 1000
     DEFAULT_STOP_AT = 10000
 
+    DEFAULT_LEVEL = 'phylum'
+
     def get_permissions(self):
         return super(self.__class__, self).get_permissions()
 
@@ -105,7 +109,6 @@ class GeneViewSet(BulkViewSet):
 
         window_size = query_params.get('window_size', self.DEFAULT_WINDOW_SIZE)
         stop_at = query_params.get('stop_at', self.DEFAULT_STOP_AT)
-        # df = read_frame(Gene.objects.all(), fieldnames=[self.GENE_LENGTH_COL])
         queryset = Gene.objects.all()
         if not queryset.exists():
             return Response(
@@ -115,3 +118,42 @@ class GeneViewSet(BulkViewSet):
         return Response(
             {'results': self._count_windows(queryset, window_size=window_size, stop_at=stop_at)}
         )
+
+    def _taxonomy_counts(self, queryset, level=DEFAULT_LEVEL):
+        filter_no_annotation = {f"taxonomy__{level}__isnull": True}
+        filter_annotation = {f"taxonomy__{level}__isnull": False}
+        value_to_retrieve = f'taxonomy__{level}__name'
+        taxonomy_counts = {}
+        taxonomy_counts['counts'] = defaultdict(lambda: 0)
+        taxonomy_counts['counts']['No annotation'] = queryset.filter(**filter_no_annotation).values().count()
+        if taxonomy_counts['counts']['No annotation'] == 0:
+            del taxonomy_counts['counts']['No annotation']
+        for value in queryset.filter(**filter_annotation).values(value_to_retrieve):
+            tax_name = value[value_to_retrieve]
+            taxonomy_counts['counts'][tax_name] += 1
+        return taxonomy_counts
+
+    @action(methods=['get'], detail=False)
+    def taxonomy_counts(self, request):
+        try:
+            query_params = TaxCountQueryParams().load(request.query_params)
+        except ValidationError as validation_error:
+            error_message = validation_error.normalized_messages()
+            error_message.update({
+                'allowed_query_params': ', '.join(GeneLengthQueryParams().declared_fields.keys())
+            })
+            return Response(error_message, status=HTTP_422_UNPROCESSABLE_ENTITY)
+
+        level = query_params.get('level', self.DEFAULT_LEVEL)
+        level = 'class_rank' if level == 'class' else level  # deal with class exception @TODO fix cleaner way
+        queryset = Gene.objects.all().select_related(f'taxonomy__{level}')
+        if not queryset.exists():
+            return Response(
+                {},
+                status=HTTP_204_NO_CONTENT
+            )
+        counts = self._taxonomy_counts(queryset, level=level)
+        counts['level'] = query_params.get('level', self.DEFAULT_LEVEL)
+        return Response(
+            {'results': counts}
+        )
diff --git a/backend/metagenedb/api/catalog/views/test_gene.py b/backend/metagenedb/api/catalog/views/test_gene.py
index 24e2ee2..e2e4421 100644
--- a/backend/metagenedb/api/catalog/views/test_gene.py
+++ b/backend/metagenedb/api/catalog/views/test_gene.py
@@ -5,7 +5,7 @@ from rest_framework import status
 from rest_framework.test import APITestCase
 from rest_framework_jwt.settings import api_settings
 
-from metagenedb.apps.catalog.factory import GeneFactory
+from metagenedb.apps.catalog.factory import GeneFactory, TaxonomyFactory
 from metagenedb.common.utils.mocks.metagenedb import MetageneDBCatalogGeneAPIMock
 
 
@@ -71,3 +71,47 @@ class TestCountWindowsAPI(APITestCase):
             'stop_at': 2000
         }
         self.assertDictEqual(self.gene_api.get_gene_length(params=query_params), expected_dict)
+
+
+class TestTaxonomyCountsAPI(APITestCase):
+
+    def setUp(self):
+        self.gene_api = MetageneDBCatalogGeneAPIMock(self.client)
+
+    def test_taxonomy_counts_no_content(self):
+        self.assertFalse(self.gene_api.get_tax_counts())
+
+    def test_taxonomy_counts_api(self):
+        tax_name = "TaxTest"
+        taxonomy = TaxonomyFactory(rank='phylum', name=tax_name)
+        taxonomy.phylum = taxonomy  # link taxonomy to itself as phylum
+        taxonomy.save()
+        gene = GeneFactory.create(taxonomy=taxonomy)  # noqa
+        expected_dict = {
+            'results': {
+                'level': 'phylum',
+                'counts': {
+                    tax_name: 1
+                }
+            }
+        }
+        self.assertDictEqual(self.gene_api.get_tax_counts(), expected_dict)
+
+    def test_taxonomy_counts_api_class_level(self):
+        tax_name = "TaxTest"
+        taxonomy = TaxonomyFactory(rank='class_rank', name=tax_name)
+        taxonomy.class_rank = taxonomy  # link taxonomy to itself as phylum
+        taxonomy.save()
+        gene = GeneFactory.create(taxonomy=taxonomy)  # noqa
+        expected_dict = {
+            'results': {
+                'level': 'class',
+                'counts': {
+                    tax_name: 1
+                }
+            }
+        }
+        query_params = {
+            'level': 'class'
+        }
+        self.assertDictEqual(self.gene_api.get_tax_counts(params=query_params), expected_dict)
diff --git a/backend/metagenedb/apps/catalog/factory/gene.py b/backend/metagenedb/apps/catalog/factory/gene.py
index ec337ea..7423229 100644
--- a/backend/metagenedb/apps/catalog/factory/gene.py
+++ b/backend/metagenedb/apps/catalog/factory/gene.py
@@ -1,9 +1,10 @@
-from factory import DjangoModelFactory, fuzzy
+from factory import DjangoModelFactory, SubFactory, fuzzy
 from faker import Factory
 
 from metagenedb.apps.catalog import models
 
 from .fuzzy_base import FuzzyLowerText
+from .taxonomy import TaxonomyFactory
 
 faker = Factory.create()
 
@@ -17,3 +18,4 @@ class GeneFactory(DjangoModelFactory):
     gene_id = FuzzyLowerText(prefix='gene-', length=15)
     gene_name = fuzzy.FuzzyText(prefix='name-', length=15)
     length = fuzzy.FuzzyInteger(200, 10000)
+    taxonomy = SubFactory(TaxonomyFactory)
diff --git a/backend/metagenedb/common/utils/mocks/metagenedb.py b/backend/metagenedb/common/utils/mocks/metagenedb.py
index 8f39b15..01f65ed 100644
--- a/backend/metagenedb/common/utils/mocks/metagenedb.py
+++ b/backend/metagenedb/common/utils/mocks/metagenedb.py
@@ -56,6 +56,15 @@ class MetageneDBCatalogGeneAPIMock(MetageneDBAPIMock):
             return {}
         return response.json()
 
+    def get_tax_counts(self, params=None):
+        reverse_path = f"{self.reverse_path}-taxonomy-counts"
+        response = self.client.get(reverse(reverse_path), params)
+        if response.status_code in self.BAD_REQUESTS:
+            raise HTTPError
+        if response.status_code == 204:  # no content
+            return {}
+        return response.json()
+
 
 class MetageneDBCatalogTaxonomyAPIMock(MetageneDBAPIMock):
     KEY_ID = 'gene_id'
-- 
GitLab