synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. synth_ai/__init__.py +28 -2
  2. synth_ai/core/system.py +4 -0
  3. synth_ai/environments/__init__.py +35 -0
  4. synth_ai/environments/environment/__init__.py +1 -0
  5. synth_ai/environments/environment/artifacts/__init__.py +1 -0
  6. synth_ai/environments/environment/artifacts/base.py +50 -0
  7. synth_ai/environments/environment/core.py +22 -0
  8. synth_ai/environments/environment/db/__init__.py +1 -0
  9. synth_ai/environments/environment/db/sqlite.py +45 -0
  10. synth_ai/environments/environment/registry.py +24 -0
  11. synth_ai/environments/environment/resources/sqlite.py +46 -0
  12. synth_ai/environments/environment/results.py +1 -0
  13. synth_ai/environments/environment/rewards/__init__.py +1 -0
  14. synth_ai/environments/environment/rewards/core.py +28 -0
  15. synth_ai/environments/environment/shared_engine.py +26 -0
  16. synth_ai/environments/environment/tools/__init__.py +34 -0
  17. synth_ai/environments/examples/__init__.py +1 -0
  18. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
  26. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  27. synth_ai/environments/examples/crafter_classic/engine.py +502 -0
  28. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  29. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  30. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  31. synth_ai/environments/examples/crafter_classic/environment.py +255 -0
  32. synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
  33. synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
  34. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  35. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  36. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  37. synth_ai/environments/examples/enron/engine.py +291 -0
  38. synth_ai/environments/examples/enron/environment.py +165 -0
  39. synth_ai/environments/examples/enron/taskset.py +112 -0
  40. synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
  41. synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
  42. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  43. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  44. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
  45. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  46. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
  47. synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
  48. synth_ai/environments/examples/minigrid/engine.py +589 -0
  49. synth_ai/environments/examples/minigrid/environment.py +274 -0
  50. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  51. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  52. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  53. synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
  54. synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
  55. synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
  56. synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
  57. synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
  58. synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
  59. synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
  60. synth_ai/environments/examples/nethack/__init__.py +7 -0
  61. synth_ai/environments/examples/nethack/achievements.py +337 -0
  62. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  63. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  64. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
  65. synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
  66. synth_ai/environments/examples/nethack/engine.py +738 -0
  67. synth_ai/environments/examples/nethack/environment.py +255 -0
  68. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  69. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  70. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  71. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  72. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  73. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  74. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  75. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  76. synth_ai/environments/examples/nethack/taskset.py +323 -0
  77. synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
  78. synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
  79. synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
  80. synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
  81. synth_ai/environments/examples/red/__init__.py +7 -0
  82. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  83. synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
  84. synth_ai/environments/examples/red/config_logging.py +110 -0
  85. synth_ai/environments/examples/red/engine.py +693 -0
  86. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  87. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  88. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  89. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  90. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  91. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  92. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  93. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  94. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  95. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  96. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  97. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  98. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  99. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  100. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  101. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  102. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  103. synth_ai/environments/examples/red/environment.py +235 -0
  104. synth_ai/environments/examples/red/taskset.py +77 -0
  105. synth_ai/environments/examples/red/test_fixes.py +125 -0
  106. synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
  107. synth_ai/environments/examples/red/units/__init__.py +1 -0
  108. synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
  109. synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
  110. synth_ai/environments/examples/red/units/test_engine.py +192 -0
  111. synth_ai/environments/examples/red/units/test_environment.py +455 -0
  112. synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
  113. synth_ai/environments/examples/red/units/test_integration.py +217 -0
  114. synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
  115. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
  116. synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
  117. synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
  118. synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
  119. synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
  120. synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
  121. synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
  122. synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
  123. synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
  124. synth_ai/environments/examples/red/units/test_taskset.py +116 -0
  125. synth_ai/environments/examples/red/units/test_tree.py +448 -0
  126. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  127. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
  128. synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
  129. synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
  130. synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
  131. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
  132. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
  133. synth_ai/environments/examples/sokoban/engine.py +675 -0
  134. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  135. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  136. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  137. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  138. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  139. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  140. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  141. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  142. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  143. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  144. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  145. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  146. synth_ai/environments/examples/sokoban/environment.py +228 -0
  147. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  148. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  149. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  150. synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
  151. synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
  152. synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
  153. synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
  154. synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
  155. synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
  156. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  157. synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
  158. synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
  159. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  160. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  161. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  162. synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
  163. synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
  164. synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
  165. synth_ai/environments/examples/verilog/__init__.py +10 -0
  166. synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
  167. synth_ai/environments/examples/verilog/engine.py +328 -0
  168. synth_ai/environments/examples/verilog/environment.py +349 -0
  169. synth_ai/environments/examples/verilog/taskset.py +418 -0
  170. synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
  171. synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
  172. synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
  173. synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
  174. synth_ai/environments/reproducibility/core.py +42 -0
  175. synth_ai/environments/reproducibility/tree.py +364 -0
  176. synth_ai/environments/service/app.py +78 -0
  177. synth_ai/environments/service/core_routes.py +775 -0
  178. synth_ai/environments/service/external_registry.py +57 -0
  179. synth_ai/environments/service/registry.py +9 -0
  180. synth_ai/environments/stateful/__init__.py +1 -0
  181. synth_ai/environments/stateful/core.py +28 -0
  182. synth_ai/environments/stateful/engine.py +21 -0
  183. synth_ai/environments/stateful/state.py +7 -0
  184. synth_ai/environments/tasks/api.py +19 -0
  185. synth_ai/environments/tasks/core.py +78 -0
  186. synth_ai/environments/tasks/filters.py +39 -0
  187. synth_ai/environments/tasks/utils.py +89 -0
  188. synth_ai/environments/v0_observability/history.py +3 -0
  189. synth_ai/environments/v0_observability/log.py +2 -0
  190. synth_ai/lm/caching/constants.py +1 -0
  191. synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
  192. synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
  193. synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
  194. synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
  195. synth_ai/{zyk/lms → lm}/config.py +2 -1
  196. synth_ai/{zyk/lms → lm}/constants.py +2 -2
  197. synth_ai/{zyk/lms → lm}/core/all.py +10 -10
  198. synth_ai/{zyk/lms → lm}/core/main.py +57 -33
  199. synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
  200. synth_ai/lm/cost/monitor.py +1 -0
  201. synth_ai/lm/cost/statefulness.py +1 -0
  202. synth_ai/lm/provider_support/__init__.py +8 -0
  203. synth_ai/lm/provider_support/anthropic.py +945 -0
  204. synth_ai/lm/provider_support/openai.py +1115 -0
  205. synth_ai/lm/provider_support/suppress_logging.py +31 -0
  206. synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
  207. synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
  208. synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
  209. synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
  210. synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
  211. synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
  212. synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
  213. synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
  214. synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
  215. synth_ai/lm/vendors/supported/__init__.py +0 -0
  216. synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
  217. synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
  218. synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
  219. synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
  220. synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
  221. synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
  222. synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
  223. synth_ai/tracing/__init__.py +0 -0
  224. synth_ai/tracing/abstractions.py +224 -0
  225. synth_ai/tracing/base_client.py +91 -0
  226. synth_ai/tracing/client_manager.py +131 -0
  227. synth_ai/tracing/config.py +140 -0
  228. synth_ai/tracing/context.py +146 -0
  229. synth_ai/tracing/decorators.py +679 -0
  230. synth_ai/tracing/events/__init__.py +0 -0
  231. synth_ai/tracing/events/manage.py +147 -0
  232. synth_ai/tracing/events/scope.py +86 -0
  233. synth_ai/tracing/events/store.py +227 -0
  234. synth_ai/tracing/immediate_client.py +152 -0
  235. synth_ai/tracing/local.py +18 -0
  236. synth_ai/tracing/log_client_base.py +74 -0
  237. synth_ai/tracing/retry_queue.py +187 -0
  238. synth_ai/tracing/trackers.py +515 -0
  239. synth_ai/tracing/upload.py +504 -0
  240. synth_ai/tracing/utils.py +9 -0
  241. synth_ai/zyk/__init__.py +28 -2
  242. synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
  243. synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
  244. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
  245. synth_ai/zyk/lms/caching/constants.py +0 -1
  246. synth_ai/zyk/lms/cost/monitor.py +0 -1
  247. synth_ai/zyk/lms/cost/statefulness.py +0 -1
  248. synth_ai-0.2.0.dist-info/METADATA +0 -36
  249. synth_ai-0.2.0.dist-info/RECORD +0 -50
  250. /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
  251. /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
  252. /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
  253. /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
  254. /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
  255. /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
  256. /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
  257. /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
  258. /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
  259. /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
  260. /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
  261. /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
  262. /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
  263. /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
  264. /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
  265. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
  266. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1112 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to run ReAct agents against NetHack environment on synth service (port 8901)
4
+ Tests on multiple easy NetHack instances with enhanced debugging
5
+ """
6
+
7
+ import asyncio
8
+ import json
9
+ import uuid
10
+ from datetime import datetime
11
+ from typing import Dict, Any, Optional, List
12
+ from pydantic import BaseModel, Field
13
+ from httpx import AsyncClient
14
+ import sys
15
+ import os
16
+ from tqdm import tqdm
17
+
18
+ # Add the src directory to the path
19
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "src"))
20
+
21
+ from synth_ai.zyk import LM
22
+ from synth_ai.zyk.lms.tools.base import BaseTool
23
+
24
+
25
+ # --- Configuration Class ---
26
+ class NetHackConfig:
27
+ """Configuration for NetHack evaluation (mirrors CrafterConfig)."""
28
+
29
+ def __init__(self, config_path: Optional[str] = None):
30
+ # Default values
31
+ self.model_name = "gpt-4.1-mini"
32
+ self.num_instances = 2
33
+ self.max_turns = 40
34
+ self.difficulty = "beginner"
35
+ self.service_base_url = "http://localhost:8901"
36
+ self.service_timeout = 30.0
37
+ self.seed = 42
38
+ self.save_traces = True
39
+ self.save_detailed_results = True
40
+
41
+ # Load from TOML if supplied
42
+ if config_path and os.path.exists(config_path):
43
+ try:
44
+ import toml
45
+
46
+ cfg = toml.load(config_path)
47
+
48
+ eval_cfg = cfg.get("eval", {})
49
+ self.model_name = eval_cfg.get("model_name", self.model_name)
50
+ self.num_instances = eval_cfg.get("episodes", self.num_instances)
51
+ self.max_turns = eval_cfg.get("max_steps", self.max_turns)
52
+ self.difficulty = eval_cfg.get("difficulty", self.difficulty)
53
+ self.seed = eval_cfg.get("seed", self.seed)
54
+
55
+ svc_cfg = cfg.get("service", {})
56
+ self.service_base_url = svc_cfg.get("base_url", self.service_base_url)
57
+ self.service_timeout = svc_cfg.get("timeout", self.service_timeout)
58
+
59
+ out_cfg = cfg.get("output", {})
60
+ self.save_traces = out_cfg.get("save_traces", self.save_traces)
61
+ self.save_detailed_results = out_cfg.get(
62
+ "save_detailed_results", self.save_detailed_results
63
+ )
64
+ except Exception as e:
65
+ print(f"[WARNING] Failed to load config from {config_path}: {e}")
66
+
67
+
68
+ # Instantiate default config (may be overridden by CLI later)
69
+ config = NetHackConfig()
70
+
71
+
72
+ # Overwrite the original global constants to use config values (so rest of script works unchanged)
73
+ def _apply_config_to_globals(cfg: NetHackConfig):
74
+ globals()["MODEL_NAME"] = cfg.model_name
75
+ globals()["NUM_INSTANCES"] = cfg.num_instances
76
+ globals()["MAX_TURNS"] = cfg.max_turns
77
+ globals()["DIFFICULTY"] = cfg.difficulty
78
+ globals()["SERVICE_BASE_URL"] = cfg.service_base_url
79
+
80
+
81
+ _apply_config_to_globals(config)
82
+
83
+ # --- CLI Override (similar to Crafter script) ---
84
+ # CLI parsing moved to end of file after main() is defined
85
+
86
+
87
+ # --- Service Configuration ---
88
+ SERVICE_BASE_URL = "http://localhost:8901"
89
+ MODEL_NAME = "gpt-4.1-mini"
90
+ NUM_INSTANCES = 2
91
+ MAX_TURNS = 40
92
+ DIFFICULTY = "beginner" # beginner, beginner, intermediate, advanced, expert
93
+
94
+
95
+ # --- Tool Definitions ---
96
+ class NetHackActionArgs(BaseModel):
97
+ """Arguments for nethack actions."""
98
+
99
+ actions: List[str] = Field(
100
+ description="List of 1-3 action names to execute in sequence (e.g., ['north', 'search', 'inventory'])"
101
+ )
102
+ reasoning: str = Field(description="Brief explanation of why these actions were chosen")
103
+
104
+
105
+ class TerminateArgs(BaseModel):
106
+ """Arguments for termination."""
107
+
108
+ reason: str = Field(description="Reason for termination")
109
+
110
+
111
+ class NetHackActionTool(BaseTool):
112
+ """Tool for performing actions in the NetHack environment."""
113
+
114
+ name: str = "interact"
115
+ arguments: type[BaseModel] = NetHackActionArgs
116
+ description: str = "Perform 1-3 actions in sequence in the NetHack environment."
117
+
118
+
119
+ class TerminateTool(BaseTool):
120
+ """Tool to terminate the episode."""
121
+
122
+ name: str = "terminate"
123
+ arguments: type[BaseModel] = TerminateArgs
124
+ description: str = "End the episode when finished or no progress can be made."
125
+
126
+
127
+ # --- Base ReAct Agent ---
128
+ class BaseReActAgent:
129
+ """Base ReAct agent for environment interaction."""
130
+
131
+ def __init__(self, llm: LM, max_turns: int = 30, verbose: bool = False):
132
+ self.llm = llm
133
+ self.max_turns = max_turns
134
+ self.verbose = verbose
135
+ self.history = []
136
+ self.system_name = "base-react-agent"
137
+
138
+ # Define tools in OpenAI format
139
+ self.tools = [
140
+ NetHackActionTool(),
141
+ TerminateTool(),
142
+ ]
143
+
144
+ async def decide(self, obs: str, system_message: str, turn: int) -> Dict[str, Any]:
145
+ """Get agent decision based on observation."""
146
+ # Create conversation context
147
+ context = f"Turn {turn + 1}/{self.max_turns}\n\n{obs}"
148
+
149
+ # Generate response using LLM
150
+ response_obj = await self.llm.respond_async(
151
+ system_message=system_message, user_message=context, tools=self.tools
152
+ )
153
+
154
+ tool_calls = response_obj.tool_calls
155
+
156
+ # Handle case where tool_calls is None or empty (graceful fallback)
157
+ if not tool_calls:
158
+ if self.verbose:
159
+ print(f"[WARNING] No tool calls returned by LLM, using default action")
160
+ return {
161
+ "name": "interact",
162
+ "parameters": {
163
+ "actions": ["inventory"],
164
+ "reasoning": "Default action - no tool call received",
165
+ },
166
+ }
167
+
168
+ tool_call_data = tool_calls[0]
169
+
170
+ # Handle both dict and object formats
171
+ if isinstance(tool_call_data, dict):
172
+ tool_name = tool_call_data["function"]["name"]
173
+ tool_args_str = tool_call_data["function"]["arguments"]
174
+ else:
175
+ tool_name = tool_call_data.function.name
176
+ tool_args_str = tool_call_data.function.arguments
177
+
178
+ tool_arguments = json.loads(tool_args_str)
179
+
180
+ return {"name": tool_name, "parameters": tool_arguments}
181
+
182
+
183
+ # --- NetHack ReAct Agent ---
184
+ class NetHackReActAgent(BaseReActAgent):
185
+ """ReAct agent for NetHack environment."""
186
+
187
+ def __init__(self, llm: LM, max_turns: int = 30, verbose: bool = False):
188
+ super().__init__(llm, max_turns, verbose)
189
+ self.system_name = "nethack-react-agent"
190
+
191
+ def get_system_message(self) -> str:
192
+ return """You are an expert NetHack player. Your goal is to explore the dungeon, survive, and make progress.
193
+
194
+ MOVEMENT ACTIONS:
195
+ - north, south, east, west: Move in cardinal directions
196
+ - northeast, northwest, southeast, southwest: Move diagonally
197
+ - go_up, go_down: Use stairs (must be on < or > symbol)
198
+
199
+ EXPLORATION ACTIONS:
200
+ - search: Look for secret doors or traps
201
+ - open: Open doors
202
+ - close: Close doors
203
+ - look: Examine surroundings (FREE ACTION)
204
+
205
+ INVENTORY ACTIONS:
206
+ - inventory: Check your items (FREE ACTION)
207
+ - pickup: Pick up items
208
+ - drop: Drop items
209
+ - wear: Put on armor
210
+ - wield: Equip weapon
211
+ - eat: Consume food
212
+ - drink: Drink potion
213
+ - read: Read scroll
214
+
215
+ INTERACTION:
216
+ - wait: Rest for one turn
217
+ - chat: Talk to NPCs
218
+ - pay: Pay shopkeeper
219
+ - kick: Kick something
220
+
221
+ MAP SYMBOLS:
222
+ - @ = you (the player)
223
+ - . = floor
224
+ - # = wall/corridor
225
+ - + = closed door
226
+ - - = open door
227
+ - < = stairs up
228
+ - > = stairs down
229
+ - $ = gold
230
+ - % = food
231
+ - ! = potion
232
+ - ? = scroll
233
+ - / = wand
234
+ - ) = weapon
235
+ - [ = armor
236
+ - d,f = pets (dog/cat)
237
+ - Letters = monsters
238
+
239
+ STRATEGY:
240
+ 1. Explore systematically to map the dungeon
241
+ 2. Collect useful items and gold
242
+ 3. Manage hunger by eating food
243
+ 4. Fight weak monsters for experience
244
+ 5. Use 'look' and 'inventory' frequently (they're free!)
245
+ 6. Be cautious around unknown monsters
246
+
247
+ Remember: NetHack is complex but rewarding. Take your time and observe carefully."""
248
+
249
+ def format_observation(self, obs: Dict[str, Any]) -> str:
250
+ """Format observation for NetHack."""
251
+ parts = []
252
+
253
+ if "ascii_map" in obs:
254
+ parts.append("ASCII Map:")
255
+ parts.append(obs["ascii_map"])
256
+
257
+ if "message" in obs and obs["message"]:
258
+ parts.append(f"Message: {obs['message']}")
259
+
260
+ if "character_stats" in obs:
261
+ stats = obs["character_stats"]
262
+ stat_items = []
263
+ for key, value in stats.items():
264
+ if key in ["HP", "level", "gold", "score", "turn"]:
265
+ stat_items.append(f"{key}: {value}")
266
+ if stat_items:
267
+ parts.append(f"Stats: {', '.join(stat_items)}")
268
+
269
+ if "inventory_summary" in obs:
270
+ parts.append(f"Inventory: {obs['inventory_summary']}")
271
+
272
+ if "hunger_status" in obs and obs["hunger_status"]:
273
+ parts.append(f"Hunger: {obs['hunger_status']}")
274
+
275
+ if "terminated" in obs:
276
+ parts.append(f"Terminated: {obs['terminated']}")
277
+
278
+ if "reward" in obs:
279
+ parts.append(f"Reward: {obs['reward']}")
280
+
281
+ return "\n".join(parts) if parts else "No formatted observation available"
282
+
283
+
284
+ # --- Episode Runner ---
285
+ async def run_single_episode(
286
+ client: AsyncClient,
287
+ agent: NetHackReActAgent,
288
+ task_instance,
289
+ instance_num: int,
290
+ progress_bar=None,
291
+ ) -> Dict[str, Any]:
292
+ """Run a single NetHack episode and return episode metrics."""
293
+ try:
294
+ # Create environment using the task instance
295
+ create_resp = await client.post(
296
+ f"/env/NetHack/initialize", json={"task_instance": await task_instance.serialize()}
297
+ )
298
+
299
+ if create_resp.status_code != 200:
300
+ print(
301
+ f" Instance {instance_num}: Failed to create environment - {create_resp.status_code}: {create_resp.text}"
302
+ )
303
+ return {"eval_metric": 0.0, "rubric": {}, "error": True}
304
+
305
+ env_id = create_resp.json()["env_id"]
306
+
307
+ # Get initial observation
308
+ obs = create_resp.json()["observation"]
309
+ formatted_obs = agent.format_observation(obs)
310
+
311
+ # DEBUG: Print initial state
312
+ # print(f"\n Instance {instance_num}: Starting NetHack adventure")
313
+ # print(f" Character: {task_instance.metadata.character_role}")
314
+ # print(f" Goal: Reach depth {task_instance.metadata.target_depth}")
315
+
316
+ # Track progress
317
+ initial_depth = 1
318
+ max_depth_reached = initial_depth
319
+ max_reward = 0.0
320
+ final_stats = {}
321
+ balrog_score = 0.0
322
+ balrog_total_reward = 0.0
323
+ achievements_unlocked = []
324
+
325
+ # Track additional progress metrics
326
+ monsters_killed = 0
327
+ items_picked_up = 0
328
+ scrolls_read = 0
329
+ potions_drunk = 0
330
+ rooms_explored = 0
331
+ secret_doors_found = 0
332
+ stairs_found = 0
333
+ traps_encountered = 0
334
+ spells_cast = 0
335
+ prayers_attempted = 0
336
+ max_score = 0
337
+
338
+ # Track shaped rewards (requires previous observation)
339
+ prev_obs = None
340
+ shaped_rewards = {
341
+ # Survival & Progress
342
+ "depth_delta_total": 0.0,
343
+ "stairs_seen_total": 0,
344
+ "turn_alive_total": 0.0,
345
+ "hp_gain_total": 0.0,
346
+ "hunger_ok_total": 0,
347
+ # Exploration
348
+ "new_tiles_total": 0,
349
+ "rooms_explored_delta_total": 0,
350
+ "secret_doors_delta_total": 0,
351
+ "traps_identified_delta_total": 0,
352
+ # Combat
353
+ "monsters_killed_delta_total": 0,
354
+ "dmg_dealt_total": 0.0,
355
+ "dmg_taken_total": 0.0,
356
+ # Resources
357
+ "gold_delta_total": 0,
358
+ "items_picked_delta_total": 0,
359
+ "scrolls_read_delta_total": 0,
360
+ "potions_quaffed_delta_total": 0,
361
+ "spells_cast_delta_total": 0,
362
+ # Skill/Utility
363
+ "first_prayer_achieved": False,
364
+ "first_spell_achieved": False,
365
+ "identify_item_total": 0,
366
+ # Achievements
367
+ "achievement_unlocked_total": 0,
368
+ # Intermediate reward accumulation
369
+ "total_intermediate_reward": 0.0,
370
+ }
371
+
372
+ # Run episode
373
+ for turn in range(agent.max_turns):
374
+ # Get agent decision
375
+ action = await agent.decide(formatted_obs, agent.get_system_message(), turn)
376
+
377
+ # Check for termination
378
+ if action["name"] == "terminate":
379
+ print(
380
+ f" Agent terminated: {action['parameters'].get('reason', 'no reason given')}"
381
+ )
382
+ break
383
+
384
+ # Execute actions in environment
385
+ action_sequence = action["parameters"]["actions"]
386
+
387
+ step_resp = await client.post(
388
+ f"/env/NetHack/step",
389
+ json={
390
+ "env_id": env_id,
391
+ "request_id": str(uuid.uuid4()),
392
+ "action": {
393
+ "tool_calls": [{"tool": "interact", "args": {"actions": action_sequence}}]
394
+ },
395
+ },
396
+ )
397
+
398
+ if step_resp.status_code != 200:
399
+ print(f" āŒ Step failed: {step_resp.status_code}: {step_resp.text}")
400
+ break
401
+
402
+ obs = step_resp.json()["observation"]
403
+ formatted_obs = agent.format_observation(obs)
404
+
405
+ # Calculate shaped rewards if we have a previous observation
406
+ if prev_obs is not None:
407
+ # --- Survival & Progress ---
408
+ current_depth = obs.get("character_stats", {}).get("dungeon_level", 1)
409
+ prev_depth = prev_obs.get("character_stats", {}).get("dungeon_level", 1)
410
+ depth_delta = current_depth - prev_depth
411
+ shaped_rewards["depth_delta_total"] += depth_delta
412
+
413
+ stairs_seen = int(obs.get("stairs_found", 0) > prev_obs.get("stairs_found", 0))
414
+ shaped_rewards["stairs_seen_total"] += stairs_seen
415
+
416
+ shaped_rewards["turn_alive_total"] += 0.01 # tiny tick reward every step survived
417
+
418
+ # HP calculations
419
+ current_hp = obs.get("character_stats", {}).get("hp", 1)
420
+ current_max_hp = obs.get("character_stats", {}).get("max_hp", 1)
421
+ prev_hp = prev_obs.get("character_stats", {}).get("hp", 1)
422
+ prev_max_hp = prev_obs.get("character_stats", {}).get("max_hp", 1)
423
+
424
+ if current_max_hp > 0 and prev_max_hp > 0:
425
+ hp_pct = current_hp / current_max_hp
426
+ prev_hp_pct = prev_hp / prev_max_hp
427
+ hp_gain = hp_pct - prev_hp_pct
428
+ shaped_rewards["hp_gain_total"] += hp_gain
429
+
430
+ hunger_ok = int(obs.get("hunger_status", "") in ("Not hungry", "Satiated", ""))
431
+ shaped_rewards["hunger_ok_total"] += hunger_ok
432
+
433
+ # --- Exploration ---
434
+ new_tiles = obs.get("exploration_stats", {}).get("new_tiles", 0)
435
+ shaped_rewards["new_tiles_total"] += new_tiles
436
+
437
+ rooms_explored_delta = obs.get("rooms_explored", 0) - prev_obs.get(
438
+ "rooms_explored", 0
439
+ )
440
+ shaped_rewards["rooms_explored_delta_total"] += rooms_explored_delta
441
+
442
+ secret_doors_delta = obs.get("secret_doors_found", 0) - prev_obs.get(
443
+ "secret_doors_found", 0
444
+ )
445
+ shaped_rewards["secret_doors_delta_total"] += secret_doors_delta
446
+
447
+ traps_identified_delta = obs.get("traps_encountered", 0) - prev_obs.get(
448
+ "traps_encountered", 0
449
+ )
450
+ shaped_rewards["traps_identified_delta_total"] += traps_identified_delta
451
+
452
+ # --- Combat ---
453
+ monsters_killed_delta = obs.get("achievement_stats", {}).get(
454
+ "monsters_killed", 0
455
+ ) - prev_obs.get("achievement_stats", {}).get("monsters_killed", 0)
456
+ shaped_rewards["monsters_killed_delta_total"] += monsters_killed_delta
457
+
458
+ dmg_dealt = obs.get("combat", {}).get("damage_dealt", 0)
459
+ shaped_rewards["dmg_dealt_total"] += dmg_dealt
460
+
461
+ dmg_taken = obs.get("combat", {}).get("damage_taken", 0)
462
+ shaped_rewards["dmg_taken_total"] += dmg_taken
463
+
464
+ # --- Resources ---
465
+ gold_delta = obs.get("character_stats", {}).get("gold", 0) - prev_obs.get(
466
+ "character_stats", {}
467
+ ).get("gold", 0)
468
+ shaped_rewards["gold_delta_total"] += gold_delta
469
+
470
+ items_picked_delta = obs.get("items_collected", 0) - prev_obs.get(
471
+ "items_collected", 0
472
+ )
473
+ shaped_rewards["items_picked_delta_total"] += items_picked_delta
474
+
475
+ scrolls_read_delta = obs.get("scrolls_read", 0) - prev_obs.get("scrolls_read", 0)
476
+ shaped_rewards["scrolls_read_delta_total"] += scrolls_read_delta
477
+
478
+ potions_quaffed_delta = obs.get("potions_drunk", 0) - prev_obs.get(
479
+ "potions_drunk", 0
480
+ )
481
+ shaped_rewards["potions_quaffed_delta_total"] += potions_quaffed_delta
482
+
483
+ spells_cast_delta = obs.get("spells_cast", 0) - prev_obs.get("spells_cast", 0)
484
+ shaped_rewards["spells_cast_delta_total"] += spells_cast_delta
485
+
486
+ # --- Skill/Utility ---
487
+ if (
488
+ obs.get("prayers_attempted", 0) > 0
489
+ and prev_obs.get("prayers_attempted", 0) == 0
490
+ ):
491
+ shaped_rewards["first_prayer_achieved"] = True
492
+
493
+ if spells_cast_delta > 0 and prev_obs.get("spells_cast", 0) == 0:
494
+ shaped_rewards["first_spell_achieved"] = True
495
+
496
+ message = obs.get("message", "")
497
+ if isinstance(message, bytes):
498
+ message = message.decode("ascii", errors="ignore").strip("\x00")
499
+ if "You identify" in message:
500
+ shaped_rewards["identify_item_total"] += 1
501
+
502
+ # --- Achievements ---
503
+ current_achievements = obs.get("achievements_unlocked", {})
504
+ prev_achievements = prev_obs.get("achievements_unlocked", {})
505
+ achievement_unlocked = sum(
506
+ int(v and not prev_achievements.get(k, False))
507
+ for k, v in current_achievements.items()
508
+ )
509
+ shaped_rewards["achievement_unlocked_total"] += achievement_unlocked
510
+
511
+ # --- Calculate intermediate reward ---
512
+ intermediate_reward = (
513
+ 1.0 * depth_delta
514
+ + 0.2 * new_tiles
515
+ + 2.0 * monsters_killed_delta
516
+ - 0.5 * dmg_taken / 10
517
+ + 0.1 * gold_delta
518
+ + 5.0 * achievement_unlocked
519
+ )
520
+ shaped_rewards["total_intermediate_reward"] += intermediate_reward
521
+
522
+ # Store current observation as previous for next iteration
523
+ prev_obs = obs.copy() if obs else None
524
+
525
+ # Track progress
526
+ if "character_stats" in obs:
527
+ final_stats = obs["character_stats"]
528
+ if "dungeon_level" in final_stats:
529
+ current_depth = final_stats["dungeon_level"]
530
+ max_depth_reached = max(max_depth_reached, current_depth)
531
+
532
+ reward = obs.get("reward", 0.0)
533
+ max_reward = max(max_reward, reward)
534
+
535
+ # Track achievements and Balrog rewards (like in main agent)
536
+ if "achievements_unlocked" in obs:
537
+ for ach, unlocked in obs["achievements_unlocked"].items():
538
+ if unlocked and ach not in achievements_unlocked:
539
+ achievements_unlocked.append(ach)
540
+
541
+ if "balrog_total_reward" in obs:
542
+ balrog_total_reward = obs["balrog_total_reward"]
543
+
544
+ if "achievement_stats" in obs and "balrog_score" in obs["achievement_stats"]:
545
+ balrog_score = obs["achievement_stats"]["balrog_score"]
546
+
547
+ # Track additional progress metrics from achievement stats
548
+ if "achievement_stats" in obs:
549
+ ach_stats = obs["achievement_stats"]
550
+ monsters_killed = ach_stats.get("monsters_killed", 0)
551
+ items_picked_up = ach_stats.get("items_collected", 0)
552
+ rooms_explored = ach_stats.get("rooms_explored", 0)
553
+ secret_doors_found = ach_stats.get("secret_doors_found", 0)
554
+ stairs_found = ach_stats.get("stairs_found", 0)
555
+
556
+ # Track score progression
557
+ current_score = obs.get("score", 0)
558
+ max_score = max(max_score, current_score)
559
+
560
+ # Parse message for additional events
561
+ message = obs.get("message", "")
562
+ if isinstance(message, bytes):
563
+ message = message.decode("ascii", errors="ignore").strip("\x00")
564
+
565
+ # Look for specific events in messages
566
+ if "You read" in message:
567
+ scrolls_read += 1
568
+ elif "You drink" in message:
569
+ potions_drunk += 1
570
+ elif "You cast" in message or "spell" in message.lower():
571
+ spells_cast += 1
572
+ elif "You pray" in message:
573
+ prayers_attempted += 1
574
+ elif "trap" in message.lower():
575
+ traps_encountered += 1
576
+
577
+ # Check if episode ended
578
+ terminated = obs.get("terminated", False)
579
+
580
+ if terminated:
581
+ print(
582
+ f" šŸ“Š Instance {instance_num}: Episode ended at depth {max_depth_reached}, reward: {max_reward:.3f}"
583
+ )
584
+ break
585
+
586
+ # Update progress bar
587
+ if progress_bar is not None:
588
+ progress_bar.update(1)
589
+
590
+ # Cleanup
591
+ await client.post(f"/env/NetHack/terminate", json={"env_id": env_id})
592
+
593
+ # Ensure progress bar completes
594
+ if progress_bar is not None:
595
+ progress_bar.n = progress_bar.total
596
+ progress_bar.close()
597
+
598
+ # Calculate eval metric and rubric
599
+ target_depth = task_instance.metadata.target_depth
600
+
601
+ # Balrog score: Use proper score from observation (like in main agent)
602
+ # This is the standard NetHack evaluation metric
603
+
604
+ # Eval metric is the normalized Balrog score (0-1)
605
+ eval_metric = balrog_score / 100.0
606
+
607
+ # Create rubric with specific achievements
608
+ rubric = {
609
+ # Core progression metrics
610
+ "reached_target_depth": 1.0 if max_depth_reached >= target_depth else 0.0,
611
+ "depth_progress": min(1.0, max_depth_reached / target_depth),
612
+ "gained_experience": 1.0 if final_stats.get("experience", 0) > 0 else 0.0,
613
+ "collected_gold": 1.0 if final_stats.get("gold", 0) > 100 else 0.0,
614
+ "gained_levels": 1.0 if final_stats.get("level", 1) > 1 else 0.0,
615
+ "survived_turns": min(1.0, len(agent.history) / 20.0), # Normalize to 20 turns
616
+ "positive_reward": 1.0 if max_reward > 0 else 0.0,
617
+ "achievement_fraction": len(achievements_unlocked)
618
+ / 100.0, # Core Balrog metric (approximated)
619
+ # Combat and interaction metrics
620
+ "monsters_defeated": min(1.0, monsters_killed / 5.0), # Normalize to 5 kills
621
+ "items_collected": min(1.0, items_picked_up / 10.0), # Normalize to 10 items
622
+ "scrolls_used": min(1.0, scrolls_read / 3.0), # Normalize to 3 scrolls
623
+ "potions_used": min(1.0, potions_drunk / 2.0), # Normalize to 2 potions
624
+ "spells_cast": min(1.0, spells_cast / 2.0), # Normalize to 2 spells
625
+ # Exploration metrics
626
+ "rooms_explored": min(1.0, rooms_explored / 5.0), # Normalize to 5 rooms
627
+ "secret_doors_found": 1.0 if secret_doors_found > 0 else 0.0,
628
+ "stairs_found": 1.0 if stairs_found > 0 else 0.0,
629
+ "traps_encountered": 1.0 if traps_encountered > 0 else 0.0,
630
+ # Advanced metrics
631
+ "prayers_attempted": 1.0 if prayers_attempted > 0 else 0.0,
632
+ "score_progress": min(1.0, max_score / 100.0), # Normalize to 100 points
633
+ "active_exploration": 1.0
634
+ if (rooms_explored + secret_doors_found + stairs_found) > 0
635
+ else 0.0,
636
+ "item_interaction": 1.0 if (scrolls_read + potions_drunk + spells_cast) > 0 else 0.0,
637
+ # --- Shaped Rewards ---
638
+ # Survival & Progress
639
+ "depth_progress_reward": max(0.0, shaped_rewards["depth_delta_total"]),
640
+ "stairs_discovery_reward": min(1.0, shaped_rewards["stairs_seen_total"] / 5.0),
641
+ "survival_reward": min(
642
+ 1.0, shaped_rewards["turn_alive_total"] / 1.0
643
+ ), # Normalize to 1.0 for 100 turns
644
+ "hp_management_reward": max(0.0, shaped_rewards["hp_gain_total"]),
645
+ "hunger_management_reward": min(
646
+ 1.0, shaped_rewards["hunger_ok_total"] / (len(agent.history) or 1)
647
+ ),
648
+ # Exploration
649
+ "new_tiles_reward": min(
650
+ 1.0, shaped_rewards["new_tiles_total"] / 100.0
651
+ ), # Normalize to 100 tiles
652
+ "room_discovery_reward": min(1.0, shaped_rewards["rooms_explored_delta_total"] / 5.0),
653
+ "secret_discovery_reward": min(1.0, shaped_rewards["secret_doors_delta_total"] / 3.0),
654
+ "trap_discovery_reward": min(1.0, shaped_rewards["traps_identified_delta_total"] / 3.0),
655
+ # Combat
656
+ "combat_success_reward": min(1.0, shaped_rewards["monsters_killed_delta_total"] / 5.0),
657
+ "damage_dealt_reward": min(1.0, shaped_rewards["dmg_dealt_total"] / 50.0),
658
+ "damage_avoided_reward": max(0.0, 1.0 - shaped_rewards["dmg_taken_total"] / 50.0),
659
+ # Resources
660
+ "wealth_accumulation_reward": min(1.0, shaped_rewards["gold_delta_total"] / 100.0),
661
+ "item_collection_reward": min(1.0, shaped_rewards["items_picked_delta_total"] / 10.0),
662
+ "scroll_usage_reward": min(1.0, shaped_rewards["scrolls_read_delta_total"] / 3.0),
663
+ "potion_usage_reward": min(1.0, shaped_rewards["potions_quaffed_delta_total"] / 3.0),
664
+ "spell_usage_reward": min(1.0, shaped_rewards["spells_cast_delta_total"] / 3.0),
665
+ # Skill/Utility
666
+ "first_prayer_reward": 1.0 if shaped_rewards["first_prayer_achieved"] else 0.0,
667
+ "first_spell_reward": 1.0 if shaped_rewards["first_spell_achieved"] else 0.0,
668
+ "identification_reward": min(1.0, shaped_rewards["identify_item_total"] / 3.0),
669
+ # Achievements
670
+ "achievement_unlock_reward": min(
671
+ 1.0, shaped_rewards["achievement_unlocked_total"] / 10.0
672
+ ),
673
+ # Overall shaped reward
674
+ "total_intermediate_reward": shaped_rewards["total_intermediate_reward"],
675
+ "normalized_intermediate_reward": min(
676
+ 1.0, max(0.0, shaped_rewards["total_intermediate_reward"] / 20.0)
677
+ ),
678
+ }
679
+
680
+ # Remove or mark irrelevant rubric keys
681
+ irrelevant_rubric = {}
682
+ for k in list(rubric.keys()):
683
+ if k in IRRELEVANT_RUBRIC_KEYS:
684
+ irrelevant_rubric[k] = rubric.pop(k)
685
+
686
+ # Success determination
687
+ success = max_depth_reached >= target_depth or max_reward > 10.0 or balrog_score > 5.0
688
+
689
+ if success:
690
+ print(
691
+ f" āœ… Instance {instance_num}: SUCCESS! Depth {max_depth_reached}, Balrog score: {balrog_score:.0f}"
692
+ )
693
+ else:
694
+ print(
695
+ f" āŒ Instance {instance_num}: Partial progress - depth {max_depth_reached}/{target_depth}, Balrog score: {balrog_score:.0f}"
696
+ )
697
+
698
+ return {
699
+ "eval_metric": eval_metric,
700
+ "rubric": rubric,
701
+ "max_depth_reached": max_depth_reached,
702
+ "target_depth": target_depth,
703
+ "max_reward": max_reward,
704
+ "balrog_score": balrog_score,
705
+ "balrog_total_reward": balrog_total_reward,
706
+ "achievements_unlocked": achievements_unlocked,
707
+ "final_stats": final_stats,
708
+ "success": success,
709
+ "error": False,
710
+ # Additional progress metrics
711
+ "monsters_killed": monsters_killed,
712
+ "items_picked_up": items_picked_up,
713
+ "scrolls_read": scrolls_read,
714
+ "potions_drunk": potions_drunk,
715
+ "rooms_explored": rooms_explored,
716
+ "secret_doors_found": secret_doors_found,
717
+ "stairs_found": stairs_found,
718
+ "traps_encountered": traps_encountered,
719
+ "spells_cast": spells_cast,
720
+ "prayers_attempted": prayers_attempted,
721
+ "max_score": max_score,
722
+ # Shaped rewards
723
+ "shaped_rewards": shaped_rewards,
724
+ "irrelevant_rubric": irrelevant_rubric,
725
+ }
726
+
727
+ except Exception as e:
728
+ print(f" Instance {instance_num}: Error - {e}")
729
+ import traceback
730
+
731
+ traceback.print_exc()
732
+ return {"eval_metric": 0.0, "rubric": {}, "error": True}
733
+
734
+
735
+ # --- Batch Evaluation ---
736
+ async def evaluate_nethack_batch() -> Dict[str, Any]:
737
+ """Evaluate NetHack agent on multiple easy instances."""
738
+ print(f"šŸŽÆ Evaluating NetHack on {NUM_INSTANCES} {DIFFICULTY} instances...")
739
+
740
+ llm = LM(model_name=MODEL_NAME, formatting_model_name=MODEL_NAME, temperature=0.0)
741
+
742
+ # Get task instances using the taskset system
743
+ from synth_ai.environments.examples.nethack.taskset import create_nethack_taskset
744
+
745
+ taskset = await create_nethack_taskset()
746
+
747
+ # Filter for the desired difficulty
748
+ task_instances = [inst for inst in taskset.instances if inst.metadata.difficulty == DIFFICULTY][
749
+ :NUM_INSTANCES
750
+ ]
751
+
752
+ if len(task_instances) < NUM_INSTANCES:
753
+ print(f" āš ļø Only found {len(task_instances)} {DIFFICULTY} instances, using all available")
754
+
755
+ print(f" šŸ“ Using {len(task_instances)} {DIFFICULTY} task instances")
756
+
757
+ async with AsyncClient(
758
+ base_url=SERVICE_BASE_URL, timeout=60.0
759
+ ) as client: # Longer timeout for NetHack
760
+ tasks = []
761
+ bars = []
762
+ for i, task_instance in enumerate(task_instances):
763
+ bar = tqdm(total=MAX_TURNS, desc=f"Ep {i + 1}", position=i, leave=True)
764
+ bars.append(bar)
765
+ agent = NetHackReActAgent(llm, max_turns=MAX_TURNS, verbose=False)
766
+ tasks.append(run_single_episode(client, agent, task_instance, i + 1, bar))
767
+
768
+ results = await asyncio.gather(*tasks)
769
+
770
+ # Filter out error results
771
+ valid_results = [r for r in results if not r.get("error", False)]
772
+
773
+ if not valid_results:
774
+ return {
775
+ "eval_metrics": [],
776
+ "mean_eval_metric": 0.0,
777
+ "mean_rubric": {},
778
+ "num_episodes": 0,
779
+ }
780
+
781
+ # Extract eval metrics and rubrics
782
+ eval_metrics = [r["eval_metric"] for r in valid_results]
783
+ mean_eval_metric = sum(eval_metrics) / len(eval_metrics)
784
+
785
+ # Extract Balrog scores
786
+ balrog_scores = [r.get("balrog_score", 0.0) for r in valid_results]
787
+ mean_balrog_score = sum(balrog_scores) / len(balrog_scores) if balrog_scores else 0.0
788
+
789
+ # Extract Balrog total rewards
790
+ balrog_total_rewards = [r.get("balrog_total_reward", 0.0) for r in valid_results]
791
+ mean_balrog_total_reward = (
792
+ sum(balrog_total_rewards) / len(balrog_total_rewards) if balrog_total_rewards else 0.0
793
+ )
794
+
795
+ # Extract additional progress metrics
796
+ progress_metrics = {
797
+ "monsters_killed": [r.get("monsters_killed", 0) for r in valid_results],
798
+ "items_picked_up": [r.get("items_picked_up", 0) for r in valid_results],
799
+ "scrolls_read": [r.get("scrolls_read", 0) for r in valid_results],
800
+ "potions_drunk": [r.get("potions_drunk", 0) for r in valid_results],
801
+ "rooms_explored": [r.get("rooms_explored", 0) for r in valid_results],
802
+ "secret_doors_found": [r.get("secret_doors_found", 0) for r in valid_results],
803
+ "stairs_found": [r.get("stairs_found", 0) for r in valid_results],
804
+ "traps_encountered": [r.get("traps_encountered", 0) for r in valid_results],
805
+ "spells_cast": [r.get("spells_cast", 0) for r in valid_results],
806
+ "prayers_attempted": [r.get("prayers_attempted", 0) for r in valid_results],
807
+ "max_score": [r.get("max_score", 0) for r in valid_results],
808
+ }
809
+
810
+ # Calculate means for progress metrics
811
+ mean_progress_metrics = {}
812
+ for key, values in progress_metrics.items():
813
+ mean_progress_metrics[key] = sum(values) / len(values) if values else 0.0
814
+
815
+ # Extract shaped rewards
816
+ shaped_rewards_summary = {}
817
+ irrelevant_shaped_summary = {}
818
+ if valid_results and "shaped_rewards" in valid_results[0]:
819
+ shaped_reward_keys = valid_results[0]["shaped_rewards"].keys()
820
+ for key in shaped_reward_keys:
821
+ values = [r.get("shaped_rewards", {}).get(key, 0) for r in valid_results]
822
+ if isinstance(values[0], bool):
823
+ avg_value = sum(values) / len(values) # Fraction of episodes
824
+ else:
825
+ avg_value = sum(values) / len(values) if values else 0.0
826
+
827
+ if key in IRRELEVANT_RUBRIC_KEYS:
828
+ irrelevant_shaped_summary[key] = avg_value
829
+ else:
830
+ shaped_rewards_summary[key] = avg_value
831
+
832
+ # Calculate individual relevant shaped rewards sums
833
+ individual_relevant_sums = []
834
+ if valid_results and "shaped_rewards" in valid_results[0]:
835
+ for result in valid_results:
836
+ episode_shaped_rewards = result.get("shaped_rewards", {})
837
+ relevant_sum = sum(
838
+ v for k, v in episode_shaped_rewards.items() if k not in IRRELEVANT_RUBRIC_KEYS
839
+ )
840
+ individual_relevant_sums.append(relevant_sum)
841
+
842
+ # Calculate mean of relevant shaped rewards sums
843
+ relevant_shaped_rewards_sum = (
844
+ sum(individual_relevant_sums) / len(individual_relevant_sums)
845
+ if individual_relevant_sums
846
+ else 0.0
847
+ )
848
+
849
+ # Calculate individual relevant rubric sums
850
+ individual_relevant_rubric_sums = []
851
+ for result in valid_results:
852
+ episode_rubric = result.get("rubric", {})
853
+ relevant_rubric_sum = sum(
854
+ v for k, v in episode_rubric.items() if k not in IRRELEVANT_RUBRIC_KEYS
855
+ )
856
+ individual_relevant_rubric_sums.append(relevant_rubric_sum)
857
+
858
+ # Calculate mean of relevant rubric sums
859
+ relevant_rubric_sum = (
860
+ sum(individual_relevant_rubric_sums) / len(individual_relevant_rubric_sums)
861
+ if individual_relevant_rubric_sums
862
+ else 0.0
863
+ )
864
+
865
+ # Calculate mean rubric values (excluding irrelevant)
866
+ all_rubric_keys = set()
867
+ for r in valid_results:
868
+ all_rubric_keys.update(
869
+ [k for k in r["rubric"].keys() if k not in IRRELEVANT_RUBRIC_KEYS]
870
+ )
871
+
872
+ mean_rubric = {}
873
+ for key in all_rubric_keys:
874
+ values = [r["rubric"].get(key, 0.0) for r in valid_results]
875
+ mean_rubric[key] = sum(values) / len(values)
876
+
877
+ # Collect irrelevant rubric metrics summary
878
+ irrelevant_summary = {}
879
+ for key in IRRELEVANT_RUBRIC_KEYS:
880
+ vals = [r.get("irrelevant_rubric", {}).get(key, 0.0) for r in valid_results]
881
+ irrelevant_summary[key] = sum(vals) / len(vals) if vals else 0.0
882
+
883
+ return {
884
+ "eval_metrics": eval_metrics,
885
+ "mean_eval_metric": mean_eval_metric,
886
+ "balrog_scores": balrog_scores,
887
+ "mean_balrog_score": mean_balrog_score,
888
+ "balrog_total_rewards": balrog_total_rewards,
889
+ "mean_balrog_total_reward": mean_balrog_total_reward,
890
+ "mean_rubric": mean_rubric,
891
+ "progress_metrics": progress_metrics,
892
+ "mean_progress_metrics": mean_progress_metrics,
893
+ "shaped_rewards_summary": shaped_rewards_summary,
894
+ "irrelevant_summary": irrelevant_summary,
895
+ "irrelevant_shaped_summary": irrelevant_shaped_summary,
896
+ "relevant_shaped_rewards_sum": relevant_shaped_rewards_sum,
897
+ "individual_relevant_sums": individual_relevant_sums,
898
+ "individual_relevant_rubric_sums": individual_relevant_rubric_sums,
899
+ "relevant_rubric_sum": relevant_rubric_sum,
900
+ "num_episodes": len(valid_results),
901
+ }
902
+
903
+
904
+ async def main():
905
+ """Run NetHack evaluation."""
906
+ print(f"šŸŽ® NetHack ReAct Agent Evaluation")
907
+ print(f"Model: {MODEL_NAME}")
908
+ print(f"Service: {SERVICE_BASE_URL}")
909
+ print(f"Instances: {NUM_INSTANCES}")
910
+ print(f"Difficulty: {DIFFICULTY}")
911
+ print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
912
+ print("=" * 50)
913
+
914
+ # Test service health
915
+ async with AsyncClient(base_url=SERVICE_BASE_URL, timeout=10.0) as client:
916
+ try:
917
+ health_resp = await client.get("/health")
918
+ health_data = health_resp.json()
919
+
920
+ if "NetHack" not in health_data.get("supported_environments", []):
921
+ print("āŒ NetHack not available on service")
922
+ return
923
+
924
+ print("āœ… Service health check passed")
925
+
926
+ except Exception as e:
927
+ print(f"āŒ Service health check failed: {e}")
928
+ return
929
+
930
+ # Run evaluation
931
+ try:
932
+ results = await evaluate_nethack_batch()
933
+
934
+ print("\n" + "=" * 80)
935
+ print("šŸ† FINAL NETHACK EVALUATION RESULTS")
936
+ print("=" * 80)
937
+
938
+ # Print eval metrics
939
+ print(f"šŸ“Š EVAL METRICS:")
940
+ print(f" Episodes: {results['num_episodes']}")
941
+ print(f" Individual Scores: {[f'{x:.2f}' for x in results['eval_metrics']]}")
942
+ print(f" Mean Eval Metric: {results['mean_eval_metric']:.2f}")
943
+
944
+ # Print Balrog scores
945
+ print(f"\nāš”ļø BALROG SCORES:")
946
+ print(f" Individual Scores: {[f'{x:.3f}' for x in results['balrog_scores']]}")
947
+ print(f" Mean Balrog Score: {results['mean_balrog_score']:.3f}")
948
+
949
+ # Print Balrog total rewards
950
+ print(f"\nšŸ† BALROG TOTAL REWARDS:")
951
+ print(f" Individual Rewards: {[f'{x:.2f}' for x in results['balrog_total_rewards']]}")
952
+ print(f" Mean Balrog Total Reward: {results['mean_balrog_total_reward']:.2f}")
953
+
954
+ # Print relevant sums
955
+ print(f"\nšŸ’Æ RELEVANT RUBRIC SUMS:")
956
+ print(
957
+ f" Individual Sums: {[f'{x:.3f}' for x in results.get('individual_relevant_rubric_sums', [])]}"
958
+ )
959
+ print(f" Mean Relevant Rubric Sum: {results.get('relevant_rubric_sum', 0.0):.3f}")
960
+
961
+ print(f"\nšŸ’Æ RELEVANT SHAPED REWARD SUMS:")
962
+ print(
963
+ f" Individual Sums: {[f'{x:.3f}' for x in results.get('individual_relevant_sums', [])]}"
964
+ )
965
+ print(
966
+ f" Mean Relevant Shaped Reward Sum: {results.get('relevant_shaped_rewards_sum', 0.0):.3f}"
967
+ )
968
+
969
+ # Print rubric results
970
+ print(f"\nšŸŽÆ RUBRIC RESULTS:")
971
+ if results["mean_rubric"]:
972
+ for achievement, score in sorted(results["mean_rubric"].items()):
973
+ print(f" {achievement}: {score:.2f}")
974
+ else:
975
+ print(" No rubric data available")
976
+
977
+ # Print progress metrics
978
+ print(f"\nšŸ“ˆ PROGRESS METRICS:")
979
+ if results["mean_progress_metrics"]:
980
+ for metric, value in sorted(results["mean_progress_metrics"].items()):
981
+ print(f" {metric}: {value:.1f}")
982
+ else:
983
+ print(" No progress data available")
984
+
985
+ # Print shaped rewards summary
986
+ print(f"\nšŸŽÆ SHAPED REWARDS SUMMARY:")
987
+ if results.get("shaped_rewards_summary"):
988
+ for reward_key, value in sorted(results["shaped_rewards_summary"].items()):
989
+ if isinstance(value, bool):
990
+ print(f" {reward_key}: {value}")
991
+ else:
992
+ print(f" {reward_key}: {value:.3f}")
993
+ else:
994
+ print(" No shaped rewards data available")
995
+
996
+ # Print irrelevant shaped rewards
997
+ print(f"\n🚫 IRRELEVANT SHAPED REWARDS:")
998
+ if results.get("irrelevant_shaped_summary"):
999
+ for reward_key, value in sorted(results["irrelevant_shaped_summary"].items()):
1000
+ print(f" {reward_key}: {value:.3f}")
1001
+ else:
1002
+ print(" None")
1003
+
1004
+ # Print irrelevant rubric metrics
1005
+ print(f"\n🚫 IRRELEVANT RUBRIC METRICS:")
1006
+ if results.get("irrelevant_summary"):
1007
+ for metric, value in sorted(results["irrelevant_summary"].items()):
1008
+ print(f" {metric}: {value:.2f}")
1009
+ else:
1010
+ print(" None")
1011
+
1012
+ # Overall assessment
1013
+ print(f"\nšŸ” ASSESSMENT:")
1014
+ balrog_score = results["mean_balrog_score"]
1015
+ eval_metric = results["mean_eval_metric"]
1016
+
1017
+ if eval_metric > 0.8 or balrog_score > 40.0:
1018
+ print("šŸŽ‰ Excellent performance - mastering the dungeon!")
1019
+ elif eval_metric > 0.6 or balrog_score > 20.0:
1020
+ print("āœ… Good performance - making solid progress!")
1021
+ elif eval_metric > 0.4 or balrog_score > 10.0:
1022
+ print("āš ļø Moderate performance - learning the ropes")
1023
+ elif balrog_score > 5.0:
1024
+ print("šŸ“ˆ Decent exploration - building dungeon skills")
1025
+ else:
1026
+ print("šŸƒ Early exploration - focus on basic survival and movement")
1027
+
1028
+ # Output markdown table row for README collation
1029
+ print(f"\nšŸ“‹ MARKDOWN TABLE ROW:")
1030
+ print(
1031
+ "| Model | Episodes | Mean Eval | Mean Balrog | Mean Relevant Rubric | Mean Relevant Shaped | Non-Zero Progress | Non-Zero Rubric | Assessment |"
1032
+ )
1033
+ print(
1034
+ "|------------------|----------|-----------|-------------|----------------------|----------------------|-------------------|-----------------|------------|"
1035
+ )
1036
+ relevant_rubric_sum = results.get("relevant_rubric_sum", 0.0)
1037
+ relevant_shaped_sum = results.get("relevant_shaped_rewards_sum", 0.0)
1038
+
1039
+ # Count non-zero progress metrics
1040
+ progress_metrics = results.get("mean_progress_metrics", {})
1041
+ non_zero_progress = sum(1 for value in progress_metrics.values() if value > 0.0)
1042
+
1043
+ # Count non-zero rubric results (excluding irrelevant ones)
1044
+ rubric_results = results.get("mean_rubric", {})
1045
+ non_zero_rubric = sum(
1046
+ 1
1047
+ for key, value in rubric_results.items()
1048
+ if value > 0.0 and key not in IRRELEVANT_RUBRIC_KEYS
1049
+ )
1050
+
1051
+ if eval_metric > 0.6 or balrog_score > 20.0:
1052
+ assessment = "Excellent"
1053
+ elif eval_metric > 0.4 or balrog_score > 10.0:
1054
+ assessment = "Good"
1055
+ elif balrog_score > 5.0:
1056
+ assessment = "Moderate"
1057
+ else:
1058
+ assessment = "Learning"
1059
+
1060
+ print(
1061
+ f"| {MODEL_NAME:<16} | {results['num_episodes']:>8} | {eval_metric:>9.3f} | {balrog_score:>11.3f} | {relevant_rubric_sum:>20.3f} | {relevant_shaped_sum:>20.3f} | {non_zero_progress:>17} | {non_zero_rubric:>15} | {assessment:<10} |"
1062
+ )
1063
+
1064
+ except Exception as e:
1065
+ print(f"āŒ Evaluation failed: {e}")
1066
+
1067
+
1068
+ # Metrics that are considered baseline / always-positive and should be treated as irrelevant when summarizing
1069
+ IRRELEVANT_RUBRIC_KEYS = {
1070
+ "survival_reward",
1071
+ "hunger_management_reward",
1072
+ "damage_avoided_reward",
1073
+ "stairs_discovery_reward",
1074
+ "turn_alive_total", # from shaped summary
1075
+ "hunger_ok_total", # from shaped summary
1076
+ }
1077
+
1078
+ # --- CLI Entry Point ---
1079
+ if __name__ == "__main__":
1080
+ import argparse
1081
+ import asyncio
1082
+
1083
+ parser = argparse.ArgumentParser(
1084
+ description="Run NetHack ReAct Agent Evaluation (TOML configurable)"
1085
+ )
1086
+ parser.add_argument("--config", "-c", type=str, help="Path to TOML configuration file")
1087
+ parser.add_argument("--model", "-m", type=str, help="Model name (overrides config)")
1088
+ parser.add_argument("--episodes", "-e", type=int, help="Number of episodes (overrides config)")
1089
+ parser.add_argument("--max-turns", "-t", type=int, help="Maximum turns (overrides config)")
1090
+ parser.add_argument("--difficulty", "-d", type=str, help="Difficulty (overrides config)")
1091
+
1092
+ args = parser.parse_args()
1093
+
1094
+ if args.config:
1095
+ config = NetHackConfig(args.config)
1096
+ else:
1097
+ config = NetHackConfig()
1098
+
1099
+ # Apply CLI overrides
1100
+ if args.model:
1101
+ config.model_name = args.model
1102
+ if args.episodes:
1103
+ config.num_instances = args.episodes
1104
+ if args.max_turns:
1105
+ config.max_turns = args.max_turns
1106
+ if args.difficulty:
1107
+ config.difficulty = args.difficulty
1108
+
1109
+ _apply_config_to_globals(config)
1110
+
1111
+ # Run the evaluation
1112
+ asyncio.run(main())