d2_graph_test.py 2.51 KB
Newer Older
1
import unittest
2
from scipy.special import comb
3
4
5

from d2_graph import D2Graph
from d_graph import Dgraph
6
import graph_manipulator as gm
7

Yoann Dufresne's avatar
Yoann Dufresne committed
8
from tests.d_graph_data import complete_graph
9
10
11
12


class TestD2Graph(unittest.TestCase):
    def test_construction(self):
Yoann Dufresne's avatar
Yoann Dufresne committed
13
        d2 = D2Graph(complete_graph, 6)
14
15

        # Evaluate the number of candidate unit d_graphs generated
Yoann Dufresne's avatar
Yoann Dufresne committed
16
        for node, candidates in d2.d_graphs_per_node.items():
17
18
19
20
21
            if node == "C" or node == "B2":
                self.assertEquals(1, len(candidates))
            else:
                self.assertEquals(0, len(candidates))

Yoann Dufresne's avatar
Yoann Dufresne committed
22
23
        # Evaluate the index
        self.assertEquals(13, len(d2.index))
24

Yoann Dufresne's avatar
Yoann Dufresne committed
25
26
27
28
29
30
31
        overlap_key = ('A1', 'A2', 'B0', 'B1', 'B2', 'C')
        for dmer, dg_lst in d2.index.items():
            if dmer == overlap_key:
                self.assertEquals(2, len(d2.index[dmer]))
                self.assertNotEquals(d2.index[dmer][0], d2.index[dmer][1])
            else:
                self.assertEquals(1, len(d2.index[dmer]))
32
33


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    def test_linear_d2_construction(self):
        for d in range(1, 10):
            size = 2 * d + 3
            index_k = 2 * d - 1


            G = gm.generate_d_graph_chain(size, d)
            d2 = D2Graph(G, index_size=index_k)

            # Test the number of d-graphs
            awaited_d_num = size - 2 * d
            self.assertEquals(awaited_d_num, len(d2.all_d_graphs))

            # Test index
            awaited_index_size = comb(2*d+1, index_k) + (size - (2*d+1)) * comb(2*d, index_k-1)
            if len(d2.index) != awaited_index_size:
                dmers = [list(x) for x in d2.index]
51
                dmers = [str(x) for x in dmers if len(x) != len(frozenset(x))]
52

53
54
55
            self.assertEquals(awaited_index_size, len(d2.index))

            d2_nx = d2.nx_graph
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
            # Test connectivity
            # Center node names
            c1 = d
            c2 = d+1
            c3 = d+2
            # Connectivity matrix
            awaited_distances = {
                c1:{c2:2, c3:4},
                c2:{c1:2, c3:2},
                c3:{c1:4, c2:2}
            }
            
            # distance tests
            for idx1, neighbors in d2.distances.items():
                dg1 = d2.node_by_idx[idx1]

                for idx2, dist in neighbors.items():
                    dg2 = d2.node_by_idx[idx2]
                    
                    awaited_dist = awaited_distances[dg1.center][dg2.center]
                    self.assertEquals(dist, awaited_dist)
78
79


80
81
if __name__ == "__main__":
    unittest.main()