Select Git revision
stitched.py
import_command.py 8.98 KiB
import glob
import re
import copy
from django.utils import timezone
from django.db import transaction
from django.core.management import BaseCommand, CommandError
import mysql.connector
import requests_cache
import click
import xml.etree.ElementTree as ET
class MyConverter(mysql.connector.conversion.MySQLConverter):
def row_to_python(self, row, fields):
row = super(MyConverter, self).row_to_python(row, fields)
def to_unicode(col):
if type(col) == bytearray:
return col.decode('utf-8')
return col
return[to_unicode(col) for col in row]
class ImportTask(object):
description = "Abstract import task"
option = ""
target_classes = {}
main_class = None
depends_on = []
def __init__(self, command, **kwargs):
self.out_stream = command.stdout
self.err_stream = command.stderr
self.style = command.style
self.traceback = kwargs.get('traceback', False)
self.stop_on_fail = kwargs.get('stoponfail', False)
self.progress_bar = kwargs.get('progress_bar', False)
self.done = False
def _flush_target_models(self):
for target_class in self.target_classes:
self.out_stream.write(
'Flushing {target_class} models...'.format(target_class=target_class.__name__))
target_class.objects.all().delete()
self.out_stream.write(self.style.SUCCESS(
'Successfully flushed {target_class} models!'.format(target_class=target_class.__name__)))
def check_final_count(self):
if self.main_class:
count = self.main_class.objects.count()
if count==self.source_count:
self.out_stream.write(self.style.SUCCESS(
'rows count ok for model {}: expected {}, counted {}'.format(self.main_class.__name__, self.source_count, count)))
else:
message = 'rows count not ok for model {}: expected {}, counted {}'.format(
self.main_class.__name__, self.source_count, count)
if self.stop_on_fail:
raise CommandError(message)
else:
self.out_stream.write(self.style.ERROR(message))
def migrate_row(self,row):
raise NotImplementedError()
def open_data_source(self):
raise NotImplementedError()
def post_process(self):
pass
def _process_rows(self, rows, progress=None):
for row in rows:
try:
new_object = None
with transaction.atomic():
new_object = self.migrate_row(row)
except Exception as e:
if self.traceback:
import traceback
self.err_stream.write(self.style.NOTICE(traceback.format_exc()))
if self.stop_on_fail:
raise CommandError(
'Failed inserting {}'.format(new_object))
else:
self.out_stream.write(self.style.ERROR(
'Failed inserting {}'.format(new_object)))
else:
if self.progress_bar == False:
self.out_stream.write(self.style.SUCCESS(
'Successfully inserted {}'.format(new_object)))
if self.progress_bar == True:
progress.update(1)
def count_source(self):
self.source_count = len(self.rows)
def _run_import_loop(self):
self.open_data_source()
self.count_source()
if self.progress_bar is True:
with click.progressbar(length=self.source_count,
label='Importing ' + self.description + ' (' + str(self.source_count) + ' rows to process).') as progress:
self._process_rows(self.rows, progress)
else:
self._process_rows(self.rows)
self.post_process()
self.check_final_count()
self.done = True
def check(self):
self.open_data_source()
self.count_source()
self.check_final_count()
def run(self):
self._flush_target_models()
self._run_import_loop()
class MysqlImportTask(ImportTask):
description = "Abstract MySQL import task"
outer_sql = ""
def get_cursor(self):
return self.conn.cursor()
def open_data_source(self):
cursor = self.get_cursor()
cursor.execute(self.outer_sql)
self.rows = cursor.fetchall()
class ListImportTask(ImportTask):
description = "Abstract Python list import task"
DATA = []
def open_data_source(self):
self.rows = self.DATA
class DataFrameImportTask(ImportTask):
description = "Abstract Dataframe import task"
dataframe = None
def open_data_source(self):
self.rows = (row for idx, row in self.dataframe.dropna(how='all').iterrows())
def count_source(self):
self.source_count = self.dataframe.shape[0]
class XMLImportTask(ImportTask):
description = "Abstract XML import task"
xmlFile = None
xpathSelector = "*"
def open_data_source(self):
tree = ET.parse(argv[1])
root = tree.getroot()
self.rows = root.findall(self.xpathSelector)
def count_source(self):
self.source_count = len(self.rows)
class ImportCommand(BaseCommand):
help = "Generic command to import data into a django database"
# list all the import tasks that should be available via the command
task_classes = []
def add_arguments(self, parser):
task_choices = [task_class.option for task_class in self.task_classes]
task_choices.append('all')
task_help = 'Import task to be run.\n '
task_help += ',\n \n'.join([task_class.option + ': ' + task_class.description for task_class in self.task_classes])
task_help += ',\n \nall: import everything.'
parser.add_argument(
'task',
type=str,
choices=task_choices,
help=task_help,
)
parser.add_argument(
'--all',
action='store_true',
dest='all',
default=False,
help='import everything',
)
parser.add_argument(
'--wscache',
action='store_true',
dest='wscache',
default=False,
help='use web services cache',
)
parser.add_argument(
'--errortb',
action='store_true',
dest='errortb',
default=False,
help='show tracebacks on errors',
)
parser.add_argument(
'--stoponfail',
action='store_true',
dest='stoponfail',
default=False,
help='stop on first error',
)
parser.add_argument(
'--progressbar',
action='store_true',
dest='progress_bar',
default=False,
help='show progressbar instead of logging inserted entries',
)
parser.add_argument(
'--check',
action='store_true',
dest='check',
default=False,
help='only check instead of running import',
)
def handle(self, *args, **options):
# use web service cache if option has been selected
if options.get('wscache'):
requests_cache.install_cache('ws_cache')
task_option = options.get('task')
# map task names to task classes
option_to_task = {task_class.option: task_class for task_class in self.task_classes}
# map task names to the list of task names that depend upon them
dependency_to_task = {task_class.option: [] for task_class in self.task_classes}
for task_class in self.task_classes:
for task_depency_class in task_class.depends_on:
dependency_to_task[task_depency_class.option].append(task_class.option)
dependency_to_task['all'] = option_to_task.keys()
def append_dependencies(option):
dependencies = set([option])
for depending_option in dependency_to_task.get(option,[]):
dependencies.update(append_dependencies(depending_option))
return dependencies
if task_option:
# list of tasks to run because they depend on initial option,
# directly or not
dependencies = append_dependencies(task_option)
dependencies_classes = {task_class for opt, task_class in option_to_task.items() if opt in dependencies}
task_kwargs = {
'traceback': options['errortb'],
'stoponfail': options['stoponfail'],
'progress_bar': options['progress_bar'],
'conn': options.get('conn')
}
while len(dependencies_classes)>0:
for task_class in copy.copy(dependencies_classes):
if not(set(task_class.depends_on) & dependencies_classes):
task = task_class(self, **task_kwargs)
task.run()
dependencies_classes.remove(task_class)