hippoformer 0.0.14__tar.gz → 0.0.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.14 → hippoformer-0.0.16}/.gitignore +2 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/PKG-INFO +1 -1
- {hippoformer-0.0.14 → hippoformer-0.0.16}/hippoformer/hippoformer.py +10 -10
- {hippoformer-0.0.14 → hippoformer-0.0.16}/pyproject.toml +1 -1
- hippoformer-0.0.16/train_memory_maze.py +348 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/LICENSE +0 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/README.md +0 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/hippoformer-fig6.png +0 -0
- {hippoformer-0.0.14 → hippoformer-0.0.16}/tests/test_hippoformer.py +0 -0
|
@@ -285,7 +285,7 @@ class Attention(Module):
|
|
|
285
285
|
q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
|
|
286
286
|
|
|
287
287
|
if exists(kv_cache):
|
|
288
|
-
ck1, cv1,
|
|
288
|
+
ck1, cv1, ck2, cv2 = kv_cache
|
|
289
289
|
k1 = cat((ck1, k1), dim = -2)
|
|
290
290
|
v1 = cat((cv1, v1), dim = -2)
|
|
291
291
|
k2 = cat((ck2, k2), dim = -2)
|
|
@@ -300,12 +300,12 @@ class Attention(Module):
|
|
|
300
300
|
|
|
301
301
|
i, j = sim.shape[-2:]
|
|
302
302
|
|
|
303
|
-
j_seq = arange(
|
|
304
|
-
i_seq = arange(
|
|
303
|
+
j_seq = arange(i, device = device)[:, None]
|
|
304
|
+
i_seq = arange(j, device = device)[None, :] + (j - i)
|
|
305
305
|
|
|
306
306
|
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
307
307
|
|
|
308
|
-
sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
308
|
+
sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
309
309
|
|
|
310
310
|
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
311
311
|
|
|
@@ -365,7 +365,7 @@ class TEMTransformerBlock(Module):
|
|
|
365
365
|
|
|
366
366
|
x = self.ff(x) + x
|
|
367
367
|
|
|
368
|
-
next_kv_cache =
|
|
368
|
+
next_kv_cache = tuple(t[:, -self.window_size:] for t in next_kv_cache)
|
|
369
369
|
|
|
370
370
|
return x, next_kv_cache
|
|
371
371
|
|
|
@@ -401,7 +401,7 @@ class TEMTransformer(Module):
|
|
|
401
401
|
**transformer_kwargs
|
|
402
402
|
)
|
|
403
403
|
|
|
404
|
-
layers.append(block)
|
|
404
|
+
self.layers.append(block)
|
|
405
405
|
|
|
406
406
|
def forward(
|
|
407
407
|
self,
|
|
@@ -411,7 +411,7 @@ class TEMTransformer(Module):
|
|
|
411
411
|
prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
|
|
412
412
|
):
|
|
413
413
|
|
|
414
|
-
structure, next_hiddens = self.
|
|
414
|
+
structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
|
|
415
415
|
|
|
416
416
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
417
417
|
|
|
@@ -484,7 +484,7 @@ class mmTEM(Module):
|
|
|
484
484
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
485
485
|
|
|
486
486
|
self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
|
|
487
|
-
self.assoc_scan = AssocScan(
|
|
487
|
+
self.assoc_scan = AssocScan(**assoc_scan_kwargs)
|
|
488
488
|
|
|
489
489
|
self.meta_memory_mlp = create_mlp(
|
|
490
490
|
dim = dim * 2,
|
|
@@ -495,7 +495,7 @@ class mmTEM(Module):
|
|
|
495
495
|
)
|
|
496
496
|
|
|
497
497
|
def forward_with_mse_loss(params, keys, values):
|
|
498
|
-
pred = functional_call(self.meta_memory_mlp, params, keys)
|
|
498
|
+
pred = functional_call(self.meta_memory_mlp, params, (keys,))
|
|
499
499
|
return F.mse_loss(pred, values)
|
|
500
500
|
|
|
501
501
|
grad_fn = grad(forward_with_mse_loss)
|
|
@@ -596,7 +596,7 @@ class mmTEM(Module):
|
|
|
596
596
|
|
|
597
597
|
# 2b. structure from structure
|
|
598
598
|
|
|
599
|
-
decoded_structure, decoded_encoded_sensory = self.retrieve(
|
|
599
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
|
|
600
600
|
|
|
601
601
|
structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
602
602
|
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "torch",
|
|
4
|
+
# "accelerate",
|
|
5
|
+
# "einops",
|
|
6
|
+
# "gym==0.25.2",
|
|
7
|
+
# "memory-maze",
|
|
8
|
+
# "dm-control",
|
|
9
|
+
# "matplotlib",
|
|
10
|
+
# "numpy<2",
|
|
11
|
+
# "beartype",
|
|
12
|
+
# "pillow",
|
|
13
|
+
# "scipy",
|
|
14
|
+
# "assoc-scan",
|
|
15
|
+
# "einx",
|
|
16
|
+
# "x-mlps-pytorch",
|
|
17
|
+
# ]
|
|
18
|
+
# ///
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
os.environ['MUJOCO_GL'] = 'glfw'
|
|
22
|
+
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
from torch import nn, Tensor, pi, stack
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
from torch.utils.data import Dataset, DataLoader
|
|
29
|
+
from torch.optim import Adam
|
|
30
|
+
from einops import rearrange
|
|
31
|
+
from accelerate import Accelerator
|
|
32
|
+
|
|
33
|
+
import gym
|
|
34
|
+
import memory_maze
|
|
35
|
+
|
|
36
|
+
from hippoformer.hippoformer import mmTEM, maze_sensory_enc_dec
|
|
37
|
+
|
|
38
|
+
import matplotlib.pyplot as plt
|
|
39
|
+
import numpy as np
|
|
40
|
+
from PIL import Image
|
|
41
|
+
from scipy.signal import correlate2d
|
|
42
|
+
|
|
43
|
+
# helpers
|
|
44
|
+
|
|
45
|
+
def exists(v):
|
|
46
|
+
return v is not None
|
|
47
|
+
|
|
48
|
+
def default(v, d):
|
|
49
|
+
return v if exists(v) else d
|
|
50
|
+
|
|
51
|
+
def divisible_by(num, den):
|
|
52
|
+
return (num % den) == 0
|
|
53
|
+
|
|
54
|
+
# MemoryMaze environment wrapper
|
|
55
|
+
|
|
56
|
+
def find_physics(env):
|
|
57
|
+
curr = env
|
|
58
|
+
for _ in range(20): # depth limit
|
|
59
|
+
if hasattr(curr, '_physics'): return curr._physics
|
|
60
|
+
if hasattr(curr, 'physics'): return curr.physics
|
|
61
|
+
if hasattr(curr, 'env'): curr = curr.env
|
|
62
|
+
elif hasattr(curr, 'unwrapped'): curr = curr.unwrapped
|
|
63
|
+
else: break
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
class MemoryMazeEnv:
|
|
67
|
+
def __init__(self, env_name = 'MemoryMaze-9x9-v0'):
|
|
68
|
+
self.env_name = env_name
|
|
69
|
+
self.env = gym.make(env_name)
|
|
70
|
+
self.observation_space = self.env.observation_space
|
|
71
|
+
self.action_space = self.env.action_space
|
|
72
|
+
self.physics = None
|
|
73
|
+
|
|
74
|
+
def reset(self):
|
|
75
|
+
obs = self.env.reset()
|
|
76
|
+
self.physics = find_physics(self.env)
|
|
77
|
+
return obs
|
|
78
|
+
|
|
79
|
+
def step(self, action):
|
|
80
|
+
return self.env.step(action)
|
|
81
|
+
|
|
82
|
+
def get_pos(self):
|
|
83
|
+
if self.physics is None:
|
|
84
|
+
self.physics = find_physics(self.env)
|
|
85
|
+
try:
|
|
86
|
+
return self.physics.data.qpos[:2].copy()
|
|
87
|
+
except Exception:
|
|
88
|
+
return np.array([0., 0.])
|
|
89
|
+
|
|
90
|
+
def generate_trajectory(self, steps = 100, skip_obs = False):
|
|
91
|
+
obs = self.reset()
|
|
92
|
+
observations, actions, positions = [], [], []
|
|
93
|
+
|
|
94
|
+
for _ in range(steps):
|
|
95
|
+
action = self.action_space.sample()
|
|
96
|
+
|
|
97
|
+
if not skip_obs:
|
|
98
|
+
obs_t = torch.from_numpy(obs.copy()).float()
|
|
99
|
+
obs_t = rearrange(obs_t, 'h w c -> c h w') / 255.0
|
|
100
|
+
observations.append(obs_t)
|
|
101
|
+
|
|
102
|
+
v_w = torch.zeros(2, dtype = torch.float32)
|
|
103
|
+
if action == 1: v_w[0] = 0.5 # Move forward
|
|
104
|
+
elif action == 2: v_w[1] = -0.5 # Rotate right
|
|
105
|
+
elif action == 3: v_w[1] = 0.5 # Rotate left
|
|
106
|
+
|
|
107
|
+
actions.append(v_w)
|
|
108
|
+
positions.append(torch.from_numpy(self.get_pos()).float())
|
|
109
|
+
|
|
110
|
+
step_res = self.step(action)
|
|
111
|
+
obs, done = step_res[0], step_res[2]
|
|
112
|
+
if done: obs = self.reset()
|
|
113
|
+
|
|
114
|
+
return stack(observations) if not skip_obs else None, stack(actions), stack(positions)
|
|
115
|
+
|
|
116
|
+
# dataset
|
|
117
|
+
|
|
118
|
+
class TrajectoryDataset(Dataset):
|
|
119
|
+
def __init__(self, world, num_trajectories = 32, steps = 100):
|
|
120
|
+
self.data = [world.generate_trajectory(steps) for _ in range(num_trajectories)]
|
|
121
|
+
|
|
122
|
+
def __len__(self):
|
|
123
|
+
return len(self.data)
|
|
124
|
+
|
|
125
|
+
def __getitem__(self, idx):
|
|
126
|
+
return self.data[idx]
|
|
127
|
+
|
|
128
|
+
# grid cell visualization
|
|
129
|
+
|
|
130
|
+
def get_sac(rate_map: Tensor):
|
|
131
|
+
"""Spatial Autocorrelogram (SAC) using Torch"""
|
|
132
|
+
# rate_map: (res, res)
|
|
133
|
+
mask = ~rate_map.isnan()
|
|
134
|
+
if not mask.any():
|
|
135
|
+
return torch.zeros_like(rate_map)
|
|
136
|
+
|
|
137
|
+
m = rate_map.clone()
|
|
138
|
+
mean = rate_map[mask].mean()
|
|
139
|
+
m[mask] -= mean
|
|
140
|
+
m[~mask] = 0.
|
|
141
|
+
|
|
142
|
+
# 2D correlation via conv2d
|
|
143
|
+
# correlate2d(m, m, mode='full')
|
|
144
|
+
h, w = m.shape
|
|
145
|
+
m_batch = rearrange(m, 'h w -> 1 1 h w')
|
|
146
|
+
|
|
147
|
+
sac = F.conv2d(
|
|
148
|
+
F.pad(m_batch, (w - 1, w - 1, h - 1, h - 1)),
|
|
149
|
+
m_batch
|
|
150
|
+
)
|
|
151
|
+
return rearrange(sac, '1 1 h w -> h w')
|
|
152
|
+
|
|
153
|
+
def gaussian_blur_2d(img: Tensor, sigma: float = 1.0):
|
|
154
|
+
"""2D Gaussian Blur in Torch"""
|
|
155
|
+
# img: (c, h, w)
|
|
156
|
+
ksize = int(2 * 3 * sigma + 1)
|
|
157
|
+
if ksize % 2 == 0: ksize += 1
|
|
158
|
+
|
|
159
|
+
x = torch.linspace(-3 * sigma, 3 * sigma, ksize)
|
|
160
|
+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
|
161
|
+
kernel1d = pdf / pdf.sum()
|
|
162
|
+
kernel2d = kernel1d[:, None] * kernel1d[None, :]
|
|
163
|
+
|
|
164
|
+
c = img.shape[0]
|
|
165
|
+
kernel2d = rearrange(kernel2d, 'h w -> 1 1 h w').to(img.device)
|
|
166
|
+
kernel2d = kernel2d.expand(c, 1, -1, -1)
|
|
167
|
+
|
|
168
|
+
padding = ksize // 2
|
|
169
|
+
# reflect pad to avoid edge artifacts
|
|
170
|
+
img_padded = F.pad(rearrange(img, 'c h w -> 1 c h w'), (padding, padding, padding, padding), mode = 'reflect')
|
|
171
|
+
blurred = F.conv2d(img_padded, kernel2d, groups = c)
|
|
172
|
+
return rearrange(blurred, '1 c h w -> c h w')
|
|
173
|
+
|
|
174
|
+
class GridCellVisualizer:
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
world: MemoryMazeEnv,
|
|
178
|
+
resolution: int = 40,
|
|
179
|
+
spatial_range: tuple[float, float] = (-5.0, 5.0)
|
|
180
|
+
):
|
|
181
|
+
self.world = world
|
|
182
|
+
self.resolution = resolution
|
|
183
|
+
self.spatial_range = spatial_range
|
|
184
|
+
|
|
185
|
+
@torch.no_grad()
|
|
186
|
+
def get_rate_maps(self, model: nn.Module, steps: int = 5000):
|
|
187
|
+
model.eval()
|
|
188
|
+
device = next(model.parameters()).device
|
|
189
|
+
|
|
190
|
+
# Probing trajectory (skip observations for speed)
|
|
191
|
+
_, actions, positions = self.world.generate_trajectory(steps = steps, skip_obs = True)
|
|
192
|
+
|
|
193
|
+
actions = actions.to(device)
|
|
194
|
+
positions = positions.to(device)
|
|
195
|
+
|
|
196
|
+
actions_in = rearrange(actions, 't d -> 1 t d')
|
|
197
|
+
structural_codes = model.path_integrator(actions_in)
|
|
198
|
+
structural_codes = rearrange(structural_codes, '1 t d -> t d')
|
|
199
|
+
|
|
200
|
+
# Vectorized binning in Torch
|
|
201
|
+
res = self.resolution
|
|
202
|
+
p_min, p_max = self.spatial_range
|
|
203
|
+
|
|
204
|
+
# Map positions to [0, resolution - 1]
|
|
205
|
+
indices = ((positions - p_min) / (p_max - p_min + 1e-5) * (res - 1)).long()
|
|
206
|
+
indices = torch.clamp(indices, 0, res - 1)
|
|
207
|
+
|
|
208
|
+
num_cells = structural_codes.shape[-1]
|
|
209
|
+
activations = torch.zeros((num_cells, res, res), device = device)
|
|
210
|
+
counts = torch.zeros((res, res), device = device)
|
|
211
|
+
|
|
212
|
+
# Flat indices for index_add_
|
|
213
|
+
flat_indices = indices[:, 0] * res + indices[:, 1]
|
|
214
|
+
|
|
215
|
+
activations_flat = rearrange(activations, 'd h w -> d (h w)')
|
|
216
|
+
activations_flat.index_add_(1, flat_indices, structural_codes.T)
|
|
217
|
+
|
|
218
|
+
counts_flat = counts.view(-1)
|
|
219
|
+
counts_flat.index_add_(0, flat_indices, torch.ones_like(flat_indices, dtype = torch.float32))
|
|
220
|
+
|
|
221
|
+
# Occupancy normalization
|
|
222
|
+
rate_maps = activations / rearrange(counts.clamp(min = 1), 'h w -> 1 h w')
|
|
223
|
+
mask = counts < 1
|
|
224
|
+
|
|
225
|
+
# Fill NaNs before smoothing
|
|
226
|
+
has_visits = (~mask).any()
|
|
227
|
+
if has_visits:
|
|
228
|
+
# For each cell, fill unvisited with its own mean
|
|
229
|
+
# rate_maps: (c, h, w)
|
|
230
|
+
# mask: (h, w)
|
|
231
|
+
for i in range(num_cells):
|
|
232
|
+
rmap = rate_maps[i]
|
|
233
|
+
rmap[mask] = rmap[~mask].mean()
|
|
234
|
+
|
|
235
|
+
# Smoothing
|
|
236
|
+
rate_maps = gaussian_blur_2d(rate_maps, sigma = 1.0)
|
|
237
|
+
|
|
238
|
+
# Normalize to [0, 1] per cell
|
|
239
|
+
rm_min = rearrange(rate_maps.amin(dim = (1, 2)), 'c -> c 1 1')
|
|
240
|
+
rm_max = rearrange(rate_maps.amax(dim = (1, 2)), 'c -> c 1 1')
|
|
241
|
+
|
|
242
|
+
rate_maps = (rate_maps - rm_min) / (rm_max - rm_min).clamp(min = 1e-5)
|
|
243
|
+
|
|
244
|
+
# Restore NaNs for visualization transparency
|
|
245
|
+
rate_maps[:, mask] = float('nan')
|
|
246
|
+
|
|
247
|
+
return rate_maps
|
|
248
|
+
|
|
249
|
+
def visualize(
|
|
250
|
+
self,
|
|
251
|
+
model: nn.Module,
|
|
252
|
+
epoch: int,
|
|
253
|
+
path_to_save: str | Path,
|
|
254
|
+
probing_steps: int = 5000
|
|
255
|
+
):
|
|
256
|
+
path_to_save = Path(path_to_save)
|
|
257
|
+
rate_maps = self.get_rate_maps(model, steps = probing_steps)
|
|
258
|
+
rate_maps_cpu = rate_maps.cpu()
|
|
259
|
+
|
|
260
|
+
# Sort by spatial variance to find high-information cells
|
|
261
|
+
# variance handling NaNs
|
|
262
|
+
variances = torch.from_numpy(np.nanvar(rate_maps_cpu.numpy(), axis = (1, 2)))
|
|
263
|
+
top_indices = torch.argsort(variances, descending = True)[:8]
|
|
264
|
+
|
|
265
|
+
fig, axes = plt.subplots(4, 4, figsize = (14, 14), facecolor = 'white')
|
|
266
|
+
|
|
267
|
+
cmap_rate = plt.get_cmap('rainbow').copy()
|
|
268
|
+
cmap_rate.set_bad('white')
|
|
269
|
+
|
|
270
|
+
for i, idx in enumerate(top_indices):
|
|
271
|
+
# Rate Map
|
|
272
|
+
ax_rate = axes[i // 2, (i % 2) * 2]
|
|
273
|
+
rate_map = rate_maps_cpu[idx]
|
|
274
|
+
|
|
275
|
+
ax_rate.imshow(rate_map.numpy(), cmap = cmap_rate, interpolation = 'nearest', origin = 'lower')
|
|
276
|
+
ax_rate.axis('off')
|
|
277
|
+
ax_rate.set_title(f'Rate Map {idx}')
|
|
278
|
+
|
|
279
|
+
# Spatial Autocorrelogram
|
|
280
|
+
ax_sac = axes[i // 2, (i % 2) * 2 + 1]
|
|
281
|
+
sac = get_sac(rate_map)
|
|
282
|
+
|
|
283
|
+
ax_sac.imshow(sac.numpy(), cmap = 'jet', interpolation = 'gaussian', origin = 'lower')
|
|
284
|
+
ax_sac.axis('off')
|
|
285
|
+
ax_sac.set_title(f'SAC {idx}')
|
|
286
|
+
|
|
287
|
+
plt.tight_layout()
|
|
288
|
+
plt.suptitle(f'Grid Cell Discovery (Epoch {epoch})', fontsize = 18)
|
|
289
|
+
|
|
290
|
+
plt.savefig(path_to_save)
|
|
291
|
+
plt.close()
|
|
292
|
+
|
|
293
|
+
# main simulation
|
|
294
|
+
|
|
295
|
+
def run_simulation():
|
|
296
|
+
accelerator = Accelerator()
|
|
297
|
+
accelerator.print(f"Using device: {accelerator.device}")
|
|
298
|
+
|
|
299
|
+
world = MemoryMazeEnv('MemoryMaze-9x9-v0')
|
|
300
|
+
visualizer = GridCellVisualizer(world)
|
|
301
|
+
|
|
302
|
+
model = mmTEM(
|
|
303
|
+
dim = 32,
|
|
304
|
+
sensory_encoder_decoder = maze_sensory_enc_dec,
|
|
305
|
+
dim_sensory = (3, 64, 64),
|
|
306
|
+
dim_action = 2,
|
|
307
|
+
dim_encoded_sensory = 32,
|
|
308
|
+
dim_structure = 64
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
optimizer = Adam(model.parameters(), lr = 1e-3)
|
|
312
|
+
|
|
313
|
+
accelerator.print("Generating training dataset (scale: 64x100)...")
|
|
314
|
+
dataset = TrajectoryDataset(world, num_trajectories = 64, steps = 100)
|
|
315
|
+
loader = DataLoader(dataset, batch_size = 16, shuffle = True)
|
|
316
|
+
|
|
317
|
+
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
|
|
318
|
+
|
|
319
|
+
results_folder = Path('results')
|
|
320
|
+
results_folder.mkdir(parents = True, exist_ok = True)
|
|
321
|
+
|
|
322
|
+
accelerator.print("Starting extended training on MemoryMaze3D...")
|
|
323
|
+
for epoch in range(1, 16):
|
|
324
|
+
model.train()
|
|
325
|
+
total_loss = 0
|
|
326
|
+
for obs, actions, _ in loader:
|
|
327
|
+
obs = rearrange(obs, 'b t c h w -> b c t h w')
|
|
328
|
+
loss = model(obs, actions)
|
|
329
|
+
optimizer.zero_grad()
|
|
330
|
+
accelerator.backward(loss)
|
|
331
|
+
optimizer.step()
|
|
332
|
+
total_loss += loss.item()
|
|
333
|
+
|
|
334
|
+
accelerator.print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.4f}")
|
|
335
|
+
|
|
336
|
+
if divisible_by(epoch, 5):
|
|
337
|
+
visualizer.visualize(accelerator.unwrap_model(model), epoch, path_to_save = results_folder / f'grid_cells_epoch_{epoch}.png', probing_steps = 5000)
|
|
338
|
+
accelerator.print(f"Grid cell visualization (epoch {epoch}) saved.")
|
|
339
|
+
|
|
340
|
+
visualizer.visualize(accelerator.unwrap_model(model), 15, path_to_save = results_folder / 'grid_cells_final.png', probing_steps = 10000)
|
|
341
|
+
|
|
342
|
+
obs, _, _ = world.generate_trajectory(steps = 1)
|
|
343
|
+
sample_img = rearrange(obs[0], 'c h w -> h w c').numpy()
|
|
344
|
+
Image.fromarray((sample_img * 255).astype(np.uint8)).save(results_folder / 'sample_view.png')
|
|
345
|
+
accelerator.print("Extended simulation complete. Results saved to 'results/' folder.")
|
|
346
|
+
|
|
347
|
+
if __name__ == "__main__":
|
|
348
|
+
run_simulation()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|