synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +575 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/tracing_v3/examples/basic_usage.py +188 -0
  100. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
  101. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
  102. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
  103. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
  104. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
  105. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
1
+ """Trajectory recording and replay functionality for NetHack."""
2
+
3
+ import json
4
+ import pickle
5
+ import gzip
6
+ from pathlib import Path
7
+ from typing import Dict, Any, List, Optional, Tuple
8
+ from datetime import datetime
9
+ import numpy as np
10
+ from dataclasses import dataclass, asdict
11
+ import base64
12
+
13
+
14
+ @dataclass
15
+ class TrajectoryFrame:
16
+ """Single frame in a trajectory."""
17
+
18
+ step: int
19
+ action: str
20
+ observation: Dict[str, Any]
21
+ reward: float
22
+ done: bool
23
+ info: Dict[str, Any]
24
+ timestamp: float
25
+
26
+ def to_dict(self) -> Dict[str, Any]:
27
+ """Convert to serializable dict."""
28
+ d = asdict(self)
29
+ # Handle numpy arrays in observation
30
+ if "observation" in d and d["observation"]:
31
+ d["observation"] = self._serialize_observation(d["observation"])
32
+ return d
33
+
34
+ def _serialize_observation(self, obs: Dict[str, Any]) -> Dict[str, Any]:
35
+ """Serialize observation, converting numpy arrays to lists."""
36
+ serialized = {}
37
+ for key, value in obs.items():
38
+ if isinstance(value, np.ndarray):
39
+ serialized[key] = {
40
+ "type": "ndarray",
41
+ "data": value.tolist(),
42
+ "dtype": str(value.dtype),
43
+ "shape": value.shape,
44
+ }
45
+ elif isinstance(value, dict):
46
+ serialized[key] = self._serialize_observation(value)
47
+ elif isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], dict):
48
+ serialized[key] = [
49
+ self._serialize_observation(item) if isinstance(item, dict) else item
50
+ for item in value
51
+ ]
52
+ else:
53
+ serialized[key] = value
54
+ return serialized
55
+
56
+ @classmethod
57
+ def from_dict(cls, d: Dict[str, Any]) -> "TrajectoryFrame":
58
+ """Reconstruct from dict."""
59
+ if "observation" in d and d["observation"]:
60
+ d["observation"] = cls._deserialize_observation(d["observation"])
61
+ return cls(**d)
62
+
63
+ @staticmethod
64
+ def _deserialize_observation(obs: Dict[str, Any]) -> Dict[str, Any]:
65
+ """Deserialize observation, converting lists back to numpy arrays."""
66
+ deserialized = {}
67
+ for key, value in obs.items():
68
+ if isinstance(value, dict) and value.get("type") == "ndarray":
69
+ deserialized[key] = np.array(value["data"], dtype=value["dtype"]).reshape(
70
+ value["shape"]
71
+ )
72
+ elif isinstance(value, dict):
73
+ deserialized[key] = TrajectoryFrame._deserialize_observation(value)
74
+ elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict):
75
+ deserialized[key] = [
76
+ TrajectoryFrame._deserialize_observation(item)
77
+ if isinstance(item, dict)
78
+ else item
79
+ for item in value
80
+ ]
81
+ else:
82
+ deserialized[key] = value
83
+ return deserialized
84
+
85
+
86
+ @dataclass
87
+ class TrajectoryMetadata:
88
+ """Metadata for a trajectory."""
89
+
90
+ trajectory_id: str
91
+ character_role: str
92
+ task_id: Optional[str]
93
+ start_time: datetime
94
+ end_time: Optional[datetime]
95
+ total_steps: int
96
+ total_reward: float
97
+ final_status: str # 'completed', 'died', 'quit', 'truncated'
98
+ max_depth_reached: int
99
+ achievements: Dict[str, bool]
100
+
101
+ def to_dict(self) -> Dict[str, Any]:
102
+ """Convert to serializable dict."""
103
+ d = asdict(self)
104
+ d["start_time"] = self.start_time.isoformat()
105
+ if self.end_time:
106
+ d["end_time"] = self.end_time.isoformat()
107
+ return d
108
+
109
+ @classmethod
110
+ def from_dict(cls, d: Dict[str, Any]) -> "TrajectoryMetadata":
111
+ """Reconstruct from dict."""
112
+ d["start_time"] = datetime.fromisoformat(d["start_time"])
113
+ if d.get("end_time"):
114
+ d["end_time"] = datetime.fromisoformat(d["end_time"])
115
+ return cls(**d)
116
+
117
+
118
+ class TrajectoryRecorder:
119
+ """Records and saves NetHack game trajectories."""
120
+
121
+ def __init__(self, save_dir: str = "temp/nethack_trajectories"):
122
+ self.save_dir = Path(save_dir)
123
+ self.save_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ self.frames: List[TrajectoryFrame] = []
126
+ self.metadata: Optional[TrajectoryMetadata] = None
127
+ self.is_recording = False
128
+ self.current_step = 0
129
+ self.total_reward = 0.0
130
+ self.max_depth = 1
131
+
132
+ def start_recording(self, character_role: str, task_id: Optional[str] = None) -> str:
133
+ """Start recording a new trajectory."""
134
+ trajectory_id = datetime.now().strftime("%Y%m%d_%H%M%S_") + character_role
135
+
136
+ self.metadata = TrajectoryMetadata(
137
+ trajectory_id=trajectory_id,
138
+ character_role=character_role,
139
+ task_id=task_id,
140
+ start_time=datetime.now(),
141
+ end_time=None,
142
+ total_steps=0,
143
+ total_reward=0.0,
144
+ final_status="in_progress",
145
+ max_depth_reached=1,
146
+ achievements={},
147
+ )
148
+
149
+ self.frames = []
150
+ self.is_recording = True
151
+ self.current_step = 0
152
+ self.total_reward = 0.0
153
+ self.max_depth = 1
154
+
155
+ return trajectory_id
156
+
157
+ def record_step(
158
+ self,
159
+ action: str,
160
+ observation: Dict[str, Any],
161
+ reward: float,
162
+ done: bool,
163
+ info: Dict[str, Any],
164
+ ):
165
+ """Record a single step."""
166
+ if not self.is_recording:
167
+ raise ValueError("Recording not started. Call start_recording first.")
168
+
169
+ frame = TrajectoryFrame(
170
+ step=self.current_step,
171
+ action=action,
172
+ observation=observation.copy(),
173
+ reward=reward,
174
+ done=done,
175
+ info=info.copy(),
176
+ timestamp=datetime.now().timestamp(),
177
+ )
178
+
179
+ self.frames.append(frame)
180
+ self.current_step += 1
181
+ self.total_reward += reward
182
+
183
+ # Update max depth if available
184
+ if "player_stats" in observation:
185
+ depth = observation["player_stats"].get("depth", 1)
186
+ self.max_depth = max(self.max_depth, depth)
187
+
188
+ def stop_recording(
189
+ self,
190
+ final_status: str = "completed",
191
+ achievements: Optional[Dict[str, bool]] = None,
192
+ ):
193
+ """Stop recording and finalize metadata."""
194
+ if not self.is_recording:
195
+ return
196
+
197
+ self.is_recording = False
198
+
199
+ if self.metadata:
200
+ self.metadata.end_time = datetime.now()
201
+ self.metadata.total_steps = self.current_step
202
+ self.metadata.total_reward = self.total_reward
203
+ self.metadata.final_status = final_status
204
+ self.metadata.max_depth_reached = self.max_depth
205
+ if achievements:
206
+ self.metadata.achievements = achievements
207
+
208
+ def save_trajectory(self, filename: Optional[str] = None) -> str:
209
+ """Save trajectory to disk."""
210
+ if not self.metadata:
211
+ raise ValueError("No trajectory to save")
212
+
213
+ if filename is None:
214
+ filename = f"{self.metadata.trajectory_id}.trajectory.gz"
215
+
216
+ filepath = self.save_dir / filename
217
+
218
+ trajectory_data = {
219
+ "metadata": self.metadata.to_dict(),
220
+ "frames": [frame.to_dict() for frame in self.frames],
221
+ }
222
+
223
+ # Save as compressed JSON
224
+ with gzip.open(filepath, "wt", encoding="utf-8") as f:
225
+ json.dump(trajectory_data, f, indent=2)
226
+
227
+ # Also save a quick info file
228
+ info_file = self.save_dir / f"{self.metadata.trajectory_id}.info.json"
229
+ with open(info_file, "w") as f:
230
+ json.dump(self.metadata.to_dict(), f, indent=2)
231
+
232
+ return str(filepath)
233
+
234
+ @classmethod
235
+ def load_trajectory(
236
+ cls, filepath: str
237
+ ) -> Tuple["TrajectoryRecorder", TrajectoryMetadata, List[TrajectoryFrame]]:
238
+ """Load a trajectory from disk."""
239
+ recorder = cls()
240
+
241
+ with gzip.open(filepath, "rt", encoding="utf-8") as f:
242
+ data = json.load(f)
243
+
244
+ metadata = TrajectoryMetadata.from_dict(data["metadata"])
245
+ frames = [TrajectoryFrame.from_dict(frame) for frame in data["frames"]]
246
+
247
+ recorder.metadata = metadata
248
+ recorder.frames = frames
249
+
250
+ return recorder, metadata, frames
251
+
252
+ def get_summary(self) -> Dict[str, Any]:
253
+ """Get summary statistics of the current trajectory."""
254
+ if not self.frames:
255
+ return {}
256
+
257
+ actions_taken = {}
258
+ for frame in self.frames:
259
+ actions_taken[frame.action] = actions_taken.get(frame.action, 0) + 1
260
+
261
+ return {
262
+ "total_steps": len(self.frames),
263
+ "total_reward": self.total_reward,
264
+ "max_depth": self.max_depth,
265
+ "actions_distribution": actions_taken,
266
+ "unique_actions": len(actions_taken),
267
+ "average_reward_per_step": self.total_reward / len(self.frames) if self.frames else 0,
268
+ }
@@ -0,0 +1,308 @@
1
+ """Interactive replay viewer for NetHack trajectories."""
2
+
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import List, Dict, Any, Optional, Tuple
7
+ import json
8
+ import gzip
9
+ from datetime import datetime
10
+
11
+ # Add parent directory to path for imports
12
+ sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent.parent))
13
+
14
+ from src.synth_env.examples.nethack.helpers.trajectory_recorder import (
15
+ TrajectoryRecorder,
16
+ TrajectoryFrame,
17
+ )
18
+ from src.synth_env.examples.nethack.helpers.visualization.visualizer import (
19
+ NetHackVisualizer,
20
+ )
21
+
22
+
23
+ class ReplayViewer:
24
+ """Interactive viewer for NetHack trajectory replays."""
25
+
26
+ def __init__(self, trajectory_path: str):
27
+ """Initialize replay viewer with a trajectory file."""
28
+ self.trajectory_path = Path(trajectory_path)
29
+
30
+ # Load trajectory
31
+ self.recorder, self.metadata, self.frames = TrajectoryRecorder.load_trajectory(
32
+ str(trajectory_path)
33
+ )
34
+
35
+ # Initialize visualizer
36
+ self.visualizer = NetHackVisualizer()
37
+
38
+ # Playback state
39
+ self.current_frame = 0
40
+ self.is_playing = False
41
+ self.playback_speed = 1.0 # Frames per second
42
+
43
+ print(f"Loaded trajectory: {self.metadata.trajectory_id}")
44
+ print(f"Character: {self.metadata.character_role}")
45
+ print(f"Total steps: {self.metadata.total_steps}")
46
+ print(f"Total reward: {self.metadata.total_reward:.2f}")
47
+ print(f"Final status: {self.metadata.final_status}")
48
+ print(f"Max depth reached: {self.metadata.max_depth_reached}")
49
+
50
+ def show_frame(self, frame_idx: int):
51
+ """Display a specific frame."""
52
+ if 0 <= frame_idx < len(self.frames):
53
+ frame = self.frames[frame_idx]
54
+ obs = frame.observation
55
+
56
+ print(f"\n{'=' * 80}")
57
+ print(f"Frame {frame_idx}/{len(self.frames) - 1} | Step: {frame.step}")
58
+ print(f"Action: {frame.action}")
59
+ print(f"Reward: {frame.reward:+.2f}")
60
+
61
+ # Show message
62
+ message = obs.get("message", "").strip()
63
+ if message:
64
+ print(f"Message: {message}")
65
+
66
+ # Show stats
67
+ stats = obs.get("player_stats", {})
68
+ print(
69
+ f"Position: ({stats.get('x', 0)}, {stats.get('y', 0)}) | "
70
+ f"HP: {stats.get('hp', 0)}/{stats.get('max_hp', 0)} | "
71
+ f"Level: {stats.get('experience_level', 1)} | "
72
+ f"Depth: {stats.get('depth', 1)} | "
73
+ f"Gold: {stats.get('gold', 0)}"
74
+ )
75
+
76
+ # Show map
77
+ ascii_map = obs.get("ascii_map", "")
78
+ if ascii_map:
79
+ lines = ascii_map.split("\n")
80
+ px, py = stats.get("x", 0), stats.get("y", 0)
81
+
82
+ # Show area around player
83
+ print("\nMap view:")
84
+ for y in range(max(0, py - 10), min(len(lines), py + 11)):
85
+ if 0 <= y < len(lines):
86
+ line = lines[y]
87
+ start = max(0, px - 20)
88
+ end = min(len(line), px + 21)
89
+ if y == py:
90
+ print(f">>> {line[start:end]} <<<")
91
+ else:
92
+ print(f" {line[start:end]}")
93
+
94
+ def interactive_replay(self):
95
+ """Run interactive replay session."""
96
+ print("\n=== Interactive Replay Viewer ===")
97
+ print("Commands:")
98
+ print(" n/next - Next frame")
99
+ print(" p/prev - Previous frame")
100
+ print(" g <num> - Go to frame number")
101
+ print(" f/first - Go to first frame")
102
+ print(" l/last - Go to last frame")
103
+ print(" i/info - Show trajectory info")
104
+ print(" s/search <action> - Find frames with action")
105
+ print(" export <type> - Export (video/stats/actions)")
106
+ print(" q/quit - Exit viewer")
107
+
108
+ # Show first frame
109
+ self.show_frame(0)
110
+
111
+ while True:
112
+ try:
113
+ cmd = input(f"\nFrame {self.current_frame}> ").strip().lower()
114
+
115
+ if cmd in ["q", "quit"]:
116
+ break
117
+
118
+ elif cmd in ["n", "next"]:
119
+ if self.current_frame < len(self.frames) - 1:
120
+ self.current_frame += 1
121
+ self.show_frame(self.current_frame)
122
+ else:
123
+ print("Already at last frame")
124
+
125
+ elif cmd in ["p", "prev"]:
126
+ if self.current_frame > 0:
127
+ self.current_frame -= 1
128
+ self.show_frame(self.current_frame)
129
+ else:
130
+ print("Already at first frame")
131
+
132
+ elif cmd.startswith("g "):
133
+ try:
134
+ frame_num = int(cmd.split()[1])
135
+ if 0 <= frame_num < len(self.frames):
136
+ self.current_frame = frame_num
137
+ self.show_frame(self.current_frame)
138
+ else:
139
+ print(f"Frame number must be between 0 and {len(self.frames) - 1}")
140
+ except (ValueError, IndexError):
141
+ print("Usage: g <frame_number>")
142
+
143
+ elif cmd in ["f", "first"]:
144
+ self.current_frame = 0
145
+ self.show_frame(self.current_frame)
146
+
147
+ elif cmd in ["l", "last"]:
148
+ self.current_frame = len(self.frames) - 1
149
+ self.show_frame(self.current_frame)
150
+
151
+ elif cmd in ["i", "info"]:
152
+ self.show_trajectory_info()
153
+
154
+ elif cmd.startswith("s ") or cmd.startswith("search "):
155
+ action = " ".join(cmd.split()[1:])
156
+ self.search_action(action)
157
+
158
+ elif cmd.startswith("export"):
159
+ parts = cmd.split()
160
+ if len(parts) > 1:
161
+ self.export_trajectory(parts[1])
162
+ else:
163
+ print("Usage: export <video|stats|actions>")
164
+
165
+ else:
166
+ print("Unknown command. Type 'q' to quit.")
167
+
168
+ except KeyboardInterrupt:
169
+ print("\nUse 'q' to quit")
170
+ except Exception as e:
171
+ print(f"Error: {e}")
172
+
173
+ def show_trajectory_info(self):
174
+ """Display trajectory information and statistics."""
175
+ print(f"\n=== Trajectory Information ===")
176
+ print(f"ID: {self.metadata.trajectory_id}")
177
+ print(f"Character: {self.metadata.character_role}")
178
+ print(f"Task ID: {self.metadata.task_id or 'N/A'}")
179
+ print(f"Start time: {self.metadata.start_time}")
180
+ print(f"End time: {self.metadata.end_time}")
181
+ print(
182
+ f"Duration: {(self.metadata.end_time - self.metadata.start_time).total_seconds():.1f} seconds"
183
+ )
184
+ print(f"Total steps: {self.metadata.total_steps}")
185
+ print(f"Total reward: {self.metadata.total_reward:.2f}")
186
+ print(f"Average reward/step: {self.metadata.total_reward / self.metadata.total_steps:.4f}")
187
+ print(f"Final status: {self.metadata.final_status}")
188
+ print(f"Max depth: {self.metadata.max_depth_reached}")
189
+
190
+ # Action distribution
191
+ action_counts = {}
192
+ for frame in self.frames:
193
+ action_counts[frame.action] = action_counts.get(frame.action, 0) + 1
194
+
195
+ print(f"\nTop 10 actions:")
196
+ for action, count in sorted(action_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
197
+ print(f" {action}: {count} ({count / len(self.frames) * 100:.1f}%)")
198
+
199
+ def search_action(self, action: str):
200
+ """Search for frames containing specific action."""
201
+ matches = []
202
+ for i, frame in enumerate(self.frames):
203
+ if action.lower() in frame.action.lower():
204
+ matches.append(i)
205
+
206
+ if matches:
207
+ print(f"Found {len(matches)} frames with action '{action}':")
208
+ for i, frame_idx in enumerate(matches[:10]): # Show first 10
209
+ frame = self.frames[frame_idx]
210
+ print(f" Frame {frame_idx}: {frame.action} (reward: {frame.reward:+.2f})")
211
+
212
+ if len(matches) > 10:
213
+ print(f" ... and {len(matches) - 10} more")
214
+ else:
215
+ print(f"No frames found with action '{action}'")
216
+
217
+ def export_trajectory(self, export_type: str):
218
+ """Export trajectory in various formats."""
219
+ output_dir = self.trajectory_path.parent / "exports"
220
+ output_dir.mkdir(exist_ok=True)
221
+
222
+ if export_type == "video":
223
+ print("Creating video...")
224
+ output_path = output_dir / f"{self.metadata.trajectory_id}.mp4"
225
+
226
+ # Convert frames to format expected by visualizer
227
+ vis_frames = []
228
+ for frame in self.frames:
229
+ vis_frames.append({"action": frame.action, "observation": frame.observation})
230
+
231
+ try:
232
+ video_path = self.visualizer.create_trajectory_video(
233
+ vis_frames, str(output_path), fps=4, include_stats=True
234
+ )
235
+ print(f"Video saved to: {video_path}")
236
+ except Exception as e:
237
+ print(f"Error creating video: {e}")
238
+ print("Make sure ffmpeg is installed for video export")
239
+
240
+ elif export_type == "stats":
241
+ print("Creating statistics plots...")
242
+ stats_path = output_dir / f"{self.metadata.trajectory_id}_stats.png"
243
+ action_path = output_dir / f"{self.metadata.trajectory_id}_actions.png"
244
+
245
+ # Convert frames for visualizer
246
+ vis_frames = []
247
+ for frame in self.frames:
248
+ vis_frames.append(
249
+ {
250
+ "action": frame.action,
251
+ "observation": frame.observation,
252
+ "reward": frame.reward,
253
+ }
254
+ )
255
+
256
+ self.visualizer.plot_trajectory_stats(vis_frames, str(stats_path))
257
+ self.visualizer.plot_action_distribution(vis_frames, str(action_path))
258
+ print(f"Stats saved to: {stats_path}")
259
+ print(f"Action distribution saved to: {action_path}")
260
+
261
+ elif export_type == "actions":
262
+ print("Exporting action sequence...")
263
+ actions_path = output_dir / f"{self.metadata.trajectory_id}_actions.txt"
264
+
265
+ with open(actions_path, "w") as f:
266
+ f.write(f"# Trajectory: {self.metadata.trajectory_id}\n")
267
+ f.write(f"# Character: {self.metadata.character_role}\n")
268
+ f.write(f"# Total steps: {self.metadata.total_steps}\n")
269
+ f.write(f"# Total reward: {self.metadata.total_reward}\n\n")
270
+
271
+ for frame in self.frames:
272
+ f.write(f"{frame.step:4d}: {frame.action:15s} (reward: {frame.reward:+.2f})\n")
273
+
274
+ print(f"Action sequence saved to: {actions_path}")
275
+
276
+ else:
277
+ print(f"Unknown export type: {export_type}")
278
+ print("Available types: video, stats, actions")
279
+
280
+
281
+ def main():
282
+ """Main entry point for replay viewer."""
283
+ import argparse
284
+
285
+ parser = argparse.ArgumentParser(description="NetHack Trajectory Replay Viewer")
286
+ parser.add_argument("trajectory", help="Path to trajectory file (.trajectory.gz)")
287
+ parser.add_argument(
288
+ "--export",
289
+ choices=["video", "stats", "actions"],
290
+ help="Export trajectory without interactive mode",
291
+ )
292
+
293
+ args = parser.parse_args()
294
+
295
+ if not Path(args.trajectory).exists():
296
+ print(f"Error: Trajectory file not found: {args.trajectory}")
297
+ sys.exit(1)
298
+
299
+ viewer = ReplayViewer(args.trajectory)
300
+
301
+ if args.export:
302
+ viewer.export_trajectory(args.export)
303
+ else:
304
+ viewer.interactive_replay()
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()