synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. synth_ai/__init__.py +28 -2
  2. synth_ai/core/system.py +4 -0
  3. synth_ai/environments/__init__.py +35 -0
  4. synth_ai/environments/environment/__init__.py +1 -0
  5. synth_ai/environments/environment/artifacts/__init__.py +1 -0
  6. synth_ai/environments/environment/artifacts/base.py +50 -0
  7. synth_ai/environments/environment/core.py +22 -0
  8. synth_ai/environments/environment/db/__init__.py +1 -0
  9. synth_ai/environments/environment/db/sqlite.py +45 -0
  10. synth_ai/environments/environment/registry.py +24 -0
  11. synth_ai/environments/environment/resources/sqlite.py +46 -0
  12. synth_ai/environments/environment/results.py +1 -0
  13. synth_ai/environments/environment/rewards/__init__.py +1 -0
  14. synth_ai/environments/environment/rewards/core.py +28 -0
  15. synth_ai/environments/environment/shared_engine.py +26 -0
  16. synth_ai/environments/environment/tools/__init__.py +34 -0
  17. synth_ai/environments/examples/__init__.py +1 -0
  18. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
  26. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  27. synth_ai/environments/examples/crafter_classic/engine.py +502 -0
  28. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  29. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  30. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  31. synth_ai/environments/examples/crafter_classic/environment.py +255 -0
  32. synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
  33. synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
  34. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  35. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  36. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  37. synth_ai/environments/examples/enron/engine.py +291 -0
  38. synth_ai/environments/examples/enron/environment.py +165 -0
  39. synth_ai/environments/examples/enron/taskset.py +112 -0
  40. synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
  41. synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
  42. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  43. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  44. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
  45. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  46. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
  47. synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
  48. synth_ai/environments/examples/minigrid/engine.py +589 -0
  49. synth_ai/environments/examples/minigrid/environment.py +274 -0
  50. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  51. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  52. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  53. synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
  54. synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
  55. synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
  56. synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
  57. synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
  58. synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
  59. synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
  60. synth_ai/environments/examples/nethack/__init__.py +7 -0
  61. synth_ai/environments/examples/nethack/achievements.py +337 -0
  62. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  63. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  64. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
  65. synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
  66. synth_ai/environments/examples/nethack/engine.py +738 -0
  67. synth_ai/environments/examples/nethack/environment.py +255 -0
  68. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  69. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  70. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  71. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  72. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  73. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  74. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  75. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  76. synth_ai/environments/examples/nethack/taskset.py +323 -0
  77. synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
  78. synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
  79. synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
  80. synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
  81. synth_ai/environments/examples/red/__init__.py +7 -0
  82. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  83. synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
  84. synth_ai/environments/examples/red/config_logging.py +110 -0
  85. synth_ai/environments/examples/red/engine.py +693 -0
  86. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  87. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  88. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  89. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  90. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  91. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  92. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  93. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  94. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  95. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  96. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  97. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  98. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  99. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  100. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  101. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  102. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  103. synth_ai/environments/examples/red/environment.py +235 -0
  104. synth_ai/environments/examples/red/taskset.py +77 -0
  105. synth_ai/environments/examples/red/test_fixes.py +125 -0
  106. synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
  107. synth_ai/environments/examples/red/units/__init__.py +1 -0
  108. synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
  109. synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
  110. synth_ai/environments/examples/red/units/test_engine.py +192 -0
  111. synth_ai/environments/examples/red/units/test_environment.py +455 -0
  112. synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
  113. synth_ai/environments/examples/red/units/test_integration.py +217 -0
  114. synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
  115. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
  116. synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
  117. synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
  118. synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
  119. synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
  120. synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
  121. synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
  122. synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
  123. synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
  124. synth_ai/environments/examples/red/units/test_taskset.py +116 -0
  125. synth_ai/environments/examples/red/units/test_tree.py +448 -0
  126. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  127. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
  128. synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
  129. synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
  130. synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
  131. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
  132. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
  133. synth_ai/environments/examples/sokoban/engine.py +675 -0
  134. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  135. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  136. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  137. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  138. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  139. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  140. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  141. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  142. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  143. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  144. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  145. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  146. synth_ai/environments/examples/sokoban/environment.py +228 -0
  147. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  148. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  149. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  150. synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
  151. synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
  152. synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
  153. synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
  154. synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
  155. synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
  156. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  157. synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
  158. synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
  159. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  160. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  161. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  162. synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
  163. synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
  164. synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
  165. synth_ai/environments/examples/verilog/__init__.py +10 -0
  166. synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
  167. synth_ai/environments/examples/verilog/engine.py +328 -0
  168. synth_ai/environments/examples/verilog/environment.py +349 -0
  169. synth_ai/environments/examples/verilog/taskset.py +418 -0
  170. synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
  171. synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
  172. synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
  173. synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
  174. synth_ai/environments/reproducibility/core.py +42 -0
  175. synth_ai/environments/reproducibility/tree.py +364 -0
  176. synth_ai/environments/service/app.py +78 -0
  177. synth_ai/environments/service/core_routes.py +775 -0
  178. synth_ai/environments/service/external_registry.py +57 -0
  179. synth_ai/environments/service/registry.py +9 -0
  180. synth_ai/environments/stateful/__init__.py +1 -0
  181. synth_ai/environments/stateful/core.py +28 -0
  182. synth_ai/environments/stateful/engine.py +21 -0
  183. synth_ai/environments/stateful/state.py +7 -0
  184. synth_ai/environments/tasks/api.py +19 -0
  185. synth_ai/environments/tasks/core.py +78 -0
  186. synth_ai/environments/tasks/filters.py +39 -0
  187. synth_ai/environments/tasks/utils.py +89 -0
  188. synth_ai/environments/v0_observability/history.py +3 -0
  189. synth_ai/environments/v0_observability/log.py +2 -0
  190. synth_ai/lm/caching/constants.py +1 -0
  191. synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
  192. synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
  193. synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
  194. synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
  195. synth_ai/{zyk/lms → lm}/config.py +2 -1
  196. synth_ai/{zyk/lms → lm}/constants.py +2 -2
  197. synth_ai/{zyk/lms → lm}/core/all.py +10 -10
  198. synth_ai/{zyk/lms → lm}/core/main.py +57 -33
  199. synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
  200. synth_ai/lm/cost/monitor.py +1 -0
  201. synth_ai/lm/cost/statefulness.py +1 -0
  202. synth_ai/lm/provider_support/__init__.py +8 -0
  203. synth_ai/lm/provider_support/anthropic.py +945 -0
  204. synth_ai/lm/provider_support/openai.py +1115 -0
  205. synth_ai/lm/provider_support/suppress_logging.py +31 -0
  206. synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
  207. synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
  208. synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
  209. synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
  210. synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
  211. synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
  212. synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
  213. synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
  214. synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
  215. synth_ai/lm/vendors/supported/__init__.py +0 -0
  216. synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
  217. synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
  218. synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
  219. synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
  220. synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
  221. synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
  222. synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
  223. synth_ai/tracing/__init__.py +0 -0
  224. synth_ai/tracing/abstractions.py +224 -0
  225. synth_ai/tracing/base_client.py +91 -0
  226. synth_ai/tracing/client_manager.py +131 -0
  227. synth_ai/tracing/config.py +140 -0
  228. synth_ai/tracing/context.py +146 -0
  229. synth_ai/tracing/decorators.py +679 -0
  230. synth_ai/tracing/events/__init__.py +0 -0
  231. synth_ai/tracing/events/manage.py +147 -0
  232. synth_ai/tracing/events/scope.py +86 -0
  233. synth_ai/tracing/events/store.py +227 -0
  234. synth_ai/tracing/immediate_client.py +152 -0
  235. synth_ai/tracing/local.py +18 -0
  236. synth_ai/tracing/log_client_base.py +74 -0
  237. synth_ai/tracing/retry_queue.py +187 -0
  238. synth_ai/tracing/trackers.py +515 -0
  239. synth_ai/tracing/upload.py +504 -0
  240. synth_ai/tracing/utils.py +9 -0
  241. synth_ai/zyk/__init__.py +28 -2
  242. synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
  243. synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
  244. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
  245. synth_ai/zyk/lms/caching/constants.py +0 -1
  246. synth_ai/zyk/lms/cost/monitor.py +0 -1
  247. synth_ai/zyk/lms/cost/statefulness.py +0 -1
  248. synth_ai-0.2.0.dist-info/METADATA +0 -36
  249. synth_ai-0.2.0.dist-info/RECORD +0 -50
  250. /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
  251. /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
  252. /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
  253. /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
  254. /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
  255. /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
  256. /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
  257. /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
  258. /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
  259. /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
  260. /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
  261. /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
  262. /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
  263. /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
  264. /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
  265. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
  266. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,832 @@
1
+ """ReAct agent demo for NetHack environment."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import argparse
7
+ from typing import Dict, Any, List, Optional, TYPE_CHECKING, cast
8
+ from pydantic import BaseModel, Field
9
+
10
+ from synth_ai.zyk import LM
11
+
12
+ from synth_ai.environments.examples.nethack.environment import NetHackEnvironment
13
+ from synth_ai.environments.examples.nethack.taskset import (
14
+ create_nethack_taskset,
15
+ NetHackTaskInstanceMetadata,
16
+ )
17
+ from synth_ai.environments.examples.nethack.helpers import (
18
+ format_observation_for_llm,
19
+ extract_game_context,
20
+ get_actions_for_context,
21
+ )
22
+ from synth_ai.zyk.lms.tools.base import BaseTool
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class TerminateArgs(BaseModel):
29
+ """Arguments for termination."""
30
+
31
+ reason: str
32
+
33
+
34
+ class NetHackInteractArgs(BaseModel):
35
+ """Arguments for NetHack interaction."""
36
+
37
+ reasoning: str = Field(description="Explain your reasoning for these actions")
38
+ actions: List[str] = Field(description="List of actions to perform in sequence")
39
+
40
+
41
+ class NetHackInteractTool(BaseTool):
42
+ """Tool for interacting with NetHack environment"""
43
+
44
+ name: str = "nethack_interact"
45
+ arguments: type[BaseModel] = NetHackInteractArgs
46
+ description: str = (
47
+ "Perform one or more actions in NetHack. Use EXACT action names from the system prompt!\n"
48
+ "\n"
49
+ "VALID ACTIONS: north, south, east, west, northeast, northwest, southeast, southwest, "
50
+ "wait, go_up, go_down, search, open, close, inventory, pickup, drop, wear, wield, "
51
+ "eat, drink, read, zap, apply, throw, fire, cast, pray, look, help, quit, save, "
52
+ "plus menu letters (a-z) and numbers (0-9).\n"
53
+ "\n"
54
+ "CRITICAL RULES:\n"
55
+ "- Use exact names: 'go_up' not 'up', 'open' not 'open door'\n"
56
+ "- No compound actions: use ['open', 'west'] not ['open west']\n"
57
+ "- Combat: move INTO monsters, there is no 'fight' action\n"
58
+ "- Stairs: must be ON stairs (< or >) to use go_up/go_down\n"
59
+ "\n"
60
+ "Invalid actions will return helpful error messages."
61
+ )
62
+
63
+
64
+ class TerminateTool(BaseTool):
65
+ """Tool for terminating the game"""
66
+
67
+ name: str = "terminate"
68
+ arguments: type[BaseModel] = TerminateArgs
69
+ description: str = "End the game when you die, complete objectives, or decide to quit"
70
+
71
+
72
+ class NetHackReActAgent:
73
+ """ReAct agent for playing NetHack."""
74
+
75
+ def __init__(self, llm: LM, max_turns: int = 50):
76
+ self.llm = llm
77
+ self.max_turns = max_turns
78
+ self.history: List[Dict[str, Any]] = []
79
+ self.recent_actions: List[Dict[str, Any]] = [] # Track recent tool calls
80
+ self.system_name = "nethack-react"
81
+
82
+ # Define tools using BaseTool classes
83
+ self.tools = [NetHackInteractTool(), TerminateTool()]
84
+
85
+ def _create_system_prompt(self, task_instructions: str) -> str:
86
+ """Create the system prompt for the agent."""
87
+ base_prompt = (
88
+ """You are an expert NetHack player. Your goal is to navigate the dungeon and complete objectives.
89
+
90
+ """
91
+ + task_instructions
92
+ + """
93
+
94
+ === EXACT VALID ACTIONS ===
95
+ You MUST use only these exact action names. Invalid actions will cause errors!
96
+
97
+ MOVEMENT (consumes time):
98
+ - north, south, east, west, northeast, northwest, southeast, southwest
99
+ - wait (rest for one turn)
100
+ - go_up (use stairs up), go_down (use stairs down)
101
+ - run_north, run_south, run_east, run_west (run until something interesting)
102
+
103
+ EXPLORATION (consumes time):
104
+ - search (for secret doors/traps)
105
+ - open (open nearest door)
106
+ - close (close nearest door)
107
+ - kick (kick something)
108
+ - force (force a lock)
109
+ - untrap (disarm trap)
110
+
111
+ INVENTORY (consumes time):
112
+ - inventory (check items - but this is actually FREE!)
113
+ - pickup (pick up items here)
114
+ - drop (drop items)
115
+ - dropall (drop everything)
116
+ - wear (put on armor)
117
+ - take_off (remove armor)
118
+ - wield (equip weapon)
119
+ - unwield (unequip weapon)
120
+ - quiver (ready ammunition)
121
+ - put_on (wear accessories)
122
+ - remove (take off accessories)
123
+
124
+ USING ITEMS (consumes time):
125
+ - eat (consume food)
126
+ - drink (drink potion - but NetHack calls this "quaff")
127
+ - read (read scroll/book)
128
+ - zap (use wand)
129
+ - apply (use tool)
130
+ - invoke (use artifact power)
131
+ - rub (rub lamp/stone)
132
+ - throw (throw item)
133
+ - fire (shoot from quiver)
134
+
135
+ MAGIC & DIVINE (consumes time):
136
+ - cast (cast spell)
137
+ - pray (pray to deity)
138
+ - offer (sacrifice at altar)
139
+ - turn_undead (priest ability)
140
+
141
+ CHARACTER ACTIONS (consumes time):
142
+ - enhance (improve skills)
143
+ - sit (sit down)
144
+ - pay (pay shopkeeper)
145
+ - chat (talk to someone)
146
+ - loot (search container)
147
+ - engrave (write on ground)
148
+
149
+ FREE ACTIONS (do NOT consume time):
150
+ - look (look around - FREE!)
151
+ - farlook (examine specific location - FREE!)
152
+ - whatis (identify map symbol - FREE!)
153
+ - identify (check item details - FREE!)
154
+ - discoveries (list known items - FREE!)
155
+ - conduct (check game conduct - FREE!)
156
+ - attributes (check character stats - FREE!)
157
+ - help (show help - FREE!)
158
+ - version (show version - FREE!)
159
+ - history (show message history - FREE!)
160
+
161
+ GAME COMMANDS:
162
+ - save (save game)
163
+ - quit (quit game)
164
+
165
+ MENU/PROMPT RESPONSES:
166
+ - yes, no, all, none, escape
167
+ - Single letters: a, b, c, ..., z, A, B, C, ..., Z
168
+ - Single numbers: 0, 1, 2, ..., 9
169
+
170
+ === CRITICAL RULES ===
171
+
172
+ 1. EXACT ACTION NAMES: Use only the exact names above! "go_up" not "up", "open" not "open door"
173
+
174
+ 2. NO COMPOUND ACTIONS: Don't use "open west" or "go to stairs". Use separate actions: ["open", "west"]
175
+
176
+ 3. COMBAT: NO "fight" action exists! To attack monsters, MOVE INTO them using directional moves!
177
+
178
+ 4. STAIRS: Must be standing ON stairs (< or >) to use go_up/go_down. Move to them first!
179
+
180
+ 5. FREE ACTIONS: look, inventory, help, etc. don't advance time. Use sparingly!
181
+
182
+ 6. DOORS: Use "open" (not "open door"). It opens the nearest door automatically.
183
+
184
+ 7. TIME MANAGEMENT: Every action except FREE ACTIONS consumes a turn and lets monsters act!
185
+
186
+ === MAP SYMBOLS ===
187
+ @ = you (player)
188
+ f = kitten (pet), d = dog (pet)
189
+ a-z = monsters (letters are monsters!)
190
+ $ = gold, * = gem
191
+ + = closed door, - = horizontal open door, | = vertical open door
192
+ < = stairs up, > = stairs down
193
+ . = floor, # = corridor/wall
194
+ ! = potion, ? = scroll, / = wand, = = ring, " = amulet
195
+ [ = armor, ) = weapon, ( = tool, % = food
196
+ ^ = trap, { = fountain, } = pool
197
+
198
+ === STRATEGY ===
199
+ 1. EXPLORE: Move around systematically to map the dungeon
200
+ 2. DESCEND: Find stairs down (>) and use go_down to reach deeper levels
201
+ 3. COMBAT: Move into monsters to attack them
202
+ 4. LOOT: Pick up gold ($) and useful items
203
+ 5. SURVIVE: Monitor health, eat when hungry
204
+ 6. ACHIEVE: Try to unlock achievements through exploration and combat
205
+
206
+ === LEARNING FROM FEEDBACK ===
207
+ - If "No stairs here to up/down" → You're not standing on stairs, move to them first
208
+ - If "Unknown action" → You used an invalid action name, use exact names from list above
209
+ - If nothing happens after "look" → Use a time-consuming action to advance the game
210
+ - If "You see no objects here" → Move to explore other areas
211
+
212
+ Always think step by step and use the exact action names provided above!"""
213
+ )
214
+
215
+ return base_prompt
216
+
217
+ def _update_last_action_result(self, result_message: str):
218
+ """Update the result of the most recent action."""
219
+ if self.recent_actions:
220
+ self.recent_actions[-1]["result"] = result_message[:100] # Truncate long messages
221
+
222
+ def _format_observation(self, obs: Dict[str, Any]) -> str:
223
+ """Format observation for LLM."""
224
+ # Use our formatting utility
225
+ formatted = format_observation_for_llm(obs)
226
+
227
+ # Add recent actions context if we have any
228
+ if self.recent_actions:
229
+ formatted += "\n\n=== RECENT ACTIONS (Last 3 turns) ==="
230
+ for action in self.recent_actions[-3:]:
231
+ formatted += f"\nTurn {action['turn']}: {action['actions']} → {action['result']}"
232
+ formatted += "\n=== END RECENT ACTIONS ==="
233
+
234
+ # Add context information
235
+ context = extract_game_context(obs)
236
+ if context["in_combat"]:
237
+ formatted += "\n\n⚔️ COMBAT ALERT: Monster nearby!"
238
+ if context["low_health"]:
239
+ formatted += "\n\n❤️ WARNING: Low health!"
240
+ if context["hungry"]:
241
+ formatted += "\n\n🍖 WARNING: You are hungry!"
242
+ if context["at_stairs"]:
243
+ formatted += f"\n\n🪜 You are at the {context.get('stairs_type', 'stairs')}!"
244
+
245
+ # Add ALL valid actions organized by category
246
+ from src.synth_env.examples.nethack.helpers.action_mapping import (
247
+ ACTION_CATEGORIES,
248
+ )
249
+
250
+ formatted += "\n\n=== VALID ACTIONS ==="
251
+ for category in ACTION_CATEGORIES:
252
+ formatted += f"\n{category.name} ({category.description}):"
253
+ formatted += f"\n {', '.join(category.actions)}"
254
+
255
+ # Also add context-specific suggestions
256
+ suggested_actions = get_actions_for_context(obs)
257
+ if suggested_actions:
258
+ formatted += f"\n\n💡 Suggested for current situation: {', '.join(suggested_actions)}"
259
+
260
+ return formatted
261
+
262
+ async def decide(self, obs: str) -> Dict[str, Any]:
263
+ """Get LLM decision based on observation."""
264
+ try:
265
+ # Debug: log first few observations
266
+ if len(self.history) < 3:
267
+ logger.info(f"Turn {len(self.history) + 1} observation preview: {obs[:300]}...")
268
+
269
+ # Save full prompt to file for inspection
270
+ turn_num = len(self.history) + 1
271
+ with open(f"nethack_prompt_turn_{turn_num}.txt", "w") as f:
272
+ f.write("=== SYSTEM PROMPT ===\n")
273
+ f.write(self.system_prompt)
274
+ f.write("\n\n=== USER MESSAGE (OBSERVATION) ===\n")
275
+ f.write(obs)
276
+ f.write("\n\n=== TOOLS ===\n")
277
+ f.write(json.dumps([tool.to_openai_tool() for tool in self.tools], indent=2))
278
+
279
+ # Add observation to history (limit history size)
280
+ self.history.append({"role": "user", "content": obs})
281
+ if len(self.history) > 10:
282
+ # Keep only recent history
283
+ self.history = self.history[-10:]
284
+
285
+ # Get LLM response
286
+ response = await self.llm.respond_async(
287
+ system_message=self.system_prompt, user_message=obs, tools=self.tools
288
+ )
289
+
290
+ # Check response has tool calls
291
+ logger.info(f"Response type: {type(response)}")
292
+ logger.info(f"Has tool_calls: {hasattr(response, 'tool_calls')}")
293
+ if hasattr(response, "tool_calls"):
294
+ logger.info(f"tool_calls value: {response.tool_calls}")
295
+
296
+ # Parse response - access tool_calls directly like other agents
297
+ if hasattr(response, "tool_calls") and response.tool_calls:
298
+ tool_calls = response.tool_calls
299
+ logger.info(f"Found {len(tool_calls)} tool calls")
300
+
301
+ if tool_calls and len(tool_calls) > 0:
302
+ tool_call = tool_calls[0]
303
+
304
+ # Handle different tool call structures
305
+ tool_name = ""
306
+ tool_args_str = ""
307
+
308
+ if (
309
+ hasattr(tool_call, "function")
310
+ and hasattr(tool_call.function, "name")
311
+ and hasattr(tool_call.function, "arguments")
312
+ ):
313
+ tool_name = tool_call.function.name
314
+ tool_args_str = tool_call.function.arguments
315
+ elif (
316
+ isinstance(tool_call, dict)
317
+ and "function" in tool_call
318
+ and isinstance(tool_call["function"], dict)
319
+ ):
320
+ tool_name = tool_call["function"].get("name")
321
+ tool_args_str = tool_call["function"].get("arguments")
322
+
323
+ # Log the full tool call for debugging
324
+ logger.info(f"Tool name: {tool_name}, Args: {tool_args_str}")
325
+
326
+ # Parse arguments
327
+ if isinstance(tool_args_str, str):
328
+ try:
329
+ args = json.loads(tool_args_str)
330
+ except:
331
+ args = {
332
+ "reasoning": "Failed to parse arguments",
333
+ "actions": ["wait"],
334
+ }
335
+ else:
336
+ args = tool_args_str
337
+
338
+ # Track this action for next turn's context
339
+ turn_num = len(self.recent_actions) + 1
340
+ action_record = {
341
+ "turn": turn_num,
342
+ "actions": args.get("actions", [args.get("action", "unknown")]),
343
+ "reasoning": args.get("reasoning", ""),
344
+ "result": "pending", # Will be updated after execution
345
+ }
346
+ self.recent_actions.append(action_record)
347
+
348
+ # Keep only last 3 actions
349
+ if len(self.recent_actions) > 3:
350
+ self.recent_actions = self.recent_actions[-3:]
351
+
352
+ return {"name": tool_name, "parameters": args}
353
+
354
+ # Fallback to exploring
355
+ logger.warning("No tool call found in LLM response, defaulting to wait")
356
+ logger.warning(f"Response type: {type(response)}")
357
+ logger.warning(f"Response attrs: {dir(response)}")
358
+ logger.warning(f"Response: {response}")
359
+
360
+ # Log all attributes of response for debugging
361
+ for attr in dir(response):
362
+ if not attr.startswith("_"):
363
+ try:
364
+ value = getattr(response, attr)
365
+ if not callable(value):
366
+ logger.warning(f" {attr}: {value}")
367
+ except:
368
+ pass
369
+ return {
370
+ "name": "nethack_interact",
371
+ "parameters": {
372
+ "reasoning": "No valid response from LLM",
373
+ "actions": ["wait"],
374
+ },
375
+ }
376
+
377
+ except Exception as e:
378
+ logger.error(f"Error in decide: {e}")
379
+ # Default safe action
380
+ return {
381
+ "name": "nethack_interact",
382
+ "parameters": {"reasoning": f"Error: {str(e)}", "actions": ["wait"]},
383
+ }
384
+
385
+ async def run_episode(self, env: NetHackEnvironment) -> Dict[str, Any]:
386
+ """Run one episode with the agent."""
387
+ # Get task instructions
388
+ self.system_prompt = self._create_system_prompt(env.task_instance.impetus.instructions)
389
+
390
+ # Initialize environment
391
+ obs = await env.initialize()
392
+
393
+ # Track episode statistics
394
+ stats = {
395
+ "turns": 0,
396
+ "max_depth": 1,
397
+ "final_score": 0,
398
+ "total_reward": 0.0,
399
+ "balrog_total_reward": 0.0,
400
+ "balrog_score": 0.0,
401
+ "terminated": False,
402
+ "death_reason": None,
403
+ "objectives_completed": 0,
404
+ "achievements_unlocked": [],
405
+ "achievement_details": {},
406
+ "error": None,
407
+ "actions_taken": [], # Track all actions
408
+ "observations": [], # Track key observations
409
+ }
410
+
411
+ try:
412
+ for turn in range(self.max_turns):
413
+ stats["turns"] = turn + 1
414
+
415
+ # Format observation for agent
416
+ formatted_obs = self._format_observation(obs)
417
+
418
+ # Get agent decision
419
+ action = await self.decide(formatted_obs)
420
+
421
+ # Record the action
422
+ action_record = {
423
+ "turn": turn + 1,
424
+ "action_type": action["name"],
425
+ "action": "unknown", # Will be filled in later
426
+ "action_params": action.get("parameters", {}),
427
+ "position_before": obs.get("position", "unknown"),
428
+ "dungeon_level": obs.get("dungeon_level", 1),
429
+ }
430
+
431
+ # Check for termination
432
+ if action["name"] == "terminate":
433
+ stats["terminated"] = True
434
+ stats["death_reason"] = action["parameters"].get("reason", "Agent terminated")
435
+ action_record["action"] = "terminate"
436
+ action_record["result"] = "terminated"
437
+ stats["actions_taken"].append(action_record)
438
+ break
439
+
440
+ # Execute action(s)
441
+ if action["name"] == "nethack_interact":
442
+ params = action["parameters"]
443
+
444
+ # Log reasoning
445
+ reasoning = params.get("reasoning", "No reasoning provided")
446
+ logger.info(f"Reasoning: {reasoning}")
447
+
448
+ # Handle both old format (single action) and new format (multiple actions)
449
+ if "actions" in params:
450
+ actions_list = params["actions"]
451
+ elif "action" in params:
452
+ actions_list = [params["action"]]
453
+ else:
454
+ actions_list = ["wait"]
455
+
456
+ # Execute each action in sequence
457
+ for act in actions_list:
458
+ if obs.get("terminated", False):
459
+ break
460
+
461
+ # Handle "fight" by converting to movement
462
+ if act == "fight":
463
+ logger.warning(
464
+ "'fight' is not a valid action - to attack, move into the monster!"
465
+ )
466
+ # Skip this action
467
+ continue
468
+
469
+ # Save position before this specific action
470
+ pos_before = obs.get("position", "unknown")
471
+
472
+ obs = await env.step(act)
473
+
474
+ # Update the recent action result with the message from this action
475
+ result_msg = obs.get("message", "").rstrip("\x00").strip()
476
+ if not result_msg:
477
+ result_msg = "No message"
478
+ self._update_last_action_result(result_msg)
479
+
480
+ # Create a new record for each action
481
+ single_action_record = {
482
+ "turn": stats["turns"],
483
+ "action_type": "nethack_interact",
484
+ "action": act,
485
+ "reasoning": reasoning if act == actions_list[0] else "continuation",
486
+ "position_before": pos_before,
487
+ "position_after": obs.get("position", "unknown"),
488
+ "message": obs.get("message", "").rstrip("\x00")[:100],
489
+ "reward": obs.get("reward_last", 0),
490
+ "hp": obs.get("character_stats", {}).get("hp", "unknown"),
491
+ }
492
+ stats["actions_taken"].append(single_action_record)
493
+ else:
494
+ logger.warning(f"Unknown action: {action['name']}")
495
+ obs = await env.step("wait")
496
+ action_record["action"] = "wait (fallback)"
497
+ stats["actions_taken"].append(action_record)
498
+
499
+ # Update statistics - expect these fields to exist
500
+ stats["max_depth"] = max(stats["max_depth"], obs["dungeon_level"])
501
+ stats["total_reward"] = obs["total_reward"]
502
+ stats["final_score"] = obs["score"]
503
+
504
+ # Track achievements and Balrog rewards
505
+ if "achievements_unlocked" in obs:
506
+ for ach, unlocked in obs["achievements_unlocked"].items():
507
+ if unlocked and ach not in stats["achievements_unlocked"]:
508
+ stats["achievements_unlocked"].append(ach)
509
+
510
+ if "balrog_total_reward" in obs:
511
+ stats["balrog_total_reward"] = obs["balrog_total_reward"]
512
+
513
+ if "achievement_stats" in obs and "balrog_score" in obs["achievement_stats"]:
514
+ stats["balrog_score"] = obs["achievement_stats"]["balrog_score"]
515
+
516
+ # Update the last observation for next iteration
517
+ # (removed duplicate action record append)
518
+
519
+ # Record key observations every 5 turns
520
+ if turn % 5 == 0 or turn == 0:
521
+ stats["observations"].append(
522
+ {
523
+ "turn": turn + 1,
524
+ "position": obs.get("position", "unknown"),
525
+ "dungeon_level": obs.get("dungeon_level", 1),
526
+ "hp": f"{obs.get('character_stats', {}).get('hp', '?')}/{obs.get('character_stats', {}).get('max_hp', '?')}",
527
+ "score": obs.get("score", 0),
528
+ "message": obs.get("message", "")[:100],
529
+ }
530
+ )
531
+
532
+ # Check for game termination
533
+ if obs["terminated"]:
534
+ stats["terminated"] = True
535
+ if "died" in obs["message"].lower():
536
+ stats["death_reason"] = "Character died"
537
+ else:
538
+ stats["death_reason"] = "Game ended"
539
+ break
540
+
541
+ # Check if objective achieved
542
+ # We know metadata is NetHackTaskInstanceMetadata from environment
543
+ metadata = cast(NetHackTaskInstanceMetadata, env.task_instance.metadata)
544
+ target_depth = metadata.target_depth
545
+ if obs["dungeon_level"] >= target_depth:
546
+ logger.info(f"Objective achieved! Reached depth {target_depth}")
547
+ stats["objectives_completed"] += 1
548
+
549
+ except Exception as e:
550
+ logger.error(f"Error during episode: {e}")
551
+ stats["error"] = str(e)
552
+
553
+ finally:
554
+ # Ensure environment is terminated
555
+ await env.terminate()
556
+
557
+ return stats
558
+
559
+
560
+ async def eval_react_nethack(
561
+ model_name: str = "gpt-4.1-nano", num_episodes: int = 3, max_turns: int = 50
562
+ ) -> List[Dict[str, Any]]:
563
+ """Run ReAct agent evaluation on NetHack taskset."""
564
+ logger.info(f"Starting NetHack evaluation with model: {model_name}")
565
+
566
+ # Load taskset
567
+ taskset = await create_nethack_taskset()
568
+ logger.info(f"Loaded {len(taskset.instances)} task instances")
569
+
570
+ # Initialize LLM and agent
571
+ llm = LM(model_name=model_name, formatting_model_name=model_name, temperature=0.7)
572
+ agent = NetHackReActAgent(llm, max_turns=max_turns)
573
+
574
+ # Select subset of tasks for evaluation
575
+ # Focus on tasks that require actual exploration (target_depth > 1)
576
+ eval_instances = [
577
+ inst
578
+ for inst in taskset.instances
579
+ if hasattr(inst.metadata, "difficulty")
580
+ and hasattr(inst.metadata, "target_depth")
581
+ and cast(NetHackTaskInstanceMetadata, inst.metadata).target_depth > 1
582
+ and cast(NetHackTaskInstanceMetadata, inst.metadata).difficulty
583
+ in ["beginner", "intermediate"]
584
+ ][:num_episodes]
585
+
586
+ logger.info(f"Evaluating on {len(eval_instances)} instances")
587
+
588
+ # Run episodes
589
+ results = []
590
+ for i, instance in enumerate(eval_instances):
591
+ logger.info(f"\nEpisode {i + 1}/{len(eval_instances)}")
592
+ if hasattr(instance.metadata, "character_role"):
593
+ metadata = cast(NetHackTaskInstanceMetadata, instance.metadata)
594
+ logger.info(f"Character: {metadata.character_role}")
595
+ logger.info(f"Target depth: {metadata.target_depth}")
596
+ logger.info(f"Time limit: {metadata.time_limit} turns")
597
+
598
+ try:
599
+ env = NetHackEnvironment(instance)
600
+ result = await agent.run_episode(env)
601
+
602
+ # Add task info to result
603
+ result["task_id"] = str(instance.id)
604
+ if hasattr(instance.metadata, "character_role"):
605
+ metadata = cast(NetHackTaskInstanceMetadata, instance.metadata)
606
+ result["character_role"] = metadata.character_role
607
+ result["target_depth"] = metadata.target_depth
608
+ result["time_limit"] = metadata.time_limit
609
+ result["difficulty"] = metadata.difficulty
610
+ result["success"] = result["max_depth"] >= metadata.target_depth
611
+ else:
612
+ result["success"] = False
613
+
614
+ results.append(result)
615
+
616
+ logger.info(
617
+ f"Episode completed - Success: {result['success']}, "
618
+ f"Depth: {result['max_depth']}/{result['target_depth']}, "
619
+ f"Turns: {result['turns']}, Score: {result['final_score']}, "
620
+ f"Balrog: {result.get('balrog_score', 0):.1f}%"
621
+ )
622
+
623
+ except Exception as e:
624
+ logger.error(f"Failed to run episode: {e}")
625
+ results.append({"task_id": str(instance.id), "error": str(e), "success": False})
626
+
627
+ return results
628
+
629
+
630
+ def analyze_nethack_results(results: List[Dict[str, Any]]) -> Dict[str, float]:
631
+ """Analyze NetHack-specific performance metrics."""
632
+ if not results:
633
+ return {}
634
+
635
+ # Filter out errored results for metrics
636
+ valid_results = [r for r in results if "error" not in r or r["error"] is None]
637
+
638
+ if not valid_results:
639
+ return {"error_rate": 1.0, "num_episodes": len(results)}
640
+
641
+ metrics = {
642
+ "num_episodes": len(results),
643
+ "success_rate": sum(1 for r in valid_results if r.get("success", False))
644
+ / len(valid_results),
645
+ "avg_depth_reached": sum(r.get("max_depth", 1) for r in valid_results) / len(valid_results),
646
+ "avg_turns": sum(r.get("turns", 0) for r in valid_results) / len(valid_results),
647
+ "avg_score": sum(r.get("final_score", 0) for r in valid_results) / len(valid_results),
648
+ "avg_reward": sum(r.get("total_reward", 0.0) for r in valid_results) / len(valid_results),
649
+ "avg_balrog_reward": sum(r.get("balrog_total_reward", 0.0) for r in valid_results)
650
+ / len(valid_results),
651
+ "avg_balrog_score": sum(r.get("balrog_score", 0.0) for r in valid_results)
652
+ / len(valid_results),
653
+ "death_rate": sum(
654
+ 1 for r in valid_results if "died" in str(r.get("death_reason", "")).lower()
655
+ )
656
+ / len(valid_results),
657
+ "timeout_rate": sum(
658
+ 1 for r in valid_results if r.get("turns", 0) >= r.get("time_limit", float("inf"))
659
+ )
660
+ / len(valid_results),
661
+ "error_rate": sum(1 for r in results if "error" in r and r["error"] is not None)
662
+ / len(results),
663
+ # Achievement metrics
664
+ "avg_achievements_unlocked": sum(
665
+ len(r.get("achievements_unlocked", [])) for r in valid_results
666
+ )
667
+ / len(valid_results),
668
+ "total_unique_achievements": len(
669
+ set(ach for r in valid_results for ach in r.get("achievements_unlocked", []))
670
+ ),
671
+ }
672
+
673
+ # Count how many times each achievement was unlocked
674
+ achievement_counts = {}
675
+ for r in valid_results:
676
+ for ach in r.get("achievements_unlocked", []):
677
+ achievement_counts[ach] = achievement_counts.get(ach, 0) + 1
678
+
679
+ # Most common achievements
680
+ if achievement_counts:
681
+ sorted_achievements = sorted(achievement_counts.items(), key=lambda x: x[1], reverse=True)
682
+ metrics["most_common_achievements"] = sorted_achievements[:5]
683
+
684
+ # Breakdown by difficulty
685
+ for difficulty in ["tutorial", "beginner", "intermediate", "advanced", "expert"]:
686
+ diff_results = [r for r in valid_results if r.get("difficulty") == difficulty]
687
+ if diff_results:
688
+ metrics[f"{difficulty}_success_rate"] = sum(
689
+ 1 for r in diff_results if r.get("success", False)
690
+ ) / len(diff_results)
691
+ metrics[f"{difficulty}_avg_depth"] = sum(
692
+ r.get("max_depth", 1) for r in diff_results
693
+ ) / len(diff_results)
694
+ metrics[f"{difficulty}_avg_achievements"] = sum(
695
+ len(r.get("achievements_unlocked", [])) for r in diff_results
696
+ ) / len(diff_results)
697
+
698
+ # Breakdown by character role
699
+ for role in ["tourist", "knight", "wizard", "barbarian"]:
700
+ role_results = [r for r in valid_results if r.get("character_role") == role]
701
+ if role_results:
702
+ metrics[f"{role}_success_rate"] = sum(
703
+ 1 for r in role_results if r.get("success", False)
704
+ ) / len(role_results)
705
+ metrics[f"{role}_avg_achievements"] = sum(
706
+ len(r.get("achievements_unlocked", [])) for r in role_results
707
+ ) / len(role_results)
708
+
709
+ return metrics
710
+
711
+
712
+ async def main():
713
+ """Run the evaluation."""
714
+ # Parse command-line arguments
715
+ parser = argparse.ArgumentParser(description="Evaluate NetHack ReAct agent")
716
+ parser.add_argument(
717
+ "--model",
718
+ type=str,
719
+ default="gpt-4.1-nano",
720
+ help="Model to use (default: gpt-4.1-nano)",
721
+ )
722
+ parser.add_argument(
723
+ "--episodes", type=int, default=5, help="Number of episodes to run (default: 5)"
724
+ )
725
+ parser.add_argument(
726
+ "--max-turns",
727
+ type=int,
728
+ default=50,
729
+ help="Maximum turns per episode (default: 50)",
730
+ )
731
+
732
+ args = parser.parse_args()
733
+
734
+ print(f"Running NetHack evaluation with:")
735
+ print(f" Model: {args.model}")
736
+ print(f" Episodes: {args.episodes}")
737
+ print(f" Max turns: {args.max_turns}")
738
+
739
+ # Run evaluation
740
+ results = await eval_react_nethack(
741
+ model_name=args.model, num_episodes=args.episodes, max_turns=args.max_turns
742
+ )
743
+
744
+ # Analyze results
745
+ metrics = analyze_nethack_results(results)
746
+
747
+ # Print results - OFFICIAL LEADERBOARD SCORE FIRST
748
+ print("\n" + "=" * 80)
749
+ print("🏆 OFFICIAL BALROG LEADERBOARD SCORE 🏆")
750
+ print("=" * 80)
751
+ print(f"📊 BALROG SCORE (0-100%): {metrics.get('avg_balrog_score', 0):.3f}%")
752
+ print(f"📈 Current SOTA benchmark: ~1-2%")
753
+ print(f"🎯 Episodes evaluated: {metrics.get('num_episodes', 0)}")
754
+ print("=" * 80)
755
+ print("⚠️ This is the ONLY score that matters for SOTA claims and leaderboard comparisons!")
756
+ print("⚠️ All other metrics below are for analysis/debugging only.")
757
+
758
+ print("\n" + "=" * 80)
759
+ print("📋 ANALYSIS METRICS (Not for leaderboard comparison)")
760
+ print("=" * 80)
761
+ print(f"Success rate (task completion): {metrics.get('success_rate', 0):.2%}")
762
+ print(f"Average depth reached: {metrics.get('avg_depth_reached', 0):.1f}")
763
+ print(f"Average turns: {metrics.get('avg_turns', 0):.0f}")
764
+ print(f"Average game score: {metrics.get('avg_score', 0):.0f}")
765
+ print(f"Death rate: {metrics.get('death_rate', 0):.2%}")
766
+ print(f"Error rate: {metrics.get('error_rate', 0):.2%}")
767
+
768
+ print("\n=== Training Signal Metrics (Shaped Rewards) ===")
769
+ print(f"Average custom reward: {metrics.get('avg_reward', 0):.2f}")
770
+ print(f"Average Balrog shaped reward: {metrics.get('avg_balrog_reward', 0):.2f}")
771
+ print("(These are training signals, NOT the leaderboard score)")
772
+
773
+ print("\n=== Achievement Metrics ===")
774
+ print(f"Average achievements unlocked: {metrics.get('avg_achievements_unlocked', 0):.1f}")
775
+ print(f"Total unique achievements: {metrics.get('total_unique_achievements', 0)}")
776
+
777
+ if "most_common_achievements" in metrics and metrics["most_common_achievements"]:
778
+ print("\nMost common achievements:")
779
+ for ach, count in metrics["most_common_achievements"]:
780
+ print(f" {ach}: {count} times ({count / metrics['num_episodes'] * 100:.0f}%)")
781
+
782
+ # Save results
783
+ with open("nethack_react_results.json", "w") as f:
784
+ json.dump({"results": results, "metrics": metrics}, f, indent=2)
785
+
786
+ print("\nResults saved to nethack_react_results.json")
787
+
788
+ # Print detailed action summary for sanity check
789
+ print("\n=== Detailed Episode Summary ===")
790
+ for i, result in enumerate(results):
791
+ print(f"\nEpisode {i + 1}:")
792
+ print(f" 🏆 BALROG LEADERBOARD SCORE: {result.get('balrog_score', 0):.3f}%")
793
+ print(f" Character: {result.get('character_role', 'unknown')}")
794
+ print(f" Target depth: {result.get('target_depth', 'unknown')}")
795
+ print(f" Max depth reached: {result.get('max_depth', 0)}")
796
+ print(f" Total turns: {result.get('turns', 0)}")
797
+ print(f" Success: {result.get('success', False)}")
798
+ print(f" Final game score: {result.get('final_score', 0)}")
799
+ print(f" Custom shaped reward: {result.get('total_reward', 0):.2f}")
800
+ print(f" Balrog shaped reward: {result.get('balrog_total_reward', 0):.2f}")
801
+
802
+ if result.get("achievements_unlocked"):
803
+ print(f" Achievements unlocked ({len(result['achievements_unlocked'])}):")
804
+ for ach in result["achievements_unlocked"][:10]: # Show first 10
805
+ print(f" - {ach}")
806
+ if len(result["achievements_unlocked"]) > 10:
807
+ print(f" ... and {len(result['achievements_unlocked']) - 10} more")
808
+
809
+ if "actions_taken" in result and result["actions_taken"]:
810
+ print(f"\n First 10 actions:")
811
+ for action in result["actions_taken"][:10]:
812
+ print(
813
+ f" Turn {action.get('turn', '?')}: {action.get('action', 'unknown')} "
814
+ f"(pos: {action.get('position_before', '?')} → {action.get('position_after', '?')}, "
815
+ f"HP: {action.get('hp', '?')})"
816
+ )
817
+ if action.get("reasoning") and action["reasoning"] != "continuation":
818
+ print(f" Reasoning: {action['reasoning'][:80]}...")
819
+ if action.get("message", "").strip():
820
+ print(f" Message: {action['message'][:60]}...")
821
+
822
+ if "observations" in result and result["observations"]:
823
+ print(f"\n Key observations:")
824
+ for obs in result["observations"]:
825
+ print(
826
+ f" Turn {obs['turn']}: Level {obs['dungeon_level']}, "
827
+ f"HP: {obs['hp']}, Score: {obs['score']}"
828
+ )
829
+
830
+
831
+ if __name__ == "__main__":
832
+ asyncio.run(main())