diff --git a/src/dataloaders.py b/src/dataloaders.py
index 3fd6fa5e07a216fa3fa835fdb8d4bcd6c2b5bfa4..60adba8c09f932506679163923529a55b68ca8b7 100644
--- a/src/dataloaders.py
+++ b/src/dataloaders.py
@@ -62,6 +62,41 @@ class EdgePairWithCommonNodeDataset(Dataset):
 
         return positive_edge, negative_edge
     
+class EdgePairWithCommonNodeDatasetFull(Dataset):
+    def __init__(self, graph, node_mapping=None):
+        self.graph = graph
+        if node_mapping is None:
+            node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}}
+        self.node_mapping = node_mapping
+        self.node_list = list(self.graph.nodes())
+        import itertools
+        self.indices = []
+        for u, v, w in itertools.permutations(self.graph.nodes(), 3):
+            e1 = (u, v)
+            e2 = (u, w)
+            w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0
+            w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0
+            if w1 != w2:
+                self.indices.append((u,v,w))
+        print(len(self))
+
+    def __len__(self):
+        return len(self.indices)
+    
+    def __getitem__(self, i):
+        u, v, w = self.indices[i]
+        e1 = (u, v)
+        e2 = (u, w)
+        w1 = self.graph.get_edge_data(*e1)['weight'] if self.graph.has_edge(*e1) else 0
+        w2 = self.graph.get_edge_data(*e2)['weight'] if self.graph.has_edge(*e2) else 0
+        
+        # recover the corresponding edges in the directed graph
+        positive_edge, negative_edge = (e1, e2) if w1 > w2 else (e2, e1)
+        positive_edge = (self.node_mapping['key_to_id'][positive_edge[0]], self.node_mapping['key_to_id'][positive_edge[1]])
+        negative_edge = (self.node_mapping['key_to_id'][negative_edge[0]], self.node_mapping['key_to_id'][negative_edge[1]])
+
+        return positive_edge, negative_edge
+    
 class EdgePairDataset(IterableDataset):
     def __init__(self, graph, node_mapping=None):
         self.graph = graph
@@ -125,16 +160,15 @@ class EdgeDataset(Dataset):
         anchor_key, target_key = self.node_mapping['id_to_key'][i], self.node_mapping['id_to_key'][j]
         return (i,j), (self.graph.get_edge_data(anchor_key, target_key)['weight'] if self.graph.has_edge(anchor_key,target_key) else 0)
         
-if __name__ == '__main__':
-    from torch.utils.data import DataLoader
-    from synthetic_graphs import *
-    from dataloaders import EdgeTripletDataset
+# if __name__ == '__main__':
+#     from torch.utils.data import DataLoader
+#     from synthetic_graphs import *
 
-    N, M = 15, 12
-    graph, _ = synfire_chain(N, M)
-    dataset = EdgeTripletDataset(graph)
-    dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
+#     N, M = 15, 12
+#     graph, _ = synfire_chain(N, M)
+#     dataset = EdgeTripletDataset(graph)
+#     dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
 
-    for _ in dataloader:
-        pass
+#     for _ in dataloader:
+#         pass
         
diff --git a/src/experiment1.py b/src/experiment1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8649abe8470816873b6ba19fdb3ea3b1ddff8947
--- /dev/null
+++ b/src/experiment1.py
@@ -0,0 +1,99 @@
+import os
+from torch.utils.data import DataLoader
+import torch
+from torch.optim import SGD, Adam, LBFGS
+from synthetic_graphs import *
+from dataloaders import *
+from single_space_model import Model
+from kernels import GaussianKernel
+import matplotlib.pyplot as plt
+from matplotlib.animation import FuncAnimation
+from torch.utils.tensorboard import SummaryWriter
+
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+N, M = 15, 12
+graph, _ = synfire_chain(N, M)
+dataset = EdgePairWithCommonNodeDataset(graph)
+dataloader = DataLoader(dataset, shuffle=False, batch_size=len(dataset))
+
+model = Model(graph.order(), 3, GaussianKernel(3, bias=True), activation=torch.sigmoid)
+model.init_embedding(graph)
+model.to(device)
+optim = Adam(model.parameters())
+# optim = LBFGS(model.parameters(), lr=1, line_search_fn='strong_wolfe')
+n_epochs = 100000
+
+writer = SummaryWriter(comment='EXP001')
+slow_interval = n_epochs//20
+fast_interval = n_epochs//1000
+loss_history = []
+bias_history = []
+embedding_snap = []
+def fast_logging(epoch):
+    bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten())
+    embedding_snap.append(model.x.weight.cpu().detach().numpy())
+def slow_logging(epoch):
+    with torch.no_grad():
+        dot_product_matrix = torch.zeros((len(graph), len(graph)))
+        for i in torch.arange(len(graph), dtype=int).reshape(-1,1):
+            dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu()
+
+    writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW')
+    bias = model.kernel.bias.detach().cpu().numpy().flatten()
+    embedding = model.x.weight.cpu().detach().numpy()
+    fig = plt.figure()
+    plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1')
+    plt.arrow(0,0,bias[0], bias[1])
+    plt.axis('equal')
+    writer.add_figure('embedding', fig, epoch)
+    plt.close()
+
+def closure(e1, e2): # for LBFGS
+    def f():
+        optim.zero_grad()
+        loss = -model.triplet_loss(e1, e2).mean()
+        loss.backward()
+        return loss
+    return f
+
+try:
+    for epoch in range(n_epochs):
+        sum_loss = 0
+        if not(epoch%100) and epoch:
+            print(f'epoch {epoch}/{n_epochs}')
+        for e1, e2 in dataloader:
+            optim.step(closure=closure(e1,e2))
+            # model.center_rescale_embedding()
+            with torch.no_grad():
+                sum_loss += -model.triplet_loss(e1, e2).mean().cpu().numpy()
+        epoch_loss = sum_loss/len(dataloader)+1
+        writer.add_scalar('loss', epoch_loss, epoch)
+
+        if not(epoch%fast_interval) or epoch == n_epochs-1:
+            fast_logging(epoch)
+        if not(epoch%slow_interval) or epoch == n_epochs-1:
+            slow_logging(epoch)
+except KeyboardInterrupt:
+    print("Final logging. Interrupt again to skip.")
+    fast_logging(epoch)
+    slow_logging(epoch)
+
+print("Final logging.")
+fig = plt.figure()
+pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))])
+arrow = plt.arrow(0, 0, 0, 0)
+maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0)
+mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0)
+plt.xlim(mins[0], maxs[0])
+plt.ylim(mins[1], maxs[1])
+
+def anim(i):
+    pointcloud.set_offsets(embedding_snap[i])
+    arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1])
+    return pointcloud, arrow
+
+ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True)
+ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4'))
+plt.close()
+writer.close()
\ No newline at end of file
diff --git a/src/single_space_model.py b/src/single_space_model.py
index 49c102c1a92fd28b433ba0196b2f74a1d27781b9..96a10e7f94aa2048a07ef416decb41faacc7ec5d 100644
--- a/src/single_space_model.py
+++ b/src/single_space_model.py
@@ -1,15 +1,27 @@
 import torch.nn as nn
 import torch
+from networkx import spectral_layout
 
 device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 class Model(nn.Module):
-    def __init__(self, graph, dim, kernel, activation):
+    def __init__(self, n_nodes, dim, kernel, activation):
         super().__init__()
-        self.x = nn.Embedding(graph.order(), dim)
+        self.n_nodes = n_nodes
+        self.dim = dim
+        self.x = nn.Embedding(n_nodes, dim)
         self.kernel= kernel
         self.activation = activation
 
+    def init_embedding(self, graph, node_mapping = None):
+        if node_mapping is None:
+            node_mapping = {'id_to_key':{i:i for i in range(graph.order())},'key_to_id':{i:i for i in range(graph.order())}}
+        init_embedding = spectral_layout(graph, dim=self.dim)
+        init_embedding = [torch.Tensor(init_embedding[node_mapping['id_to_key'][i]]) for i in range(len(graph))]
+        init_embedding = torch.stack(init_embedding, dim=0)
+        init_embedding = init_embedding/init_embedding.std()
+        self.x.weight.data = init_embedding
+
     def apply_kernel(self, node1, node2):
         vec1, vec2 = self.x(node1.to(device)), self.x(node2.to(device))
         return self.kernel(vec1, vec2)
@@ -25,93 +37,4 @@ class Model(nn.Module):
     def triplet_loss(self, e_pos, e_neg):
         k_pos = self.apply_kernel(*e_pos)
         k_neg = self.apply_kernel(*e_neg)
-        return k_pos/(k_pos + k_neg)
-    
-if __name__ == '__main__':
-    import os
-    from torch.utils.data import DataLoader
-    from torch.optim import SGD, Adam
-    from synthetic_graphs import *
-    from dataloaders import *
-    from kernels import GaussianKernel
-    import matplotlib.pyplot as plt
-    from matplotlib.animation import FuncAnimation
-    from torch.utils.tensorboard import SummaryWriter
-
-    N, M = 15, 12
-    graph, _ = synfire_chain(N, M)
-    dataset = EdgePairWithCommonNodeDataset(graph)
-    dataloader = DataLoader(dataset, shuffle=False, batch_size=N*M)
-
-    model = Model(graph, 3, GaussianKernel(3, bias=True), activation=torch.sigmoid)
-    model.to(device)
-    optim = Adam(model.parameters())
-    n_epochs = 100000
-
-    writer = SummaryWriter()
-    slow_interval = n_epochs//20
-    fast_interval = n_epochs//1000
-    loss_history = []
-    bias_history = []
-    embedding_snap = []
-    def fast_logging(epoch):
-        bias_history.append(model.kernel.bias.detach().cpu().numpy().flatten())
-        embedding_snap.append(model.x.weight.cpu().detach().numpy())
-    def slow_logging(epoch):
-        with torch.no_grad():
-            dot_product_matrix = torch.zeros((len(graph), len(graph)))
-            for i in torch.arange(len(graph), dtype=int).reshape(-1,1):
-                dot_product_matrix[i.item(),:] = model.apply_kernel(i.repeat(len(graph)), torch.arange(len(graph))).cpu()
-        writer.add_image('kernel_product', dot_product_matrix, epoch, dataformats='HW')
-
-        bias = model.kernel.bias.detach().cpu().numpy().flatten()
-        embedding = model.x.weight.cpu().detach().numpy()
-        fig = plt.figure()
-        plt.scatter(embedding[:,0], embedding[:,1], c=[i//M for i in range(len(graph))], cmap='Set1')
-        plt.arrow(0,0,bias[0], bias[1])
-        plt.axis('equal')
-        writer.add_figure('embedding', fig, epoch)
-        plt.close()
-
-    try:
-        for epoch in range(n_epochs):
-            sum_loss = 0
-            if not(epoch%100) and epoch:
-                print(f'epoch {epoch}/{n_epochs}')
-            for e1, e2 in dataloader:
-                optim.zero_grad()
-                loss = -model.triplet_loss(e1, e2).mean()
-                loss.backward()
-                optim.step()
-                # model.center_rescale_embedding()
-                sum_loss += loss.cpu().detach().item()
-            epoch_loss = sum_loss/len(dataloader)+1
-            writer.add_scalar('loss', epoch_loss, epoch)
-
-            if not(epoch%fast_interval) or epoch == n_epochs-1:
-                fast_logging(epoch)
-            if not(epoch%slow_interval) or epoch == n_epochs-1:
-                slow_logging(epoch)
-    except KeyboardInterrupt:
-        print("Final logging. Interrupt again to skip.")
-        fast_logging(epoch)
-        slow_logging(epoch)
-
-    print("Final logging.")
-    fig = plt.figure()
-    pointcloud = plt.scatter(embedding_snap[0][:,0], embedding_snap[0][:,1], c=[i//12 for i in range(len(graph))])
-    arrow = plt.arrow(0, 0, 0, 0)
-    maxs = 1.05*np.max(np.concatenate(embedding_snap, axis=0), axis=0)
-    mins = 1.05*np.min(np.concatenate(embedding_snap, axis=0), axis=0)
-    plt.xlim(mins[0], maxs[0])
-    plt.ylim(mins[1], maxs[1])
-
-    def anim(i):
-        pointcloud.set_offsets(embedding_snap[i])
-        arrow.set_data(dx=bias_history[i][0], dy=bias_history[i][1])
-        return pointcloud, arrow
-    
-    ani = FuncAnimation(fig, anim, len(embedding_snap), interval=50, repeat=True)
-    ani.save(os.path.join(writer.log_dir, 'latent_space_animation.mp4'))
-    plt.close()
-    writer.close()
\ No newline at end of file
+        return k_pos/(k_pos + k_neg)
\ No newline at end of file