d2_graph_test.py 4.05 KB
Newer Older
1
import unittest
2
3
import tempfile
import networkx as nx
4
from scipy.special import comb
5
6
7

from d2_graph import D2Graph
from d_graph import Dgraph
8
import graph_manipulator as gm
9

Yoann Dufresne's avatar
Yoann Dufresne committed
10
from tests.d_graph_data import complete_graph
11
12
13
14


class TestD2Graph(unittest.TestCase):
    def test_construction(self):
15
16
        d2 = D2Graph(complete_graph)
        d2.construct_from_barcodes(index_size=6, verbose=False)
17
18

        # Evaluate the number of candidate unit d_graphs generated
Yoann Dufresne's avatar
Yoann Dufresne committed
19
        for node, candidates in d2.d_graphs_per_node.items():
20
21
22
23
24
            if node == "C" or node == "B2":
                self.assertEquals(1, len(candidates))
            else:
                self.assertEquals(0, len(candidates))

Yoann Dufresne's avatar
Yoann Dufresne committed
25
26
        # Evaluate the index
        self.assertEquals(13, len(d2.index))
27

Yoann Dufresne's avatar
Yoann Dufresne committed
28
29
30
31
32
33
34
        overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
        for dmer, dg_lst in d2.index.items():
            if dmer == overlap_key:
                self.assertEquals(2, len(d2.index[dmer]))
                self.assertNotEquals(d2.index[dmer][0], d2.index[dmer][1])
            else:
                self.assertEquals(1, len(d2.index[dmer]))
35
36


37
38
39
40
41
42
43
    def test_linear_d2_construction(self):
        for d in range(1, 10):
            size = 2 * d + 3
            index_k = 2 * d - 1


            G = gm.generate_d_graph_chain(size, d)
44
45
            d2 = D2Graph(G)
            d2.construct_from_barcodes(index_size=index_k, verbose=False)
46
47
48
49
50
51
52
53
54

            # Test the number of d-graphs
            awaited_d_num = size - 2 * d
            self.assertEquals(awaited_d_num, len(d2.all_d_graphs))

            # Test index
            awaited_index_size = comb(2*d+1, index_k) + (size - (2*d+1)) * comb(2*d, index_k-1)
            if len(d2.index) != awaited_index_size:
                dmers = [list(x) for x in d2.index]
55
                dmers = [str(x) for x in dmers if len(x) != len(frozenset(x))]
56

57
58
            self.assertEquals(awaited_index_size, len(d2.index))

59
60
61
62
63
64
65
66
67
68
69
70
            # Test connectivity
            # Center node names
            c1 = d
            c2 = d+1
            c3 = d+2
            # Connectivity matrix
            awaited_distances = {
                c1:{c2:2, c3:4},
                c2:{c1:2, c3:2},
                c3:{c1:4, c2:2}
            }
            
71
72
73
            for x, y, data in d2.edges(data=True):
                dg1 = d2.node_by_idx[int(x.split(" ")[0])]
                dg2 = d2.node_by_idx[int(y.split(" ")[0])]
74

75
76
77
78
79
80
81
82
83
                awaited_dist = awaited_distances[dg1.center][dg2.center]
                self.assertEquals(data["distance"], awaited_dist)

            # # distance tests
            # for idx1, neighbors in d2.distances.items():
            #     dg1 = d2.node_by_idx[idx1]

            #     for idx2, dist in neighbors.items():
            #         dg2 = d2.node_by_idx[idx2]
84
                    
85
86
            #         awaited_dist = awaited_distances[dg1.center][dg2.center]
            #         self.assertEquals(dist, awaited_dist)
87
88


89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def test_reloading(self):
        # Parameters
        d = 3
        size = 2 * d + 3
        index_k = 2 * d - 1

        # Create a d2  graph
        G = gm.generate_d_graph_chain(size, d)
        d2 = D2Graph(G)
        d2.construct_from_barcodes(index_size=index_k, verbose=False)

        # Save and reload the d2 in a temporary file
        with tempfile.NamedTemporaryFile() as fp:
            # Save
103
            nx.write_gexf(d2, fp.name)
104
105
106
107
108
109

            # Reload
            d2_reloaded = D2Graph(G)
            d2_reloaded.load(fp.name)

            # Test the nx graph
110
111
112
            self.assertNotEquals(d2_reloaded, None)
            self.assertEquals(len(d2_reloaded.nodes()), len(d2.nodes()))
            self.assertEquals(len(d2_reloaded.edges()), len(d2.edges()))
113

114
115
            # TODO: Verify distances

116
117
118
119
120
121
122
123
            # Test all_d_graphs
            self.assertEquals(len(d2_reloaded.all_d_graphs), len(d2.all_d_graphs))
            # Verify dg idxs
            reloaded_idxs = [dg.idx for dg in d2_reloaded.all_d_graphs]
            for dg in d2.all_d_graphs:
                self.assertTrue(dg.idx in reloaded_idxs)


124
125
if __name__ == "__main__":
    unittest.main()