Commit 684d8d7d authored by Bryan  BRANCOTTE's avatar Bryan BRANCOTTE
Browse files

parse the file before actually saving it, raise error during form validation...

parse the file before actually saving it, raise error during form validation instead of showing messages after save
WIP #240
parent eaf1c94f
......@@ -47,6 +47,7 @@ ViralHostResponse = namedtuple(
"host",
"host_identifiers",
"response",
"parsed_response",
"row_id",
"col_id",
])
......@@ -70,13 +71,16 @@ class ImportationObserver:
class Meta:
abstract = True
DUPLICATED = 1
EMPTY_NAME = 2
def notify_response_error(self, virus, host, response_str, replaced):
pass
def notify_host_error(self, host, column_id, reason=None):
def notify_host_error(self, host, column_id, reason=None, reason_id=None):
pass
def notify_virus_error(self, virus, row_id, reason=None):
def notify_virus_error(self, virus, row_id, reason=None, reason_id=None):
pass
@staticmethod
......@@ -96,8 +100,11 @@ class MessageImportationObserver(ImportationObserver):
self.host_warned = set()
self.virus_warned = set()
def add_message(self, request, level, message):
messages.add_message(request=request, level=level, message=message)
def notify_response_error(self, virus, host, response_str, replaced):
messages.add_message(
self.add_message(
self.request,
messages.WARNING,
"[ImportErr2] " + gettext(
......@@ -111,7 +118,11 @@ class MessageImportationObserver(ImportationObserver):
)
@staticmethod
def reason_to_str(msg, reason):
def reason_to_str(msg, reason, reason_id):
if reason_id == ImportationObserver.DUPLICATED:
return msg + gettext(": ") + gettext("Duplicated")
if reason_id == ImportationObserver.EMPTY_NAME:
return msg + gettext(": ") + gettext("Empty name")
msg = [msg, "<br/>"]
if type(reason) == ValidationError:
reason = reason.error_dict
......@@ -128,7 +139,7 @@ class MessageImportationObserver(ImportationObserver):
msg.append(str(reason))
return "".join(msg)
def notify_host_error(self, host, column_id, reason=None):
def notify_host_error(self, host, column_id, reason=None, reason_id=None):
if column_id in self.host_warned:
return
self.host_warned.add(column_id)
......@@ -138,15 +149,15 @@ class MessageImportationObserver(ImportationObserver):
column_id=str(column_id),
column_index=str(self.id_to_excel_index(column_id)),
)
if reason:
msg = self.reason_to_str(msg, reason)
messages.add_message(
if reason or reason_id:
msg = self.reason_to_str(msg, reason, reason_id)
self.add_message(
self.request,
messages.WARNING,
mark_safe(msg),
)
def notify_virus_error(self, virus, row_id, reason=None):
def notify_virus_error(self, virus, row_id, reason=None, reason_id=None):
if row_id in self.virus_warned:
return
self.virus_warned.add(row_id)
......@@ -155,14 +166,35 @@ class MessageImportationObserver(ImportationObserver):
row_id=str(row_id),
)
if reason:
msg = self.reason_to_str(msg, reason)
messages.add_message(
msg = self.reason_to_str(msg, reason, reason_id)
self.add_message(
self.request,
messages.WARNING,
mark_safe(msg),
)
class StackErrorImportationObserver(MessageImportationObserver):
def __init__(self):
super().__init__(request=None)
self.errors = []
def add_message(self, request, level, message):
self.errors.append(message)
def notify_response_error(self, virus, host, response_str, replaced):
self.add_message(
self.request,
messages.WARNING,
"[ImportErr2] " + gettext(
"Could not import response \"%(response)s\" for virus \"%(virus)s\", host\"%(host)s\"") % dict(
response=str(response_str),
virus=str(virus),
host=str(host),
)
)
def panda_color_mapping(v):
key = 'html_color_%s' % str(v)
color = cache.get(key)
......@@ -334,8 +366,8 @@ def __parse_file(file, importation_observer: ImportationObserver = None, sheet_n
except KeyError:
importation_observer.notify_host_error(
h,
header_col + start_at,
reason=gettext("Duplicated host, this occurrence will not be imported"),
header_col + start_at + 2,
reason_id=ImportationObserver.DUPLICATED,
)
break
elif id_col > 0 and sub_row[0] != "":
......@@ -355,7 +387,7 @@ def __parse_file(file, importation_observer: ImportationObserver = None, sheet_n
importation_observer.notify_virus_error(
sub_row[0],
row_id,
reason=gettext("Duplicated virus, this occurrence will overwrite the previous row"),
reason_id=ImportationObserver.DUPLICATED,
)
has_seen_data = True
h = header[id_col + start_at]
......@@ -364,16 +396,23 @@ def __parse_file(file, importation_observer: ImportationObserver = None, sheet_n
importation_observer.notify_host_error(
h,
id_col + start_at,
reason=gettext("Empty host, not imported"),
reason_id=ImportationObserver.EMPTY_NAME,
)
continue
host, host_identifiers = extract_name_and_identifiers(h)
try:
parsed_response = float(cell)
except ValueError:
parsed_response = -1000
if importation_observer:
importation_observer.notify_response_error(virus, host, cell, parsed_response)
yield ViralHostResponse(
virus=virus,
virus_identifiers=virus_identifiers,
host=host,
host_identifiers=host_identifiers,
response=cell,
parsed_response=parsed_response,
row_id=row_id,
col_id=id_col + start_at,
)
......@@ -474,8 +513,23 @@ def restore_backup(*, data_source, log_entry, importation_observer: ImportationO
)
def import_file_later(*, file):
observer = StackErrorImportationObserver()
data = list(__parse_file(file, importation_observer=observer))
def actually_import_file(data_source, importation_observer):
return __import_file(data_source=data_source, parsed_file=data, importation_observer=importation_observer)
return observer.errors, actually_import_file
@transaction.atomic
def import_file(*, data_source, file, importation_observer: ImportationObserver = None):
parsed_file = parse_file(file, importation_observer)
return __import_file(data_source=data_source, parsed_file=parsed_file, importation_observer=importation_observer)
def __import_file(*, data_source, parsed_file, importation_observer: ImportationObserver = None):
"""
Import the file and associate responses to the data source provided. If responses are already present there are
overwritten with the one from the file. Updated response are automatically map following the mapping observed in db
......@@ -498,7 +552,7 @@ def import_file(*, data_source, file, importation_observer: ImportationObserver
responses_to_create = []
# former mapping, if present empty dict otherwise
former_mapping = dict(data_source.get_mapping(only_pk=True))
for vhr in parse_file(file, importation_observer):
for vhr in parsed_file:
# if vhr.virus == "" or vhr.host == "" or vhr.response == "":
# continue
explicit_virus = explicit_item(
......@@ -587,11 +641,10 @@ def import_file(*, data_source, file, importation_observer: ImportationObserver
host_dict[explicit_host] = host
try:
raw_response = float(vhr.response)
float(vhr.response)
except ValueError as e:
if importation_observer:
raw_response = -1000
importation_observer.notify_response_error(vhr.virus, vhr.host, vhr.response, raw_response)
importation_observer.notify_response_error(vhr.virus, vhr.host, vhr.response, vhr.parsed_response)
else:
raise e
# update or create response in db
......@@ -600,8 +653,8 @@ def import_file(*, data_source, file, importation_observer: ImportationObserver
virus=virus,
host=host,
defaults=dict(
raw_response=raw_response,
response_id=former_mapping.get(raw_response, not_mapped_yet.pk),
raw_response=vhr.parsed_response,
response_id=former_mapping.get(vhr.parsed_response, not_mapped_yet.pk),
),
)
old_virus_pk.discard(virus.pk)
......
......@@ -55,11 +55,19 @@ class ImportDataSourceForm(forms.ModelForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.on_disk_file = None
self.import_file_fcn = None
def full_clean(self):
super().full_clean()
if not self.is_bound or not self.make_upload_mandatory: # Stop further processing.
if not self.is_bound: # Stop further processing.
return
if self.files is not None and len(self.files) == 1:
file = list(self.files.values())[0]
errors, self.import_file_fcn = business_process.import_file_later(file=file)
for e in errors:
self.add_error('file', e)
# else:
# self.add_error('file', "No file provided")
# if (self.files is None or len(self.files) == 0) and len(self.cleaned_data["url"]) == 0 or \
# (self.files is not None and len(self.files) > 0) and len(self.cleaned_data["url"]) > 0:
# self.add_error("url", _("You have to either provide a file or an URL, and not both."))
......@@ -70,33 +78,13 @@ class ImportDataSourceForm(forms.ModelForm):
if owner is not None:
instance.owner = owner
instance.save()
file = None
if self.files is not None and len(self.files) == 1:
file = list(self.files.values())[0]
if (instance.raw_name or "") == "":
instance.raw_name = file.name
instance.kind = "FILE"
# if len(self.cleaned_data["url"]) > 0:
# try:
# instance.kind = "URL"
# url = self.cleaned_data["url"]
# instance.raw_name = url
# raise NotImplementedError("Downloading from an URL is not available yet")
# except KeyError:
# pass
if file is not None:
# with NamedTemporaryFile(mode='wb+', suffix="-%s" % file.name, delete=False) as destination:
# print(destination.name)
# for chunk in file.chunks():
# destination.write(chunk)
# file = destination.name
# with open(file, "rb") as input_file:
business_process.import_file(
file=file,
data_source=instance,
importation_observer=importation_observer,
)
if (instance.raw_name or "") == "":
instance.raw_name = list(self.files.values())[0].name
instance.kind = "FILE"
self.import_file_fcn(
data_source=instance,
importation_observer=importation_observer,
)
if commit:
instance.save()
return instance
......
......@@ -39,7 +39,7 @@ class NoErrorImportationObserver(business_process.MessageImportationObserver):
)
)
def notify_host_error(self, host, column_id, reason=None):
def notify_host_error(self, host, column_id, reason=None, reason_id=None):
msg = "[ImportErr1] " + "Could not parse host \"%(host)s\" at column \"%(column_id)s\" (i.e column \"%(column_index)s\")" % dict(
host=str(host),
column_id=str(column_id),
......@@ -49,13 +49,13 @@ class NoErrorImportationObserver(business_process.MessageImportationObserver):
msg = self.reason_to_str(msg, reason)
raise HostException(msg)
def notify_virus_error(self, virus, row_id, reason=None):
def notify_virus_error(self, virus, row_id, reason=None, reason_id=None):
msg = "[ImportErr3] " + "Could not parse virus \"%(virus)s\" at row \"%(row_id)s\"" % dict(
virus=str(virus),
row_id=str(row_id),
)
if reason:
msg = self.reason_to_str(msg, reason)
msg = self.reason_to_str(msg, reason, reason_id)
raise VirusException(msg)
......
......@@ -498,16 +498,11 @@ class FileImportTestCase(ViewTestCase):
file=SimpleUploadedFile(f.name, f.read(), content_type="application/vnd.ms-excel"),
)
response = self.client.post(url, form_data, follow=True)
response = self.client.post(url, form_data, follow=False)
self.assertEqual(response.status_code, 200)
messages = list(response.context['messages'])
self.assertEqual(len(messages), 3)
error_codes = set()
for m in messages:
parts = str(m).split("]")
if len(parts) > 1:
error_codes.add(parts[0][1:])
self.assertSetEqual(error_codes, {"ImportErr1", "ImportErr2"})
str_content = str(response.content)
self.assertIn("ImportErr1", str_content)
self.assertIn("ImportErr2", str_content)
def test_works_messages_raised_with_too_long_virus(self):
url = reverse('viralhostrangedb:file-import-view')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment