diff --git a/src/dataloaders.py b/src/dataloaders.py
index 3fd6fa5e07a216fa3fa835fdb8d4bcd6c2b5bfa4..60adba8c09f932506679163923529a55b68ca8b7 100644
--- a/src/dataloaders.py
+++ b/src/dataloaders.py
@@ -62,6 +62,41 @@ class EdgePairWithCommonNodeDataset(Dataset):
return positive_edge, negative_edge
+class EdgePairWithCommonNodeDatasetFull(Dataset):
+ def __init__(self, graph, node_mapping=None):
+ self.graph = graph
+ if node_mapping is None:
+ node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}}
+ self.node_mapping = node_mapping
+ self.node_list = list(self.graph.nodes())
+ import itertools
+ self.indices = []
+ for u, v, w in itertools.permutations(self.graph.nodes(), 3):
+ e1 = (u, v)
+ e2 = (u, w)
+ w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0
+ w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0
+ if w1 != w2:
+ self.indices.append((u,v,w))
+ print(len(self))
+
+ def __len__(self):
+ return len(self.indices)
+
+ def __getitem__(self, i):
+ u, v, w = self.indices[i]
+ e1 = (u, v)
+ e2 = (u, w)
+ w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0
+ w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0
+
+ # recover the corresponding edges in the directed graph
+ positive_edge, negative_edge = (e1, e2) if w1 > w2 else (e2, e1)
+ positive_edge = (self.node_mapping['key_to_id'][positive_edge[0]], self.node_mapping['key_to_id'][positive_edge[1]])
+ negative_edge = (self.node_mapping['key_to_id'][negative_edge[0]], self.node_mapping['key_to_id'][negative_edge[1]])
+
+ return positive_edge, negative_edge
+
class EdgePairDataset(IterableDataset):
def __init__(self, graph, node_mapping=None):
self.graph = graph
@@ -125,16 +160,15 @@ class EdgeDataset(Dataset):
anchor_key, target_key = self.node_mapping['id_to_key'][i], self.node_mapping['id_to_key'][j]
return (i,j), (self.graph.get_edge_data(anchor_key, target_key)['weight'] if self.graph.has_edge(anchor_key,target_key) else 0)
-if __name__ == '__main__':
- from torch.utils.data import DataLoader
- from synthetic_graphs import *
- from dataloaders import EdgeTripletDataset
+# if __name__ == '__main__':
+# from torch.utils.data import DataLoader
+# from synthetic_graphs import *
- N, M = 15, 12
- graph, _ = synfire_chain(N, M)
- dataset = EdgeTripletDataset(graph)
- dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
+# N, M = 15, 12
+# graph, _ = synfire_chain(N, M)
+# dataset = EdgeTripletDataset(graph)
+# dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
- for _ in dataloader:
- pass
+# for _ in dataloader:
+# pass
diff --git a/src/experiment1.py b/src/experiment1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8649abe8470816873b6ba19fdb3ea3b1ddff8947
--- /dev/null
+++ b/src/experiment1.py
@@ -0,0 +1,99 @@
+import os
+from torch.utils.data import DataLoader
+import torch
+from torch.optim import SGD, Adam, LBFGS
+from synthetic_graphs import *
+from dataloaders import *
+from single_space_model import Model
+from kernels import GaussianKernel
+import matplotlib.pyplot as plt
+from matplotlib.animation import FuncAnimation
+from torch.utils.tensorboard import SummaryWriter
+
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+N, M = 15, 12
+graph, _ = synfire_chain(N, M)
+dataset = EdgePairWithCommonNodeDataset(graph)
+dataloader = DataLoader(dataset, shuffle=False, batch_size=len(dataset))
+
+model = Model(graph.order(), 3, GaussianKernel(3, bias=True), activation=torch.sigmoid)
+model.init_embedding(graph)
+model.to(device)
+optim = Adam(model.parameters())
+# optim = LBFGS(model.parameters(), lr=1, line_search_fn='strong_wolfe')
+n_epochs = 100000
+
+writer = SummaryWriter(comment='EXP001')
+slow_interval = n_epochs//20
+fast_interval = n_epochs//1000
+loss_history = []
+bias_history = []
+embedding_snap = []
+def fast_logging(epoch):
+ bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten())
+ embedding_snap.append(model.x.weight.cpu().detach().numpy())
+def slow_logging(epoch):
+ with torch.no_grad():
+ dot_product_matrix = torch.zeros((len(graph), len(graph)))
+ for i in torch.arange(len(graph), dtype=int).reshape(-1,1):
+ dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu()
+
+ writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW')
+ bias = model.kernel.bias.detach().cpu().numpy().flatten()
+ embedding = model.x.weight.cpu().detach().numpy()
+ fig = plt.figure()
+ plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1')
+ plt.arrow(0,0,bias[0], bias[1])
+ plt.axis('equal')
+ writer.add_figure('embedding', fig, epoch)
+ plt.close()
+
+def closure(e1, e2): # for LBFGS
+ def f():
+ optim.zero_grad()
+ loss = -model.triplet_loss(e1, e2).mean()
+ loss.backward()
+ return loss
+ return f
+
+try:
+ for epoch in range(n_epochs):
+ sum_loss = 0
+ if not(epoch%100) and epoch:
+ print(f'epoch {epoch}/{n_epochs}')
+ for e1, e2 in dataloader:
+ optim.step(closure=closure(e1,e2))
+ # model.center_rescale_embedding()
+ with torch.no_grad():
+ sum_loss += -model.triplet_loss(e1, e2).mean().cpu().numpy()
+ epoch_loss = sum_loss/len(dataloader)+1
+ writer.add_scalar('loss', epoch_loss, epoch)
+
+ if not(epoch%fast_interval) or epoch == n_epochs-1:
+ fast_logging(epoch)
+ if not(epoch%slow_interval) or epoch == n_epochs-1:
+ slow_logging(epoch)
+except KeyboardInterrupt:
+ print("Final logging. Interrupt again to skip.")
+ fast_logging(epoch)
+ slow_logging(epoch)
+
+print("Final logging.")
+fig = plt.figure()
+pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))])
+arrow = plt.arrow(0, 0, 0, 0)
+maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0)
+mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0)
+plt.xlim(mins[0], maxs[0])
+plt.ylim(mins[1], maxs[1])
+
+def anim(i):
+ pointcloud.set_offsets(embedding_snap[i])
+ arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1])
+ return pointcloud, arrow
+
+ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True)
+ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4'))
+plt.close()
+writer.close()
\ No newline at end of file
diff --git a/src/single_space_model.py b/src/single_space_model.py
index 49c102c1a92fd28b433ba0196b2f74a1d27781b9..96a10e7f94aa2048a07ef416decb41faacc7ec5d 100644
--- a/src/single_space_model.py
+++ b/src/single_space_model.py
@@ -1,15 +1,27 @@
import torch.nn as nn
import torch
+from networkx import spectral_layout
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Model(nn.Module):
- def __init__(self, graph, dim, kernel, activation):
+ def __init__(self, n_nodes, dim, kernel, activation):
super().__init__()
- self.x = nn.Embedding(graph.order(), dim)
+ self.n_nodes = n_nodes
+ self.dim = dim
+ self.x = nn.Embedding(n_nodes, dim)
self.kernel= kernel
self.activation = activation
+ def init_embedding(self, graph, node_mapping = None):
+ if node_mapping is None:
+ node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}}
+ init_embedding = spectral_layout(graph, dim=self.dim)
+ init_embedding = [torch.Tensor(init_embedding[node_mapping['id_to_key'][i]]) for i in range(len(graph))]
+ init_embedding = torch.stack(init_embedding, dim=0)
+ init_embedding = init_embedding/init_embedding.std()
+ self.x.weight.data = init_embedding
+
def apply_kernel(self, node1, node2):
vec1, vec2 = self.x(node1.to(device)), self.x(node2.to(device))
return self.kernel(vec1, vec2)
@@ -25,93 +37,4 @@ class Model(nn.Module):
def triplet_loss(self, e_pos, e_neg):
k_pos = self.apply_kernel(*e_pos)
k_neg = self.apply_kernel(*e_neg)
- return k_pos/(k_pos + k_neg)
-
-if __name__ == '__main__':
- import os
- from torch.utils.data import DataLoader
- from torch.optim import SGD, Adam
- from synthetic_graphs import *
- from dataloaders import *
- from kernels import GaussianKernel
- import matplotlib.pyplot as plt
- from matplotlib.animation import FuncAnimation
- from torch.utils.tensorboard import SummaryWriter
-
- N, M = 15, 12
- graph, _ = synfire_chain(N, M)
- dataset = EdgePairWithCommonNodeDataset(graph)
- dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
-
- model = Model(graph, 3, GaussianKernel(3, bias=True), activation=torch.sigmoid)
- model.to(device)
- optim = Adam(model.parameters())
- n_epochs = 100000
-
- writer = SummaryWriter()
- slow_interval = n_epochs//20
- fast_interval = n_epochs//1000
- loss_history = []
- bias_history = []
- embedding_snap = []
- def fast_logging(epoch):
- bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten())
- embedding_snap.append(model.x.weight.cpu().detach().numpy())
- def slow_logging(epoch):
- with torch.no_grad():
- dot_product_matrix = torch.zeros((len(graph), len(graph)))
- for i in torch.arange(len(graph), dtype=int).reshape(-1,1):
- dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu()
- writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW')
-
- bias = model.kernel.bias.detach().cpu().numpy().flatten()
- embedding = model.x.weight.cpu().detach().numpy()
- fig = plt.figure()
- plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1')
- plt.arrow(0,0,bias[0], bias[1])
- plt.axis('equal')
- writer.add_figure('embedding', fig, epoch)
- plt.close()
-
- try:
- for epoch in range(n_epochs):
- sum_loss = 0
- if not(epoch%100) and epoch:
- print(f'epoch {epoch}/{n_epochs}')
- for e1, e2 in dataloader:
- optim.zero_grad()
- loss = -model.triplet_loss(e1, e2).mean()
- loss.backward()
- optim.step()
- # model.center_rescale_embedding()
- sum_loss += loss.cpu().detach().item()
- epoch_loss = sum_loss/len(dataloader)+1
- writer.add_scalar('loss', epoch_loss, epoch)
-
- if not(epoch%fast_interval) or epoch == n_epochs-1:
- fast_logging(epoch)
- if not(epoch%slow_interval) or epoch == n_epochs-1:
- slow_logging(epoch)
- except KeyboardInterrupt:
- print("Final logging. Interrupt again to skip.")
- fast_logging(epoch)
- slow_logging(epoch)
-
- print("Final logging.")
- fig = plt.figure()
- pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))])
- arrow = plt.arrow(0, 0, 0, 0)
- maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0)
- mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0)
- plt.xlim(mins[0], maxs[0])
- plt.ylim(mins[1], maxs[1])
-
- def anim(i):
- pointcloud.set_offsets(embedding_snap[i])
- arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1])
- return pointcloud, arrow
-
- ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True)
- ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4'))
- plt.close()
- writer.close()
\ No newline at end of file
+ return k_pos/(k_pos + k_neg)
\ No newline at end of file