Commit a6b89a0e authored by Kenzo-Hugo Hillion's avatar Kenzo-Hugo Hillion
Browse files

update put to behave as upsert

parent 8924454a
Pipeline #14435 passed with stages
in 2 minutes and 32 seconds
......@@ -18,13 +18,35 @@ class BulkViewSet(ModelViewSet):
}
return serializer.data
def create(self, request, *args, **kwargs):
if isinstance(request.data, list):
serializer = self.get_serializer(data=request.data, many=True)
def _updated_payload(self, create_serializer, update_serializer, request):
return {
'path': request.path_info,
'created': {
'count': len(create_serializer.data) if create_serializer else 0
},
'updated': {
'count': len(update_serializer.data) if update_serializer else 0
}
}
def _get_create_serializer(self, data):
if isinstance(data, list):
serializer = self.get_serializer(data=data, many=True)
else:
serializer = self.get_serializer(data=request.data)
serializer = self.get_serializer(data=data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
return serializer
def _get_update_serializer(self, instances, data):
serializer = self.get_serializer(instances, data=data, many=True)
serializer.is_valid(raise_exception=True)
if data:
self.perform_update(serializer)
return serializer
def create(self, request, *args, **kwargs):
serializer = self._get_create_serializer(request.data)
headers = self.get_success_headers(serializer.data)
return Response(
self._created_payload(serializer, request),
......@@ -32,24 +54,28 @@ class BulkViewSet(ModelViewSet):
)
def update(self, request, *args, **kwargs):
print(self.lookup_field)
if self.lookup_field in kwargs.keys():
# perform the classic update
return super().update(request, *args, **kwargs)
instance_ids = set([element[self.lookup_field] for element in request.data])
instances = self.get_objects(instance_ids)
serializer = self.get_serializer(instances.values(), data=request.data, many=True)
serializer.is_valid(raise_exception=True)
# instance = self.get_object()
# serializer = self.get_serializer(instance, data=request.data, partial=partial)
# serializer.is_valid(raise_exception=True)
# self.perform_update(serializer)
data_to_update = []
data_to_create = []
for item in request.data:
if item.get(self.lookup_field) in instances.keys():
data_to_update.append(item)
else:
data_to_create.append(item)
update_serializer = self._get_update_serializer(instances, data_to_update)
create_serializer = self._get_create_serializer(data_to_create)
# if getattr(instance, '_prefetched_objects_cache', None):
# # If 'prefetch_related' has been applied to a queryset, we need to
# # forcibly invalidate the prefetch cache on the instance.
# instance._prefetched_objects_cache = {}
# return Response(serializer.data)
return Response({'wait': 'for it'})
headers = self.get_success_headers(update_serializer.data)
return Response(
self._updated_payload(create_serializer, update_serializer, request),
status=status.HTTP_201_CREATED, headers=headers
)
......@@ -77,7 +77,62 @@ class TestOperationsBulkViewSet(APITestCase):
]
for element in data:
self.assertNotEqual(self.function_api.get(element['function_id']), element)
self.function_api.put(data)
response = self.function_api.put(data)
self.assertEqual(response.get('created').get('count'), 0)
self.assertEqual(response.get('updated').get('count'), 2)
self.assertEqual(self.function_api.get_all()['count'], 2)
for element in data:
self.assertDictEqual(self.function_api.get(element['function_id']), element)
def test_create_through_update_functions(self):
functions = FunctionFactory.build_batch(2)
data = [
{
"function_id": functions[0].function_id,
"source": functions[0].source,
"name": "Test 1"
},
{
"function_id": functions[1].function_id,
"source": functions[1].source,
"name": "Test 2"
}
]
response = self.function_api.put(data)
self.assertEqual(response.get('created').get('count'), 2)
self.assertEqual(response.get('updated').get('count'), 0)
self.assertEqual(self.function_api.get_all()['count'], 2)
for element in data:
self.assertDictEqual(self.function_api.get(element['function_id']), element)
def test_create_and_update_functions(self):
functions = FunctionFactory.create_batch(2)
new_functions = FunctionFactory.build_batch(2)
data = [
{
"function_id": functions[0].function_id,
"source": functions[0].source,
"name": "Test 1"
},
{
"function_id": functions[1].function_id,
"source": functions[1].source,
"name": "Test 2"
},
{
"function_id": new_functions[0].function_id,
"source": new_functions[0].source,
"name": "New Test 1"
},
{
"function_id": new_functions[1].function_id,
"source": new_functions[1].source,
"name": "New Test 2"
}
]
response = self.function_api.put(data)
self.assertEqual(response.get('created').get('count'), 2)
self.assertEqual(response.get('updated').get('count'), 2)
self.assertEqual(self.function_api.get_all()['count'], 4)
for element in data:
self.assertDictEqual(self.function_api.get(element['function_id']), element)
from rest_framework import serializers
from rest_framework.utils import model_meta # noqa
from rest_framework.exceptions import ValidationError
from rest_framework.fields import SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta
class BulkListSerializer(serializers.ListSerializer):
......@@ -33,6 +36,56 @@ class BulkListSerializer(serializers.ListSerializer):
all_keys.append(key)
return list(set(all_keys))
def to_internal_value(self, data):
"""
Copy of original method to be overloaded when performing put to instances
List of dicts of native values <- List of dicts of primitive datatypes.
"""
ModelClass = self.Meta.model
info = model_meta.get_field_info(ModelClass)
db_index_fields = self._get_db_index_fields(info)
lookup_field = db_index_fields[0]
if html.is_html_input(data):
data = html.parse_html_list(data, default=[])
if not isinstance(data, list):
message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='not_a_list')
if not self.allow_empty and len(data) == 0:
if self.parent and self.partial:
raise SkipField()
message = self.error_messages['empty']
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='empty')
ret = []
errors = []
for item in data:
try:
if isinstance(self.instance, dict):
self.child.instance = self.instance[item[lookup_field]]
validated = self.child.run_validation(item)
else:
validated = self.child.run_validation(item)
except ValidationError as exc:
errors.append(exc.detail)
else:
ret.append(validated)
errors.append({})
if any(errors):
raise ValidationError(errors)
return ret
def create(self, validated_data):
ModelClass = self.Meta.model
instances = ModelClass.objects.bulk_create(
......@@ -41,19 +94,24 @@ class BulkListSerializer(serializers.ListSerializer):
return instances
def update(self, instances, validated_data):
"""
:param instances: instances to update
:type instances: DICT of instance object
"""
ModelClass = self.Meta.model
info = model_meta.get_field_info(ModelClass)
db_index_fields = self._get_db_index_fields(info)
lookup_field = db_index_fields[0]
updated_keys = self._get_all_key_fields(validated_data)
[data.pop(db_index_fields[0]) for data in validated_data] # remove db_index fields
for instance, validated_data_element in zip(instances, validated_data):
validated_data = {item.pop(lookup_field): item for item in validated_data}
for item_id, validated_data_element in validated_data.items():
for key, value in validated_data_element.items():
setattr(instance, key, value)
instances = ModelClass.objects.bulk_update(
instances,
setattr(instances[item_id], key, value)
ModelClass.objects.bulk_update(
list(instances.values()),
updated_keys
)
return instances
return list(instances.values())
class Meta:
model = NotImplemented
......@@ -132,6 +132,7 @@ class TestOperationsBulk(APITestCase):
"name": "Test 2"
}
]
functions = {item.function_id: item for item in functions}
for data in validated_data:
self.assertNotEqual(self.function_api.get(data['function_id']), data)
serializer = FunctionSerializer(many=True)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment