npcpy 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
npcpy/gen/world_gen.py ADDED
@@ -0,0 +1,609 @@
1
+ """
2
+ World model generation - predict next state(s) given current state and action.
3
+
4
+ Supports local world simulation models like:
5
+ - DIAMOND (diffusion world model for Atari)
6
+ - GameNGen (neural game engine)
7
+ - Dreamer/DreamerV3 (latent world models)
8
+
9
+ These are interactive models that take actions and maintain consistency,
10
+ unlike video gen which just produces video from prompts.
11
+ """
12
+
13
+ from typing import List, Optional, Union, Dict, Any
14
+ import numpy as np
15
+
16
+
17
+ def world_step(
18
+ frames: List[np.ndarray],
19
+ action: Optional[Union[int, np.ndarray]] = None,
20
+ model_path: str = None,
21
+ model_type: str = "diamond",
22
+ num_steps: int = 1,
23
+ device: str = "cuda",
24
+ **kwargs
25
+ ) -> Dict[str, Any]:
26
+ """
27
+ Predict next frame(s) given frame history and optional action.
28
+
29
+ Args:
30
+ frames: List of recent frames as numpy arrays (H, W, C) or (C, H, W)
31
+ action: Action to condition on. Int for discrete, array for continuous.
32
+ model_path: Path to local model checkpoint
33
+ model_type: One of "diamond", "gamengen", "dreamer"
34
+ num_steps: Number of frames to predict
35
+ device: "cuda" or "cpu"
36
+
37
+ Returns:
38
+ Dict with:
39
+ - "frames": List of predicted frames as numpy arrays
40
+ - "latent": Optional latent state for continuing simulation
41
+ - "metadata": Model-specific info
42
+ """
43
+
44
+ if model_type == "diamond":
45
+ return _step_diamond(frames, action, model_path, num_steps, device, **kwargs)
46
+ elif model_type == "gamengen":
47
+ return _step_gamengen(frames, action, model_path, num_steps, device, **kwargs)
48
+ elif model_type == "dreamer":
49
+ return _step_dreamer(frames, action, model_path, num_steps, device, **kwargs)
50
+ else:
51
+ raise ValueError(f"Unknown model_type: {model_type}. Supported: diamond, gamengen, dreamer")
52
+
53
+
54
+ def world_rollout(
55
+ initial_frames: List[np.ndarray],
56
+ actions: List[Union[int, np.ndarray]],
57
+ model_path: str = None,
58
+ model_type: str = "diamond",
59
+ device: str = "cuda",
60
+ **kwargs
61
+ ) -> Dict[str, Any]:
62
+ """
63
+ Roll out multiple steps given a sequence of actions.
64
+
65
+ Args:
66
+ initial_frames: Starting frame(s)
67
+ actions: List of actions to execute in sequence
68
+ model_path: Path to local model checkpoint
69
+ model_type: One of "diamond", "gamengen", "dreamer"
70
+ device: "cuda" or "cpu"
71
+
72
+ Returns:
73
+ Dict with:
74
+ - "frames": All predicted frames including initial
75
+ - "latents": Latent states at each step (if available)
76
+ """
77
+
78
+ all_frames = list(initial_frames)
79
+ latents = []
80
+ current_frames = initial_frames
81
+
82
+ for action in actions:
83
+ result = world_step(
84
+ current_frames,
85
+ action=action,
86
+ model_path=model_path,
87
+ model_type=model_type,
88
+ num_steps=1,
89
+ device=device,
90
+ **kwargs
91
+ )
92
+
93
+ predicted = result["frames"]
94
+ all_frames.extend(predicted)
95
+
96
+ if result.get("latent") is not None:
97
+ latents.append(result["latent"])
98
+
99
+ # Slide window for next step
100
+ current_frames = current_frames[len(predicted):] + predicted
101
+
102
+ return {
103
+ "frames": all_frames,
104
+ "latents": latents if latents else None
105
+ }
106
+
107
+
108
+ # ============== Model-specific implementations ==============
109
+
110
+ def _step_diamond(
111
+ frames: List[np.ndarray],
112
+ action: Optional[int],
113
+ model_path: str,
114
+ num_steps: int,
115
+ device: str,
116
+ num_actions: int = 18, # Atari default
117
+ denoising_steps: int = 3, # DIAMOND default
118
+ **kwargs
119
+ ) -> Dict[str, Any]:
120
+ """
121
+ DIAMOND: Diffusion for World Modeling
122
+ https://github.com/eloialonso/diamond
123
+
124
+ Trained on Atari games, uses diffusion to predict next frames.
125
+
126
+ Setup:
127
+ git clone https://github.com/eloialonso/diamond.git
128
+ cd diamond
129
+ pip install -r requirements.txt
130
+
131
+ Pretrained checkpoints downloaded via:
132
+ python src/play.py --pretrained
133
+
134
+ Args:
135
+ frames: List of 4 recent frames as (H, W, C) uint8 arrays (64x64 for Atari)
136
+ action: Discrete action index (0 to num_actions-1)
137
+ model_path: Path to checkpoint .pt file
138
+ num_steps: Number of frames to predict (rolls out autoregressively)
139
+ device: "cuda" or "cpu"
140
+ num_actions: Action space size (18 for Atari)
141
+ denoising_steps: Diffusion denoising steps (3 is DIAMOND default)
142
+ """
143
+ import torch
144
+ import sys
145
+ from pathlib import Path
146
+
147
+ # Add diamond src to path if needed
148
+ diamond_path = kwargs.get("diamond_src_path")
149
+ if diamond_path and diamond_path not in sys.path:
150
+ sys.path.insert(0, diamond_path)
151
+
152
+ try:
153
+ from agent import Agent
154
+ from envs.world_model_env import WorldModelEnv
155
+ from hydra.utils import instantiate
156
+ from omegaconf import OmegaConf
157
+ except ImportError as e:
158
+ raise ImportError(
159
+ f"DIAMOND dependencies not found: {e}. "
160
+ "Clone https://github.com/eloialonso/diamond and add src/ to path, "
161
+ "or pass diamond_src_path kwarg."
162
+ )
163
+
164
+ # Load agent from checkpoint
165
+ ckpt_path = Path(model_path)
166
+ if not ckpt_path.exists():
167
+ raise FileNotFoundError(f"Checkpoint not found: {model_path}")
168
+
169
+ # Load checkpoint to get config
170
+ ckpt = torch.load(ckpt_path, map_location=device)
171
+
172
+ # Build agent from config in checkpoint
173
+ cfg = ckpt.get("config")
174
+ if cfg is None:
175
+ raise ValueError("Checkpoint missing config - ensure it's a DIAMOND checkpoint")
176
+
177
+ agent_cfg = OmegaConf.create(cfg["agent"]) if isinstance(cfg, dict) else cfg.agent
178
+ agent = Agent(instantiate(agent_cfg, num_actions=num_actions)).to(device).eval()
179
+ agent.load(ckpt_path)
180
+
181
+ # Prepare input frames
182
+ # DIAMOND expects (B, T, C, H, W) with T=4 context frames, normalized to [-0.5, 0.5]
183
+ if len(frames) < 4:
184
+ # Pad with first frame if not enough history
185
+ frames = [frames[0]] * (4 - len(frames)) + list(frames)
186
+ frames = frames[-4:] # Take last 4
187
+
188
+ # Convert to tensor
189
+ frames_t = []
190
+ for f in frames:
191
+ if f.dtype != np.uint8:
192
+ f = (f * 255).astype(np.uint8)
193
+ # Normalize to [-0.5, 0.5]
194
+ f_norm = f.astype(np.float32) / 255.0 - 0.5
195
+ # (H, W, C) -> (C, H, W)
196
+ if f_norm.shape[-1] in [1, 3]:
197
+ f_norm = np.transpose(f_norm, (2, 0, 1))
198
+ frames_t.append(torch.from_numpy(f_norm))
199
+
200
+ obs_buffer = torch.stack(frames_t, dim=0).unsqueeze(0).to(device) # (1, 4, C, H, W)
201
+
202
+ # Action as tensor
203
+ if action is None:
204
+ action = 0
205
+ act_tensor = torch.tensor([[action]], dtype=torch.long, device=device) # (1, 1)
206
+
207
+ predicted_frames = []
208
+
209
+ with torch.no_grad():
210
+ for step in range(num_steps):
211
+ # Sample next observation using diffusion
212
+ # The denoiser expects: obs (B, T, C, H, W), act (B, T)
213
+ next_obs = agent.denoiser.sample(
214
+ obs=obs_buffer,
215
+ act=act_tensor.expand(-1, obs_buffer.shape[1]),
216
+ n_steps=denoising_steps
217
+ ) # Returns (B, C, H, W)
218
+
219
+ # Convert back to numpy
220
+ pred_np = next_obs[0].cpu().numpy() # (C, H, W)
221
+ pred_np = np.transpose(pred_np, (1, 2, 0)) # (H, W, C)
222
+ pred_np = ((pred_np + 0.5) * 255).clip(0, 255).astype(np.uint8)
223
+ predicted_frames.append(pred_np)
224
+
225
+ # Roll buffer for next step
226
+ obs_buffer = torch.cat([
227
+ obs_buffer[:, 1:],
228
+ next_obs.unsqueeze(1)
229
+ ], dim=1)
230
+
231
+ return {
232
+ "frames": predicted_frames,
233
+ "latent": None,
234
+ "metadata": {
235
+ "model_type": "diamond",
236
+ "denoising_steps": denoising_steps,
237
+ "num_actions": num_actions
238
+ }
239
+ }
240
+
241
+
242
+ def _step_gamengen(
243
+ frames: List[np.ndarray],
244
+ action: Optional[int],
245
+ model_path: str,
246
+ num_steps: int,
247
+ device: str,
248
+ **kwargs
249
+ ) -> Dict[str, Any]:
250
+ """
251
+ GameNGen: Neural game engine (e.g., DOOM in a neural net)
252
+ https://gamengen.github.io/
253
+
254
+ Uses diffusion conditioned on actions to generate game frames.
255
+ """
256
+ # TODO: Implement when open weights/code available
257
+ #
258
+ # GameNGen architecture:
259
+ # - Takes ~64 previous frames for context
260
+ # - Action conditioning via cross-attention
261
+ # - Stable Diffusion backbone with game-specific fine-tuning
262
+
263
+ raise NotImplementedError(
264
+ "GameNGen support not yet implemented. "
265
+ "No open weights currently available."
266
+ )
267
+
268
+
269
+ def _step_dreamer(
270
+ frames: List[np.ndarray],
271
+ action: Optional[np.ndarray],
272
+ model_path: str,
273
+ num_steps: int,
274
+ device: str,
275
+ **kwargs
276
+ ) -> Dict[str, Any]:
277
+ """
278
+ DreamerV3: Latent world model
279
+ https://github.com/danijar/dreamerv3
280
+
281
+ Operates in latent space, good for RL and planning.
282
+ """
283
+ # TODO: Implement with dreamerv3 package
284
+ #
285
+ # import dreamerv3
286
+ #
287
+ # agent = dreamerv3.Agent.load(model_path)
288
+ #
289
+ # # Encode current observation to latent
290
+ # latent = agent.wm.encoder(frames[-1])
291
+ #
292
+ # # Predict forward in latent space
293
+ # predicted_latents = []
294
+ # current = latent
295
+ # for _ in range(num_steps):
296
+ # current = agent.wm.dynamics.img_step(current, action)
297
+ # predicted_latents.append(current)
298
+ #
299
+ # # Decode back to images
300
+ # predicted_frames = [agent.wm.decoder(l) for l in predicted_latents]
301
+ #
302
+ # return {"frames": predicted_frames, "latent": current, "metadata": {}}
303
+
304
+ raise NotImplementedError(
305
+ "DreamerV3 support not yet implemented. "
306
+ "See https://github.com/danijar/dreamerv3 for setup."
307
+ )
308
+
309
+
310
+ # ============== Setup/Download ==============
311
+
312
+ DIAMOND_REPO = "https://github.com/eloialonso/diamond.git"
313
+ DIAMOND_GAMES = ["Asterix", "Breakout", "Boxing", "Pong", "Seaquest", "SpaceInvaders"]
314
+
315
+
316
+ def setup_diamond(
317
+ install_path: str = None,
318
+ games: List[str] = None,
319
+ device: str = "cuda"
320
+ ) -> str:
321
+ """
322
+ Clone DIAMOND repo and download pretrained checkpoints.
323
+
324
+ Args:
325
+ install_path: Where to clone the repo. Defaults to ~/.npcpy/diamond
326
+ games: List of games to download. Defaults to all available.
327
+ device: "cuda" or "cpu" for checkpoint download
328
+
329
+ Returns:
330
+ Path to diamond/src directory (add to sys.path)
331
+ """
332
+ import subprocess
333
+ import os
334
+
335
+ if install_path is None:
336
+ install_path = os.path.expanduser("~/.npcpy/diamond")
337
+
338
+ diamond_dir = os.path.join(install_path, "diamond")
339
+ src_path = os.path.join(diamond_dir, "src")
340
+
341
+ # Clone if not exists
342
+ if not os.path.exists(diamond_dir):
343
+ print(f"Cloning DIAMOND to {diamond_dir}...")
344
+ os.makedirs(install_path, exist_ok=True)
345
+ subprocess.run(
346
+ ["git", "clone", DIAMOND_REPO],
347
+ cwd=install_path,
348
+ check=True
349
+ )
350
+ print("Installing requirements...")
351
+ subprocess.run(
352
+ ["pip", "install", "-r", "requirements.txt"],
353
+ cwd=diamond_dir,
354
+ check=True
355
+ )
356
+ else:
357
+ print(f"DIAMOND already exists at {diamond_dir}")
358
+
359
+ # Download pretrained checkpoints
360
+ checkpoints_dir = os.path.join(diamond_dir, "checkpoints")
361
+ if not os.path.exists(checkpoints_dir) or not os.listdir(checkpoints_dir):
362
+ print("Downloading pretrained checkpoints...")
363
+ # DIAMOND's play.py --pretrained downloads them
364
+ import sys
365
+ sys.path.insert(0, src_path)
366
+
367
+ # Download via their utility
368
+ try:
369
+ from utils import download_and_unzip_from_google_drive
370
+ # Their checkpoint IDs (from their play.py)
371
+ gdrive_ids = {
372
+ "Asterix": "1KVvh0E2pFYGLdSU2dKiYntRXM3wPB7R7",
373
+ "Boxing": "1_yy3ZcOa7mPCNBPyu1vGO-oiMpwzX8qT",
374
+ "Breakout": "1VYKyFnEo0_pI9kAZmvJ5p0MdIp9u0gdj",
375
+ "Pong": "1FBB-L8t-LzBV9wPZHCB9HUQRYP2K-9Xw",
376
+ "Seaquest": "1nhEG9HHfQlCYVadG9TFbzQdHAqQ0G_Ty",
377
+ "SpaceInvaders": "1N8rp7SYAalJkJT-2RQmfiU4UNZHUhm0W",
378
+ }
379
+
380
+ os.makedirs(checkpoints_dir, exist_ok=True)
381
+ games_to_download = games or DIAMOND_GAMES
382
+
383
+ for game in games_to_download:
384
+ if game not in gdrive_ids:
385
+ print(f"Unknown game: {game}, skipping")
386
+ continue
387
+ ckpt_path = os.path.join(checkpoints_dir, game)
388
+ if os.path.exists(ckpt_path):
389
+ print(f"{game} checkpoint already exists")
390
+ continue
391
+ print(f"Downloading {game}...")
392
+ download_and_unzip_from_google_drive(gdrive_ids[game], checkpoints_dir)
393
+
394
+ except Exception as e:
395
+ print(f"Auto-download failed: {e}")
396
+ print("Run manually: cd {diamond_dir} && python src/play.py --pretrained")
397
+
398
+ return src_path
399
+
400
+
401
+ def get_diamond_checkpoint(game: str, install_path: str = None) -> str:
402
+ """
403
+ Get path to a DIAMOND checkpoint for a specific game.
404
+
405
+ Args:
406
+ game: One of Asterix, Boxing, Breakout, Pong, Seaquest, SpaceInvaders
407
+ install_path: DIAMOND install location. Defaults to ~/.npcpy/diamond
408
+
409
+ Returns:
410
+ Path to checkpoint file
411
+ """
412
+ import os
413
+
414
+ if install_path is None:
415
+ install_path = os.path.expanduser("~/.npcpy/diamond")
416
+
417
+ ckpt_dir = os.path.join(install_path, "diamond", "checkpoints", game)
418
+
419
+ if not os.path.exists(ckpt_dir):
420
+ raise FileNotFoundError(
421
+ f"Checkpoint for {game} not found at {ckpt_dir}. "
422
+ f"Run setup_diamond() first or download manually."
423
+ )
424
+
425
+ # Find the .pt file
426
+ for f in os.listdir(ckpt_dir):
427
+ if f.endswith(".pt"):
428
+ return os.path.join(ckpt_dir, f)
429
+
430
+ # Sometimes it's in a subdirectory
431
+ for root, dirs, files in os.walk(ckpt_dir):
432
+ for f in files:
433
+ if f.endswith(".pt"):
434
+ return os.path.join(root, f)
435
+
436
+ raise FileNotFoundError(f"No .pt checkpoint found in {ckpt_dir}")
437
+
438
+
439
+ def get_diamond_src_path(install_path: str = None) -> str:
440
+ """Get path to diamond/src for imports."""
441
+ import os
442
+ if install_path is None:
443
+ install_path = os.path.expanduser("~/.npcpy/diamond")
444
+ return os.path.join(install_path, "diamond", "src")
445
+
446
+
447
+ def diamond_step(
448
+ frames: List[np.ndarray],
449
+ action: int,
450
+ game: str = "Breakout",
451
+ num_steps: int = 1,
452
+ device: str = "cuda",
453
+ auto_setup: bool = True
454
+ ) -> Dict[str, Any]:
455
+ """
456
+ High-level API: predict next frame(s) using DIAMOND.
457
+
458
+ Automatically handles setup, checkpoint loading, and caching.
459
+
460
+ Args:
461
+ frames: List of recent frames (at least 1, ideally 4). 64x64 uint8 RGB.
462
+ action: Atari action index (0-17)
463
+ game: One of Asterix, Boxing, Breakout, Pong, Seaquest, SpaceInvaders
464
+ num_steps: Number of frames to predict
465
+ device: "cuda" or "cpu"
466
+ auto_setup: If True, automatically clone repo and download checkpoints
467
+
468
+ Returns:
469
+ Dict with "frames" (list of predicted np arrays) and "metadata"
470
+
471
+ Example:
472
+ >>> from npcpy.gen.world_gen import diamond_step
473
+ >>> import numpy as np
474
+ >>> frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
475
+ >>> result = diamond_step([frame], action=1, game="Breakout")
476
+ >>> next_frame = result["frames"][0]
477
+ """
478
+ import os
479
+
480
+ # Auto-setup if needed
481
+ src_path = get_diamond_src_path()
482
+ if auto_setup and not os.path.exists(src_path):
483
+ print(f"DIAMOND not found, running setup...")
484
+ src_path = setup_diamond(games=[game])
485
+
486
+ # Get checkpoint
487
+ try:
488
+ ckpt_path = get_diamond_checkpoint(game)
489
+ except FileNotFoundError:
490
+ if auto_setup:
491
+ print(f"Checkpoint for {game} not found, downloading...")
492
+ setup_diamond(games=[game])
493
+ ckpt_path = get_diamond_checkpoint(game)
494
+ else:
495
+ raise
496
+
497
+ return world_step(
498
+ frames=frames,
499
+ action=action,
500
+ model_path=ckpt_path,
501
+ model_type="diamond",
502
+ num_steps=num_steps,
503
+ device=device,
504
+ diamond_src_path=src_path
505
+ )
506
+
507
+
508
+ # ============== Utilities ==============
509
+
510
+ # Global cache for loaded models
511
+ _MODEL_CACHE: Dict[str, Any] = {}
512
+
513
+
514
+ def load_diamond_model(
515
+ model_path: str,
516
+ device: str = "cuda",
517
+ num_actions: int = 18,
518
+ diamond_src_path: str = None
519
+ ):
520
+ """
521
+ Pre-load a DIAMOND model for faster repeated inference.
522
+
523
+ Args:
524
+ model_path: Path to checkpoint .pt file
525
+ device: "cuda" or "cpu"
526
+ num_actions: Action space size
527
+ diamond_src_path: Path to diamond/src directory
528
+
529
+ Returns:
530
+ Tuple of (agent, denoising_config)
531
+ """
532
+ import torch
533
+ import sys
534
+ from pathlib import Path
535
+
536
+ cache_key = f"diamond:{model_path}:{device}"
537
+ if cache_key in _MODEL_CACHE:
538
+ return _MODEL_CACHE[cache_key]
539
+
540
+ if diamond_src_path and diamond_src_path not in sys.path:
541
+ sys.path.insert(0, diamond_src_path)
542
+
543
+ from agent import Agent
544
+ from hydra.utils import instantiate
545
+ from omegaconf import OmegaConf
546
+
547
+ ckpt_path = Path(model_path)
548
+ ckpt = torch.load(ckpt_path, map_location=device)
549
+ cfg = ckpt.get("config")
550
+
551
+ agent_cfg = OmegaConf.create(cfg["agent"]) if isinstance(cfg, dict) else cfg.agent
552
+ agent = Agent(instantiate(agent_cfg, num_actions=num_actions)).to(device).eval()
553
+ agent.load(ckpt_path)
554
+
555
+ _MODEL_CACHE[cache_key] = agent
556
+ return agent
557
+
558
+
559
+ def clear_model_cache():
560
+ """Clear all cached models to free memory."""
561
+ global _MODEL_CACHE
562
+ _MODEL_CACHE.clear()
563
+
564
+
565
+ def load_world_model(model_path: str, model_type: str, device: str = "cuda", **kwargs):
566
+ """
567
+ Pre-load a world model for faster repeated inference.
568
+
569
+ Args:
570
+ model_path: Path to checkpoint
571
+ model_type: "diamond", "gamengen", or "dreamer"
572
+ device: "cuda" or "cpu"
573
+
574
+ Returns:
575
+ Loaded model object (type depends on model_type)
576
+ """
577
+ if model_type == "diamond":
578
+ return load_diamond_model(model_path, device, **kwargs)
579
+ else:
580
+ raise NotImplementedError(f"Model loading not implemented for {model_type}")
581
+
582
+
583
+ def frames_to_video(
584
+ frames: List[np.ndarray],
585
+ output_path: str,
586
+ fps: int = 30
587
+ ) -> str:
588
+ """Save predicted frames to video file."""
589
+ import cv2
590
+ import os
591
+
592
+ if not frames:
593
+ raise ValueError("No frames to save")
594
+
595
+ h, w = frames[0].shape[:2]
596
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
597
+
598
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
599
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
600
+
601
+ for frame in frames:
602
+ if frame.dtype != np.uint8:
603
+ frame = (frame * 255).astype(np.uint8)
604
+ if frame.shape[-1] == 3:
605
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
606
+ writer.write(frame)
607
+
608
+ writer.release()
609
+ return output_path