Skip to content
Snippets Groups Projects
Commit 60353fb3 authored by alexandre-blanc's avatar alexandre-blanc
Browse files

WIP

parent b2bd3edb
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,41 @@ class EdgePairWithCommonNodeDataset(Dataset):
return positive_edge, negative_edge
class EdgePairWithCommonNodeDatasetFull(Dataset):
def __init__(self, graph, node_mapping=None):
self.graph = graph
if node_mapping is None:
node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}}
self.node_mapping = node_mapping
self.node_list = list(self.graph.nodes())
import itertools
self.indices = []
for u, v, w in itertools.permutations(self.graph.nodes(), 3):
e1 = (u, v)
e2 = (u, w)
w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0
w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0
if w1 != w2:
self.indices.append((u,v,w))
print(len(self))
def __len__(self):
return len(self.indices)
def __getitem__(self, i):
u, v, w = self.indices[i]
e1 = (u, v)
e2 = (u, w)
w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0
w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0
# recover the corresponding edges in the directed graph
positive_edge, negative_edge = (e1, e2) if w1 > w2 else (e2, e1)
positive_edge = (self.node_mapping['key_to_id'][positive_edge[0]], self.node_mapping['key_to_id'][positive_edge[1]])
negative_edge = (self.node_mapping['key_to_id'][negative_edge[0]], self.node_mapping['key_to_id'][negative_edge[1]])
return positive_edge, negative_edge
class EdgePairDataset(IterableDataset):
def __init__(self, graph, node_mapping=None):
self.graph = graph
......@@ -125,16 +160,15 @@ class EdgeDataset(Dataset):
anchor_key, target_key = self.node_mapping['id_to_key'][i], self.node_mapping['id_to_key'][j]
return (i,j), (self.graph.get_edge_data(anchor_key, target_key)['weight'] if self.graph.has_edge(anchor_key,target_key) else 0)
if __name__ == '__main__':
from torch.utils.data import DataLoader
from synthetic_graphs import *
from dataloaders import EdgeTripletDataset
# if __name__ == '__main__':
# from torch.utils.data import DataLoader
# from synthetic_graphs import *
N, M = 15, 12
graph, _ = synfire_chain(N, M)
dataset = EdgeTripletDataset(graph)
dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
# N, M = 15, 12
# graph, _ = synfire_chain(N, M)
# dataset = EdgeTripletDataset(graph)
# dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
for _ in dataloader:
pass
# for _ in dataloader:
# pass
import os
from torch.utils.data import DataLoader
import torch
from torch.optim import SGD, Adam, LBFGS
from synthetic_graphs import *
from dataloaders import *
from single_space_model import Model
from kernels import GaussianKernel
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torch.utils.tensorboard import SummaryWriter
device = 'cuda' if torch.cuda.is_available() else 'cpu'
N, M = 15, 12
graph, _ = synfire_chain(N, M)
dataset = EdgePairWithCommonNodeDataset(graph)
dataloader = DataLoader(dataset, shuffle=False, batch_size=len(dataset))
model = Model(graph.order(), 3, GaussianKernel(3, bias=True), activation=torch.sigmoid)
model.init_embedding(graph)
model.to(device)
optim = Adam(model.parameters())
# optim = LBFGS(model.parameters(), lr=1, line_search_fn='strong_wolfe')
n_epochs = 100000
writer = SummaryWriter(comment='EXP001')
slow_interval = n_epochs//20
fast_interval = n_epochs//1000
loss_history = []
bias_history = []
embedding_snap = []
def fast_logging(epoch):
bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten())
embedding_snap.append(model.x.weight.cpu().detach().numpy())
def slow_logging(epoch):
with torch.no_grad():
dot_product_matrix = torch.zeros((len(graph), len(graph)))
for i in torch.arange(len(graph), dtype=int).reshape(-1,1):
dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu()
writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW')
bias = model.kernel.bias.detach().cpu().numpy().flatten()
embedding = model.x.weight.cpu().detach().numpy()
fig = plt.figure()
plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1')
plt.arrow(0,0,bias[0], bias[1])
plt.axis('equal')
writer.add_figure('embedding', fig, epoch)
plt.close()
def closure(e1, e2): # for LBFGS
def f():
optim.zero_grad()
loss = -model.triplet_loss(e1, e2).mean()
loss.backward()
return loss
return f
try:
for epoch in range(n_epochs):
sum_loss = 0
if not(epoch%100) and epoch:
print(f'epoch {epoch}/{n_epochs}')
for e1, e2 in dataloader:
optim.step(closure=closure(e1,e2))
# model.center_rescale_embedding()
with torch.no_grad():
sum_loss += -model.triplet_loss(e1, e2).mean().cpu().numpy()
epoch_loss = sum_loss/len(dataloader)+1
writer.add_scalar('loss', epoch_loss, epoch)
if not(epoch%fast_interval) or epoch == n_epochs-1:
fast_logging(epoch)
if not(epoch%slow_interval) or epoch == n_epochs-1:
slow_logging(epoch)
except KeyboardInterrupt:
print("Final logging. Interrupt again to skip.")
fast_logging(epoch)
slow_logging(epoch)
print("Final logging.")
fig = plt.figure()
pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))])
arrow = plt.arrow(0, 0, 0, 0)
maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0)
mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0)
plt.xlim(mins[0], maxs[0])
plt.ylim(mins[1], maxs[1])
def anim(i):
pointcloud.set_offsets(embedding_snap[i])
arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1])
return pointcloud, arrow
ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True)
ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4'))
plt.close()
writer.close()
\ No newline at end of file
import torch.nn as nn
import torch
from networkx import spectral_layout
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Model(nn.Module):
def __init__(self, graph, dim, kernel, activation):
def __init__(self, n_nodes, dim, kernel, activation):
super().__init__()
self.x = nn.Embedding(graph.order(), dim)
self.n_nodes = n_nodes
self.dim = dim
self.x = nn.Embedding(n_nodes, dim)
self.kernel= kernel
self.activation = activation
def init_embedding(self, graph, node_mapping = None):
if node_mapping is None:
node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}}
init_embedding = spectral_layout(graph, dim=self.dim)
init_embedding = [torch.Tensor(init_embedding[node_mapping['id_to_key'][i]]) for i in range(len(graph))]
init_embedding = torch.stack(init_embedding, dim=0)
init_embedding = init_embedding/init_embedding.std()
self.x.weight.data = init_embedding
def apply_kernel(self, node1, node2):
vec1, vec2 = self.x(node1.to(device)), self.x(node2.to(device))
return self.kernel(vec1, vec2)
......@@ -25,93 +37,4 @@ class Model(nn.Module):
def triplet_loss(self, e_pos, e_neg):
k_pos = self.apply_kernel(*e_pos)
k_neg = self.apply_kernel(*e_neg)
return k_pos/(k_pos + k_neg)
if __name__ == '__main__':
import os
from torch.utils.data import DataLoader
from torch.optim import SGD, Adam
from synthetic_graphs import *
from dataloaders import *
from kernels import GaussianKernel
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torch.utils.tensorboard import SummaryWriter
N, M = 15, 12
graph, _ = synfire_chain(N, M)
dataset = EdgePairWithCommonNodeDataset(graph)
dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
model = Model(graph, 3, GaussianKernel(3, bias=True), activation=torch.sigmoid)
model.to(device)
optim = Adam(model.parameters())
n_epochs = 100000
writer = SummaryWriter()
slow_interval = n_epochs//20
fast_interval = n_epochs//1000
loss_history = []
bias_history = []
embedding_snap = []
def fast_logging(epoch):
bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten())
embedding_snap.append(model.x.weight.cpu().detach().numpy())
def slow_logging(epoch):
with torch.no_grad():
dot_product_matrix = torch.zeros((len(graph), len(graph)))
for i in torch.arange(len(graph), dtype=int).reshape(-1,1):
dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu()
writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW')
bias = model.kernel.bias.detach().cpu().numpy().flatten()
embedding = model.x.weight.cpu().detach().numpy()
fig = plt.figure()
plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1')
plt.arrow(0,0,bias[0], bias[1])
plt.axis('equal')
writer.add_figure('embedding', fig, epoch)
plt.close()
try:
for epoch in range(n_epochs):
sum_loss = 0
if not(epoch%100) and epoch:
print(f'epoch {epoch}/{n_epochs}')
for e1, e2 in dataloader:
optim.zero_grad()
loss = -model.triplet_loss(e1, e2).mean()
loss.backward()
optim.step()
# model.center_rescale_embedding()
sum_loss += loss.cpu().detach().item()
epoch_loss = sum_loss/len(dataloader)+1
writer.add_scalar('loss', epoch_loss, epoch)
if not(epoch%fast_interval) or epoch == n_epochs-1:
fast_logging(epoch)
if not(epoch%slow_interval) or epoch == n_epochs-1:
slow_logging(epoch)
except KeyboardInterrupt:
print("Final logging. Interrupt again to skip.")
fast_logging(epoch)
slow_logging(epoch)
print("Final logging.")
fig = plt.figure()
pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))])
arrow = plt.arrow(0, 0, 0, 0)
maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0)
mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0)
plt.xlim(mins[0], maxs[0])
plt.ylim(mins[1], maxs[1])
def anim(i):
pointcloud.set_offsets(embedding_snap[i])
arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1])
return pointcloud, arrow
ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True)
ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4'))
plt.close()
writer.close()
\ No newline at end of file
return k_pos/(k_pos + k_neg)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment