synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. synth_ai/__init__.py +28 -2
  2. synth_ai/core/system.py +4 -0
  3. synth_ai/environments/__init__.py +35 -0
  4. synth_ai/environments/environment/__init__.py +1 -0
  5. synth_ai/environments/environment/artifacts/__init__.py +1 -0
  6. synth_ai/environments/environment/artifacts/base.py +50 -0
  7. synth_ai/environments/environment/core.py +22 -0
  8. synth_ai/environments/environment/db/__init__.py +1 -0
  9. synth_ai/environments/environment/db/sqlite.py +45 -0
  10. synth_ai/environments/environment/registry.py +24 -0
  11. synth_ai/environments/environment/resources/sqlite.py +46 -0
  12. synth_ai/environments/environment/results.py +1 -0
  13. synth_ai/environments/environment/rewards/__init__.py +1 -0
  14. synth_ai/environments/environment/rewards/core.py +28 -0
  15. synth_ai/environments/environment/shared_engine.py +26 -0
  16. synth_ai/environments/environment/tools/__init__.py +34 -0
  17. synth_ai/environments/examples/__init__.py +1 -0
  18. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
  26. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  27. synth_ai/environments/examples/crafter_classic/engine.py +502 -0
  28. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  29. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  30. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  31. synth_ai/environments/examples/crafter_classic/environment.py +255 -0
  32. synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
  33. synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
  34. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  35. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  36. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  37. synth_ai/environments/examples/enron/engine.py +291 -0
  38. synth_ai/environments/examples/enron/environment.py +165 -0
  39. synth_ai/environments/examples/enron/taskset.py +112 -0
  40. synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
  41. synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
  42. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  43. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  44. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
  45. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  46. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
  47. synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
  48. synth_ai/environments/examples/minigrid/engine.py +589 -0
  49. synth_ai/environments/examples/minigrid/environment.py +274 -0
  50. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  51. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  52. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  53. synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
  54. synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
  55. synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
  56. synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
  57. synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
  58. synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
  59. synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
  60. synth_ai/environments/examples/nethack/__init__.py +7 -0
  61. synth_ai/environments/examples/nethack/achievements.py +337 -0
  62. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  63. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  64. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
  65. synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
  66. synth_ai/environments/examples/nethack/engine.py +738 -0
  67. synth_ai/environments/examples/nethack/environment.py +255 -0
  68. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  69. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  70. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  71. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  72. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  73. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  74. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  75. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  76. synth_ai/environments/examples/nethack/taskset.py +323 -0
  77. synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
  78. synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
  79. synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
  80. synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
  81. synth_ai/environments/examples/red/__init__.py +7 -0
  82. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  83. synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
  84. synth_ai/environments/examples/red/config_logging.py +110 -0
  85. synth_ai/environments/examples/red/engine.py +693 -0
  86. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  87. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  88. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  89. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  90. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  91. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  92. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  93. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  94. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  95. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  96. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  97. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  98. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  99. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  100. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  101. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  102. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  103. synth_ai/environments/examples/red/environment.py +235 -0
  104. synth_ai/environments/examples/red/taskset.py +77 -0
  105. synth_ai/environments/examples/red/test_fixes.py +125 -0
  106. synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
  107. synth_ai/environments/examples/red/units/__init__.py +1 -0
  108. synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
  109. synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
  110. synth_ai/environments/examples/red/units/test_engine.py +192 -0
  111. synth_ai/environments/examples/red/units/test_environment.py +455 -0
  112. synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
  113. synth_ai/environments/examples/red/units/test_integration.py +217 -0
  114. synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
  115. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
  116. synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
  117. synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
  118. synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
  119. synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
  120. synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
  121. synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
  122. synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
  123. synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
  124. synth_ai/environments/examples/red/units/test_taskset.py +116 -0
  125. synth_ai/environments/examples/red/units/test_tree.py +448 -0
  126. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  127. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
  128. synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
  129. synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
  130. synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
  131. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
  132. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
  133. synth_ai/environments/examples/sokoban/engine.py +675 -0
  134. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  135. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  136. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  137. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  138. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  139. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  140. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  141. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  142. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  143. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  144. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  145. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  146. synth_ai/environments/examples/sokoban/environment.py +228 -0
  147. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  148. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  149. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  150. synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
  151. synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
  152. synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
  153. synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
  154. synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
  155. synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
  156. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  157. synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
  158. synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
  159. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  160. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  161. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  162. synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
  163. synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
  164. synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
  165. synth_ai/environments/examples/verilog/__init__.py +10 -0
  166. synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
  167. synth_ai/environments/examples/verilog/engine.py +328 -0
  168. synth_ai/environments/examples/verilog/environment.py +349 -0
  169. synth_ai/environments/examples/verilog/taskset.py +418 -0
  170. synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
  171. synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
  172. synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
  173. synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
  174. synth_ai/environments/reproducibility/core.py +42 -0
  175. synth_ai/environments/reproducibility/tree.py +364 -0
  176. synth_ai/environments/service/app.py +78 -0
  177. synth_ai/environments/service/core_routes.py +775 -0
  178. synth_ai/environments/service/external_registry.py +57 -0
  179. synth_ai/environments/service/registry.py +9 -0
  180. synth_ai/environments/stateful/__init__.py +1 -0
  181. synth_ai/environments/stateful/core.py +28 -0
  182. synth_ai/environments/stateful/engine.py +21 -0
  183. synth_ai/environments/stateful/state.py +7 -0
  184. synth_ai/environments/tasks/api.py +19 -0
  185. synth_ai/environments/tasks/core.py +78 -0
  186. synth_ai/environments/tasks/filters.py +39 -0
  187. synth_ai/environments/tasks/utils.py +89 -0
  188. synth_ai/environments/v0_observability/history.py +3 -0
  189. synth_ai/environments/v0_observability/log.py +2 -0
  190. synth_ai/lm/caching/constants.py +1 -0
  191. synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
  192. synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
  193. synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
  194. synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
  195. synth_ai/{zyk/lms → lm}/config.py +2 -1
  196. synth_ai/{zyk/lms → lm}/constants.py +2 -2
  197. synth_ai/{zyk/lms → lm}/core/all.py +10 -10
  198. synth_ai/{zyk/lms → lm}/core/main.py +57 -33
  199. synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
  200. synth_ai/lm/cost/monitor.py +1 -0
  201. synth_ai/lm/cost/statefulness.py +1 -0
  202. synth_ai/lm/provider_support/__init__.py +8 -0
  203. synth_ai/lm/provider_support/anthropic.py +945 -0
  204. synth_ai/lm/provider_support/openai.py +1115 -0
  205. synth_ai/lm/provider_support/suppress_logging.py +31 -0
  206. synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
  207. synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
  208. synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
  209. synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
  210. synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
  211. synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
  212. synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
  213. synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
  214. synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
  215. synth_ai/lm/vendors/supported/__init__.py +0 -0
  216. synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
  217. synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
  218. synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
  219. synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
  220. synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
  221. synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
  222. synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
  223. synth_ai/tracing/__init__.py +0 -0
  224. synth_ai/tracing/abstractions.py +224 -0
  225. synth_ai/tracing/base_client.py +91 -0
  226. synth_ai/tracing/client_manager.py +131 -0
  227. synth_ai/tracing/config.py +140 -0
  228. synth_ai/tracing/context.py +146 -0
  229. synth_ai/tracing/decorators.py +679 -0
  230. synth_ai/tracing/events/__init__.py +0 -0
  231. synth_ai/tracing/events/manage.py +147 -0
  232. synth_ai/tracing/events/scope.py +86 -0
  233. synth_ai/tracing/events/store.py +227 -0
  234. synth_ai/tracing/immediate_client.py +152 -0
  235. synth_ai/tracing/local.py +18 -0
  236. synth_ai/tracing/log_client_base.py +74 -0
  237. synth_ai/tracing/retry_queue.py +187 -0
  238. synth_ai/tracing/trackers.py +515 -0
  239. synth_ai/tracing/upload.py +504 -0
  240. synth_ai/tracing/utils.py +9 -0
  241. synth_ai/zyk/__init__.py +28 -2
  242. synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
  243. synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
  244. synth_ai/zyk/lms/caching/constants.py +0 -1
  245. synth_ai/zyk/lms/cost/monitor.py +0 -1
  246. synth_ai/zyk/lms/cost/statefulness.py +0 -1
  247. synth_ai-0.1.9.dist-info/METADATA +0 -37
  248. synth_ai-0.1.9.dist-info/RECORD +0 -50
  249. /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
  250. /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
  251. /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
  252. /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
  253. /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
  254. /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
  255. /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
  256. /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
  257. /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
  258. /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
  259. /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
  260. /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
  261. /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
  262. /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
  263. /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
  264. {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
  265. {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
  266. {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,738 @@
1
+ """NetHack engine implementation with state management and NLE integration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import base64
7
+ from dataclasses import dataclass, field
8
+ from typing import Dict, Any, Optional, Tuple, List, TYPE_CHECKING, cast
9
+ import numpy as np
10
+ import logging
11
+
12
+ from synth_ai.environments.stateful.engine import StatefulEngine, StatefulEngineSnapshot
13
+ from synth_ai.environments.reproducibility.core import IReproducibleEngine
14
+ from synth_ai.environments.environment.rewards.core import RewardStack, RewardComponent
15
+ from synth_ai.environments.environment.shared_engine import (
16
+ GetObservationCallable,
17
+ InternalObservation,
18
+ )
19
+ from synth_ai.environments.tasks.core import TaskInstance
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # NLE imports are required
24
+ try:
25
+ from .helpers.nle_wrapper import NLEWrapper
26
+ from .helpers.action_mapping import convert_action_to_nle
27
+ from .achievements import NetHackAchievements, calculate_balrog_reward
28
+ except ImportError as e:
29
+ raise ImportError(
30
+ "NLE (NetHack Learning Environment) is required but not installed. "
31
+ "Please install it with: pip install nle"
32
+ ) from e
33
+
34
+ if TYPE_CHECKING:
35
+ from .taskset import NetHackTaskInstanceMetadata
36
+
37
+
38
+ @dataclass
39
+ class NetHackPublicState:
40
+ """State visible to the agent."""
41
+
42
+ # Game state
43
+ dungeon_level: int = 1
44
+ character_stats: Dict[str, Any] = field(default_factory=dict)
45
+ inventory: List[Dict[str, Any]] = field(default_factory=list)
46
+ position: Tuple[int, int] = (0, 0)
47
+
48
+ # Observation data
49
+ ascii_map: str = ""
50
+ message: str = ""
51
+ cursor_position: Tuple[int, int] = (0, 0)
52
+
53
+ # Meta information
54
+ turn_count: int = 0
55
+ max_turns: int = 10000
56
+ last_action: str = ""
57
+ terminated: bool = False
58
+
59
+ # Game context
60
+ in_menu: bool = False
61
+ menu_items: List[str] = field(default_factory=list)
62
+
63
+ # Achievements tracking
64
+ achievements: NetHackAchievements = field(default_factory=NetHackAchievements)
65
+ achievements_unlocked: Dict[str, bool] = field(default_factory=dict)
66
+
67
+ def diff(self, prev_state: "NetHackPublicState") -> Dict[str, Any]:
68
+ """Track changes between states."""
69
+ differences = {}
70
+
71
+ if self.dungeon_level != prev_state.dungeon_level:
72
+ differences["dungeon_level"] = (
73
+ prev_state.dungeon_level,
74
+ self.dungeon_level,
75
+ )
76
+ if self.position != prev_state.position:
77
+ differences["position"] = (prev_state.position, self.position)
78
+ if self.message != prev_state.message:
79
+ differences["message"] = (prev_state.message, self.message)
80
+ if self.turn_count != prev_state.turn_count:
81
+ differences["turn_count"] = (prev_state.turn_count, self.turn_count)
82
+ if self.terminated != prev_state.terminated:
83
+ differences["terminated"] = (prev_state.terminated, self.terminated)
84
+ if self.last_action != prev_state.last_action:
85
+ differences["last_action"] = (prev_state.last_action, self.last_action)
86
+
87
+ return differences
88
+
89
+ @property
90
+ def map_text(self) -> str:
91
+ """Formatted ASCII dungeon map."""
92
+ return self.ascii_map
93
+
94
+
95
+ @dataclass
96
+ class NetHackPrivateState:
97
+ """Internal state (rewards, termination flags)."""
98
+
99
+ reward_last: float = 0.0
100
+ total_reward: float = 0.0
101
+ terminated: bool = False
102
+ truncated: bool = False
103
+
104
+ # Progress tracking
105
+ score: int = 0
106
+ depth_reached: int = 1
107
+ experience_level: int = 1
108
+ monsters_killed: int = 0
109
+ items_collected: int = 0
110
+
111
+ # Balrog reward tracking
112
+ balrog_reward_last: float = 0.0
113
+ balrog_total_reward: float = 0.0
114
+
115
+ def diff(self, prev_state: "NetHackPrivateState") -> Dict[str, Any]:
116
+ """Track reward/progress changes."""
117
+ differences = {}
118
+
119
+ if self.reward_last != prev_state.reward_last:
120
+ differences["reward_last"] = (prev_state.reward_last, self.reward_last)
121
+ if self.total_reward != prev_state.total_reward:
122
+ differences["total_reward"] = (prev_state.total_reward, self.total_reward)
123
+ if self.score != prev_state.score:
124
+ differences["score"] = (prev_state.score, self.score)
125
+ if self.depth_reached != prev_state.depth_reached:
126
+ differences["depth_reached"] = (
127
+ prev_state.depth_reached,
128
+ self.depth_reached,
129
+ )
130
+
131
+ return differences
132
+
133
+
134
+ @dataclass
135
+ class NetHackEngineSnapshot(StatefulEngineSnapshot):
136
+ """Serialization container for NetHack engine state."""
137
+
138
+ task_instance_dict: Dict[str, Any]
139
+ engine_snapshot: Dict[str, Any]
140
+ nle_state: Optional[Dict[str, Any]] = None # NLE-specific state if available
141
+
142
+
143
+ class NetHackSurvivalComponent(RewardComponent):
144
+ """Reward component for staying alive."""
145
+
146
+ async def score(self, state: NetHackPublicState, action: str) -> float:
147
+ if state.terminated:
148
+ return -1.0 # Penalty for death
149
+ return 0.01 # Small reward for each turn survived
150
+
151
+
152
+ class NetHackProgressComponent(RewardComponent):
153
+ """Reward component for exploration and depth."""
154
+
155
+ def __init__(self):
156
+ self.last_depth = 1
157
+
158
+ async def score(self, state: NetHackPublicState, action: str) -> float:
159
+ reward = 0.0
160
+
161
+ # Reward for reaching new dungeon levels
162
+ if state.dungeon_level > self.last_depth:
163
+ reward += 1.0 * (state.dungeon_level - self.last_depth)
164
+ self.last_depth = state.dungeon_level
165
+
166
+ return reward
167
+
168
+
169
+ class NetHackScoreComponent(RewardComponent):
170
+ """Reward component based on game score."""
171
+
172
+ def __init__(self):
173
+ self.last_score = 0
174
+
175
+ async def score(self, state: NetHackPublicState, action: str) -> float:
176
+ # Get score from character stats - require it exists
177
+ current_score = state.character_stats["score"]
178
+
179
+ # Calculate score delta
180
+ score_delta = current_score - self.last_score
181
+ self.last_score = current_score
182
+
183
+ # Scale the score reward (NLE scores can be large)
184
+ return score_delta / 100.0 if score_delta > 0 else 0.0
185
+
186
+
187
+ class NetHackAchievementComponent(RewardComponent):
188
+ """Reward component for unlocking achievements."""
189
+
190
+ def __init__(self):
191
+ self.last_unlocked = set()
192
+
193
+ async def score(self, state: NetHackPublicState, action: str) -> float:
194
+ reward = 0.0
195
+
196
+ # Count newly unlocked achievements
197
+ current_unlocked = set(k for k, v in state.achievements_unlocked.items() if v)
198
+ new_achievements = current_unlocked - self.last_unlocked
199
+
200
+ # Give rewards for different achievement types
201
+ for achievement in new_achievements:
202
+ if "first_" in achievement:
203
+ reward += 1.0 # First-time achievements
204
+ elif "reached_dlvl_" in achievement:
205
+ reward += 2.0 # Depth achievements
206
+ elif "killed_" in achievement and "monsters" in achievement:
207
+ reward += 0.5 # Kill milestones
208
+ elif "collected_" in achievement and "gold" in achievement:
209
+ reward += 0.5 # Gold milestones
210
+ elif "reached_level_" in achievement:
211
+ reward += 1.5 # Experience level milestones
212
+ elif "minetown" in achievement or "castle" in achievement:
213
+ reward += 5.0 # Special locations
214
+ elif "quest" in achievement:
215
+ reward += 10.0 # Quest achievements
216
+ else:
217
+ reward += 0.5 # Default reward
218
+
219
+ self.last_unlocked = current_unlocked
220
+ return reward
221
+
222
+
223
+ class NetHackEngine(StatefulEngine, IReproducibleEngine):
224
+ """NetHack game engine with NLE backend."""
225
+
226
+ def __init__(self, task_instance: TaskInstance):
227
+ self.task_instance = task_instance
228
+
229
+ # Require proper metadata
230
+ from .taskset import NetHackTaskInstanceMetadata
231
+
232
+ if not isinstance(task_instance.metadata, NetHackTaskInstanceMetadata):
233
+ raise TypeError(
234
+ f"Expected NetHackTaskInstanceMetadata, got {type(task_instance.metadata).__name__}"
235
+ )
236
+
237
+ metadata = cast(NetHackTaskInstanceMetadata, task_instance.metadata)
238
+ self.character_role = metadata.character_role
239
+ self.max_turns = metadata.time_limit
240
+
241
+ # Initialize NLE wrapper
242
+ self.nle = NLEWrapper(character_role=self.character_role)
243
+
244
+ # Initialize reward components with proper tracking - NO SURVIVAL NOISE
245
+ self.progress_component = NetHackProgressComponent()
246
+ self.score_component = NetHackScoreComponent()
247
+ self.achievement_component = NetHackAchievementComponent()
248
+
249
+ self.reward_stack = RewardStack(
250
+ [
251
+ self.progress_component, # Depth progress
252
+ self.score_component, # Game score changes
253
+ self.achievement_component, # Achievement unlocks
254
+ ]
255
+ )
256
+
257
+ # State tracking
258
+ self.public_state: Optional[NetHackPublicState] = None
259
+ self.private_state: Optional[NetHackPrivateState] = None
260
+
261
+ # NLE observation processing
262
+ self.last_nle_obs = None
263
+
264
+ async def _reset_engine(
265
+ self, *, seed: int | None = None
266
+ ) -> Tuple[NetHackPrivateState, NetHackPublicState]:
267
+ """Reset to initial state using NLE."""
268
+ # Reset NLE environment with seed
269
+ obs = await asyncio.to_thread(self.nle.reset, seed)
270
+ self.last_nle_obs = obs
271
+
272
+ # Log what we actually got from NLE
273
+ logger.info(f"NLE reset returned observation keys: {list(obs.keys())}")
274
+ if "player_stats" in obs:
275
+ logger.info(f"Player stats keys: {list(obs['player_stats'].keys())}")
276
+
277
+ # Initialize private state - require all fields
278
+ player_stats = obs["player_stats"] # Will KeyError if missing
279
+ self.private_state = NetHackPrivateState(
280
+ reward_last=0.0,
281
+ total_reward=0.0,
282
+ terminated=False,
283
+ truncated=False,
284
+ score=player_stats["score"],
285
+ depth_reached=player_stats["depth"],
286
+ experience_level=player_stats["experience_level"],
287
+ monsters_killed=0,
288
+ items_collected=0,
289
+ balrog_reward_last=0.0,
290
+ balrog_total_reward=0.0,
291
+ )
292
+
293
+ # Initialize public state from NLE observation - no fallbacks
294
+ self.public_state = NetHackPublicState(
295
+ dungeon_level=player_stats["depth"],
296
+ character_stats={
297
+ "hp": player_stats["hp"],
298
+ "max_hp": player_stats["max_hp"],
299
+ "strength": player_stats["strength"],
300
+ "dexterity": player_stats["dexterity"],
301
+ "constitution": player_stats["constitution"],
302
+ "intelligence": player_stats["intelligence"],
303
+ "wisdom": player_stats["wisdom"],
304
+ "charisma": player_stats["charisma"],
305
+ "gold": player_stats["gold"],
306
+ "experience": player_stats["experience_points"],
307
+ "level": player_stats["experience_level"],
308
+ "ac": player_stats["ac"],
309
+ },
310
+ inventory=self._process_inventory(obs["inventory"]) if "inventory" in obs else [],
311
+ position=(player_stats["y"], player_stats["x"]),
312
+ ascii_map=obs["ascii_map"],
313
+ message=obs["message"],
314
+ cursor_position=obs.get(
315
+ "cursor", (player_stats["y"], player_stats["x"])
316
+ ), # Cursor might not be in processed obs
317
+ turn_count=0,
318
+ max_turns=self.max_turns,
319
+ last_action="",
320
+ terminated=False,
321
+ in_menu=obs.get("in_menu", False), # Menu detection is heuristic-based
322
+ menu_items=obs.get("menu_text", []), # Menu text only present when in menu
323
+ achievements=NetHackAchievements(),
324
+ achievements_unlocked={},
325
+ )
326
+
327
+ # Reset reward components
328
+ self.progress_component.last_depth = self.public_state.dungeon_level
329
+ self.score_component.last_score = self.private_state.score
330
+
331
+ return self.private_state, self.public_state
332
+
333
+ def _process_inventory(self, inventory_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
334
+ """Process NLE inventory format to our format."""
335
+ processed_items = []
336
+ for item in inventory_items:
337
+ processed_items.append(
338
+ {
339
+ "name": item["description"],
340
+ "count": 1, # NLE doesn't always provide count
341
+ "letter": item["letter"],
342
+ }
343
+ )
344
+ return processed_items
345
+
346
+ async def _step_engine(self, action: str) -> Tuple[NetHackPrivateState, NetHackPublicState]:
347
+ """Execute one step/action using NLE."""
348
+ # print(f"===== NetHack Engine _step_engine called with action: {action} =====")
349
+ if self.public_state is None or self.private_state is None:
350
+ raise RuntimeError("Engine not initialized. Call _reset_engine first.")
351
+
352
+ # Validate action
353
+ if action not in self.nle.action_map and action not in ["terminate"]:
354
+ # Try to handle menu selections and special cases
355
+ if len(action) == 1 and (action.isalpha() or action.isdigit()):
356
+ # Single character actions are likely menu selections
357
+ pass
358
+ else:
359
+ raise ValueError(
360
+ f"Invalid action: {action}. Valid actions: {list(self.nle.action_map.keys())}"
361
+ )
362
+
363
+ # Update turn count
364
+ self.public_state.turn_count += 1
365
+ self.public_state.last_action = action
366
+
367
+ # Define non-turn-consuming actions
368
+ non_turn_actions = [
369
+ "look",
370
+ "farlook",
371
+ "whatis",
372
+ "identify",
373
+ "discoveries",
374
+ "conduct",
375
+ "attributes",
376
+ "help",
377
+ "version",
378
+ "history",
379
+ ]
380
+
381
+ # Warn about non-advancing actions
382
+ if action in non_turn_actions:
383
+ logger.warning(f"Action '{action}' is a free action that doesn't advance game time!")
384
+ # If we're repeatedly using non-advancing actions, force a wait
385
+ if hasattr(self, "_consecutive_free_actions"):
386
+ self._consecutive_free_actions += 1
387
+ if self._consecutive_free_actions >= 3:
388
+ logger.warning(
389
+ f"Too many consecutive free actions ({self._consecutive_free_actions}), forcing 'wait'"
390
+ )
391
+ action = "wait"
392
+ self._consecutive_free_actions = 0
393
+ else:
394
+ self._consecutive_free_actions = 1
395
+ else:
396
+ self._consecutive_free_actions = 0
397
+
398
+ # Check for manual termination
399
+ if action == "terminate":
400
+ self.public_state.terminated = True
401
+ self.private_state.terminated = True
402
+ self.public_state.message = "Game terminated by agent."
403
+ return self.private_state, self.public_state
404
+
405
+ # Check for timeout
406
+ if self.public_state.turn_count >= self.public_state.max_turns:
407
+ self.public_state.terminated = True
408
+ self.private_state.terminated = True
409
+ self.private_state.truncated = True
410
+ self.public_state.message = "Time limit reached. Game over!"
411
+ return self.private_state, self.public_state
412
+
413
+ # Execute action in NLE
414
+ try:
415
+ # Save previous observation BEFORE stepping
416
+ prev_obs = self.last_nle_obs
417
+
418
+ obs, reward, done, info = await asyncio.to_thread(self.nle.step, action)
419
+ logger.debug(f"NLE step returned - reward: {reward}, done: {done}, info: {info}")
420
+ except Exception as e:
421
+ logger.error(f"NLE step failed for action '{action}': {e}")
422
+ raise
423
+
424
+ # Log observation structure on first few steps for debugging
425
+ if self.public_state.turn_count < 3:
426
+ logger.info(f"Turn {self.public_state.turn_count} observation keys: {list(obs.keys())}")
427
+
428
+ # Update state from NLE observation - no defensive coding
429
+ player_stats = obs["player_stats"] # Will KeyError if missing
430
+
431
+ # Track previous values for reward calculation
432
+ prev_score = self.private_state.score
433
+ prev_depth = self.private_state.depth_reached
434
+
435
+ # Update private state
436
+ self.private_state.score = player_stats["score"]
437
+ self.private_state.depth_reached = max(
438
+ self.private_state.depth_reached, player_stats["depth"]
439
+ )
440
+ self.private_state.experience_level = player_stats["experience_level"]
441
+
442
+ # Update public state
443
+ self.public_state.dungeon_level = player_stats["depth"]
444
+ self.public_state.position = (player_stats["y"], player_stats["x"])
445
+ self.public_state.ascii_map = obs["ascii_map"]
446
+ self.public_state.message = obs["message"]
447
+ self.public_state.cursor_position = obs.get(
448
+ "cursor", (player_stats["y"], player_stats["x"])
449
+ )
450
+ self.public_state.in_menu = obs.get("in_menu", False)
451
+ self.public_state.menu_items = obs.get("menu_text", [])
452
+
453
+ # Update character stats - require all fields
454
+ self.public_state.character_stats = {
455
+ "hp": player_stats["hp"],
456
+ "max_hp": player_stats["max_hp"],
457
+ "strength": player_stats["strength"],
458
+ "dexterity": player_stats["dexterity"],
459
+ "constitution": player_stats["constitution"],
460
+ "intelligence": player_stats["intelligence"],
461
+ "wisdom": player_stats["wisdom"],
462
+ "charisma": player_stats["charisma"],
463
+ "gold": player_stats["gold"],
464
+ "experience": player_stats["experience_points"],
465
+ "level": player_stats["experience_level"],
466
+ "ac": player_stats["ac"],
467
+ "score": player_stats["score"],
468
+ }
469
+
470
+ # Update inventory
471
+ self.public_state.inventory = (
472
+ self._process_inventory(obs["inventory"]) if "inventory" in obs else []
473
+ )
474
+
475
+ # Handle termination from NLE
476
+ if done:
477
+ self.public_state.terminated = True
478
+ self.private_state.terminated = True
479
+ # Log info to understand structure
480
+ logger.info(f"Game ended - info: {info}")
481
+ if "end_status" in info and info["end_status"] == 0: # 0 means death
482
+ self.public_state.message = info.get(
483
+ "death_reason", "You died!"
484
+ ) # death_reason might not always exist
485
+ else:
486
+ self.public_state.message = "Game ended."
487
+
488
+ # Update achievements before calculating rewards
489
+ newly_unlocked = self.public_state.achievements.update_from_observation(obs, prev_obs)
490
+ self.public_state.achievements_unlocked.update(
491
+ self.public_state.achievements.get_unlocked_achievements()
492
+ )
493
+
494
+ # Log newly unlocked achievements
495
+ if newly_unlocked:
496
+ logger.info(f"Achievements unlocked: {list(newly_unlocked.keys())}")
497
+
498
+ # Calculate rewards
499
+ # Base reward from NLE
500
+ nle_reward = reward
501
+
502
+ # Additional reward shaping
503
+ step_reward = await self.reward_stack.step_reward(self.public_state, action)
504
+
505
+ self.private_state.reward_last = nle_reward + step_reward
506
+ self.private_state.total_reward += self.private_state.reward_last
507
+
508
+ # Calculate Balrog-style reward
509
+ self.private_state.balrog_reward_last = calculate_balrog_reward(obs, prev_obs)
510
+ self.private_state.balrog_total_reward += self.private_state.balrog_reward_last
511
+
512
+ # Log balrog reward changes with context
513
+ if self.private_state.balrog_reward_last > 0:
514
+ print(
515
+ f"🏆 BALROG REWARD: +{self.private_state.balrog_reward_last:.3f} (total: {self.private_state.balrog_total_reward:.3f})"
516
+ )
517
+ balrog_score = self.public_state.achievements.balrog_progress.percent
518
+ print(
519
+ f" Balrog score: {balrog_score}% (dungeon: {self.public_state.achievements.balrog_progress.dungeon_progression}, exp: {self.public_state.achievements.balrog_progress.experience_progression})"
520
+ )
521
+
522
+ # NOW update last_nle_obs for next step
523
+ self.last_nle_obs = obs
524
+
525
+ return self.private_state, self.public_state
526
+
527
+ def __del__(self):
528
+ """Cleanup NLE environment on deletion."""
529
+ if hasattr(self, "nle"):
530
+ self.nle.close()
531
+
532
+ async def _serialize_engine(self) -> NetHackEngineSnapshot:
533
+ """Serialize current state."""
534
+ if self.public_state is None or self.private_state is None:
535
+ raise RuntimeError("Cannot serialize uninitialized engine")
536
+
537
+ # Get NLE state
538
+ nle_state = None
539
+ try:
540
+ nle_state_bytes = await asyncio.to_thread(self.nle.get_state)
541
+ # Convert bytes to base64 string for JSON serialization
542
+ nle_state = base64.b64encode(nle_state_bytes).decode("ascii")
543
+ except Exception as e:
544
+ logger.warning(f"Failed to serialize NLE state: {e}")
545
+
546
+ task_dict = await self.task_instance.serialize()
547
+ logger.debug(f"Serialized task instance: {task_dict}")
548
+
549
+ return NetHackEngineSnapshot(
550
+ task_instance_dict=task_dict,
551
+ engine_snapshot={
552
+ "public_state": {
553
+ "dungeon_level": self.public_state.dungeon_level,
554
+ "character_stats": self.public_state.character_stats,
555
+ "inventory": self.public_state.inventory,
556
+ "position": self.public_state.position,
557
+ "ascii_map": self.public_state.ascii_map,
558
+ "message": self.public_state.message,
559
+ "cursor_position": self.public_state.cursor_position,
560
+ "turn_count": self.public_state.turn_count,
561
+ "max_turns": self.public_state.max_turns,
562
+ "last_action": self.public_state.last_action,
563
+ "terminated": self.public_state.terminated,
564
+ "in_menu": self.public_state.in_menu,
565
+ "menu_items": self.public_state.menu_items,
566
+ },
567
+ "private_state": {
568
+ "reward_last": self.private_state.reward_last,
569
+ "total_reward": self.private_state.total_reward,
570
+ "terminated": self.private_state.terminated,
571
+ "truncated": self.private_state.truncated,
572
+ "score": self.private_state.score,
573
+ "depth_reached": self.private_state.depth_reached,
574
+ "experience_level": self.private_state.experience_level,
575
+ "monsters_killed": self.private_state.monsters_killed,
576
+ "items_collected": self.private_state.items_collected,
577
+ },
578
+ "character_role": self.character_role,
579
+ "progress_last_depth": self.progress_component.last_depth,
580
+ "score_last_score": self.score_component.last_score,
581
+ },
582
+ nle_state=nle_state,
583
+ )
584
+
585
+ @classmethod
586
+ async def _deserialize_engine(cls, snapshot: NetHackEngineSnapshot) -> "NetHackEngine":
587
+ """Restore from serialized state."""
588
+ from .taskset import NetHackTaskInstance
589
+
590
+ task_instance = await NetHackTaskInstance.deserialize(snapshot.task_instance_dict)
591
+ if task_instance is None:
592
+ raise ValueError("Failed to deserialize task instance")
593
+ engine = cls(task_instance)
594
+
595
+ # Restore state
596
+ engine_data = snapshot.engine_snapshot
597
+ pub_data = engine_data["public_state"]
598
+ priv_data = engine_data["private_state"]
599
+
600
+ engine.public_state = NetHackPublicState(
601
+ dungeon_level=pub_data["dungeon_level"],
602
+ character_stats=pub_data["character_stats"],
603
+ inventory=pub_data["inventory"],
604
+ position=(pub_data["position"][0], pub_data["position"][1]),
605
+ ascii_map=pub_data["ascii_map"],
606
+ message=pub_data["message"],
607
+ cursor_position=(
608
+ pub_data["cursor_position"][0],
609
+ pub_data["cursor_position"][1],
610
+ ),
611
+ turn_count=pub_data["turn_count"],
612
+ max_turns=pub_data["max_turns"],
613
+ last_action=pub_data["last_action"],
614
+ terminated=pub_data["terminated"],
615
+ in_menu=pub_data["in_menu"],
616
+ menu_items=pub_data["menu_items"],
617
+ )
618
+
619
+ engine.private_state = NetHackPrivateState(
620
+ reward_last=priv_data["reward_last"],
621
+ total_reward=priv_data["total_reward"],
622
+ terminated=priv_data["terminated"],
623
+ truncated=priv_data["truncated"],
624
+ score=priv_data["score"],
625
+ depth_reached=priv_data["depth_reached"],
626
+ experience_level=priv_data["experience_level"],
627
+ monsters_killed=priv_data["monsters_killed"],
628
+ items_collected=priv_data["items_collected"],
629
+ )
630
+
631
+ engine.character_role = engine_data["character_role"]
632
+
633
+ # Restore reward component states
634
+ engine.progress_component.last_depth = engine_data["progress_last_depth"]
635
+ engine.score_component.last_score = engine_data["score_last_score"]
636
+
637
+ # Restore NLE state if available
638
+ if snapshot.nle_state:
639
+ try:
640
+ nle_state_bytes = base64.b64decode(snapshot.nle_state)
641
+ await asyncio.to_thread(engine.nle.set_state, nle_state_bytes)
642
+ except Exception as e:
643
+ logger.warning(f"Failed to restore NLE state: {e}")
644
+ # If we can't restore NLE state, reset it
645
+ await asyncio.to_thread(engine.nle.reset)
646
+
647
+ return engine
648
+
649
+ def get_current_states_for_observation(
650
+ self,
651
+ ) -> Tuple[NetHackPrivateState, NetHackPublicState]:
652
+ """Get current states without advancing."""
653
+ if self.public_state is None or self.private_state is None:
654
+ raise RuntimeError("Engine not initialized")
655
+ return self.private_state, self.public_state
656
+
657
+
658
+ class NetHackObservationCallable(GetObservationCallable):
659
+ """Standard observation callable for NetHack."""
660
+
661
+ async def get_observation(
662
+ self, pub: NetHackPublicState, priv: NetHackPrivateState
663
+ ) -> InternalObservation:
664
+ observation = {
665
+ "ascii_map": pub.ascii_map,
666
+ "message": pub.message,
667
+ "character_stats": pub.character_stats,
668
+ "inventory_summary": self._format_inventory(pub.inventory),
669
+ "dungeon_level": pub.dungeon_level,
670
+ "position": pub.position,
671
+ "turn_count": pub.turn_count,
672
+ "last_action": pub.last_action,
673
+ "reward_last": priv.reward_last,
674
+ "total_reward": priv.total_reward,
675
+ "balrog_reward_last": priv.balrog_reward_last,
676
+ "balrog_total_reward": priv.balrog_total_reward,
677
+ "score": priv.score,
678
+ "experience_level": priv.experience_level,
679
+ "terminated": priv.terminated,
680
+ "in_menu": pub.in_menu,
681
+ "menu_items": pub.menu_items if pub.in_menu else [],
682
+ "achievements_unlocked": pub.achievements_unlocked,
683
+ "achievements_summary": self._format_achievements(pub.achievements_unlocked),
684
+ }
685
+ return observation # type: ignore[return-value]
686
+
687
+ def _format_inventory(self, inventory: List[Dict[str, Any]]) -> str:
688
+ """Format inventory for display."""
689
+ if not inventory:
690
+ return "Your inventory is empty."
691
+
692
+ items = []
693
+ for item in inventory:
694
+ items.append(f"- {item['name']} (count: {item.get('count', 1)})")
695
+ return "\n".join(items)
696
+
697
+ def _format_achievements(self, achievements: Dict[str, bool]) -> str:
698
+ """Format achievements for display."""
699
+ unlocked = [name for name, status in achievements.items() if status]
700
+ if not unlocked:
701
+ return "None unlocked yet"
702
+ if len(unlocked) <= 5:
703
+ return ", ".join(unlocked)
704
+ else:
705
+ return f"{', '.join(unlocked[:5])} and {len(unlocked) - 5} more"
706
+
707
+
708
+ class NetHackCheckpointObservationCallable(GetObservationCallable):
709
+ """Checkpoint observation callable for NetHack."""
710
+
711
+ async def get_observation(
712
+ self, pub: NetHackPublicState, priv: NetHackPrivateState
713
+ ) -> InternalObservation:
714
+ observation = {
715
+ "final_score": priv.score,
716
+ "max_depth": priv.depth_reached,
717
+ "experience_level": priv.experience_level,
718
+ "monsters_killed": priv.monsters_killed,
719
+ "items_collected": priv.items_collected,
720
+ "turn_count_final": pub.turn_count,
721
+ "total_reward": priv.total_reward,
722
+ "balrog_total_reward": priv.balrog_total_reward,
723
+ "terminated": priv.terminated,
724
+ "truncated": priv.truncated,
725
+ "character_role": pub.character_stats.get("role", "unknown"),
726
+ "achievements_unlocked": list(pub.achievements_unlocked.keys()),
727
+ "achievements_count": len([v for v in pub.achievements_unlocked.values() if v]),
728
+ "achievement_stats": {
729
+ "depth_reached": pub.achievements.depth_reached,
730
+ "monsters_killed": pub.achievements.monsters_killed,
731
+ "gold_collected": pub.achievements.gold_collected,
732
+ "items_collected": pub.achievements.items_picked_up,
733
+ "max_level": pub.achievements.max_level_reached,
734
+ "turns_survived": pub.achievements.turns_survived,
735
+ "balrog_score": pub.achievements.balrog_progress.percent,
736
+ },
737
+ }
738
+ return observation # type: ignore[return-value]