synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. synth_ai/environments/examples/__init__.py +1 -0
  2. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  3. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  4. synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
  5. synth_ai/environments/examples/crafter_classic/engine.py +575 -0
  6. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  7. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  8. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  9. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
  10. synth_ai/environments/examples/crafter_classic/environment.py +364 -0
  11. synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
  12. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
  13. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
  14. synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
  15. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
  16. synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
  17. synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
  18. synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
  19. synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
  20. synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
  21. synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
  22. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
  23. synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
  24. synth_ai/environments/examples/crafter_custom/environment.py +312 -0
  25. synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
  26. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  27. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  28. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  29. synth_ai/environments/examples/enron/engine.py +291 -0
  30. synth_ai/environments/examples/enron/environment.py +165 -0
  31. synth_ai/environments/examples/enron/taskset.py +112 -0
  32. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  33. synth_ai/environments/examples/minigrid/engine.py +589 -0
  34. synth_ai/environments/examples/minigrid/environment.py +274 -0
  35. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  36. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  37. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  38. synth_ai/environments/examples/nethack/__init__.py +7 -0
  39. synth_ai/environments/examples/nethack/achievements.py +337 -0
  40. synth_ai/environments/examples/nethack/engine.py +738 -0
  41. synth_ai/environments/examples/nethack/environment.py +255 -0
  42. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  43. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  44. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  45. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  46. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  47. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  48. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  49. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  50. synth_ai/environments/examples/nethack/taskset.py +323 -0
  51. synth_ai/environments/examples/red/__init__.py +7 -0
  52. synth_ai/environments/examples/red/config_logging.py +110 -0
  53. synth_ai/environments/examples/red/engine.py +693 -0
  54. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  55. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  56. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  57. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  58. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  59. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  60. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  61. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  62. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  63. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  64. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  69. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  70. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  71. synth_ai/environments/examples/red/environment.py +235 -0
  72. synth_ai/environments/examples/red/taskset.py +77 -0
  73. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  74. synth_ai/environments/examples/sokoban/engine.py +675 -0
  75. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  76. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  77. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  78. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  79. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  80. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  81. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  82. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  87. synth_ai/environments/examples/sokoban/environment.py +228 -0
  88. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  89. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  90. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  91. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  92. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  93. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  94. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  95. synth_ai/environments/examples/verilog/__init__.py +10 -0
  96. synth_ai/environments/examples/verilog/engine.py +328 -0
  97. synth_ai/environments/examples/verilog/environment.py +349 -0
  98. synth_ai/environments/examples/verilog/taskset.py +418 -0
  99. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/METADATA +1 -1
  100. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/RECORD +104 -6
  101. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/WHEEL +0 -0
  102. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/entry_points.txt +0 -0
  103. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/licenses/LICENSE +0 -0
  104. {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
1
+ """Observation processing utilities for NetHack."""
2
+
3
+ from typing import Dict, Any, List, Optional, Tuple
4
+ import re
5
+
6
+
7
+ def format_observation_for_llm(observation: Dict[str, Any]) -> str:
8
+ """
9
+ Format NetHack observation for LLM consumption.
10
+
11
+ Args:
12
+ observation: Raw observation dictionary
13
+
14
+ Returns:
15
+ Formatted string suitable for LLM input
16
+ """
17
+ lines = []
18
+
19
+ # Header with turn count and location
20
+ lines.append(f"=== NetHack - Turn {observation.get('turn_count', 0)} ===")
21
+ lines.append(f"Dungeon Level: {observation.get('dungeon_level', 1)}")
22
+
23
+ # ASCII map
24
+ if "ascii_map" in observation:
25
+ lines.append("\n--- Dungeon Map ---")
26
+ lines.append(observation["ascii_map"])
27
+
28
+ # Game message
29
+ if "message" in observation and observation["message"]:
30
+ # Clean null bytes and trailing whitespace
31
+ message = observation["message"].rstrip("\x00").strip()
32
+ if message:
33
+ lines.append(f"\nMessage: {message}")
34
+
35
+ # Character stats
36
+ if "character_stats" in observation:
37
+ stats = observation["character_stats"]
38
+ lines.append("\n--- Character Status ---")
39
+ lines.append(f"HP: {stats.get('hp', 0)}/{stats.get('max_hp', 0)}")
40
+ lines.append(f"Level: {stats.get('level', 1)} (Exp: {stats.get('experience', 0)})")
41
+ lines.append(f"AC: {stats.get('ac', 10)}, Gold: {stats.get('gold', 0)}")
42
+
43
+ # Attributes
44
+ attrs = []
45
+ for attr in [
46
+ "strength",
47
+ "dexterity",
48
+ "constitution",
49
+ "intelligence",
50
+ "wisdom",
51
+ "charisma",
52
+ ]:
53
+ if attr in stats:
54
+ attrs.append(f"{attr[:3].upper()}:{stats[attr]}")
55
+ if attrs:
56
+ lines.append(f"Attributes: {' '.join(attrs)}")
57
+
58
+ # Inventory summary
59
+ if "inventory_summary" in observation:
60
+ lines.append("\n--- Inventory ---")
61
+ lines.append(observation["inventory_summary"])
62
+
63
+ # Menu items if in menu
64
+ if observation.get("in_menu", False) and "menu_items" in observation:
65
+ lines.append("\n--- Menu Options ---")
66
+ for i, item in enumerate(observation["menu_items"]):
67
+ if i < 26:
68
+ lines.append(f"{chr(ord('a') + i)}) {item}")
69
+ else:
70
+ lines.append(f"{i - 26}) {item}")
71
+
72
+ # Score and rewards
73
+ lines.append("\n--- Progress ---")
74
+ lines.append(f"Score: {observation.get('score', 0)}")
75
+ lines.append(f"Total Reward: {observation.get('total_reward', 0.0):.2f}")
76
+ lines.append(f"Last Reward: {observation.get('reward_last', 0.0):.2f}")
77
+
78
+ # Termination status
79
+ if observation.get("terminated", False):
80
+ lines.append("\n*** GAME OVER ***")
81
+
82
+ return "\n".join(lines)
83
+
84
+
85
+ def parse_ascii_map(ascii_map: str) -> Dict[str, Any]:
86
+ """
87
+ Parse ASCII map to extract key information.
88
+
89
+ Args:
90
+ ascii_map: Raw ASCII map string
91
+
92
+ Returns:
93
+ Dictionary with extracted map information
94
+ """
95
+ lines = ascii_map.strip().split("\n")
96
+
97
+ map_info = {
98
+ "width": 0,
99
+ "height": len(lines),
100
+ "player_position": None,
101
+ "stairs_positions": [],
102
+ "door_positions": [],
103
+ "item_positions": [],
104
+ "monster_positions": [],
105
+ "wall_positions": [],
106
+ "floor_positions": [],
107
+ }
108
+
109
+ if lines:
110
+ map_info["width"] = max(len(line) for line in lines)
111
+
112
+ # Common NetHack ASCII symbols
113
+ symbols = {
114
+ "@": "player",
115
+ "<": "stairs_up",
116
+ ">": "stairs_down",
117
+ "+": "closed_door",
118
+ "-": "open_door_horizontal",
119
+ "|": "open_door_vertical",
120
+ ".": "floor",
121
+ "#": "corridor",
122
+ " ": "wall",
123
+ "$": "gold",
124
+ "*": "gem",
125
+ "!": "potion",
126
+ "?": "scroll",
127
+ "/": "wand",
128
+ "=": "ring",
129
+ '"': "amulet",
130
+ "[": "armor",
131
+ ")": "weapon",
132
+ "(": "tool",
133
+ "%": "food",
134
+ "^": "trap",
135
+ }
136
+
137
+ # Monster symbols (letters)
138
+ for y, line in enumerate(lines):
139
+ for x, char in enumerate(line):
140
+ pos = (x, y)
141
+
142
+ if char == "@":
143
+ map_info["player_position"] = pos
144
+ elif char == "<":
145
+ map_info["stairs_positions"].append(("up", pos))
146
+ elif char == ">":
147
+ map_info["stairs_positions"].append(("down", pos))
148
+ elif char in ["+", "-", "|"]:
149
+ map_info["door_positions"].append((char, pos))
150
+ elif char in ["$", "*", "!", "?", "/", "=", '"', "[", ")", "(", "%"]:
151
+ map_info["item_positions"].append((char, pos))
152
+ elif char.isalpha() and char != "@":
153
+ # Store both the character and position for monster identification
154
+ map_info["monster_positions"].append((char, pos))
155
+ elif char == ".":
156
+ map_info["floor_positions"].append(pos)
157
+ elif char in ["#", " "]:
158
+ map_info["wall_positions"].append(pos)
159
+
160
+ return map_info
161
+
162
+
163
+ def extract_game_context(observation: Dict[str, Any]) -> Dict[str, Any]:
164
+ """
165
+ Extract high-level game context from observation.
166
+
167
+ Args:
168
+ observation: Raw observation dictionary
169
+
170
+ Returns:
171
+ Dictionary with game context information
172
+ """
173
+ context = {
174
+ "in_combat": False,
175
+ "in_shop": False,
176
+ "at_stairs": False,
177
+ "items_nearby": False,
178
+ "doors_nearby": False,
179
+ "low_health": False,
180
+ "hungry": False,
181
+ "encumbered": False,
182
+ "in_menu": observation.get("in_menu", False),
183
+ "game_over": observation.get("terminated", False),
184
+ }
185
+
186
+ # Parse map for context
187
+ if "ascii_map" in observation:
188
+ map_info = parse_ascii_map(observation["ascii_map"])
189
+
190
+ # Check if player is near important features
191
+ if map_info["player_position"]:
192
+ px, py = map_info["player_position"]
193
+
194
+ # Check for nearby monsters (within 2 squares)
195
+ # Exclude pets from combat detection
196
+ pet_symbols = ["f", "d"] # f = kitten, d = dog
197
+ for monster_char, (mx, my) in map_info["monster_positions"]:
198
+ if abs(mx - px) <= 2 and abs(my - py) <= 2:
199
+ # Only trigger combat for non-pet monsters
200
+ if monster_char not in pet_symbols:
201
+ context["in_combat"] = True
202
+ break
203
+
204
+ # Check for stairs
205
+ for stair_type, (sx, sy) in map_info["stairs_positions"]:
206
+ if sx == px and sy == py:
207
+ context["at_stairs"] = True
208
+ context["stairs_type"] = stair_type
209
+ break
210
+
211
+ # Check for nearby items
212
+ for _, (ix, iy) in map_info["item_positions"]:
213
+ if abs(ix - px) <= 1 and abs(iy - py) <= 1:
214
+ context["items_nearby"] = True
215
+ break
216
+
217
+ # Check for nearby doors
218
+ for _, (dx, dy) in map_info["door_positions"]:
219
+ if abs(dx - px) <= 1 and abs(dy - py) <= 1:
220
+ context["doors_nearby"] = True
221
+ break
222
+
223
+ # Check health status
224
+ if "character_stats" in observation:
225
+ stats = observation["character_stats"]
226
+ hp = stats.get("hp", 0)
227
+ max_hp = stats.get("max_hp", 1)
228
+ if hp < max_hp * 0.3:
229
+ context["low_health"] = True
230
+
231
+ # Check for specific messages
232
+ message = observation.get("message", "").lower()
233
+ if "hungry" in message or "weak" in message:
234
+ context["hungry"] = True
235
+ if "burdened" in message or "stressed" in message:
236
+ context["encumbered"] = True
237
+ if "shop" in message or "shopkeeper" in message:
238
+ context["in_shop"] = True
239
+
240
+ return context
241
+
242
+
243
+ def simplify_observation(observation: Dict[str, Any]) -> Dict[str, Any]:
244
+ """
245
+ Create a simplified observation for agents that need less detail.
246
+
247
+ Args:
248
+ observation: Full observation dictionary
249
+
250
+ Returns:
251
+ Simplified observation dictionary
252
+ """
253
+ simplified = {
254
+ "turn": observation.get("turn_count", 0),
255
+ "level": observation.get("dungeon_level", 1),
256
+ "hp": 0,
257
+ "max_hp": 0,
258
+ "message": observation.get("message", ""),
259
+ "terminated": observation.get("terminated", False),
260
+ "reward": observation.get("reward_last", 0.0),
261
+ }
262
+
263
+ # Extract HP
264
+ if "character_stats" in observation:
265
+ stats = observation["character_stats"]
266
+ simplified["hp"] = stats.get("hp", 0)
267
+ simplified["max_hp"] = stats.get("max_hp", 0)
268
+
269
+ # Extract key map features
270
+ if "ascii_map" in observation:
271
+ map_info = parse_ascii_map(observation["ascii_map"])
272
+ simplified["player_pos"] = map_info["player_position"]
273
+ simplified["monsters_nearby"] = len(map_info["monster_positions"])
274
+ simplified["items_nearby"] = len(map_info["item_positions"])
275
+ simplified["at_stairs"] = any(
276
+ pos == map_info["player_position"] for _, pos in map_info["stairs_positions"]
277
+ )
278
+
279
+ return simplified
280
+
281
+
282
+ def extract_inventory_from_message(message: str) -> List[Dict[str, Any]]:
283
+ """
284
+ Extract inventory information from NetHack inventory messages.
285
+
286
+ Args:
287
+ message: Inventory message string
288
+
289
+ Returns:
290
+ List of inventory items
291
+ """
292
+ items = []
293
+
294
+ # Common inventory line patterns
295
+ # Example: "a - a blessed +1 long sword (weapon in hand)"
296
+ # Example: "b - an uncursed food ration"
297
+ pattern = r"^([a-zA-Z])\s*-\s*(.+?)(?:\s*\(([^)]+)\))?$"
298
+
299
+ for line in message.split("\n"):
300
+ match = re.match(pattern, line.strip())
301
+ if match:
302
+ letter, description, status = match.groups()
303
+
304
+ item = {
305
+ "letter": letter,
306
+ "description": description.strip(),
307
+ "status": status.strip() if status else None,
308
+ }
309
+
310
+ # Parse quantity
311
+ qty_match = re.match(r"^(\d+)\s+(.+)", description)
312
+ if qty_match:
313
+ item["quantity"] = int(qty_match.group(1))
314
+ item["name"] = qty_match.group(2)
315
+ else:
316
+ item["quantity"] = 1
317
+ item["name"] = description
318
+
319
+ # Identify item type
320
+ item["type"] = identify_item_type(description)
321
+
322
+ items.append(item)
323
+
324
+ return items
325
+
326
+
327
+ def identify_item_type(description: str) -> str:
328
+ """
329
+ Identify the type of an item from its description.
330
+
331
+ Args:
332
+ description: Item description string
333
+
334
+ Returns:
335
+ Item type string
336
+ """
337
+ desc_lower = description.lower()
338
+
339
+ # Weapons
340
+ if any(
341
+ word in desc_lower
342
+ for word in [
343
+ "sword",
344
+ "dagger",
345
+ "spear",
346
+ "axe",
347
+ "mace",
348
+ "bow",
349
+ "arrow",
350
+ "dart",
351
+ "knife",
352
+ ]
353
+ ):
354
+ return "weapon"
355
+
356
+ # Armor
357
+ if any(
358
+ word in desc_lower
359
+ for word in [
360
+ "armor",
361
+ "mail",
362
+ "helmet",
363
+ "boots",
364
+ "gloves",
365
+ "shield",
366
+ "cloak",
367
+ "robe",
368
+ ]
369
+ ):
370
+ return "armor"
371
+
372
+ # Food
373
+ if any(
374
+ word in desc_lower
375
+ for word in [
376
+ "food",
377
+ "ration",
378
+ "corpse",
379
+ "egg",
380
+ "fruit",
381
+ "meat",
382
+ "candy",
383
+ "cookie",
384
+ ]
385
+ ):
386
+ return "food"
387
+
388
+ # Potions
389
+ if "potion" in desc_lower:
390
+ return "potion"
391
+
392
+ # Scrolls
393
+ if "scroll" in desc_lower:
394
+ return "scroll"
395
+
396
+ # Wands
397
+ if "wand" in desc_lower:
398
+ return "wand"
399
+
400
+ # Rings
401
+ if "ring" in desc_lower:
402
+ return "ring"
403
+
404
+ # Amulets
405
+ if "amulet" in desc_lower:
406
+ return "amulet"
407
+
408
+ # Tools
409
+ if any(
410
+ word in desc_lower
411
+ for word in [
412
+ "pick",
413
+ "key",
414
+ "lamp",
415
+ "candle",
416
+ "bag",
417
+ "sack",
418
+ "horn",
419
+ "whistle",
420
+ "mirror",
421
+ ]
422
+ ):
423
+ return "tool"
424
+
425
+ # Gems/stones
426
+ if any(word in desc_lower for word in ["gem", "stone", "rock", "crystal"]):
427
+ return "gem"
428
+
429
+ # Gold
430
+ if "gold" in desc_lower:
431
+ return "gold"
432
+
433
+ return "unknown"
@@ -0,0 +1,201 @@
1
+ """Environment wrapper that adds trajectory recording capabilities."""
2
+
3
+ from typing import Dict, Any, Optional, Tuple
4
+ from pathlib import Path
5
+ import logging
6
+
7
+ from src.synth_env.examples.nethack.environment import NetHackEnvironment
8
+ from src.synth_env.examples.nethack.helpers.trajectory_recorder import (
9
+ TrajectoryRecorder,
10
+ )
11
+ from src.synth_env.environment.tools import EnvToolCall
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class RecordingNetHackEnvironment(NetHackEnvironment):
18
+ """NetHack environment with automatic trajectory recording."""
19
+
20
+ def __init__(
21
+ self,
22
+ save_dir: str = "temp/nethack_trajectories",
23
+ auto_record: bool = True,
24
+ **kwargs,
25
+ ):
26
+ """Initialize recording wrapper.
27
+
28
+ Args:
29
+ save_dir: Directory to save trajectories
30
+ auto_record: Whether to automatically record all episodes
31
+ **kwargs: Arguments passed to NetHackEnvironment
32
+ """
33
+ super().__init__(**kwargs)
34
+
35
+ self.save_dir = Path(save_dir)
36
+ self.save_dir.mkdir(parents=True, exist_ok=True)
37
+
38
+ self.auto_record = auto_record
39
+ self.recorder = TrajectoryRecorder(save_dir)
40
+ self.is_recording = False
41
+ self.trajectory_id = None
42
+
43
+ async def start(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
44
+ """Start environment and optionally begin recording."""
45
+ public_state, private_state = await super().start(**kwargs)
46
+
47
+ if self.auto_record:
48
+ # Extract character role from task metadata
49
+ character_role = "adventurer" # default
50
+ if hasattr(self.task_instance, "metadata") and hasattr(
51
+ self.task_instance.metadata, "character_role"
52
+ ):
53
+ character_role = self.task_instance.metadata.character_role
54
+
55
+ task_id = self.task_instance.task_id if self.task_instance else None
56
+ self.trajectory_id = self.recorder.start_recording(character_role, task_id)
57
+ self.is_recording = True
58
+
59
+ logger.info(f"Started recording trajectory: {self.trajectory_id}")
60
+
61
+ # Record initial state
62
+ obs = self._extract_observation(public_state)
63
+ self.recorder.record_step("reset", obs, 0.0, False, {})
64
+
65
+ return public_state, private_state
66
+
67
+ async def process_action(
68
+ self, tool_calls: list[EnvToolCall]
69
+ ) -> Tuple[Any, float, bool, Dict[str, Any]]:
70
+ """Process action and record if enabled."""
71
+ # Execute action
72
+ observation, reward, terminated, info = await super().process_action(tool_calls)
73
+
74
+ # Record step if recording
75
+ if self.is_recording and self.recorder.is_recording:
76
+ # Extract action from tool calls
77
+ action = "unknown"
78
+ if tool_calls and len(tool_calls) > 0:
79
+ if hasattr(tool_calls[0], "args") and "action" in tool_calls[0].args:
80
+ action = tool_calls[0].args["action"]
81
+
82
+ # Extract observation data
83
+ obs_data = self._extract_observation(observation)
84
+
85
+ # Record step
86
+ self.recorder.record_step(action, obs_data, reward, terminated, info)
87
+
88
+ # If episode ended, stop recording
89
+ if terminated and self.is_recording:
90
+ await self._finalize_recording(observation)
91
+
92
+ return observation, reward, terminated, info
93
+
94
+ def _extract_observation(self, state: Any) -> Dict[str, Any]:
95
+ """Extract observation data from state object."""
96
+ if hasattr(state, "__dict__"):
97
+ # Convert state object to dict
98
+ obs = {}
99
+
100
+ # Extract relevant fields
101
+ if hasattr(state, "ascii_map"):
102
+ obs["ascii_map"] = state.ascii_map
103
+ if hasattr(state, "message"):
104
+ obs["message"] = state.message
105
+ if hasattr(state, "character_stats"):
106
+ obs["player_stats"] = state.character_stats
107
+ if hasattr(state, "inventory"):
108
+ obs["inventory"] = state.inventory
109
+ if hasattr(state, "position"):
110
+ obs["player_stats"] = obs.get("player_stats", {})
111
+ obs["player_stats"]["x"] = state.position[0]
112
+ obs["player_stats"]["y"] = state.position[1]
113
+ if hasattr(state, "dungeon_level"):
114
+ obs["player_stats"] = obs.get("player_stats", {})
115
+ obs["player_stats"]["depth"] = state.dungeon_level
116
+ if hasattr(state, "in_menu"):
117
+ obs["in_menu"] = state.in_menu
118
+ if hasattr(state, "menu_items"):
119
+ obs["menu_items"] = state.menu_items
120
+
121
+ return obs
122
+
123
+ # If it's already a dict, return as is
124
+ return state if isinstance(state, dict) else {}
125
+
126
+ async def _finalize_recording(self, final_state: Any):
127
+ """Finalize and save recording."""
128
+ if not self.is_recording:
129
+ return
130
+
131
+ # Determine final status
132
+ final_status = "completed"
133
+ if hasattr(final_state, "message"):
134
+ msg = final_state.message.lower()
135
+ if "die" in msg or "killed" in msg:
136
+ final_status = "died"
137
+ elif "quit" in msg:
138
+ final_status = "quit"
139
+ elif "time limit" in msg or "truncated" in msg:
140
+ final_status = "truncated"
141
+
142
+ # Extract achievements if available
143
+ achievements = {}
144
+ # TODO: Extract achievements from game state
145
+
146
+ # Stop recording
147
+ self.recorder.stop_recording(final_status, achievements)
148
+
149
+ # Save trajectory
150
+ filepath = self.recorder.save_trajectory()
151
+ logger.info(f"Saved trajectory to: {filepath}")
152
+
153
+ # Get and log summary
154
+ summary = self.recorder.get_summary()
155
+ logger.info(f"Trajectory summary: {summary}")
156
+
157
+ self.is_recording = False
158
+
159
+ def start_recording(
160
+ self, character_role: Optional[str] = None, task_id: Optional[str] = None
161
+ ) -> str:
162
+ """Manually start recording."""
163
+ if self.is_recording:
164
+ logger.warning("Recording already in progress")
165
+ return self.trajectory_id
166
+
167
+ if character_role is None:
168
+ character_role = "adventurer"
169
+ if hasattr(self, "engine") and hasattr(self.engine, "character_role"):
170
+ character_role = self.engine.character_role
171
+
172
+ self.trajectory_id = self.recorder.start_recording(character_role, task_id)
173
+ self.is_recording = True
174
+ logger.info(f"Started manual recording: {self.trajectory_id}")
175
+
176
+ return self.trajectory_id
177
+
178
+ def stop_recording(self, save: bool = True) -> Optional[str]:
179
+ """Manually stop recording."""
180
+ if not self.is_recording:
181
+ logger.warning("No recording in progress")
182
+ return None
183
+
184
+ self.recorder.stop_recording()
185
+
186
+ filepath = None
187
+ if save:
188
+ filepath = self.recorder.save_trajectory()
189
+ logger.info(f"Saved trajectory to: {filepath}")
190
+
191
+ self.is_recording = False
192
+ return filepath
193
+
194
+ def get_recording_status(self) -> Dict[str, Any]:
195
+ """Get current recording status."""
196
+ return {
197
+ "is_recording": self.is_recording,
198
+ "trajectory_id": self.trajectory_id,
199
+ "current_step": self.recorder.current_step if self.is_recording else 0,
200
+ "total_reward": self.recorder.total_reward if self.is_recording else 0.0,
201
+ }