synth-ai 0.1.9__py3-none-any.whl → 0.2.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. synth_ai/__init__.py +28 -2
  2. synth_ai/core/system.py +4 -0
  3. synth_ai/environments/__init__.py +35 -0
  4. synth_ai/environments/environment/__init__.py +1 -0
  5. synth_ai/environments/environment/artifacts/__init__.py +1 -0
  6. synth_ai/environments/environment/artifacts/base.py +50 -0
  7. synth_ai/environments/environment/core.py +22 -0
  8. synth_ai/environments/environment/db/__init__.py +1 -0
  9. synth_ai/environments/environment/db/sqlite.py +45 -0
  10. synth_ai/environments/environment/registry.py +24 -0
  11. synth_ai/environments/environment/resources/sqlite.py +46 -0
  12. synth_ai/environments/environment/results.py +1 -0
  13. synth_ai/environments/environment/rewards/__init__.py +1 -0
  14. synth_ai/environments/environment/rewards/core.py +28 -0
  15. synth_ai/environments/environment/shared_engine.py +26 -0
  16. synth_ai/environments/environment/tools/__init__.py +34 -0
  17. synth_ai/environments/examples/__init__.py +1 -0
  18. synth_ai/environments/examples/crafter_classic/__init__.py +8 -0
  19. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +58 -0
  20. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +152 -0
  21. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +1194 -0
  22. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +51 -0
  23. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +872 -0
  24. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +1412 -0
  25. synth_ai/environments/examples/crafter_classic/agent_demos/test_crafter_react_agent.py +1110 -0
  26. synth_ai/environments/examples/crafter_classic/config_logging.py +111 -0
  27. synth_ai/environments/examples/crafter_classic/engine.py +502 -0
  28. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +63 -0
  29. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +5 -0
  30. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +74 -0
  31. synth_ai/environments/examples/crafter_classic/environment.py +255 -0
  32. synth_ai/environments/examples/crafter_classic/taskset.py +228 -0
  33. synth_ai/environments/examples/enron/agent_demos/test_synth_react.py +535 -0
  34. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +156 -0
  35. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +280 -0
  36. synth_ai/environments/examples/enron/art_helpers/types_enron.py +24 -0
  37. synth_ai/environments/examples/enron/engine.py +291 -0
  38. synth_ai/environments/examples/enron/environment.py +165 -0
  39. synth_ai/environments/examples/enron/taskset.py +112 -0
  40. synth_ai/environments/examples/enron/units/keyword_stats.py +111 -0
  41. synth_ai/environments/examples/enron/units/test_email_index.py +8 -0
  42. synth_ai/environments/examples/minigrid/__init__.py +48 -0
  43. synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +1188 -0
  44. synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +47 -0
  45. synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +562 -0
  46. synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +220 -0
  47. synth_ai/environments/examples/minigrid/agent_demos/test_minigrid_react_agent.py +393 -0
  48. synth_ai/environments/examples/minigrid/engine.py +589 -0
  49. synth_ai/environments/examples/minigrid/environment.py +274 -0
  50. synth_ai/environments/examples/minigrid/environment_mapping.py +242 -0
  51. synth_ai/environments/examples/minigrid/puzzle_loader.py +416 -0
  52. synth_ai/environments/examples/minigrid/taskset.py +583 -0
  53. synth_ai/environments/examples/minigrid/units/test_action_behavior.py +226 -0
  54. synth_ai/environments/examples/minigrid/units/test_debug_messages.py +83 -0
  55. synth_ai/environments/examples/minigrid/units/test_exploration.py +120 -0
  56. synth_ai/environments/examples/minigrid/units/test_minigrid_engine.py +214 -0
  57. synth_ai/environments/examples/minigrid/units/test_minigrid_environment.py +238 -0
  58. synth_ai/environments/examples/minigrid/units/test_minigrid_environment_mapping.py +301 -0
  59. synth_ai/environments/examples/minigrid/units/test_minigrid_taskset.py +210 -0
  60. synth_ai/environments/examples/nethack/__init__.py +7 -0
  61. synth_ai/environments/examples/nethack/achievements.py +337 -0
  62. synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +981 -0
  63. synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +74 -0
  64. synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +832 -0
  65. synth_ai/environments/examples/nethack/agent_demos/test_nethack_react_agent.py +1112 -0
  66. synth_ai/environments/examples/nethack/engine.py +738 -0
  67. synth_ai/environments/examples/nethack/environment.py +255 -0
  68. synth_ai/environments/examples/nethack/helpers/__init__.py +42 -0
  69. synth_ai/environments/examples/nethack/helpers/action_mapping.py +301 -0
  70. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +401 -0
  71. synth_ai/environments/examples/nethack/helpers/observation_utils.py +433 -0
  72. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +201 -0
  73. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +268 -0
  74. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +308 -0
  75. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +430 -0
  76. synth_ai/environments/examples/nethack/taskset.py +323 -0
  77. synth_ai/environments/examples/nethack/units/test_nethack_engine.py +277 -0
  78. synth_ai/environments/examples/nethack/units/test_nethack_environment.py +281 -0
  79. synth_ai/environments/examples/nethack/units/test_nethack_taskset.py +213 -0
  80. synth_ai/environments/examples/nethack/units/test_recording.py +307 -0
  81. synth_ai/environments/examples/red/__init__.py +7 -0
  82. synth_ai/environments/examples/red/agent_demos/__init__.py +1 -0
  83. synth_ai/environments/examples/red/agent_demos/test_synth_react.py +1471 -0
  84. synth_ai/environments/examples/red/config_logging.py +110 -0
  85. synth_ai/environments/examples/red/engine.py +693 -0
  86. synth_ai/environments/examples/red/engine_helpers/__init__.py +1 -0
  87. synth_ai/environments/examples/red/engine_helpers/memory_map.py +28 -0
  88. synth_ai/environments/examples/red/engine_helpers/reward_components.py +275 -0
  89. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +142 -0
  90. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +56 -0
  91. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +283 -0
  92. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +149 -0
  93. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +137 -0
  94. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +56 -0
  95. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +330 -0
  96. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +120 -0
  97. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +558 -0
  98. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +312 -0
  99. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +147 -0
  100. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +246 -0
  101. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +367 -0
  102. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +139 -0
  103. synth_ai/environments/examples/red/environment.py +235 -0
  104. synth_ai/environments/examples/red/taskset.py +77 -0
  105. synth_ai/environments/examples/red/test_fixes.py +125 -0
  106. synth_ai/environments/examples/red/test_fixes_mock.py +148 -0
  107. synth_ai/environments/examples/red/units/__init__.py +1 -0
  108. synth_ai/environments/examples/red/units/test_basic_functionality.py +97 -0
  109. synth_ai/environments/examples/red/units/test_button_press_requirements.py +217 -0
  110. synth_ai/environments/examples/red/units/test_engine.py +192 -0
  111. synth_ai/environments/examples/red/units/test_environment.py +455 -0
  112. synth_ai/environments/examples/red/units/test_exploration_strategy.py +227 -0
  113. synth_ai/environments/examples/red/units/test_integration.py +217 -0
  114. synth_ai/environments/examples/red/units/test_memory_extraction.py +111 -0
  115. synth_ai/environments/examples/red/units/test_menu_bug_reproduction.py +1100 -0
  116. synth_ai/environments/examples/red/units/test_movement_debug.py +255 -0
  117. synth_ai/environments/examples/red/units/test_pokemon_mcts_debug.py +163 -0
  118. synth_ai/environments/examples/red/units/test_pokemon_mcts_verbose.py +117 -0
  119. synth_ai/environments/examples/red/units/test_red_basic.py +145 -0
  120. synth_ai/environments/examples/red/units/test_red_comprehensive.py +323 -0
  121. synth_ai/environments/examples/red/units/test_retry_movement.py +195 -0
  122. synth_ai/environments/examples/red/units/test_reward_components.py +186 -0
  123. synth_ai/environments/examples/red/units/test_rom_integration.py +260 -0
  124. synth_ai/environments/examples/red/units/test_taskset.py +116 -0
  125. synth_ai/environments/examples/red/units/test_tree.py +448 -0
  126. synth_ai/environments/examples/sokoban/__init__.py +1 -0
  127. synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +900 -0
  128. synth_ai/environments/examples/sokoban/agent_demos/test_dspy_react.py +1 -0
  129. synth_ai/environments/examples/sokoban/agent_demos/test_sokoban_react_agent.py +498 -0
  130. synth_ai/environments/examples/sokoban/agent_demos/test_synth_lats.py +1 -0
  131. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_locally.py +748 -0
  132. synth_ai/environments/examples/sokoban/agent_demos/test_synth_react_service.py +296 -0
  133. synth_ai/environments/examples/sokoban/engine.py +675 -0
  134. synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +1 -0
  135. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +656 -0
  136. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +17 -0
  137. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +3 -0
  138. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +129 -0
  139. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +370 -0
  140. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +331 -0
  141. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +305 -0
  142. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +66 -0
  143. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +114 -0
  144. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +122 -0
  145. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +394 -0
  146. synth_ai/environments/examples/sokoban/environment.py +228 -0
  147. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +438 -0
  148. synth_ai/environments/examples/sokoban/puzzle_loader.py +311 -0
  149. synth_ai/environments/examples/sokoban/taskset.py +425 -0
  150. synth_ai/environments/examples/sokoban/units/astar_common.py +94 -0
  151. synth_ai/environments/examples/sokoban/units/test_building_task_set.py +49 -0
  152. synth_ai/environments/examples/sokoban/units/test_false_positive.py +120 -0
  153. synth_ai/environments/examples/sokoban/units/test_simple_run_through_environment.py +119 -0
  154. synth_ai/environments/examples/sokoban/units/test_sokoban_environment.py +98 -0
  155. synth_ai/environments/examples/sokoban/units/test_tree.py +364 -0
  156. synth_ai/environments/examples/tictactoe/__init__.py +1 -0
  157. synth_ai/environments/examples/tictactoe/agent_demos/test_synth_react.py +266 -0
  158. synth_ai/environments/examples/tictactoe/agent_demos/test_tictactoe_react_agent.py +470 -0
  159. synth_ai/environments/examples/tictactoe/engine.py +368 -0
  160. synth_ai/environments/examples/tictactoe/environment.py +239 -0
  161. synth_ai/environments/examples/tictactoe/taskset.py +214 -0
  162. synth_ai/environments/examples/tictactoe/units/test_tictactoe_engine.py +393 -0
  163. synth_ai/environments/examples/tictactoe/units/test_tictactoe_environment.py +493 -0
  164. synth_ai/environments/examples/tictactoe/units/test_tictactoe_taskset.py +191 -0
  165. synth_ai/environments/examples/verilog/__init__.py +10 -0
  166. synth_ai/environments/examples/verilog/agent_demos/test_synth_react.py +520 -0
  167. synth_ai/environments/examples/verilog/engine.py +328 -0
  168. synth_ai/environments/examples/verilog/environment.py +349 -0
  169. synth_ai/environments/examples/verilog/taskset.py +418 -0
  170. synth_ai/environments/examples/verilog/units/test_verilog_engine.py +466 -0
  171. synth_ai/environments/examples/verilog/units/test_verilog_environment.py +585 -0
  172. synth_ai/environments/examples/verilog/units/test_verilog_integration.py +383 -0
  173. synth_ai/environments/examples/verilog/units/test_verilog_taskset.py +457 -0
  174. synth_ai/environments/reproducibility/core.py +42 -0
  175. synth_ai/environments/reproducibility/tree.py +364 -0
  176. synth_ai/environments/service/app.py +78 -0
  177. synth_ai/environments/service/core_routes.py +775 -0
  178. synth_ai/environments/service/external_registry.py +57 -0
  179. synth_ai/environments/service/registry.py +9 -0
  180. synth_ai/environments/stateful/__init__.py +1 -0
  181. synth_ai/environments/stateful/core.py +28 -0
  182. synth_ai/environments/stateful/engine.py +21 -0
  183. synth_ai/environments/stateful/state.py +7 -0
  184. synth_ai/environments/tasks/api.py +19 -0
  185. synth_ai/environments/tasks/core.py +78 -0
  186. synth_ai/environments/tasks/filters.py +39 -0
  187. synth_ai/environments/tasks/utils.py +89 -0
  188. synth_ai/environments/v0_observability/history.py +3 -0
  189. synth_ai/environments/v0_observability/log.py +2 -0
  190. synth_ai/lm/caching/constants.py +1 -0
  191. synth_ai/{zyk/lms → lm}/caching/ephemeral.py +4 -8
  192. synth_ai/{zyk/lms → lm}/caching/handler.py +15 -15
  193. synth_ai/{zyk/lms → lm}/caching/initialize.py +2 -4
  194. synth_ai/{zyk/lms → lm}/caching/persistent.py +4 -10
  195. synth_ai/{zyk/lms → lm}/config.py +2 -1
  196. synth_ai/{zyk/lms → lm}/constants.py +2 -2
  197. synth_ai/{zyk/lms → lm}/core/all.py +10 -10
  198. synth_ai/{zyk/lms → lm}/core/main.py +57 -33
  199. synth_ai/{zyk/lms → lm}/core/vendor_clients.py +12 -10
  200. synth_ai/lm/cost/monitor.py +1 -0
  201. synth_ai/lm/cost/statefulness.py +1 -0
  202. synth_ai/lm/provider_support/__init__.py +8 -0
  203. synth_ai/lm/provider_support/anthropic.py +945 -0
  204. synth_ai/lm/provider_support/openai.py +1115 -0
  205. synth_ai/lm/provider_support/suppress_logging.py +31 -0
  206. synth_ai/{zyk/lms → lm}/structured_outputs/handler.py +58 -80
  207. synth_ai/{zyk/lms → lm}/structured_outputs/inject.py +6 -20
  208. synth_ai/{zyk/lms → lm}/structured_outputs/rehabilitate.py +6 -12
  209. synth_ai/{zyk/lms → lm}/vendors/core/anthropic_api.py +21 -30
  210. synth_ai/{zyk/lms → lm}/vendors/core/gemini_api.py +37 -32
  211. synth_ai/{zyk/lms → lm}/vendors/core/mistral_api.py +19 -28
  212. synth_ai/{zyk/lms → lm}/vendors/core/openai_api.py +26 -36
  213. synth_ai/{zyk/lms → lm}/vendors/openai_standard.py +29 -33
  214. synth_ai/{zyk/lms → lm}/vendors/retries.py +1 -1
  215. synth_ai/lm/vendors/supported/__init__.py +0 -0
  216. synth_ai/{zyk/lms → lm}/vendors/supported/custom_endpoint.py +131 -118
  217. synth_ai/{zyk/lms → lm}/vendors/supported/deepseek.py +4 -8
  218. synth_ai/{zyk/lms → lm}/vendors/supported/grok.py +6 -8
  219. synth_ai/{zyk/lms → lm}/vendors/supported/groq.py +1 -1
  220. synth_ai/{zyk/lms → lm}/vendors/supported/ollama.py +2 -2
  221. synth_ai/{zyk/lms → lm}/vendors/supported/openrouter.py +18 -16
  222. synth_ai/{zyk/lms → lm}/vendors/supported/together.py +1 -1
  223. synth_ai/tracing/__init__.py +0 -0
  224. synth_ai/tracing/abstractions.py +224 -0
  225. synth_ai/tracing/base_client.py +91 -0
  226. synth_ai/tracing/client_manager.py +131 -0
  227. synth_ai/tracing/config.py +140 -0
  228. synth_ai/tracing/context.py +146 -0
  229. synth_ai/tracing/decorators.py +679 -0
  230. synth_ai/tracing/events/__init__.py +0 -0
  231. synth_ai/tracing/events/manage.py +147 -0
  232. synth_ai/tracing/events/scope.py +86 -0
  233. synth_ai/tracing/events/store.py +227 -0
  234. synth_ai/tracing/immediate_client.py +152 -0
  235. synth_ai/tracing/local.py +18 -0
  236. synth_ai/tracing/log_client_base.py +74 -0
  237. synth_ai/tracing/retry_queue.py +187 -0
  238. synth_ai/tracing/trackers.py +515 -0
  239. synth_ai/tracing/upload.py +504 -0
  240. synth_ai/tracing/utils.py +9 -0
  241. synth_ai/zyk/__init__.py +28 -2
  242. synth_ai-0.2.1.dev0.dist-info/METADATA +349 -0
  243. synth_ai-0.2.1.dev0.dist-info/RECORD +261 -0
  244. synth_ai/zyk/lms/caching/constants.py +0 -1
  245. synth_ai/zyk/lms/cost/monitor.py +0 -1
  246. synth_ai/zyk/lms/cost/statefulness.py +0 -1
  247. synth_ai-0.1.9.dist-info/METADATA +0 -37
  248. synth_ai-0.1.9.dist-info/RECORD +0 -50
  249. /synth_ai/{zyk/lms/__init__.py → environments/reproducibility/helpers.py} +0 -0
  250. /synth_ai/{zyk/lms/caching → lm}/__init__.py +0 -0
  251. /synth_ai/{zyk/lms/core → lm/caching}/__init__.py +0 -0
  252. /synth_ai/{zyk/lms → lm}/caching/dbs.py +0 -0
  253. /synth_ai/{zyk/lms/cost → lm/core}/__init__.py +0 -0
  254. /synth_ai/{zyk/lms → lm}/core/exceptions.py +0 -0
  255. /synth_ai/{zyk/lms/structured_outputs → lm/cost}/__init__.py +0 -0
  256. /synth_ai/{zyk/lms/vendors → lm/structured_outputs}/__init__.py +0 -0
  257. /synth_ai/{zyk/lms → lm}/tools/__init__.py +0 -0
  258. /synth_ai/{zyk/lms → lm}/tools/base.py +0 -0
  259. /synth_ai/{zyk/lms/vendors/core → lm/vendors}/__init__.py +0 -0
  260. /synth_ai/{zyk/lms → lm}/vendors/base.py +0 -0
  261. /synth_ai/{zyk/lms/vendors/local → lm/vendors/core}/__init__.py +0 -0
  262. /synth_ai/{zyk/lms/vendors/supported → lm/vendors/local}/__init__.py +0 -0
  263. /synth_ai/{zyk/lms → lm}/vendors/local/ollama.py +0 -0
  264. {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/WHEEL +0 -0
  265. {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/licenses/LICENSE +0 -0
  266. {synth_ai-0.1.9.dist-info → synth_ai-0.2.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
1
+ from __future__ import annotations
2
+ from typing import List, Optional, Any, Dict, Union
3
+ from pydantic import BaseModel, Field
4
+
5
+ # Import logging configuration to suppress JAX debug messages
6
+
7
+ from .engine import (
8
+ PokemonRedEngine,
9
+ PokemonRedPrivateState,
10
+ PokemonRedPublicState,
11
+ PokemonRedEngineSnapshot,
12
+ )
13
+ from .taskset import PokemonRedTaskInstance, INSTANCE as DEFAULT_TASK_INSTANCE
14
+ from synth_ai.environments.environment.shared_engine import (
15
+ GetObservationCallable,
16
+ InternalObservation,
17
+ )
18
+ from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
19
+ from synth_ai.environments.stateful.core import StatefulEnvironment
20
+ from synth_ai.environments.environment.tools import (
21
+ AbstractTool,
22
+ EnvToolCall,
23
+ ToolResult,
24
+ TOOL_REGISTRY,
25
+ register_tool,
26
+ )
27
+
28
+
29
+ # Tool input schemas
30
+ class PressButtonInput(BaseModel):
31
+ button: str = Field(
32
+ ..., description="Game Boy button: A, B, UP, DOWN, LEFT, RIGHT, START, SELECT"
33
+ )
34
+ frames: int = Field(1, description="Number of frames to hold the button")
35
+
36
+
37
+ # Tool definitions
38
+ class PressButtonTool(AbstractTool):
39
+ name = "press_button"
40
+ description = "Press a Game Boy button for the specified number of frames"
41
+ call_schema = PressButtonInput
42
+ result_schema = ToolResult
43
+
44
+ def __init__(self, engine: PokemonRedEngine):
45
+ self.engine = engine
46
+
47
+ async def __call__(self, call: EnvToolCall) -> ToolResult:
48
+ try:
49
+ validated_args = self.call_schema(**call.args)
50
+ action = {"button": validated_args.button, "frames": validated_args.frames}
51
+ priv_state, pub_state = await self.engine._step_engine(action)
52
+ return ToolResult(
53
+ ok=True,
54
+ payload={
55
+ "public": pub_state,
56
+ "private": priv_state,
57
+ },
58
+ )
59
+ except Exception as e:
60
+ # Get current state for error context
61
+ priv_state, pub_state = self.engine._create_states(reward=0.0)
62
+ return ToolResult(
63
+ ok=False,
64
+ error=str(e),
65
+ payload={"public": pub_state},
66
+ )
67
+
68
+
69
+ # Observation callable for Pokemon Red
70
+ class PokemonRedObservationCallable(GetObservationCallable):
71
+ async def get_observation(
72
+ self, pub: PokemonRedPublicState, priv: PokemonRedPrivateState
73
+ ) -> InternalObservation:
74
+ """Convert Pokemon Red states to agent observation"""
75
+ from .engine_helpers.state_extraction import (
76
+ get_badge_count,
77
+ format_position,
78
+ format_hp_status,
79
+ )
80
+
81
+ badge_count = get_badge_count(pub.badges)
82
+ position = format_position(pub.player_x, pub.player_y, pub.map_id)
83
+ hp_status = format_hp_status(pub.party_hp_current, pub.party_hp_max)
84
+
85
+ obs = {
86
+ "position": position,
87
+ "badges_earned": badge_count,
88
+ "badges_bitfield": pub.badges,
89
+ "hp_status": hp_status,
90
+ "party_level": pub.party_level,
91
+ "party_xp": pub.party_xp,
92
+ "in_battle": pub.in_battle,
93
+ "step_count": pub.step_count,
94
+ "reward_last_step": priv.reward_last_step,
95
+ "total_reward": priv.total_reward,
96
+ "terminated": priv.terminated,
97
+ }
98
+
99
+ if pub.error_info:
100
+ obs["error"] = pub.error_info
101
+
102
+ return obs
103
+
104
+
105
+ class PokemonRedEnvironment(StatefulEnvironment, ReproducibleEnvironment[PokemonRedEngine]):
106
+ """Pokemon Red stateful game environment for AI agents"""
107
+
108
+ def __init__(
109
+ self,
110
+ task_instance: Optional[PokemonRedTaskInstance] = None,
111
+ custom_step_obs: Optional[GetObservationCallable] = None,
112
+ custom_ckpt_obs: Optional[GetObservationCallable] = None,
113
+ ):
114
+ self.name = "PokemonRed"
115
+ self.task_instance = task_instance or DEFAULT_TASK_INSTANCE
116
+ self.custom_step_observation_callable = custom_step_obs or PokemonRedObservationCallable()
117
+ self.custom_checkpoint_observation_callable = (
118
+ custom_ckpt_obs or PokemonRedObservationCallable()
119
+ )
120
+ self.engine = PokemonRedEngine(self.task_instance)
121
+
122
+ # Register tools
123
+ self._press_button_tool = PressButtonTool(self.engine)
124
+ if self._press_button_tool.name not in TOOL_REGISTRY:
125
+ register_tool(self._press_button_tool)
126
+
127
+ async def initialize(self) -> InternalObservation:
128
+ """Initialize the Pokemon Red environment"""
129
+ priv, pub = await self.engine._reset_engine()
130
+ return await self._to_observation(priv, pub, self.custom_step_observation_callable)
131
+
132
+ async def terminate(self) -> InternalObservation:
133
+ """Terminate the environment"""
134
+ priv, pub = self.engine._create_states(reward=0.0, terminated=True)
135
+ obs_dict = {
136
+ "terminated": True,
137
+ "message": "Pokemon Red environment terminated.",
138
+ }
139
+ return await self._to_observation(
140
+ priv, pub, self.custom_step_observation_callable, extra_obs=obs_dict
141
+ )
142
+
143
+ def validate_tool_calls(
144
+ self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
145
+ ) -> EnvToolCall:
146
+ """Validate and normalize tool calls to single EnvToolCall"""
147
+ if isinstance(tool_calls, list):
148
+ if not tool_calls:
149
+ raise ValueError("Received empty list of tool calls.")
150
+ if isinstance(tool_calls[0], list):
151
+ if not tool_calls[0]:
152
+ raise ValueError("Received empty inner list of tool calls.")
153
+ agent_call = tool_calls[0][0]
154
+ else:
155
+ agent_call = tool_calls[0]
156
+ elif isinstance(tool_calls, EnvToolCall):
157
+ agent_call = tool_calls
158
+ else:
159
+ raise TypeError(f"Unexpected type for tool_calls: {type(tool_calls)}")
160
+
161
+ if not isinstance(agent_call, EnvToolCall):
162
+ raise TypeError(f"Processed call is not EnvToolCall: {type(agent_call)}")
163
+ if agent_call.tool != "press_button":
164
+ raise ValueError(f"Unknown tool: {agent_call.tool}. Expected 'press_button'.")
165
+
166
+ return agent_call
167
+
168
+ async def step(
169
+ self, tool_calls: Union[EnvToolCall, List[EnvToolCall], List[List[EnvToolCall]]]
170
+ ) -> InternalObservation:
171
+ """Execute one step in the Pokemon Red environment"""
172
+ agent_call = self.validate_tool_calls(tool_calls)
173
+ tool_result: ToolResult = await self._press_button_tool(agent_call)
174
+
175
+ payload_dict = tool_result.payload
176
+ if not tool_result.ok or not isinstance(payload_dict, dict):
177
+ # Fallback if tool execution failed
178
+ priv_state, pub_state = self.engine._create_states(reward=0.0)
179
+ if tool_result.error and hasattr(pub_state, "error_info"):
180
+ pub_state.error_info = tool_result.error
181
+ else:
182
+ # Extract states from successful tool execution - now they're dataclass objects
183
+ priv_state = payload_dict.get("private")
184
+ pub_state = payload_dict.get("public")
185
+
186
+ if priv_state is None or pub_state is None:
187
+ priv_state, pub_state = self.engine._create_states(reward=0.0)
188
+ if tool_result.error and hasattr(pub_state, "error_info"):
189
+ pub_state.error_info = tool_result.error
190
+ else:
191
+ # States are already dataclass objects, no need to reconstruct
192
+ if tool_result.error and hasattr(pub_state, "error_info"):
193
+ pub_state.error_info = tool_result.error
194
+
195
+ return await self._to_observation(
196
+ priv_state, pub_state, self.custom_step_observation_callable
197
+ )
198
+
199
+ async def checkpoint(self) -> InternalObservation:
200
+ """Create a checkpoint of the current environment state"""
201
+ engine_snapshot: PokemonRedEngineSnapshot = await self.engine._serialize_engine()
202
+ priv, pub = self.engine._create_states(reward=0.0)
203
+ obs_data = await self._to_observation(
204
+ priv, pub, self.custom_checkpoint_observation_callable
205
+ )
206
+ if isinstance(obs_data, dict):
207
+ obs_data["engine_snapshot_data"] = engine_snapshot.model_dump()
208
+ return obs_data
209
+
210
+ async def _to_observation(
211
+ self,
212
+ priv: PokemonRedPrivateState,
213
+ pub: PokemonRedPublicState,
214
+ obs_cb: Optional[GetObservationCallable],
215
+ extra_obs: Optional[Dict[str, Any]] = None,
216
+ ) -> InternalObservation:
217
+ """Convert states to observation using the specified callback"""
218
+ active_obs_cb = obs_cb or PokemonRedObservationCallable()
219
+ observation = await active_obs_cb.get_observation(pub, priv)
220
+ if extra_obs and isinstance(observation, dict):
221
+ observation.update(extra_obs)
222
+ return observation
223
+
224
+ # ReproducibleEnvironment methods
225
+ async def _serialize_engine(self) -> PokemonRedEngineSnapshot:
226
+ return await self.engine._serialize_engine()
227
+
228
+ @classmethod
229
+ async def _deserialize_engine(
230
+ cls, snapshot: PokemonRedEngineSnapshot, task_instance: PokemonRedTaskInstance
231
+ ) -> "PokemonRedEnvironment":
232
+ eng = await PokemonRedEngine._deserialize_engine(snapshot, task_instance)
233
+ env = cls(task_instance)
234
+ env.engine = eng
235
+ return env
@@ -0,0 +1,77 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ import uuid
5
+ from synth_ai.environments.tasks.core import (
6
+ Task,
7
+ TaskInstance,
8
+ Impetus,
9
+ Intent,
10
+ TaskInstanceMetadata,
11
+ )
12
+
13
+ # Define the main task for Pokemon Red
14
+ TASK = Task(
15
+ global_premises="You are playing Pokemon Red. Start in Pewter City with a level-10 Pikachu.",
16
+ global_constraints="No glitches or exploits. Play within normal game mechanics.",
17
+ global_objectives="Defeat Brock at the Pewter Gym to earn the Boulder Badge.",
18
+ shared_env_params={},
19
+ )
20
+
21
+ # Path to initial save state (would contain a save near Pewter Gym)
22
+ INITIAL_SNAPSHOT = Path(__file__).parent / "snapshots" / "pewter_start.state"
23
+
24
+
25
+ @dataclass
26
+ class PokemonRedTaskInstance(TaskInstance):
27
+ """Task instance for Pokemon Red challenges"""
28
+
29
+ async def serialize(self) -> dict:
30
+ """Serialize the task instance to a dictionary"""
31
+ return {
32
+ "id": str(self.id),
33
+ "impetus": {"instructions": self.impetus.instructions},
34
+ "intent": {
35
+ "rubric": self.intent.rubric,
36
+ "gold_trajectories": None,
37
+ "gold_state_diff": self.intent.gold_state_diff,
38
+ },
39
+ "metadata": {},
40
+ "is_reproducible": self.is_reproducible,
41
+ "initial_engine_snapshot": str(self.initial_engine_snapshot)
42
+ if self.initial_engine_snapshot
43
+ else None,
44
+ }
45
+
46
+ @classmethod
47
+ async def deserialize(cls, data: dict) -> "PokemonRedTaskInstance":
48
+ """Deserialize a task instance from a dictionary"""
49
+ return cls(
50
+ id=uuid.UUID(data["id"]),
51
+ impetus=Impetus(instructions=data["impetus"]["instructions"]),
52
+ intent=Intent(
53
+ rubric=data["intent"]["rubric"],
54
+ gold_trajectories=None,
55
+ gold_state_diff=data["intent"]["gold_state_diff"],
56
+ ),
57
+ metadata=TaskInstanceMetadata(),
58
+ is_reproducible=data["is_reproducible"],
59
+ initial_engine_snapshot=None,
60
+ )
61
+
62
+
63
+ # Main task instance - beat Brock for Boulder Badge
64
+ INSTANCE = PokemonRedTaskInstance(
65
+ id=uuid.UUID("12345678-1234-5678-9abc-123456789abc"),
66
+ impetus=Impetus(
67
+ instructions="Navigate to Pewter Gym and defeat Brock to earn the Boulder Badge. Use strategic Pokemon battles and item management."
68
+ ),
69
+ intent=Intent(
70
+ rubric="Successfully obtain the Boulder Badge by defeating Brock at Pewter Gym. Efficiency measured by minimal steps and strategic Pokemon usage.",
71
+ gold_trajectories=None,
72
+ gold_state_diff={"badges": 1},
73
+ ),
74
+ metadata=TaskInstanceMetadata(),
75
+ is_reproducible=True,
76
+ initial_engine_snapshot=INITIAL_SNAPSHOT if INITIAL_SNAPSHOT.exists() else None,
77
+ )
@@ -0,0 +1,125 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify red environment fixes.
4
+ Tests JAX logging suppression and error handling.
5
+ """
6
+
7
+ import asyncio
8
+ import logging
9
+ import sys
10
+
11
+ from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
12
+ from synth_ai.environments.examples.red.taskset import INSTANCE as POKEMON_TASK
13
+ from synth_ai.environments.environment.tools import EnvToolCall
14
+
15
+
16
+ class PressButtonCall(EnvToolCall):
17
+ """Helper class for creating button press calls"""
18
+
19
+ def __init__(self, button: str, frames: int = 1):
20
+ super().__init__(tool="press_button", args={"button": button, "frames": frames})
21
+
22
+
23
+ async def test_environment_setup():
24
+ """Test that the environment can be set up without errors."""
25
+ print("Testing Pokemon Red environment setup...")
26
+
27
+ try:
28
+ # Create environment instance
29
+ env = PokemonRedEnvironment(POKEMON_TASK)
30
+ print("✅ Environment created successfully")
31
+
32
+ # Try to initialize
33
+ obs = await env.initialize()
34
+ print("✅ Environment initialized successfully")
35
+ print(f"Initial observation keys: {list(obs.keys())}")
36
+
37
+ # Try a simple step
38
+ obs = await env.step(PressButtonCall("A"))
39
+ print("✅ Environment step executed successfully")
40
+ print(
41
+ f"Step observation: step_count={obs.get('step_count')}, terminated={obs.get('terminated')}"
42
+ )
43
+
44
+ # Terminate
45
+ final_obs = await env.terminate()
46
+ print("✅ Environment terminated successfully")
47
+
48
+ return True
49
+
50
+ except Exception as e:
51
+ print(f"❌ Failed to setup environment: {e}")
52
+ logging.exception("Failed to setup environment, aborting test")
53
+ return False
54
+
55
+
56
+ def test_logging_configuration():
57
+ """Test that logging is properly configured."""
58
+ print("Testing logging configuration...")
59
+
60
+ # Check that JAX loggers are set to WARNING level
61
+ jax_loggers = [
62
+ "jax._src.cache_key",
63
+ "jax._src.compilation_cache",
64
+ "jax._src.compiler",
65
+ "jax._src.dispatch",
66
+ ]
67
+
68
+ for logger_name in jax_loggers:
69
+ logger = logging.getLogger(logger_name)
70
+ if logger.level >= logging.WARNING:
71
+ print(f"✅ {logger_name} logger level: {logging.getLevelName(logger.level)}")
72
+ else:
73
+ print(
74
+ f"❌ {logger_name} logger level: {logging.getLevelName(logger.level)} (should be WARNING or higher)"
75
+ )
76
+
77
+ # Test that debug messages are suppressed
78
+ jax_logger = logging.getLogger("jax._src.cache_key")
79
+ jax_logger.debug("This debug message should not appear")
80
+ print("✅ JAX debug logging appears to be suppressed")
81
+
82
+
83
+ def test_safe_compare():
84
+ """Test the safe comparison function."""
85
+ print("Testing safe comparison function...")
86
+
87
+ from synth_ai.environments.examples.red.config_logging import safe_compare
88
+
89
+ # Test cases
90
+ test_cases = [
91
+ ("5", 3, ">", True), # String vs int
92
+ (5, "3", ">", True), # Int vs string
93
+ ("abc", 5, ">", False), # Invalid string vs int
94
+ ("5", "3", ">", True), # String vs string (numeric)
95
+ ("abc", "def", ">", False), # String vs string (alphabetic)
96
+ (5, 3, ">", True), # Normal int comparison
97
+ ]
98
+
99
+ for left, right, op, expected in test_cases:
100
+ result = safe_compare(left, right, op)
101
+ status = "✅" if result == expected else "❌"
102
+ print(f"{status} safe_compare({left}, {right}, '{op}') = {result} (expected {expected})")
103
+
104
+
105
+ async def main():
106
+ """Main test function."""
107
+ print("Running Pokemon Red environment fixes test...\n")
108
+
109
+ # Test logging configuration
110
+ test_logging_configuration()
111
+ print()
112
+
113
+ # Test safe comparison
114
+ test_safe_compare()
115
+ print()
116
+
117
+ # Test environment setup
118
+ success = await test_environment_setup()
119
+
120
+ print(f"\nOverall test result: {'✅ PASSED' if success else '❌ FAILED'}")
121
+ return 0 if success else 1
122
+
123
+
124
+ if __name__ == "__main__":
125
+ sys.exit(asyncio.run(main()))
@@ -0,0 +1,148 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Mock test script to verify red environment fixes without ROM file.
4
+ Tests JAX logging suppression and error handling.
5
+ """
6
+
7
+ import logging
8
+ import sys
9
+ from unittest.mock import Mock, patch
10
+
11
+
12
+ def test_logging_configuration():
13
+ """Test that logging is properly configured."""
14
+ print("Testing logging configuration...")
15
+
16
+ # Import configuration to trigger setup
17
+ from synth_ai.environments.examples.red.config_logging import configure_logging
18
+
19
+ configure_logging()
20
+
21
+ # Check that JAX loggers are set to WARNING level
22
+ jax_loggers = [
23
+ "jax._src.cache_key",
24
+ "jax._src.compilation_cache",
25
+ "jax._src.compiler",
26
+ "jax._src.dispatch",
27
+ ]
28
+
29
+ success = True
30
+ for logger_name in jax_loggers:
31
+ logger = logging.getLogger(logger_name)
32
+ if logger.level >= logging.WARNING:
33
+ print(f"✅ {logger_name} logger level: {logging.getLevelName(logger.level)}")
34
+ else:
35
+ print(
36
+ f"❌ {logger_name} logger level: {logging.getLevelName(logger.level)} (should be WARNING or higher)"
37
+ )
38
+ success = False
39
+
40
+ # Test that debug messages are suppressed
41
+ jax_logger = logging.getLogger("jax._src.cache_key")
42
+ jax_logger.debug("This debug message should not appear")
43
+ print("✅ JAX debug logging appears to be suppressed")
44
+
45
+ return success
46
+
47
+
48
+ def test_safe_compare():
49
+ """Test the safe comparison function."""
50
+ print("Testing safe comparison function...")
51
+
52
+ from synth_ai.environments.examples.red.config_logging import safe_compare
53
+
54
+ # Test cases that previously would cause the string vs int error
55
+ test_cases = [
56
+ ("5", 3, ">", True), # String vs int
57
+ (5, "3", ">", True), # Int vs string
58
+ ("abc", 5, ">", False), # Invalid string vs int
59
+ ("5", "3", ">", True), # String vs string (numeric)
60
+ ("abc", "def", ">", False), # String vs string (alphabetic)
61
+ (5, 3, ">", True), # Normal int comparison
62
+ ("10", 5, ">=", True), # String number >= int
63
+ (3, "10", "<=", True), # Int <= string number
64
+ ]
65
+
66
+ success = True
67
+ for left, right, op, expected in test_cases:
68
+ result = safe_compare(left, right, op)
69
+ status = "✅" if result == expected else "❌"
70
+ print(f"{status} safe_compare({left}, {right}, '{op}') = {result} (expected {expected})")
71
+ if result != expected:
72
+ success = False
73
+
74
+ return success
75
+
76
+
77
+ def test_state_creation_error_handling():
78
+ """Test that state creation handles type errors gracefully."""
79
+ print("Testing state creation error handling...")
80
+
81
+ from synth_ai.environments.examples.red.engine import PokemonRedEngine
82
+ from synth_ai.environments.examples.red.taskset import INSTANCE as POKEMON_TASK
83
+
84
+ try:
85
+ # Mock the PyBoy emulator to avoid ROM requirement
86
+ with patch("examples.red.engine.PyBoy") as mock_pyboy:
87
+ mock_emulator = Mock()
88
+ mock_pyboy.return_value = mock_emulator
89
+
90
+ # Create engine instance
91
+ engine = PokemonRedEngine(POKEMON_TASK)
92
+
93
+ # Mock extract_game_state to return problematic data that could cause comparison errors
94
+ with patch.object(engine, "_extract_current_state") as mock_extract:
95
+ # Test with string badges that could cause comparison error
96
+ mock_extract.return_value = {
97
+ "map_id": "1", # String instead of int
98
+ "player_x": "10",
99
+ "player_y": "20",
100
+ "badges": "abc", # Non-numeric string
101
+ "in_battle": "false", # String instead of bool
102
+ "party_level": "5",
103
+ "party_hp_current": "50",
104
+ "party_hp_max": "50",
105
+ "party_xp": "100",
106
+ }
107
+
108
+ # This should not crash due to our error handling
109
+ priv_state, pub_state = engine._create_states(0.0, False)
110
+
111
+ print("✅ State creation handles problematic data gracefully")
112
+ print(f"✅ Created states: badges={pub_state.badges}, map_id={pub_state.map_id}")
113
+
114
+ # Test with completely invalid data
115
+ mock_extract.side_effect = Exception("Memory read error")
116
+ priv_state, pub_state = engine._create_states(0.0, False)
117
+ print("✅ State creation handles extraction errors gracefully")
118
+
119
+ return True
120
+
121
+ except Exception as e:
122
+ print(f"❌ State creation error handling failed: {e}")
123
+ return False
124
+
125
+
126
+ def main():
127
+ """Main test function."""
128
+ print("Running Pokemon Red environment fixes test (mock version)...\n")
129
+
130
+ # Test logging configuration
131
+ logging_ok = test_logging_configuration()
132
+ print()
133
+
134
+ # Test safe comparison
135
+ compare_ok = test_safe_compare()
136
+ print()
137
+
138
+ # Test error handling
139
+ error_handling_ok = test_state_creation_error_handling()
140
+ print()
141
+
142
+ success = logging_ok and compare_ok and error_handling_ok
143
+ print(f"Overall test result: {'✅ PASSED' if success else '❌ FAILED'}")
144
+ return 0 if success else 1
145
+
146
+
147
+ if __name__ == "__main__":
148
+ sys.exit(main())
@@ -0,0 +1 @@
1
+ # Unit tests for Pokemon Red environment
@@ -0,0 +1,97 @@
1
+ import pytest
2
+ from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
3
+ from synth_ai.environments.examples.red.taskset import INSTANCE as POKEMON_TASK
4
+ from synth_ai.environments.environment.tools import EnvToolCall
5
+
6
+
7
+ class PressButtonCall(EnvToolCall):
8
+ """Helper class for creating button press calls"""
9
+
10
+ def __init__(self, button: str, frames: int = 1):
11
+ super().__init__(tool="press_button", args={"button": button, "frames": frames})
12
+
13
+
14
+ @pytest.mark.asyncio
15
+ async def test_pokemon_red_basic():
16
+ """Test basic Pokemon Red environment functionality"""
17
+ env = PokemonRedEnvironment(POKEMON_TASK)
18
+
19
+ # Initialize environment
20
+ obs = await env.initialize()
21
+ assert "position" in obs
22
+ assert "badges_earned" in obs
23
+ assert obs["badges_earned"] == 0 # Should start with no badges
24
+
25
+ # Test a few button presses
26
+ obs = await env.step(PressButtonCall("A"))
27
+ assert "step_count" in obs
28
+ assert obs["step_count"] == 1
29
+
30
+ obs = await env.step(PressButtonCall("RIGHT", 2))
31
+ assert obs["step_count"] == 2
32
+
33
+ # Test termination
34
+ final_obs = await env.terminate()
35
+ assert final_obs["terminated"] is True
36
+
37
+
38
+ @pytest.mark.asyncio
39
+ async def test_pokemon_red_multiple_actions():
40
+ """Test sequence of actions in Pokemon Red"""
41
+ env = PokemonRedEnvironment(POKEMON_TASK)
42
+
43
+ obs = await env.initialize()
44
+ initial_reward = obs["total_reward"]
45
+
46
+ # Sequence of movements and actions
47
+ actions = [
48
+ PressButtonCall("RIGHT"),
49
+ PressButtonCall("UP"),
50
+ PressButtonCall("A"),
51
+ PressButtonCall("DOWN"),
52
+ PressButtonCall("B"),
53
+ ]
54
+
55
+ for action in actions:
56
+ obs = await env.step(action)
57
+ assert "position" in obs
58
+ assert "hp_status" in obs
59
+ assert "party_level" in obs
60
+
61
+ # Should have accumulated some reward (mostly negative from step penalty)
62
+ assert obs["total_reward"] <= initial_reward # Step penalties
63
+ assert obs["step_count"] == len(actions)
64
+
65
+
66
+ @pytest.mark.asyncio
67
+ async def test_pokemon_red_checkpointing():
68
+ """Test environment checkpointing functionality"""
69
+ env = PokemonRedEnvironment(POKEMON_TASK)
70
+
71
+ # Initialize and take some steps
72
+ await env.initialize()
73
+ await env.step(PressButtonCall("RIGHT"))
74
+ await env.step(PressButtonCall("A"))
75
+
76
+ # Create checkpoint
77
+ checkpoint_obs = await env.checkpoint()
78
+ assert "engine_snapshot_data" in checkpoint_obs
79
+ assert checkpoint_obs["step_count"] == 2
80
+
81
+ # Verify checkpoint contains expected data
82
+ snapshot_data = checkpoint_obs["engine_snapshot_data"]
83
+ assert "state_data" in snapshot_data
84
+ assert "total_reward" in snapshot_data
85
+ assert "step_count" in snapshot_data
86
+
87
+
88
+ @pytest.mark.asyncio
89
+ async def test_pokemon_red_invalid_button():
90
+ """Test handling of invalid button inputs"""
91
+ env = PokemonRedEnvironment(POKEMON_TASK)
92
+ await env.initialize()
93
+
94
+ # Test with invalid button - should handle gracefully
95
+ obs = await env.step(PressButtonCall("INVALID_BUTTON"))
96
+ # Should still return valid observation even if action failed
97
+ assert "position" in obs