views_api.py 34.6 KB
Newer Older
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
1
import itertools
2
import math
3
import re
4
import time
5
from functools import reduce
6
from typing import List
7
from urllib.parse import urlencode
8

9
from django.contrib.auth.models import AnonymousUser
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
10
from django.db import connection
11
12
from django.db.models import Avg, Q, Count, Func, When, IntegerField, Case, Value, Max, Min, CharField, FloatField, \
    BooleanField
13
from django.db.models.functions import Concat, Coalesce, Replace, Length, Upper, Cast
14
from django.db.models.lookups import Transform
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
15
from django.http import JsonResponse, HttpResponseBadRequest
16
from django.urls import reverse
17
from rest_framework import views as drf_views, viewsets
18
19
20
from rest_framework.filters import BaseFilterBackend
from rest_framework.response import Response

21
from viralhostrangedb import models, mixins, serializers, forms, business_process
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


class MyAPIView(drf_views.APIView):
    queryset = None
    filter_backends = []

    def get_queryset(self):
        queryset = self.queryset
        return queryset

    def filter_queryset(self, queryset):
        """
        Given a queryset, filter it with whichever filter backend is in use.
        """
        for backend in list(self.filter_backends):
            queryset = backend().filter_queryset(self.request, queryset, self)
        return queryset

40
    def get(self, request, *args, **kwargs):
41
42
43
        raise NotImplementedError()


44
45
46
47
48
49
50
51
def to_int_array(input_list: List):
    for a in input_list:
        try:
            yield int(a)
        except ValueError:
            pass


Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
52
53
54
55
56
57
58
59
def filter_ds(queryset, path_to_data_source, qd, **kwargs):
    if not "ds" in qd:
        return queryset
    getlist = reduce(lambda x, y: x + y, [l.split(',') for l in qd.getlist("ds")])
    return queryset.filter(**{path_to_data_source + "pk__in": to_int_array(getlist)})


def filter_owner(queryset, path_to_data_source, qd, **kwargs):
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
60
    if not "owner" in qd:
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return queryset
    getlist = reduce(lambda x, y: x + y, [l.split(',') for l in qd.getlist("owner")])
    return queryset.filter(**{path_to_data_source + "owner__pk__in": to_int_array(getlist)})


def filter_host(queryset, path_to_host, qd, **kwargs):
    if not "host" in qd:
        return queryset
    getlist = reduce(lambda x, y: x + y, [l.split(',') for l in qd.getlist("host")])
    return queryset.filter(**{path_to_host + "pk__in": to_int_array(getlist)})


def filter_virus(queryset, path_to_virus, qd, **kwargs):
    if "virus" not in qd:
        return queryset
    getlist = reduce(lambda x, y: x + y, [l.split(',') for l in qd.getlist("virus")])
    return queryset.filter(**{path_to_virus + "pk__in": to_int_array(getlist)})


def filter_only_published_data(queryset, path_to_data_source, qd, **kwargs):
    if "only_published_data" not in qd:
        return queryset
83
    return queryset.filter(**{path_to_data_source + "publication_url__isnull": False})
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110


def filter_only_virus_ncbi_id(queryset, path_to_virus, qd, **kwargs):
    if "only_virus_ncbi_id" not in qd:
        return queryset
    return queryset.filter(**{path_to_virus + "is_ncbi_identifier_value": True})


def filter_only_host_ncbi_id(queryset, path_to_host, qd, **kwargs):
    if "only_host_ncbi_id" not in qd:
        return queryset
    return queryset.filter(**{path_to_host + "is_ncbi_identifier_value": True})


def filter_life_domain(queryset, path_to_data_source, qd, **kwargs):
    if "life_domain" not in qd or len(qd["life_domain"]) == 0:
        return queryset
    return queryset.filter(**{path_to_data_source + "life_domain": qd["life_domain"]})


class MetaFilterResponseBackend(BaseFilterBackend):
    paths = dict(
        path_to_data_source="data_source__",
        path_to_virus="virus__",
        path_to_host="host__",
    )

111
    def filter_queryset(self, request, queryset, view):
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
112
113
114
115
116
117
118
119
120
121
        filter_kwargs = dict(qd=request.GET)
        filter_kwargs.update(self.paths)
        queryset = filter_ds(queryset=queryset, **filter_kwargs)
        queryset = filter_owner(queryset=queryset, **filter_kwargs)
        queryset = filter_host(queryset=queryset, **filter_kwargs)
        queryset = filter_virus(queryset=queryset, **filter_kwargs)
        queryset = filter_only_published_data(queryset=queryset, **filter_kwargs)
        queryset = filter_only_virus_ncbi_id(queryset=queryset, **filter_kwargs)
        queryset = filter_only_host_ncbi_id(queryset=queryset, **filter_kwargs)
        queryset = filter_life_domain(queryset=queryset, **filter_kwargs)
122
123
124
        return queryset


Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class FromGetParamsFilterResponseBackend(MetaFilterResponseBackend):
    paths = dict(
        path_to_data_source="data_source__",
        path_to_virus="virus__",
        path_to_host="host__",
    )

    def filter_queryset(self, request, queryset, view):
        if "allow_overflow" not in request.GET \
                and "ds" not in request.GET \
                and "host" not in request.GET \
                and "virus" not in request.GET:
            return queryset.none()
        return super().filter_queryset(request, queryset, view)


class FromGetParamsFilterVirusBackend(MetaFilterResponseBackend):
    paths = dict(
        path_to_data_source="data_source__",
        path_to_virus="",
        path_to_host="responseindatasource__host__",
    )


class FromGetParamsFilterHostBackend(MetaFilterResponseBackend):
    paths = dict(
        path_to_data_source="data_source__",
        path_to_virus="responseindatasource__virus__",
        path_to_host="",
    )


157
158
159
160
161
162
class FromGetParamsFilterDataSourceBackend(MetaFilterResponseBackend):
    paths = dict(
        path_to_data_source="",
        path_to_virus="virus__",
        path_to_host="host__",
    )
163
164


165
class OnlyPublicOrGrantedOrOwnedFilterBackend(BaseFilterBackend):
166
    def filter_queryset(self, request, queryset, view):
167
        return mixins.only_public_or_granted_or_owned_queryset_filter(
168
169
170
171
172
173
174
            self,
            request=request,
            queryset=queryset,
            path_to_data_source="data_source__",
        )


175
176
class OnlyMappedResponseFilterBackend(BaseFilterBackend):
    def filter_queryset(self, request, queryset, view):
177
        return queryset.filter(~Q(response__pk=models.GlobalViralHostResponseValue.get_not_mapped_yet_pk()))
178
179


180
181
182
class Round(Func):
    function = 'ROUND'
    arity = 2
183
184
    arg_joiner = '::numeric, '

Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
185
186
187
    def __init__(self, *expressions,  **extra):
        super().__init__(*expressions, output_field=FloatField(), **extra)

188
189
    def as_sqlite(self, compiler, connection, **extra_context):
        return super().as_sqlite(compiler, connection, arg_joiner=", ", **extra_context)
190
191


192
class AggregatedResponseViewSet(MyAPIView):
193
    queryset = models.ViralHostResponseValueInDataSource.objects
194
    filter_backends = [OnlyMappedResponseFilterBackend, FromGetParamsFilterResponseBackend,
195
                       OnlyPublicOrGrantedOrOwnedFilterBackend]
196
197
198

    def get(self, request):
        queryset = self.filter_queryset(self.get_queryset())
199
        queryset = queryset.values("virus__pk", "host__pk") \
200
            .annotate(val=Round(Avg('response__value'), 2)) \
201
            .annotate(diff=Count('response__value', distinct=True)) \
202
203
204
            .order_by('virus__pk')

        table = {}
205
        hosts = None
206
207
208
209
        last_virus = None
        for entry in queryset:
            virus_pk = entry['virus__pk']
            if last_virus != virus_pk:
210
                hosts = {}
211
                last_virus = virus_pk
212
213
                table[virus_pk] = hosts
            hosts[entry['host__pk']] = dict(val=entry['val'], diff=entry['diff'])
214
215

        return Response(table)
216
217
218


class CompleteResponseViewSet(MyAPIView):
219
    queryset = models.ViralHostResponseValueInDataSource.objects
220
    filter_backends = [OnlyMappedResponseFilterBackend, FromGetParamsFilterResponseBackend,
221
                       OnlyPublicOrGrantedOrOwnedFilterBackend]
222
223
224

    def get(self, request):
        queryset = self.filter_queryset(self.get_queryset())
225
226
        queryset = queryset.values("virus__pk", "host__pk", "data_source__pk", "response__value") \
            .order_by('virus__pk', 'host__pk', 'data_source__pk', )
227
228

        table = {}
229
        hosts = None
230
231
232
        virus = None
        response = None
        last_virus = None
233
        last_host = None
234
235
236
        for entry in queryset:
            virus_pk = entry['virus__pk']
            if last_virus != virus_pk:
237
                hosts = {}
238
                last_virus = virus_pk
239
                table[virus_pk] = hosts
240
                last_host = None
241
242
            host_pk = entry['host__pk']
            if last_host != host_pk:
243
                response = {}
244
245
                last_host = host_pk
                hosts[host_pk] = response
Bryan  BRANCOTTE's avatar
typo    
Bryan BRANCOTTE committed
246
            response[entry['data_source__pk']] = entry['response__value']
247
248

        return Response(table)
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267


class PKFromGetParamsFilterBackend(BaseFilterBackend):
    def filter_queryset(self, request, queryset, view):
        if "pks" in request.GET:
            queryset = queryset.filter(pk__in=request.GET["pks"].split(','))
        if "pk" in request.GET:
            queryset = queryset.filter(pk__in=request.GET["pk"].split(','))
        if "ids" in request.GET:
            queryset = queryset.filter(pk__in=request.GET["ids"].split(','))
        if "id" in request.GET:
            queryset = queryset.filter(pk__in=request.GET["id"].split(','))
        return queryset


class VirusViewSet(viewsets.ReadOnlyModelViewSet):
    queryset = models.Virus.objects
    serializer_class = serializers.VirusSerializer
    filter_backends = [
268
        OnlyPublicOrGrantedOrOwnedFilterBackend,
269
        PKFromGetParamsFilterBackend,
270
        FromGetParamsFilterVirusBackend,
271
272
273
    ]


274
275
276
class HostViewSet(viewsets.ReadOnlyModelViewSet):
    queryset = models.Host.objects
    serializer_class = serializers.HostSerializer
277
    filter_backends = [
278
        OnlyPublicOrGrantedOrOwnedFilterBackend,
279
        PKFromGetParamsFilterBackend,
280
        FromGetParamsFilterHostBackend,
281
282
283
    ]


284
class OnlyPublicOrGrantedOrOwnedDataSourceFilterBackend(BaseFilterBackend):
285
    def filter_queryset(self, request, queryset, view):
286
        return mixins.only_public_or_granted_or_owned_queryset_filter(
287
288
289
290
291
292
293
294
295
296
            self,
            request=request,
            queryset=queryset,
        )


class DataSourceViewSet(viewsets.ReadOnlyModelViewSet):
    queryset = models.DataSource.objects
    serializer_class = serializers.DataSourceSerializer
    filter_backends = [
297
        OnlyPublicOrGrantedOrOwnedDataSourceFilterBackend,
298
        PKFromGetParamsFilterBackend,
299
        FromGetParamsFilterDataSourceBackend,
300
    ]
301
302
303


class VirusInfectionRatioViewSet(MyAPIView):
304
    data_source_aggregated = True
305
    queryset = models.ViralHostResponseValueInDataSource.objects
306
    filter_backends = [OnlyMappedResponseFilterBackend, FromGetParamsFilterResponseBackend,
307
                       OnlyPublicOrGrantedOrOwnedFilterBackend]
308
309
310
311

    def get(self, request, slug=None, slug_pk=None, *args, **kwargs):
        queryset = self.get_queryset()

312
313
314
315
316
        other_slug = None
        if slug == 'virus':
            other_slug = 'host'
        elif slug == 'host':
            other_slug = 'virus'
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
317
318
        else:
            return HttpResponseBadRequest("unknown slug '%s'" % slug)
319

320
321
322
323
324
        # filter only for the current slug
        if slug_pk != None:
            if slug == "virus":
                queryset = queryset.filter(virus__pk=int(str(slug_pk)))
            elif slug == "host":
325
                queryset = queryset.filter(host__pk=int(str(slug_pk)))
326
            # As the FromGetParamsFilterResponseBackend prevent to return too much data when no GET parameter is provided we
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
            # explicitly indicate that even if there is no GET parameter, we should not return an empty queryset.
            request.GET = request.GET.copy()
            request.GET.update({"allow_overflow": True})

        # filter with_backend
        queryset = self.filter_queryset(queryset)

        # picking the minimum global response value to consider that there is an infection
        if "weak_infection" in request.GET:
            # an infection is everything except the lower (i.e: No Lysis)
            min_level = models.GlobalViralHostResponseValue.objects_mappable().aggregate(v=Min("value"))['v']
            min_level = models.GlobalViralHostResponseValue.objects_mappable().filter(value__gt=min_level) \
                .aggregate(v=Min("value"))['v']
        else:
            # an infection is only the higher (i.e: Lysis)
            min_level = models.GlobalViralHostResponseValue.objects_mappable().aggregate(v=Max("value"))['v']

        # annotating the query such as infection is now 0/1
        queryset = queryset.annotate(infection=Case(
            When(response__value__gte=min_level, then=1),
            default=Value(0),
            output_field=IntegerField(),
        ))

        # determine how the response should be aggregated, do we enforce consensus or not
        if "agreed_infection" in request.GET:
            infection_aggregated = Min('infection')
        else:
            infection_aggregated = Max('infection')

357
358
359
360
361
362
        # aggregate per host/virus the infections
        if self.data_source_aggregated:
            queryset = queryset.values('host__pk', 'virus__pk')
        else:
            queryset = queryset.values('host__pk', 'virus__pk', 'data_source__pk')
        queryset = queryset.annotate(infection_aggregated=infection_aggregated)
363
364
365

        # Store infection in a dict where key in the pk, and value the count of infection and no infection
        infection_aggregated_counters = {}
366
        infection_aggregated_counters_on_data_source = None if self.data_source_aggregated else {}
367

368
369
370
371
372
373
        if self.data_source_aggregated:
            # we will get either virus__pk or host__pk, so suffixing the slug
            aggregation_key = slug + "__pk"
        else:
            # we will get the aggregation over the other slug: for a host, we get for each virus.
            aggregation_key = other_slug + "__pk"
374
375
376
377

        # run the query and fetch results
        for o in queryset:
            # infection is either yes(1) or no(0)
378
            infection_aggregated_counter = infection_aggregated_counters.setdefault(o[aggregation_key], [0, 0])
379
380
            # count infection or not an infection
            infection_aggregated_counter[o['infection_aggregated']] += 1
381
382
383
384
385
            if infection_aggregated_counters_on_data_source is not None:
                infection_aggregated_counter = infection_aggregated_counters_on_data_source \
                    .setdefault(o['data_source__pk'], [0, 0])
                # count infection or not an infection
                infection_aggregated_counter[o['infection_aggregated']] += 1
386
387

        # build the final ratio
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        self.aggregate_infection_counters(infection_aggregated_counters)

        if self.data_source_aggregated:
            # return the ratios as a json
            return Response(infection_aggregated_counters)

        self.aggregate_infection_counters(infection_aggregated_counters_on_data_source)

        return Response({
            other_slug: infection_aggregated_counters,
            "data_source": infection_aggregated_counters_on_data_source
        })

    def aggregate_infection_counters(self, infection_aggregated_counters):
402
403
404
405
406
407
408
        for k, infection_aggregated_counter in infection_aggregated_counters.items():
            s = sum(infection_aggregated_counter)
            # compute the ratio as it is what we want, keep the total in case third part service need it
            infection_aggregated_counters[k] = dict(
                ratio=infection_aggregated_counter[1] / s,
                total=s,
            )
409
410


411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def remplace_greek_letters_name_by_themselves(searched_text):
    p = getattr(remplace_greek_letters_name_by_themselves, "__pattern_instance", None)
    if p is None:
        letters = {
            'alpha': 'α',
            'beta': 'β',
            'gamma': 'γ',
            'delta': 'δ',
            'epsilon': 'ε',
            'zeta': 'ζ',
            'eta': 'η',
            'theta': 'θ',
            'iota': 'ι',
            'kappa': 'κ',
            'lambda': 'λ',
            'mu': 'μ',
            'nu': 'ν',
            'xi': 'ξ',
            'omicron': 'ο',
            'pi': 'π',
            'rho': 'ρ',
            'sigma': 'σ',
            'tau': 'τ',
            'upsilon': 'υ',
            'phi': 'φ',
            'chi': 'χ',
            'psi': 'ψ',
            'omega': 'ω',
        }

        def flagged_replacer():
            def actual_replacer(match):
                match = match.group().lower()
                actual_replacer.has_match = True
                return match.replace(match.strip(), letters[match.strip()])

            actual_replacer.has_match = False

            return actual_replacer

        pattern_instance = re.compile(
            flags=re.IGNORECASE,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
            pattern=r'(alpha)|'
                    r'(beta)|'
                    r'(gamma)|'
                    r'(delta)|'
                    r'(epsilon)|'
                    r'(zeta)|'
                    r'(eta)|'
                    r'(theta)|'
                    r'(iota)|'
                    r'(kappa)|'
                    r'(lambda)|'
                    r'(mu)|'
                    r'(nu)|'
                    r'(xi)|'
                    r'(omicron)|'
                    r'(pi)|'
                    r'(rho)|'
                    r'(sigma)|'
                    r'(tau)|'
                    r'(upsilon)|'
                    r'(phi)|'
                    r'(chi)|'
                    r'(psi)|'
                    r'(omega)',
477
478
479
480
481
482
483
484
485
486
487
        )
        remplace_greek_letters_name_by_themselves.__pattern_instance = pattern_instance
        remplace_greek_letters_name_by_themselves.__greek_letters_replacer_factory = flagged_replacer
    else:
        pattern_instance = remplace_greek_letters_name_by_themselves.__pattern_instance

    replacer = remplace_greek_letters_name_by_themselves.__greek_letters_replacer_factory()

    return pattern_instance.sub(replacer, searched_text), replacer.has_match


488
489
490
491
492
class Unaccent(Transform):
    function = 'UNACCENT'
    lookup_name = 'unaccent'


493
def search(request):
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
494
    form = forms.SearchForm(data=request.POST or request.GET, user=request.user)
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
495
    if not form.is_valid():
496
        return JsonResponse(dict(err="Invalid query", detail=form.errors, app_error_code=1, searched_text=""))
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
497

498
    sorting = form.cleaned_data["sorting"] or "TFIDFnCOV"
499
    use_tf = sorting[:2] == "TF"
500
501
    use_idf = sorting[:5] == "TFIDF"
    use_cov = sorting == "TFIDFnCOV"
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
502

Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
503
    # get the raw searched string
504
    searched_text = form.cleaned_data["search"]
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
505

506
507
    # replace all space by a %20 : <i>bla "foo bar" bli<i> become <i>bla foo%20bar bli<i>
    for m in [m.group() for m in re.finditer('"[^"]+"', searched_text)]:
508
        searched_text = searched_text.replace(m, m.replace("\"", "%20").replace(" ", "%20"))
509

Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
510
    # split the search text by space
511
    raw_searched_texts = [x for x in searched_text.split(" ") if len(x) > 0]
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
512
513

    # find greek letters
514
515
    greek_searched_text, has_greek = remplace_greek_letters_name_by_themselves(searched_text)
    if has_greek:
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
516
        # if found, add also them, again split by space
517
518
519
        raw_searched_texts += [x if len(x) > 1 else "%20" + x + "%20"
                               for x in greek_searched_text.split(" ")
                               if len(x.strip()) > 0]
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
520

521
    # extract mandatory token
522
    mandatory = [x[1:] for x in raw_searched_texts if x[0] == '+']
523
    # extract mandatory token
524
    banned = [x[1:] for x in raw_searched_texts if x[0] == '-']
525

Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
526
    # remove leading "+" and  words starting with "-"
527
    raw_searched_texts = [x[1:] if x[0] == '+' else x for x in raw_searched_texts if x[0] != '-']
528
529

    # as splitting by space is done, so does picking mandatory and banned, we replace back %20 with blank space
530
    searched_texts = extract_word_and_target_pairs(raw_searched_texts)
531
532
    mandatory = extract_word_and_target_pairs(mandatory)
    banned = extract_word_and_target_pairs(banned)
533
    if len(searched_texts) == 0:
534
535
536
537
538
539
540
        if len(extract_word_and_target_pairs(raw_searched_texts, True)) > 0:
            return JsonResponse(dict(
                err="Single letter query are not permitted, put it between double quote to override",
                app_error_code=2,
                searched_text=searched_text,
            ))
        return JsonResponse(dict(err="Invalid query", app_error_code=1, searched_text=""))
541

Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
542
    # search options
543
    sample_size = form.cleaned_data["sample_size"]
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
544
    kind = form.cleaned_data["kind"] or 'all'
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
545
    search_key = "__unaccent__icontains" if connection.vendor == "postgresql" else "__icontains"
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
546
    ui_help = form.cleaned_data["ui_help"]
547
    str_max_length = form.cleaned_data["str_max_length"]
548

549
550
551
    if sample_size == 0:
        sample_size = 7 if kind == "all" else 15

552
    qs_virus = models.Virus.objects.annotate(
553
554
555
556
557
558
559
560
561
562
563
564
565
        her_identifier_s=Coalesce(Cast('her_identifier', CharField()), Value('')),
        tax_id_value_s=Coalesce('tax_id_value', Value('')),
        has_identifier=Case(
            When(
                Q(is_ncbi_identifier_value=True)
                | Q(her_identifier__isnull=False)
                | Q(tax_id_value__isnull=False),
                then=1,
            ),
            default=0,
            output_field=IntegerField(),
        )
    )
566
567
568
569
570
    qs_virus = filter_and_sort_qs(
        searched_texts=searched_texts,
        banned=banned,
        mandatory=mandatory,
        search_key=search_key,
571
        fields_to_search_in=["name", "identifier", "her_identifier_s", "tax_id_value_s"],
572
        path_to_data_source="data_source__",
573
        queryset=qs_virus,
574
        request=request,
575
576
        use_idf=use_idf,
        use_tf=use_tf,
577
        use_cov=use_cov,
578
        cls_filter_backend=FromGetParamsFilterVirusBackend,
579
        additional_ordering=["-has_identifier", ],
580
    )
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
581
582
    # print("@@@@@@@@@@@@@" + searched_text)
    # for o in qs_virus:
583
    #     print(o.pk, o.name, o.identifier, o.her_identifier, o.her_identifier_s, o.relevance)
584

585
586
587
588
589
590
591
592
593
594
595
596
    qs_host = models.Host.objects.annotate(
        tax_id_value_s=Coalesce('tax_id_value', Value('')),
        has_identifier=Case(
            When(
                Q(is_ncbi_identifier_value=True)
                | Q(tax_id_value__isnull=False),
                then=1,
            ),
            default=0,
            output_field=IntegerField(),
        )
    )
597
598
599
600
601
    qs_host = filter_and_sort_qs(
        searched_texts=searched_texts,
        banned=banned,
        mandatory=mandatory,
        search_key=search_key,
602
        fields_to_search_in=["name", "identifier", "tax_id_value_s"],
603
        path_to_data_source="data_source__",
604
        queryset=qs_host,
605
        request=request,
606
607
        use_idf=use_idf,
        use_tf=use_tf,
608
        use_cov=use_cov,
609
        cls_filter_backend=FromGetParamsFilterVirusBackend,
610
        additional_ordering=["-has_identifier", ],
611
    )
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
612
613
    # print("%%%%%%%%%%%%%" + searched_text)
    # for o in qs_host:
614
    #     print(o.pk, o.name, o.identifier, o.relevance)
615

616
617
    # annotating the qs so we have the provider ie provider_* if defined, owner otherwise
    qs_data_source = models.DataSource.objects \
618
        .annotate(publication_url_or_empty=Coalesce('publication_url', Value(''))) \
619
620
621
622
623
624
625
        .annotate(owner_str=Concat('owner__first_name', Value(' '), 'owner__last_name')) \
        .annotate(provider_raw=Concat('provider_first_name', Value(' '), 'provider_last_name')) \
        .annotate(provider_str=Case(
        When(provider_raw=' ', then=None),
        default='provider_raw',
        output_field=CharField(),
    )) \
626
627
628
629
630
        .annotate(provider=Coalesce('provider_str', 'owner_str')) \
        .annotate(has_publication_url=Case(When(publication_url__isnull=True, then=True),
                                           default=False,
                                           output_field=BooleanField(),
                                           ))
631

632
633
634
635
636
637
    qs_data_source = filter_and_sort_qs(
        searched_texts=searched_texts,
        banned=banned,
        mandatory=mandatory,
        search_key=search_key,
        fields_to_search_in=["name", "description", "provider", "publication_url_or_empty"],
638
        path_to_data_source="",
639
        queryset=qs_data_source,
640
641
642
        request=request,
        use_idf=use_idf,
        use_tf=use_tf,
643
        use_cov=use_cov,
644
        cls_filter_backend=FromGetParamsFilterDataSourceBackend,
645
        additional_ordering=["has_publication_url"],
646
    )
647
648
    # print("#############" + searched_text)
    # for o in qs_data_source:
649
    #     print(o.pk, o.name, o.relevance)
650
651
652
653

    responses = dict(
        query=dict(
            searched_text=searched_text,
654
655
656
657
658
659
            mandatory_searched_texts=[x[1] for x in mandatory],
            banned_searched_texts=[x[1] for x in banned],
            alternative_searched_texts=[x[1] for x in searched_texts],
            # mandatory_searched_texts_and_target=mandatory,
            # banned_searched_texts_and_target=banned,
            # alternative_searched_texts_and_target=searched_texts,
660
            sample_size=sample_size,
661
            sorting=sorting,
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
662
663
            kind=kind,
        )
664
    )
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
665
666
667
    if ui_help:
        responses["ui_help"] = dict()

668
    get_param = dict(form.cleaned_data)
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
669
    get_param["owner"] = list(get_param["owner"].values_list("pk", flat=True))
670
671
672
673
674
675
    get_param["sample_size"] = -1
    get_param.pop("ui_help")
    get_param.pop("kind")
    for k, v in list(get_param.items()):
        if not v:
            get_param.pop(k)
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
676
    get_param = urlencode(get_param, doseq=True)
677

Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
678
679
680
681
682
    if kind in ['virus', 'all']:
        if ui_help:
            responses["ui_help"]["virus"] = models.Virus._meta.verbose_name.title()
        responses["virus"] = dict(
            count=qs_virus.count(),
683
684
685
            sample=serializers.VirusSerializerWithURL(
                qs_virus[0:sample_size] if sample_size > 0 else qs_virus, many=True).data,
            subset_url="%s?kind=virus&%s" % (reverse("viralhostrangedb:search"), get_param)
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
686
687
688
689
690
691
        )
    if kind in ['host', 'all']:
        if ui_help:
            responses["ui_help"]["host"] = models.Host._meta.verbose_name.title()
        responses["host"] = dict(
            count=qs_host.count(),
692
693
694
            sample=serializers.HostSerializerWithURL(
                qs_host[0:sample_size] if sample_size > 0 else qs_host, many=True).data,
            subset_url="%s?kind=host&%s" % (reverse("viralhostrangedb:search"), get_param)
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
695
696
697
698
699
700
        )
    if kind in ['data_source', 'all']:
        if ui_help:
            responses["ui_help"]["data_source"] = models.DataSource._meta.verbose_name.title()
        responses["data_source"] = dict(
            count=qs_data_source.count(),
701
            sample=serializers.DataSourceSerializerForSearch(
702
703
704
705
                instance=qs_data_source[0:sample_size] if sample_size > 0 else qs_data_source,
                many=True,
                str_max_length=str_max_length,
            ).data,
706
            subset_url="%s?kind=data_source&%s" % (reverse("viralhostrangedb:search"), get_param)
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
707
        )
708
    return JsonResponse(responses)
709
710


711
def extract_word_and_target_pairs(texts, allow_single_letter=False):
712
713
714
715
716
    def clean_up(x):
        return x.replace("%20", " ").strip()

    pairs = []
    for s in texts:
717
718
719
720
721
722
723
724
725
        if s.upper().startswith("ID:"):
            targets = ["identifier", "her_identifier_s"]
            word = s[3:]
        elif s.upper().startswith("NCBI:"):
            targets = ["identifier"]
            word = s[5:]
        elif s.upper().startswith("HER:"):
            targets = ["her_identifier_s"]
            word = s[4:]
Bryan  BRANCOTTE's avatar
Bryan BRANCOTTE committed
726
727
728
        elif s.upper().startswith("TAX:"):
            targets = ["tax_id_value"]
            word = s[4:]
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        elif s.upper().startswith("NAME:"):
            targets = ["name"]
            word = s[5:]
        elif s.upper().startswith("DESC:"):
            targets = ["description"]
            word = s[5:]
        elif s.upper().startswith("OWNER:"):
            targets = ["provider"]
            word = s[6:]
        elif s.upper().startswith("PROVIDER:"):
            targets = ["provider"]
            word = s[9:]
        elif s.upper().startswith("PROV:"):
            targets = ["provider"]
            word = s[5:]
        else:
            targets = None
            word = s
        if len(word) > 1 or allow_single_letter and len(word) == 1:
            pairs.append((targets, clean_up(word)))
749
750
751
    return pairs


752
753
754
755
756
757
758
def filter_and_sort_qs(
        searched_texts,
        banned,
        mandatory,
        search_key,
        fields_to_search_in,
        path_to_data_source,
759
        queryset,
760
761
762
        request,
        use_idf,
        use_tf,
763
        use_cov,
764
        cls_filter_backend,
765
        additional_ordering=None,
766
):
767
    additional_ordering = additional_ordering or []
768
    queryset.query.clear_ordering(force_empty=True)
769
770
771
772
    if len(mandatory) == 0:
        # filtering data source to get entries matching one of the search key
        # if mandatory is not empty this step is useless as all match must have all keyword in mandatory, and the other
        # can be, but also can not be present. We will use the searched_texts only for sorting
773
774
775
776
777
        qs = []
        for s, f in itertools.product(searched_texts, fields_to_search_in):
            if s[0] is None or f in s[0]:
                qs.append(Q(**{f + search_key: s[1]}))
        queryset = queryset.filter(reduce(lambda x, y: x | y, qs, Q(pk=None)))
778
    for word in mandatory:
779
780
781
782
783
        qs = []
        for f in fields_to_search_in:
            if word[0] is None or f in word[0]:
                qs.append(Q(**{f + search_key: word[1]}))
        queryset = queryset.filter(reduce(lambda x, y: x | y, qs, Q(pk=None)))
784
    for word, f in itertools.product(banned, fields_to_search_in):
785
786
        if word[0] is None or f in word[0]:
            queryset = queryset.exclude(**{f + search_key: word[1]})
787
    if use_idf:
788
        idf_ds = compute_idf(fields_to_search_in, queryset, search_key, searched_texts, mandatory)
789
790
791
792
    else:
        idf_ds = {}
    if use_tf:
        # annotating each row with the number of time one of the search key have been found in one of the fields
793
        queryset = queryset.annotate(relevance=get_hit_count(
794
795
796
797
798
            fields=fields_to_search_in,
            searched_texts=searched_texts,
            idf=idf_ds,
            use_cov=use_cov,
        ))
799
        additional_ordering.insert(0, "-relevance")
800
801
    queryset = cls_filter_backend().filter_queryset(request, queryset, None)
    queryset = mixins.only_public_or_granted_or_owned_queryset_filter(
802
803
        None,
        request=request,
804
        queryset=queryset,
805
806
        path_to_data_source=path_to_data_source,
    )
807
    queryset = queryset.order_by(*additional_ordering)
808
809
810
    # if queryset.model == models.Virus:
    #     print("-------", ";".join([str(a) + ":" + str(b) for a, b in searched_texts]))
    #     for o in queryset:
811
    #         print(o.pk, o.name, o.identifier, o.her_identifier, o.tax_id_value, o.relevance)
812
    return queryset
813
814


815
def compute_idf(fields_names, qs, search_key, searched_texts, mandatory):
816
    fields_names = set(fields_names)
817
818
    idf = {}
    number_of_ds = qs.count()
819
820
821
822
823
    mandatory_words = [x[1] for x in mandatory]
    for target_and_word in searched_texts:
        s = target_and_word[1]
        if s in mandatory_words:
            continue
824
        try:
825
826
827
828
829
830
831
            hits = []
            if target_and_word[0] is None:
                fields_for_this_s = fields_names
            else:
                fields_for_this_s = fields_names.intersection(set(target_and_word[0]))
            for f in fields_for_this_s:
                hits.append(Q(**{f + search_key: s}))
832
            idf[s] = math.log(1 + number_of_ds / qs.filter(
833
                reduce(lambda x, y: x | y, hits, Q(pk=None))
834
835
836
837
838
839
            ).count())
        except ZeroDivisionError:
            idf[s] = 0
    return idf


840
841
if connection.vendor == "postgresql":
    def transformer(x):
842
        return Unaccent(Upper(x))
843
844
845
846
else:
    def transformer(x):
        return Upper(x)

847
__default_idf = math.log(2)
848

849

850
def get_hit_count(fields, searched_texts, idf=None, default_idf=__default_idf, use_cov=False):
851
852
853
    # To count how many time we find s in f, we compare the length of f, with the length of f where each occurrences
    # of s have been removed/replaced-by-empty-string (LR). To get the number of hit we then do (len(f)-LR)/len(s)
    # Note that we don't use f and s directly, we first put it in uppercase, and if possible without accent
854
855
    if idf is None:
        idf = dict()
856
857
858
859
860
861
862
863
864
865
866
    fields = set(fields)
    tf_idf_s = []
    for target_and_word in searched_texts:
        s = target_and_word[1]
        len_s = len(s)
        nb_hits = []
        if target_and_word[0] is None:
            fields_for_this_s = fields
        else:
            fields_for_this_s = fields.intersection(set(target_and_word[0]))
        for f in fields_for_this_s:
867
868
869
870
871
            f_without_s = Length(f) - Length(Replace(transformer(f), transformer(Value(s)), Value('')))
            nb_hit = Cast(f_without_s, FloatField()) / Value(len_s)
            if use_cov:
                nb_hit = Cast(nb_hit + Cast(f_without_s, FloatField()) / (Length(f) + Value(0.00000001)), FloatField())
            nb_hits.append(nb_hit)
872
873
874
875
        if len(nb_hits) > 0:
            tf_idf_s.append(
                Value(idf.get(s, default_idf)) * (reduce(lambda x, y: x + y, nb_hits))
            )
876
    return reduce(lambda x, y: x + y, tf_idf_s, Value(0.0, output_field=FloatField()))
877
878


879
880
881
882
883
884
885
886
887
888
def fetch_identifier_status(request):
    start = time.time()
    business_process.fetch_identifier_status(
        queryset=business_process.get_identifier_status_unknown(models.Virus.objects),
    )
    business_process.fetch_identifier_status(
        queryset=business_process.get_identifier_status_unknown(models.Host.objects),
    )
    end = time.time()
    return JsonResponse(dict(duration=end - start))
889

890

891
def get_statistics(request):
892
893
894
895
    if request.GET.get("public", "FALSE").upper() == "TRUE":
        user = AnonymousUser()
    else:
        user = request.user
896
897
898
899
    return JsonResponse(
        dict(
            virus_count=mixins.only_public_or_granted_or_owned_queryset_filter(
                self=None,
900
901
                request=None,
                user=user,
902
903
904
905
906
                queryset=models.Virus.objects,
                path_to_data_source="data_source__",
            ).count(),
            host_count=mixins.only_public_or_granted_or_owned_queryset_filter(
                self=None,
907
908
                request=None,
                user=user,
909
910
911
912
913
                queryset=models.Host.objects,
                path_to_data_source="data_source__",
            ).count(),
            response_count=mixins.only_public_or_granted_or_owned_queryset_filter(
                self=None,
914
915
                request=None,
                user=user,
916
917
918
919
920
                queryset=models.ViralHostResponseValueInDataSource.objects,
                path_to_data_source="data_source__",
            ).count(),
            data_source_count=mixins.only_public_or_granted_or_owned_queryset_filter(
                self=None,
921
922
                request=None,
                user=user,
923
924
925
926
927
                queryset=models.DataSource.objects,
                path_to_data_source="",
            ).count(),
        )
    )