Skip to content
Snippets Groups Projects
Select Git revision
  • d91f5113e5effbade177c1fd547c2539fc012802
  • master default protected
  • bbrancot-master-patch-78606
  • patch-1
4 results

import_command.py

Blame
  • Forked from Hervé MENAGER / django-diu
    Source project has a limited visibility.
    import_command.py 8.98 KiB
    import glob
    import re
    import copy
    
    from django.utils import timezone
    from django.db import transaction
    from django.core.management import BaseCommand, CommandError
    import mysql.connector
    import requests_cache
    import click
    import xml.etree.ElementTree as ET
    
    
    class MyConverter(mysql.connector.conversion.MySQLConverter):
    
        def row_to_python(self, row, fields):
            row = super(MyConverter, self).row_to_python(row, fields)
    
            def to_unicode(col):
                if type(col) == bytearray:
                    return col.decode('utf-8')
                return col
    
            return[to_unicode(col) for col in row]
    
    class ImportTask(object):
    
        description = "Abstract import task"
    
        option = ""
    
        target_classes = {}
    
        main_class = None
    
        depends_on = []
    
        def __init__(self, command, **kwargs):
            self.out_stream = command.stdout
            self.err_stream = command.stderr
            self.style = command.style
            self.traceback = kwargs.get('traceback', False)
            self.stop_on_fail = kwargs.get('stoponfail', False)
            self.progress_bar = kwargs.get('progress_bar', False)
            self.done = False
    
        def _flush_target_models(self):
            for target_class in self.target_classes:
                self.out_stream.write(
                    'Flushing {target_class} models...'.format(target_class=target_class.__name__))
                target_class.objects.all().delete()
                self.out_stream.write(self.style.SUCCESS(
                    'Successfully flushed {target_class} models!'.format(target_class=target_class.__name__)))
    
        def check_final_count(self):
            if self.main_class:
                count = self.main_class.objects.count()
                if count==self.source_count:
                    self.out_stream.write(self.style.SUCCESS(
                        'rows count ok for model {}: expected {}, counted {}'.format(self.main_class.__name__, self.source_count, count)))
                else:
                    message = 'rows count not ok for model {}: expected {}, counted {}'.format(
                        self.main_class.__name__, self.source_count, count)
                    if self.stop_on_fail:
                        raise CommandError(message)
                    else:
                        self.out_stream.write(self.style.ERROR(message))
    
        def migrate_row(self,row):
            raise NotImplementedError()
    
        def open_data_source(self):
            raise NotImplementedError()
    
        def post_process(self):
            pass
    
        def _process_rows(self, rows, progress=None):
            for row in rows:
                try:
                    new_object = None
                    with transaction.atomic():
                        new_object = self.migrate_row(row)
                except Exception as e:
                    if self.traceback:
                        import traceback
                        self.err_stream.write(self.style.NOTICE(traceback.format_exc()))
                    if self.stop_on_fail:
                        raise CommandError(
                            'Failed inserting {}'.format(new_object))
                    else:
                        self.out_stream.write(self.style.ERROR(
                            'Failed inserting {}'.format(new_object)))
                else:
                    if self.progress_bar == False:
                        self.out_stream.write(self.style.SUCCESS(
                            'Successfully inserted {}'.format(new_object)))
                if self.progress_bar == True:
                    progress.update(1)
    
        def count_source(self):
            self.source_count = len(self.rows)
    
        def _run_import_loop(self):
            self.open_data_source()
            self.count_source()
            if self.progress_bar is True:
                with click.progressbar(length=self.source_count,
                                    label='Importing ' + self.description + ' (' + str(self.source_count) + ' rows to process).') as progress:
                    self._process_rows(self.rows, progress)
            else:
                self._process_rows(self.rows)
            self.post_process()
            self.check_final_count()
            self.done = True
    
        def check(self):
            self.open_data_source()
            self.count_source()
            self.check_final_count()
    
        def run(self):
            self._flush_target_models()
            self._run_import_loop()
    
    class MysqlImportTask(ImportTask):
    
        description = "Abstract MySQL import task"
    
        outer_sql = ""
    
        def get_cursor(self):
            return self.conn.cursor()
    
        def open_data_source(self):
            cursor = self.get_cursor()
            cursor.execute(self.outer_sql)
            self.rows = cursor.fetchall()
    
    class ListImportTask(ImportTask):
    
        description = "Abstract Python list import task"
    
        DATA = []
    
        def open_data_source(self):
            self.rows = self.DATA
    
    
    class DataFrameImportTask(ImportTask):
    
        description = "Abstract Dataframe import task"
    
        dataframe = None
    
        def open_data_source(self):
            self.rows = (row for idx, row in self.dataframe.dropna(how='all').iterrows())
    
        def count_source(self):
            self.source_count = self.dataframe.shape[0]
    
    
    class XMLImportTask(ImportTask):
    
        description = "Abstract XML import task"
    
        xmlFile = None
        xpathSelector = "*"
    
        def open_data_source(self):
            tree = ET.parse(argv[1])
            root = tree.getroot()
            self.rows = root.findall(self.xpathSelector)
    
        def count_source(self):
            self.source_count = len(self.rows)
    
    class ImportCommand(BaseCommand):
        help = "Generic command to import data into a django database"
    
        # list all the import tasks that should be available via the command
        task_classes = []
    
        def add_arguments(self, parser):
            task_choices = [task_class.option for task_class in self.task_classes]
            task_choices.append('all')
            task_help = 'Import task to be run.\n '
            task_help += ',\n \n'.join([task_class.option + ': ' + task_class.description for task_class in self.task_classes])
            task_help += ',\n \nall: import everything.'
            parser.add_argument(
                'task',
                type=str,
                choices=task_choices,
                help=task_help,
            )
            parser.add_argument(
                '--all',
                action='store_true',
                dest='all',
                default=False,
                help='import everything',
            )
            parser.add_argument(
                '--wscache',
                action='store_true',
                dest='wscache',
                default=False,
                help='use web services cache',
            )
            parser.add_argument(
                '--errortb',
                action='store_true',
                dest='errortb',
                default=False,
                help='show tracebacks on errors',
            )
            parser.add_argument(
                '--stoponfail',
                action='store_true',
                dest='stoponfail',
                default=False,
                help='stop on first error',
            )
            parser.add_argument(
                '--progressbar',
                action='store_true',
                dest='progress_bar',
                default=False,
                help='show progressbar instead of logging inserted entries',
            )
            parser.add_argument(
                '--check',
                action='store_true',
                dest='check',
                default=False,
                help='only check instead of running import',
            )
    
        def handle(self, *args, **options):
            # use web service cache if option has been selected
            if options.get('wscache'):
                requests_cache.install_cache('ws_cache')
            task_option = options.get('task')
            # map task names to task classes
            option_to_task = {task_class.option: task_class for task_class in self.task_classes}
            # map task names to the list of task names that depend upon them
            dependency_to_task = {task_class.option: [] for task_class in self.task_classes}
            for task_class in self.task_classes:
                for task_depency_class in task_class.depends_on:
                    dependency_to_task[task_depency_class.option].append(task_class.option)
            dependency_to_task['all'] = option_to_task.keys()
            def append_dependencies(option):
                dependencies = set([option])
                for depending_option in dependency_to_task.get(option,[]):
                    dependencies.update(append_dependencies(depending_option))
                return dependencies
            if task_option:
                # list of tasks to run because they depend on initial option,
                # directly or not
                dependencies = append_dependencies(task_option)
                dependencies_classes = {task_class for opt, task_class in option_to_task.items() if opt in dependencies}
            task_kwargs = {
                'traceback': options['errortb'],
                'stoponfail': options['stoponfail'],
                'progress_bar': options['progress_bar'],
                'conn': options.get('conn')
            }
            while len(dependencies_classes)>0:
                for task_class in copy.copy(dependencies_classes):
                    if not(set(task_class.depends_on) & dependencies_classes):
                        task = task_class(self, **task_kwargs)
                        task.run()
                        dependencies_classes.remove(task_class)