Commit 53bc313e authored by Bryan  BRANCOTTE's avatar Bryan BRANCOTTE
Browse files

more elegant way to sort the queryset, WIP #105

parent 1ae594d0
def order_queryset_specifically(*, queryset, actual_order: dict):
"""
From https://stackoverflow.com/a/51291817/2144569
:param queryset: the queryset of objects to apply a specific order
:param actual_order: a dict where key is pk, and value is the rank of the associated object
:type dict
:return: the queryset annotate such as the order will follow actual_order
"""
from django.db.models import Case, When, Value, IntegerField
return queryset.annotate(
rank=Case(
*[When(pk=pk, then=Value(actual_order[pk])) for pk in actual_order.keys()],
default=Value(len(actual_order)),
output_field=IntegerField(),
),
).order_by('rank')
......@@ -35,6 +35,7 @@ from viralhostrangedb import forms, business_process, views_api
from viralhostrangedb import mixins
from viralhostrangedb import models
from viralhostrangedb.business_process import MessageImportationObserver
from viralhostrangedb.utils import order_queryset_specifically
def index(request):
......@@ -696,16 +697,20 @@ def download_responses(request):
aggregated_responses = json.loads(response.rendered_content.decode('utf-8'))
virus = form.cleaned_data["virus"]
if virus.exists():
actual_order = dict((int(o), i) for i, o in enumerate(request.GET.getlist("virus")))
virus = sorted(virus, key=lambda o: actual_order[o.pk])
virus = order_queryset_specifically(
queryset=virus,
actual_order=dict((int(o), i) for i, o in enumerate(request.GET.getlist("virus"))),
)
else:
virus = models.Virus.objects.filter(
pk__in=aggregated_responses.keys()
).order_by('pk')
host = form.cleaned_data["host"]
if host.exists():
actual_order = dict((int(o), i) for i, o in enumerate(request.GET.getlist("host")))
host = sorted(host, key=lambda o: actual_order[o.pk])
host = order_queryset_specifically(
queryset=host,
actual_order=dict((int(o), i) for i, o in enumerate(request.GET.getlist("host"))),
)
else:
host = models.Host.objects.filter(
pk__in=itertools.chain(*[d.keys() for d in aggregated_responses.values()])
......@@ -725,11 +730,11 @@ def download_responses(request):
# get at which position each host is, mandatory if data are sparse
col_pos = dict([(pk, i) for i, pk in enumerate(
[o.pk for o in host],
host.values_list('pk', flat=True),
start=virus_infection_ratio_shift,
)])
row_pos = dict([(pk, i) for i, pk in enumerate(
[o.pk for o in virus],
virus.values_list('pk', flat=True),
start=host_infection_ratio_shift,
)])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment