Commit 492d4203 authored by Kenzo-Hugo Hillion's avatar Kenzo-Hugo Hillion
Browse files

update way to extract m2m fields on bulk lists

parent 536dd17b
Pipeline #16698 passed with stages
in 2 minutes and 47 seconds
from collections import defaultdict
from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail
from rest_framework.fields import SkipField
......@@ -7,13 +9,20 @@ from rest_framework.utils import html, model_meta
class BulkListSerializer(serializers.ListSerializer):
def _extract_many_to_many(self, validated_data, info):
many_to_many = [{} for v in validated_data]
def _extract_many_to_many(self, validated_data, info, lookup_field):
many_to_many = {
'keys': set(),
'values': defaultdict(list)
}
for field_name, relation_info in info.relations.items():
if relation_info.to_many:
for data_item, many_to_many_item in zip(validated_data, many_to_many):
for data_item in validated_data:
if field_name in data_item:
many_to_many_item[field_name] = data_item.pop(field_name)
many_to_many['keys'].add(field_name)
many_to_many['values'][field_name].append({
field_name: data_item.pop(field_name),
lookup_field: data_item[lookup_field]
})
return many_to_many
def _get_db_index_fields(self, info):
......@@ -93,9 +102,15 @@ class BulkListSerializer(serializers.ListSerializer):
def create(self, validated_data):
ModelClass = self.Meta.model
info = model_meta.get_field_info(ModelClass)
lookup_field = self._get_db_index_fields(info)[0]
many_to_many = self._extract_many_to_many(validated_data, info, lookup_field)
instances = ModelClass.objects.bulk_create(
[ModelClass(**item) for item in validated_data]
)
if many_to_many:
for field in many_to_many['keys']:
getattr(self, f'_handle_{field}', None)(many_to_many['values'][field])
return instances
def update(self, instances, validated_data):
......@@ -105,8 +120,8 @@ class BulkListSerializer(serializers.ListSerializer):
"""
ModelClass = self.Meta.model
info = model_meta.get_field_info(ModelClass)
db_index_fields = self._get_db_index_fields(info)
lookup_field = db_index_fields[0]
lookup_field = self._get_db_index_fields(info)[0]
many_to_many = self._extract_many_to_many(validated_data, info, lookup_field)
updated_keys = self._get_all_key_fields(validated_data)
validated_data = {item.pop(lookup_field): item for item in validated_data}
for item_id, validated_data_element in validated_data.items():
......@@ -116,6 +131,10 @@ class BulkListSerializer(serializers.ListSerializer):
list(instances.values()),
updated_keys
)
# Link existing many-to-many relationships.
if many_to_many:
for field in many_to_many['keys']:
getattr(self, f'_handle_{field}', None)(many_to_many['values'][field])
return list(instances.values())
class Meta:
......
......@@ -24,8 +24,8 @@ class BaseTestBulkListSerializerMethods(TestCase):
def setUp(self):
self.data = [
{'field1': 'value1', 'field2': 'value2'},
{'field1': 'value3', 'field2': 'value4'}
{'id': 'entry_1', 'field1': 'value1', 'field2': 'value2'},
{'id': 'entry_2', 'field1': 'value3', 'field2': 'value4'}
]
self.bulk_list_serializer = BulkListSerializerTestExtractManyToMany()
self.info = Mock()
......@@ -39,12 +39,17 @@ class TestExtractManyToMany(BaseTestBulkListSerializerMethods):
'field2': Mock(to_many=False)
}
ori_list = deepcopy(self.data)
expected_list = [
{'field1': 'value1'},
{'field1': 'value3'}
]
tested_list = self.bulk_list_serializer._extract_many_to_many(self.data, self.info)
self.assertListEqual(tested_list, expected_list)
expected_dict = {
'keys': {'field1'},
'values': {
'field1': [
{'field1': 'value1', 'id': 'entry_1'},
{'field1': 'value3', 'id': 'entry_2'}
]
}
}
tested_dict = self.bulk_list_serializer._extract_many_to_many(self.data, self.info, 'id')
self.assertDictEqual(tested_dict, expected_dict)
self.assertNotEqual(ori_list, self.data)
def test_extract_no_many_to_many(self):
......@@ -53,12 +58,12 @@ class TestExtractManyToMany(BaseTestBulkListSerializerMethods):
'field2': Mock(to_many=False)
}
ori_list = deepcopy(self.data)
expected_list = [
{},
{}
]
tested_list = self.bulk_list_serializer._extract_many_to_many(self.data, self.info)
self.assertListEqual(tested_list, expected_list)
expected_dict = {
'keys': set(),
'values': {}
}
tested_dict = self.bulk_list_serializer._extract_many_to_many(self.data, self.info, 'id')
self.assertDictEqual(tested_dict, expected_dict)
self.assertListEqual(ori_list, self.data)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment