npcpy 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/build_funcs.py +288 -0
- npcpy/data/load.py +1 -1
- npcpy/data/web.py +5 -4
- npcpy/gen/image_gen.py +2 -1
- npcpy/gen/response.py +119 -66
- npcpy/gen/world_gen.py +609 -0
- npcpy/llm_funcs.py +177 -271
- npcpy/memory/command_history.py +107 -2
- npcpy/memory/knowledge_graph.py +1 -1
- npcpy/npc_compiler.py +176 -32
- npcpy/npc_sysenv.py +5 -5
- npcpy/serve.py +311 -2
- npcpy/sql/npcsql.py +272 -59
- npcpy/work/browser.py +30 -0
- {npcpy-1.3.4.dist-info → npcpy-1.3.6.dist-info}/METADATA +1 -1
- {npcpy-1.3.4.dist-info → npcpy-1.3.6.dist-info}/RECORD +19 -16
- {npcpy-1.3.4.dist-info → npcpy-1.3.6.dist-info}/WHEEL +0 -0
- {npcpy-1.3.4.dist-info → npcpy-1.3.6.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.3.4.dist-info → npcpy-1.3.6.dist-info}/top_level.txt +0 -0
npcpy/gen/world_gen.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
"""
|
|
2
|
+
World model generation - predict next state(s) given current state and action.
|
|
3
|
+
|
|
4
|
+
Supports local world simulation models like:
|
|
5
|
+
- DIAMOND (diffusion world model for Atari)
|
|
6
|
+
- GameNGen (neural game engine)
|
|
7
|
+
- Dreamer/DreamerV3 (latent world models)
|
|
8
|
+
|
|
9
|
+
These are interactive models that take actions and maintain consistency,
|
|
10
|
+
unlike video gen which just produces video from prompts.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import List, Optional, Union, Dict, Any
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def world_step(
|
|
18
|
+
frames: List[np.ndarray],
|
|
19
|
+
action: Optional[Union[int, np.ndarray]] = None,
|
|
20
|
+
model_path: str = None,
|
|
21
|
+
model_type: str = "diamond",
|
|
22
|
+
num_steps: int = 1,
|
|
23
|
+
device: str = "cuda",
|
|
24
|
+
**kwargs
|
|
25
|
+
) -> Dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Predict next frame(s) given frame history and optional action.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
frames: List of recent frames as numpy arrays (H, W, C) or (C, H, W)
|
|
31
|
+
action: Action to condition on. Int for discrete, array for continuous.
|
|
32
|
+
model_path: Path to local model checkpoint
|
|
33
|
+
model_type: One of "diamond", "gamengen", "dreamer"
|
|
34
|
+
num_steps: Number of frames to predict
|
|
35
|
+
device: "cuda" or "cpu"
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Dict with:
|
|
39
|
+
- "frames": List of predicted frames as numpy arrays
|
|
40
|
+
- "latent": Optional latent state for continuing simulation
|
|
41
|
+
- "metadata": Model-specific info
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
if model_type == "diamond":
|
|
45
|
+
return _step_diamond(frames, action, model_path, num_steps, device, **kwargs)
|
|
46
|
+
elif model_type == "gamengen":
|
|
47
|
+
return _step_gamengen(frames, action, model_path, num_steps, device, **kwargs)
|
|
48
|
+
elif model_type == "dreamer":
|
|
49
|
+
return _step_dreamer(frames, action, model_path, num_steps, device, **kwargs)
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError(f"Unknown model_type: {model_type}. Supported: diamond, gamengen, dreamer")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def world_rollout(
|
|
55
|
+
initial_frames: List[np.ndarray],
|
|
56
|
+
actions: List[Union[int, np.ndarray]],
|
|
57
|
+
model_path: str = None,
|
|
58
|
+
model_type: str = "diamond",
|
|
59
|
+
device: str = "cuda",
|
|
60
|
+
**kwargs
|
|
61
|
+
) -> Dict[str, Any]:
|
|
62
|
+
"""
|
|
63
|
+
Roll out multiple steps given a sequence of actions.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
initial_frames: Starting frame(s)
|
|
67
|
+
actions: List of actions to execute in sequence
|
|
68
|
+
model_path: Path to local model checkpoint
|
|
69
|
+
model_type: One of "diamond", "gamengen", "dreamer"
|
|
70
|
+
device: "cuda" or "cpu"
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dict with:
|
|
74
|
+
- "frames": All predicted frames including initial
|
|
75
|
+
- "latents": Latent states at each step (if available)
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
all_frames = list(initial_frames)
|
|
79
|
+
latents = []
|
|
80
|
+
current_frames = initial_frames
|
|
81
|
+
|
|
82
|
+
for action in actions:
|
|
83
|
+
result = world_step(
|
|
84
|
+
current_frames,
|
|
85
|
+
action=action,
|
|
86
|
+
model_path=model_path,
|
|
87
|
+
model_type=model_type,
|
|
88
|
+
num_steps=1,
|
|
89
|
+
device=device,
|
|
90
|
+
**kwargs
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
predicted = result["frames"]
|
|
94
|
+
all_frames.extend(predicted)
|
|
95
|
+
|
|
96
|
+
if result.get("latent") is not None:
|
|
97
|
+
latents.append(result["latent"])
|
|
98
|
+
|
|
99
|
+
# Slide window for next step
|
|
100
|
+
current_frames = current_frames[len(predicted):] + predicted
|
|
101
|
+
|
|
102
|
+
return {
|
|
103
|
+
"frames": all_frames,
|
|
104
|
+
"latents": latents if latents else None
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
# ============== Model-specific implementations ==============
|
|
109
|
+
|
|
110
|
+
def _step_diamond(
|
|
111
|
+
frames: List[np.ndarray],
|
|
112
|
+
action: Optional[int],
|
|
113
|
+
model_path: str,
|
|
114
|
+
num_steps: int,
|
|
115
|
+
device: str,
|
|
116
|
+
num_actions: int = 18, # Atari default
|
|
117
|
+
denoising_steps: int = 3, # DIAMOND default
|
|
118
|
+
**kwargs
|
|
119
|
+
) -> Dict[str, Any]:
|
|
120
|
+
"""
|
|
121
|
+
DIAMOND: Diffusion for World Modeling
|
|
122
|
+
https://github.com/eloialonso/diamond
|
|
123
|
+
|
|
124
|
+
Trained on Atari games, uses diffusion to predict next frames.
|
|
125
|
+
|
|
126
|
+
Setup:
|
|
127
|
+
git clone https://github.com/eloialonso/diamond.git
|
|
128
|
+
cd diamond
|
|
129
|
+
pip install -r requirements.txt
|
|
130
|
+
|
|
131
|
+
Pretrained checkpoints downloaded via:
|
|
132
|
+
python src/play.py --pretrained
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
frames: List of 4 recent frames as (H, W, C) uint8 arrays (64x64 for Atari)
|
|
136
|
+
action: Discrete action index (0 to num_actions-1)
|
|
137
|
+
model_path: Path to checkpoint .pt file
|
|
138
|
+
num_steps: Number of frames to predict (rolls out autoregressively)
|
|
139
|
+
device: "cuda" or "cpu"
|
|
140
|
+
num_actions: Action space size (18 for Atari)
|
|
141
|
+
denoising_steps: Diffusion denoising steps (3 is DIAMOND default)
|
|
142
|
+
"""
|
|
143
|
+
import torch
|
|
144
|
+
import sys
|
|
145
|
+
from pathlib import Path
|
|
146
|
+
|
|
147
|
+
# Add diamond src to path if needed
|
|
148
|
+
diamond_path = kwargs.get("diamond_src_path")
|
|
149
|
+
if diamond_path and diamond_path not in sys.path:
|
|
150
|
+
sys.path.insert(0, diamond_path)
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
from agent import Agent
|
|
154
|
+
from envs.world_model_env import WorldModelEnv
|
|
155
|
+
from hydra.utils import instantiate
|
|
156
|
+
from omegaconf import OmegaConf
|
|
157
|
+
except ImportError as e:
|
|
158
|
+
raise ImportError(
|
|
159
|
+
f"DIAMOND dependencies not found: {e}. "
|
|
160
|
+
"Clone https://github.com/eloialonso/diamond and add src/ to path, "
|
|
161
|
+
"or pass diamond_src_path kwarg."
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Load agent from checkpoint
|
|
165
|
+
ckpt_path = Path(model_path)
|
|
166
|
+
if not ckpt_path.exists():
|
|
167
|
+
raise FileNotFoundError(f"Checkpoint not found: {model_path}")
|
|
168
|
+
|
|
169
|
+
# Load checkpoint to get config
|
|
170
|
+
ckpt = torch.load(ckpt_path, map_location=device)
|
|
171
|
+
|
|
172
|
+
# Build agent from config in checkpoint
|
|
173
|
+
cfg = ckpt.get("config")
|
|
174
|
+
if cfg is None:
|
|
175
|
+
raise ValueError("Checkpoint missing config - ensure it's a DIAMOND checkpoint")
|
|
176
|
+
|
|
177
|
+
agent_cfg = OmegaConf.create(cfg["agent"]) if isinstance(cfg, dict) else cfg.agent
|
|
178
|
+
agent = Agent(instantiate(agent_cfg, num_actions=num_actions)).to(device).eval()
|
|
179
|
+
agent.load(ckpt_path)
|
|
180
|
+
|
|
181
|
+
# Prepare input frames
|
|
182
|
+
# DIAMOND expects (B, T, C, H, W) with T=4 context frames, normalized to [-0.5, 0.5]
|
|
183
|
+
if len(frames) < 4:
|
|
184
|
+
# Pad with first frame if not enough history
|
|
185
|
+
frames = [frames[0]] * (4 - len(frames)) + list(frames)
|
|
186
|
+
frames = frames[-4:] # Take last 4
|
|
187
|
+
|
|
188
|
+
# Convert to tensor
|
|
189
|
+
frames_t = []
|
|
190
|
+
for f in frames:
|
|
191
|
+
if f.dtype != np.uint8:
|
|
192
|
+
f = (f * 255).astype(np.uint8)
|
|
193
|
+
# Normalize to [-0.5, 0.5]
|
|
194
|
+
f_norm = f.astype(np.float32) / 255.0 - 0.5
|
|
195
|
+
# (H, W, C) -> (C, H, W)
|
|
196
|
+
if f_norm.shape[-1] in [1, 3]:
|
|
197
|
+
f_norm = np.transpose(f_norm, (2, 0, 1))
|
|
198
|
+
frames_t.append(torch.from_numpy(f_norm))
|
|
199
|
+
|
|
200
|
+
obs_buffer = torch.stack(frames_t, dim=0).unsqueeze(0).to(device) # (1, 4, C, H, W)
|
|
201
|
+
|
|
202
|
+
# Action as tensor
|
|
203
|
+
if action is None:
|
|
204
|
+
action = 0
|
|
205
|
+
act_tensor = torch.tensor([[action]], dtype=torch.long, device=device) # (1, 1)
|
|
206
|
+
|
|
207
|
+
predicted_frames = []
|
|
208
|
+
|
|
209
|
+
with torch.no_grad():
|
|
210
|
+
for step in range(num_steps):
|
|
211
|
+
# Sample next observation using diffusion
|
|
212
|
+
# The denoiser expects: obs (B, T, C, H, W), act (B, T)
|
|
213
|
+
next_obs = agent.denoiser.sample(
|
|
214
|
+
obs=obs_buffer,
|
|
215
|
+
act=act_tensor.expand(-1, obs_buffer.shape[1]),
|
|
216
|
+
n_steps=denoising_steps
|
|
217
|
+
) # Returns (B, C, H, W)
|
|
218
|
+
|
|
219
|
+
# Convert back to numpy
|
|
220
|
+
pred_np = next_obs[0].cpu().numpy() # (C, H, W)
|
|
221
|
+
pred_np = np.transpose(pred_np, (1, 2, 0)) # (H, W, C)
|
|
222
|
+
pred_np = ((pred_np + 0.5) * 255).clip(0, 255).astype(np.uint8)
|
|
223
|
+
predicted_frames.append(pred_np)
|
|
224
|
+
|
|
225
|
+
# Roll buffer for next step
|
|
226
|
+
obs_buffer = torch.cat([
|
|
227
|
+
obs_buffer[:, 1:],
|
|
228
|
+
next_obs.unsqueeze(1)
|
|
229
|
+
], dim=1)
|
|
230
|
+
|
|
231
|
+
return {
|
|
232
|
+
"frames": predicted_frames,
|
|
233
|
+
"latent": None,
|
|
234
|
+
"metadata": {
|
|
235
|
+
"model_type": "diamond",
|
|
236
|
+
"denoising_steps": denoising_steps,
|
|
237
|
+
"num_actions": num_actions
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _step_gamengen(
|
|
243
|
+
frames: List[np.ndarray],
|
|
244
|
+
action: Optional[int],
|
|
245
|
+
model_path: str,
|
|
246
|
+
num_steps: int,
|
|
247
|
+
device: str,
|
|
248
|
+
**kwargs
|
|
249
|
+
) -> Dict[str, Any]:
|
|
250
|
+
"""
|
|
251
|
+
GameNGen: Neural game engine (e.g., DOOM in a neural net)
|
|
252
|
+
https://gamengen.github.io/
|
|
253
|
+
|
|
254
|
+
Uses diffusion conditioned on actions to generate game frames.
|
|
255
|
+
"""
|
|
256
|
+
# TODO: Implement when open weights/code available
|
|
257
|
+
#
|
|
258
|
+
# GameNGen architecture:
|
|
259
|
+
# - Takes ~64 previous frames for context
|
|
260
|
+
# - Action conditioning via cross-attention
|
|
261
|
+
# - Stable Diffusion backbone with game-specific fine-tuning
|
|
262
|
+
|
|
263
|
+
raise NotImplementedError(
|
|
264
|
+
"GameNGen support not yet implemented. "
|
|
265
|
+
"No open weights currently available."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _step_dreamer(
|
|
270
|
+
frames: List[np.ndarray],
|
|
271
|
+
action: Optional[np.ndarray],
|
|
272
|
+
model_path: str,
|
|
273
|
+
num_steps: int,
|
|
274
|
+
device: str,
|
|
275
|
+
**kwargs
|
|
276
|
+
) -> Dict[str, Any]:
|
|
277
|
+
"""
|
|
278
|
+
DreamerV3: Latent world model
|
|
279
|
+
https://github.com/danijar/dreamerv3
|
|
280
|
+
|
|
281
|
+
Operates in latent space, good for RL and planning.
|
|
282
|
+
"""
|
|
283
|
+
# TODO: Implement with dreamerv3 package
|
|
284
|
+
#
|
|
285
|
+
# import dreamerv3
|
|
286
|
+
#
|
|
287
|
+
# agent = dreamerv3.Agent.load(model_path)
|
|
288
|
+
#
|
|
289
|
+
# # Encode current observation to latent
|
|
290
|
+
# latent = agent.wm.encoder(frames[-1])
|
|
291
|
+
#
|
|
292
|
+
# # Predict forward in latent space
|
|
293
|
+
# predicted_latents = []
|
|
294
|
+
# current = latent
|
|
295
|
+
# for _ in range(num_steps):
|
|
296
|
+
# current = agent.wm.dynamics.img_step(current, action)
|
|
297
|
+
# predicted_latents.append(current)
|
|
298
|
+
#
|
|
299
|
+
# # Decode back to images
|
|
300
|
+
# predicted_frames = [agent.wm.decoder(l) for l in predicted_latents]
|
|
301
|
+
#
|
|
302
|
+
# return {"frames": predicted_frames, "latent": current, "metadata": {}}
|
|
303
|
+
|
|
304
|
+
raise NotImplementedError(
|
|
305
|
+
"DreamerV3 support not yet implemented. "
|
|
306
|
+
"See https://github.com/danijar/dreamerv3 for setup."
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# ============== Setup/Download ==============
|
|
311
|
+
|
|
312
|
+
DIAMOND_REPO = "https://github.com/eloialonso/diamond.git"
|
|
313
|
+
DIAMOND_GAMES = ["Asterix", "Breakout", "Boxing", "Pong", "Seaquest", "SpaceInvaders"]
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def setup_diamond(
|
|
317
|
+
install_path: str = None,
|
|
318
|
+
games: List[str] = None,
|
|
319
|
+
device: str = "cuda"
|
|
320
|
+
) -> str:
|
|
321
|
+
"""
|
|
322
|
+
Clone DIAMOND repo and download pretrained checkpoints.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
install_path: Where to clone the repo. Defaults to ~/.npcpy/diamond
|
|
326
|
+
games: List of games to download. Defaults to all available.
|
|
327
|
+
device: "cuda" or "cpu" for checkpoint download
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
Path to diamond/src directory (add to sys.path)
|
|
331
|
+
"""
|
|
332
|
+
import subprocess
|
|
333
|
+
import os
|
|
334
|
+
|
|
335
|
+
if install_path is None:
|
|
336
|
+
install_path = os.path.expanduser("~/.npcpy/diamond")
|
|
337
|
+
|
|
338
|
+
diamond_dir = os.path.join(install_path, "diamond")
|
|
339
|
+
src_path = os.path.join(diamond_dir, "src")
|
|
340
|
+
|
|
341
|
+
# Clone if not exists
|
|
342
|
+
if not os.path.exists(diamond_dir):
|
|
343
|
+
print(f"Cloning DIAMOND to {diamond_dir}...")
|
|
344
|
+
os.makedirs(install_path, exist_ok=True)
|
|
345
|
+
subprocess.run(
|
|
346
|
+
["git", "clone", DIAMOND_REPO],
|
|
347
|
+
cwd=install_path,
|
|
348
|
+
check=True
|
|
349
|
+
)
|
|
350
|
+
print("Installing requirements...")
|
|
351
|
+
subprocess.run(
|
|
352
|
+
["pip", "install", "-r", "requirements.txt"],
|
|
353
|
+
cwd=diamond_dir,
|
|
354
|
+
check=True
|
|
355
|
+
)
|
|
356
|
+
else:
|
|
357
|
+
print(f"DIAMOND already exists at {diamond_dir}")
|
|
358
|
+
|
|
359
|
+
# Download pretrained checkpoints
|
|
360
|
+
checkpoints_dir = os.path.join(diamond_dir, "checkpoints")
|
|
361
|
+
if not os.path.exists(checkpoints_dir) or not os.listdir(checkpoints_dir):
|
|
362
|
+
print("Downloading pretrained checkpoints...")
|
|
363
|
+
# DIAMOND's play.py --pretrained downloads them
|
|
364
|
+
import sys
|
|
365
|
+
sys.path.insert(0, src_path)
|
|
366
|
+
|
|
367
|
+
# Download via their utility
|
|
368
|
+
try:
|
|
369
|
+
from utils import download_and_unzip_from_google_drive
|
|
370
|
+
# Their checkpoint IDs (from their play.py)
|
|
371
|
+
gdrive_ids = {
|
|
372
|
+
"Asterix": "1KVvh0E2pFYGLdSU2dKiYntRXM3wPB7R7",
|
|
373
|
+
"Boxing": "1_yy3ZcOa7mPCNBPyu1vGO-oiMpwzX8qT",
|
|
374
|
+
"Breakout": "1VYKyFnEo0_pI9kAZmvJ5p0MdIp9u0gdj",
|
|
375
|
+
"Pong": "1FBB-L8t-LzBV9wPZHCB9HUQRYP2K-9Xw",
|
|
376
|
+
"Seaquest": "1nhEG9HHfQlCYVadG9TFbzQdHAqQ0G_Ty",
|
|
377
|
+
"SpaceInvaders": "1N8rp7SYAalJkJT-2RQmfiU4UNZHUhm0W",
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
os.makedirs(checkpoints_dir, exist_ok=True)
|
|
381
|
+
games_to_download = games or DIAMOND_GAMES
|
|
382
|
+
|
|
383
|
+
for game in games_to_download:
|
|
384
|
+
if game not in gdrive_ids:
|
|
385
|
+
print(f"Unknown game: {game}, skipping")
|
|
386
|
+
continue
|
|
387
|
+
ckpt_path = os.path.join(checkpoints_dir, game)
|
|
388
|
+
if os.path.exists(ckpt_path):
|
|
389
|
+
print(f"{game} checkpoint already exists")
|
|
390
|
+
continue
|
|
391
|
+
print(f"Downloading {game}...")
|
|
392
|
+
download_and_unzip_from_google_drive(gdrive_ids[game], checkpoints_dir)
|
|
393
|
+
|
|
394
|
+
except Exception as e:
|
|
395
|
+
print(f"Auto-download failed: {e}")
|
|
396
|
+
print("Run manually: cd {diamond_dir} && python src/play.py --pretrained")
|
|
397
|
+
|
|
398
|
+
return src_path
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def get_diamond_checkpoint(game: str, install_path: str = None) -> str:
|
|
402
|
+
"""
|
|
403
|
+
Get path to a DIAMOND checkpoint for a specific game.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
game: One of Asterix, Boxing, Breakout, Pong, Seaquest, SpaceInvaders
|
|
407
|
+
install_path: DIAMOND install location. Defaults to ~/.npcpy/diamond
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
Path to checkpoint file
|
|
411
|
+
"""
|
|
412
|
+
import os
|
|
413
|
+
|
|
414
|
+
if install_path is None:
|
|
415
|
+
install_path = os.path.expanduser("~/.npcpy/diamond")
|
|
416
|
+
|
|
417
|
+
ckpt_dir = os.path.join(install_path, "diamond", "checkpoints", game)
|
|
418
|
+
|
|
419
|
+
if not os.path.exists(ckpt_dir):
|
|
420
|
+
raise FileNotFoundError(
|
|
421
|
+
f"Checkpoint for {game} not found at {ckpt_dir}. "
|
|
422
|
+
f"Run setup_diamond() first or download manually."
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Find the .pt file
|
|
426
|
+
for f in os.listdir(ckpt_dir):
|
|
427
|
+
if f.endswith(".pt"):
|
|
428
|
+
return os.path.join(ckpt_dir, f)
|
|
429
|
+
|
|
430
|
+
# Sometimes it's in a subdirectory
|
|
431
|
+
for root, dirs, files in os.walk(ckpt_dir):
|
|
432
|
+
for f in files:
|
|
433
|
+
if f.endswith(".pt"):
|
|
434
|
+
return os.path.join(root, f)
|
|
435
|
+
|
|
436
|
+
raise FileNotFoundError(f"No .pt checkpoint found in {ckpt_dir}")
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def get_diamond_src_path(install_path: str = None) -> str:
|
|
440
|
+
"""Get path to diamond/src for imports."""
|
|
441
|
+
import os
|
|
442
|
+
if install_path is None:
|
|
443
|
+
install_path = os.path.expanduser("~/.npcpy/diamond")
|
|
444
|
+
return os.path.join(install_path, "diamond", "src")
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def diamond_step(
|
|
448
|
+
frames: List[np.ndarray],
|
|
449
|
+
action: int,
|
|
450
|
+
game: str = "Breakout",
|
|
451
|
+
num_steps: int = 1,
|
|
452
|
+
device: str = "cuda",
|
|
453
|
+
auto_setup: bool = True
|
|
454
|
+
) -> Dict[str, Any]:
|
|
455
|
+
"""
|
|
456
|
+
High-level API: predict next frame(s) using DIAMOND.
|
|
457
|
+
|
|
458
|
+
Automatically handles setup, checkpoint loading, and caching.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
frames: List of recent frames (at least 1, ideally 4). 64x64 uint8 RGB.
|
|
462
|
+
action: Atari action index (0-17)
|
|
463
|
+
game: One of Asterix, Boxing, Breakout, Pong, Seaquest, SpaceInvaders
|
|
464
|
+
num_steps: Number of frames to predict
|
|
465
|
+
device: "cuda" or "cpu"
|
|
466
|
+
auto_setup: If True, automatically clone repo and download checkpoints
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
Dict with "frames" (list of predicted np arrays) and "metadata"
|
|
470
|
+
|
|
471
|
+
Example:
|
|
472
|
+
>>> from npcpy.gen.world_gen import diamond_step
|
|
473
|
+
>>> import numpy as np
|
|
474
|
+
>>> frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
|
475
|
+
>>> result = diamond_step([frame], action=1, game="Breakout")
|
|
476
|
+
>>> next_frame = result["frames"][0]
|
|
477
|
+
"""
|
|
478
|
+
import os
|
|
479
|
+
|
|
480
|
+
# Auto-setup if needed
|
|
481
|
+
src_path = get_diamond_src_path()
|
|
482
|
+
if auto_setup and not os.path.exists(src_path):
|
|
483
|
+
print(f"DIAMOND not found, running setup...")
|
|
484
|
+
src_path = setup_diamond(games=[game])
|
|
485
|
+
|
|
486
|
+
# Get checkpoint
|
|
487
|
+
try:
|
|
488
|
+
ckpt_path = get_diamond_checkpoint(game)
|
|
489
|
+
except FileNotFoundError:
|
|
490
|
+
if auto_setup:
|
|
491
|
+
print(f"Checkpoint for {game} not found, downloading...")
|
|
492
|
+
setup_diamond(games=[game])
|
|
493
|
+
ckpt_path = get_diamond_checkpoint(game)
|
|
494
|
+
else:
|
|
495
|
+
raise
|
|
496
|
+
|
|
497
|
+
return world_step(
|
|
498
|
+
frames=frames,
|
|
499
|
+
action=action,
|
|
500
|
+
model_path=ckpt_path,
|
|
501
|
+
model_type="diamond",
|
|
502
|
+
num_steps=num_steps,
|
|
503
|
+
device=device,
|
|
504
|
+
diamond_src_path=src_path
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
# ============== Utilities ==============
|
|
509
|
+
|
|
510
|
+
# Global cache for loaded models
|
|
511
|
+
_MODEL_CACHE: Dict[str, Any] = {}
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def load_diamond_model(
|
|
515
|
+
model_path: str,
|
|
516
|
+
device: str = "cuda",
|
|
517
|
+
num_actions: int = 18,
|
|
518
|
+
diamond_src_path: str = None
|
|
519
|
+
):
|
|
520
|
+
"""
|
|
521
|
+
Pre-load a DIAMOND model for faster repeated inference.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
model_path: Path to checkpoint .pt file
|
|
525
|
+
device: "cuda" or "cpu"
|
|
526
|
+
num_actions: Action space size
|
|
527
|
+
diamond_src_path: Path to diamond/src directory
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
Tuple of (agent, denoising_config)
|
|
531
|
+
"""
|
|
532
|
+
import torch
|
|
533
|
+
import sys
|
|
534
|
+
from pathlib import Path
|
|
535
|
+
|
|
536
|
+
cache_key = f"diamond:{model_path}:{device}"
|
|
537
|
+
if cache_key in _MODEL_CACHE:
|
|
538
|
+
return _MODEL_CACHE[cache_key]
|
|
539
|
+
|
|
540
|
+
if diamond_src_path and diamond_src_path not in sys.path:
|
|
541
|
+
sys.path.insert(0, diamond_src_path)
|
|
542
|
+
|
|
543
|
+
from agent import Agent
|
|
544
|
+
from hydra.utils import instantiate
|
|
545
|
+
from omegaconf import OmegaConf
|
|
546
|
+
|
|
547
|
+
ckpt_path = Path(model_path)
|
|
548
|
+
ckpt = torch.load(ckpt_path, map_location=device)
|
|
549
|
+
cfg = ckpt.get("config")
|
|
550
|
+
|
|
551
|
+
agent_cfg = OmegaConf.create(cfg["agent"]) if isinstance(cfg, dict) else cfg.agent
|
|
552
|
+
agent = Agent(instantiate(agent_cfg, num_actions=num_actions)).to(device).eval()
|
|
553
|
+
agent.load(ckpt_path)
|
|
554
|
+
|
|
555
|
+
_MODEL_CACHE[cache_key] = agent
|
|
556
|
+
return agent
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def clear_model_cache():
|
|
560
|
+
"""Clear all cached models to free memory."""
|
|
561
|
+
global _MODEL_CACHE
|
|
562
|
+
_MODEL_CACHE.clear()
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def load_world_model(model_path: str, model_type: str, device: str = "cuda", **kwargs):
|
|
566
|
+
"""
|
|
567
|
+
Pre-load a world model for faster repeated inference.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
model_path: Path to checkpoint
|
|
571
|
+
model_type: "diamond", "gamengen", or "dreamer"
|
|
572
|
+
device: "cuda" or "cpu"
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
Loaded model object (type depends on model_type)
|
|
576
|
+
"""
|
|
577
|
+
if model_type == "diamond":
|
|
578
|
+
return load_diamond_model(model_path, device, **kwargs)
|
|
579
|
+
else:
|
|
580
|
+
raise NotImplementedError(f"Model loading not implemented for {model_type}")
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def frames_to_video(
|
|
584
|
+
frames: List[np.ndarray],
|
|
585
|
+
output_path: str,
|
|
586
|
+
fps: int = 30
|
|
587
|
+
) -> str:
|
|
588
|
+
"""Save predicted frames to video file."""
|
|
589
|
+
import cv2
|
|
590
|
+
import os
|
|
591
|
+
|
|
592
|
+
if not frames:
|
|
593
|
+
raise ValueError("No frames to save")
|
|
594
|
+
|
|
595
|
+
h, w = frames[0].shape[:2]
|
|
596
|
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
597
|
+
|
|
598
|
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
599
|
+
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
|
600
|
+
|
|
601
|
+
for frame in frames:
|
|
602
|
+
if frame.dtype != np.uint8:
|
|
603
|
+
frame = (frame * 255).astype(np.uint8)
|
|
604
|
+
if frame.shape[-1] == 3:
|
|
605
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
606
|
+
writer.write(frame)
|
|
607
|
+
|
|
608
|
+
writer.release()
|
|
609
|
+
return output_path
|