synth-ai 0.2.0__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. synth_ai/__init__.py +28 -2
  2. synth_ai/core/system.py +4 -0
  3. synth_ai/environments/__init__.py +35 -0
  4. synth_ai/environments/environment/__init__.py +1 -0
  5. synth_ai/environments/environment/artifacts/__init__.py +1 -0
  6. synth_ai/environments/environment/artifacts/base.py +50 -0
  7. synth_ai/environments/environment/core.py +22 -0
  8. synth_ai/environments/environment/db/__init__.py +1 -0
  9. synth_ai/environments/environment/db/sqlite.py +45 -0
  10. synth_ai/environments/environment/registry.py +24 -0
  11. synth_ai/environments/environment/resources/sqlite.py +46 -0
  12. synth_ai/environments/environment/results.py +1 -0
  13. synth_ai/environments/environment/rewards/__init__.py +1 -0
  14. synth_ai/environments/environment/rewards/core.py +28 -0
  15. synth_ai/environments/environment/shared_engine.py +26 -0
  16. synth_ai/environments/environment/tools/__init__.py +34 -0
  17. synth_ai/environments/examples/__init__.py +1 -0
  18. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
  26. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  27. synth_ai/environments/examples/crafter_classic/engine.py +502 -0
  28. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  29. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  30. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  31. synth_ai/environments/examples/crafter_classic/environment.py +255 -0
  32. synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
  33. synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
  34. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  35. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  36. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  37. synth_ai/environments/examples/enron/engine.py +291 -0
  38. synth_ai/environments/examples/enron/environment.py +165 -0
  39. synth_ai/environments/examples/enron/taskset.py +112 -0
  40. synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
  41. synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
  42. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  43. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  44. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
  45. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  46. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
  47. synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
  48. synth_ai/environments/examples/minigrid/engine.py +589 -0
  49. synth_ai/environments/examples/minigrid/environment.py +274 -0
  50. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  51. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  52. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  53. synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
  54. synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
  55. synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
  56. synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
  57. synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
  58. synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
  59. synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
  60. synth_ai/environments/examples/nethack/__init__.py +7 -0
  61. synth_ai/environments/examples/nethack/achievements.py +337 -0
  62. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  63. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  64. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
  65. synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
  66. synth_ai/environments/examples/nethack/engine.py +738 -0
  67. synth_ai/environments/examples/nethack/environment.py +255 -0
  68. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  69. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  70. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  71. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  72. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  73. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  74. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  75. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  76. synth_ai/environments/examples/nethack/taskset.py +323 -0
  77. synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
  78. synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
  79. synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
  80. synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
  81. synth_ai/environments/examples/red/__init__.py +7 -0
  82. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  83. synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
  84. synth_ai/environments/examples/red/config_logging.py +110 -0
  85. synth_ai/environments/examples/red/engine.py +693 -0
  86. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  87. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  88. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  89. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  90. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  91. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  92. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  93. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  94. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  95. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  96. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  97. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  98. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  99. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  100. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  101. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  102. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  103. synth_ai/environments/examples/red/environment.py +235 -0
  104. synth_ai/environments/examples/red/taskset.py +77 -0
  105. synth_ai/environments/examples/red/test_fixes.py +125 -0
  106. synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
  107. synth_ai/environments/examples/red/units/__init__.py +1 -0
  108. synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
  109. synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
  110. synth_ai/environments/examples/red/units/test_engine.py +192 -0
  111. synth_ai/environments/examples/red/units/test_environment.py +455 -0
  112. synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
  113. synth_ai/environments/examples/red/units/test_integration.py +217 -0
  114. synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
  115. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
  116. synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
  117. synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
  118. synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
  119. synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
  120. synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
  121. synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
  122. synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
  123. synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
  124. synth_ai/environments/examples/red/units/test_taskset.py +116 -0
  125. synth_ai/environments/examples/red/units/test_tree.py +448 -0
  126. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  127. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
  128. synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
  129. synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
  130. synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
  131. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
  132. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
  133. synth_ai/environments/examples/sokoban/engine.py +675 -0
  134. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  135. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  136. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  137. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  138. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  139. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  140. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  141. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  142. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  143. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  144. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  145. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  146. synth_ai/environments/examples/sokoban/environment.py +228 -0
  147. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  148. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  149. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  150. synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
  151. synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
  152. synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
  153. synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
  154. synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
  155. synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
  156. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  157. synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
  158. synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
  159. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  160. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  161. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  162. synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
  163. synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
  164. synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
  165. synth_ai/environments/examples/verilog/__init__.py +10 -0
  166. synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
  167. synth_ai/environments/examples/verilog/engine.py +328 -0
  168. synth_ai/environments/examples/verilog/environment.py +349 -0
  169. synth_ai/environments/examples/verilog/taskset.py +418 -0
  170. synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
  171. synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
  172. synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
  173. synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
  174. synth_ai/environments/reproducibility/core.py +42 -0
  175. synth_ai/environments/reproducibility/tree.py +364 -0
  176. synth_ai/environments/service/app.py +78 -0
  177. synth_ai/environments/service/core_routes.py +775 -0
  178. synth_ai/environments/service/external_registry.py +57 -0
  179. synth_ai/environments/service/registry.py +9 -0
  180. synth_ai/environments/stateful/__init__.py +1 -0
  181. synth_ai/environments/stateful/core.py +28 -0
  182. synth_ai/environments/stateful/engine.py +21 -0
  183. synth_ai/environments/stateful/state.py +7 -0
  184. synth_ai/environments/tasks/api.py +19 -0
  185. synth_ai/environments/tasks/core.py +78 -0
  186. synth_ai/environments/tasks/filters.py +39 -0
  187. synth_ai/environments/tasks/utils.py +89 -0
  188. synth_ai/environments/v0_observability/history.py +3 -0
  189. synth_ai/environments/v0_observability/log.py +2 -0
  190. synth_ai/lm/caching/constants.py +1 -0
  191. synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
  192. synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
  193. synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
  194. synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
  195. synth_ai/{zyk/lms → lm}/config.py +2 -1
  196. synth_ai/{zyk/lms → lm}/constants.py +2 -2
  197. synth_ai/{zyk/lms → lm}/core/all.py +10 -10
  198. synth_ai/{zyk/lms → lm}/core/main.py +57 -33
  199. synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
  200. synth_ai/lm/cost/monitor.py +1 -0
  201. synth_ai/lm/cost/statefulness.py +1 -0
  202. synth_ai/lm/provider_support/__init__.py +8 -0
  203. synth_ai/lm/provider_support/anthropic.py +945 -0
  204. synth_ai/lm/provider_support/openai.py +1115 -0
  205. synth_ai/lm/provider_support/suppress_logging.py +31 -0
  206. synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
  207. synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
  208. synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
  209. synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
  210. synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +35 -32
  211. synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
  212. synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
  213. synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
  214. synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
  215. synth_ai/lm/vendors/supported/__init__.py +0 -0
  216. synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
  217. synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
  218. synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
  219. synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
  220. synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
  221. synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
  222. synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
  223. synth_ai/tracing/__init__.py +0 -0
  224. synth_ai/tracing/abstractions.py +224 -0
  225. synth_ai/tracing/base_client.py +91 -0
  226. synth_ai/tracing/client_manager.py +131 -0
  227. synth_ai/tracing/config.py +140 -0
  228. synth_ai/tracing/context.py +146 -0
  229. synth_ai/tracing/decorators.py +679 -0
  230. synth_ai/tracing/events/__init__.py +0 -0
  231. synth_ai/tracing/events/manage.py +147 -0
  232. synth_ai/tracing/events/scope.py +86 -0
  233. synth_ai/tracing/events/store.py +227 -0
  234. synth_ai/tracing/immediate_client.py +152 -0
  235. synth_ai/tracing/local.py +18 -0
  236. synth_ai/tracing/log_client_base.py +74 -0
  237. synth_ai/tracing/retry_queue.py +187 -0
  238. synth_ai/tracing/trackers.py +515 -0
  239. synth_ai/tracing/upload.py +504 -0
  240. synth_ai/tracing/utils.py +9 -0
  241. synth_ai/zyk/__init__.py +28 -2
  242. synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
  243. synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
  244. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +1 -1
  245. synth_ai/zyk/lms/caching/constants.py +0 -1
  246. synth_ai/zyk/lms/cost/monitor.py +0 -1
  247. synth_ai/zyk/lms/cost/statefulness.py +0 -1
  248. synth_ai-0.2.0.dist-info/METADATA +0 -36
  249. synth_ai-0.2.0.dist-info/RECORD +0 -50
  250. /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
  251. /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
  252. /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
  253. /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
  254. /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
  255. /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
  256. /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
  257. /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
  258. /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
  259. /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
  260. /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
  261. /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
  262. /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
  263. /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
  264. /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
  265. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info/licenses}/LICENSE +0 -0
  266. {synth_ai-0.2.0.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1471 @@
1
+ import asyncio
2
+ import uuid
3
+ import pytest
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Dict, Any, List, Optional, Deque, Literal
7
+ from pydantic import BaseModel, Field, validator
8
+ from collections import deque
9
+ from synth_ai.zyk import LM
10
+ from synth_ai.zyk.lms.tools.base import BaseTool
11
+ from synth_sdk.tracing.decorators import trace_event_async
12
+ from synth_sdk.tracing.abstractions import RewardSignal, Dataset, TrainingQuestion
13
+ from synth_sdk.tracing.utils import get_system_id
14
+
15
+ # Monkey patch the zyk cache handler to allow mixed content types (for images)
16
+ try:
17
+ from synth_ai.zyk.lms.caching.handler import CacheHandler
18
+
19
+ original_validate_messages = CacheHandler._validate_messages
20
+
21
+ def patched_validate_messages(self, messages: List[Dict[str, Any]]) -> None:
22
+ """Validate that messages are in the correct format - PATCHED to allow mixed content for images."""
23
+ # Allow mixed content types when images are involved - just check that messages exist
24
+ assert all(isinstance(msg, dict) and "content" in msg for msg in messages), (
25
+ "All messages must be dicts with content"
26
+ )
27
+
28
+ CacheHandler._validate_messages = patched_validate_messages
29
+ print("[DEBUG] Successfully monkey patched zyk cache validation to support images")
30
+ except Exception as e:
31
+ print(f"[DEBUG] Failed to monkey patch zyk cache validation: {e}")
32
+ # Continue anyway - the assertion might not be hit in all cases
33
+
34
+ # Pokemon Red specific imports
35
+ from synth_ai.environments.examples.red.environment import (
36
+ PokemonRedEnvironment,
37
+ PokemonRedPublicState,
38
+ PokemonRedPrivateState,
39
+ )
40
+
41
+ # Import early game reward components
42
+ from synth_ai.environments.examples.red.engine_helpers.reward_library.pallet_town_rewards import (
43
+ LeaveStartingRoomReward,
44
+ TalkToMomReward,
45
+ InteractWithTVReward,
46
+ CheckComputerReward,
47
+ ExitHouseReward,
48
+ ExploreTownReward,
49
+ TalkToNPCsReward,
50
+ OakLabDiscoveryReward,
51
+ AttemptRoute1Reward,
52
+ ChooseStarterPokemonReward,
53
+ DoorInteractionReward,
54
+ ObjectInteractionReward,
55
+ TryAllDirectionsReward,
56
+ )
57
+ from synth_ai.environments.examples.red.engine_helpers.reward_library.exploration_rewards import (
58
+ NewAreaDiscoveryReward,
59
+ BuildingEntryReward,
60
+ )
61
+ from synth_ai.environments.examples.red.engine_helpers.reward_library.novelty_rewards import (
62
+ FirstBattleReward,
63
+ FirstPokemonCenterVisitReward,
64
+ )
65
+
66
+ from synth_ai.environments.environment.shared_engine import (
67
+ GetObservationCallable,
68
+ InternalObservation,
69
+ )
70
+ from synth_ai.environments.examples.red.taskset import PokemonRedTaskInstance
71
+ from synth_ai.environments.tasks.core import Impetus, Intent, TaskInstanceMetadata
72
+ from synth_ai.environments.environment.tools import EnvToolCall
73
+
74
+ # Import screen analysis functions
75
+ from synth_ai.environments.examples.red.engine_helpers.screen_analysis import (
76
+ analyze_screen_buffer,
77
+ create_detailed_screen_description,
78
+ )
79
+
80
+ import logging
81
+
82
+ logging.disable(logging.CRITICAL)
83
+
84
+
85
+ # --- Early Game Reward Manager ---
86
+ class EarlyGameRewardManager:
87
+ """Manages early game rewards for Pokemon Red to encourage exploration and progress"""
88
+
89
+ def __init__(self):
90
+ # Initialize early game reward components
91
+ self.rewards = [
92
+ # Pallet Town house exploration
93
+ LeaveStartingRoomReward(),
94
+ TalkToMomReward(),
95
+ InteractWithTVReward(),
96
+ CheckComputerReward(),
97
+ ExitHouseReward(),
98
+ # Town and building exploration
99
+ ExploreTownReward(),
100
+ TalkToNPCsReward(),
101
+ NewAreaDiscoveryReward(),
102
+ BuildingEntryReward(),
103
+ # Story progression
104
+ OakLabDiscoveryReward(),
105
+ AttemptRoute1Reward(),
106
+ ChooseStarterPokemonReward(),
107
+ # Basic interactions
108
+ DoorInteractionReward(),
109
+ ObjectInteractionReward(),
110
+ TryAllDirectionsReward(),
111
+ # First time experiences
112
+ FirstBattleReward(),
113
+ FirstPokemonCenterVisitReward(),
114
+ ]
115
+
116
+ self.total_reward_earned = 0.0
117
+ self.reward_history = []
118
+
119
+ async def calculate_rewards(
120
+ self,
121
+ current_state: Dict[str, Any],
122
+ prev_state: Dict[str, Any],
123
+ action_info: Dict[str, Any],
124
+ ) -> float:
125
+ """Calculate rewards for the current state transition"""
126
+ total_reward = 0.0
127
+ step_rewards = []
128
+
129
+ # Create action context with previous state info
130
+ action_context = {
131
+ "prev_map_id": prev_state.get("map_id", -1),
132
+ "prev_player_x": prev_state.get("player_x", -1),
133
+ "prev_player_y": prev_state.get("player_y", -1),
134
+ "prev_text_box_active": prev_state.get("text_box_active", False),
135
+ "prev_in_battle": prev_state.get("in_battle", False),
136
+ "prev_party": prev_state.get("party", []),
137
+ "prev_inventory": prev_state.get("inventory", []),
138
+ "prev_money": prev_state.get("money", 0),
139
+ **action_info, # Include any additional action info
140
+ }
141
+
142
+ # Calculate rewards from each component
143
+ for reward_component in self.rewards:
144
+ try:
145
+ reward = await reward_component.score(current_state, action_context)
146
+ if reward > 0:
147
+ total_reward += reward
148
+ step_rewards.append(
149
+ {
150
+ "component": reward_component.__class__.__name__,
151
+ "reward": reward,
152
+ }
153
+ )
154
+ print(f"[REWARD] {reward_component.__class__.__name__}: +{reward:.1f}")
155
+ except Exception as e:
156
+ print(f"[REWARD_ERROR] {reward_component.__class__.__name__}: {e}")
157
+ continue
158
+
159
+ if total_reward > 0:
160
+ self.total_reward_earned += total_reward
161
+ self.reward_history.append(
162
+ {
163
+ "step": current_state.get("step_count", 0),
164
+ "total_reward": total_reward,
165
+ "components": step_rewards,
166
+ }
167
+ )
168
+ print(
169
+ f"[REWARD_TOTAL] Step {current_state.get('step_count', 0)}: +{total_reward:.1f} (Total: {self.total_reward_earned:.1f})"
170
+ )
171
+
172
+ return total_reward
173
+
174
+
175
+ # --- Helper function to format observation for LLM ---
176
+ def format_obs_for_llm_from_states(
177
+ pub: PokemonRedPublicState,
178
+ priv: PokemonRedPrivateState,
179
+ screen_analysis: dict = None,
180
+ mode: str = "state_and_screen",
181
+ ) -> str:
182
+ """Format Pokemon Red observation for LLM consumption with comprehensive text-based state information.
183
+
184
+ This function provides rich, semantic game state information to eliminate
185
+ the need for visual processing, as specified in text_port.txt requirements.
186
+ """
187
+
188
+ obs_lines = [
189
+ "=== POKEMON RED GAME STATE ===",
190
+ f"Step: {pub.progress.step_count}",
191
+ ]
192
+
193
+ # === VISUAL SCREEN INFORMATION ===
194
+ if screen_analysis:
195
+ obs_lines.extend(["", "=== VISUAL SCREEN ANALYSIS ==="])
196
+
197
+ # Add detailed screen description - only include ASCII for state_and_ascii mode
198
+ if mode == "state_and_ascii":
199
+ screen_description = create_detailed_screen_description(screen_analysis)
200
+ else:
201
+ # For state_and_screen mode, show summary without ASCII
202
+ screen_description = f"SCREEN TYPE: {screen_analysis.get('screen_type', 'UNKNOWN')}\n"
203
+
204
+ # Add color analysis
205
+ if "colors" in screen_analysis:
206
+ colors_text = "DOMINANT COLORS: " + ", ".join(
207
+ [f"{color}({pct}%)" for color, pct in screen_analysis["colors"].items()]
208
+ )
209
+ screen_description += colors_text + "\n"
210
+
211
+ # Add entity detection summary
212
+ if "entities" in screen_analysis:
213
+ screen_description += (
214
+ f"DETECTED ENTITIES: {len(screen_analysis['entities'])} sprite-like objects\n"
215
+ )
216
+
217
+ # Add UI elements
218
+ if "ui_elements" in screen_analysis:
219
+ ui_elements = screen_analysis["ui_elements"]
220
+ if ui_elements:
221
+ screen_description += f"UI: {', '.join(ui_elements)} detected\n"
222
+
223
+ obs_lines.append(screen_description)
224
+
225
+ # === WORLD INFORMATION ===
226
+ obs_lines.extend(
227
+ [
228
+ "",
229
+ "=== WORLD LOCATION ===",
230
+ f"Map ID: {pub.world.map_id} | Position: ({pub.world.player_x}, {pub.world.player_y})",
231
+ ]
232
+ )
233
+
234
+ # === PLAYER PROGRESS ===
235
+ obs_lines.extend(
236
+ [
237
+ "",
238
+ "=== PLAYER PROGRESS ===",
239
+ f"Badges: {pub.progress.badge_count}/8 (0x{pub.progress.badges:02X})",
240
+ f"Money: ${pub.progress.money:,}",
241
+ ]
242
+ )
243
+
244
+ # === POKEMON PARTY ===
245
+ obs_lines.extend(["", "=== POKEMON PARTY ==="])
246
+
247
+ if pub.party:
248
+ for i, pokemon in enumerate(pub.party, 1):
249
+ status_icon = "●" if pokemon.hp_current > 0 else "✗"
250
+ obs_lines.append(
251
+ f"{i}. Species#{pokemon.species_id:03d} L{pokemon.level} | "
252
+ f"HP:{pokemon.hp_current}/{pokemon.hp_max} ({pokemon.hp_percentage:.1f}%) {status_icon} | "
253
+ f"XP:{pokemon.xp:,}"
254
+ )
255
+ else:
256
+ obs_lines.append("No Pokemon in party")
257
+
258
+ # === INVENTORY ===
259
+ obs_lines.extend(["", "=== INVENTORY ==="])
260
+
261
+ if pub.inventory:
262
+ # Show first 8 items with quantities
263
+ for item in pub.inventory[:8]:
264
+ obs_lines.append(f"Item#{item.item_id:03d} x{item.quantity}")
265
+
266
+ if len(pub.inventory) > 8:
267
+ obs_lines.append(f"... and {len(pub.inventory) - 8} more items")
268
+
269
+ obs_lines.append(f"Total Items: {len(pub.inventory)}")
270
+ else:
271
+ obs_lines.append("No items in inventory")
272
+
273
+ # === GAME SYSTEM STATE ===
274
+ obs_lines.extend(["", "=== GAME SYSTEM STATE ==="])
275
+
276
+ # Just show raw state without interpretation
277
+ if pub.system.in_battle:
278
+ obs_lines.append("In Battle: True")
279
+ obs_lines.append(f"Battle Outcome: {pub.system.battle_outcome}")
280
+ else:
281
+ obs_lines.append("In Battle: False")
282
+
283
+ if pub.system.text_box_active:
284
+ obs_lines.append("Text Box Active: True")
285
+ else:
286
+ obs_lines.append("Text Box Active: False")
287
+
288
+ obs_lines.append(f"Warp Flag: {pub.system.warp_flag}")
289
+
290
+ # === TECHNICAL INFO ===
291
+ obs_lines.extend(
292
+ [
293
+ "",
294
+ "=== TECHNICAL INFO ===",
295
+ f"Last Reward: {priv.reward_last_step:.3f}",
296
+ f"Total Reward: {priv.total_reward:.3f}",
297
+ f"Terminated: {priv.terminated} | Truncated: {priv.truncated}",
298
+ ]
299
+ )
300
+
301
+ if pub.error_info:
302
+ obs_lines.append(f"Error: {pub.error_info}")
303
+
304
+ obs_lines.append("=== END GAME STATE ===")
305
+
306
+ return "\n".join(obs_lines)
307
+
308
+
309
+ # --- Custom observation callable for Pokemon Red ---
310
+ class PokemonRedHistoryObservationCallable(GetObservationCallable):
311
+ def __init__(
312
+ self,
313
+ max_history: int = 1,
314
+ mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
315
+ ):
316
+ self._hist_obs: Deque[str] = deque(maxlen=max_history)
317
+ self._hist_pub_state: Deque[PokemonRedPublicState] = deque(maxlen=max_history)
318
+ self._hist_priv_state: Deque[PokemonRedPrivateState] = deque(maxlen=max_history)
319
+ self._last_state_hash = None
320
+ self._stuck_count = 0
321
+ self.screen_buffer = None # Store screen buffer for agent access
322
+ self.mode = mode # Store mode for observation formatting
323
+
324
+ # Initialize reward manager for early game rewards
325
+ self.reward_manager = EarlyGameRewardManager()
326
+ self._last_state_dict = None # Store previous state for reward calculation
327
+
328
+ async def get_observation(
329
+ self, pub: PokemonRedPublicState, priv: PokemonRedPrivateState
330
+ ) -> InternalObservation:
331
+ if pub is None or priv is None:
332
+ raise RuntimeError("Missing public or private state in get_observation - HARD FAIL")
333
+
334
+ # Create current state dict for reward calculation
335
+ current_state_dict = {
336
+ "map_id": pub.map_id,
337
+ "player_x": pub.player_x,
338
+ "player_y": pub.player_y,
339
+ "step_count": pub.step_count,
340
+ "text_box_active": pub.system.text_box_active,
341
+ "in_battle": pub.system.in_battle,
342
+ "party": [
343
+ {
344
+ "species_id": p.species_id,
345
+ "level": p.level,
346
+ "hp_current": p.hp_current,
347
+ "hp_max": p.hp_max,
348
+ }
349
+ for p in pub.party
350
+ ],
351
+ "inventory": [
352
+ {"item_id": item.item_id, "quantity": item.quantity} for item in pub.inventory
353
+ ],
354
+ "money": pub.progress.money,
355
+ "badges": pub.progress.badges,
356
+ }
357
+
358
+ # Calculate rewards if we have a previous state
359
+ additional_reward = 0.0
360
+ if self._last_state_dict is not None:
361
+ try:
362
+ additional_reward = await self.reward_manager.calculate_rewards(
363
+ current_state_dict,
364
+ self._last_state_dict,
365
+ {"buttons_pressed": []}, # Could track actual buttons if needed
366
+ )
367
+ except Exception as e:
368
+ print(f"[REWARD_ERROR] Failed to calculate rewards: {e}")
369
+
370
+ # Store current state for next iteration
371
+ self._last_state_dict = current_state_dict.copy()
372
+
373
+ # Check if we're stuck (same position and menu state for multiple steps)
374
+ # Use property accessors that handle the new state structure
375
+ current_state_hash = hash((pub.player_x, pub.player_y, pub.map_id, pub.step_count))
376
+ if self._last_state_hash == current_state_hash and pub.step_count > 1:
377
+ self._stuck_count += 1
378
+ if self._stuck_count >= 3:
379
+ raise RuntimeError(
380
+ f"Agent stuck in same state for {self._stuck_count} steps - HARD FAIL. Position: ({pub.player_x}, {pub.player_y}), Map: {pub.map_id}"
381
+ )
382
+ else:
383
+ self._stuck_count = 0
384
+ self._last_state_hash = current_state_hash
385
+
386
+ # Extract screen buffer for agent vision - FAIL HARD if screen access doesn't work
387
+ additional_context = ""
388
+ screen_analysis = None
389
+
390
+ try:
391
+ # Look for environment in call stack to access engine/emulator
392
+ import inspect
393
+
394
+ frame = inspect.currentframe()
395
+ env = None
396
+
397
+ # Walk up the call stack to find the environment
398
+ while frame:
399
+ if "self" in frame.f_locals and hasattr(frame.f_locals["self"], "engine"):
400
+ env = frame.f_locals["self"]
401
+ break
402
+ frame = frame.f_back
403
+
404
+ if not env or not hasattr(env, "engine") or not env.engine:
405
+ raise RuntimeError("Cannot access environment engine - HARD FAIL")
406
+
407
+ # REQUIRE screen access to work
408
+ if not hasattr(env.engine, "emulator") or not env.engine.emulator:
409
+ raise RuntimeError("Emulator not available - HARD FAIL")
410
+
411
+ if not hasattr(env.engine.emulator, "screen"):
412
+ raise RuntimeError("Emulator screen not available - HARD FAIL")
413
+
414
+ # Use PyBoy's documented screen.ndarray property - shape (144, 160, 4) RGBA
415
+ screen_buffer = (
416
+ env.engine.emulator.screen.ndarray.copy()
417
+ ) # Copy to avoid reference issues
418
+
419
+ if screen_buffer is None:
420
+ raise RuntimeError("Screen ndarray is None - HARD FAIL")
421
+
422
+ # Store screen buffer for agent to access
423
+ self.screen_buffer = screen_buffer
424
+ print(f"[DEBUG] Successfully extracted screen buffer with shape: {screen_buffer.shape}")
425
+
426
+ # Perform detailed screen analysis
427
+ screen_analysis = analyze_screen_buffer(screen_buffer)
428
+ print(
429
+ f"[DEBUG] Screen analysis completed - type: {screen_analysis.get('screen_type', 'UNKNOWN')}"
430
+ )
431
+
432
+ # Get additional game state context - REQUIRE this to work
433
+ current_state = env.engine._extract_current_state()
434
+ if not current_state:
435
+ raise RuntimeError("Failed to extract game state - HARD FAIL")
436
+
437
+ # Use the new structured state information from the public state
438
+ additional_context += f"\nWarp Flag: {pub.system.warp_flag}"
439
+ additional_context += f"\nBattle Outcome: {pub.system.battle_outcome}"
440
+ additional_context += f"\nInventory Count: {len(pub.inventory)}"
441
+
442
+ except Exception as e:
443
+ # HARD FAIL on any screen/context extraction errors
444
+ raise RuntimeError(f"Screen/context extraction HARD FAIL: {e}")
445
+
446
+ # Format the base observation with screen analysis
447
+ if self.mode == "state_and_ascii":
448
+ # Include ASCII analysis but no screen buffer in observation
449
+ formatted_obs = format_obs_for_llm_from_states(pub, priv, screen_analysis, self.mode)
450
+ else:
451
+ # Include screen analysis for screen mode
452
+ formatted_obs = format_obs_for_llm_from_states(pub, priv, screen_analysis, self.mode)
453
+
454
+ # Add context info
455
+ enhanced_obs = formatted_obs.replace(
456
+ "\n=== END GAME STATE ===", f"{additional_context}\n=== END GAME STATE ==="
457
+ )
458
+
459
+ # Add reward information to the observation
460
+ if additional_reward > 0 or self.reward_manager.total_reward_earned > 0:
461
+ reward_info = "\n\n=== REWARD PROGRESS ===\n"
462
+ if additional_reward > 0:
463
+ reward_info += f"Step Reward: +{additional_reward:.1f}\n"
464
+ reward_info += f"Total Rewards Earned: {self.reward_manager.total_reward_earned:.1f}\n"
465
+
466
+ # Show recent reward achievements (last 3)
467
+ if self.reward_manager.reward_history:
468
+ reward_info += "Recent Achievements:\n"
469
+ for achievement in self.reward_manager.reward_history[-3:]:
470
+ for component in achievement["components"]:
471
+ reward_info += f"• {component['component']}: +{component['reward']:.1f}\n"
472
+
473
+ enhanced_obs = enhanced_obs.replace(
474
+ "\n=== END GAME STATE ===", f"{reward_info}=== END GAME STATE ==="
475
+ )
476
+
477
+ self._hist_obs.append(enhanced_obs)
478
+ self._hist_pub_state.append(pub)
479
+ self._hist_priv_state.append(priv)
480
+
481
+ observation_dict = {
482
+ "public": pub,
483
+ "private": priv,
484
+ "formatted_obs": enhanced_obs,
485
+ "history_formatted_obs": list(self._hist_obs),
486
+ "history_public_states": list(self._hist_pub_state),
487
+ "history_private_states": list(self._hist_priv_state),
488
+ }
489
+
490
+ # Only include screen buffer for screen mode
491
+ if self.mode == "state_and_screen":
492
+ observation_dict["screen_buffer"] = self.screen_buffer
493
+
494
+ return observation_dict # type: ignore[return-value]
495
+
496
+
497
+ # --- Pydantic Models for Tool Arguments ---
498
+ class PokemonRedInteractArgs(BaseModel):
499
+ buttons: List[str] = Field(
500
+ description="A sequence of 1-5 buttons to press in Pokemon Red (e.g., ['A'], ['UP', 'RIGHT'], ['START', 'DOWN', 'A']). Each button should be one of: A, B, UP, DOWN, LEFT, RIGHT, START, SELECT."
501
+ )
502
+ reasoning: str = Field(
503
+ description="A brief explanation of why this sequence of buttons was chosen and what you expect to accomplish."
504
+ )
505
+
506
+ @validator("buttons")
507
+ def validate_buttons(cls, v):
508
+ valid_buttons = {"A", "B", "UP", "DOWN", "LEFT", "RIGHT", "START", "SELECT"}
509
+ if not v or len(v) == 0:
510
+ raise ValueError("Must provide at least one button")
511
+ if len(v) > 5: # Reduced from 20 to 5
512
+ raise ValueError("Cannot provide more than 5 buttons in sequence")
513
+ for button in v:
514
+ if button.upper() not in valid_buttons:
515
+ raise ValueError(f"Invalid button: {button}. Valid buttons: {valid_buttons}")
516
+ return [button.upper() for button in v] # Normalize to uppercase
517
+
518
+
519
+ class TerminateArgs(BaseModel):
520
+ reason: str = Field(
521
+ description="Reason for termination (e.g., 'all tasks complete', 'stuck', 'max_steps_reached')."
522
+ )
523
+
524
+
525
+ # --- Environment tool call wrapper ---
526
+ class PressButtonCall(EnvToolCall):
527
+ """Helper class for creating button press calls"""
528
+
529
+ def __init__(self, button: str, frames: int = 1):
530
+ super().__init__(tool="press_button", args={"button": button, "frames": frames})
531
+
532
+
533
+ # --- ReAct agent for Pokemon Red ---
534
+ class ReActAgent:
535
+ def __init__(self, llm, max_turns: int = 50):
536
+ self.llm, self.max_turns = llm, max_turns
537
+ self.history: List[Dict[str, Any]] = []
538
+ self.system_name: str = "pokemon-red-react"
539
+ self.system_id: Any = get_system_id(self.system_name)
540
+ self.system_instance_id: str = str(uuid.uuid4())
541
+ self.last_obs_dict: Optional[Dict[str, Any]] = None
542
+ self.current_badges: int = 0
543
+
544
+ # Valid button inputs for Pokemon Red
545
+ self.valid_buttons = [
546
+ "A",
547
+ "B",
548
+ "UP",
549
+ "DOWN",
550
+ "LEFT",
551
+ "RIGHT",
552
+ "START",
553
+ "SELECT",
554
+ ]
555
+
556
+ # Create proper BaseTool objects for zyk
557
+ self.tools = [
558
+ BaseTool(
559
+ name="pokemon_red_interact",
560
+ description="Interacts with the Pokemon Red game by pressing a button.",
561
+ arguments=PokemonRedInteractArgs,
562
+ ),
563
+ BaseTool(
564
+ name="terminate",
565
+ description="Terminates the agent's execution if the task is considered complete or no useful progress can be made.",
566
+ arguments=TerminateArgs,
567
+ ),
568
+ ]
569
+
570
+ def _format_history_for_prompt(self) -> str:
571
+ prompt_history = []
572
+ for entry in self.history:
573
+ if entry["type"] == "obs":
574
+ prompt_history.append(f"OBSERVATION:\n{entry['content']}")
575
+ elif entry["type"] == "tool_call":
576
+ args_str = json.dumps(entry["tool_arguments"])
577
+ prompt_history.append(
578
+ f"THOUGHT:\nI will call the tool `{entry['tool_name']}` with arguments: {args_str}\nACTION: (Tool call executed)"
579
+ )
580
+ elif entry["type"] == "tool_response":
581
+ prompt_history.append(
582
+ "TOOL_RESPONSE:\n(Button pressed, new observation will follow if not terminal)"
583
+ )
584
+ return "\n".join(prompt_history)
585
+
586
+ def _get_recent_reasoning_traces(self, k: int = 5) -> str:
587
+ """Get the reasoning from the last k tool calls to help agent avoid repeating mistakes."""
588
+ recent_reasoning = []
589
+ tool_calls = [entry for entry in self.history if entry["type"] == "tool_call"]
590
+
591
+ # Get last k tool calls
592
+ for tool_call in tool_calls[-k:]:
593
+ if "tool_arguments" in tool_call and "reasoning" in tool_call["tool_arguments"]:
594
+ step_num = len(
595
+ [
596
+ e
597
+ for e in self.history[: self.history.index(tool_call) + 1]
598
+ if e["type"] == "tool_call"
599
+ ]
600
+ )
601
+ reasoning = tool_call["tool_arguments"]["reasoning"]
602
+ buttons = tool_call["tool_arguments"].get("buttons", ["unknown"])
603
+ recent_reasoning.append(
604
+ f"Step {step_num}: Pressed {buttons} - Reasoning: {reasoning}"
605
+ )
606
+
607
+ if recent_reasoning:
608
+ # Add warning if same button pressed many times OR same button sequence repeated
609
+ if len(recent_reasoning) >= 3:
610
+ last_3_buttons = []
611
+ last_3_sequences = []
612
+ for trace in recent_reasoning[-3:]:
613
+ # Extract buttons from trace
614
+ if "Pressed ['" in trace:
615
+ start = trace.find("Pressed ['") + 10
616
+ end = trace.find("']", start)
617
+ if end > start:
618
+ buttons_str = trace[start:end]
619
+ # Handle both single buttons and sequences
620
+ if "', '" in buttons_str:
621
+ buttons = buttons_str.split("', '")
622
+ else:
623
+ buttons = [buttons_str]
624
+ last_3_buttons.append(buttons[0] if buttons else "unknown")
625
+ last_3_sequences.append(str(buttons))
626
+
627
+ # Check for repeated single button
628
+ if len(set(last_3_buttons)) == 1 and len(last_3_buttons) >= 3:
629
+ warning = f"\n⚠️ WARNING: You've pressed '{last_3_buttons[0]}' button {len(last_3_buttons)} times in a row! This button may not be working for the current situation. Try a different approach like pressing 'B' to cancel, or movement buttons to navigate away.\n"
630
+ return (
631
+ "RECENT REASONING HISTORY:\n" + "\n".join(recent_reasoning) + warning + "\n"
632
+ )
633
+
634
+ # Check for repeated button sequences
635
+ if len(set(last_3_sequences)) == 1 and len(last_3_sequences) >= 3:
636
+ warning = f"\n⚠️ WARNING: You've used the same button sequence {last_3_sequences[0]} {len(last_3_sequences)} times in a row! This sequence may not be working. Try a completely different approach like 'B' to cancel or different movement directions.\n"
637
+ return (
638
+ "RECENT REASONING HISTORY:\n" + "\n".join(recent_reasoning) + warning + "\n"
639
+ )
640
+
641
+ return "RECENT REASONING HISTORY:\n" + "\n".join(recent_reasoning) + "\n\n"
642
+ return ""
643
+
644
+ @trace_event_async(event_type="react_agent_decide")
645
+ async def decide(
646
+ self,
647
+ obs_str: str,
648
+ current_raw_obs: Dict[str, Any],
649
+ mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
650
+ ) -> List[str]:
651
+ print(f"[AGENT_DEBUG] Starting decide with obs: {obs_str[:100]}...")
652
+ self.history.append({"type": "obs", "content": obs_str})
653
+ self.last_obs_dict = current_raw_obs
654
+
655
+ # Update current badge count from the raw observation
656
+ if current_raw_obs and isinstance(current_raw_obs.get("public"), PokemonRedPublicState):
657
+ pub_state: PokemonRedPublicState = current_raw_obs["public"]
658
+ self.current_badges = pub_state.badges
659
+
660
+ print(f"[AGENT_DEBUG] History length: {len(self.history)}")
661
+
662
+ # Extract current step count for cache busting
663
+ current_step_count = 0
664
+ if current_raw_obs and isinstance(current_raw_obs.get("public"), PokemonRedPublicState):
665
+ pub_state: PokemonRedPublicState = current_raw_obs["public"]
666
+ current_step_count = pub_state.step_count
667
+
668
+ # Extract screen buffer for vision only in screen mode
669
+ screen_images_bytes = []
670
+ if mode == "state_and_screen":
671
+ try:
672
+ # Get screen buffer directly from the observation
673
+ if (
674
+ current_raw_obs
675
+ and "screen_buffer" in current_raw_obs
676
+ and current_raw_obs["screen_buffer"] is not None
677
+ ):
678
+ screen_buffer = current_raw_obs["screen_buffer"]
679
+ print(f"[AGENT_DEBUG] Got screen buffer with shape: {screen_buffer.shape}")
680
+
681
+ # Convert screen buffer to base64 image
682
+ import base64
683
+ import io
684
+ from PIL import Image
685
+ import numpy as np
686
+
687
+ # Ensure the array is in the right format (0-255 uint8)
688
+ if screen_buffer.dtype != np.uint8:
689
+ if screen_buffer.max() <= 1.0:
690
+ screen_array = (screen_buffer * 255).astype(np.uint8)
691
+ else:
692
+ screen_array = screen_buffer.astype(np.uint8)
693
+ else:
694
+ screen_array = screen_buffer
695
+
696
+ # PyBoy screen format is (144, 160, 4) RGBA
697
+ if len(screen_array.shape) == 3 and screen_array.shape[2] == 4: # RGBA
698
+ # Convert RGBA to RGB by dropping alpha channel
699
+ image = Image.fromarray(screen_array[:, :, :3], mode="RGB")
700
+ else:
701
+ raise ValueError(f"Unsupported screen array shape: {screen_array.shape}")
702
+
703
+ # DEBUG: Save the image to debug directory
704
+ debug_dir = Path(__file__).parent / "debug"
705
+ debug_dir.mkdir(exist_ok=True)
706
+ debug_filename = (
707
+ f"step_{current_step_count:04d}_agent_{self.system_instance_id[-8:]}.png"
708
+ )
709
+ debug_path = debug_dir / debug_filename
710
+ image.save(debug_path)
711
+ print(f"[DEBUG] Saved screen image to: {debug_path}")
712
+
713
+ # Convert to base64
714
+ buffer = io.BytesIO()
715
+ image.save(buffer, format="PNG")
716
+ buffer.seek(0)
717
+ base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
718
+ screen_images_bytes = [base64_image]
719
+ print("[AGENT_DEBUG] Successfully converted screen to base64 image")
720
+ else:
721
+ print("[AGENT_DEBUG] No screen buffer available in observation")
722
+
723
+ except Exception as e:
724
+ print(f"[AGENT_DEBUG] Failed to extract screen buffer: {e}")
725
+ # Continue without screen - the text observation should still work
726
+
727
+ # Create appropriate prompt based on mode
728
+ if mode == "state_and_ascii":
729
+ prompt = (
730
+ f"{self._get_recent_reasoning_traces(k=5)}"
731
+ f"CURRENT OBSERVATION:\n{obs_str}\n\n"
732
+ "Based on the game state text and ASCII representation above, "
733
+ "what is your reasoning and which tool (`pokemon_red_interact` or `terminate`) should you call next? "
734
+ "The ASCII representation shows the visual layout of the screen. "
735
+ "Look at your recent reasoning history to avoid repeating the same ineffective actions. "
736
+ "Focus on making progress: collect badges, heal when HP is low, explore new areas, and interact with the world.\n"
737
+ f"[Turn: {current_step_count}]"
738
+ )
739
+ else: # state_and_screen
740
+ prompt = (
741
+ f"{self._get_recent_reasoning_traces(k=5)}"
742
+ f"CURRENT OBSERVATION:\n{obs_str}\n\n"
743
+ "Based on the game state text above AND the game screen image (if provided), "
744
+ "what is your reasoning and which tool (`pokemon_red_interact` or `terminate`) should you call next? "
745
+ "Look at both the text information and the visual screen to understand what's happening in the game. "
746
+ "Look at your recent reasoning history to avoid repeating the same ineffective actions. "
747
+ "Focus on making progress: collect badges, heal when HP is low, explore new areas, and interact with the world.\n"
748
+ f"[Turn: {current_step_count}]"
749
+ )
750
+
751
+ system_message = (
752
+ "You are an agent playing Pokemon Red. You receive structured game state information "
753
+ "and can execute button sequences to interact with the game. "
754
+ "Your goal is to progress through the game by collecting badges, training Pokemon, and exploring.\n\n"
755
+ "GAME STATE INFORMATION:\n"
756
+ "You receive detailed information about:\n"
757
+ "• World Location: Current map ID and position coordinates\n"
758
+ "• Player Progress: Badge count and money\n"
759
+ "• Pokemon Party: Each Pokemon's species, level, HP, and XP\n"
760
+ "• Inventory: Items with quantities\n"
761
+ "• Game System State: Raw system flags and states\n"
762
+ )
763
+
764
+ if mode == "state_and_ascii":
765
+ system_message += (
766
+ "• Visual Screen Analysis: ASCII representation and entity detection\n\n"
767
+ )
768
+ else:
769
+ system_message += (
770
+ "• Visual Screen Analysis: ASCII representation and actual screen images\n\n"
771
+ )
772
+
773
+ system_message += (
774
+ "AVAILABLE ACTIONS:\n"
775
+ "You can execute sequences of 1-5 buttons. Use as many button presses as are appropriate - sometimes 1 or 2, occasionally 3-5:\n"
776
+ f"• Available buttons: {', '.join(self.valid_buttons)}\n"
777
+ "• Examples: ['A'], ['UP', 'RIGHT'], ['START', 'DOWN', 'A']\n\n"
778
+ "IMPORTANT GUIDANCE:\n"
779
+ "• If 'Text Box Active: True' and A button isn't working, try B to cancel or navigate away\n"
780
+ "• If you're repeating the same button many times without progress, try a different approach\n"
781
+ "• When stuck, try movement buttons (UP, DOWN, LEFT, RIGHT) to explore or navigate menus\n"
782
+ "• B button often cancels menus or text boxes when A doesn't work\n"
783
+ "• Look at your recent reasoning history to avoid ineffective repeated actions\n"
784
+ "• Use shorter button sequences (1-3 buttons) rather than long sequences\n"
785
+ "• If the same action doesn't work after 2-3 tries, try something completely different\n\n"
786
+ "TOOLS AVAILABLE:\n"
787
+ f"• pokemon_red_interact: Execute button sequences\n"
788
+ "• terminate: End the session\n\n"
789
+ "Make decisions based on the game state information provided. "
790
+ "Always provide reasoning that references the specific state information."
791
+ )
792
+
793
+ print("=" * 80)
794
+ print("[AI_INPUT] SYSTEM MESSAGE:")
795
+ print(system_message)
796
+ print("-" * 40)
797
+ print("[AI_INPUT] USER MESSAGE:")
798
+ print(prompt)
799
+ print("-" * 40)
800
+ print("[AI_INPUT] TOOLS:")
801
+ print(json.dumps([tool.to_openai_tool() for tool in self.tools], indent=2))
802
+ print("-" * 40)
803
+ print(f"[AI_INPUT] IMAGES: {len(screen_images_bytes)} image(s) provided")
804
+ print("=" * 80)
805
+
806
+ print(
807
+ f"[AGENT_DEBUG] Calling LLM with prompt length: {len(prompt)}, images: {len(screen_images_bytes)}"
808
+ )
809
+ response_obj = await self.llm.respond_async(
810
+ system_message=system_message,
811
+ user_message=prompt,
812
+ tools=self.tools,
813
+ images_as_bytes=screen_images_bytes,
814
+ )
815
+ print("[AGENT_DEBUG] LLM response received")
816
+
817
+ print("=" * 80)
818
+ print("[AI_OUTPUT] RESPONSE OBJECT:")
819
+ print(f"Response type: {type(response_obj)}")
820
+ print(f"Response content: {response_obj}")
821
+ if hasattr(response_obj, "tool_calls"):
822
+ print(f"Tool calls: {response_obj.tool_calls}")
823
+ if hasattr(response_obj, "content"):
824
+ print(f"Content: {response_obj.content}")
825
+ print("=" * 80)
826
+
827
+ assert response_obj.tool_calls, "Response object didn't have tool call"
828
+ tool_calls = None
829
+
830
+ try:
831
+ if hasattr(response_obj, "tool_calls") and response_obj.tool_calls:
832
+ tool_calls = response_obj.tool_calls
833
+ print(f"[AGENT_DEBUG] Found {len(tool_calls)} tool calls")
834
+
835
+ if not tool_calls:
836
+ print("[AGENT_DEBUG] No tool calls found, falling back to A")
837
+ self.history.append(
838
+ {
839
+ "type": "tool_call",
840
+ "tool_name": "pokemon_red_interact",
841
+ "tool_arguments": {
842
+ "button": "A",
843
+ "reasoning": "LLM failed to provide tool_calls, fallback to A button.",
844
+ },
845
+ }
846
+ )
847
+ return ["A"]
848
+
849
+ if len(tool_calls) == 0:
850
+ print("[AGENT_DEBUG] Empty tool calls list, falling back to A")
851
+ self.history.append(
852
+ {"type": "error", "content": "LLM returned empty tool_calls list."}
853
+ )
854
+ return ["A"]
855
+
856
+ tool_call_data = tool_calls[0]
857
+ tool_name = ""
858
+ tool_args_str = ""
859
+
860
+ if (
861
+ hasattr(tool_call_data, "function")
862
+ and hasattr(tool_call_data.function, "name")
863
+ and hasattr(tool_call_data.function, "arguments")
864
+ ):
865
+ tool_name = tool_call_data.function.name
866
+ tool_args_str = tool_call_data.function.arguments
867
+ elif (
868
+ isinstance(tool_call_data, dict)
869
+ and "function" in tool_call_data
870
+ and isinstance(tool_call_data["function"], dict)
871
+ ):
872
+ tool_name = tool_call_data["function"].get("name")
873
+ tool_args_str = tool_call_data["function"].get("arguments")
874
+ if not isinstance(tool_args_str, str):
875
+ tool_arguments_dict = tool_args_str
876
+ tool_args_str = json.dumps(tool_arguments_dict)
877
+ else:
878
+ tool_arguments_dict = json.loads(tool_args_str)
879
+ else:
880
+ print("[AGENT_DEBUG] Unexpected tool_call structure, falling back to A")
881
+ self.history.append({"type": "error", "content": "Unexpected tool_call structure."})
882
+ return ["A"]
883
+
884
+ print(f"[AGENT_DEBUG] Tool name: {tool_name}, Args: {tool_args_str}")
885
+
886
+ if not tool_args_str:
887
+ print(f"[AGENT_DEBUG] Missing arguments for tool {tool_name}, falling back to A")
888
+ self.history.append(
889
+ {
890
+ "type": "error",
891
+ "content": f"Missing arguments for tool {tool_name}. Args string: '{tool_args_str}'",
892
+ }
893
+ )
894
+ return ["A"]
895
+
896
+ tool_arguments = json.loads(tool_args_str)
897
+
898
+ self.history.append(
899
+ {
900
+ "type": "tool_call",
901
+ "tool_name": tool_name,
902
+ "tool_arguments": tool_arguments,
903
+ }
904
+ )
905
+
906
+ if tool_name == "pokemon_red_interact":
907
+ print("[AGENT_DEBUG] Processing pokemon_red_interact tool call")
908
+ validated_args = PokemonRedInteractArgs(**tool_arguments)
909
+ buttons = validated_args.buttons
910
+ print(
911
+ f"[AGENT_DEBUG] Buttons: {buttons}, Valid: {[button in self.valid_buttons for button in buttons]}"
912
+ )
913
+
914
+ invalid_buttons = [button for button in buttons if button not in self.valid_buttons]
915
+ if invalid_buttons:
916
+ print(f"[AGENT_DEBUG] Invalid buttons: {invalid_buttons}, falling back to A")
917
+ self.history.append(
918
+ {
919
+ "type": "error",
920
+ "content": f"Invalid buttons: {invalid_buttons}. Falling back to A.",
921
+ }
922
+ )
923
+ return ["A"]
924
+ print(f"[AGENT_DEBUG] Returning buttons: {buttons}")
925
+ return buttons
926
+
927
+ elif tool_name == "terminate":
928
+ print("[AGENT_DEBUG] Processing terminate tool call")
929
+ # Allow termination if agent decides
930
+ print("[AGENT_DEBUG] Agent decided to terminate, returning TERMINATE")
931
+ return ["TERMINATE"]
932
+
933
+ else:
934
+ print(f"[AGENT_DEBUG] Unknown tool_name: {tool_name}, falling back to A")
935
+ self.history.append({"type": "error", "content": f"Unknown tool_name: {tool_name}"})
936
+ return ["A"]
937
+
938
+ except Exception as e:
939
+ error_content = (
940
+ f"Error processing LLM response: {str(e)}. Response: {str(response_obj)[:500]}"
941
+ )
942
+ print(f"[AGENT_DEBUG] Exception in decide: {error_content}")
943
+ self.history.append({"type": "error", "content": error_content})
944
+ return ["A"]
945
+
946
+
947
+ # --- Test for a single agent run ---
948
+ @pytest.mark.asyncio
949
+ async def test_react_agent_pokemon_red(
950
+ tmp_path: Path,
951
+ mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
952
+ ):
953
+ # Create a simple Pokemon Red task instance for testing
954
+ task_metadata = TaskInstanceMetadata()
955
+ inst = PokemonRedTaskInstance(
956
+ id=uuid.uuid4(),
957
+ impetus=Impetus(instructions="Start your Pokemon journey and collect badges."),
958
+ intent=Intent(
959
+ rubric={"goal": "Collect badges and progress"},
960
+ gold_trajectories=None,
961
+ gold_state_diff={},
962
+ ),
963
+ metadata=task_metadata,
964
+ is_reproducible=True,
965
+ initial_engine_snapshot=None,
966
+ )
967
+
968
+ hist_cb = PokemonRedHistoryObservationCallable(max_history=1, mode=mode)
969
+ env = PokemonRedEnvironment(inst, custom_step_obs=hist_cb)
970
+
971
+ llm = LM(model_name="gpt-4.1-nano", formatting_model_name="gpt-4.1-nano", temperature=0.0)
972
+ agent = ReActAgent(llm, max_turns=30)
973
+
974
+ async def run_episode():
975
+ obs_payload = await env.initialize()
976
+
977
+ if "error" in obs_payload:
978
+ print(f"Error during env.initialize: {obs_payload['error']}")
979
+ return False, 0
980
+
981
+ current_formatted_obs = obs_payload["formatted_obs"]
982
+ raw_obs_for_agent_decision = obs_payload
983
+
984
+ for turn in range(agent.max_turns):
985
+ buttons = await agent.decide(current_formatted_obs, raw_obs_for_agent_decision, mode)
986
+
987
+ if "TERMINATE" in buttons:
988
+ obs_payload_next = obs_payload
989
+ break
990
+
991
+ # Execute button sequence one by one
992
+ for i, button in enumerate(buttons):
993
+ print(f"[DEBUG] Executing button {i + 1}/{len(buttons)}: {button}")
994
+ obs_payload_next = await env.step([[PressButtonCall(button)]])
995
+
996
+ if "error" in obs_payload_next:
997
+ raise RuntimeError(
998
+ f"Environment step error on button {i + 1}: {obs_payload_next['error']}"
999
+ )
1000
+
1001
+ # Update observation after each button press
1002
+ obs_payload = obs_payload_next
1003
+
1004
+ # Check if environment terminated after this button
1005
+ if obs_payload["private"].terminated or obs_payload["private"].truncated:
1006
+ print(
1007
+ f"[DEBUG] Environment terminated/truncated after button {i + 1}/{len(buttons)}"
1008
+ )
1009
+ break
1010
+
1011
+ if "obs_payload_next" not in locals():
1012
+ obs_payload_next = obs_payload
1013
+
1014
+ if "error" in obs_payload_next:
1015
+ return False, agent.current_badges
1016
+
1017
+ final_private_state: PokemonRedPrivateState = obs_payload_next["private"]
1018
+ episode_successful = final_private_state.terminated or final_private_state.truncated
1019
+ return episode_successful, agent.current_badges
1020
+
1021
+ episode_completed, badges_collected = await run_episode()
1022
+
1023
+ dataset = Dataset(
1024
+ questions=[
1025
+ TrainingQuestion(
1026
+ id="pokemon_red_ep_test",
1027
+ intent="progress_in_game",
1028
+ criteria="completed_episode_or_collected_badges",
1029
+ )
1030
+ ],
1031
+ reward_signals=[
1032
+ RewardSignal(
1033
+ question_id="pokemon_red_ep_test",
1034
+ run_id=agent.system_instance_id,
1035
+ system_instance_id=agent.system_instance_id,
1036
+ reward=1 if episode_completed or badges_collected > 0 else 0,
1037
+ error_message="" if episode_completed else "Episode not completed as expected.",
1038
+ metadata={
1039
+ "agent_history": agent.history,
1040
+ "badges_collected": badges_collected,
1041
+ "total_reward_earned": hist_cb.reward_manager.total_reward_earned,
1042
+ "reward_history": hist_cb.reward_manager.reward_history,
1043
+ },
1044
+ )
1045
+ ],
1046
+ )
1047
+ # upload(dataset=dataset) # Optional: uncomment to upload trace
1048
+
1049
+ assert episode_completed or badges_collected > 0, (
1050
+ "Agent failed to complete the episode or collect any badges in the test."
1051
+ )
1052
+
1053
+
1054
+ async def eval_react_pokemon_red(
1055
+ model_name: str = "gpt-4o-mini",
1056
+ max_turns: int = 20,
1057
+ mode: Literal["state_and_ascii", "state_and_screen"] = "state_and_screen",
1058
+ ) -> None:
1059
+ """
1060
+ Run ReAct agents on Pokemon Red instances of different difficulties,
1061
+ and print aggregated success rates and average badges collected.
1062
+ """
1063
+ from tabulate import tabulate
1064
+
1065
+ current_model_name_for_eval = model_name
1066
+
1067
+ _temp_llm_for_names = LM(
1068
+ model_name=current_model_name_for_eval,
1069
+ formatting_model_name=current_model_name_for_eval,
1070
+ temperature=0.0,
1071
+ )
1072
+ _temp_agent_for_names = ReActAgent(_temp_llm_for_names)
1073
+ actual_system_name = _temp_agent_for_names.system_name
1074
+
1075
+ # ------------------------------------------------------------------ helpers
1076
+ async def run_episode_eval(
1077
+ inst: PokemonRedTaskInstance, agent_max_turns: int
1078
+ ) -> tuple[bool, int, float, list]:
1079
+ """Run a single agent/instance episode and return (success_status, badges_collected, total_rewards, reward_history)."""
1080
+ print(f"[DEBUG] Starting episode for instance {inst.id}")
1081
+ hist_cb = PokemonRedHistoryObservationCallable(max_history=1, mode=mode)
1082
+ env = PokemonRedEnvironment(inst, custom_step_obs=hist_cb)
1083
+
1084
+ llm_for_episode = LM(
1085
+ model_name=current_model_name_for_eval,
1086
+ formatting_model_name=current_model_name_for_eval,
1087
+ temperature=0.0,
1088
+ )
1089
+ agent = ReActAgent(llm_for_episode, max_turns=agent_max_turns)
1090
+ print(f"[DEBUG] Created agent with max_turns={agent_max_turns}")
1091
+
1092
+ print("[DEBUG] Initializing environment...")
1093
+ obs_payload = await env.initialize()
1094
+ print(
1095
+ f"[DEBUG] Environment initialized. Obs keys: {list(obs_payload.keys()) if isinstance(obs_payload, dict) else type(obs_payload)}"
1096
+ )
1097
+ if "error" in obs_payload:
1098
+ raise RuntimeError(f"Environment initialization failed: {obs_payload['error']}")
1099
+
1100
+ current_formatted_obs = obs_payload["formatted_obs"]
1101
+ raw_obs_for_agent_decision = obs_payload
1102
+ print(f"[DEBUG] Initial formatted obs: {current_formatted_obs[:200]}...")
1103
+
1104
+ # Track state changes to detect if agent is stuck
1105
+ last_position = None
1106
+ last_map_id = None
1107
+ stuck_count = 0
1108
+ same_button_count = 0
1109
+ last_button = None
1110
+
1111
+ turn_count = 0
1112
+ for turn_idx in range(agent.max_turns):
1113
+ turn_count += 1
1114
+ print(f"[DEBUG] === Turn {turn_idx + 1}/{agent.max_turns} ===")
1115
+ print(f"[DEBUG] Agent deciding on obs: {current_formatted_obs[:100]}...")
1116
+
1117
+ buttons = await agent.decide(current_formatted_obs, raw_obs_for_agent_decision, mode)
1118
+ print(f"[DEBUG] Agent decided buttons: {buttons}")
1119
+
1120
+ # Check for repeated button presses
1121
+ if buttons[0] == last_button:
1122
+ same_button_count += 1
1123
+ # Increased tolerance since engine now handles retries automatically
1124
+ # and some game states may legitimately require the same button multiple times
1125
+ if same_button_count >= 8: # Increased from 4 to 8
1126
+ print(
1127
+ f"[WARNING] Agent pressed same button '{buttons[0]}' {same_button_count} times in a row"
1128
+ )
1129
+ # Don't hard fail anymore - let the engine's retry mechanism handle it
1130
+ # raise RuntimeError(f"Agent pressing same button '{buttons[0]}' {same_button_count} times in a row - HARD FAIL")
1131
+ else:
1132
+ same_button_count = 1
1133
+ last_button = buttons[0]
1134
+
1135
+ if "TERMINATE" in buttons:
1136
+ print(f"[DEBUG] Agent decided to terminate after {turn_count} turns")
1137
+ break
1138
+
1139
+ print(f"[DEBUG] Stepping environment with buttons {buttons}")
1140
+
1141
+ try:
1142
+ # Execute button sequence one by one
1143
+ for i, button in enumerate(buttons):
1144
+ print(f"[DEBUG] Executing button {i + 1}/{len(buttons)}: {button}")
1145
+ obs_payload_next = await env.step([[PressButtonCall(button)]])
1146
+
1147
+ if "error" in obs_payload_next:
1148
+ raise RuntimeError(
1149
+ f"Environment step error on button {i + 1}: {obs_payload_next['error']}"
1150
+ )
1151
+
1152
+ # Update observation after each button press
1153
+ obs_payload = obs_payload_next
1154
+
1155
+ # Check if environment terminated after this button
1156
+ if obs_payload["private"].terminated or obs_payload["private"].truncated:
1157
+ print(
1158
+ f"[DEBUG] Environment terminated/truncated after button {i + 1}/{len(buttons)}"
1159
+ )
1160
+ break
1161
+ except RuntimeError as e:
1162
+ if "HARD FAIL" in str(e):
1163
+ raise # Re-raise hard failures immediately
1164
+ raise RuntimeError(f"Environment step failed: {e}")
1165
+
1166
+ print(
1167
+ f"[DEBUG] Environment step completed. Obs keys: {list(obs_payload.keys()) if isinstance(obs_payload, dict) else type(obs_payload)}"
1168
+ )
1169
+
1170
+ if "error" in obs_payload:
1171
+ raise RuntimeError(f"Environment step error: {obs_payload['error']}")
1172
+
1173
+ # Check if state is changing meaningfully using screen buffer hashes
1174
+ screen_changed = True
1175
+ if obs_payload.get("screen_buffer") is not None:
1176
+ import hashlib
1177
+
1178
+ current_screen_hash = hashlib.md5(
1179
+ obs_payload["screen_buffer"].tobytes()
1180
+ ).hexdigest()
1181
+ if not hasattr(run_episode_eval, "last_screen_hash"):
1182
+ run_episode_eval.last_screen_hash = None
1183
+ run_episode_eval.same_screen_count = 0
1184
+
1185
+ if run_episode_eval.last_screen_hash == current_screen_hash:
1186
+ run_episode_eval.same_screen_count += 1
1187
+ screen_changed = False
1188
+ else:
1189
+ run_episode_eval.same_screen_count = 0
1190
+ screen_changed = True
1191
+
1192
+ run_episode_eval.last_screen_hash = current_screen_hash
1193
+ print(
1194
+ f"[DEBUG] Screen hash: {current_screen_hash[:8]}..., Same count: {run_episode_eval.same_screen_count}, Changed: {screen_changed}"
1195
+ )
1196
+
1197
+ # More intelligent failure detection for Pokemon Red
1198
+ # Based on investigation: menu_state=1 is normal overworld state, not a stuck condition
1199
+ # B button doing nothing is often expected (no menu to close)
1200
+ button_tolerance = {
1201
+ "B": 15, # B often does nothing in overworld - very lenient
1202
+ "A": 10, # A for interactions/dialogue - moderately lenient
1203
+ "START": 8, # START for menu opening - moderate
1204
+ "SELECT": 8, # SELECT for menu navigation - moderate
1205
+ "UP": 5, # Movement buttons - less lenient
1206
+ "DOWN": 5,
1207
+ "LEFT": 5,
1208
+ "RIGHT": 5,
1209
+ }
1210
+
1211
+ max_same_button = button_tolerance.get(
1212
+ buttons[0], 5
1213
+ ) # Default to 5 for unknown buttons
1214
+ min_screen_unchanged = 12 # Increased - Pokemon Red often has static screens
1215
+ min_turn_threshold = 10 # Increased - allow more exploration time
1216
+
1217
+ # Only fail if BOTH conditions are met:
1218
+ # 1. Screen hasn't changed for many turns (visual stuckness)
1219
+ # 2. Agent is repeating ineffective actions beyond reasonable tolerance
1220
+ if (
1221
+ run_episode_eval.same_screen_count >= min_screen_unchanged
1222
+ and turn_idx > min_turn_threshold
1223
+ and same_button_count >= max_same_button
1224
+ ):
1225
+ # Additional check: don't fail on B button if menu_state indicates normal overworld
1226
+ if buttons[0] == "B":
1227
+ # B button in overworld is often ineffective but not necessarily wrong
1228
+ # Just be more lenient with B button in general
1229
+ if same_button_count < 20: # Much more lenient for B button
1230
+ print(
1231
+ f"[DEBUG] B button often ineffective in overworld - allowing more attempts ({same_button_count}/20)"
1232
+ )
1233
+ # Continue without failing
1234
+ obs_payload = obs_payload_next
1235
+ continue
1236
+
1237
+ print(
1238
+ f"[WARNING] Agent appears stuck - screen unchanged for {run_episode_eval.same_screen_count} turns with repeated button '{buttons[0]}' {same_button_count} times"
1239
+ )
1240
+ print(
1241
+ f"[WARNING] Button tolerance for '{buttons[0]}': {max_same_button}, screen unchanged threshold: {min_screen_unchanged}"
1242
+ )
1243
+ raise RuntimeError(
1244
+ f"Agent stuck - screen unchanged for {run_episode_eval.same_screen_count} turns with repeated button '{buttons[0]}' ({same_button_count} times, tolerance: {max_same_button}) - HARD FAIL"
1245
+ )
1246
+
1247
+ # Legacy position-based detection (keep as fallback but make more lenient)
1248
+ current_pub = obs_payload["public"]
1249
+ current_position = (current_pub.player_x, current_pub.player_y)
1250
+ current_map_id = current_pub.map_id
1251
+
1252
+ # Only check position-based stuck if screen is also not changing
1253
+ if (
1254
+ last_position == current_position
1255
+ and last_map_id == current_map_id
1256
+ and not screen_changed
1257
+ and turn_idx > 8
1258
+ ): # Much more lenient - allow many turns for dialogue
1259
+ stuck_count += 1
1260
+ if stuck_count >= 8: # Require many more turns of true stuck state
1261
+ raise RuntimeError(
1262
+ f"Agent truly stuck - no position or screen changes for {stuck_count} turns. Position: {current_position}, Map: {current_map_id} - HARD FAIL"
1263
+ )
1264
+ else:
1265
+ stuck_count = 0
1266
+
1267
+ last_position = current_position
1268
+ last_map_id = current_map_id
1269
+
1270
+ current_formatted_obs = obs_payload["formatted_obs"]
1271
+ raw_obs_for_agent_decision = obs_payload
1272
+
1273
+ agent.history.append(
1274
+ {
1275
+ "type": "tool_response",
1276
+ "content": f"Button sequence executed: {buttons}",
1277
+ }
1278
+ )
1279
+
1280
+ print(f"[DEBUG] New formatted obs: {current_formatted_obs[:100]}...")
1281
+
1282
+ if obs_payload["private"].terminated or obs_payload["private"].truncated:
1283
+ print(f"[DEBUG] Environment terminated/truncated after {turn_count} turns")
1284
+ print(
1285
+ f"[DEBUG] Terminated: {obs_payload['private'].terminated}, Truncated: {obs_payload['private'].truncated}"
1286
+ )
1287
+ break
1288
+
1289
+ print(f"[DEBUG] Episode completed after {turn_count} turns")
1290
+ final_private_state: PokemonRedPrivateState = obs_payload["private"]
1291
+ run_successful = final_private_state.terminated or final_private_state.truncated
1292
+ badges_collected = agent.current_badges
1293
+ total_rewards = hist_cb.reward_manager.total_reward_earned
1294
+ print(
1295
+ f"[DEBUG] Episode result - successful: {run_successful}, badges: {badges_collected}, rewards: {total_rewards:.1f}"
1296
+ )
1297
+ print(
1298
+ f"[DEBUG] Final private state - terminated: {final_private_state.terminated}, truncated: {final_private_state.truncated}"
1299
+ )
1300
+ print(f"[DEBUG] Total reward: {final_private_state.total_reward}")
1301
+ return (
1302
+ run_successful,
1303
+ badges_collected,
1304
+ total_rewards,
1305
+ hist_cb.reward_manager.reward_history,
1306
+ )
1307
+
1308
+ # ---------------------------------------------------------------- instance factory
1309
+ async def make_pokemon_red_instances(
1310
+ difficulty: str, n_instances: int = 3, start_seed: int = 0
1311
+ ) -> List[PokemonRedTaskInstance]:
1312
+ instances = []
1313
+
1314
+ for i in range(n_instances):
1315
+ current_seed = start_seed + i
1316
+ metadata = TaskInstanceMetadata()
1317
+ instance = PokemonRedTaskInstance(
1318
+ id=uuid.uuid4(),
1319
+ impetus=Impetus(
1320
+ instructions=f"Play Pokemon Red on {difficulty} difficulty and collect badges."
1321
+ ),
1322
+ intent=Intent(rubric={}, gold_trajectories=None, gold_state_diff={}),
1323
+ metadata=metadata,
1324
+ is_reproducible=True,
1325
+ initial_engine_snapshot=None,
1326
+ )
1327
+ instances.append(instance)
1328
+ return instances
1329
+
1330
+ # ---------------------------------------------------------------- evaluation
1331
+ configs = [
1332
+ (
1333
+ "easy",
1334
+ 1,
1335
+ max_turns,
1336
+ ), # (difficulty_label, num_agents/instances, max_turns_per_episode) - Use parameter
1337
+ ]
1338
+ table_rows = []
1339
+ base_seed_for_difficulty = {"easy": 1000, "hard": 2000}
1340
+
1341
+ print("Starting Pokemon Red ReAct Agent Evaluation...")
1342
+ print(f"Model: {current_model_name_for_eval}, System: {actual_system_name}")
1343
+
1344
+ all_generated_task_data = []
1345
+ all_reward_achievements = {} # Track all rewards across all runs
1346
+
1347
+ print("\nGenerating task instances...")
1348
+ all_tasks_for_eval: Dict[str, List[PokemonRedTaskInstance]] = {}
1349
+ for label, num_agents, _ in configs:
1350
+ insts = await make_pokemon_red_instances(
1351
+ label, n_instances=num_agents, start_seed=base_seed_for_difficulty[label]
1352
+ )
1353
+ all_tasks_for_eval[label] = insts
1354
+ for inst in insts:
1355
+ instance_dict = await inst.serialize()
1356
+ all_generated_task_data.append(instance_dict)
1357
+ print(f"Generated {len(insts)} instances for {label} difficulty.")
1358
+
1359
+ # Save all generated task data to a single JSON file
1360
+ dataset_dir = Path(__file__).parent.parent / "dataset"
1361
+ dataset_dir.mkdir(parents=True, exist_ok=True)
1362
+ synthetic_mix_path = dataset_dir / "synthetic_mix.json"
1363
+ with open(synthetic_mix_path, "w") as f:
1364
+ json.dump(all_generated_task_data, f, indent=2)
1365
+ print(
1366
+ f"Saved all {len(all_generated_task_data)} generated task instances to {synthetic_mix_path}"
1367
+ )
1368
+
1369
+ # Now, run the evaluations using the generated tasks
1370
+ for label, num_agents, max_episode_turns in configs:
1371
+ print(
1372
+ f"\nRunning {num_agents} agents on {label} difficulty tasks (max_turns: {max_episode_turns})..."
1373
+ )
1374
+ current_difficulty_instances = all_tasks_for_eval[label]
1375
+ print(f"[DEBUG] About to run {len(current_difficulty_instances)} instances")
1376
+
1377
+ import time
1378
+
1379
+ start_time = time.time()
1380
+ print(
1381
+ f"[DEBUG] Starting asyncio.gather for {len(current_difficulty_instances)} episodes at {start_time}"
1382
+ )
1383
+ results = await asyncio.gather(
1384
+ *(run_episode_eval(inst, max_episode_turns) for inst in current_difficulty_instances)
1385
+ )
1386
+ end_time = time.time()
1387
+ print(f"[DEBUG] Completed asyncio.gather in {end_time - start_time:.2f} seconds")
1388
+ print(f"[DEBUG] Results: {results}")
1389
+
1390
+ num_successful_runs = sum(1 for r_success, _, _, _ in results if r_success)
1391
+ total_badges = sum(r_badges for _, r_badges, _, _ in results)
1392
+ total_rewards = sum(r_rewards for _, _, r_rewards, _ in results)
1393
+ avg_badges = total_badges / len(results) if results else 0.0
1394
+ avg_rewards = total_rewards / len(results) if results else 0.0
1395
+
1396
+ # Collect reward data for summary
1397
+ reward_counts = {}
1398
+ for inst_idx, (_, _, _, reward_history) in enumerate(results):
1399
+ # Get the reward history from the corresponding hist_cb
1400
+ # We need to access this from the episode run, so let's store it
1401
+ reward_counts[inst_idx] = reward_history
1402
+
1403
+ # Aggregate rewards across all instances for this difficulty
1404
+ for inst_idx, reward_history in reward_counts.items():
1405
+ for achievement in reward_history:
1406
+ for component in achievement["components"]:
1407
+ component_name = component["component"]
1408
+ if component_name not in all_reward_achievements:
1409
+ all_reward_achievements[component_name] = 0
1410
+ all_reward_achievements[component_name] += 1
1411
+
1412
+ table_rows.append(
1413
+ [
1414
+ label,
1415
+ f"{num_successful_runs}/{len(current_difficulty_instances)}",
1416
+ f"{avg_badges:.2f}",
1417
+ f"{avg_rewards:.1f}",
1418
+ ]
1419
+ )
1420
+ print(
1421
+ f"Completed {label}: {num_successful_runs}/{len(current_difficulty_instances)} successful, Avg. Badges: {avg_badges:.2f}, Avg. Rewards: {avg_rewards:.1f}"
1422
+ )
1423
+
1424
+ print("\n--- Evaluation Summary ---")
1425
+ print(f"Model: {current_model_name_for_eval}, System: {actual_system_name}")
1426
+ print(
1427
+ tabulate(
1428
+ table_rows,
1429
+ headers=[
1430
+ "Difficulty",
1431
+ "Successful Runs",
1432
+ "Avg Badges Collected",
1433
+ "Avg Rewards Earned",
1434
+ ],
1435
+ tablefmt="github",
1436
+ )
1437
+ )
1438
+
1439
+ # Display reward achievements summary
1440
+ if all_reward_achievements:
1441
+ print("\n--- Reward Achievements Summary ---")
1442
+ reward_summary_rows = []
1443
+ for reward_name, count in sorted(
1444
+ all_reward_achievements.items(), key=lambda x: x[1], reverse=True
1445
+ ):
1446
+ reward_summary_rows.append([reward_name, count])
1447
+
1448
+ print(
1449
+ tabulate(
1450
+ reward_summary_rows,
1451
+ headers=["Reward Component", "Times Achieved"],
1452
+ tablefmt="github",
1453
+ )
1454
+ )
1455
+ print(f"\nTotal Unique Rewards Achieved: {len(all_reward_achievements)}")
1456
+ print(f"Total Reward Instances: {sum(all_reward_achievements.values())}")
1457
+ else:
1458
+ print("\n--- No Rewards Achieved ---")
1459
+
1460
+
1461
+ if __name__ == "__main__":
1462
+ # To run the test:
1463
+ # import tempfile
1464
+ # with tempfile.TemporaryDirectory() as tmpdir:
1465
+ # asyncio.run(test_react_agent_pokemon_red(Path(tmpdir)))
1466
+
1467
+ # better state management
1468
+ # To run the evaluation:
1469
+ asyncio.run(
1470
+ eval_react_pokemon_red(model_name="gpt-4.1-mini", max_turns=10, mode="state_and_screen")
1471
+ )