synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (291) hide show
  1. examples/multi_step/configs/README_verilog_rl.md +77 -0
  2. examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
  3. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
  4. examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
  5. examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
  6. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +17 -5
  7. examples/multi_step/configs/crafter_synth_backend.md +40 -0
  8. examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
  9. examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
  10. examples/multi_step/configs/verilog_rl_lora.toml +190 -0
  11. examples/multi_step/judges/crafter_backend_judge.py +220 -0
  12. examples/multi_step/judges/verilog_backend_judge.py +234 -0
  13. examples/multi_step/readme.md +48 -0
  14. examples/multi_step/verilog_rl_lora.md +218 -0
  15. examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
  16. examples/sft/evaluate.py +2 -0
  17. examples/sft/generate_traces.py +2 -0
  18. examples/swe/task_app/grpo_swe_mini.py +56 -26
  19. examples/swe/task_app/hosted/rollout.py +42 -0
  20. examples/swe/task_app/hosted/test_service.py +5 -6
  21. examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
  22. examples/task_apps/TESTING.md +275 -0
  23. examples/task_apps/__init__.py +0 -0
  24. examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
  25. examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
  26. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
  27. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
  28. examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
  29. examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
  30. examples/task_apps/crafter/__init__.py +0 -0
  31. examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
  32. examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
  33. examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
  34. examples/task_apps/crafter/task_app/__init__.py +5 -0
  35. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +324 -21
  36. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  37. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
  38. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +76 -7
  39. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
  40. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
  41. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +77 -4
  42. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +117 -9
  43. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  44. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +218 -0
  45. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  71. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  72. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  73. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  74. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  75. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  76. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  77. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  78. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  79. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  80. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  81. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  82. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  83. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  84. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  85. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  86. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  87. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  88. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  89. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  90. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  91. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  92. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  93. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  94. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  95. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  96. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  97. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  98. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  99. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  100. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  101. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  102. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  103. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  104. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  105. examples/task_apps/enron/__init__.py +1 -0
  106. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  107. examples/task_apps/enron/filter_sft.toml +5 -0
  108. examples/task_apps/enron/task_app/README.md +14 -0
  109. examples/task_apps/enron/task_app/__init__.py +1 -0
  110. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  111. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  112. examples/task_apps/enron/tests/__init__.py +4 -0
  113. examples/task_apps/enron/tests/conftest.py +115 -0
  114. examples/task_apps/enron/tests/integration/__init__.py +4 -0
  115. examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
  116. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  117. examples/task_apps/enron/tests/unit/__init__.py +4 -0
  118. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  119. examples/task_apps/math/__init__.py +0 -0
  120. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  121. examples/task_apps/pokemon_battle/__init__.py +2 -0
  122. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  123. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  124. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  125. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  126. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
  127. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
  128. examples/task_apps/pokemon_red/README.md +357 -0
  129. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
  130. examples/task_apps/pokemon_red/__init__.py +3 -0
  131. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
  132. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  133. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
  134. examples/task_apps/pokemon_red/task_app.py +799 -0
  135. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
  136. examples/task_apps/sokoban/README.md +307 -0
  137. examples/task_apps/sokoban/__init__.py +3 -0
  138. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  139. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  140. examples/task_apps/sokoban/filter_sft.toml +5 -0
  141. examples/task_apps/sokoban/task_app.py +1058 -0
  142. examples/task_apps/sokoban/tests/__init__.py +4 -0
  143. examples/task_apps/sokoban/tests/conftest.py +113 -0
  144. examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
  145. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  146. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  147. examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
  148. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  149. examples/task_apps/verilog/__init__.py +1 -0
  150. examples/task_apps/verilog/eval_groq_qwen32b.toml +24 -0
  151. examples/task_apps/verilog/filter_sft.toml +5 -0
  152. examples/task_apps/verilog/task_app/README.md +12 -0
  153. examples/task_apps/verilog/task_app/__init__.py +1 -0
  154. examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
  155. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  156. examples/task_apps/verilog/tests/__init__.py +4 -0
  157. examples/task_apps/verilog/tests/conftest.py +115 -0
  158. examples/task_apps/verilog/tests/integration/__init__.py +4 -0
  159. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
  160. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  161. examples/task_apps/verilog/tests/unit/__init__.py +4 -0
  162. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  163. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  164. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  165. examples/warming_up_to_rl/groq_test.py +2 -0
  166. examples/warming_up_to_rl/run_local_rollout.py +2 -0
  167. examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
  168. examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
  169. examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
  170. examples/warming_up_to_rl/run_rollout_remote.py +2 -0
  171. examples/workflows/__init__.py +0 -0
  172. examples/workflows/math_rl/__init__.py +0 -0
  173. examples/workflows/math_rl/download_dataset.py +80 -0
  174. synth_ai/__init__.py +2 -2
  175. synth_ai/api/models/supported.py +1 -0
  176. synth_ai/api/train/builders.py +25 -11
  177. synth_ai/api/train/cli.py +12 -6
  178. synth_ai/api/train/configs/__init__.py +10 -10
  179. synth_ai/api/train/configs/rl.py +5 -4
  180. synth_ai/api/train/configs/sft.py +4 -3
  181. synth_ai/api/train/env_resolver.py +5 -2
  182. synth_ai/api/train/supported_algos.py +10 -5
  183. synth_ai/api/train/utils.py +7 -4
  184. synth_ai/cli/__init__.py +48 -59
  185. synth_ai/cli/_modal_wrapper.py +3 -2
  186. synth_ai/cli/_storage.py +4 -3
  187. synth_ai/cli/_validate_task_app.py +11 -0
  188. synth_ai/cli/balance.py +4 -3
  189. synth_ai/cli/calc.py +2 -2
  190. synth_ai/cli/demo.py +14 -7
  191. synth_ai/cli/legacy_root_backup.py +1 -1
  192. synth_ai/cli/recent.py +1 -1
  193. synth_ai/cli/rl_demo.py +8 -7
  194. synth_ai/cli/root.py +0 -97
  195. synth_ai/cli/status.py +1 -1
  196. synth_ai/cli/task_apps.py +1922 -190
  197. synth_ai/cli/traces.py +1 -1
  198. synth_ai/cli/tui.py +57 -0
  199. synth_ai/cli/turso.py +1 -1
  200. synth_ai/cli/watch.py +1 -1
  201. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +29 -17
  202. synth_ai/environments/examples/crafter_classic/environment.py +1 -1
  203. synth_ai/environments/examples/enron/engine.py +7 -2
  204. synth_ai/environments/examples/enron/environment.py +68 -0
  205. synth_ai/environments/examples/red/engine.py +27 -0
  206. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  207. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  208. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  209. synth_ai/environments/examples/red/environment.py +60 -0
  210. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  211. synth_ai/environments/examples/verilog/engine.py +104 -12
  212. synth_ai/evals/client.py +58 -61
  213. synth_ai/jobs/client.py +16 -4
  214. synth_ai/judge_schemas.py +9 -9
  215. synth_ai/py.typed +0 -0
  216. synth_ai/task/__init__.py +24 -5
  217. synth_ai/task/apps/__init__.py +1 -0
  218. synth_ai/task/config.py +257 -0
  219. synth_ai/task/contracts.py +138 -39
  220. synth_ai/task/proxy.py +48 -56
  221. synth_ai/task/rubrics/__init__.py +56 -0
  222. synth_ai/task/rubrics/loaders.py +152 -0
  223. synth_ai/task/rubrics/models.py +57 -0
  224. synth_ai/task/rubrics/scoring.py +116 -0
  225. synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
  226. synth_ai/task/server.py +8 -7
  227. synth_ai/task/trace_correlation_helpers.py +315 -0
  228. synth_ai/task/validators.py +413 -6
  229. synth_ai/tracing_v3/abstractions.py +3 -3
  230. synth_ai/tracing_v3/decorators.py +7 -3
  231. synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
  232. synth_ai/tracing_v3/replica_sync.py +4 -4
  233. synth_ai/tracing_v3/serialization.py +5 -5
  234. synth_ai/tracing_v3/session_tracer.py +16 -6
  235. synth_ai/tracing_v3/storage/base.py +29 -29
  236. synth_ai/tracing_v3/storage/config.py +3 -3
  237. synth_ai/tracing_v3/trace_utils.py +317 -0
  238. synth_ai/tracing_v3/turso/daemon.py +8 -7
  239. synth_ai/tracing_v3/turso/native_manager.py +66 -43
  240. synth_ai/tracing_v3/utils.py +3 -3
  241. synth_ai/tui/__init__.py +5 -0
  242. synth_ai/tui/__main__.py +13 -0
  243. synth_ai/tui/cli/__init__.py +1 -0
  244. synth_ai/tui/cli/query_experiments.py +164 -0
  245. synth_ai/tui/cli/query_experiments_v3.py +164 -0
  246. synth_ai/tui/dashboard.py +906 -0
  247. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/METADATA +4 -1
  248. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/RECORD +278 -126
  249. examples/agora_ex/README_MoE.md +0 -224
  250. examples/agora_ex/__init__.py +0 -7
  251. examples/agora_ex/agora_ex.py +0 -65
  252. examples/agora_ex/agora_ex_task_app.py +0 -590
  253. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
  254. examples/agora_ex/reward_fn_grpo-human.py +0 -129
  255. examples/agora_ex/system_prompt_CURRENT.md +0 -63
  256. examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
  257. examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
  258. examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
  259. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +0 -62
  260. synth_ai/rubrics/__init__.py +0 -22
  261. synth_ai/task/rubrics.py +0 -219
  262. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  263. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  264. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  265. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  266. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  267. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  268. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  269. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  270. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  271. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  272. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  273. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  274. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  275. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  276. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  277. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  278. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  279. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  280. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  281. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  282. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  283. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  284. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  285. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  286. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  287. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  288. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
  289. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
  290. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
  291. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,799 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, Iterable, Mapping, Sequence
5
+
6
+ from fastapi import HTTPException, Request
7
+ import httpx
8
+
9
+ from synth_ai.environments.examples.red.environment import PokemonRedEnvironment
10
+ from synth_ai.environments.environment.tools import EnvToolCall
11
+ from synth_ai.environments.examples.red.taskset import INSTANCE as RED_DEFAULT_INSTANCE
12
+ from synth_ai.environments.examples.red.engine_helpers.reward_library.pallet_town_progression import (
13
+ PalletTownProgressionCompositeReward,
14
+ )
15
+ from synth_ai.task.apps import TaskAppEntry, register_task_app
16
+ from synth_ai.task.contracts import (
17
+ RolloutMetrics,
18
+ RolloutRequest,
19
+ RolloutResponse,
20
+ RolloutStep,
21
+ RolloutTrajectory,
22
+ TaskInfo,
23
+ )
24
+ from synth_ai.task.server import ProxyConfig, TaskAppConfig
25
+ from synth_ai.task.tracing_utils import (
26
+ build_tracer_factory,
27
+ resolve_sft_output_dir,
28
+ resolve_tracing_db_url,
29
+ tracing_env_enabled,
30
+ )
31
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def _base_task_info() -> TaskInfo:
37
+ return TaskInfo(
38
+ task={"id": "pokemon_red", "name": "Pokémon Red", "version": "0.1.0"},
39
+ environment="pokemon_red",
40
+ action_space={
41
+ "type": "tool_call",
42
+ "tools": [
43
+ {
44
+ "name": "press_button",
45
+ "schema": {"button": "string", "frames": "int"},
46
+ },
47
+ {
48
+ "name": "execute_sequence",
49
+ "description": "Execute multiple button presses in sequence. More efficient than separate calls. Recommended: 5-10 actions per call.",
50
+ "schema": {
51
+ "type": "object",
52
+ "properties": {
53
+ "actions": {
54
+ "type": "array",
55
+ "items": {
56
+ "type": "object",
57
+ "properties": {
58
+ "button": {"type": "string", "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"]},
59
+ "frames": {"type": "integer", "minimum": 1, "maximum": 120}
60
+ },
61
+ "required": ["button", "frames"]
62
+ },
63
+ "minItems": 1,
64
+ "maxItems": 20
65
+ }
66
+ },
67
+ "required": ["actions"]
68
+ },
69
+ }
70
+ ],
71
+ "max_calls": 1,
72
+ },
73
+ observation={
74
+ "summary": "GB memory-derived state with reward fields.",
75
+ "keys": [
76
+ "position",
77
+ "badges_earned",
78
+ "badges_bitfield",
79
+ "hp_status",
80
+ "party_level",
81
+ "party_xp",
82
+ "in_battle",
83
+ "step_count",
84
+ "reward_last_step",
85
+ "total_reward",
86
+ "terminated",
87
+ ],
88
+ },
89
+ dataset={"id": "pokemon_red_default", "name": "Pokémon Red Default", "version": "0.1.0"},
90
+ rubric={"version": "1", "criteria_count": 1, "source": "inline"},
91
+ inference={
92
+ "supports_proxy": True,
93
+ "tool": {"name": "press_button", "parallel_tool_calls": False},
94
+ "endpoints": {
95
+ "openai": "/proxy/v1/chat/completions",
96
+ "groq": "/proxy/groq/v1/chat/completions",
97
+ },
98
+ },
99
+ limits={"max_steps": 1000},
100
+ )
101
+
102
+
103
+ def _describe_taskset() -> dict[str, Any]:
104
+ return {"id": "pokemon_red_default", "name": "Pokémon Red Default"}
105
+
106
+
107
+ def _provide_task_instances(seeds: Sequence[int]) -> Iterable[TaskInfo]:
108
+ base = _base_task_info()
109
+ for s in seeds:
110
+ yield TaskInfo(
111
+ task=base.task,
112
+ environment=base.environment,
113
+ action_space=base.action_space,
114
+ observation={**base.observation, "seed": s},
115
+ dataset=base.dataset,
116
+ rubric=base.rubric,
117
+ inference=base.inference,
118
+ limits=base.limits,
119
+ )
120
+
121
+
122
+ def _build_action_context(prev_state: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]:
123
+ """Build action context dict with prev_ fields for reward calculation."""
124
+ return {
125
+ "prev_map_id": prev_state.get("map_id", 0),
126
+ "prev_player_x": prev_state.get("player_x", 0),
127
+ "prev_player_y": prev_state.get("player_y", 0),
128
+ "prev_party_count": prev_state.get("party_count", 0),
129
+ "prev_in_battle": prev_state.get("in_battle", False),
130
+ "prev_text_box_active": prev_state.get("text_box_active", False),
131
+ "prev_enemy_hp_current": prev_state.get("enemy_hp_current", 0),
132
+ "prev_enemy_hp_percentage": prev_state.get("enemy_hp_percentage", 0.0),
133
+ "prev_badges": prev_state.get("badges", 0),
134
+ "prev_party_level": prev_state.get("party_level", 0),
135
+ "prev_party_xp": prev_state.get("party_xp", 0),
136
+ }
137
+
138
+
139
+ def _describe_milestone(current_state: dict[str, Any], prev_state: dict[str, Any], reward: float) -> str:
140
+ """Generate human-readable milestone description."""
141
+ descriptions = []
142
+
143
+ # Map transitions
144
+ prev_map = prev_state.get("map_id", -1)
145
+ curr_map = current_state.get("map_id", -1)
146
+ if prev_map != curr_map:
147
+ map_names = {0: "Pallet Town", 1: "Bedroom", 2: "House", 3: "Oak's Lab"}
148
+ descriptions.append(f"Moved from {map_names.get(prev_map, f'Map{prev_map}')} to {map_names.get(curr_map, f'Map{curr_map}')}")
149
+
150
+ # Party changes
151
+ prev_party = prev_state.get("party_count", 0)
152
+ curr_party = current_state.get("party_count", 0)
153
+ if curr_party > prev_party:
154
+ descriptions.append(f"Received Pokémon (party: {prev_party}→{curr_party})")
155
+
156
+ # Battle state
157
+ prev_battle = prev_state.get("in_battle", False)
158
+ curr_battle = current_state.get("in_battle", False)
159
+ if not prev_battle and curr_battle:
160
+ descriptions.append("Entered battle")
161
+ elif prev_battle and not curr_battle:
162
+ battle_outcome = current_state.get("battle_outcome", 0)
163
+ if battle_outcome == 1:
164
+ descriptions.append("Won battle")
165
+ elif battle_outcome == 2:
166
+ descriptions.append("Lost battle")
167
+
168
+ # HP damage
169
+ prev_enemy_hp = prev_state.get("enemy_hp_current", 0)
170
+ curr_enemy_hp = current_state.get("enemy_hp_current", 0)
171
+ if prev_enemy_hp > curr_enemy_hp > 0:
172
+ damage = prev_enemy_hp - curr_enemy_hp
173
+ descriptions.append(f"Dealt {damage} damage to enemy")
174
+
175
+ return " | ".join(descriptions) if descriptions else f"Progress (+{reward:.0f})"
176
+
177
+
178
+ def _calculate_outcome_score(final_state: dict[str, Any], total_reward: float) -> float:
179
+ """Calculate outcome score based on final state and total reward."""
180
+ # Normalize reward to 0-1 scale (max expected is ~700)
181
+ reward_score = min(total_reward / 700.0, 1.0)
182
+
183
+ # Bonus for having Pokemon
184
+ has_pokemon = 1.0 if final_state.get("party_count", 0) > 0 else 0.0
185
+
186
+ # Bonus for being in Oak's lab or having left it
187
+ map_id = final_state.get("map_id", -1)
188
+ map_bonus = 0.5 if map_id in [0, 3] else 0.0 # Pallet Town or Oak's Lab
189
+
190
+ # Weighted combination
191
+ return (reward_score * 0.7) + (has_pokemon * 0.2) + (map_bonus * 0.1)
192
+
193
+
194
+ async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
195
+ # Initialize SessionTracer for this rollout
196
+ tracer_factory = getattr(fastapi_request.app.state, "session_tracer_factory", None)
197
+ tracer_instance: SessionTracer | None = None
198
+ if callable(tracer_factory):
199
+ try:
200
+ inst = tracer_factory()
201
+ tracer_instance = inst if isinstance(inst, SessionTracer) else None
202
+ except Exception as exc:
203
+ logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
204
+
205
+ # Start tracing session
206
+ if tracer_instance is not None:
207
+ try:
208
+ await tracer_instance.initialize()
209
+ await tracer_instance.start_session(
210
+ session_id=request.run_id,
211
+ metadata={
212
+ "run_id": request.run_id,
213
+ "env_name": "pokemon_red",
214
+ "policy_name": request.policy.policy_name or "default",
215
+ "seed": request.env.seed,
216
+ }
217
+ )
218
+ logger.info(f"[pokemon_red] tracing enabled for run_id={request.run_id}")
219
+ except Exception as exc:
220
+ logger.warning(f"[pokemon_red] tracing init failed: {exc}")
221
+ tracer_instance = None
222
+
223
+ async def _call_inference(policy_cfg: Mapping[str, Any], observation: Mapping[str, Any]) -> Mapping[str, Any]:
224
+ # Check if vision mode is enabled
225
+ use_vision = bool(policy_cfg.get("use_vision", False))
226
+ image_only_mode = bool(policy_cfg.get("image_only_mode", False))
227
+
228
+ # Build user message content
229
+ if use_vision and "observation_image_data_url" in observation:
230
+ # Extract image data URL
231
+ image_data_url = observation["observation_image_data_url"]
232
+
233
+ # Build state summary (text observation)
234
+ state_summary = "State summary: " + str({
235
+ k: observation.get(k)
236
+ for k in observation.keys()
237
+ if k not in ["error", "observation_image_base64", "observation_image_data_url",
238
+ "observation_image_format", "observation_image_width", "observation_image_height"]
239
+ })
240
+
241
+ # Image-only mode: only send image, no text
242
+ if image_only_mode:
243
+ user_content = [
244
+ {"type": "image_url", "image_url": {"url": image_data_url}}
245
+ ]
246
+ else:
247
+ # Vision mode with text: send both text and image
248
+ user_content = [
249
+ {"type": "text", "text": state_summary},
250
+ {"type": "image_url", "image_url": {"url": image_data_url}}
251
+ ]
252
+ else:
253
+ # Text-only mode (default)
254
+ state_summary = "State summary: " + str({
255
+ k: observation.get(k) for k in observation.keys() if k != "error"
256
+ })
257
+ user_content = state_summary
258
+
259
+ messages = [
260
+ {
261
+ "role": "system",
262
+ "content": (
263
+ "You are controlling Pokémon Red. Respond with a single tool call named 'press_button' "
264
+ "with JSON arguments {button: 'A|B|UP|DOWN|LEFT|RIGHT|START|SELECT', frames: 1-120}."
265
+ ),
266
+ },
267
+ {
268
+ "role": "user",
269
+ "content": user_content,
270
+ },
271
+ ]
272
+ payload = {
273
+ "model": policy_cfg.get("model") or "qwen-2.5-7b",
274
+ "messages": messages,
275
+ "tools": [
276
+ {
277
+ "type": "function",
278
+ "function": {
279
+ "name": "execute_sequence",
280
+ "description": "Execute multiple button presses in sequence. More efficient than separate calls. Recommended: 5-10 actions per call.",
281
+ "parameters": {
282
+ "type": "object",
283
+ "properties": {
284
+ "actions": {
285
+ "type": "array",
286
+ "items": {
287
+ "type": "object",
288
+ "properties": {
289
+ "button": {
290
+ "type": "string",
291
+ "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"],
292
+ "description": "Game Boy button to press"
293
+ },
294
+ "frames": {
295
+ "type": "integer",
296
+ "minimum": 1,
297
+ "maximum": 120,
298
+ "description": "Number of frames to hold the button (30 frames = 0.5 seconds)"
299
+ }
300
+ },
301
+ "required": ["button", "frames"]
302
+ },
303
+ "minItems": 1,
304
+ "maxItems": 20,
305
+ "description": "Sequence of button presses to execute"
306
+ }
307
+ },
308
+ "required": ["actions"],
309
+ "additionalProperties": False,
310
+ },
311
+ },
312
+ },
313
+ {
314
+ "type": "function",
315
+ "function": {
316
+ "name": "press_button",
317
+ "description": "Press a single Game Boy button for N frames (use execute_sequence for multiple actions)",
318
+ "parameters": {
319
+ "type": "object",
320
+ "properties": {
321
+ "button": {"type": "string", "enum": ["UP", "DOWN", "LEFT", "RIGHT", "A", "B", "START", "SELECT"]},
322
+ "frames": {"type": "integer", "minimum": 1, "maximum": 120},
323
+ },
324
+ "required": ["button"],
325
+ "additionalProperties": False,
326
+ },
327
+ },
328
+ }
329
+ ],
330
+ "tool_choice": {"type": "function", "function": {"name": "execute_sequence"}},
331
+ "temperature": float(policy_cfg.get("temperature") or 0.0),
332
+ "top_p": float(policy_cfg.get("top_p") or 1.0),
333
+ "max_tokens": int(policy_cfg.get("max_tokens") or 500),
334
+ }
335
+ inference_url = str(policy_cfg.get("inference_url") or "").rstrip("/")
336
+
337
+ # Determine if this is an external URL or internal proxy
338
+ is_external = inference_url.startswith("http://") or inference_url.startswith("https://")
339
+
340
+ if not inference_url:
341
+ # Prefer built-in proxy endpoints from app if no external URL
342
+ provider = (policy_cfg.get("provider") or "").lower()
343
+ if provider == "groq":
344
+ inference_url = "/proxy/groq/v1/chat/completions"
345
+ else:
346
+ inference_url = "/proxy/v1/chat/completions"
347
+ is_external = False
348
+ elif is_external:
349
+ # Add /v1/chat/completions if using OpenAI directly
350
+ if "api.openai.com" in inference_url and not inference_url.endswith("/chat/completions"):
351
+ inference_url = inference_url + "/v1/chat/completions"
352
+
353
+ if is_external:
354
+ # External API: use direct HTTP client with auth header
355
+ headers = {}
356
+ if "api.openai.com" in inference_url:
357
+ import os
358
+ api_key = os.getenv("OPENAI_API_KEY")
359
+ if api_key:
360
+ headers["Authorization"] = f"Bearer {api_key}"
361
+
362
+ async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
363
+ resp = await client.post(inference_url, json=payload, headers=headers)
364
+ else:
365
+ # Internal proxy: use local base_url
366
+ async with httpx.AsyncClient(
367
+ base_url="http://127.0.0.1:" + str(fastapi_request.url.port or 8913),
368
+ timeout=httpx.Timeout(60.0)
369
+ ) as client:
370
+ resp = await client.post(inference_url, json=payload)
371
+
372
+ resp.raise_for_status()
373
+ data = resp.json()
374
+ # Extract first tool call
375
+ choices = data.get("choices") or []
376
+ if not choices:
377
+ return {}
378
+ message = choices[0].get("message") or {}
379
+ raw_calls = message.get("tool_calls") or []
380
+ if not raw_calls:
381
+ return {}
382
+ f = raw_calls[0].get("function") or {}
383
+ tool_name = f.get("name", "")
384
+ args = f.get("arguments")
385
+ import json as _json
386
+ try:
387
+ parsed_args = _json.loads(args) if isinstance(args, str) else dict(args or {})
388
+ except Exception:
389
+ parsed_args = {}
390
+
391
+ # Handle execute_sequence tool
392
+ if tool_name == "execute_sequence":
393
+ return {"actions": parsed_args.get("actions", [])}
394
+
395
+ # Handle press_button tool (legacy single action)
396
+ return {"button": parsed_args.get("button"), "frames": int(parsed_args.get("frames") or 30)}
397
+
398
+ # Initialize reward function
399
+ reward_fn = PalletTownProgressionCompositeReward()
400
+
401
+ env = PokemonRedEnvironment(RED_DEFAULT_INSTANCE)
402
+ obs0 = await env.initialize()
403
+
404
+ # Track cumulative stats
405
+ total_reward = 0.0
406
+ all_reward_components: list[dict[str, Any]] = []
407
+ milestone_events: list[dict[str, Any]] = []
408
+
409
+ steps: list[RolloutStep] = [
410
+ RolloutStep(obs=obs0, tool_calls=[], reward=0.0, done=False, info={"step_type": "initial"}),
411
+ ]
412
+
413
+ # Track previous state for reward calculation
414
+ prev_state = dict(obs0) if isinstance(obs0, Mapping) else {}
415
+
416
+ # Process all ops (explicit actions)
417
+ final_obs = obs0
418
+ for step_idx, op in enumerate(request.ops or []):
419
+ macro = None
420
+ if isinstance(op, dict):
421
+ macro = op.get("action") or op
422
+
423
+ if isinstance(macro, dict):
424
+ # Check if this is an execute_sequence call
425
+ if "actions" in macro:
426
+ # Handle execute_sequence: multiple actions in one call
427
+ actions_list = macro.get("actions", [])
428
+ sequence_reward = 0.0
429
+ sequence_tool_calls = []
430
+
431
+ for action_item in actions_list:
432
+ button = action_item.get("button", "A")
433
+ frames = int(action_item.get("frames", 1))
434
+
435
+ obs1 = await env.step(EnvToolCall(tool="press_button", args={"button": button, "frames": frames}))
436
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
437
+ action_context = _build_action_context(prev_state, current_state)
438
+ step_reward = await reward_fn.score(current_state, action_context)
439
+
440
+ sequence_reward += step_reward
441
+ sequence_tool_calls.append({"tool": "press_button", "args": {"button": button, "frames": frames}})
442
+
443
+ if step_reward > 0:
444
+ reward_component = {
445
+ "step": step_idx + 1,
446
+ "reward": step_reward,
447
+ "button": button,
448
+ "map_id": current_state.get("map_id"),
449
+ "position": f"({current_state.get('player_x')},{current_state.get('player_y')})",
450
+ }
451
+ all_reward_components.append(reward_component)
452
+ milestone_events.append({
453
+ "type": "milestone",
454
+ "step": step_idx + 1,
455
+ "reward": step_reward,
456
+ "description": _describe_milestone(current_state, prev_state, step_reward),
457
+ })
458
+
459
+ final_obs = obs1
460
+ prev_state = current_state
461
+
462
+ total_reward += sequence_reward
463
+ step_info = {
464
+ "step_type": "sequence",
465
+ "step_idx": step_idx,
466
+ "actions_count": len(actions_list),
467
+ "cumulative_reward": total_reward,
468
+ }
469
+ if sequence_reward > 0:
470
+ step_info["sequence_reward"] = sequence_reward
471
+
472
+ steps.append(
473
+ RolloutStep(
474
+ obs=final_obs,
475
+ tool_calls=sequence_tool_calls,
476
+ reward=sequence_reward,
477
+ done=False,
478
+ info=step_info,
479
+ )
480
+ )
481
+ else:
482
+ # Handle single press_button call
483
+ button = macro.get("button") or "A"
484
+ frames = int(macro.get("frames") or 1)
485
+ obs1 = await env.step(EnvToolCall(tool="press_button", args={"button": button, "frames": frames}))
486
+
487
+ # Calculate step reward
488
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
489
+ action_context = _build_action_context(prev_state, current_state)
490
+ step_reward = await reward_fn.score(current_state, action_context)
491
+ total_reward += step_reward
492
+
493
+ # Track reward components if non-zero
494
+ step_info: dict[str, Any] = {"step_type": "action", "step_idx": step_idx}
495
+ if step_reward > 0:
496
+ reward_component = {
497
+ "step": step_idx + 1,
498
+ "reward": step_reward,
499
+ "button": button,
500
+ "map_id": current_state.get("map_id"),
501
+ "position": f"({current_state.get('player_x')},{current_state.get('player_y')})",
502
+ }
503
+ all_reward_components.append(reward_component)
504
+ step_info["reward_component"] = reward_component
505
+
506
+ # Track milestone events
507
+ milestone_events.append({
508
+ "type": "milestone",
509
+ "step": step_idx + 1,
510
+ "reward": step_reward,
511
+ "description": _describe_milestone(current_state, prev_state, step_reward),
512
+ })
513
+
514
+ step_info["cumulative_reward"] = total_reward
515
+
516
+ steps.append(
517
+ RolloutStep(
518
+ obs=obs1,
519
+ tool_calls=[{"tool": "press_button", "args": {"button": button, "frames": frames}}],
520
+ reward=step_reward,
521
+ done=False,
522
+ info=step_info,
523
+ )
524
+ )
525
+ final_obs = obs1
526
+ prev_state = current_state
527
+ else:
528
+ # Attempt policy-driven step if policy.config present
529
+ policy_cfg = request.policy.config or {}
530
+ if policy_cfg:
531
+ try:
532
+ action = await _call_inference(policy_cfg, final_obs if isinstance(final_obs, Mapping) else {})
533
+
534
+ # Handle execute_sequence from policy
535
+ if "actions" in action:
536
+ actions_list = action.get("actions", [])
537
+ sequence_reward = 0.0
538
+ sequence_tool_calls = []
539
+
540
+ for action_item in actions_list:
541
+ button = action_item.get("button", "A")
542
+ frames = int(action_item.get("frames", 30))
543
+
544
+ obs1 = await env.step(EnvToolCall(tool="press_button", args={"button": button, "frames": frames}))
545
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
546
+ action_context = _build_action_context(prev_state, current_state)
547
+ step_reward = await reward_fn.score(current_state, action_context)
548
+
549
+ sequence_reward += step_reward
550
+ sequence_tool_calls.append({"tool": "press_button", "args": {"button": button, "frames": frames}})
551
+
552
+ if step_reward > 0:
553
+ reward_component = {
554
+ "step": step_idx + 1,
555
+ "reward": step_reward,
556
+ "button": button,
557
+ "map_id": current_state.get("map_id"),
558
+ "position": f"({current_state.get('player_x')},{current_state.get('player_y')})",
559
+ }
560
+ all_reward_components.append(reward_component)
561
+ milestone_events.append({
562
+ "type": "milestone",
563
+ "step": step_idx + 1,
564
+ "reward": step_reward,
565
+ "description": _describe_milestone(current_state, prev_state, step_reward),
566
+ })
567
+
568
+ final_obs = obs1
569
+ prev_state = current_state
570
+
571
+ total_reward += sequence_reward
572
+ step_info = {
573
+ "step_type": "policy_sequence",
574
+ "step_idx": step_idx,
575
+ "actions_count": len(actions_list),
576
+ "cumulative_reward": total_reward,
577
+ }
578
+ if sequence_reward > 0:
579
+ step_info["sequence_reward"] = sequence_reward
580
+
581
+ steps.append(
582
+ RolloutStep(
583
+ obs=final_obs,
584
+ tool_calls=sequence_tool_calls,
585
+ reward=sequence_reward,
586
+ done=False,
587
+ info=step_info,
588
+ )
589
+ )
590
+
591
+ # Handle single button press from policy
592
+ elif action.get("button"):
593
+ obs1 = await env.step(EnvToolCall(tool="press_button", args=action))
594
+
595
+ # Calculate step reward
596
+ current_state = dict(obs1) if isinstance(obs1, Mapping) else {}
597
+ action_context = _build_action_context(prev_state, current_state)
598
+ step_reward = await reward_fn.score(current_state, action_context)
599
+ total_reward += step_reward
600
+
601
+ step_info_policy: dict[str, Any] = {
602
+ "step_type": "policy",
603
+ "step_idx": step_idx,
604
+ "cumulative_reward": total_reward,
605
+ "proxy": True,
606
+ }
607
+ if step_reward > 0:
608
+ step_info_policy["reward_earned"] = step_reward
609
+
610
+ steps.append(
611
+ RolloutStep(
612
+ obs=obs1,
613
+ tool_calls=[{"tool": "press_button", "args": action}],
614
+ reward=step_reward,
615
+ done=False,
616
+ info=step_info_policy,
617
+ )
618
+ )
619
+ final_obs = obs1
620
+ prev_state = current_state
621
+ except Exception:
622
+ pass
623
+
624
+ # Calculate outcome score based on milestones achieved
625
+ final_state = dict(final_obs) if isinstance(final_obs, Mapping) else {}
626
+ outcome_score = _calculate_outcome_score(final_state, total_reward)
627
+
628
+ metrics = RolloutMetrics(
629
+ episode_returns=[total_reward],
630
+ mean_return=total_reward,
631
+ num_steps=len(steps),
632
+ num_episodes=1,
633
+ outcome_score=outcome_score,
634
+ details={
635
+ "total_reward": total_reward,
636
+ "reward_components": all_reward_components,
637
+ "milestone_events": milestone_events,
638
+ "final_map": final_state.get("map_id"),
639
+ "party_count": final_state.get("party_count", 0),
640
+ "badges": final_state.get("badges", 0),
641
+ },
642
+ )
643
+
644
+ # Extract inference_url from policy config
645
+ inference_url = (policy_cfg or {}).get("inference_url")
646
+
647
+ trajectory = RolloutTrajectory(
648
+ env_id="pokemon_red",
649
+ policy_id=request.policy.policy_id or "policy",
650
+ steps=steps,
651
+ final={"observation": final_obs, "reward": total_reward},
652
+ length=len(steps),
653
+ inference_url=inference_url, # NEW: Required for trace correlation
654
+ )
655
+
656
+ # Record outcome rewards and end session
657
+ trace_payload = None
658
+ if tracer_instance is not None:
659
+ try:
660
+ # Count achievements (milestones)
661
+ achievements_count = len(milestone_events)
662
+
663
+ # Build metadata with all relevant info
664
+ reward_metadata = {
665
+ "run_id": request.run_id,
666
+ "env_name": "pokemon_red",
667
+ "final_map": final_state.get("map_id", -1),
668
+ "party_count": final_state.get("party_count", 0),
669
+ "badges": final_state.get("badges", 0),
670
+ "steps": len(steps),
671
+ "milestone_events": milestone_events,
672
+ "reward_components": all_reward_components,
673
+ }
674
+
675
+ # Record outcome reward to Turso
676
+ await tracer_instance.record_outcome_reward(
677
+ total_reward=int(total_reward),
678
+ achievements_count=achievements_count,
679
+ total_steps=len(steps),
680
+ reward_metadata=reward_metadata,
681
+ )
682
+ logger.info(f"[pokemon_red] recorded outcome: reward={total_reward}, achievements={achievements_count}")
683
+
684
+ # End session and get trace
685
+ session_trace = await tracer_instance.end_session()
686
+
687
+ # Build trace payload if requested
688
+ record_config = getattr(request, 'record', None)
689
+ if record_config and getattr(record_config, 'return_trace', False) and session_trace:
690
+ trace_payload = {
691
+ "session_id": session_trace.session_id,
692
+ "created_at": session_trace.created_at.isoformat() if session_trace.created_at else None,
693
+ "metadata": dict(session_trace.metadata or {}),
694
+ "num_timesteps": session_trace.num_timesteps,
695
+ "num_events": session_trace.num_events,
696
+ "num_messages": session_trace.num_messages,
697
+ }
698
+ except Exception as exc:
699
+ logger.warning(f"[pokemon_red] tracing finalization failed: {exc}")
700
+
701
+ # Fallback trace payload if no tracer but CLI needs it
702
+ if trace_payload is None:
703
+ record_config = getattr(request, 'record', None)
704
+ if record_config and getattr(record_config, 'return_trace', False):
705
+ trace_payload = {
706
+ "session_id": request.run_id,
707
+ "created_at": import_datetime().now().isoformat(),
708
+ "metadata": {
709
+ "run_id": request.run_id,
710
+ "env_name": "pokemon_red",
711
+ "total_reward": int(total_reward),
712
+ "final_map": final_state.get("map_id", -1),
713
+ "party_count": final_state.get("party_count", 0),
714
+ "badges": final_state.get("badges", 0),
715
+ "steps": len(steps),
716
+ },
717
+ "num_timesteps": len(steps),
718
+ "num_events": len(steps),
719
+ "num_messages": len(steps) * 2,
720
+ }
721
+
722
+ return RolloutResponse(
723
+ run_id=request.run_id,
724
+ trajectories=[trajectory],
725
+ branches={},
726
+ metrics=metrics,
727
+ aborted=False,
728
+ ops_executed=len(request.ops or []),
729
+ trace=trace_payload,
730
+ )
731
+
732
+
733
+ def import_datetime():
734
+ """Helper to import datetime for trace timestamps."""
735
+ from datetime import datetime
736
+ return datetime
737
+
738
+
739
+ def build_config() -> TaskAppConfig:
740
+ base_info = _base_task_info()
741
+
742
+ # Set up tracing
743
+ tracing_enabled = tracing_env_enabled()
744
+ tracing_db_url = resolve_tracing_db_url()
745
+ tracer_factory = build_tracer_factory(
746
+ SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url
747
+ )
748
+ sft_output_dir = resolve_sft_output_dir()
749
+
750
+ app_state: dict[str, Any] = {
751
+ "tracing_enabled": tracing_enabled,
752
+ }
753
+ if tracer_factory is not None:
754
+ app_state["session_tracer_factory"] = tracer_factory
755
+ if sft_output_dir:
756
+ app_state["sft_output_dir"] = sft_output_dir
757
+
758
+ if tracing_enabled:
759
+ status_msg = f"[task:tracing] enabled (db={tracing_db_url or 'default'})"
760
+ logger.info(status_msg)
761
+ print(status_msg, flush=True)
762
+
763
+ return TaskAppConfig(
764
+ app_id="pokemon_red",
765
+ name="Pokémon Red Task App",
766
+ description="Expose Pokémon Red via Synth task framework (demo).",
767
+ base_task_info=base_info,
768
+ describe_taskset=_describe_taskset,
769
+ provide_task_instances=_provide_task_instances,
770
+ rollout=rollout_executor,
771
+ dataset_registry=None,
772
+ proxy=ProxyConfig(
773
+ enable_openai=True,
774
+ enable_groq=True,
775
+ system_hint=(
776
+ "You control Pokémon Red. Use 'execute_sequence' with 5-10 actions to play efficiently. "
777
+ "Plan ahead: navigate rooms, advance dialogue, battle strategically. "
778
+ "Example: {\"tool\": \"execute_sequence\", \"args\": {\"actions\": [{\"button\": \"DOWN\", \"frames\": 30}, ...]}}"
779
+ ),
780
+ ),
781
+ app_state=app_state,
782
+ require_api_key=False,
783
+ expose_debug_env=True,
784
+ cors_origins=["*"],
785
+ )
786
+
787
+
788
+ register_task_app(
789
+ entry=TaskAppEntry(
790
+ app_id="pokemon_red",
791
+ description="Pokémon Red demo task app",
792
+ config_factory=build_config,
793
+ aliases=("pokemon_red_demo",),
794
+ env_files=(),
795
+ modal=None,
796
+ )
797
+ )
798
+
799
+