synth-ai 0.2.4.dev3__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +575 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. synth_ai/tracing_v3/examples/basic_usage.py +188 -0
  100. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
  101. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +105 -6
  102. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
  103. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
  104. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
  105. {synth_ai-0.2.4.dev3.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,430 @@
1
+ """Visualization tools for NetHack trajectories."""
2
+
3
+ import numpy as np
4
+ from typing import List, Dict, Any, Optional, Tuple
5
+ from pathlib import Path
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.animation as animation
8
+ from matplotlib.patches import Rectangle
9
+ import seaborn as sns
10
+ from datetime import datetime
11
+ import json
12
+
13
+ try:
14
+ from PIL import Image, ImageDraw, ImageFont
15
+
16
+ HAS_PIL = True
17
+ except ImportError:
18
+ HAS_PIL = False
19
+ print("Warning: PIL not available. Some visualization features will be limited.")
20
+
21
+
22
+ class NetHackVisualizer:
23
+ """Visualize NetHack game states and trajectories."""
24
+
25
+ # Character to color mapping for ASCII visualization
26
+ CHAR_COLORS = {
27
+ "@": "#FF0000", # Player - red
28
+ "d": "#8B4513", # Dog/pet - brown
29
+ "f": "#FFA500", # Cat - orange
30
+ ">": "#0000FF", # Stairs down - blue
31
+ "<": "#00FF00", # Stairs up - green
32
+ ".": "#D3D3D3", # Floor - light gray
33
+ "#": "#696969", # Wall - dark gray
34
+ "+": "#8B4513", # Door - brown
35
+ "-": "#8B4513", # Door - brown
36
+ "|": "#8B4513", # Door - brown
37
+ "{": "#00CED1", # Fountain - dark turquoise
38
+ "}": "#4682B4", # Pool - steel blue
39
+ "^": "#FF1493", # Trap - deep pink
40
+ "%": "#FFD700", # Food - gold
41
+ "!": "#FF69B4", # Potion - hot pink
42
+ "?": "#DDA0DD", # Scroll - plum
43
+ "/": "#9370DB", # Wand - medium purple
44
+ "=": "#FFD700", # Ring - gold
45
+ '"': "#FF4500", # Amulet - orange red
46
+ "[": "#C0C0C0", # Armor - silver
47
+ ")": "#A9A9A9", # Weapon - dark gray
48
+ "*": "#FFFF00", # Gold/gem - yellow
49
+ "$": "#FFD700", # Gold - gold
50
+ "`": "#8B4513", # Boulder/statue - brown
51
+ }
52
+
53
+ # Default color for unknown characters
54
+ DEFAULT_COLOR = "#FFFFFF"
55
+
56
+ def __init__(self, cell_size: int = 10, font_size: int = 8):
57
+ self.cell_size = cell_size
58
+ self.font_size = font_size
59
+
60
+ def ascii_to_image(
61
+ self, ascii_map: str, highlight_pos: Optional[Tuple[int, int]] = None
62
+ ) -> np.ndarray:
63
+ """Convert ASCII map to colored image."""
64
+ if not HAS_PIL:
65
+ return self._simple_ascii_to_image(ascii_map, highlight_pos)
66
+
67
+ lines = ascii_map.strip().split("\n")
68
+ height = len(lines)
69
+ width = max(len(line) for line in lines) if lines else 0
70
+
71
+ # Create image
72
+ img_width = width * self.cell_size
73
+ img_height = height * self.cell_size
74
+ image = Image.new("RGB", (img_width, img_height), color="black")
75
+ draw = ImageDraw.Draw(image)
76
+
77
+ # Try to load a monospace font
78
+ try:
79
+ font = ImageFont.truetype("/System/Library/Fonts/Monaco.ttf", self.font_size)
80
+ except:
81
+ font = ImageFont.load_default()
82
+
83
+ # Draw each character
84
+ for y, line in enumerate(lines):
85
+ for x, char in enumerate(line):
86
+ if char == " ":
87
+ continue
88
+
89
+ # Get color for character
90
+ color = self.CHAR_COLORS.get(char, self.DEFAULT_COLOR)
91
+
92
+ # Highlight player position
93
+ if highlight_pos and (x, y) == highlight_pos:
94
+ # Draw background
95
+ draw.rectangle(
96
+ [
97
+ x * self.cell_size,
98
+ y * self.cell_size,
99
+ (x + 1) * self.cell_size,
100
+ (y + 1) * self.cell_size,
101
+ ],
102
+ fill="yellow",
103
+ )
104
+
105
+ # Draw character
106
+ draw.text(
107
+ (x * self.cell_size + 2, y * self.cell_size),
108
+ char,
109
+ fill=color,
110
+ font=font,
111
+ )
112
+
113
+ return np.array(image)
114
+
115
+ def _simple_ascii_to_image(
116
+ self, ascii_map: str, highlight_pos: Optional[Tuple[int, int]] = None
117
+ ) -> np.ndarray:
118
+ """Simple ASCII to image conversion without PIL."""
119
+ lines = ascii_map.strip().split("\n")
120
+ height = len(lines)
121
+ width = max(len(line) for line in lines) if lines else 0
122
+
123
+ # Create RGB image
124
+ image = np.zeros((height * self.cell_size, width * self.cell_size, 3), dtype=np.uint8)
125
+
126
+ # Simple character to grayscale mapping
127
+ char_values = {
128
+ "@": 255, # Player - white
129
+ "#": 64, # Wall - dark gray
130
+ ".": 128, # Floor - gray
131
+ ">": 200, # Stairs - light
132
+ "<": 200, # Stairs - light
133
+ " ": 0, # Empty - black
134
+ }
135
+
136
+ for y, line in enumerate(lines):
137
+ for x, char in enumerate(line):
138
+ value = char_values.get(char, 100)
139
+ y_start = y * self.cell_size
140
+ y_end = (y + 1) * self.cell_size
141
+ x_start = x * self.cell_size
142
+ x_end = (x + 1) * self.cell_size
143
+
144
+ if highlight_pos and (x, y) == highlight_pos:
145
+ # Highlight player in red
146
+ image[y_start:y_end, x_start:x_end] = [255, 0, 0]
147
+ else:
148
+ image[y_start:y_end, x_start:x_end] = [value, value, value]
149
+
150
+ return image
151
+
152
+ def create_frame_image(
153
+ self, observation: Dict[str, Any], include_stats: bool = True
154
+ ) -> np.ndarray:
155
+ """Create a single frame image from observation."""
156
+ # Standard NetHack terminal size
157
+ TERM_WIDTH = 80
158
+ TERM_HEIGHT = 24
159
+
160
+ if HAS_PIL:
161
+ return self._create_terminal_view(observation, TERM_WIDTH, TERM_HEIGHT)
162
+ else:
163
+ # Fallback to simple view
164
+ ascii_map = observation.get("ascii_map", "")
165
+ player_stats = observation.get("player_stats", {})
166
+ player_pos = (player_stats.get("x", 0), player_stats.get("y", 0))
167
+ return self.ascii_to_image(ascii_map, player_pos)
168
+
169
+ def _create_terminal_view(
170
+ self, observation: Dict[str, Any], width: int = 80, height: int = 24
171
+ ) -> np.ndarray:
172
+ """Create a full terminal-style NetHack view."""
173
+ # Create black background
174
+ img_width = width * self.cell_size
175
+ img_height = height * self.cell_size
176
+ image = Image.new("RGB", (img_width, img_height), color="black")
177
+ draw = ImageDraw.Draw(image)
178
+
179
+ # Try to load a monospace font
180
+ try:
181
+ font = ImageFont.truetype("/System/Library/Fonts/Monaco.ttf", self.font_size)
182
+ except:
183
+ try:
184
+ font = ImageFont.truetype(
185
+ "/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf",
186
+ self.font_size,
187
+ )
188
+ except:
189
+ font = ImageFont.load_default()
190
+
191
+ # Get game data
192
+ ascii_map = observation.get("ascii_map", "")
193
+ player_stats = observation.get("player_stats", {})
194
+ message = observation.get("message", "").strip()
195
+
196
+ # Draw the map (first 21 lines)
197
+ lines = ascii_map.split("\n")
198
+ for y, line in enumerate(lines[:21]):
199
+ for x, char in enumerate(line[:width]):
200
+ if char == " ":
201
+ continue
202
+
203
+ # Get color for character
204
+ color = self.CHAR_COLORS.get(char, self.DEFAULT_COLOR)
205
+
206
+ # Highlight player position
207
+ if (x, y) == (player_stats.get("x", -1), player_stats.get("y", -1)):
208
+ # Draw yellow background for player
209
+ draw.rectangle(
210
+ [
211
+ x * self.cell_size,
212
+ y * self.cell_size,
213
+ (x + 1) * self.cell_size,
214
+ (y + 1) * self.cell_size,
215
+ ],
216
+ fill="#333300",
217
+ )
218
+
219
+ # Draw character
220
+ draw.text(
221
+ (x * self.cell_size + 1, y * self.cell_size),
222
+ char,
223
+ fill=color,
224
+ font=font,
225
+ )
226
+
227
+ # Draw separator line (line 21)
228
+ draw.line(
229
+ [(0, 21 * self.cell_size), (img_width, 21 * self.cell_size)],
230
+ fill="#666666",
231
+ width=1,
232
+ )
233
+
234
+ # Draw status line (line 22)
235
+ status_y = 22 * self.cell_size
236
+ character_name = observation.get("character_name", "Agent")
237
+ character_role = observation.get("character_role", "Adventurer")
238
+
239
+ # Format status line like real NetHack
240
+ status_parts = [
241
+ f"{character_name} the {character_role}",
242
+ f"St:{player_stats.get('strength', 10)}",
243
+ f"Dx:{player_stats.get('dexterity', 10)}",
244
+ f"Co:{player_stats.get('constitution', 10)}",
245
+ f"In:{player_stats.get('intelligence', 10)}",
246
+ f"Wi:{player_stats.get('wisdom', 10)}",
247
+ f"Ch:{player_stats.get('charisma', 10)}",
248
+ "Neutral", # Alignment
249
+ ]
250
+
251
+ status_line = " ".join(status_parts)
252
+ draw.text((5, status_y), status_line, fill="white", font=font)
253
+
254
+ # Draw second status line (line 23)
255
+ status2_y = 23 * self.cell_size
256
+ dlvl = player_stats.get("depth", 1)
257
+ gold = player_stats.get("gold", 0)
258
+ hp = player_stats.get("hp", 10)
259
+ max_hp = player_stats.get("max_hp", 10)
260
+ pw = player_stats.get("energy", 0)
261
+ max_pw = player_stats.get("max_energy", 0)
262
+ ac = player_stats.get("ac", 10)
263
+ xp = player_stats.get("experience_level", 1)
264
+
265
+ status2_parts = [
266
+ f"Dlvl:{dlvl}",
267
+ f"$:{gold}",
268
+ f"HP:{hp}({max_hp})",
269
+ f"Pw:{pw}({max_pw})",
270
+ f"AC:{ac}",
271
+ f"Xp:{xp}",
272
+ ]
273
+
274
+ # Add turn count if available
275
+ if "turn_count" in observation:
276
+ status2_parts.append(f"T:{observation['turn_count']}")
277
+
278
+ status2_line = " ".join(status2_parts)
279
+ draw.text((5, status2_y), status2_line, fill="white", font=font)
280
+
281
+ # Draw message at top if present
282
+ if message:
283
+ # Clear message area (line 0)
284
+ draw.rectangle([0, 0, img_width, self.cell_size], fill="black")
285
+ draw.text((5, 0), message[: width - 1], fill="white", font=font)
286
+
287
+ return np.array(image)
288
+
289
+ def create_trajectory_video(
290
+ self,
291
+ frames: List[Dict[str, Any]],
292
+ output_path: str,
293
+ fps: int = 4,
294
+ include_stats: bool = True,
295
+ ) -> str:
296
+ """Create a video from trajectory frames."""
297
+ if not frames:
298
+ raise ValueError("No frames to create video from")
299
+
300
+ # Create figure and axis
301
+ fig, ax = plt.subplots(figsize=(10, 12))
302
+ ax.axis("off")
303
+
304
+ # Create first frame
305
+ first_img = self.create_frame_image(frames[0]["observation"], include_stats)
306
+ im = ax.imshow(first_img)
307
+
308
+ def animate(i):
309
+ if i < len(frames):
310
+ img = self.create_frame_image(frames[i]["observation"], include_stats)
311
+ im.set_array(img)
312
+ ax.set_title(f"Step {i}: {frames[i]['action']}")
313
+ return [im]
314
+
315
+ # Create animation
316
+ anim = animation.FuncAnimation(
317
+ fig, animate, frames=len(frames), interval=1000 / fps, blit=True
318
+ )
319
+
320
+ # Save as video
321
+ output_path = Path(output_path)
322
+ output_path.parent.mkdir(parents=True, exist_ok=True)
323
+
324
+ if output_path.suffix == ".gif":
325
+ anim.save(str(output_path), writer="pillow", fps=fps)
326
+ else:
327
+ anim.save(str(output_path), writer="ffmpeg", fps=fps)
328
+
329
+ plt.close()
330
+
331
+ return str(output_path)
332
+
333
+ def plot_trajectory_stats(
334
+ self, frames: List[Dict[str, Any]], output_path: Optional[str] = None
335
+ ):
336
+ """Plot statistics from a trajectory."""
337
+ if not frames:
338
+ return
339
+
340
+ # Extract data
341
+ steps = []
342
+ rewards = []
343
+ depths = []
344
+ hps = []
345
+ positions_x = []
346
+ positions_y = []
347
+
348
+ cumulative_reward = 0
349
+ for i, frame in enumerate(frames):
350
+ steps.append(i)
351
+ cumulative_reward += frame["reward"]
352
+ rewards.append(cumulative_reward)
353
+
354
+ stats = frame["observation"].get("player_stats", {})
355
+ depths.append(stats.get("depth", 1))
356
+ hps.append(stats.get("hp", 0))
357
+ positions_x.append(stats.get("x", 0))
358
+ positions_y.append(stats.get("y", 0))
359
+
360
+ # Create subplots
361
+ fig, axes = plt.subplots(2, 2, figsize=(12, 10))
362
+
363
+ # Cumulative reward
364
+ axes[0, 0].plot(steps, rewards, "b-")
365
+ axes[0, 0].set_xlabel("Step")
366
+ axes[0, 0].set_ylabel("Cumulative Reward")
367
+ axes[0, 0].set_title("Reward Progress")
368
+ axes[0, 0].grid(True)
369
+
370
+ # Depth progression
371
+ axes[0, 1].plot(steps, depths, "g-")
372
+ axes[0, 1].set_xlabel("Step")
373
+ axes[0, 1].set_ylabel("Dungeon Depth")
374
+ axes[0, 1].set_title("Depth Exploration")
375
+ axes[0, 1].grid(True)
376
+
377
+ # HP over time
378
+ axes[1, 0].plot(steps, hps, "r-")
379
+ axes[1, 0].set_xlabel("Step")
380
+ axes[1, 0].set_ylabel("Hit Points")
381
+ axes[1, 0].set_title("Health Over Time")
382
+ axes[1, 0].grid(True)
383
+
384
+ # Position heatmap
385
+ axes[1, 1].scatter(positions_x, positions_y, c=steps, cmap="viridis", s=1)
386
+ axes[1, 1].set_xlabel("X Position")
387
+ axes[1, 1].set_ylabel("Y Position")
388
+ axes[1, 1].set_title("Movement Pattern")
389
+ axes[1, 1].invert_yaxis() # Invert Y axis for proper orientation
390
+
391
+ plt.tight_layout()
392
+
393
+ if output_path:
394
+ plt.savefig(output_path, dpi=150)
395
+ plt.close()
396
+ else:
397
+ plt.show()
398
+
399
+ def plot_action_distribution(
400
+ self, frames: List[Dict[str, Any]], output_path: Optional[str] = None
401
+ ):
402
+ """Plot distribution of actions taken."""
403
+ if not frames:
404
+ return
405
+
406
+ # Count actions
407
+ action_counts = {}
408
+ for frame in frames:
409
+ action = frame["action"]
410
+ action_counts[action] = action_counts.get(action, 0) + 1
411
+
412
+ # Sort by frequency
413
+ actions = sorted(action_counts.items(), key=lambda x: x[1], reverse=True)
414
+
415
+ # Create bar plot
416
+ plt.figure(figsize=(12, 6))
417
+ actions_list, counts = zip(*actions[:20]) # Top 20 actions
418
+
419
+ plt.bar(range(len(actions_list)), counts)
420
+ plt.xticks(range(len(actions_list)), actions_list, rotation=45, ha="right")
421
+ plt.xlabel("Action")
422
+ plt.ylabel("Count")
423
+ plt.title("Action Distribution (Top 20)")
424
+ plt.tight_layout()
425
+
426
+ if output_path:
427
+ plt.savefig(output_path, dpi=150)
428
+ plt.close()
429
+ else:
430
+ plt.show()