synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.13.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (226) hide show
  1. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -1
  2. examples/swe/task_app/grpo_swe_mini.py +55 -26
  3. examples/swe/task_app/hosted/rollout.py +40 -0
  4. examples/swe/task_app/hosted/test_service.py +5 -6
  5. examples/task_apps/TESTING.md +275 -0
  6. examples/task_apps/__init__.py +0 -0
  7. examples/task_apps/crafter/__init__.py +0 -0
  8. examples/task_apps/crafter/task_app/__init__.py +2 -0
  9. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +18 -13
  10. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  11. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +60 -4
  12. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +25 -3
  13. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +10 -0
  14. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  15. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  16. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  17. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  18. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  19. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  20. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  21. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  22. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  23. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  24. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  25. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  26. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  27. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  28. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  29. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  30. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  31. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  32. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  33. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  34. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  35. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  36. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  37. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  38. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  39. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  40. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  41. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  42. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  43. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  44. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  45. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  71. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  72. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  73. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  74. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  75. examples/task_apps/enron/__init__.py +1 -0
  76. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  77. examples/task_apps/enron/task_app/README.md +14 -0
  78. examples/task_apps/enron/task_app/__init__.py +1 -0
  79. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  80. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  81. examples/task_apps/enron/tests/__init__.py +2 -0
  82. examples/task_apps/enron/tests/conftest.py +115 -0
  83. examples/task_apps/enron/tests/integration/__init__.py +2 -0
  84. examples/task_apps/enron/tests/integration/test_enron_eval.py +177 -0
  85. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  86. examples/task_apps/enron/tests/unit/__init__.py +2 -0
  87. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  88. examples/task_apps/math/__init__.py +0 -0
  89. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  90. examples/task_apps/pokemon_battle/__init__.py +2 -0
  91. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  92. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  93. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  94. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  95. examples/task_apps/pokemon_red/README.md +357 -0
  96. examples/task_apps/pokemon_red/__init__.py +3 -0
  97. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  98. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +73 -0
  99. examples/task_apps/pokemon_red/task_app.py +606 -0
  100. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +191 -0
  101. examples/task_apps/sokoban/README.md +307 -0
  102. examples/task_apps/sokoban/__init__.py +3 -0
  103. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  104. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  105. examples/task_apps/sokoban/task_app.py +1058 -0
  106. examples/task_apps/sokoban/tests/__init__.py +2 -0
  107. examples/task_apps/sokoban/tests/conftest.py +113 -0
  108. examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
  109. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  110. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  111. examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
  112. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  113. examples/task_apps/verilog/__init__.py +1 -0
  114. examples/task_apps/verilog/eval_groq_qwen32b.toml +20 -0
  115. examples/task_apps/verilog/task_app/README.md +12 -0
  116. examples/task_apps/verilog/task_app/__init__.py +1 -0
  117. examples/task_apps/verilog/task_app/grpo_verilog.py +931 -0
  118. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  119. examples/task_apps/verilog/tests/__init__.py +2 -0
  120. examples/task_apps/verilog/tests/conftest.py +115 -0
  121. examples/task_apps/verilog/tests/integration/__init__.py +2 -0
  122. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +179 -0
  123. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  124. examples/task_apps/verilog/tests/unit/__init__.py +2 -0
  125. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  126. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  127. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  128. examples/workflows/__init__.py +0 -0
  129. examples/workflows/math_rl/__init__.py +0 -0
  130. examples/workflows/math_rl/download_dataset.py +80 -0
  131. synth_ai/__init__.py +2 -2
  132. synth_ai/api/train/builders.py +25 -11
  133. synth_ai/api/train/cli.py +12 -6
  134. synth_ai/api/train/configs/__init__.py +10 -10
  135. synth_ai/api/train/configs/rl.py +5 -4
  136. synth_ai/api/train/configs/sft.py +4 -3
  137. synth_ai/api/train/env_resolver.py +5 -2
  138. synth_ai/api/train/supported_algos.py +10 -5
  139. synth_ai/api/train/utils.py +7 -4
  140. synth_ai/cli/__init__.py +7 -51
  141. synth_ai/cli/_storage.py +4 -3
  142. synth_ai/cli/_validate_task_app.py +11 -0
  143. synth_ai/cli/balance.py +4 -3
  144. synth_ai/cli/calc.py +2 -2
  145. synth_ai/cli/demo.py +14 -7
  146. synth_ai/cli/legacy_root_backup.py +1 -1
  147. synth_ai/cli/rl_demo.py +8 -7
  148. synth_ai/cli/root.py +0 -97
  149. synth_ai/cli/task_apps.py +1707 -186
  150. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +28 -16
  151. synth_ai/environments/examples/enron/engine.py +7 -2
  152. synth_ai/environments/examples/enron/environment.py +68 -0
  153. synth_ai/environments/examples/red/engine.py +27 -0
  154. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  155. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  156. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  157. synth_ai/environments/examples/red/environment.py +60 -0
  158. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  159. synth_ai/environments/examples/verilog/engine.py +30 -4
  160. synth_ai/evals/client.py +58 -61
  161. synth_ai/jobs/client.py +16 -4
  162. synth_ai/judge_schemas.py +16 -16
  163. synth_ai/py.typed +0 -0
  164. synth_ai/task/__init__.py +14 -5
  165. synth_ai/task/contracts.py +124 -38
  166. synth_ai/task/proxy.py +48 -56
  167. synth_ai/task/rubrics/__init__.py +53 -0
  168. synth_ai/task/rubrics/loaders.py +133 -0
  169. synth_ai/task/rubrics/models.py +57 -0
  170. synth_ai/task/rubrics/scoring.py +113 -0
  171. synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
  172. synth_ai/task/server.py +8 -7
  173. synth_ai/task/validators.py +269 -6
  174. synth_ai/tracing_v3/decorators.py +7 -3
  175. synth_ai/tracing_v3/replica_sync.py +4 -4
  176. synth_ai/tracing_v3/serialization.py +5 -5
  177. synth_ai/tracing_v3/trace_utils.py +317 -0
  178. synth_ai/tracing_v3/turso/native_manager.py +3 -3
  179. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/METADATA +4 -1
  180. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/RECORD +214 -101
  181. examples/agora_ex/README_MoE.md +0 -224
  182. examples/agora_ex/__init__.py +0 -7
  183. examples/agora_ex/agora_ex.py +0 -65
  184. examples/agora_ex/agora_ex_task_app.py +0 -590
  185. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
  186. examples/agora_ex/reward_fn_grpo-human.py +0 -129
  187. examples/agora_ex/system_prompt_CURRENT.md +0 -63
  188. examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
  189. examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
  190. examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
  191. synth_ai/rubrics/__init__.py +0 -22
  192. synth_ai/task/rubrics.py +0 -219
  193. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  194. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  195. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  196. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  197. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  198. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  199. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  200. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  201. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +0 -0
  202. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +0 -0
  203. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  204. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  205. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  206. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  207. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +0 -0
  208. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  209. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  210. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  211. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  212. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  213. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/utils.py +0 -0
  214. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  215. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  216. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  217. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  218. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  219. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  220. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  221. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  222. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  223. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/WHEEL +0 -0
  224. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/entry_points.txt +0 -0
  225. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/licenses/LICENSE +0 -0
  226. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.13.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,477 @@
1
+ """
2
+ Ultra-Rich Reward Shaping for Pallet Town First Section
3
+
4
+ This module provides fine-grained reward components that track important
5
+ achievements in the initial Pallet Town sequence: leaving the house, finding
6
+ Oak's lab, talking to Oak, starting the rival battle, attacking and damaging
7
+ the opponent, winning the battle, getting a party member, and leaving the lab.
8
+
9
+ Each milestone is carefully weighted to provide dense, meaningful feedback
10
+ for reinforcement learning agents learning to play Pokemon Red.
11
+ """
12
+
13
+ from typing import Any, Dict
14
+
15
+ from synth_ai.environments.environment.rewards.core import RewardComponent
16
+
17
+
18
+ class LeaveBedroomReward(RewardComponent):
19
+ """
20
+ Reward for going downstairs from bedroom to main floor.
21
+ This is the first meaningful action in the game.
22
+
23
+ Reward: +20 points (one-time)
24
+ """
25
+
26
+ def __init__(self):
27
+ self.triggered = False
28
+
29
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
30
+ if self.triggered:
31
+ return 0.0
32
+
33
+ prev_map = action.get("prev_map_id", -1)
34
+ current_map = state.get("map_id", -1)
35
+
36
+ # Detect moving from bedroom (map 38/0x26) to downstairs (map 37/0x25)
37
+ # In Red's house, bedroom is map 38, downstairs is map 37
38
+ if prev_map == 38 and current_map == 37:
39
+ self.triggered = True
40
+ return 20.0
41
+ return 0.0
42
+
43
+
44
+ class ExitHouseFirstTimeReward(RewardComponent):
45
+ """
46
+ Reward for leaving the starting house and entering Pallet Town.
47
+ This is a major milestone showing the agent understands doors.
48
+
49
+ Reward: +30 points (one-time)
50
+ """
51
+
52
+ def __init__(self):
53
+ self.triggered = False
54
+
55
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
56
+ if self.triggered:
57
+ return 0.0
58
+
59
+ prev_map = action.get("prev_map_id", -1)
60
+ current_map = state.get("map_id", -1)
61
+
62
+ # Exit from house (map 37) to Pallet Town (likely map 0-36 range)
63
+ # Detect leaving house interior to outdoor area
64
+ if prev_map == 37 and current_map != 37 and current_map != 38:
65
+ self.triggered = True
66
+ return 30.0
67
+ return 0.0
68
+
69
+
70
+ class FindOakLabReward(RewardComponent):
71
+ """
72
+ Reward for discovering and entering Oak's Lab for the first time.
73
+ This shows the agent can navigate the town and find the lab.
74
+
75
+ Reward: +40 points (one-time)
76
+ """
77
+
78
+ def __init__(self):
79
+ self.lab_found = False
80
+
81
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
82
+ if self.lab_found:
83
+ return 0.0
84
+
85
+ prev_map = action.get("prev_map_id", -1)
86
+ current_map = state.get("map_id", -1)
87
+
88
+ # Entering Oak's lab (map 3/0x03) from Pallet Town (map 0)
89
+ if prev_map == 0 and current_map == 3:
90
+ self.lab_found = True
91
+ return 40.0
92
+ return 0.0
93
+
94
+
95
+ class TalkToOakReward(RewardComponent):
96
+ """
97
+ Reward for first conversation with Professor Oak in the lab.
98
+ This is detected by text box activation in Oak's lab.
99
+
100
+ Reward: +50 points (one-time)
101
+ """
102
+
103
+ def __init__(self):
104
+ self.oak_talked_to = False
105
+
106
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
107
+ if self.oak_talked_to:
108
+ return 0.0
109
+
110
+ # Detect first dialogue in Oak's lab
111
+ if state.get("map_id", -1) == 3 and state.get("text_box_active", False):
112
+ prev_text_active = action.get("prev_text_box_active", False)
113
+ if not prev_text_active:
114
+ self.oak_talked_to = True
115
+ return 50.0
116
+ return 0.0
117
+
118
+
119
+ class ReceiveStarterPokemonReward(RewardComponent):
120
+ """
121
+ Reward for receiving your first Pokemon from Oak.
122
+ This is a major story milestone detected by party count changing from 0 to 1.
123
+
124
+ Reward: +100 points (one-time)
125
+ """
126
+
127
+ def __init__(self):
128
+ self.starter_received = False
129
+
130
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
131
+ if self.starter_received:
132
+ return 0.0
133
+
134
+ # Detect receiving first Pokemon
135
+ prev_party_count = action.get("prev_party_count", 0)
136
+ current_party_count = state.get("party_count", 0)
137
+
138
+ if prev_party_count == 0 and current_party_count == 1:
139
+ # Verify we're in Oak's lab
140
+ if state.get("map_id", -1) == 3:
141
+ self.starter_received = True
142
+ return 100.0
143
+ return 0.0
144
+
145
+
146
+ class EnterFirstBattleReward(RewardComponent):
147
+ """
148
+ Reward for entering the first rival battle in Oak's lab.
149
+ This shows the agent has progressed through dialogue and triggered the battle.
150
+
151
+ Reward: +75 points (one-time)
152
+ """
153
+
154
+ def __init__(self):
155
+ self.first_battle_entered = False
156
+
157
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
158
+ if self.first_battle_entered:
159
+ return 0.0
160
+
161
+ # Detect entering battle for the first time
162
+ prev_in_battle = action.get("prev_in_battle", False)
163
+ current_in_battle = state.get("in_battle", False)
164
+
165
+ if not prev_in_battle and current_in_battle:
166
+ # Verify we're in Oak's lab (rival battle)
167
+ if state.get("map_id", -1) == 3:
168
+ self.first_battle_entered = True
169
+ return 75.0
170
+ return 0.0
171
+
172
+
173
+ class DealDamageToRivalReward(RewardComponent):
174
+ """
175
+ Reward for successfully attacking and damaging the rival's Pokemon.
176
+ This is detected by a decrease in enemy HP during battle.
177
+
178
+ Reward: +5 points per HP damage dealt (cumulative, capped at 10 total rewards)
179
+ """
180
+
181
+ def __init__(self):
182
+ self.damage_instances = 0
183
+ self.max_instances = 10
184
+ self.prev_enemy_hp = None
185
+
186
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
187
+ if self.damage_instances >= self.max_instances:
188
+ return 0.0
189
+
190
+ # Track damage during battle
191
+ if state.get("in_battle", False):
192
+ current_enemy_hp = state.get("enemy_hp_current", 0)
193
+ prev_enemy_hp = action.get("prev_enemy_hp_current", current_enemy_hp)
194
+
195
+ # Detect HP decrease (damage dealt)
196
+ if prev_enemy_hp > current_enemy_hp > 0:
197
+ damage = prev_enemy_hp - current_enemy_hp
198
+ self.damage_instances += 1
199
+ # Give +5 points per instance of damage
200
+ return 5.0
201
+
202
+ return 0.0
203
+
204
+
205
+ class ReduceEnemyHPByHalfReward(RewardComponent):
206
+ """
207
+ Reward for reducing enemy HP below 50% for the first time.
208
+ This shows the agent is making significant progress in battle.
209
+
210
+ Reward: +25 points (one-time)
211
+ """
212
+
213
+ def __init__(self):
214
+ self.half_hp_achieved = False
215
+
216
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
217
+ if self.half_hp_achieved:
218
+ return 0.0
219
+
220
+ if state.get("in_battle", False):
221
+ enemy_hp_pct = state.get("enemy_hp_percentage", 0.0)
222
+ prev_enemy_hp_pct = action.get("prev_enemy_hp_percentage", 100.0)
223
+
224
+ # Detect crossing below 50% threshold
225
+ if prev_enemy_hp_pct >= 50.0 and enemy_hp_pct < 50.0:
226
+ self.half_hp_achieved = True
227
+ return 25.0
228
+
229
+ return 0.0
230
+
231
+
232
+ class ReduceEnemyHPToLowReward(RewardComponent):
233
+ """
234
+ Reward for reducing enemy HP below 25% (critical/low HP).
235
+ This shows the agent is close to winning the battle.
236
+
237
+ Reward: +35 points (one-time)
238
+ """
239
+
240
+ def __init__(self):
241
+ self.low_hp_achieved = False
242
+
243
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
244
+ if self.low_hp_achieved:
245
+ return 0.0
246
+
247
+ if state.get("in_battle", False):
248
+ enemy_hp_pct = state.get("enemy_hp_percentage", 0.0)
249
+ prev_enemy_hp_pct = action.get("prev_enemy_hp_percentage", 100.0)
250
+
251
+ # Detect crossing below 25% threshold
252
+ if prev_enemy_hp_pct >= 25.0 and enemy_hp_pct < 25.0:
253
+ self.low_hp_achieved = True
254
+ return 35.0
255
+
256
+ return 0.0
257
+
258
+
259
+ class WinFirstBattleReward(RewardComponent):
260
+ """
261
+ Reward for winning the first battle against the rival.
262
+ This is the culmination of the battle sequence.
263
+
264
+ Reward: +150 points (one-time)
265
+ """
266
+
267
+ def __init__(self):
268
+ self.first_battle_won = False
269
+
270
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
271
+ if self.first_battle_won:
272
+ return 0.0
273
+
274
+ # Detect winning a battle (transition from in_battle to not in_battle with win outcome)
275
+ prev_in_battle = action.get("prev_in_battle", False)
276
+ current_in_battle = state.get("in_battle", False)
277
+ battle_outcome = state.get("battle_outcome", 0)
278
+
279
+ # battle_outcome: 0=ongoing, 1=win, 2=lose
280
+ if prev_in_battle and not current_in_battle and battle_outcome == 1:
281
+ # Verify it's in Oak's lab (the rival battle)
282
+ if state.get("map_id", -1) == 3:
283
+ self.first_battle_won = True
284
+ return 150.0
285
+
286
+ return 0.0
287
+
288
+
289
+ class ExitLabAfterBattleReward(RewardComponent):
290
+ """
291
+ Reward for leaving Oak's Lab after receiving Pokemon and winning the battle.
292
+ This completes the initial Pallet Town sequence.
293
+
294
+ Reward: +60 points (one-time, requires having a party member)
295
+ """
296
+
297
+ def __init__(self):
298
+ self.exited_with_pokemon = False
299
+
300
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
301
+ if self.exited_with_pokemon:
302
+ return 0.0
303
+
304
+ prev_map = action.get("prev_map_id", -1)
305
+ current_map = state.get("map_id", -1)
306
+
307
+ # Exit from lab (map 3) to town (map 0)
308
+ if prev_map == 3 and current_map == 0:
309
+ # Verify we have at least one Pokemon
310
+ if state.get("party_count", 0) > 0:
311
+ self.exited_with_pokemon = True
312
+ return 60.0
313
+
314
+ return 0.0
315
+
316
+
317
+ class FirstBattleEfficiencyReward(RewardComponent):
318
+ """
319
+ Reward for winning the first battle efficiently (in fewer turns).
320
+ Encourages the agent to learn optimal battle strategies early.
321
+
322
+ Reward: +20 points if won in ≤5 turns, +10 if ≤8 turns
323
+ """
324
+
325
+ def __init__(self):
326
+ self.efficiency_rewarded = False
327
+ self.max_turns_seen = 0
328
+
329
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
330
+ if self.efficiency_rewarded:
331
+ return 0.0
332
+
333
+ # Track turn count during battle
334
+ if state.get("in_battle", False):
335
+ self.max_turns_seen = max(self.max_turns_seen, state.get("battle_turn", 0))
336
+
337
+ # When battle ends with a win, assess efficiency
338
+ prev_in_battle = action.get("prev_in_battle", False)
339
+ current_in_battle = state.get("in_battle", False)
340
+ battle_outcome = state.get("battle_outcome", 0)
341
+
342
+ if prev_in_battle and not current_in_battle and battle_outcome == 1:
343
+ if state.get("map_id", -1) == 3: # Rival battle in lab
344
+ self.efficiency_rewarded = True
345
+ if self.max_turns_seen <= 5:
346
+ return 20.0
347
+ elif self.max_turns_seen <= 8:
348
+ return 10.0
349
+
350
+ return 0.0
351
+
352
+
353
+ class KeepPokemonHealthyReward(RewardComponent):
354
+ """
355
+ Reward for keeping your Pokemon's HP above 50% during the first battle.
356
+ Encourages defensive play and resource management.
357
+
358
+ Reward: +30 points (one-time, checked at end of first battle)
359
+ """
360
+
361
+ def __init__(self):
362
+ self.health_bonus_given = False
363
+
364
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
365
+ if self.health_bonus_given:
366
+ return 0.0
367
+
368
+ # Check health status when battle ends
369
+ prev_in_battle = action.get("prev_in_battle", False)
370
+ current_in_battle = state.get("in_battle", False)
371
+ battle_outcome = state.get("battle_outcome", 0)
372
+
373
+ if prev_in_battle and not current_in_battle and battle_outcome == 1:
374
+ if state.get("map_id", -1) == 3: # Rival battle in lab
375
+ # Check if first Pokemon has >50% HP
376
+ if len(state.get("party_pokemon", [])) > 0:
377
+ first_pokemon = state.get("party_pokemon", [])[0]
378
+ hp_pct = first_pokemon.get("hp_percentage", 0)
379
+ if hp_pct > 50.0:
380
+ self.health_bonus_given = True
381
+ return 30.0
382
+
383
+ return 0.0
384
+
385
+
386
+ class NavigationSpeedReward(RewardComponent):
387
+ """
388
+ Reward for completing the Pallet Town sequence quickly (by step count).
389
+ Encourages efficient navigation and minimal wandering.
390
+
391
+ Reward: Scales based on step count (fewer steps = higher reward)
392
+ """
393
+
394
+ def __init__(self):
395
+ self.step_count = 0
396
+ self.sequence_complete = False
397
+ self.reward_given = False
398
+
399
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
400
+ if self.reward_given:
401
+ return 0.0
402
+
403
+ # Track steps
404
+ self.step_count += 1
405
+
406
+ # Check if sequence is complete (exited lab with Pokemon after battle)
407
+ prev_map = action.get("prev_map_id", -1)
408
+ current_map = state.get("map_id", -1)
409
+
410
+ if prev_map == 3 and current_map == 0: # Exiting lab
411
+ if state.get("party_count", 0) > 0: # Have Pokemon
412
+ self.sequence_complete = True
413
+ self.reward_given = True
414
+
415
+ # Award points based on efficiency
416
+ # Optimal path is roughly 30-40 steps
417
+ if self.step_count <= 40:
418
+ return 50.0 # Very efficient
419
+ elif self.step_count <= 60:
420
+ return 30.0 # Good
421
+ elif self.step_count <= 80:
422
+ return 15.0 # Acceptable
423
+ else:
424
+ return 5.0 # Completed but slow
425
+
426
+ return 0.0
427
+
428
+
429
+ # Composite reward for the complete Pallet Town sequence
430
+ class PalletTownProgressionCompositeReward(RewardComponent):
431
+ """
432
+ Composite reward that combines all Pallet Town progression milestones.
433
+
434
+ Total possible points: ~600+
435
+ - Leave bedroom: 20
436
+ - Exit house: 30
437
+ - Find lab: 40
438
+ - Talk to Oak: 50
439
+ - Get starter: 100
440
+ - Enter battle: 75
441
+ - Deal damage: 50 (10 instances × 5)
442
+ - Half HP: 25
443
+ - Low HP: 35
444
+ - Win battle: 150
445
+ - Exit lab: 60
446
+ - Efficiency: 20
447
+ - Keep healthy: 30
448
+ - Navigation: 50
449
+
450
+ This provides dense, meaningful feedback throughout the entire sequence.
451
+ """
452
+
453
+ def __init__(self):
454
+ self.components = [
455
+ LeaveBedroomReward(),
456
+ ExitHouseFirstTimeReward(),
457
+ FindOakLabReward(),
458
+ TalkToOakReward(),
459
+ ReceiveStarterPokemonReward(),
460
+ EnterFirstBattleReward(),
461
+ DealDamageToRivalReward(),
462
+ ReduceEnemyHPByHalfReward(),
463
+ ReduceEnemyHPToLowReward(),
464
+ WinFirstBattleReward(),
465
+ ExitLabAfterBattleReward(),
466
+ FirstBattleEfficiencyReward(),
467
+ KeepPokemonHealthyReward(),
468
+ NavigationSpeedReward(),
469
+ ]
470
+
471
+ async def score(self, state: Dict[str, Any], action: Dict[str, Any]) -> float:
472
+ total_reward = 0.0
473
+ for component in self.components:
474
+ reward = await component.score(state, action)
475
+ total_reward += reward
476
+ return total_reward
477
+
@@ -85,6 +85,33 @@ def extract_inventory(memory) -> List[Dict[str, Any]]:
85
85
  return inventory
86
86
 
87
87
 
88
+ def extract_battle_state(memory) -> Dict[str, Any]:
89
+ """Extract battle-specific state"""
90
+ in_battle = get_byte(memory, IN_BATTLE_FLAG) > 0
91
+
92
+ if not in_battle:
93
+ return {
94
+ "enemy_hp_current": 0,
95
+ "enemy_hp_max": 0,
96
+ "enemy_level": 0,
97
+ "enemy_species_id": 0,
98
+ "enemy_hp_percentage": 0.0,
99
+ "battle_turn": 0,
100
+ }
101
+
102
+ enemy_hp_current = get_word(memory, ENEMY_HP_CURRENT)
103
+ enemy_hp_max = get_word(memory, ENEMY_HP_MAX)
104
+
105
+ return {
106
+ "enemy_hp_current": enemy_hp_current,
107
+ "enemy_hp_max": enemy_hp_max,
108
+ "enemy_level": get_byte(memory, ENEMY_LEVEL),
109
+ "enemy_species_id": get_byte(memory, ENEMY_SPECIES),
110
+ "enemy_hp_percentage": round((enemy_hp_current / enemy_hp_max * 100) if enemy_hp_max > 0 else 0, 1),
111
+ "battle_turn": get_byte(memory, BATTLE_TURN),
112
+ }
113
+
114
+
88
115
  def extract_game_state(memory) -> Dict[str, Any]:
89
116
  """Extract comprehensive game state from Game Boy memory"""
90
117
  # Get party and inventory details
@@ -93,6 +120,9 @@ def extract_game_state(memory) -> Dict[str, Any]:
93
120
 
94
121
  # Get money
95
122
  money = get_bcd_3byte(memory, MONEY)
123
+
124
+ # Get battle state
125
+ battle_state = extract_battle_state(memory)
96
126
 
97
127
  # Basic game state
98
128
  state = {
@@ -111,6 +141,8 @@ def extract_game_state(memory) -> Dict[str, Any]:
111
141
  "party_pokemon": party,
112
142
  "inventory_count": len(inventory),
113
143
  "inventory_items": inventory,
144
+ # Battle state
145
+ **battle_state,
114
146
  # Legacy fields for compatibility (use first Pokemon if available)
115
147
  "party_level": party[0]["level"] if party else 0,
116
148
  "party_hp_current": party[0]["hp_current"] if party else 0,
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from typing import Any, Dict, List, Optional, Union
4
+ import base64
5
+ from io import BytesIO
4
6
 
5
7
  from pydantic import BaseModel, Field
6
8
 
@@ -17,6 +19,12 @@ from synth_ai.environments.environment.tools import (
17
19
  )
18
20
  from synth_ai.environments.reproducibility.core import ReproducibleEnvironment
19
21
  from synth_ai.environments.stateful.core import StatefulEnvironment
22
+ try: # optional for image encoding
23
+ import numpy as _np # type: ignore
24
+ from PIL import Image as _PILImage # type: ignore
25
+ except Exception: # pragma: no cover - optional dependency
26
+ _np = None # type: ignore
27
+ _PILImage = None # type: ignore
20
28
 
21
29
  # Import logging configuration to suppress JAX debug messages
22
30
  from .engine import (
@@ -220,6 +228,58 @@ class PokemonRedEnvironment(StatefulEnvironment, ReproducibleEnvironment[Pokemon
220
228
  """Convert states to observation using the specified callback"""
221
229
  active_obs_cb = obs_cb or PokemonRedObservationCallable()
222
230
  observation = await active_obs_cb.get_observation(pub, priv)
231
+
232
+ # Include raw state fields for reward calculation
233
+ if isinstance(observation, dict):
234
+ observation["map_id"] = pub.world.map_id if pub.world else None
235
+ observation["player_x"] = pub.world.player_x if pub.world else None
236
+ observation["player_y"] = pub.world.player_y if pub.world else None
237
+ observation["party_count"] = len(pub.party) if pub.party else 0
238
+ observation["party_pokemon"] = [
239
+ {
240
+ "species_id": p.species_id,
241
+ "level": p.level,
242
+ "hp_current": p.hp_current,
243
+ "hp_max": p.hp_max,
244
+ "hp_percentage": (p.hp_current / p.hp_max * 100) if p.hp_max > 0 else 0,
245
+ }
246
+ for p in (pub.party or [])
247
+ ]
248
+ observation["in_battle"] = pub.system.in_battle if pub.system else False
249
+ observation["battle_outcome"] = pub.system.battle_outcome if pub.system else 0
250
+ observation["text_box_active"] = pub.system.text_box_active if pub.system else False
251
+ observation["enemy_hp_current"] = pub.system.enemy_hp_current if pub.system else 0
252
+ observation["enemy_hp_max"] = pub.system.enemy_hp_max if pub.system else 0
253
+ observation["enemy_hp_percentage"] = pub.system.enemy_hp_percentage if pub.system else 0.0
254
+ observation["badges"] = pub.progress.badges if pub.progress else 0
255
+ # Attach latest PNG frame for VLM agents if available
256
+ try:
257
+ emulator = getattr(self.engine, "emulator", None)
258
+ screen = getattr(emulator, "screen", None)
259
+ if screen is not None and _np is not None and _PILImage is not None:
260
+ # Prefer documented ndarray property if present
261
+ frame = getattr(screen, "ndarray", None)
262
+ if frame is None and hasattr(screen, "image"):
263
+ frame = screen.image
264
+ if isinstance(frame, _np.ndarray) and frame.ndim == 3 and frame.shape[0] > 0 and frame.shape[1] > 0:
265
+ array_uint8 = (
266
+ frame.astype("uint8") if frame.dtype != _np.uint8 else frame
267
+ )
268
+ # PyBoy gives RGBA; convert to RGB
269
+ if array_uint8.shape[-1] == 4:
270
+ array_uint8 = array_uint8[:, :, :3]
271
+ img = _PILImage.fromarray(array_uint8, mode="RGB")
272
+ buf = BytesIO()
273
+ img.save(buf, format="PNG")
274
+ encoded = base64.b64encode(buf.getvalue()).decode("ascii")
275
+ if isinstance(observation, dict):
276
+ observation["observation_image_base64"] = encoded
277
+ observation["observation_image_format"] = "png"
278
+ observation["observation_image_width"] = int(array_uint8.shape[1])
279
+ observation["observation_image_height"] = int(array_uint8.shape[0])
280
+ observation["observation_image_data_url"] = f"data:image/png;base64,{encoded}"
281
+ except Exception:
282
+ pass
223
283
  if extra_obs and isinstance(observation, dict):
224
284
  observation.update(extra_obs)
225
285
  return observation