diff --git a/backend/metagenedb/api/catalog/qparams_validators/gene.py b/backend/metagenedb/api/catalog/qparams_validators/gene.py index 160adf7eeaa2f533eda16f1365f07da8b6817214..9da22772f3a331d8dcb8e49c17f51ad0acc0611c 100644 --- a/backend/metagenedb/api/catalog/qparams_validators/gene.py +++ b/backend/metagenedb/api/catalog/qparams_validators/gene.py @@ -1,7 +1,12 @@ from marshmallow import Schema, fields +from marshmallow.validate import OneOf from metagenedb.common.django_default.qparams_validators import PaginatedQueryParams +TAXA_CHOICES = [ + 'superkingdom', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species' +] + class GeneLengthQueryParams(Schema): window_size = fields.Integer() @@ -11,3 +16,7 @@ class GeneLengthQueryParams(Schema): class GeneQueryParams(PaginatedQueryParams): no_taxonomy = fields.Boolean() no_functions = fields.Boolean() + + +class TaxCountQueryParams(Schema): + level = fields.String(validate=OneOf(choices=TAXA_CHOICES)) diff --git a/backend/metagenedb/api/catalog/views/gene.py b/backend/metagenedb/api/catalog/views/gene.py index dcf485ea771fd551a1be2c4b1240db40e34d0cdb..65f9b9bd6e73850ac5f0140a895368a20daf1f4c 100644 --- a/backend/metagenedb/api/catalog/views/gene.py +++ b/backend/metagenedb/api/catalog/views/gene.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from django.db.models import Max from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema @@ -8,7 +10,7 @@ from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_422_UNPROCESSABLE_EN from metagenedb.apps.catalog.models import Gene from metagenedb.api.catalog.filters import GeneFilter -from metagenedb.api.catalog.qparams_validators.gene import GeneLengthQueryParams, GeneQueryParams +from metagenedb.api.catalog.qparams_validators.gene import GeneLengthQueryParams, GeneQueryParams, TaxCountQueryParams from metagenedb.apps.catalog.serializers import GeneSerializer from .bulk_viewset import BulkViewSet @@ -55,6 +57,8 @@ class GeneViewSet(BulkViewSet): DEFAULT_WINDOW_SIZE = 1000 DEFAULT_STOP_AT = 10000 + DEFAULT_LEVEL = 'phylum' + def get_permissions(self): return super(self.__class__, self).get_permissions() @@ -105,7 +109,6 @@ class GeneViewSet(BulkViewSet): window_size = query_params.get('window_size', self.DEFAULT_WINDOW_SIZE) stop_at = query_params.get('stop_at', self.DEFAULT_STOP_AT) - # df = read_frame(Gene.objects.all(), fieldnames=[self.GENE_LENGTH_COL]) queryset = Gene.objects.all() if not queryset.exists(): return Response( @@ -115,3 +118,42 @@ class GeneViewSet(BulkViewSet): return Response( {'results': self._count_windows(queryset, window_size=window_size, stop_at=stop_at)} ) + + def _taxonomy_counts(self, queryset, level=DEFAULT_LEVEL): + filter_no_annotation = {f"taxonomy__{level}__isnull": True} + filter_annotation = {f"taxonomy__{level}__isnull": False} + value_to_retrieve = f'taxonomy__{level}__name' + taxonomy_counts = {} + taxonomy_counts['counts'] = defaultdict(lambda: 0) + taxonomy_counts['counts']['No annotation'] = queryset.filter(**filter_no_annotation).values().count() + if taxonomy_counts['counts']['No annotation'] == 0: + del taxonomy_counts['counts']['No annotation'] + for value in queryset.filter(**filter_annotation).values(value_to_retrieve): + tax_name = value[value_to_retrieve] + taxonomy_counts['counts'][tax_name] += 1 + return taxonomy_counts + + @action(methods=['get'], detail=False) + def taxonomy_counts(self, request): + try: + query_params = TaxCountQueryParams().load(request.query_params) + except ValidationError as validation_error: + error_message = validation_error.normalized_messages() + error_message.update({ + 'allowed_query_params': ', '.join(GeneLengthQueryParams().declared_fields.keys()) + }) + return Response(error_message, status=HTTP_422_UNPROCESSABLE_ENTITY) + + level = query_params.get('level', self.DEFAULT_LEVEL) + level = 'class_rank' if level == 'class' else level # deal with class exception @TODO fix cleaner way + queryset = Gene.objects.all().select_related(f'taxonomy__{level}') + if not queryset.exists(): + return Response( + {}, + status=HTTP_204_NO_CONTENT + ) + counts = self._taxonomy_counts(queryset, level=level) + counts['level'] = query_params.get('level', self.DEFAULT_LEVEL) + return Response( + {'results': counts} + ) diff --git a/backend/metagenedb/api/catalog/views/test_gene.py b/backend/metagenedb/api/catalog/views/test_gene.py index 24e2ee252eaa0e6239ea21fd90d94d9ae40a0564..e2e442150d04da56e6b6a21d856929ed539cbc6b 100644 --- a/backend/metagenedb/api/catalog/views/test_gene.py +++ b/backend/metagenedb/api/catalog/views/test_gene.py @@ -5,7 +5,7 @@ from rest_framework import status from rest_framework.test import APITestCase from rest_framework_jwt.settings import api_settings -from metagenedb.apps.catalog.factory import GeneFactory +from metagenedb.apps.catalog.factory import GeneFactory, TaxonomyFactory from metagenedb.common.utils.mocks.metagenedb import MetageneDBCatalogGeneAPIMock @@ -71,3 +71,47 @@ class TestCountWindowsAPI(APITestCase): 'stop_at': 2000 } self.assertDictEqual(self.gene_api.get_gene_length(params=query_params), expected_dict) + + +class TestTaxonomyCountsAPI(APITestCase): + + def setUp(self): + self.gene_api = MetageneDBCatalogGeneAPIMock(self.client) + + def test_taxonomy_counts_no_content(self): + self.assertFalse(self.gene_api.get_tax_counts()) + + def test_taxonomy_counts_api(self): + tax_name = "TaxTest" + taxonomy = TaxonomyFactory(rank='phylum', name=tax_name) + taxonomy.phylum = taxonomy # link taxonomy to itself as phylum + taxonomy.save() + gene = GeneFactory.create(taxonomy=taxonomy) # noqa + expected_dict = { + 'results': { + 'level': 'phylum', + 'counts': { + tax_name: 1 + } + } + } + self.assertDictEqual(self.gene_api.get_tax_counts(), expected_dict) + + def test_taxonomy_counts_api_class_level(self): + tax_name = "TaxTest" + taxonomy = TaxonomyFactory(rank='class_rank', name=tax_name) + taxonomy.class_rank = taxonomy # link taxonomy to itself as phylum + taxonomy.save() + gene = GeneFactory.create(taxonomy=taxonomy) # noqa + expected_dict = { + 'results': { + 'level': 'class', + 'counts': { + tax_name: 1 + } + } + } + query_params = { + 'level': 'class' + } + self.assertDictEqual(self.gene_api.get_tax_counts(params=query_params), expected_dict) diff --git a/backend/metagenedb/apps/catalog/factory/gene.py b/backend/metagenedb/apps/catalog/factory/gene.py index ec337ea510d7406d6a3a6c9acf536a56380091ff..7423229bfb649f1fb7a58587ba2af7390e012e85 100644 --- a/backend/metagenedb/apps/catalog/factory/gene.py +++ b/backend/metagenedb/apps/catalog/factory/gene.py @@ -1,9 +1,10 @@ -from factory import DjangoModelFactory, fuzzy +from factory import DjangoModelFactory, SubFactory, fuzzy from faker import Factory from metagenedb.apps.catalog import models from .fuzzy_base import FuzzyLowerText +from .taxonomy import TaxonomyFactory faker = Factory.create() @@ -17,3 +18,4 @@ class GeneFactory(DjangoModelFactory): gene_id = FuzzyLowerText(prefix='gene-', length=15) gene_name = fuzzy.FuzzyText(prefix='name-', length=15) length = fuzzy.FuzzyInteger(200, 10000) + taxonomy = SubFactory(TaxonomyFactory) diff --git a/backend/metagenedb/common/utils/mocks/metagenedb.py b/backend/metagenedb/common/utils/mocks/metagenedb.py index 8f39b15b95c8a352579b214d1c68a16d0be03cc1..01f65ed17b35509a45149d9a4cd93ec1fc32eef4 100644 --- a/backend/metagenedb/common/utils/mocks/metagenedb.py +++ b/backend/metagenedb/common/utils/mocks/metagenedb.py @@ -56,6 +56,15 @@ class MetageneDBCatalogGeneAPIMock(MetageneDBAPIMock): return {} return response.json() + def get_tax_counts(self, params=None): + reverse_path = f"{self.reverse_path}-taxonomy-counts" + response = self.client.get(reverse(reverse_path), params) + if response.status_code in self.BAD_REQUESTS: + raise HTTPError + if response.status_code == 204: # no content + return {} + return response.json() + class MetageneDBCatalogTaxonomyAPIMock(MetageneDBAPIMock): KEY_ID = 'gene_id'