synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/tracing_v3/examples/basic_usage.py +188 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,268 @@
|
|
1
|
+
"""Trajectory recording and replay functionality for NetHack."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import pickle
|
5
|
+
import gzip
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Dict, Any, List, Optional, Tuple
|
8
|
+
from datetime import datetime
|
9
|
+
import numpy as np
|
10
|
+
from dataclasses import dataclass, asdict
|
11
|
+
import base64
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class TrajectoryFrame:
|
16
|
+
"""Single frame in a trajectory."""
|
17
|
+
|
18
|
+
step: int
|
19
|
+
action: str
|
20
|
+
observation: Dict[str, Any]
|
21
|
+
reward: float
|
22
|
+
done: bool
|
23
|
+
info: Dict[str, Any]
|
24
|
+
timestamp: float
|
25
|
+
|
26
|
+
def to_dict(self) -> Dict[str, Any]:
|
27
|
+
"""Convert to serializable dict."""
|
28
|
+
d = asdict(self)
|
29
|
+
# Handle numpy arrays in observation
|
30
|
+
if "observation" in d and d["observation"]:
|
31
|
+
d["observation"] = self._serialize_observation(d["observation"])
|
32
|
+
return d
|
33
|
+
|
34
|
+
def _serialize_observation(self, obs: Dict[str, Any]) -> Dict[str, Any]:
|
35
|
+
"""Serialize observation, converting numpy arrays to lists."""
|
36
|
+
serialized = {}
|
37
|
+
for key, value in obs.items():
|
38
|
+
if isinstance(value, np.ndarray):
|
39
|
+
serialized[key] = {
|
40
|
+
"type": "ndarray",
|
41
|
+
"data": value.tolist(),
|
42
|
+
"dtype": str(value.dtype),
|
43
|
+
"shape": value.shape,
|
44
|
+
}
|
45
|
+
elif isinstance(value, dict):
|
46
|
+
serialized[key] = self._serialize_observation(value)
|
47
|
+
elif isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], dict):
|
48
|
+
serialized[key] = [
|
49
|
+
self._serialize_observation(item) if isinstance(item, dict) else item
|
50
|
+
for item in value
|
51
|
+
]
|
52
|
+
else:
|
53
|
+
serialized[key] = value
|
54
|
+
return serialized
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def from_dict(cls, d: Dict[str, Any]) -> "TrajectoryFrame":
|
58
|
+
"""Reconstruct from dict."""
|
59
|
+
if "observation" in d and d["observation"]:
|
60
|
+
d["observation"] = cls._deserialize_observation(d["observation"])
|
61
|
+
return cls(**d)
|
62
|
+
|
63
|
+
@staticmethod
|
64
|
+
def _deserialize_observation(obs: Dict[str, Any]) -> Dict[str, Any]:
|
65
|
+
"""Deserialize observation, converting lists back to numpy arrays."""
|
66
|
+
deserialized = {}
|
67
|
+
for key, value in obs.items():
|
68
|
+
if isinstance(value, dict) and value.get("type") == "ndarray":
|
69
|
+
deserialized[key] = np.array(value["data"], dtype=value["dtype"]).reshape(
|
70
|
+
value["shape"]
|
71
|
+
)
|
72
|
+
elif isinstance(value, dict):
|
73
|
+
deserialized[key] = TrajectoryFrame._deserialize_observation(value)
|
74
|
+
elif isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict):
|
75
|
+
deserialized[key] = [
|
76
|
+
TrajectoryFrame._deserialize_observation(item)
|
77
|
+
if isinstance(item, dict)
|
78
|
+
else item
|
79
|
+
for item in value
|
80
|
+
]
|
81
|
+
else:
|
82
|
+
deserialized[key] = value
|
83
|
+
return deserialized
|
84
|
+
|
85
|
+
|
86
|
+
@dataclass
|
87
|
+
class TrajectoryMetadata:
|
88
|
+
"""Metadata for a trajectory."""
|
89
|
+
|
90
|
+
trajectory_id: str
|
91
|
+
character_role: str
|
92
|
+
task_id: Optional[str]
|
93
|
+
start_time: datetime
|
94
|
+
end_time: Optional[datetime]
|
95
|
+
total_steps: int
|
96
|
+
total_reward: float
|
97
|
+
final_status: str # 'completed', 'died', 'quit', 'truncated'
|
98
|
+
max_depth_reached: int
|
99
|
+
achievements: Dict[str, bool]
|
100
|
+
|
101
|
+
def to_dict(self) -> Dict[str, Any]:
|
102
|
+
"""Convert to serializable dict."""
|
103
|
+
d = asdict(self)
|
104
|
+
d["start_time"] = self.start_time.isoformat()
|
105
|
+
if self.end_time:
|
106
|
+
d["end_time"] = self.end_time.isoformat()
|
107
|
+
return d
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def from_dict(cls, d: Dict[str, Any]) -> "TrajectoryMetadata":
|
111
|
+
"""Reconstruct from dict."""
|
112
|
+
d["start_time"] = datetime.fromisoformat(d["start_time"])
|
113
|
+
if d.get("end_time"):
|
114
|
+
d["end_time"] = datetime.fromisoformat(d["end_time"])
|
115
|
+
return cls(**d)
|
116
|
+
|
117
|
+
|
118
|
+
class TrajectoryRecorder:
|
119
|
+
"""Records and saves NetHack game trajectories."""
|
120
|
+
|
121
|
+
def __init__(self, save_dir: str = "temp/nethack_trajectories"):
|
122
|
+
self.save_dir = Path(save_dir)
|
123
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
124
|
+
|
125
|
+
self.frames: List[TrajectoryFrame] = []
|
126
|
+
self.metadata: Optional[TrajectoryMetadata] = None
|
127
|
+
self.is_recording = False
|
128
|
+
self.current_step = 0
|
129
|
+
self.total_reward = 0.0
|
130
|
+
self.max_depth = 1
|
131
|
+
|
132
|
+
def start_recording(self, character_role: str, task_id: Optional[str] = None) -> str:
|
133
|
+
"""Start recording a new trajectory."""
|
134
|
+
trajectory_id = datetime.now().strftime("%Y%m%d_%H%M%S_") + character_role
|
135
|
+
|
136
|
+
self.metadata = TrajectoryMetadata(
|
137
|
+
trajectory_id=trajectory_id,
|
138
|
+
character_role=character_role,
|
139
|
+
task_id=task_id,
|
140
|
+
start_time=datetime.now(),
|
141
|
+
end_time=None,
|
142
|
+
total_steps=0,
|
143
|
+
total_reward=0.0,
|
144
|
+
final_status="in_progress",
|
145
|
+
max_depth_reached=1,
|
146
|
+
achievements={},
|
147
|
+
)
|
148
|
+
|
149
|
+
self.frames = []
|
150
|
+
self.is_recording = True
|
151
|
+
self.current_step = 0
|
152
|
+
self.total_reward = 0.0
|
153
|
+
self.max_depth = 1
|
154
|
+
|
155
|
+
return trajectory_id
|
156
|
+
|
157
|
+
def record_step(
|
158
|
+
self,
|
159
|
+
action: str,
|
160
|
+
observation: Dict[str, Any],
|
161
|
+
reward: float,
|
162
|
+
done: bool,
|
163
|
+
info: Dict[str, Any],
|
164
|
+
):
|
165
|
+
"""Record a single step."""
|
166
|
+
if not self.is_recording:
|
167
|
+
raise ValueError("Recording not started. Call start_recording first.")
|
168
|
+
|
169
|
+
frame = TrajectoryFrame(
|
170
|
+
step=self.current_step,
|
171
|
+
action=action,
|
172
|
+
observation=observation.copy(),
|
173
|
+
reward=reward,
|
174
|
+
done=done,
|
175
|
+
info=info.copy(),
|
176
|
+
timestamp=datetime.now().timestamp(),
|
177
|
+
)
|
178
|
+
|
179
|
+
self.frames.append(frame)
|
180
|
+
self.current_step += 1
|
181
|
+
self.total_reward += reward
|
182
|
+
|
183
|
+
# Update max depth if available
|
184
|
+
if "player_stats" in observation:
|
185
|
+
depth = observation["player_stats"].get("depth", 1)
|
186
|
+
self.max_depth = max(self.max_depth, depth)
|
187
|
+
|
188
|
+
def stop_recording(
|
189
|
+
self,
|
190
|
+
final_status: str = "completed",
|
191
|
+
achievements: Optional[Dict[str, bool]] = None,
|
192
|
+
):
|
193
|
+
"""Stop recording and finalize metadata."""
|
194
|
+
if not self.is_recording:
|
195
|
+
return
|
196
|
+
|
197
|
+
self.is_recording = False
|
198
|
+
|
199
|
+
if self.metadata:
|
200
|
+
self.metadata.end_time = datetime.now()
|
201
|
+
self.metadata.total_steps = self.current_step
|
202
|
+
self.metadata.total_reward = self.total_reward
|
203
|
+
self.metadata.final_status = final_status
|
204
|
+
self.metadata.max_depth_reached = self.max_depth
|
205
|
+
if achievements:
|
206
|
+
self.metadata.achievements = achievements
|
207
|
+
|
208
|
+
def save_trajectory(self, filename: Optional[str] = None) -> str:
|
209
|
+
"""Save trajectory to disk."""
|
210
|
+
if not self.metadata:
|
211
|
+
raise ValueError("No trajectory to save")
|
212
|
+
|
213
|
+
if filename is None:
|
214
|
+
filename = f"{self.metadata.trajectory_id}.trajectory.gz"
|
215
|
+
|
216
|
+
filepath = self.save_dir / filename
|
217
|
+
|
218
|
+
trajectory_data = {
|
219
|
+
"metadata": self.metadata.to_dict(),
|
220
|
+
"frames": [frame.to_dict() for frame in self.frames],
|
221
|
+
}
|
222
|
+
|
223
|
+
# Save as compressed JSON
|
224
|
+
with gzip.open(filepath, "wt", encoding="utf-8") as f:
|
225
|
+
json.dump(trajectory_data, f, indent=2)
|
226
|
+
|
227
|
+
# Also save a quick info file
|
228
|
+
info_file = self.save_dir / f"{self.metadata.trajectory_id}.info.json"
|
229
|
+
with open(info_file, "w") as f:
|
230
|
+
json.dump(self.metadata.to_dict(), f, indent=2)
|
231
|
+
|
232
|
+
return str(filepath)
|
233
|
+
|
234
|
+
@classmethod
|
235
|
+
def load_trajectory(
|
236
|
+
cls, filepath: str
|
237
|
+
) -> Tuple["TrajectoryRecorder", TrajectoryMetadata, List[TrajectoryFrame]]:
|
238
|
+
"""Load a trajectory from disk."""
|
239
|
+
recorder = cls()
|
240
|
+
|
241
|
+
with gzip.open(filepath, "rt", encoding="utf-8") as f:
|
242
|
+
data = json.load(f)
|
243
|
+
|
244
|
+
metadata = TrajectoryMetadata.from_dict(data["metadata"])
|
245
|
+
frames = [TrajectoryFrame.from_dict(frame) for frame in data["frames"]]
|
246
|
+
|
247
|
+
recorder.metadata = metadata
|
248
|
+
recorder.frames = frames
|
249
|
+
|
250
|
+
return recorder, metadata, frames
|
251
|
+
|
252
|
+
def get_summary(self) -> Dict[str, Any]:
|
253
|
+
"""Get summary statistics of the current trajectory."""
|
254
|
+
if not self.frames:
|
255
|
+
return {}
|
256
|
+
|
257
|
+
actions_taken = {}
|
258
|
+
for frame in self.frames:
|
259
|
+
actions_taken[frame.action] = actions_taken.get(frame.action, 0) + 1
|
260
|
+
|
261
|
+
return {
|
262
|
+
"total_steps": len(self.frames),
|
263
|
+
"total_reward": self.total_reward,
|
264
|
+
"max_depth": self.max_depth,
|
265
|
+
"actions_distribution": actions_taken,
|
266
|
+
"unique_actions": len(actions_taken),
|
267
|
+
"average_reward_per_step": self.total_reward / len(self.frames) if self.frames else 0,
|
268
|
+
}
|
@@ -0,0 +1,308 @@
|
|
1
|
+
"""Interactive replay viewer for NetHack trajectories."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import sys
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List, Dict, Any, Optional, Tuple
|
7
|
+
import json
|
8
|
+
import gzip
|
9
|
+
from datetime import datetime
|
10
|
+
|
11
|
+
# Add parent directory to path for imports
|
12
|
+
sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent.parent))
|
13
|
+
|
14
|
+
from src.synth_env.examples.nethack.helpers.trajectory_recorder import (
|
15
|
+
TrajectoryRecorder,
|
16
|
+
TrajectoryFrame,
|
17
|
+
)
|
18
|
+
from src.synth_env.examples.nethack.helpers.visualization.visualizer import (
|
19
|
+
NetHackVisualizer,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class ReplayViewer:
|
24
|
+
"""Interactive viewer for NetHack trajectory replays."""
|
25
|
+
|
26
|
+
def __init__(self, trajectory_path: str):
|
27
|
+
"""Initialize replay viewer with a trajectory file."""
|
28
|
+
self.trajectory_path = Path(trajectory_path)
|
29
|
+
|
30
|
+
# Load trajectory
|
31
|
+
self.recorder, self.metadata, self.frames = TrajectoryRecorder.load_trajectory(
|
32
|
+
str(trajectory_path)
|
33
|
+
)
|
34
|
+
|
35
|
+
# Initialize visualizer
|
36
|
+
self.visualizer = NetHackVisualizer()
|
37
|
+
|
38
|
+
# Playback state
|
39
|
+
self.current_frame = 0
|
40
|
+
self.is_playing = False
|
41
|
+
self.playback_speed = 1.0 # Frames per second
|
42
|
+
|
43
|
+
print(f"Loaded trajectory: {self.metadata.trajectory_id}")
|
44
|
+
print(f"Character: {self.metadata.character_role}")
|
45
|
+
print(f"Total steps: {self.metadata.total_steps}")
|
46
|
+
print(f"Total reward: {self.metadata.total_reward:.2f}")
|
47
|
+
print(f"Final status: {self.metadata.final_status}")
|
48
|
+
print(f"Max depth reached: {self.metadata.max_depth_reached}")
|
49
|
+
|
50
|
+
def show_frame(self, frame_idx: int):
|
51
|
+
"""Display a specific frame."""
|
52
|
+
if 0 <= frame_idx < len(self.frames):
|
53
|
+
frame = self.frames[frame_idx]
|
54
|
+
obs = frame.observation
|
55
|
+
|
56
|
+
print(f"\n{'=' * 80}")
|
57
|
+
print(f"Frame {frame_idx}/{len(self.frames) - 1} | Step: {frame.step}")
|
58
|
+
print(f"Action: {frame.action}")
|
59
|
+
print(f"Reward: {frame.reward:+.2f}")
|
60
|
+
|
61
|
+
# Show message
|
62
|
+
message = obs.get("message", "").strip()
|
63
|
+
if message:
|
64
|
+
print(f"Message: {message}")
|
65
|
+
|
66
|
+
# Show stats
|
67
|
+
stats = obs.get("player_stats", {})
|
68
|
+
print(
|
69
|
+
f"Position: ({stats.get('x', 0)}, {stats.get('y', 0)}) | "
|
70
|
+
f"HP: {stats.get('hp', 0)}/{stats.get('max_hp', 0)} | "
|
71
|
+
f"Level: {stats.get('experience_level', 1)} | "
|
72
|
+
f"Depth: {stats.get('depth', 1)} | "
|
73
|
+
f"Gold: {stats.get('gold', 0)}"
|
74
|
+
)
|
75
|
+
|
76
|
+
# Show map
|
77
|
+
ascii_map = obs.get("ascii_map", "")
|
78
|
+
if ascii_map:
|
79
|
+
lines = ascii_map.split("\n")
|
80
|
+
px, py = stats.get("x", 0), stats.get("y", 0)
|
81
|
+
|
82
|
+
# Show area around player
|
83
|
+
print("\nMap view:")
|
84
|
+
for y in range(max(0, py - 10), min(len(lines), py + 11)):
|
85
|
+
if 0 <= y < len(lines):
|
86
|
+
line = lines[y]
|
87
|
+
start = max(0, px - 20)
|
88
|
+
end = min(len(line), px + 21)
|
89
|
+
if y == py:
|
90
|
+
print(f">>> {line[start:end]} <<<")
|
91
|
+
else:
|
92
|
+
print(f" {line[start:end]}")
|
93
|
+
|
94
|
+
def interactive_replay(self):
|
95
|
+
"""Run interactive replay session."""
|
96
|
+
print("\n=== Interactive Replay Viewer ===")
|
97
|
+
print("Commands:")
|
98
|
+
print(" n/next - Next frame")
|
99
|
+
print(" p/prev - Previous frame")
|
100
|
+
print(" g <num> - Go to frame number")
|
101
|
+
print(" f/first - Go to first frame")
|
102
|
+
print(" l/last - Go to last frame")
|
103
|
+
print(" i/info - Show trajectory info")
|
104
|
+
print(" s/search <action> - Find frames with action")
|
105
|
+
print(" export <type> - Export (video/stats/actions)")
|
106
|
+
print(" q/quit - Exit viewer")
|
107
|
+
|
108
|
+
# Show first frame
|
109
|
+
self.show_frame(0)
|
110
|
+
|
111
|
+
while True:
|
112
|
+
try:
|
113
|
+
cmd = input(f"\nFrame {self.current_frame}> ").strip().lower()
|
114
|
+
|
115
|
+
if cmd in ["q", "quit"]:
|
116
|
+
break
|
117
|
+
|
118
|
+
elif cmd in ["n", "next"]:
|
119
|
+
if self.current_frame < len(self.frames) - 1:
|
120
|
+
self.current_frame += 1
|
121
|
+
self.show_frame(self.current_frame)
|
122
|
+
else:
|
123
|
+
print("Already at last frame")
|
124
|
+
|
125
|
+
elif cmd in ["p", "prev"]:
|
126
|
+
if self.current_frame > 0:
|
127
|
+
self.current_frame -= 1
|
128
|
+
self.show_frame(self.current_frame)
|
129
|
+
else:
|
130
|
+
print("Already at first frame")
|
131
|
+
|
132
|
+
elif cmd.startswith("g "):
|
133
|
+
try:
|
134
|
+
frame_num = int(cmd.split()[1])
|
135
|
+
if 0 <= frame_num < len(self.frames):
|
136
|
+
self.current_frame = frame_num
|
137
|
+
self.show_frame(self.current_frame)
|
138
|
+
else:
|
139
|
+
print(f"Frame number must be between 0 and {len(self.frames) - 1}")
|
140
|
+
except (ValueError, IndexError):
|
141
|
+
print("Usage: g <frame_number>")
|
142
|
+
|
143
|
+
elif cmd in ["f", "first"]:
|
144
|
+
self.current_frame = 0
|
145
|
+
self.show_frame(self.current_frame)
|
146
|
+
|
147
|
+
elif cmd in ["l", "last"]:
|
148
|
+
self.current_frame = len(self.frames) - 1
|
149
|
+
self.show_frame(self.current_frame)
|
150
|
+
|
151
|
+
elif cmd in ["i", "info"]:
|
152
|
+
self.show_trajectory_info()
|
153
|
+
|
154
|
+
elif cmd.startswith("s ") or cmd.startswith("search "):
|
155
|
+
action = " ".join(cmd.split()[1:])
|
156
|
+
self.search_action(action)
|
157
|
+
|
158
|
+
elif cmd.startswith("export"):
|
159
|
+
parts = cmd.split()
|
160
|
+
if len(parts) > 1:
|
161
|
+
self.export_trajectory(parts[1])
|
162
|
+
else:
|
163
|
+
print("Usage: export <video|stats|actions>")
|
164
|
+
|
165
|
+
else:
|
166
|
+
print("Unknown command. Type 'q' to quit.")
|
167
|
+
|
168
|
+
except KeyboardInterrupt:
|
169
|
+
print("\nUse 'q' to quit")
|
170
|
+
except Exception as e:
|
171
|
+
print(f"Error: {e}")
|
172
|
+
|
173
|
+
def show_trajectory_info(self):
|
174
|
+
"""Display trajectory information and statistics."""
|
175
|
+
print(f"\n=== Trajectory Information ===")
|
176
|
+
print(f"ID: {self.metadata.trajectory_id}")
|
177
|
+
print(f"Character: {self.metadata.character_role}")
|
178
|
+
print(f"Task ID: {self.metadata.task_id or 'N/A'}")
|
179
|
+
print(f"Start time: {self.metadata.start_time}")
|
180
|
+
print(f"End time: {self.metadata.end_time}")
|
181
|
+
print(
|
182
|
+
f"Duration: {(self.metadata.end_time - self.metadata.start_time).total_seconds():.1f} seconds"
|
183
|
+
)
|
184
|
+
print(f"Total steps: {self.metadata.total_steps}")
|
185
|
+
print(f"Total reward: {self.metadata.total_reward:.2f}")
|
186
|
+
print(f"Average reward/step: {self.metadata.total_reward / self.metadata.total_steps:.4f}")
|
187
|
+
print(f"Final status: {self.metadata.final_status}")
|
188
|
+
print(f"Max depth: {self.metadata.max_depth_reached}")
|
189
|
+
|
190
|
+
# Action distribution
|
191
|
+
action_counts = {}
|
192
|
+
for frame in self.frames:
|
193
|
+
action_counts[frame.action] = action_counts.get(frame.action, 0) + 1
|
194
|
+
|
195
|
+
print(f"\nTop 10 actions:")
|
196
|
+
for action, count in sorted(action_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
|
197
|
+
print(f" {action}: {count} ({count / len(self.frames) * 100:.1f}%)")
|
198
|
+
|
199
|
+
def search_action(self, action: str):
|
200
|
+
"""Search for frames containing specific action."""
|
201
|
+
matches = []
|
202
|
+
for i, frame in enumerate(self.frames):
|
203
|
+
if action.lower() in frame.action.lower():
|
204
|
+
matches.append(i)
|
205
|
+
|
206
|
+
if matches:
|
207
|
+
print(f"Found {len(matches)} frames with action '{action}':")
|
208
|
+
for i, frame_idx in enumerate(matches[:10]): # Show first 10
|
209
|
+
frame = self.frames[frame_idx]
|
210
|
+
print(f" Frame {frame_idx}: {frame.action} (reward: {frame.reward:+.2f})")
|
211
|
+
|
212
|
+
if len(matches) > 10:
|
213
|
+
print(f" ... and {len(matches) - 10} more")
|
214
|
+
else:
|
215
|
+
print(f"No frames found with action '{action}'")
|
216
|
+
|
217
|
+
def export_trajectory(self, export_type: str):
|
218
|
+
"""Export trajectory in various formats."""
|
219
|
+
output_dir = self.trajectory_path.parent / "exports"
|
220
|
+
output_dir.mkdir(exist_ok=True)
|
221
|
+
|
222
|
+
if export_type == "video":
|
223
|
+
print("Creating video...")
|
224
|
+
output_path = output_dir / f"{self.metadata.trajectory_id}.mp4"
|
225
|
+
|
226
|
+
# Convert frames to format expected by visualizer
|
227
|
+
vis_frames = []
|
228
|
+
for frame in self.frames:
|
229
|
+
vis_frames.append({"action": frame.action, "observation": frame.observation})
|
230
|
+
|
231
|
+
try:
|
232
|
+
video_path = self.visualizer.create_trajectory_video(
|
233
|
+
vis_frames, str(output_path), fps=4, include_stats=True
|
234
|
+
)
|
235
|
+
print(f"Video saved to: {video_path}")
|
236
|
+
except Exception as e:
|
237
|
+
print(f"Error creating video: {e}")
|
238
|
+
print("Make sure ffmpeg is installed for video export")
|
239
|
+
|
240
|
+
elif export_type == "stats":
|
241
|
+
print("Creating statistics plots...")
|
242
|
+
stats_path = output_dir / f"{self.metadata.trajectory_id}_stats.png"
|
243
|
+
action_path = output_dir / f"{self.metadata.trajectory_id}_actions.png"
|
244
|
+
|
245
|
+
# Convert frames for visualizer
|
246
|
+
vis_frames = []
|
247
|
+
for frame in self.frames:
|
248
|
+
vis_frames.append(
|
249
|
+
{
|
250
|
+
"action": frame.action,
|
251
|
+
"observation": frame.observation,
|
252
|
+
"reward": frame.reward,
|
253
|
+
}
|
254
|
+
)
|
255
|
+
|
256
|
+
self.visualizer.plot_trajectory_stats(vis_frames, str(stats_path))
|
257
|
+
self.visualizer.plot_action_distribution(vis_frames, str(action_path))
|
258
|
+
print(f"Stats saved to: {stats_path}")
|
259
|
+
print(f"Action distribution saved to: {action_path}")
|
260
|
+
|
261
|
+
elif export_type == "actions":
|
262
|
+
print("Exporting action sequence...")
|
263
|
+
actions_path = output_dir / f"{self.metadata.trajectory_id}_actions.txt"
|
264
|
+
|
265
|
+
with open(actions_path, "w") as f:
|
266
|
+
f.write(f"# Trajectory: {self.metadata.trajectory_id}\n")
|
267
|
+
f.write(f"# Character: {self.metadata.character_role}\n")
|
268
|
+
f.write(f"# Total steps: {self.metadata.total_steps}\n")
|
269
|
+
f.write(f"# Total reward: {self.metadata.total_reward}\n\n")
|
270
|
+
|
271
|
+
for frame in self.frames:
|
272
|
+
f.write(f"{frame.step:4d}: {frame.action:15s} (reward: {frame.reward:+.2f})\n")
|
273
|
+
|
274
|
+
print(f"Action sequence saved to: {actions_path}")
|
275
|
+
|
276
|
+
else:
|
277
|
+
print(f"Unknown export type: {export_type}")
|
278
|
+
print("Available types: video, stats, actions")
|
279
|
+
|
280
|
+
|
281
|
+
def main():
|
282
|
+
"""Main entry point for replay viewer."""
|
283
|
+
import argparse
|
284
|
+
|
285
|
+
parser = argparse.ArgumentParser(description="NetHack Trajectory Replay Viewer")
|
286
|
+
parser.add_argument("trajectory", help="Path to trajectory file (.trajectory.gz)")
|
287
|
+
parser.add_argument(
|
288
|
+
"--export",
|
289
|
+
choices=["video", "stats", "actions"],
|
290
|
+
help="Export trajectory without interactive mode",
|
291
|
+
)
|
292
|
+
|
293
|
+
args = parser.parse_args()
|
294
|
+
|
295
|
+
if not Path(args.trajectory).exists():
|
296
|
+
print(f"Error: Trajectory file not found: {args.trajectory}")
|
297
|
+
sys.exit(1)
|
298
|
+
|
299
|
+
viewer = ReplayViewer(args.trajectory)
|
300
|
+
|
301
|
+
if args.export:
|
302
|
+
viewer.export_trajectory(args.export)
|
303
|
+
else:
|
304
|
+
viewer.interactive_replay()
|
305
|
+
|
306
|
+
|
307
|
+
if __name__ == "__main__":
|
308
|
+
main()
|