hippoformer 0.0.15__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ results/
2
+
1
3
  # Byte-compiled / optimized / DLL files
2
4
  __pycache__/
3
5
  *.py[codz]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -285,7 +285,7 @@ class Attention(Module):
285
285
  q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
286
286
 
287
287
  if exists(kv_cache):
288
- ck1, cv1, vk2, cv2 = kv_cache
288
+ ck1, cv1, ck2, cv2 = kv_cache
289
289
  k1 = cat((ck1, k1), dim = -2)
290
290
  v1 = cat((cv1, v1), dim = -2)
291
291
  k2 = cat((ck2, k2), dim = -2)
@@ -300,12 +300,12 @@ class Attention(Module):
300
300
 
301
301
  i, j = sim.shape[-2:]
302
302
 
303
- j_seq = arange(j, device = device)[:, None]
304
- i_seq = arange(i, device = device)[None, :] + (j - i)
303
+ i_seq = arange(i, device = device)[:, None] + (j - i)
304
+ j_seq = arange(j, device = device)[None, :]
305
305
 
306
306
  windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
307
307
 
308
- sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
308
+ sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
309
309
 
310
310
  # attention sink, for token as well as for attention sinking - from gpt-oss
311
311
 
@@ -401,24 +401,28 @@ class TEMTransformer(Module):
401
401
  **transformer_kwargs
402
402
  )
403
403
 
404
- layers.append(block)
404
+ self.layers.append(block)
405
405
 
406
406
  def forward(
407
407
  self,
408
408
  sensory,
409
409
  actions,
410
- prev_hiddens = None, # for the GRU based path integrator
411
- prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
410
+ prev_hiddens = None, # for the GRU based path integrator
411
+ prev_kv_cache = None, # for the specialized transformer blocks for inducing the grid-cells
412
+ return_memories = False
412
413
  ):
413
414
 
414
- structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
415
+ structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
415
416
 
416
417
  encoded_sensory = self.sensory_encoder(sensory)
417
418
 
419
+ prev_kv_cache = default(prev_kv_cache, (None,) * len(self.layers))
420
+ iter_prev_kv_cache = iter(prev_kv_cache)
421
+
418
422
  next_kv_cache = []
419
423
 
420
424
  for layer in self.layers:
421
- structure, layer_next_cache = layer(structure, encoded_sensory)
425
+ structure, layer_next_cache = layer(structure, encoded_sensory, kv_cache = next(iter_prev_kv_cache, None))
422
426
  next_kv_cache.append(layer_next_cache)
423
427
 
424
428
  decoded_sensory = self.sensory_decoder(structure)
@@ -427,7 +431,10 @@ class TEMTransformer(Module):
427
431
 
428
432
  pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
429
433
 
430
- return pred_loss
434
+ if not return_memories:
435
+ return pred_loss
436
+
437
+ return pred_loss, next_memories
431
438
 
432
439
  # proposed mmTEM
433
440
 
@@ -484,7 +491,7 @@ class mmTEM(Module):
484
491
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
485
492
 
486
493
  self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
487
- self.assoc_scan = AssocScan(*assoc_scan_kwargs)
494
+ self.assoc_scan = AssocScan(**assoc_scan_kwargs)
488
495
 
489
496
  self.meta_memory_mlp = create_mlp(
490
497
  dim = dim * 2,
@@ -495,7 +502,7 @@ class mmTEM(Module):
495
502
  )
496
503
 
497
504
  def forward_with_mse_loss(params, keys, values):
498
- pred = functional_call(self.meta_memory_mlp, params, keys)
505
+ pred = functional_call(self.meta_memory_mlp, params, (keys,))
499
506
  return F.mse_loss(pred, values)
500
507
 
501
508
  grad_fn = grad(forward_with_mse_loss)
@@ -596,7 +603,7 @@ class mmTEM(Module):
596
603
 
597
604
  # 2b. structure from structure
598
605
 
599
- decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
606
+ decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
600
607
 
601
608
  structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
602
609
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.15"
3
+ version = "0.0.17"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,348 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "accelerate",
5
+ # "einops",
6
+ # "gym==0.25.2",
7
+ # "memory-maze",
8
+ # "dm-control",
9
+ # "matplotlib",
10
+ # "numpy<2",
11
+ # "beartype",
12
+ # "pillow",
13
+ # "scipy",
14
+ # "assoc-scan",
15
+ # "einx",
16
+ # "x-mlps-pytorch",
17
+ # ]
18
+ # ///
19
+
20
+ import os
21
+ os.environ['MUJOCO_GL'] = 'glfw'
22
+
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ from torch import nn, Tensor, pi, stack
27
+ import torch.nn.functional as F
28
+ from torch.utils.data import Dataset, DataLoader
29
+ from torch.optim import Adam
30
+ from einops import rearrange
31
+ from accelerate import Accelerator
32
+
33
+ import gym
34
+ import memory_maze
35
+
36
+ from hippoformer.hippoformer import mmTEM, maze_sensory_enc_dec
37
+
38
+ import matplotlib.pyplot as plt
39
+ import numpy as np
40
+ from PIL import Image
41
+ from scipy.signal import correlate2d
42
+
43
+ # helpers
44
+
45
+ def exists(v):
46
+ return v is not None
47
+
48
+ def default(v, d):
49
+ return v if exists(v) else d
50
+
51
+ def divisible_by(num, den):
52
+ return (num % den) == 0
53
+
54
+ # MemoryMaze environment wrapper
55
+
56
+ def find_physics(env):
57
+ curr = env
58
+ for _ in range(20): # depth limit
59
+ if hasattr(curr, '_physics'): return curr._physics
60
+ if hasattr(curr, 'physics'): return curr.physics
61
+ if hasattr(curr, 'env'): curr = curr.env
62
+ elif hasattr(curr, 'unwrapped'): curr = curr.unwrapped
63
+ else: break
64
+ return None
65
+
66
+ class MemoryMazeEnv:
67
+ def __init__(self, env_name = 'MemoryMaze-9x9-v0'):
68
+ self.env_name = env_name
69
+ self.env = gym.make(env_name)
70
+ self.observation_space = self.env.observation_space
71
+ self.action_space = self.env.action_space
72
+ self.physics = None
73
+
74
+ def reset(self):
75
+ obs = self.env.reset()
76
+ self.physics = find_physics(self.env)
77
+ return obs
78
+
79
+ def step(self, action):
80
+ return self.env.step(action)
81
+
82
+ def get_pos(self):
83
+ if self.physics is None:
84
+ self.physics = find_physics(self.env)
85
+ try:
86
+ return self.physics.data.qpos[:2].copy()
87
+ except Exception:
88
+ return np.array([0., 0.])
89
+
90
+ def generate_trajectory(self, steps = 100, skip_obs = False):
91
+ obs = self.reset()
92
+ observations, actions, positions = [], [], []
93
+
94
+ for _ in range(steps):
95
+ action = self.action_space.sample()
96
+
97
+ if not skip_obs:
98
+ obs_t = torch.from_numpy(obs.copy()).float()
99
+ obs_t = rearrange(obs_t, 'h w c -> c h w') / 255.0
100
+ observations.append(obs_t)
101
+
102
+ v_w = torch.zeros(2, dtype = torch.float32)
103
+ if action == 1: v_w[0] = 0.5 # Move forward
104
+ elif action == 2: v_w[1] = -0.5 # Rotate right
105
+ elif action == 3: v_w[1] = 0.5 # Rotate left
106
+
107
+ actions.append(v_w)
108
+ positions.append(torch.from_numpy(self.get_pos()).float())
109
+
110
+ step_res = self.step(action)
111
+ obs, done = step_res[0], step_res[2]
112
+ if done: obs = self.reset()
113
+
114
+ return stack(observations) if not skip_obs else None, stack(actions), stack(positions)
115
+
116
+ # dataset
117
+
118
+ class TrajectoryDataset(Dataset):
119
+ def __init__(self, world, num_trajectories = 32, steps = 100):
120
+ self.data = [world.generate_trajectory(steps) for _ in range(num_trajectories)]
121
+
122
+ def __len__(self):
123
+ return len(self.data)
124
+
125
+ def __getitem__(self, idx):
126
+ return self.data[idx]
127
+
128
+ # grid cell visualization
129
+
130
+ def get_sac(rate_map: Tensor):
131
+ """Spatial Autocorrelogram (SAC) using Torch"""
132
+ # rate_map: (res, res)
133
+ mask = ~rate_map.isnan()
134
+ if not mask.any():
135
+ return torch.zeros_like(rate_map)
136
+
137
+ m = rate_map.clone()
138
+ mean = rate_map[mask].mean()
139
+ m[mask] -= mean
140
+ m[~mask] = 0.
141
+
142
+ # 2D correlation via conv2d
143
+ # correlate2d(m, m, mode='full')
144
+ h, w = m.shape
145
+ m_batch = rearrange(m, 'h w -> 1 1 h w')
146
+
147
+ sac = F.conv2d(
148
+ F.pad(m_batch, (w - 1, w - 1, h - 1, h - 1)),
149
+ m_batch
150
+ )
151
+ return rearrange(sac, '1 1 h w -> h w')
152
+
153
+ def gaussian_blur_2d(img: Tensor, sigma: float = 1.0):
154
+ """2D Gaussian Blur in Torch"""
155
+ # img: (c, h, w)
156
+ ksize = int(2 * 3 * sigma + 1)
157
+ if ksize % 2 == 0: ksize += 1
158
+
159
+ x = torch.linspace(-3 * sigma, 3 * sigma, ksize)
160
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
161
+ kernel1d = pdf / pdf.sum()
162
+ kernel2d = kernel1d[:, None] * kernel1d[None, :]
163
+
164
+ c = img.shape[0]
165
+ kernel2d = rearrange(kernel2d, 'h w -> 1 1 h w').to(img.device)
166
+ kernel2d = kernel2d.expand(c, 1, -1, -1)
167
+
168
+ padding = ksize // 2
169
+ # reflect pad to avoid edge artifacts
170
+ img_padded = F.pad(rearrange(img, 'c h w -> 1 c h w'), (padding, padding, padding, padding), mode = 'reflect')
171
+ blurred = F.conv2d(img_padded, kernel2d, groups = c)
172
+ return rearrange(blurred, '1 c h w -> c h w')
173
+
174
+ class GridCellVisualizer:
175
+ def __init__(
176
+ self,
177
+ world: MemoryMazeEnv,
178
+ resolution: int = 40,
179
+ spatial_range: tuple[float, float] = (-5.0, 5.0)
180
+ ):
181
+ self.world = world
182
+ self.resolution = resolution
183
+ self.spatial_range = spatial_range
184
+
185
+ @torch.no_grad()
186
+ def get_rate_maps(self, model: nn.Module, steps: int = 5000):
187
+ model.eval()
188
+ device = next(model.parameters()).device
189
+
190
+ # Probing trajectory (skip observations for speed)
191
+ _, actions, positions = self.world.generate_trajectory(steps = steps, skip_obs = True)
192
+
193
+ actions = actions.to(device)
194
+ positions = positions.to(device)
195
+
196
+ actions_in = rearrange(actions, 't d -> 1 t d')
197
+ structural_codes = model.path_integrator(actions_in)
198
+ structural_codes = rearrange(structural_codes, '1 t d -> t d')
199
+
200
+ # Vectorized binning in Torch
201
+ res = self.resolution
202
+ p_min, p_max = self.spatial_range
203
+
204
+ # Map positions to [0, resolution - 1]
205
+ indices = ((positions - p_min) / (p_max - p_min + 1e-5) * (res - 1)).long()
206
+ indices = torch.clamp(indices, 0, res - 1)
207
+
208
+ num_cells = structural_codes.shape[-1]
209
+ activations = torch.zeros((num_cells, res, res), device = device)
210
+ counts = torch.zeros((res, res), device = device)
211
+
212
+ # Flat indices for index_add_
213
+ flat_indices = indices[:, 0] * res + indices[:, 1]
214
+
215
+ activations_flat = rearrange(activations, 'd h w -> d (h w)')
216
+ activations_flat.index_add_(1, flat_indices, structural_codes.T)
217
+
218
+ counts_flat = counts.view(-1)
219
+ counts_flat.index_add_(0, flat_indices, torch.ones_like(flat_indices, dtype = torch.float32))
220
+
221
+ # Occupancy normalization
222
+ rate_maps = activations / rearrange(counts.clamp(min = 1), 'h w -> 1 h w')
223
+ mask = counts < 1
224
+
225
+ # Fill NaNs before smoothing
226
+ has_visits = (~mask).any()
227
+ if has_visits:
228
+ # For each cell, fill unvisited with its own mean
229
+ # rate_maps: (c, h, w)
230
+ # mask: (h, w)
231
+ for i in range(num_cells):
232
+ rmap = rate_maps[i]
233
+ rmap[mask] = rmap[~mask].mean()
234
+
235
+ # Smoothing
236
+ rate_maps = gaussian_blur_2d(rate_maps, sigma = 1.0)
237
+
238
+ # Normalize to [0, 1] per cell
239
+ rm_min = rearrange(rate_maps.amin(dim = (1, 2)), 'c -> c 1 1')
240
+ rm_max = rearrange(rate_maps.amax(dim = (1, 2)), 'c -> c 1 1')
241
+
242
+ rate_maps = (rate_maps - rm_min) / (rm_max - rm_min).clamp(min = 1e-5)
243
+
244
+ # Restore NaNs for visualization transparency
245
+ rate_maps[:, mask] = float('nan')
246
+
247
+ return rate_maps
248
+
249
+ def visualize(
250
+ self,
251
+ model: nn.Module,
252
+ epoch: int,
253
+ path_to_save: str | Path,
254
+ probing_steps: int = 5000
255
+ ):
256
+ path_to_save = Path(path_to_save)
257
+ rate_maps = self.get_rate_maps(model, steps = probing_steps)
258
+ rate_maps_cpu = rate_maps.cpu()
259
+
260
+ # Sort by spatial variance to find high-information cells
261
+ # variance handling NaNs
262
+ variances = torch.from_numpy(np.nanvar(rate_maps_cpu.numpy(), axis = (1, 2)))
263
+ top_indices = torch.argsort(variances, descending = True)[:8]
264
+
265
+ fig, axes = plt.subplots(4, 4, figsize = (14, 14), facecolor = 'white')
266
+
267
+ cmap_rate = plt.get_cmap('rainbow').copy()
268
+ cmap_rate.set_bad('white')
269
+
270
+ for i, idx in enumerate(top_indices):
271
+ # Rate Map
272
+ ax_rate = axes[i // 2, (i % 2) * 2]
273
+ rate_map = rate_maps_cpu[idx]
274
+
275
+ ax_rate.imshow(rate_map.numpy(), cmap = cmap_rate, interpolation = 'nearest', origin = 'lower')
276
+ ax_rate.axis('off')
277
+ ax_rate.set_title(f'Rate Map {idx}')
278
+
279
+ # Spatial Autocorrelogram
280
+ ax_sac = axes[i // 2, (i % 2) * 2 + 1]
281
+ sac = get_sac(rate_map)
282
+
283
+ ax_sac.imshow(sac.numpy(), cmap = 'jet', interpolation = 'gaussian', origin = 'lower')
284
+ ax_sac.axis('off')
285
+ ax_sac.set_title(f'SAC {idx}')
286
+
287
+ plt.tight_layout()
288
+ plt.suptitle(f'Grid Cell Discovery (Epoch {epoch})', fontsize = 18)
289
+
290
+ plt.savefig(path_to_save)
291
+ plt.close()
292
+
293
+ # main simulation
294
+
295
+ def run_simulation():
296
+ accelerator = Accelerator()
297
+ accelerator.print(f"Using device: {accelerator.device}")
298
+
299
+ world = MemoryMazeEnv('MemoryMaze-9x9-v0')
300
+ visualizer = GridCellVisualizer(world)
301
+
302
+ model = mmTEM(
303
+ dim = 32,
304
+ sensory_encoder_decoder = maze_sensory_enc_dec,
305
+ dim_sensory = (3, 64, 64),
306
+ dim_action = 2,
307
+ dim_encoded_sensory = 32,
308
+ dim_structure = 64
309
+ )
310
+
311
+ optimizer = Adam(model.parameters(), lr = 1e-3)
312
+
313
+ accelerator.print("Generating training dataset (scale: 64x100)...")
314
+ dataset = TrajectoryDataset(world, num_trajectories = 64, steps = 100)
315
+ loader = DataLoader(dataset, batch_size = 16, shuffle = True)
316
+
317
+ model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
318
+
319
+ results_folder = Path('results')
320
+ results_folder.mkdir(parents = True, exist_ok = True)
321
+
322
+ accelerator.print("Starting extended training on MemoryMaze3D...")
323
+ for epoch in range(1, 16):
324
+ model.train()
325
+ total_loss = 0
326
+ for obs, actions, _ in loader:
327
+ obs = rearrange(obs, 'b t c h w -> b c t h w')
328
+ loss = model(obs, actions)
329
+ optimizer.zero_grad()
330
+ accelerator.backward(loss)
331
+ optimizer.step()
332
+ total_loss += loss.item()
333
+
334
+ accelerator.print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.4f}")
335
+
336
+ if divisible_by(epoch, 5):
337
+ visualizer.visualize(accelerator.unwrap_model(model), epoch, path_to_save = results_folder / f'grid_cells_epoch_{epoch}.png', probing_steps = 5000)
338
+ accelerator.print(f"Grid cell visualization (epoch {epoch}) saved.")
339
+
340
+ visualizer.visualize(accelerator.unwrap_model(model), 15, path_to_save = results_folder / 'grid_cells_final.png', probing_steps = 10000)
341
+
342
+ obs, _, _ = world.generate_trajectory(steps = 1)
343
+ sample_img = rearrange(obs[0], 'c h w -> h w c').numpy()
344
+ Image.fromarray((sample_img * 255).astype(np.uint8)).save(results_folder / 'sample_view.png')
345
+ accelerator.print("Extended simulation complete. Results saved to 'results/' folder.")
346
+
347
+ if __name__ == "__main__":
348
+ run_simulation()
File without changes
File without changes