Commit 8e05be86 authored by Kenzo-Hugo Hillion's avatar Kenzo-Hugo Hillion
Browse files

Merge branch '67-statistics-on-genes' into 'dev'

Add more statistical graph about gene catalog

Closes #67

See merge request !23
parents 9650a34f 584a5f59
Pipeline #19270 passed with stages
in 2 minutes and 23 seconds
from marshmallow import Schema, fields from marshmallow import Schema, fields
from marshmallow.validate import OneOf
from metagenedb.common.django_default.qparams_validators import PaginatedQueryParams from metagenedb.common.django_default.qparams_validators import PaginatedQueryParams
TAXA_CHOICES = [
'superkingdom', 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'
]
class GeneLengthQueryParams(Schema): class GeneLengthQueryParams(Schema):
window_size = fields.Integer() window_size = fields.Integer()
...@@ -11,3 +16,7 @@ class GeneLengthQueryParams(Schema): ...@@ -11,3 +16,7 @@ class GeneLengthQueryParams(Schema):
class GeneQueryParams(PaginatedQueryParams): class GeneQueryParams(PaginatedQueryParams):
no_taxonomy = fields.Boolean() no_taxonomy = fields.Boolean()
no_functions = fields.Boolean() no_functions = fields.Boolean()
class TaxCountQueryParams(Schema):
level = fields.String(validate=OneOf(choices=TAXA_CHOICES))
from django_pandas.io import read_frame from collections import defaultdict
from django.db.models import Max
from drf_yasg import openapi from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
...@@ -8,9 +10,8 @@ from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_422_UNPROCESSABLE_EN ...@@ -8,9 +10,8 @@ from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_422_UNPROCESSABLE_EN
from metagenedb.apps.catalog.models import Gene from metagenedb.apps.catalog.models import Gene
from metagenedb.api.catalog.filters import GeneFilter from metagenedb.api.catalog.filters import GeneFilter
from metagenedb.api.catalog.qparams_validators.gene import GeneLengthQueryParams, GeneQueryParams from metagenedb.api.catalog.qparams_validators.gene import GeneLengthQueryParams, GeneQueryParams, TaxCountQueryParams
from metagenedb.apps.catalog.serializers import GeneSerializer from metagenedb.apps.catalog.serializers import GeneSerializer
from metagenedb.common.utils.df_operations import get_mask
from .bulk_viewset import BulkViewSet from .bulk_viewset import BulkViewSet
...@@ -56,18 +57,21 @@ class GeneViewSet(BulkViewSet): ...@@ -56,18 +57,21 @@ class GeneViewSet(BulkViewSet):
DEFAULT_WINDOW_SIZE = 1000 DEFAULT_WINDOW_SIZE = 1000
DEFAULT_STOP_AT = 10000 DEFAULT_STOP_AT = 10000
DEFAULT_LEVEL = 'phylum'
def get_permissions(self): def get_permissions(self):
return super(self.__class__, self).get_permissions() return super(self.__class__, self).get_permissions()
def _count_windows(self, df, window_size=DEFAULT_WINDOW_SIZE, window_col=GENE_LENGTH_COL, stop_at=DEFAULT_STOP_AT): def _count_windows(self, queryset, window_size=DEFAULT_WINDOW_SIZE, window_col=GENE_LENGTH_COL,
stop_at=DEFAULT_STOP_AT):
""" """
Count how many line of the df belong to each windows defined by the window_size for the window_col Count how many entries by performing one query per range
:param df: :param queryset:
:param window_col: column concerned by the window :param window_col: column concerned by the window
:param window_size: size of the window :param window_size: size of the window
:return: {'data': COUNTS_BY_WINDOW, 'labels': START-END} :return: {'data': COUNTS_BY_WINDOW, 'labels': START-END}
""" """
length_max = df[window_col].max() length_max = queryset.aggregate(Max('length')).get('length__max', 0)
stop_at = length_max if length_max < stop_at else stop_at stop_at = length_max if length_max < stop_at else stop_at
all_ranges = [[i, i + window_size] for i in range(0, stop_at + 1, window_size)] all_ranges = [[i, i + window_size] for i in range(0, stop_at + 1, window_size)]
all_ranges[-1][1] = length_max + 1 # last should contain all above the stop_at all_ranges[-1][1] = length_max + 1 # last should contain all above the stop_at
...@@ -75,7 +79,7 @@ class GeneViewSet(BulkViewSet): ...@@ -75,7 +79,7 @@ class GeneViewSet(BulkViewSet):
labels = [] labels = []
for rg in all_ranges: for rg in all_ranges:
labels.append(f"{rg[0]/1000}k-{rg[1]/1000}k") labels.append(f"{rg[0]/1000}k-{rg[1]/1000}k")
data.append(df[get_mask(df, rg, window_col)].count()[window_col]) data.append(queryset.filter(length__gte=rg[0], length__lt=rg[1]).count())
# Change labels # Change labels
labels[0] = f"<{labels[0].split('-')[1]}" labels[0] = f"<{labels[0].split('-')[1]}"
labels[-1] = f">{labels[-1].split('-')[0]}" labels[-1] = f">{labels[-1].split('-')[0]}"
...@@ -105,12 +109,51 @@ class GeneViewSet(BulkViewSet): ...@@ -105,12 +109,51 @@ class GeneViewSet(BulkViewSet):
window_size = query_params.get('window_size', self.DEFAULT_WINDOW_SIZE) window_size = query_params.get('window_size', self.DEFAULT_WINDOW_SIZE)
stop_at = query_params.get('stop_at', self.DEFAULT_STOP_AT) stop_at = query_params.get('stop_at', self.DEFAULT_STOP_AT)
df = read_frame(Gene.objects.all(), fieldnames=[self.GENE_LENGTH_COL]) queryset = Gene.objects.all()
if df.empty: if not queryset.exists():
return Response(
{},
status=HTTP_204_NO_CONTENT
)
return Response(
{'results': self._count_windows(queryset, window_size=window_size, stop_at=stop_at)}
)
def _taxonomy_counts(self, queryset, level=DEFAULT_LEVEL):
filter_no_annotation = {f"taxonomy__{level}__isnull": True}
filter_annotation = {f"taxonomy__{level}__isnull": False}
value_to_retrieve = f'taxonomy__{level}__name'
taxonomy_counts = {}
taxonomy_counts['counts'] = defaultdict(lambda: 0)
taxonomy_counts['counts']['No annotation'] = queryset.filter(**filter_no_annotation).values().count()
if taxonomy_counts['counts']['No annotation'] == 0:
del taxonomy_counts['counts']['No annotation']
for value in queryset.filter(**filter_annotation).values(value_to_retrieve):
tax_name = value[value_to_retrieve]
taxonomy_counts['counts'][tax_name] += 1
return taxonomy_counts
@action(methods=['get'], detail=False)
def taxonomy_counts(self, request):
try:
query_params = TaxCountQueryParams().load(request.query_params)
except ValidationError as validation_error:
error_message = validation_error.normalized_messages()
error_message.update({
'allowed_query_params': ', '.join(GeneLengthQueryParams().declared_fields.keys())
})
return Response(error_message, status=HTTP_422_UNPROCESSABLE_ENTITY)
level = query_params.get('level', self.DEFAULT_LEVEL)
level = 'class_rank' if level == 'class' else level # deal with class exception @TODO fix cleaner way
queryset = Gene.objects.all().select_related(f'taxonomy__{level}')
if not queryset.exists():
return Response( return Response(
{'results': {}}, {},
status=HTTP_204_NO_CONTENT status=HTTP_204_NO_CONTENT
) )
counts = self._taxonomy_counts(queryset, level=level)
counts['level'] = query_params.get('level', self.DEFAULT_LEVEL)
return Response( return Response(
{'results': self._count_windows(df, window_size=window_size, stop_at=stop_at)} {'results': counts}
) )
import pandas as pd
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
from django.urls import reverse from django.urls import reverse
...@@ -6,8 +5,7 @@ from rest_framework import status ...@@ -6,8 +5,7 @@ from rest_framework import status
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from rest_framework_jwt.settings import api_settings from rest_framework_jwt.settings import api_settings
from metagenedb.api.catalog.views.gene import GeneViewSet from metagenedb.apps.catalog.factory import GeneFactory, TaxonomyFactory
from metagenedb.apps.catalog.factory import GeneFactory
from metagenedb.common.utils.mocks.metagenedb import MetageneDBCatalogGeneAPIMock from metagenedb.common.utils.mocks.metagenedb import MetageneDBCatalogGeneAPIMock
...@@ -41,43 +39,17 @@ class TestGenes(TestCase): ...@@ -41,43 +39,17 @@ class TestGenes(TestCase):
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
class TestCountWindows(TestCase):
def setUp(self):
self.window_col = "length"
self.df = pd.DataFrame(
[22, 29, 35],
columns=[self.window_col]
)
def test_simple_count_window10(self):
expected_dict = {
'labels': ['<0.01k', '0.01k-0.02k', '0.02k-0.03k', '>0.03k'],
'counts': [0, 0, 2, 1]
}
geneviewset = GeneViewSet()
test_dict = geneviewset._count_windows(self.df, 10, window_col=self.window_col)
self.assertDictEqual(test_dict, expected_dict)
def test_simple_count_window10_stop20(self):
expected_dict = {
'labels': ['<0.01k', '0.01k-0.02k', '>0.02k'],
'counts': [0, 0, 3]
}
geneviewset = GeneViewSet()
test_dict = geneviewset._count_windows(self.df, window_size=10,
window_col=self.window_col, stop_at=20)
self.assertDictEqual(test_dict, expected_dict)
class TestCountWindowsAPI(APITestCase): class TestCountWindowsAPI(APITestCase):
def setUp(self): def setUp(self):
self.gene_api = MetageneDBCatalogGeneAPIMock(self.client) self.gene_api = MetageneDBCatalogGeneAPIMock(self.client)
for i in range(2000, 4000, 350):
GeneFactory.create(length=i) def test_gene_length_no_content(self):
self.assertFalse(self.gene_api.get_gene_length())
def test_gene_length_api(self): def test_gene_length_api(self):
for i in range(2000, 4000, 350):
GeneFactory.create(length=i)
expected_dict = { expected_dict = {
'results': { 'results': {
'counts': [0, 0, 3, 3], 'counts': [0, 0, 3, 3],
...@@ -87,6 +59,8 @@ class TestCountWindowsAPI(APITestCase): ...@@ -87,6 +59,8 @@ class TestCountWindowsAPI(APITestCase):
self.assertDictEqual(self.gene_api.get_gene_length(), expected_dict) self.assertDictEqual(self.gene_api.get_gene_length(), expected_dict)
def test_gene_length_api_stop_at_2000(self): def test_gene_length_api_stop_at_2000(self):
for i in range(2000, 4000, 350):
GeneFactory.create(length=i)
expected_dict = { expected_dict = {
'results': { 'results': {
'counts': [0, 0, 6], 'counts': [0, 0, 6],
...@@ -97,3 +71,47 @@ class TestCountWindowsAPI(APITestCase): ...@@ -97,3 +71,47 @@ class TestCountWindowsAPI(APITestCase):
'stop_at': 2000 'stop_at': 2000
} }
self.assertDictEqual(self.gene_api.get_gene_length(params=query_params), expected_dict) self.assertDictEqual(self.gene_api.get_gene_length(params=query_params), expected_dict)
class TestTaxonomyCountsAPI(APITestCase):
def setUp(self):
self.gene_api = MetageneDBCatalogGeneAPIMock(self.client)
def test_taxonomy_counts_no_content(self):
self.assertFalse(self.gene_api.get_tax_counts())
def test_taxonomy_counts_api(self):
tax_name = "TaxTest"
taxonomy = TaxonomyFactory(rank='phylum', name=tax_name)
taxonomy.phylum = taxonomy # link taxonomy to itself as phylum
taxonomy.save()
gene = GeneFactory.create(taxonomy=taxonomy) # noqa
expected_dict = {
'results': {
'level': 'phylum',
'counts': {
tax_name: 1
}
}
}
self.assertDictEqual(self.gene_api.get_tax_counts(), expected_dict)
def test_taxonomy_counts_api_class_level(self):
tax_name = "TaxTest"
taxonomy = TaxonomyFactory(rank='class_rank', name=tax_name)
taxonomy.class_rank = taxonomy # link taxonomy to itself as phylum
taxonomy.save()
gene = GeneFactory.create(taxonomy=taxonomy) # noqa
expected_dict = {
'results': {
'level': 'class',
'counts': {
tax_name: 1
}
}
}
query_params = {
'level': 'class'
}
self.assertDictEqual(self.gene_api.get_tax_counts(params=query_params), expected_dict)
from factory import DjangoModelFactory, fuzzy from factory import DjangoModelFactory, SubFactory, fuzzy
from faker import Factory from faker import Factory
from metagenedb.apps.catalog import models from metagenedb.apps.catalog import models
from .fuzzy_base import FuzzyLowerText from .fuzzy_base import FuzzyLowerText
from .taxonomy import TaxonomyFactory
faker = Factory.create() faker = Factory.create()
...@@ -17,3 +18,4 @@ class GeneFactory(DjangoModelFactory): ...@@ -17,3 +18,4 @@ class GeneFactory(DjangoModelFactory):
gene_id = FuzzyLowerText(prefix='gene-', length=15) gene_id = FuzzyLowerText(prefix='gene-', length=15)
gene_name = fuzzy.FuzzyText(prefix='name-', length=15) gene_name = fuzzy.FuzzyText(prefix='name-', length=15)
length = fuzzy.FuzzyInteger(200, 10000) length = fuzzy.FuzzyInteger(200, 10000)
taxonomy = SubFactory(TaxonomyFactory)
...@@ -52,6 +52,17 @@ class MetageneDBCatalogGeneAPIMock(MetageneDBAPIMock): ...@@ -52,6 +52,17 @@ class MetageneDBCatalogGeneAPIMock(MetageneDBAPIMock):
response = self.client.get(reverse(reverse_path), params) response = self.client.get(reverse(reverse_path), params)
if response.status_code in self.BAD_REQUESTS: if response.status_code in self.BAD_REQUESTS:
raise HTTPError raise HTTPError
if response.status_code == 204: # no content
return {}
return response.json()
def get_tax_counts(self, params=None):
reverse_path = f"{self.reverse_path}-taxonomy-counts"
response = self.client.get(reverse(reverse_path), params)
if response.status_code in self.BAD_REQUESTS:
raise HTTPError
if response.status_code == 204: # no content
return {}
return response.json() return response.json()
......
<template>
<v-card class="pa-2">
<div class="card-body">
<div class="text-xs-center" v-if="noGraph">
<v-progress-circular
indeterminate
color="secondary"
></v-progress-circular>
</div>
<div>
<canvas :id="chartId"></canvas>
</div>
<v-layout
justify-space-around
row
>
<v-flex sm3>
<v-switch
v-model="hideLegend"
label="Hide legend">
</v-switch>
</v-flex>
<v-flex sm8>
<v-select
v-model="hideLabels"
:items="this.labels"
attach
chips
:label="hideLabelsLabel"
multiple
clearable
></v-select>
</v-flex>
</v-layout>
</div>
</v-card>
</template>
<script>
import Chart from 'chart.js';
export default {
props: {
doughnutData: {
type: Object,
required: true,
},
chartId: String,
},
data() {
return {
myChart: {},
options: {},
colors: [],
hideLegend: false,
labels: [],
hideLabels: [],
noGraph: true,
}
},
computed: {
hideLabelsLabel() {
if (Object.entries(this.doughnutData).length == 0) {
return "Hide fields";
} else {
return "Hide " + this.doughnutData.level;
}
},
displayLegend() {
return !this.hideLegend;
}
},
methods: {
createChart() {
const ctx = document.getElementById(this.chartId);
this.myChart = new Chart(ctx, {
type: 'doughnut',
});
},
updateChart() {
this.labels = Object.keys(this.doughnutData.data);
this.generateColorList();
this.updateChartOptions();
this.updateChartData();
},
generateColorList() { // # TODO get out of this and refactor -> return list of colors
this.color = [];
for (var i = 0; i<Object.keys(this.doughnutData.data).length; i++) {
this.colors.push(this.generateColor())
}
},
generateColor() { // Same as above
var letters = '0123456789ABCDEF';
var color = '#';
for (var i = 0; i < 6; i++) {
color += letters[Math.floor(Math.random() * 16)];
}
return color;
},
updateChartOptions() {
this.options = {
legend: {
display: this.displayLegend,
}
};
this.myChart.options = this.options;
this.myChart.update();
},
updateChartData() {
const dataDict = Object.assign({}, this.doughnutData.data);
for (let i=0; i < this.hideLabels.length; i++) {
delete dataDict[this.hideLabels[i]];
};
this.myChart.data = {
labels: Object.keys(dataDict),
datasets: [
{
data: Object.values(dataDict),
backgroundColor: this.colors,
},
],
};
this.myChart.update();
},
},
watch: {
doughnutData(val) {
if (this.noGraph) {
this.noGraph = false;
this.createChart();
}
this.updateChart();
},
hideLegend(val) {
this.updateChartOptions();
},
hideLabels() {
this.updateChartData();
}
},
};
</script>
<template> <template>
<v-flex xs12 md6 xl4> <v-card class="pa-2">
<v-toolbar <div class="card-body">
:class="histoData.class" <div class="text-xs-center" v-if="noGraph">
dark <v-progress-circular
v-if="histoData.data" indeterminate
> color="secondary"
<v-icon class="white--text">{{ histoData.icon }}</v-icon> ></v-progress-circular>
<v-toolbar-title>{{ histoData.title }}</v-toolbar-title>
</v-toolbar>
<v-card class="pa-2">
<div class="card-body">
<div v-if="histoData.data">
<canvas :id="histoData.chart_id"></canvas>
</div>
<div class="text-xs-center" v-else>
<v-progress-circular
indeterminate
color="secondary"
></v-progress-circular>
</div>
</div> </div>
</v-card> <div>
</v-flex> <canvas :id="chartId"></canvas>
</div>
</div>
</v-card>
</template> </template>
<script> <script>
...@@ -33,14 +23,27 @@ export default { ...@@ -33,14 +23,27 @@ export default {
type: Object, type: Object,
required: true, required: true,
}, },
chartId: String,
}, },
updated() { data() {
this.createChart(this.histoData.chart_id); return {
myChart: {},
options: {},
noGraph: true,
}
}, },
methods: { methods: {
createChart(chartId) { createChart() {
const ctx = document.getElementById(chartId); const ctx = document.getElementById(this.chartId);
const histoData = { this.myChart = new Chart(ctx, {
type: 'bar',
});
},
updateChart() {
this.updateChartData();
},
updateChartData() {
this.myChart.data = {
labels: this.histoData.labels, labels: this.histoData.labels,
datasets: [ datasets: [
{ {
...@@ -51,11 +54,16 @@ export default { ...@@ -51,11 +54,16 @@ export default {
}, },
], ],
}; };
// eslint-disable-next-line this.myChart.update();
const myChart = new Chart(ctx, { },
type: 'bar', },
data: histoData, watch: {
}); histoData(val) {
if (this.noGraph) {
this.noGraph = false;