synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +575 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/tracing_v3/examples/basic_usage.py +188 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,430 @@
|
|
1
|
+
"""Visualization tools for NetHack trajectories."""
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from typing import List, Dict, Any, Optional, Tuple
|
5
|
+
from pathlib import Path
|
6
|
+
import matplotlib.pyplot as plt
|
7
|
+
import matplotlib.animation as animation
|
8
|
+
from matplotlib.patches import Rectangle
|
9
|
+
import seaborn as sns
|
10
|
+
from datetime import datetime
|
11
|
+
import json
|
12
|
+
|
13
|
+
try:
|
14
|
+
from PIL import Image, ImageDraw, ImageFont
|
15
|
+
|
16
|
+
HAS_PIL = True
|
17
|
+
except ImportError:
|
18
|
+
HAS_PIL = False
|
19
|
+
print("Warning: PIL not available. Some visualization features will be limited.")
|
20
|
+
|
21
|
+
|
22
|
+
class NetHackVisualizer:
|
23
|
+
"""Visualize NetHack game states and trajectories."""
|
24
|
+
|
25
|
+
# Character to color mapping for ASCII visualization
|
26
|
+
CHAR_COLORS = {
|
27
|
+
"@": "#FF0000", # Player - red
|
28
|
+
"d": "#8B4513", # Dog/pet - brown
|
29
|
+
"f": "#FFA500", # Cat - orange
|
30
|
+
">": "#0000FF", # Stairs down - blue
|
31
|
+
"<": "#00FF00", # Stairs up - green
|
32
|
+
".": "#D3D3D3", # Floor - light gray
|
33
|
+
"#": "#696969", # Wall - dark gray
|
34
|
+
"+": "#8B4513", # Door - brown
|
35
|
+
"-": "#8B4513", # Door - brown
|
36
|
+
"|": "#8B4513", # Door - brown
|
37
|
+
"{": "#00CED1", # Fountain - dark turquoise
|
38
|
+
"}": "#4682B4", # Pool - steel blue
|
39
|
+
"^": "#FF1493", # Trap - deep pink
|
40
|
+
"%": "#FFD700", # Food - gold
|
41
|
+
"!": "#FF69B4", # Potion - hot pink
|
42
|
+
"?": "#DDA0DD", # Scroll - plum
|
43
|
+
"/": "#9370DB", # Wand - medium purple
|
44
|
+
"=": "#FFD700", # Ring - gold
|
45
|
+
'"': "#FF4500", # Amulet - orange red
|
46
|
+
"[": "#C0C0C0", # Armor - silver
|
47
|
+
")": "#A9A9A9", # Weapon - dark gray
|
48
|
+
"*": "#FFFF00", # Gold/gem - yellow
|
49
|
+
"$": "#FFD700", # Gold - gold
|
50
|
+
"`": "#8B4513", # Boulder/statue - brown
|
51
|
+
}
|
52
|
+
|
53
|
+
# Default color for unknown characters
|
54
|
+
DEFAULT_COLOR = "#FFFFFF"
|
55
|
+
|
56
|
+
def __init__(self, cell_size: int = 10, font_size: int = 8):
|
57
|
+
self.cell_size = cell_size
|
58
|
+
self.font_size = font_size
|
59
|
+
|
60
|
+
def ascii_to_image(
|
61
|
+
self, ascii_map: str, highlight_pos: Optional[Tuple[int, int]] = None
|
62
|
+
) -> np.ndarray:
|
63
|
+
"""Convert ASCII map to colored image."""
|
64
|
+
if not HAS_PIL:
|
65
|
+
return self._simple_ascii_to_image(ascii_map, highlight_pos)
|
66
|
+
|
67
|
+
lines = ascii_map.strip().split("\n")
|
68
|
+
height = len(lines)
|
69
|
+
width = max(len(line) for line in lines) if lines else 0
|
70
|
+
|
71
|
+
# Create image
|
72
|
+
img_width = width * self.cell_size
|
73
|
+
img_height = height * self.cell_size
|
74
|
+
image = Image.new("RGB", (img_width, img_height), color="black")
|
75
|
+
draw = ImageDraw.Draw(image)
|
76
|
+
|
77
|
+
# Try to load a monospace font
|
78
|
+
try:
|
79
|
+
font = ImageFont.truetype("/System/Library/Fonts/Monaco.ttf", self.font_size)
|
80
|
+
except:
|
81
|
+
font = ImageFont.load_default()
|
82
|
+
|
83
|
+
# Draw each character
|
84
|
+
for y, line in enumerate(lines):
|
85
|
+
for x, char in enumerate(line):
|
86
|
+
if char == " ":
|
87
|
+
continue
|
88
|
+
|
89
|
+
# Get color for character
|
90
|
+
color = self.CHAR_COLORS.get(char, self.DEFAULT_COLOR)
|
91
|
+
|
92
|
+
# Highlight player position
|
93
|
+
if highlight_pos and (x, y) == highlight_pos:
|
94
|
+
# Draw background
|
95
|
+
draw.rectangle(
|
96
|
+
[
|
97
|
+
x * self.cell_size,
|
98
|
+
y * self.cell_size,
|
99
|
+
(x + 1) * self.cell_size,
|
100
|
+
(y + 1) * self.cell_size,
|
101
|
+
],
|
102
|
+
fill="yellow",
|
103
|
+
)
|
104
|
+
|
105
|
+
# Draw character
|
106
|
+
draw.text(
|
107
|
+
(x * self.cell_size + 2, y * self.cell_size),
|
108
|
+
char,
|
109
|
+
fill=color,
|
110
|
+
font=font,
|
111
|
+
)
|
112
|
+
|
113
|
+
return np.array(image)
|
114
|
+
|
115
|
+
def _simple_ascii_to_image(
|
116
|
+
self, ascii_map: str, highlight_pos: Optional[Tuple[int, int]] = None
|
117
|
+
) -> np.ndarray:
|
118
|
+
"""Simple ASCII to image conversion without PIL."""
|
119
|
+
lines = ascii_map.strip().split("\n")
|
120
|
+
height = len(lines)
|
121
|
+
width = max(len(line) for line in lines) if lines else 0
|
122
|
+
|
123
|
+
# Create RGB image
|
124
|
+
image = np.zeros((height * self.cell_size, width * self.cell_size, 3), dtype=np.uint8)
|
125
|
+
|
126
|
+
# Simple character to grayscale mapping
|
127
|
+
char_values = {
|
128
|
+
"@": 255, # Player - white
|
129
|
+
"#": 64, # Wall - dark gray
|
130
|
+
".": 128, # Floor - gray
|
131
|
+
">": 200, # Stairs - light
|
132
|
+
"<": 200, # Stairs - light
|
133
|
+
" ": 0, # Empty - black
|
134
|
+
}
|
135
|
+
|
136
|
+
for y, line in enumerate(lines):
|
137
|
+
for x, char in enumerate(line):
|
138
|
+
value = char_values.get(char, 100)
|
139
|
+
y_start = y * self.cell_size
|
140
|
+
y_end = (y + 1) * self.cell_size
|
141
|
+
x_start = x * self.cell_size
|
142
|
+
x_end = (x + 1) * self.cell_size
|
143
|
+
|
144
|
+
if highlight_pos and (x, y) == highlight_pos:
|
145
|
+
# Highlight player in red
|
146
|
+
image[y_start:y_end, x_start:x_end] = [255, 0, 0]
|
147
|
+
else:
|
148
|
+
image[y_start:y_end, x_start:x_end] = [value, value, value]
|
149
|
+
|
150
|
+
return image
|
151
|
+
|
152
|
+
def create_frame_image(
|
153
|
+
self, observation: Dict[str, Any], include_stats: bool = True
|
154
|
+
) -> np.ndarray:
|
155
|
+
"""Create a single frame image from observation."""
|
156
|
+
# Standard NetHack terminal size
|
157
|
+
TERM_WIDTH = 80
|
158
|
+
TERM_HEIGHT = 24
|
159
|
+
|
160
|
+
if HAS_PIL:
|
161
|
+
return self._create_terminal_view(observation, TERM_WIDTH, TERM_HEIGHT)
|
162
|
+
else:
|
163
|
+
# Fallback to simple view
|
164
|
+
ascii_map = observation.get("ascii_map", "")
|
165
|
+
player_stats = observation.get("player_stats", {})
|
166
|
+
player_pos = (player_stats.get("x", 0), player_stats.get("y", 0))
|
167
|
+
return self.ascii_to_image(ascii_map, player_pos)
|
168
|
+
|
169
|
+
def _create_terminal_view(
|
170
|
+
self, observation: Dict[str, Any], width: int = 80, height: int = 24
|
171
|
+
) -> np.ndarray:
|
172
|
+
"""Create a full terminal-style NetHack view."""
|
173
|
+
# Create black background
|
174
|
+
img_width = width * self.cell_size
|
175
|
+
img_height = height * self.cell_size
|
176
|
+
image = Image.new("RGB", (img_width, img_height), color="black")
|
177
|
+
draw = ImageDraw.Draw(image)
|
178
|
+
|
179
|
+
# Try to load a monospace font
|
180
|
+
try:
|
181
|
+
font = ImageFont.truetype("/System/Library/Fonts/Monaco.ttf", self.font_size)
|
182
|
+
except:
|
183
|
+
try:
|
184
|
+
font = ImageFont.truetype(
|
185
|
+
"/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf",
|
186
|
+
self.font_size,
|
187
|
+
)
|
188
|
+
except:
|
189
|
+
font = ImageFont.load_default()
|
190
|
+
|
191
|
+
# Get game data
|
192
|
+
ascii_map = observation.get("ascii_map", "")
|
193
|
+
player_stats = observation.get("player_stats", {})
|
194
|
+
message = observation.get("message", "").strip()
|
195
|
+
|
196
|
+
# Draw the map (first 21 lines)
|
197
|
+
lines = ascii_map.split("\n")
|
198
|
+
for y, line in enumerate(lines[:21]):
|
199
|
+
for x, char in enumerate(line[:width]):
|
200
|
+
if char == " ":
|
201
|
+
continue
|
202
|
+
|
203
|
+
# Get color for character
|
204
|
+
color = self.CHAR_COLORS.get(char, self.DEFAULT_COLOR)
|
205
|
+
|
206
|
+
# Highlight player position
|
207
|
+
if (x, y) == (player_stats.get("x", -1), player_stats.get("y", -1)):
|
208
|
+
# Draw yellow background for player
|
209
|
+
draw.rectangle(
|
210
|
+
[
|
211
|
+
x * self.cell_size,
|
212
|
+
y * self.cell_size,
|
213
|
+
(x + 1) * self.cell_size,
|
214
|
+
(y + 1) * self.cell_size,
|
215
|
+
],
|
216
|
+
fill="#333300",
|
217
|
+
)
|
218
|
+
|
219
|
+
# Draw character
|
220
|
+
draw.text(
|
221
|
+
(x * self.cell_size + 1, y * self.cell_size),
|
222
|
+
char,
|
223
|
+
fill=color,
|
224
|
+
font=font,
|
225
|
+
)
|
226
|
+
|
227
|
+
# Draw separator line (line 21)
|
228
|
+
draw.line(
|
229
|
+
[(0, 21 * self.cell_size), (img_width, 21 * self.cell_size)],
|
230
|
+
fill="#666666",
|
231
|
+
width=1,
|
232
|
+
)
|
233
|
+
|
234
|
+
# Draw status line (line 22)
|
235
|
+
status_y = 22 * self.cell_size
|
236
|
+
character_name = observation.get("character_name", "Agent")
|
237
|
+
character_role = observation.get("character_role", "Adventurer")
|
238
|
+
|
239
|
+
# Format status line like real NetHack
|
240
|
+
status_parts = [
|
241
|
+
f"{character_name} the {character_role}",
|
242
|
+
f"St:{player_stats.get('strength', 10)}",
|
243
|
+
f"Dx:{player_stats.get('dexterity', 10)}",
|
244
|
+
f"Co:{player_stats.get('constitution', 10)}",
|
245
|
+
f"In:{player_stats.get('intelligence', 10)}",
|
246
|
+
f"Wi:{player_stats.get('wisdom', 10)}",
|
247
|
+
f"Ch:{player_stats.get('charisma', 10)}",
|
248
|
+
"Neutral", # Alignment
|
249
|
+
]
|
250
|
+
|
251
|
+
status_line = " ".join(status_parts)
|
252
|
+
draw.text((5, status_y), status_line, fill="white", font=font)
|
253
|
+
|
254
|
+
# Draw second status line (line 23)
|
255
|
+
status2_y = 23 * self.cell_size
|
256
|
+
dlvl = player_stats.get("depth", 1)
|
257
|
+
gold = player_stats.get("gold", 0)
|
258
|
+
hp = player_stats.get("hp", 10)
|
259
|
+
max_hp = player_stats.get("max_hp", 10)
|
260
|
+
pw = player_stats.get("energy", 0)
|
261
|
+
max_pw = player_stats.get("max_energy", 0)
|
262
|
+
ac = player_stats.get("ac", 10)
|
263
|
+
xp = player_stats.get("experience_level", 1)
|
264
|
+
|
265
|
+
status2_parts = [
|
266
|
+
f"Dlvl:{dlvl}",
|
267
|
+
f"$:{gold}",
|
268
|
+
f"HP:{hp}({max_hp})",
|
269
|
+
f"Pw:{pw}({max_pw})",
|
270
|
+
f"AC:{ac}",
|
271
|
+
f"Xp:{xp}",
|
272
|
+
]
|
273
|
+
|
274
|
+
# Add turn count if available
|
275
|
+
if "turn_count" in observation:
|
276
|
+
status2_parts.append(f"T:{observation['turn_count']}")
|
277
|
+
|
278
|
+
status2_line = " ".join(status2_parts)
|
279
|
+
draw.text((5, status2_y), status2_line, fill="white", font=font)
|
280
|
+
|
281
|
+
# Draw message at top if present
|
282
|
+
if message:
|
283
|
+
# Clear message area (line 0)
|
284
|
+
draw.rectangle([0, 0, img_width, self.cell_size], fill="black")
|
285
|
+
draw.text((5, 0), message[: width - 1], fill="white", font=font)
|
286
|
+
|
287
|
+
return np.array(image)
|
288
|
+
|
289
|
+
def create_trajectory_video(
|
290
|
+
self,
|
291
|
+
frames: List[Dict[str, Any]],
|
292
|
+
output_path: str,
|
293
|
+
fps: int = 4,
|
294
|
+
include_stats: bool = True,
|
295
|
+
) -> str:
|
296
|
+
"""Create a video from trajectory frames."""
|
297
|
+
if not frames:
|
298
|
+
raise ValueError("No frames to create video from")
|
299
|
+
|
300
|
+
# Create figure and axis
|
301
|
+
fig, ax = plt.subplots(figsize=(10, 12))
|
302
|
+
ax.axis("off")
|
303
|
+
|
304
|
+
# Create first frame
|
305
|
+
first_img = self.create_frame_image(frames[0]["observation"], include_stats)
|
306
|
+
im = ax.imshow(first_img)
|
307
|
+
|
308
|
+
def animate(i):
|
309
|
+
if i < len(frames):
|
310
|
+
img = self.create_frame_image(frames[i]["observation"], include_stats)
|
311
|
+
im.set_array(img)
|
312
|
+
ax.set_title(f"Step {i}: {frames[i]['action']}")
|
313
|
+
return [im]
|
314
|
+
|
315
|
+
# Create animation
|
316
|
+
anim = animation.FuncAnimation(
|
317
|
+
fig, animate, frames=len(frames), interval=1000 / fps, blit=True
|
318
|
+
)
|
319
|
+
|
320
|
+
# Save as video
|
321
|
+
output_path = Path(output_path)
|
322
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
323
|
+
|
324
|
+
if output_path.suffix == ".gif":
|
325
|
+
anim.save(str(output_path), writer="pillow", fps=fps)
|
326
|
+
else:
|
327
|
+
anim.save(str(output_path), writer="ffmpeg", fps=fps)
|
328
|
+
|
329
|
+
plt.close()
|
330
|
+
|
331
|
+
return str(output_path)
|
332
|
+
|
333
|
+
def plot_trajectory_stats(
|
334
|
+
self, frames: List[Dict[str, Any]], output_path: Optional[str] = None
|
335
|
+
):
|
336
|
+
"""Plot statistics from a trajectory."""
|
337
|
+
if not frames:
|
338
|
+
return
|
339
|
+
|
340
|
+
# Extract data
|
341
|
+
steps = []
|
342
|
+
rewards = []
|
343
|
+
depths = []
|
344
|
+
hps = []
|
345
|
+
positions_x = []
|
346
|
+
positions_y = []
|
347
|
+
|
348
|
+
cumulative_reward = 0
|
349
|
+
for i, frame in enumerate(frames):
|
350
|
+
steps.append(i)
|
351
|
+
cumulative_reward += frame["reward"]
|
352
|
+
rewards.append(cumulative_reward)
|
353
|
+
|
354
|
+
stats = frame["observation"].get("player_stats", {})
|
355
|
+
depths.append(stats.get("depth", 1))
|
356
|
+
hps.append(stats.get("hp", 0))
|
357
|
+
positions_x.append(stats.get("x", 0))
|
358
|
+
positions_y.append(stats.get("y", 0))
|
359
|
+
|
360
|
+
# Create subplots
|
361
|
+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
362
|
+
|
363
|
+
# Cumulative reward
|
364
|
+
axes[0, 0].plot(steps, rewards, "b-")
|
365
|
+
axes[0, 0].set_xlabel("Step")
|
366
|
+
axes[0, 0].set_ylabel("Cumulative Reward")
|
367
|
+
axes[0, 0].set_title("Reward Progress")
|
368
|
+
axes[0, 0].grid(True)
|
369
|
+
|
370
|
+
# Depth progression
|
371
|
+
axes[0, 1].plot(steps, depths, "g-")
|
372
|
+
axes[0, 1].set_xlabel("Step")
|
373
|
+
axes[0, 1].set_ylabel("Dungeon Depth")
|
374
|
+
axes[0, 1].set_title("Depth Exploration")
|
375
|
+
axes[0, 1].grid(True)
|
376
|
+
|
377
|
+
# HP over time
|
378
|
+
axes[1, 0].plot(steps, hps, "r-")
|
379
|
+
axes[1, 0].set_xlabel("Step")
|
380
|
+
axes[1, 0].set_ylabel("Hit Points")
|
381
|
+
axes[1, 0].set_title("Health Over Time")
|
382
|
+
axes[1, 0].grid(True)
|
383
|
+
|
384
|
+
# Position heatmap
|
385
|
+
axes[1, 1].scatter(positions_x, positions_y, c=steps, cmap="viridis", s=1)
|
386
|
+
axes[1, 1].set_xlabel("X Position")
|
387
|
+
axes[1, 1].set_ylabel("Y Position")
|
388
|
+
axes[1, 1].set_title("Movement Pattern")
|
389
|
+
axes[1, 1].invert_yaxis() # Invert Y axis for proper orientation
|
390
|
+
|
391
|
+
plt.tight_layout()
|
392
|
+
|
393
|
+
if output_path:
|
394
|
+
plt.savefig(output_path, dpi=150)
|
395
|
+
plt.close()
|
396
|
+
else:
|
397
|
+
plt.show()
|
398
|
+
|
399
|
+
def plot_action_distribution(
|
400
|
+
self, frames: List[Dict[str, Any]], output_path: Optional[str] = None
|
401
|
+
):
|
402
|
+
"""Plot distribution of actions taken."""
|
403
|
+
if not frames:
|
404
|
+
return
|
405
|
+
|
406
|
+
# Count actions
|
407
|
+
action_counts = {}
|
408
|
+
for frame in frames:
|
409
|
+
action = frame["action"]
|
410
|
+
action_counts[action] = action_counts.get(action, 0) + 1
|
411
|
+
|
412
|
+
# Sort by frequency
|
413
|
+
actions = sorted(action_counts.items(), key=lambda x: x[1], reverse=True)
|
414
|
+
|
415
|
+
# Create bar plot
|
416
|
+
plt.figure(figsize=(12, 6))
|
417
|
+
actions_list, counts = zip(*actions[:20]) # Top 20 actions
|
418
|
+
|
419
|
+
plt.bar(range(len(actions_list)), counts)
|
420
|
+
plt.xticks(range(len(actions_list)), actions_list, rotation=45, ha="right")
|
421
|
+
plt.xlabel("Action")
|
422
|
+
plt.ylabel("Count")
|
423
|
+
plt.title("Action Distribution (Top 20)")
|
424
|
+
plt.tight_layout()
|
425
|
+
|
426
|
+
if output_path:
|
427
|
+
plt.savefig(output_path, dpi=150)
|
428
|
+
plt.close()
|
429
|
+
else:
|
430
|
+
plt.show()
|