hippoformer 0.0.14__tar.gz → 0.0.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ results/
2
+
1
3
  # Byte-compiled / optimized / DLL files
2
4
  __pycache__/
3
5
  *.py[codz]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -285,7 +285,7 @@ class Attention(Module):
285
285
  q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
286
286
 
287
287
  if exists(kv_cache):
288
- ck1, cv1, vk2, cv2 = kv_cache
288
+ ck1, cv1, ck2, cv2 = kv_cache
289
289
  k1 = cat((ck1, k1), dim = -2)
290
290
  v1 = cat((cv1, v1), dim = -2)
291
291
  k2 = cat((ck2, k2), dim = -2)
@@ -300,12 +300,12 @@ class Attention(Module):
300
300
 
301
301
  i, j = sim.shape[-2:]
302
302
 
303
- j_seq = arange(j, device = device)[:, None]
304
- i_seq = arange(i, device = device)[None, :] + (j - i)
303
+ j_seq = arange(i, device = device)[:, None]
304
+ i_seq = arange(j, device = device)[None, :] + (j - i)
305
305
 
306
306
  windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
307
307
 
308
- sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
308
+ sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
309
309
 
310
310
  # attention sink, for token as well as for attention sinking - from gpt-oss
311
311
 
@@ -365,7 +365,7 @@ class TEMTransformerBlock(Module):
365
365
 
366
366
  x = self.ff(x) + x
367
367
 
368
- next_kv_cache = next_kv_cache[:, -self.window_size:]
368
+ next_kv_cache = tuple(t[:, -self.window_size:] for t in next_kv_cache)
369
369
 
370
370
  return x, next_kv_cache
371
371
 
@@ -401,7 +401,7 @@ class TEMTransformer(Module):
401
401
  **transformer_kwargs
402
402
  )
403
403
 
404
- layers.append(block)
404
+ self.layers.append(block)
405
405
 
406
406
  def forward(
407
407
  self,
@@ -411,7 +411,7 @@ class TEMTransformer(Module):
411
411
  prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
412
412
  ):
413
413
 
414
- structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
414
+ structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
415
415
 
416
416
  encoded_sensory = self.sensory_encoder(sensory)
417
417
 
@@ -484,7 +484,7 @@ class mmTEM(Module):
484
484
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
485
485
 
486
486
  self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
487
- self.assoc_scan = AssocScan(*assoc_scan_kwargs)
487
+ self.assoc_scan = AssocScan(**assoc_scan_kwargs)
488
488
 
489
489
  self.meta_memory_mlp = create_mlp(
490
490
  dim = dim * 2,
@@ -495,7 +495,7 @@ class mmTEM(Module):
495
495
  )
496
496
 
497
497
  def forward_with_mse_loss(params, keys, values):
498
- pred = functional_call(self.meta_memory_mlp, params, keys)
498
+ pred = functional_call(self.meta_memory_mlp, params, (keys,))
499
499
  return F.mse_loss(pred, values)
500
500
 
501
501
  grad_fn = grad(forward_with_mse_loss)
@@ -596,7 +596,7 @@ class mmTEM(Module):
596
596
 
597
597
  # 2b. structure from structure
598
598
 
599
- decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
599
+ decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
600
600
 
601
601
  structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
602
602
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.14"
3
+ version = "0.0.16"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,348 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "accelerate",
5
+ # "einops",
6
+ # "gym==0.25.2",
7
+ # "memory-maze",
8
+ # "dm-control",
9
+ # "matplotlib",
10
+ # "numpy<2",
11
+ # "beartype",
12
+ # "pillow",
13
+ # "scipy",
14
+ # "assoc-scan",
15
+ # "einx",
16
+ # "x-mlps-pytorch",
17
+ # ]
18
+ # ///
19
+
20
+ import os
21
+ os.environ['MUJOCO_GL'] = 'glfw'
22
+
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ from torch import nn, Tensor, pi, stack
27
+ import torch.nn.functional as F
28
+ from torch.utils.data import Dataset, DataLoader
29
+ from torch.optim import Adam
30
+ from einops import rearrange
31
+ from accelerate import Accelerator
32
+
33
+ import gym
34
+ import memory_maze
35
+
36
+ from hippoformer.hippoformer import mmTEM, maze_sensory_enc_dec
37
+
38
+ import matplotlib.pyplot as plt
39
+ import numpy as np
40
+ from PIL import Image
41
+ from scipy.signal import correlate2d
42
+
43
+ # helpers
44
+
45
+ def exists(v):
46
+ return v is not None
47
+
48
+ def default(v, d):
49
+ return v if exists(v) else d
50
+
51
+ def divisible_by(num, den):
52
+ return (num % den) == 0
53
+
54
+ # MemoryMaze environment wrapper
55
+
56
+ def find_physics(env):
57
+ curr = env
58
+ for _ in range(20): # depth limit
59
+ if hasattr(curr, '_physics'): return curr._physics
60
+ if hasattr(curr, 'physics'): return curr.physics
61
+ if hasattr(curr, 'env'): curr = curr.env
62
+ elif hasattr(curr, 'unwrapped'): curr = curr.unwrapped
63
+ else: break
64
+ return None
65
+
66
+ class MemoryMazeEnv:
67
+ def __init__(self, env_name = 'MemoryMaze-9x9-v0'):
68
+ self.env_name = env_name
69
+ self.env = gym.make(env_name)
70
+ self.observation_space = self.env.observation_space
71
+ self.action_space = self.env.action_space
72
+ self.physics = None
73
+
74
+ def reset(self):
75
+ obs = self.env.reset()
76
+ self.physics = find_physics(self.env)
77
+ return obs
78
+
79
+ def step(self, action):
80
+ return self.env.step(action)
81
+
82
+ def get_pos(self):
83
+ if self.physics is None:
84
+ self.physics = find_physics(self.env)
85
+ try:
86
+ return self.physics.data.qpos[:2].copy()
87
+ except Exception:
88
+ return np.array([0., 0.])
89
+
90
+ def generate_trajectory(self, steps = 100, skip_obs = False):
91
+ obs = self.reset()
92
+ observations, actions, positions = [], [], []
93
+
94
+ for _ in range(steps):
95
+ action = self.action_space.sample()
96
+
97
+ if not skip_obs:
98
+ obs_t = torch.from_numpy(obs.copy()).float()
99
+ obs_t = rearrange(obs_t, 'h w c -> c h w') / 255.0
100
+ observations.append(obs_t)
101
+
102
+ v_w = torch.zeros(2, dtype = torch.float32)
103
+ if action == 1: v_w[0] = 0.5 # Move forward
104
+ elif action == 2: v_w[1] = -0.5 # Rotate right
105
+ elif action == 3: v_w[1] = 0.5 # Rotate left
106
+
107
+ actions.append(v_w)
108
+ positions.append(torch.from_numpy(self.get_pos()).float())
109
+
110
+ step_res = self.step(action)
111
+ obs, done = step_res[0], step_res[2]
112
+ if done: obs = self.reset()
113
+
114
+ return stack(observations) if not skip_obs else None, stack(actions), stack(positions)
115
+
116
+ # dataset
117
+
118
+ class TrajectoryDataset(Dataset):
119
+ def __init__(self, world, num_trajectories = 32, steps = 100):
120
+ self.data = [world.generate_trajectory(steps) for _ in range(num_trajectories)]
121
+
122
+ def __len__(self):
123
+ return len(self.data)
124
+
125
+ def __getitem__(self, idx):
126
+ return self.data[idx]
127
+
128
+ # grid cell visualization
129
+
130
+ def get_sac(rate_map: Tensor):
131
+ """Spatial Autocorrelogram (SAC) using Torch"""
132
+ # rate_map: (res, res)
133
+ mask = ~rate_map.isnan()
134
+ if not mask.any():
135
+ return torch.zeros_like(rate_map)
136
+
137
+ m = rate_map.clone()
138
+ mean = rate_map[mask].mean()
139
+ m[mask] -= mean
140
+ m[~mask] = 0.
141
+
142
+ # 2D correlation via conv2d
143
+ # correlate2d(m, m, mode='full')
144
+ h, w = m.shape
145
+ m_batch = rearrange(m, 'h w -> 1 1 h w')
146
+
147
+ sac = F.conv2d(
148
+ F.pad(m_batch, (w - 1, w - 1, h - 1, h - 1)),
149
+ m_batch
150
+ )
151
+ return rearrange(sac, '1 1 h w -> h w')
152
+
153
+ def gaussian_blur_2d(img: Tensor, sigma: float = 1.0):
154
+ """2D Gaussian Blur in Torch"""
155
+ # img: (c, h, w)
156
+ ksize = int(2 * 3 * sigma + 1)
157
+ if ksize % 2 == 0: ksize += 1
158
+
159
+ x = torch.linspace(-3 * sigma, 3 * sigma, ksize)
160
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
161
+ kernel1d = pdf / pdf.sum()
162
+ kernel2d = kernel1d[:, None] * kernel1d[None, :]
163
+
164
+ c = img.shape[0]
165
+ kernel2d = rearrange(kernel2d, 'h w -> 1 1 h w').to(img.device)
166
+ kernel2d = kernel2d.expand(c, 1, -1, -1)
167
+
168
+ padding = ksize // 2
169
+ # reflect pad to avoid edge artifacts
170
+ img_padded = F.pad(rearrange(img, 'c h w -> 1 c h w'), (padding, padding, padding, padding), mode = 'reflect')
171
+ blurred = F.conv2d(img_padded, kernel2d, groups = c)
172
+ return rearrange(blurred, '1 c h w -> c h w')
173
+
174
+ class GridCellVisualizer:
175
+ def __init__(
176
+ self,
177
+ world: MemoryMazeEnv,
178
+ resolution: int = 40,
179
+ spatial_range: tuple[float, float] = (-5.0, 5.0)
180
+ ):
181
+ self.world = world
182
+ self.resolution = resolution
183
+ self.spatial_range = spatial_range
184
+
185
+ @torch.no_grad()
186
+ def get_rate_maps(self, model: nn.Module, steps: int = 5000):
187
+ model.eval()
188
+ device = next(model.parameters()).device
189
+
190
+ # Probing trajectory (skip observations for speed)
191
+ _, actions, positions = self.world.generate_trajectory(steps = steps, skip_obs = True)
192
+
193
+ actions = actions.to(device)
194
+ positions = positions.to(device)
195
+
196
+ actions_in = rearrange(actions, 't d -> 1 t d')
197
+ structural_codes = model.path_integrator(actions_in)
198
+ structural_codes = rearrange(structural_codes, '1 t d -> t d')
199
+
200
+ # Vectorized binning in Torch
201
+ res = self.resolution
202
+ p_min, p_max = self.spatial_range
203
+
204
+ # Map positions to [0, resolution - 1]
205
+ indices = ((positions - p_min) / (p_max - p_min + 1e-5) * (res - 1)).long()
206
+ indices = torch.clamp(indices, 0, res - 1)
207
+
208
+ num_cells = structural_codes.shape[-1]
209
+ activations = torch.zeros((num_cells, res, res), device = device)
210
+ counts = torch.zeros((res, res), device = device)
211
+
212
+ # Flat indices for index_add_
213
+ flat_indices = indices[:, 0] * res + indices[:, 1]
214
+
215
+ activations_flat = rearrange(activations, 'd h w -> d (h w)')
216
+ activations_flat.index_add_(1, flat_indices, structural_codes.T)
217
+
218
+ counts_flat = counts.view(-1)
219
+ counts_flat.index_add_(0, flat_indices, torch.ones_like(flat_indices, dtype = torch.float32))
220
+
221
+ # Occupancy normalization
222
+ rate_maps = activations / rearrange(counts.clamp(min = 1), 'h w -> 1 h w')
223
+ mask = counts < 1
224
+
225
+ # Fill NaNs before smoothing
226
+ has_visits = (~mask).any()
227
+ if has_visits:
228
+ # For each cell, fill unvisited with its own mean
229
+ # rate_maps: (c, h, w)
230
+ # mask: (h, w)
231
+ for i in range(num_cells):
232
+ rmap = rate_maps[i]
233
+ rmap[mask] = rmap[~mask].mean()
234
+
235
+ # Smoothing
236
+ rate_maps = gaussian_blur_2d(rate_maps, sigma = 1.0)
237
+
238
+ # Normalize to [0, 1] per cell
239
+ rm_min = rearrange(rate_maps.amin(dim = (1, 2)), 'c -> c 1 1')
240
+ rm_max = rearrange(rate_maps.amax(dim = (1, 2)), 'c -> c 1 1')
241
+
242
+ rate_maps = (rate_maps - rm_min) / (rm_max - rm_min).clamp(min = 1e-5)
243
+
244
+ # Restore NaNs for visualization transparency
245
+ rate_maps[:, mask] = float('nan')
246
+
247
+ return rate_maps
248
+
249
+ def visualize(
250
+ self,
251
+ model: nn.Module,
252
+ epoch: int,
253
+ path_to_save: str | Path,
254
+ probing_steps: int = 5000
255
+ ):
256
+ path_to_save = Path(path_to_save)
257
+ rate_maps = self.get_rate_maps(model, steps = probing_steps)
258
+ rate_maps_cpu = rate_maps.cpu()
259
+
260
+ # Sort by spatial variance to find high-information cells
261
+ # variance handling NaNs
262
+ variances = torch.from_numpy(np.nanvar(rate_maps_cpu.numpy(), axis = (1, 2)))
263
+ top_indices = torch.argsort(variances, descending = True)[:8]
264
+
265
+ fig, axes = plt.subplots(4, 4, figsize = (14, 14), facecolor = 'white')
266
+
267
+ cmap_rate = plt.get_cmap('rainbow').copy()
268
+ cmap_rate.set_bad('white')
269
+
270
+ for i, idx in enumerate(top_indices):
271
+ # Rate Map
272
+ ax_rate = axes[i // 2, (i % 2) * 2]
273
+ rate_map = rate_maps_cpu[idx]
274
+
275
+ ax_rate.imshow(rate_map.numpy(), cmap = cmap_rate, interpolation = 'nearest', origin = 'lower')
276
+ ax_rate.axis('off')
277
+ ax_rate.set_title(f'Rate Map {idx}')
278
+
279
+ # Spatial Autocorrelogram
280
+ ax_sac = axes[i // 2, (i % 2) * 2 + 1]
281
+ sac = get_sac(rate_map)
282
+
283
+ ax_sac.imshow(sac.numpy(), cmap = 'jet', interpolation = 'gaussian', origin = 'lower')
284
+ ax_sac.axis('off')
285
+ ax_sac.set_title(f'SAC {idx}')
286
+
287
+ plt.tight_layout()
288
+ plt.suptitle(f'Grid Cell Discovery (Epoch {epoch})', fontsize = 18)
289
+
290
+ plt.savefig(path_to_save)
291
+ plt.close()
292
+
293
+ # main simulation
294
+
295
+ def run_simulation():
296
+ accelerator = Accelerator()
297
+ accelerator.print(f"Using device: {accelerator.device}")
298
+
299
+ world = MemoryMazeEnv('MemoryMaze-9x9-v0')
300
+ visualizer = GridCellVisualizer(world)
301
+
302
+ model = mmTEM(
303
+ dim = 32,
304
+ sensory_encoder_decoder = maze_sensory_enc_dec,
305
+ dim_sensory = (3, 64, 64),
306
+ dim_action = 2,
307
+ dim_encoded_sensory = 32,
308
+ dim_structure = 64
309
+ )
310
+
311
+ optimizer = Adam(model.parameters(), lr = 1e-3)
312
+
313
+ accelerator.print("Generating training dataset (scale: 64x100)...")
314
+ dataset = TrajectoryDataset(world, num_trajectories = 64, steps = 100)
315
+ loader = DataLoader(dataset, batch_size = 16, shuffle = True)
316
+
317
+ model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
318
+
319
+ results_folder = Path('results')
320
+ results_folder.mkdir(parents = True, exist_ok = True)
321
+
322
+ accelerator.print("Starting extended training on MemoryMaze3D...")
323
+ for epoch in range(1, 16):
324
+ model.train()
325
+ total_loss = 0
326
+ for obs, actions, _ in loader:
327
+ obs = rearrange(obs, 'b t c h w -> b c t h w')
328
+ loss = model(obs, actions)
329
+ optimizer.zero_grad()
330
+ accelerator.backward(loss)
331
+ optimizer.step()
332
+ total_loss += loss.item()
333
+
334
+ accelerator.print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.4f}")
335
+
336
+ if divisible_by(epoch, 5):
337
+ visualizer.visualize(accelerator.unwrap_model(model), epoch, path_to_save = results_folder / f'grid_cells_epoch_{epoch}.png', probing_steps = 5000)
338
+ accelerator.print(f"Grid cell visualization (epoch {epoch}) saved.")
339
+
340
+ visualizer.visualize(accelerator.unwrap_model(model), 15, path_to_save = results_folder / 'grid_cells_final.png', probing_steps = 10000)
341
+
342
+ obs, _, _ = world.generate_trajectory(steps = 1)
343
+ sample_img = rearrange(obs[0], 'c h w -> h w c').numpy()
344
+ Image.fromarray((sample_img * 255).astype(np.uint8)).save(results_folder / 'sample_view.png')
345
+ accelerator.print("Extended simulation complete. Results saved to 'results/' folder.")
346
+
347
+ if __name__ == "__main__":
348
+ run_simulation()
File without changes
File without changes