synth-ai 0.2.4.dev4__py3-none-any.whl → 0.2.4.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/environments/examples/__init__.py +1 -0
- synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
- synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +579 -0
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +266 -0
- synth_ai/environments/examples/crafter_classic/environment.py +364 -0
- synth_ai/environments/examples/crafter_classic/taskset.py +233 -0
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +229 -0
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +298 -0
- synth_ai/environments/examples/crafter_custom/__init__.py +4 -0
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +7 -0
- synth_ai/environments/examples/crafter_custom/crafter/config.py +182 -0
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +8 -0
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +269 -0
- synth_ai/environments/examples/crafter_custom/crafter/env.py +266 -0
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +418 -0
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +187 -0
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +119 -0
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +373 -0
- synth_ai/environments/examples/crafter_custom/environment.py +312 -0
- synth_ai/environments/examples/crafter_custom/run_dataset.py +305 -0
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
- synth_ai/environments/examples/enron/engine.py +291 -0
- synth_ai/environments/examples/enron/environment.py +165 -0
- synth_ai/environments/examples/enron/taskset.py +112 -0
- synth_ai/environments/examples/minigrid/__init__.py +48 -0
- synth_ai/environments/examples/minigrid/engine.py +589 -0
- synth_ai/environments/examples/minigrid/environment.py +274 -0
- synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
- synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
- synth_ai/environments/examples/minigrid/taskset.py +583 -0
- synth_ai/environments/examples/nethack/__init__.py +7 -0
- synth_ai/environments/examples/nethack/achievements.py +337 -0
- synth_ai/environments/examples/nethack/engine.py +738 -0
- synth_ai/environments/examples/nethack/environment.py +255 -0
- synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
- synth_ai/environments/examples/nethack/taskset.py +323 -0
- synth_ai/environments/examples/red/__init__.py +7 -0
- synth_ai/environments/examples/red/config_logging.py +110 -0
- synth_ai/environments/examples/red/engine.py +693 -0
- synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
- synth_ai/environments/examples/red/environment.py +235 -0
- synth_ai/environments/examples/red/taskset.py +77 -0
- synth_ai/environments/examples/sokoban/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine.py +675 -0
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
- synth_ai/environments/examples/sokoban/environment.py +228 -0
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
- synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
- synth_ai/environments/examples/sokoban/taskset.py +425 -0
- synth_ai/environments/examples/tictactoe/__init__.py +1 -0
- synth_ai/environments/examples/tictactoe/engine.py +368 -0
- synth_ai/environments/examples/tictactoe/environment.py +239 -0
- synth_ai/environments/examples/tictactoe/taskset.py +214 -0
- synth_ai/environments/examples/verilog/__init__.py +10 -0
- synth_ai/environments/examples/verilog/engine.py +328 -0
- synth_ai/environments/examples/verilog/environment.py +349 -0
- synth_ai/environments/examples/verilog/taskset.py +418 -0
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +391 -0
- synth_ai/environments/examples/wordle/environment.py +154 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +222 -0
- synth_ai/environments/service/app.py +8 -0
- synth_ai/environments/service/core_routes.py +38 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +163 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +201 -0
- synth_ai/learning/prompts/mipro.py +273 -1
- synth_ai/learning/prompts/random_search.py +247 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +160 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +305 -0
- synth_ai/lm/injection.py +81 -0
- synth_ai/lm/overrides.py +204 -0
- synth_ai/lm/provider_support/anthropic.py +39 -12
- synth_ai/lm/provider_support/openai.py +31 -4
- synth_ai/lm/vendors/core/anthropic_api.py +16 -0
- synth_ai/lm/vendors/openai_standard.py +35 -5
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/METADATA +2 -1
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/RECORD +123 -13
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev4.dist-info → synth_ai-0.2.4.dev6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,433 @@
|
|
1
|
+
"""Observation processing utilities for NetHack."""
|
2
|
+
|
3
|
+
from typing import Dict, Any, List, Optional, Tuple
|
4
|
+
import re
|
5
|
+
|
6
|
+
|
7
|
+
def format_observation_for_llm(observation: Dict[str, Any]) -> str:
|
8
|
+
"""
|
9
|
+
Format NetHack observation for LLM consumption.
|
10
|
+
|
11
|
+
Args:
|
12
|
+
observation: Raw observation dictionary
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
Formatted string suitable for LLM input
|
16
|
+
"""
|
17
|
+
lines = []
|
18
|
+
|
19
|
+
# Header with turn count and location
|
20
|
+
lines.append(f"=== NetHack - Turn {observation.get('turn_count', 0)} ===")
|
21
|
+
lines.append(f"Dungeon Level: {observation.get('dungeon_level', 1)}")
|
22
|
+
|
23
|
+
# ASCII map
|
24
|
+
if "ascii_map" in observation:
|
25
|
+
lines.append("\n--- Dungeon Map ---")
|
26
|
+
lines.append(observation["ascii_map"])
|
27
|
+
|
28
|
+
# Game message
|
29
|
+
if "message" in observation and observation["message"]:
|
30
|
+
# Clean null bytes and trailing whitespace
|
31
|
+
message = observation["message"].rstrip("\x00").strip()
|
32
|
+
if message:
|
33
|
+
lines.append(f"\nMessage: {message}")
|
34
|
+
|
35
|
+
# Character stats
|
36
|
+
if "character_stats" in observation:
|
37
|
+
stats = observation["character_stats"]
|
38
|
+
lines.append("\n--- Character Status ---")
|
39
|
+
lines.append(f"HP: {stats.get('hp', 0)}/{stats.get('max_hp', 0)}")
|
40
|
+
lines.append(f"Level: {stats.get('level', 1)} (Exp: {stats.get('experience', 0)})")
|
41
|
+
lines.append(f"AC: {stats.get('ac', 10)}, Gold: {stats.get('gold', 0)}")
|
42
|
+
|
43
|
+
# Attributes
|
44
|
+
attrs = []
|
45
|
+
for attr in [
|
46
|
+
"strength",
|
47
|
+
"dexterity",
|
48
|
+
"constitution",
|
49
|
+
"intelligence",
|
50
|
+
"wisdom",
|
51
|
+
"charisma",
|
52
|
+
]:
|
53
|
+
if attr in stats:
|
54
|
+
attrs.append(f"{attr[:3].upper()}:{stats[attr]}")
|
55
|
+
if attrs:
|
56
|
+
lines.append(f"Attributes: {' '.join(attrs)}")
|
57
|
+
|
58
|
+
# Inventory summary
|
59
|
+
if "inventory_summary" in observation:
|
60
|
+
lines.append("\n--- Inventory ---")
|
61
|
+
lines.append(observation["inventory_summary"])
|
62
|
+
|
63
|
+
# Menu items if in menu
|
64
|
+
if observation.get("in_menu", False) and "menu_items" in observation:
|
65
|
+
lines.append("\n--- Menu Options ---")
|
66
|
+
for i, item in enumerate(observation["menu_items"]):
|
67
|
+
if i < 26:
|
68
|
+
lines.append(f"{chr(ord('a') + i)}) {item}")
|
69
|
+
else:
|
70
|
+
lines.append(f"{i - 26}) {item}")
|
71
|
+
|
72
|
+
# Score and rewards
|
73
|
+
lines.append("\n--- Progress ---")
|
74
|
+
lines.append(f"Score: {observation.get('score', 0)}")
|
75
|
+
lines.append(f"Total Reward: {observation.get('total_reward', 0.0):.2f}")
|
76
|
+
lines.append(f"Last Reward: {observation.get('reward_last', 0.0):.2f}")
|
77
|
+
|
78
|
+
# Termination status
|
79
|
+
if observation.get("terminated", False):
|
80
|
+
lines.append("\n*** GAME OVER ***")
|
81
|
+
|
82
|
+
return "\n".join(lines)
|
83
|
+
|
84
|
+
|
85
|
+
def parse_ascii_map(ascii_map: str) -> Dict[str, Any]:
|
86
|
+
"""
|
87
|
+
Parse ASCII map to extract key information.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
ascii_map: Raw ASCII map string
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
Dictionary with extracted map information
|
94
|
+
"""
|
95
|
+
lines = ascii_map.strip().split("\n")
|
96
|
+
|
97
|
+
map_info = {
|
98
|
+
"width": 0,
|
99
|
+
"height": len(lines),
|
100
|
+
"player_position": None,
|
101
|
+
"stairs_positions": [],
|
102
|
+
"door_positions": [],
|
103
|
+
"item_positions": [],
|
104
|
+
"monster_positions": [],
|
105
|
+
"wall_positions": [],
|
106
|
+
"floor_positions": [],
|
107
|
+
}
|
108
|
+
|
109
|
+
if lines:
|
110
|
+
map_info["width"] = max(len(line) for line in lines)
|
111
|
+
|
112
|
+
# Common NetHack ASCII symbols
|
113
|
+
symbols = {
|
114
|
+
"@": "player",
|
115
|
+
"<": "stairs_up",
|
116
|
+
">": "stairs_down",
|
117
|
+
"+": "closed_door",
|
118
|
+
"-": "open_door_horizontal",
|
119
|
+
"|": "open_door_vertical",
|
120
|
+
".": "floor",
|
121
|
+
"#": "corridor",
|
122
|
+
" ": "wall",
|
123
|
+
"$": "gold",
|
124
|
+
"*": "gem",
|
125
|
+
"!": "potion",
|
126
|
+
"?": "scroll",
|
127
|
+
"/": "wand",
|
128
|
+
"=": "ring",
|
129
|
+
'"': "amulet",
|
130
|
+
"[": "armor",
|
131
|
+
")": "weapon",
|
132
|
+
"(": "tool",
|
133
|
+
"%": "food",
|
134
|
+
"^": "trap",
|
135
|
+
}
|
136
|
+
|
137
|
+
# Monster symbols (letters)
|
138
|
+
for y, line in enumerate(lines):
|
139
|
+
for x, char in enumerate(line):
|
140
|
+
pos = (x, y)
|
141
|
+
|
142
|
+
if char == "@":
|
143
|
+
map_info["player_position"] = pos
|
144
|
+
elif char == "<":
|
145
|
+
map_info["stairs_positions"].append(("up", pos))
|
146
|
+
elif char == ">":
|
147
|
+
map_info["stairs_positions"].append(("down", pos))
|
148
|
+
elif char in ["+", "-", "|"]:
|
149
|
+
map_info["door_positions"].append((char, pos))
|
150
|
+
elif char in ["$", "*", "!", "?", "/", "=", '"', "[", ")", "(", "%"]:
|
151
|
+
map_info["item_positions"].append((char, pos))
|
152
|
+
elif char.isalpha() and char != "@":
|
153
|
+
# Store both the character and position for monster identification
|
154
|
+
map_info["monster_positions"].append((char, pos))
|
155
|
+
elif char == ".":
|
156
|
+
map_info["floor_positions"].append(pos)
|
157
|
+
elif char in ["#", " "]:
|
158
|
+
map_info["wall_positions"].append(pos)
|
159
|
+
|
160
|
+
return map_info
|
161
|
+
|
162
|
+
|
163
|
+
def extract_game_context(observation: Dict[str, Any]) -> Dict[str, Any]:
|
164
|
+
"""
|
165
|
+
Extract high-level game context from observation.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
observation: Raw observation dictionary
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
Dictionary with game context information
|
172
|
+
"""
|
173
|
+
context = {
|
174
|
+
"in_combat": False,
|
175
|
+
"in_shop": False,
|
176
|
+
"at_stairs": False,
|
177
|
+
"items_nearby": False,
|
178
|
+
"doors_nearby": False,
|
179
|
+
"low_health": False,
|
180
|
+
"hungry": False,
|
181
|
+
"encumbered": False,
|
182
|
+
"in_menu": observation.get("in_menu", False),
|
183
|
+
"game_over": observation.get("terminated", False),
|
184
|
+
}
|
185
|
+
|
186
|
+
# Parse map for context
|
187
|
+
if "ascii_map" in observation:
|
188
|
+
map_info = parse_ascii_map(observation["ascii_map"])
|
189
|
+
|
190
|
+
# Check if player is near important features
|
191
|
+
if map_info["player_position"]:
|
192
|
+
px, py = map_info["player_position"]
|
193
|
+
|
194
|
+
# Check for nearby monsters (within 2 squares)
|
195
|
+
# Exclude pets from combat detection
|
196
|
+
pet_symbols = ["f", "d"] # f = kitten, d = dog
|
197
|
+
for monster_char, (mx, my) in map_info["monster_positions"]:
|
198
|
+
if abs(mx - px) <= 2 and abs(my - py) <= 2:
|
199
|
+
# Only trigger combat for non-pet monsters
|
200
|
+
if monster_char not in pet_symbols:
|
201
|
+
context["in_combat"] = True
|
202
|
+
break
|
203
|
+
|
204
|
+
# Check for stairs
|
205
|
+
for stair_type, (sx, sy) in map_info["stairs_positions"]:
|
206
|
+
if sx == px and sy == py:
|
207
|
+
context["at_stairs"] = True
|
208
|
+
context["stairs_type"] = stair_type
|
209
|
+
break
|
210
|
+
|
211
|
+
# Check for nearby items
|
212
|
+
for _, (ix, iy) in map_info["item_positions"]:
|
213
|
+
if abs(ix - px) <= 1 and abs(iy - py) <= 1:
|
214
|
+
context["items_nearby"] = True
|
215
|
+
break
|
216
|
+
|
217
|
+
# Check for nearby doors
|
218
|
+
for _, (dx, dy) in map_info["door_positions"]:
|
219
|
+
if abs(dx - px) <= 1 and abs(dy - py) <= 1:
|
220
|
+
context["doors_nearby"] = True
|
221
|
+
break
|
222
|
+
|
223
|
+
# Check health status
|
224
|
+
if "character_stats" in observation:
|
225
|
+
stats = observation["character_stats"]
|
226
|
+
hp = stats.get("hp", 0)
|
227
|
+
max_hp = stats.get("max_hp", 1)
|
228
|
+
if hp < max_hp * 0.3:
|
229
|
+
context["low_health"] = True
|
230
|
+
|
231
|
+
# Check for specific messages
|
232
|
+
message = observation.get("message", "").lower()
|
233
|
+
if "hungry" in message or "weak" in message:
|
234
|
+
context["hungry"] = True
|
235
|
+
if "burdened" in message or "stressed" in message:
|
236
|
+
context["encumbered"] = True
|
237
|
+
if "shop" in message or "shopkeeper" in message:
|
238
|
+
context["in_shop"] = True
|
239
|
+
|
240
|
+
return context
|
241
|
+
|
242
|
+
|
243
|
+
def simplify_observation(observation: Dict[str, Any]) -> Dict[str, Any]:
|
244
|
+
"""
|
245
|
+
Create a simplified observation for agents that need less detail.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
observation: Full observation dictionary
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
Simplified observation dictionary
|
252
|
+
"""
|
253
|
+
simplified = {
|
254
|
+
"turn": observation.get("turn_count", 0),
|
255
|
+
"level": observation.get("dungeon_level", 1),
|
256
|
+
"hp": 0,
|
257
|
+
"max_hp": 0,
|
258
|
+
"message": observation.get("message", ""),
|
259
|
+
"terminated": observation.get("terminated", False),
|
260
|
+
"reward": observation.get("reward_last", 0.0),
|
261
|
+
}
|
262
|
+
|
263
|
+
# Extract HP
|
264
|
+
if "character_stats" in observation:
|
265
|
+
stats = observation["character_stats"]
|
266
|
+
simplified["hp"] = stats.get("hp", 0)
|
267
|
+
simplified["max_hp"] = stats.get("max_hp", 0)
|
268
|
+
|
269
|
+
# Extract key map features
|
270
|
+
if "ascii_map" in observation:
|
271
|
+
map_info = parse_ascii_map(observation["ascii_map"])
|
272
|
+
simplified["player_pos"] = map_info["player_position"]
|
273
|
+
simplified["monsters_nearby"] = len(map_info["monster_positions"])
|
274
|
+
simplified["items_nearby"] = len(map_info["item_positions"])
|
275
|
+
simplified["at_stairs"] = any(
|
276
|
+
pos == map_info["player_position"] for _, pos in map_info["stairs_positions"]
|
277
|
+
)
|
278
|
+
|
279
|
+
return simplified
|
280
|
+
|
281
|
+
|
282
|
+
def extract_inventory_from_message(message: str) -> List[Dict[str, Any]]:
|
283
|
+
"""
|
284
|
+
Extract inventory information from NetHack inventory messages.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
message: Inventory message string
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
List of inventory items
|
291
|
+
"""
|
292
|
+
items = []
|
293
|
+
|
294
|
+
# Common inventory line patterns
|
295
|
+
# Example: "a - a blessed +1 long sword (weapon in hand)"
|
296
|
+
# Example: "b - an uncursed food ration"
|
297
|
+
pattern = r"^([a-zA-Z])\s*-\s*(.+?)(?:\s*\(([^)]+)\))?$"
|
298
|
+
|
299
|
+
for line in message.split("\n"):
|
300
|
+
match = re.match(pattern, line.strip())
|
301
|
+
if match:
|
302
|
+
letter, description, status = match.groups()
|
303
|
+
|
304
|
+
item = {
|
305
|
+
"letter": letter,
|
306
|
+
"description": description.strip(),
|
307
|
+
"status": status.strip() if status else None,
|
308
|
+
}
|
309
|
+
|
310
|
+
# Parse quantity
|
311
|
+
qty_match = re.match(r"^(\d+)\s+(.+)", description)
|
312
|
+
if qty_match:
|
313
|
+
item["quantity"] = int(qty_match.group(1))
|
314
|
+
item["name"] = qty_match.group(2)
|
315
|
+
else:
|
316
|
+
item["quantity"] = 1
|
317
|
+
item["name"] = description
|
318
|
+
|
319
|
+
# Identify item type
|
320
|
+
item["type"] = identify_item_type(description)
|
321
|
+
|
322
|
+
items.append(item)
|
323
|
+
|
324
|
+
return items
|
325
|
+
|
326
|
+
|
327
|
+
def identify_item_type(description: str) -> str:
|
328
|
+
"""
|
329
|
+
Identify the type of an item from its description.
|
330
|
+
|
331
|
+
Args:
|
332
|
+
description: Item description string
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
Item type string
|
336
|
+
"""
|
337
|
+
desc_lower = description.lower()
|
338
|
+
|
339
|
+
# Weapons
|
340
|
+
if any(
|
341
|
+
word in desc_lower
|
342
|
+
for word in [
|
343
|
+
"sword",
|
344
|
+
"dagger",
|
345
|
+
"spear",
|
346
|
+
"axe",
|
347
|
+
"mace",
|
348
|
+
"bow",
|
349
|
+
"arrow",
|
350
|
+
"dart",
|
351
|
+
"knife",
|
352
|
+
]
|
353
|
+
):
|
354
|
+
return "weapon"
|
355
|
+
|
356
|
+
# Armor
|
357
|
+
if any(
|
358
|
+
word in desc_lower
|
359
|
+
for word in [
|
360
|
+
"armor",
|
361
|
+
"mail",
|
362
|
+
"helmet",
|
363
|
+
"boots",
|
364
|
+
"gloves",
|
365
|
+
"shield",
|
366
|
+
"cloak",
|
367
|
+
"robe",
|
368
|
+
]
|
369
|
+
):
|
370
|
+
return "armor"
|
371
|
+
|
372
|
+
# Food
|
373
|
+
if any(
|
374
|
+
word in desc_lower
|
375
|
+
for word in [
|
376
|
+
"food",
|
377
|
+
"ration",
|
378
|
+
"corpse",
|
379
|
+
"egg",
|
380
|
+
"fruit",
|
381
|
+
"meat",
|
382
|
+
"candy",
|
383
|
+
"cookie",
|
384
|
+
]
|
385
|
+
):
|
386
|
+
return "food"
|
387
|
+
|
388
|
+
# Potions
|
389
|
+
if "potion" in desc_lower:
|
390
|
+
return "potion"
|
391
|
+
|
392
|
+
# Scrolls
|
393
|
+
if "scroll" in desc_lower:
|
394
|
+
return "scroll"
|
395
|
+
|
396
|
+
# Wands
|
397
|
+
if "wand" in desc_lower:
|
398
|
+
return "wand"
|
399
|
+
|
400
|
+
# Rings
|
401
|
+
if "ring" in desc_lower:
|
402
|
+
return "ring"
|
403
|
+
|
404
|
+
# Amulets
|
405
|
+
if "amulet" in desc_lower:
|
406
|
+
return "amulet"
|
407
|
+
|
408
|
+
# Tools
|
409
|
+
if any(
|
410
|
+
word in desc_lower
|
411
|
+
for word in [
|
412
|
+
"pick",
|
413
|
+
"key",
|
414
|
+
"lamp",
|
415
|
+
"candle",
|
416
|
+
"bag",
|
417
|
+
"sack",
|
418
|
+
"horn",
|
419
|
+
"whistle",
|
420
|
+
"mirror",
|
421
|
+
]
|
422
|
+
):
|
423
|
+
return "tool"
|
424
|
+
|
425
|
+
# Gems/stones
|
426
|
+
if any(word in desc_lower for word in ["gem", "stone", "rock", "crystal"]):
|
427
|
+
return "gem"
|
428
|
+
|
429
|
+
# Gold
|
430
|
+
if "gold" in desc_lower:
|
431
|
+
return "gold"
|
432
|
+
|
433
|
+
return "unknown"
|
@@ -0,0 +1,201 @@
|
|
1
|
+
"""Environment wrapper that adds trajectory recording capabilities."""
|
2
|
+
|
3
|
+
from typing import Dict, Any, Optional, Tuple
|
4
|
+
from pathlib import Path
|
5
|
+
import logging
|
6
|
+
|
7
|
+
from src.synth_env.examples.nethack.environment import NetHackEnvironment
|
8
|
+
from src.synth_env.examples.nethack.helpers.trajectory_recorder import (
|
9
|
+
TrajectoryRecorder,
|
10
|
+
)
|
11
|
+
from src.synth_env.environment.tools import EnvToolCall
|
12
|
+
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
class RecordingNetHackEnvironment(NetHackEnvironment):
|
18
|
+
"""NetHack environment with automatic trajectory recording."""
|
19
|
+
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
save_dir: str = "temp/nethack_trajectories",
|
23
|
+
auto_record: bool = True,
|
24
|
+
**kwargs,
|
25
|
+
):
|
26
|
+
"""Initialize recording wrapper.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
save_dir: Directory to save trajectories
|
30
|
+
auto_record: Whether to automatically record all episodes
|
31
|
+
**kwargs: Arguments passed to NetHackEnvironment
|
32
|
+
"""
|
33
|
+
super().__init__(**kwargs)
|
34
|
+
|
35
|
+
self.save_dir = Path(save_dir)
|
36
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
37
|
+
|
38
|
+
self.auto_record = auto_record
|
39
|
+
self.recorder = TrajectoryRecorder(save_dir)
|
40
|
+
self.is_recording = False
|
41
|
+
self.trajectory_id = None
|
42
|
+
|
43
|
+
async def start(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
44
|
+
"""Start environment and optionally begin recording."""
|
45
|
+
public_state, private_state = await super().start(**kwargs)
|
46
|
+
|
47
|
+
if self.auto_record:
|
48
|
+
# Extract character role from task metadata
|
49
|
+
character_role = "adventurer" # default
|
50
|
+
if hasattr(self.task_instance, "metadata") and hasattr(
|
51
|
+
self.task_instance.metadata, "character_role"
|
52
|
+
):
|
53
|
+
character_role = self.task_instance.metadata.character_role
|
54
|
+
|
55
|
+
task_id = self.task_instance.task_id if self.task_instance else None
|
56
|
+
self.trajectory_id = self.recorder.start_recording(character_role, task_id)
|
57
|
+
self.is_recording = True
|
58
|
+
|
59
|
+
logger.info(f"Started recording trajectory: {self.trajectory_id}")
|
60
|
+
|
61
|
+
# Record initial state
|
62
|
+
obs = self._extract_observation(public_state)
|
63
|
+
self.recorder.record_step("reset", obs, 0.0, False, {})
|
64
|
+
|
65
|
+
return public_state, private_state
|
66
|
+
|
67
|
+
async def process_action(
|
68
|
+
self, tool_calls: list[EnvToolCall]
|
69
|
+
) -> Tuple[Any, float, bool, Dict[str, Any]]:
|
70
|
+
"""Process action and record if enabled."""
|
71
|
+
# Execute action
|
72
|
+
observation, reward, terminated, info = await super().process_action(tool_calls)
|
73
|
+
|
74
|
+
# Record step if recording
|
75
|
+
if self.is_recording and self.recorder.is_recording:
|
76
|
+
# Extract action from tool calls
|
77
|
+
action = "unknown"
|
78
|
+
if tool_calls and len(tool_calls) > 0:
|
79
|
+
if hasattr(tool_calls[0], "args") and "action" in tool_calls[0].args:
|
80
|
+
action = tool_calls[0].args["action"]
|
81
|
+
|
82
|
+
# Extract observation data
|
83
|
+
obs_data = self._extract_observation(observation)
|
84
|
+
|
85
|
+
# Record step
|
86
|
+
self.recorder.record_step(action, obs_data, reward, terminated, info)
|
87
|
+
|
88
|
+
# If episode ended, stop recording
|
89
|
+
if terminated and self.is_recording:
|
90
|
+
await self._finalize_recording(observation)
|
91
|
+
|
92
|
+
return observation, reward, terminated, info
|
93
|
+
|
94
|
+
def _extract_observation(self, state: Any) -> Dict[str, Any]:
|
95
|
+
"""Extract observation data from state object."""
|
96
|
+
if hasattr(state, "__dict__"):
|
97
|
+
# Convert state object to dict
|
98
|
+
obs = {}
|
99
|
+
|
100
|
+
# Extract relevant fields
|
101
|
+
if hasattr(state, "ascii_map"):
|
102
|
+
obs["ascii_map"] = state.ascii_map
|
103
|
+
if hasattr(state, "message"):
|
104
|
+
obs["message"] = state.message
|
105
|
+
if hasattr(state, "character_stats"):
|
106
|
+
obs["player_stats"] = state.character_stats
|
107
|
+
if hasattr(state, "inventory"):
|
108
|
+
obs["inventory"] = state.inventory
|
109
|
+
if hasattr(state, "position"):
|
110
|
+
obs["player_stats"] = obs.get("player_stats", {})
|
111
|
+
obs["player_stats"]["x"] = state.position[0]
|
112
|
+
obs["player_stats"]["y"] = state.position[1]
|
113
|
+
if hasattr(state, "dungeon_level"):
|
114
|
+
obs["player_stats"] = obs.get("player_stats", {})
|
115
|
+
obs["player_stats"]["depth"] = state.dungeon_level
|
116
|
+
if hasattr(state, "in_menu"):
|
117
|
+
obs["in_menu"] = state.in_menu
|
118
|
+
if hasattr(state, "menu_items"):
|
119
|
+
obs["menu_items"] = state.menu_items
|
120
|
+
|
121
|
+
return obs
|
122
|
+
|
123
|
+
# If it's already a dict, return as is
|
124
|
+
return state if isinstance(state, dict) else {}
|
125
|
+
|
126
|
+
async def _finalize_recording(self, final_state: Any):
|
127
|
+
"""Finalize and save recording."""
|
128
|
+
if not self.is_recording:
|
129
|
+
return
|
130
|
+
|
131
|
+
# Determine final status
|
132
|
+
final_status = "completed"
|
133
|
+
if hasattr(final_state, "message"):
|
134
|
+
msg = final_state.message.lower()
|
135
|
+
if "die" in msg or "killed" in msg:
|
136
|
+
final_status = "died"
|
137
|
+
elif "quit" in msg:
|
138
|
+
final_status = "quit"
|
139
|
+
elif "time limit" in msg or "truncated" in msg:
|
140
|
+
final_status = "truncated"
|
141
|
+
|
142
|
+
# Extract achievements if available
|
143
|
+
achievements = {}
|
144
|
+
# TODO: Extract achievements from game state
|
145
|
+
|
146
|
+
# Stop recording
|
147
|
+
self.recorder.stop_recording(final_status, achievements)
|
148
|
+
|
149
|
+
# Save trajectory
|
150
|
+
filepath = self.recorder.save_trajectory()
|
151
|
+
logger.info(f"Saved trajectory to: {filepath}")
|
152
|
+
|
153
|
+
# Get and log summary
|
154
|
+
summary = self.recorder.get_summary()
|
155
|
+
logger.info(f"Trajectory summary: {summary}")
|
156
|
+
|
157
|
+
self.is_recording = False
|
158
|
+
|
159
|
+
def start_recording(
|
160
|
+
self, character_role: Optional[str] = None, task_id: Optional[str] = None
|
161
|
+
) -> str:
|
162
|
+
"""Manually start recording."""
|
163
|
+
if self.is_recording:
|
164
|
+
logger.warning("Recording already in progress")
|
165
|
+
return self.trajectory_id
|
166
|
+
|
167
|
+
if character_role is None:
|
168
|
+
character_role = "adventurer"
|
169
|
+
if hasattr(self, "engine") and hasattr(self.engine, "character_role"):
|
170
|
+
character_role = self.engine.character_role
|
171
|
+
|
172
|
+
self.trajectory_id = self.recorder.start_recording(character_role, task_id)
|
173
|
+
self.is_recording = True
|
174
|
+
logger.info(f"Started manual recording: {self.trajectory_id}")
|
175
|
+
|
176
|
+
return self.trajectory_id
|
177
|
+
|
178
|
+
def stop_recording(self, save: bool = True) -> Optional[str]:
|
179
|
+
"""Manually stop recording."""
|
180
|
+
if not self.is_recording:
|
181
|
+
logger.warning("No recording in progress")
|
182
|
+
return None
|
183
|
+
|
184
|
+
self.recorder.stop_recording()
|
185
|
+
|
186
|
+
filepath = None
|
187
|
+
if save:
|
188
|
+
filepath = self.recorder.save_trajectory()
|
189
|
+
logger.info(f"Saved trajectory to: {filepath}")
|
190
|
+
|
191
|
+
self.is_recording = False
|
192
|
+
return filepath
|
193
|
+
|
194
|
+
def get_recording_status(self) -> Dict[str, Any]:
|
195
|
+
"""Get current recording status."""
|
196
|
+
return {
|
197
|
+
"is_recording": self.is_recording,
|
198
|
+
"trajectory_id": self.trajectory_id,
|
199
|
+
"current_step": self.recorder.current_step if self.is_recording else 0,
|
200
|
+
"total_reward": self.recorder.total_reward if self.is_recording else 0.0,
|
201
|
+
}
|