diff --git a/src/dataloaders.py b/src/dataloaders.py index 3fd6fa5e07a216fa3fa835fdb8d4bcd6c2b5bfa4..60adba8c09f932506679163923529a55b68ca8b7 100644 --- a/src/dataloaders.py +++ b/src/dataloaders.py @@ -62,6 +62,41 @@ class EdgePairWithCommonNodeDataset(Dataset): return positive_edge, negative_edge +class EdgePairWithCommonNodeDatasetFull(Dataset): + def __init__(self, graph, node_mapping=None): + self.graph = graph + if node_mapping is None: + node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}} + self.node_mapping = node_mapping + self.node_list = list(self.graph.nodes()) + import itertools + self.indices = [] + for u, v, w in itertools.permutations(self.graph.nodes(), 3): + e1 = (u, v) + e2 = (u, w) + w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0 + w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0 + if w1 != w2: + self.indices.append((u,v,w)) + print(len(self)) + + def __len__(self): + return len(self.indices) + + def __getitem__(self, i): + u, v, w = self.indices[i] + e1 = (u, v) + e2 = (u, w) + w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0 + w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0 + + # recover the corresponding edges in the directed graph + positive_edge, negative_edge = (e1, e2) if w1 > w2 else (e2, e1) + positive_edge = (self.node_mapping['key_to_id'][positive_edge[0]], self.node_mapping['key_to_id'][positive_edge[1]]) + negative_edge = (self.node_mapping['key_to_id'][negative_edge[0]], self.node_mapping['key_to_id'][negative_edge[1]]) + + return positive_edge, negative_edge + class EdgePairDataset(IterableDataset): def __init__(self, graph, node_mapping=None): self.graph = graph @@ -125,16 +160,15 @@ class EdgeDataset(Dataset): anchor_key, target_key = self.node_mapping['id_to_key'][i], self.node_mapping['id_to_key'][j] return (i,j), (self.graph.get_edge_data(anchor_key, target_key)['weight'] if self.graph.has_edge(anchor_key,target_key) else 0) -if __name__ == '__main__': - from torch.utils.data import DataLoader - from synthetic_graphs import * - from dataloaders import EdgeTripletDataset +# if __name__ == '__main__': +# from torch.utils.data import DataLoader +# from synthetic_graphs import * - N, M = 15, 12 - graph, _ = synfire_chain(N, M) - dataset = EdgeTripletDataset(graph) - dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M) +# N, M = 15, 12 +# graph, _ = synfire_chain(N, M) +# dataset = EdgeTripletDataset(graph) +# dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M) - for _ in dataloader: - pass +# for _ in dataloader: +# pass diff --git a/src/experiment1.py b/src/experiment1.py new file mode 100644 index 0000000000000000000000000000000000000000..8649abe8470816873b6ba19fdb3ea3b1ddff8947 --- /dev/null +++ b/src/experiment1.py @@ -0,0 +1,99 @@ +import os +from torch.utils.data import DataLoader +import torch +from torch.optim import SGD, Adam, LBFGS +from synthetic_graphs import * +from dataloaders import * +from single_space_model import Model +from kernels import GaussianKernel +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from torch.utils.tensorboard import SummaryWriter + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +N, M = 15, 12 +graph, _ = synfire_chain(N, M) +dataset = EdgePairWithCommonNodeDataset(graph) +dataloader = DataLoader(dataset, shuffle=False, batch_size=len(dataset)) + +model = Model(graph.order(), 3, GaussianKernel(3, bias=True), activation=torch.sigmoid) +model.init_embedding(graph) +model.to(device) +optim = Adam(model.parameters()) +# optim = LBFGS(model.parameters(), lr=1, line_search_fn='strong_wolfe') +n_epochs = 100000 + +writer = SummaryWriter(comment='EXP001') +slow_interval = n_epochs//20 +fast_interval = n_epochs//1000 +loss_history = [] +bias_history = [] +embedding_snap = [] +def fast_logging(epoch): + bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten()) + embedding_snap.append(model.x.weight.cpu().detach().numpy()) +def slow_logging(epoch): + with torch.no_grad(): + dot_product_matrix = torch.zeros((len(graph), len(graph))) + for i in torch.arange(len(graph), dtype=int).reshape(-1,1): + dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu() + + writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW') + bias = model.kernel.bias.detach().cpu().numpy().flatten() + embedding = model.x.weight.cpu().detach().numpy() + fig = plt.figure() + plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1') + plt.arrow(0,0,bias[0], bias[1]) + plt.axis('equal') + writer.add_figure('embedding', fig, epoch) + plt.close() + +def closure(e1, e2): # for LBFGS + def f(): + optim.zero_grad() + loss = -model.triplet_loss(e1, e2).mean() + loss.backward() + return loss + return f + +try: + for epoch in range(n_epochs): + sum_loss = 0 + if not(epoch%100) and epoch: + print(f'epoch {epoch}/{n_epochs}') + for e1, e2 in dataloader: + optim.step(closure=closure(e1,e2)) + # model.center_rescale_embedding() + with torch.no_grad(): + sum_loss += -model.triplet_loss(e1, e2).mean().cpu().numpy() + epoch_loss = sum_loss/len(dataloader)+1 + writer.add_scalar('loss', epoch_loss, epoch) + + if not(epoch%fast_interval) or epoch == n_epochs-1: + fast_logging(epoch) + if not(epoch%slow_interval) or epoch == n_epochs-1: + slow_logging(epoch) +except KeyboardInterrupt: + print("Final logging. Interrupt again to skip.") + fast_logging(epoch) + slow_logging(epoch) + +print("Final logging.") +fig = plt.figure() +pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))]) +arrow = plt.arrow(0, 0, 0, 0) +maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0) +mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0) +plt.xlim(mins[0], maxs[0]) +plt.ylim(mins[1], maxs[1]) + +def anim(i): + pointcloud.set_offsets(embedding_snap[i]) + arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1]) + return pointcloud, arrow + +ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True) +ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4')) +plt.close() +writer.close() \ No newline at end of file diff --git a/src/single_space_model.py b/src/single_space_model.py index 49c102c1a92fd28b433ba0196b2f74a1d27781b9..96a10e7f94aa2048a07ef416decb41faacc7ec5d 100644 --- a/src/single_space_model.py +++ b/src/single_space_model.py @@ -1,15 +1,27 @@ import torch.nn as nn import torch +from networkx import spectral_layout device = 'cuda' if torch.cuda.is_available() else 'cpu' class Model(nn.Module): - def __init__(self, graph, dim, kernel, activation): + def __init__(self, n_nodes, dim, kernel, activation): super().__init__() - self.x = nn.Embedding(graph.order(), dim) + self.n_nodes = n_nodes + self.dim = dim + self.x = nn.Embedding(n_nodes, dim) self.kernel= kernel self.activation = activation + def init_embedding(self, graph, node_mapping = None): + if node_mapping is None: + node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}} + init_embedding = spectral_layout(graph, dim=self.dim) + init_embedding = [torch.Tensor(init_embedding[node_mapping['id_to_key'][i]]) for i in range(len(graph))] + init_embedding = torch.stack(init_embedding, dim=0) + init_embedding = init_embedding/init_embedding.std() + self.x.weight.data = init_embedding + def apply_kernel(self, node1, node2): vec1, vec2 = self.x(node1.to(device)), self.x(node2.to(device)) return self.kernel(vec1, vec2) @@ -25,93 +37,4 @@ class Model(nn.Module): def triplet_loss(self, e_pos, e_neg): k_pos = self.apply_kernel(*e_pos) k_neg = self.apply_kernel(*e_neg) - return k_pos/(k_pos + k_neg) - -if __name__ == '__main__': - import os - from torch.utils.data import DataLoader - from torch.optim import SGD, Adam - from synthetic_graphs import * - from dataloaders import * - from kernels import GaussianKernel - import matplotlib.pyplot as plt - from matplotlib.animation import FuncAnimation - from torch.utils.tensorboard import SummaryWriter - - N, M = 15, 12 - graph, _ = synfire_chain(N, M) - dataset = EdgePairWithCommonNodeDataset(graph) - dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M) - - model = Model(graph, 3, GaussianKernel(3, bias=True), activation=torch.sigmoid) - model.to(device) - optim = Adam(model.parameters()) - n_epochs = 100000 - - writer = SummaryWriter() - slow_interval = n_epochs//20 - fast_interval = n_epochs//1000 - loss_history = [] - bias_history = [] - embedding_snap = [] - def fast_logging(epoch): - bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten()) - embedding_snap.append(model.x.weight.cpu().detach().numpy()) - def slow_logging(epoch): - with torch.no_grad(): - dot_product_matrix = torch.zeros((len(graph), len(graph))) - for i in torch.arange(len(graph), dtype=int).reshape(-1,1): - dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu() - writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW') - - bias = model.kernel.bias.detach().cpu().numpy().flatten() - embedding = model.x.weight.cpu().detach().numpy() - fig = plt.figure() - plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1') - plt.arrow(0,0,bias[0], bias[1]) - plt.axis('equal') - writer.add_figure('embedding', fig, epoch) - plt.close() - - try: - for epoch in range(n_epochs): - sum_loss = 0 - if not(epoch%100) and epoch: - print(f'epoch {epoch}/{n_epochs}') - for e1, e2 in dataloader: - optim.zero_grad() - loss = -model.triplet_loss(e1, e2).mean() - loss.backward() - optim.step() - # model.center_rescale_embedding() - sum_loss += loss.cpu().detach().item() - epoch_loss = sum_loss/len(dataloader)+1 - writer.add_scalar('loss', epoch_loss, epoch) - - if not(epoch%fast_interval) or epoch == n_epochs-1: - fast_logging(epoch) - if not(epoch%slow_interval) or epoch == n_epochs-1: - slow_logging(epoch) - except KeyboardInterrupt: - print("Final logging. Interrupt again to skip.") - fast_logging(epoch) - slow_logging(epoch) - - print("Final logging.") - fig = plt.figure() - pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))]) - arrow = plt.arrow(0, 0, 0, 0) - maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0) - mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0) - plt.xlim(mins[0], maxs[0]) - plt.ylim(mins[1], maxs[1]) - - def anim(i): - pointcloud.set_offsets(embedding_snap[i]) - arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1]) - return pointcloud, arrow - - ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True) - ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4')) - plt.close() - writer.close() \ No newline at end of file + return k_pos/(k_pos + k_neg) \ No newline at end of file