diff --git a/backend/metagenedb/apps/catalog/serializers/bulk_list.py b/backend/metagenedb/apps/catalog/serializers/bulk_list.py index 163a86aceb3cb938b4e8f2dd465e8cd900619611..a6de42970836ed98360c54980bcbbf1dcf2a8510 100644 --- a/backend/metagenedb/apps/catalog/serializers/bulk_list.py +++ b/backend/metagenedb/apps/catalog/serializers/bulk_list.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from rest_framework import serializers from rest_framework.exceptions import ValidationError, ErrorDetail from rest_framework.fields import SkipField @@ -7,13 +9,20 @@ from rest_framework.utils import html, model_meta class BulkListSerializer(serializers.ListSerializer): - def _extract_many_to_many(self, validated_data, info): - many_to_many = [{} for v in validated_data] + def _extract_many_to_many(self, validated_data, info, lookup_field): + many_to_many = { + 'keys': set(), + 'values': defaultdict(list) + } for field_name, relation_info in info.relations.items(): if relation_info.to_many: - for data_item, many_to_many_item in zip(validated_data, many_to_many): + for data_item in validated_data: if field_name in data_item: - many_to_many_item[field_name] = data_item.pop(field_name) + many_to_many['keys'].add(field_name) + many_to_many['values'][field_name].append({ + field_name: data_item.pop(field_name), + lookup_field: data_item[lookup_field] + }) return many_to_many def _get_db_index_fields(self, info): @@ -93,9 +102,15 @@ class BulkListSerializer(serializers.ListSerializer): def create(self, validated_data): ModelClass = self.Meta.model + info = model_meta.get_field_info(ModelClass) + lookup_field = self._get_db_index_fields(info)[0] + many_to_many = self._extract_many_to_many(validated_data, info, lookup_field) instances = ModelClass.objects.bulk_create( [ModelClass(**item) for item in validated_data] ) + if many_to_many: + for field in many_to_many['keys']: + getattr(self, f'_handle_{field}', None)(many_to_many['values'][field]) return instances def update(self, instances, validated_data): @@ -105,8 +120,8 @@ class BulkListSerializer(serializers.ListSerializer): """ ModelClass = self.Meta.model info = model_meta.get_field_info(ModelClass) - db_index_fields = self._get_db_index_fields(info) - lookup_field = db_index_fields[0] + lookup_field = self._get_db_index_fields(info)[0] + many_to_many = self._extract_many_to_many(validated_data, info, lookup_field) updated_keys = self._get_all_key_fields(validated_data) validated_data = {item.pop(lookup_field): item for item in validated_data} for item_id, validated_data_element in validated_data.items(): @@ -116,6 +131,10 @@ class BulkListSerializer(serializers.ListSerializer): list(instances.values()), updated_keys ) + # Link existing many-to-many relationships. + if many_to_many: + for field in many_to_many['keys']: + getattr(self, f'_handle_{field}', None)(many_to_many['values'][field]) return list(instances.values()) class Meta: diff --git a/backend/metagenedb/apps/catalog/serializers/test_bulk_list.py b/backend/metagenedb/apps/catalog/serializers/test_bulk_list.py index 0e2cedae00931e55079b89dba73f9ec76a392052..10b0ec3f2b977b9eac18474c752da39e844f7969 100644 --- a/backend/metagenedb/apps/catalog/serializers/test_bulk_list.py +++ b/backend/metagenedb/apps/catalog/serializers/test_bulk_list.py @@ -24,8 +24,8 @@ class BaseTestBulkListSerializerMethods(TestCase): def setUp(self): self.data = [ - {'field1': 'value1', 'field2': 'value2'}, - {'field1': 'value3', 'field2': 'value4'} + {'id': 'entry_1', 'field1': 'value1', 'field2': 'value2'}, + {'id': 'entry_2', 'field1': 'value3', 'field2': 'value4'} ] self.bulk_list_serializer = BulkListSerializerTestExtractManyToMany() self.info = Mock() @@ -39,12 +39,17 @@ class TestExtractManyToMany(BaseTestBulkListSerializerMethods): 'field2': Mock(to_many=False) } ori_list = deepcopy(self.data) - expected_list = [ - {'field1': 'value1'}, - {'field1': 'value3'} - ] - tested_list = self.bulk_list_serializer._extract_many_to_many(self.data, self.info) - self.assertListEqual(tested_list, expected_list) + expected_dict = { + 'keys': {'field1'}, + 'values': { + 'field1': [ + {'field1': 'value1', 'id': 'entry_1'}, + {'field1': 'value3', 'id': 'entry_2'} + ] + } + } + tested_dict = self.bulk_list_serializer._extract_many_to_many(self.data, self.info, 'id') + self.assertDictEqual(tested_dict, expected_dict) self.assertNotEqual(ori_list, self.data) def test_extract_no_many_to_many(self): @@ -53,12 +58,12 @@ class TestExtractManyToMany(BaseTestBulkListSerializerMethods): 'field2': Mock(to_many=False) } ori_list = deepcopy(self.data) - expected_list = [ - {}, - {} - ] - tested_list = self.bulk_list_serializer._extract_many_to_many(self.data, self.info) - self.assertListEqual(tested_list, expected_list) + expected_dict = { + 'keys': set(), + 'values': {} + } + tested_dict = self.bulk_list_serializer._extract_many_to_many(self.data, self.info, 'id') + self.assertDictEqual(tested_dict, expected_dict) self.assertListEqual(ori_list, self.data)