Skip to content
Snippets Groups Projects
Select Git revision
  • 061f40922d04f7b86a67b95f1ff07ebee2df58e5
  • master default protected
  • dev
  • score_test
4 results

AbstractDGFactory.py

Blame
  • user avatar
    Yoann Dufresne authored
    061f4092
    History
    AbstractDGFactory.py 2.34 KiB
    import networkx as nx
    import sys
    from abc import abstractmethod
    from multiprocessing import Pool, Value
    
    from deconvolution.dgraph.FixedDGIndex import FixedDGIndex, AbstractDGIndex
    
    counter = None
    
    def process_node(factory, node):
        global counter
        my_value = counter.value
        counter.value += 1
    
        if factory.verbose:
            print(f"{my_value}: Generating d-graphs")
            sys.stdout.flush()
    
        # udg generation
        neighbors = [x for x in factory.graph[node]]
        subgraph = nx.Graph(factory.graph.subgraph(neighbors))
        dgs = factory.generate_by_node(node, subgraph)
    
        if factory.verbose:
            print(f"{my_value}: d-graphs generated, starting filtering")
            print(f"{my_value}: {len(dgs)} udg to filter")
            sys.stdout.flush()
    
        # udg domination filtering
        dgs = AbstractDGIndex.filter_entry(dgs)
    
        if factory.verbose:
            print(f"{my_value}: {len(dgs)} udg remaining after filtering")
            print(f"{my_value}({factory.nb_nodes}) terminated")
            sys.stdout.flush()
    
        return node, dgs
    
    class AbstractDGFactory:
        def __init__(self, graph, debug=False):
            self.debug = debug
            self.graph = graph
            self.nb_nodes = len(self.graph.nodes())
            self.verbose = False
            global counter
            counter = Value('i', 0)
    
        def generate_all_dgraphs(self, verbose=False, threads=8):
            index = FixedDGIndex(size=1)
            factory = self
            nb_nodes = len(self.graph.nodes())
            self.verbose = verbose
            global counter
            counter = Value('i', 0)
    
            if verbose:
                print("Start parallel work")
    
            if threads > 1:
                results = None
                with Pool(processes=threads) as pool:
                    results = pool.starmap(process_node, zip(
                        [factory]*nb_nodes,
                        self.graph.nodes()
                    ))
    
                # Fill the index by node
                for node, dgs in results:
                    key = frozenset({node})
                    for dg in dgs:
                        index.add_value(key, dg)
            else:
                for node in self.graph.nodes():
                    key = frozenset({node})
                    _, dgs = process_node(factory, node)
                    for dg in dgs:
                        index.add_value(key, dg)
    
            return index
    
    
        @abstractmethod
        def generate_by_node(self, node, subgraph):
            pass