synth-ai 0.2.8.dev4__py3-none-any.whl → 0.2.23.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (889) hide show
  1. examples/README.md +1 -0
  2. examples/__init__.py +16 -0
  3. examples/analyze_semantic_words.sh +17 -0
  4. examples/baseline/banking77_baseline.py +243 -0
  5. examples/baseline/banking77_pipeline_baseline.py +294 -0
  6. examples/baseline/crafter_baseline.py +407 -0
  7. examples/baseline/pokemon_red_baseline.py +326 -0
  8. examples/baseline/simple_baseline.py +56 -0
  9. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  10. examples/blog_posts/gepa/README.md +355 -0
  11. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  12. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +80 -0
  13. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +50 -0
  14. examples/blog_posts/gepa/configs/banking77_pipeline_gepa_local.toml +101 -0
  15. examples/blog_posts/gepa/configs/banking77_pipeline_gepa_test.toml +96 -0
  16. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +57 -0
  17. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +35 -0
  18. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +51 -0
  19. examples/blog_posts/gepa/configs/hover_gepa_local.toml +57 -0
  20. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +35 -0
  21. examples/blog_posts/gepa/configs/hover_mipro_local.toml +51 -0
  22. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +57 -0
  23. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +35 -0
  24. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +51 -0
  25. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +58 -0
  26. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +52 -0
  27. examples/blog_posts/gepa/deploy_banking77_task_app.sh +54 -0
  28. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  29. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  30. examples/blog_posts/gepa/run_gepa_banking77.sh +112 -0
  31. examples/blog_posts/gepa/run_gepa_banking77_pipeline.sh +163 -0
  32. examples/blog_posts/gepa/task_apps.py +105 -0
  33. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  34. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  35. examples/blog_posts/mipro/README.md +415 -0
  36. examples/blog_posts/mipro/configs/banking77_mipro_local.toml +91 -0
  37. examples/blog_posts/mipro/configs/banking77_mipro_test.toml +87 -0
  38. examples/blog_posts/mipro/configs/banking77_pipeline_mipro_gemini_flash_lite_local.toml +98 -0
  39. examples/blog_posts/mipro/configs/banking77_pipeline_mipro_gpt41mini_local.toml +96 -0
  40. examples/blog_posts/mipro/configs/banking77_pipeline_mipro_local.toml +94 -0
  41. examples/blog_posts/mipro/configs/banking77_pipeline_mipro_test.toml +170 -0
  42. examples/blog_posts/mipro/deploy_banking77_pipeline_task_app.sh +59 -0
  43. examples/blog_posts/mipro/deploy_banking77_task_app.sh +41 -0
  44. examples/blog_posts/mipro/multi_step.md +79 -0
  45. examples/blog_posts/mipro/run_mipro_banking77.sh +191 -0
  46. examples/blog_posts/mipro/run_mipro_banking77_pipeline.sh +171 -0
  47. examples/blog_posts/mipro/run_mipro_banking77_pipeline_gemini_flash_lite.sh +177 -0
  48. examples/blog_posts/mipro/run_mipro_banking77_pipeline_gpt41mini.sh +173 -0
  49. examples/blog_posts/mipro/verify_banking77_setup.sh +117 -0
  50. examples/blog_posts/pokemon_vl/README.md +98 -0
  51. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  52. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
  53. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  54. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  55. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
  56. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  57. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  58. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  59. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  60. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  61. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  62. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  63. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  64. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  65. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  66. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  67. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  68. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  69. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  70. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  71. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  72. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  73. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  74. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  75. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
  76. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  77. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  78. examples/crafter_debug_render.py +186 -0
  79. examples/dev/qwen3_32b_qlora_4xh100.toml +45 -0
  80. examples/gepa/banking77_pipeline_gepa.toml +96 -0
  81. examples/gepa/multi_stage_gepa_example.toml +84 -0
  82. examples/gepa/run_gepa_banking77_pipeline.sh +157 -0
  83. examples/multi_step/SFT_README.md +147 -0
  84. examples/multi_step/configs/README_verilog_rl.md +77 -0
  85. examples/multi_step/configs/VERILOG_REWARDS.md +103 -0
  86. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +196 -0
  87. examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
  88. examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
  89. examples/multi_step/configs/crafter_rl_outcome.toml +75 -0
  90. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +145 -0
  91. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +84 -0
  92. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +79 -0
  93. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  94. examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
  95. examples/multi_step/configs/crafter_synth_backend.md +40 -0
  96. examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
  97. examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
  98. examples/multi_step/configs/verilog_rl_lora.toml +147 -0
  99. examples/multi_step/convert_traces_to_sft.py +84 -0
  100. examples/multi_step/crafter_rl_lora.md +70 -0
  101. examples/multi_step/judges/crafter_backend_judge.py +220 -0
  102. examples/multi_step/judges/verilog_backend_judge.py +234 -0
  103. examples/multi_step/readme.md +48 -0
  104. examples/multi_step/run_sft_qwen30b.sh +45 -0
  105. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  106. examples/multi_step/task_app_config_notes.md +494 -0
  107. examples/multi_step/verilog_rl_lora.md +218 -0
  108. examples/qwen_coder/README.md +102 -0
  109. examples/qwen_coder/_shared.py +113 -0
  110. examples/qwen_coder/configs/coder_lora_30b.toml +60 -0
  111. examples/qwen_coder/configs/coder_lora_4b.toml +61 -0
  112. examples/qwen_coder/configs/coder_lora_small.toml +57 -0
  113. examples/qwen_coder/generate_dataset.py +98 -0
  114. examples/qwen_coder/infer_ft_smoke.py +65 -0
  115. examples/qwen_coder/infer_prod_proxy.py +73 -0
  116. examples/qwen_coder/infer_via_synth.py +87 -0
  117. examples/qwen_coder/scripts/infer_coder.sh +19 -0
  118. examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
  119. examples/qwen_coder/sft_full_17b.py +103 -0
  120. examples/qwen_coder/sft_lora_30b.py +110 -0
  121. examples/qwen_coder/subset_jsonl.py +39 -0
  122. examples/qwen_coder/todos.md +38 -0
  123. examples/qwen_coder/validate_jsonl.py +60 -0
  124. examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
  125. examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
  126. examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
  127. examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
  128. examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
  129. examples/qwen_vl/QUICKSTART.md +327 -0
  130. examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
  131. examples/qwen_vl/README.md +152 -0
  132. examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
  133. examples/qwen_vl/RL_VISION_TESTING.md +333 -0
  134. examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
  135. examples/qwen_vl/SETUP_COMPLETE.md +274 -0
  136. examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
  137. examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
  138. examples/qwen_vl/__init__.py +2 -0
  139. examples/qwen_vl/collect_data_via_cli.md +415 -0
  140. examples/qwen_vl/collect_vision_traces.py +368 -0
  141. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
  142. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
  143. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
  144. examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
  145. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
  146. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  147. examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
  148. examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
  149. examples/qwen_vl/configs/filter_vision_test.toml +8 -0
  150. examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
  151. examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
  152. examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
  153. examples/qwen_vl/run_vision_comparison.sh +61 -0
  154. examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
  155. examples/qwen_vl/test_image_validation.py +201 -0
  156. examples/qwen_vl/test_sft_vision_data.py +110 -0
  157. examples/rl/README.md +169 -0
  158. examples/rl/configs/eval_base_qwen.toml +17 -0
  159. examples/rl/configs/eval_rl_qwen.toml +13 -0
  160. examples/rl/configs/rl_from_base_qwen.toml +62 -0
  161. examples/rl/configs/rl_from_base_qwen17.toml +80 -0
  162. examples/rl/configs/rl_from_ft_qwen.toml +37 -0
  163. examples/rl/download_dataset.py +80 -0
  164. examples/rl/run_eval.py +436 -0
  165. examples/rl/run_rl_and_save.py +111 -0
  166. examples/rl/task_app/README.md +21 -0
  167. examples/rl/task_app/math_single_step.py +990 -0
  168. examples/rl/task_app/math_task_app.py +111 -0
  169. examples/run_crafter_demo.sh +10 -0
  170. examples/sdk_prompt_learning_example.py +55 -0
  171. examples/sft/README.md +139 -0
  172. examples/sft/configs/crafter_fft_qwen0p6b.toml +49 -0
  173. examples/sft/configs/crafter_lora_qwen0p6b.toml +49 -0
  174. examples/sft/evaluate.py +117 -0
  175. examples/sft/export_dataset.py +120 -0
  176. examples/sft/generate_traces.py +164 -0
  177. examples/swe/__init__.py +12 -0
  178. examples/swe/task_app/README.md +135 -0
  179. examples/swe/task_app/__init__.py +2 -0
  180. examples/swe/task_app/grpo_swe_mini.py +604 -0
  181. examples/swe/task_app/grpo_swe_mini_task_app.py +124 -0
  182. examples/swe/task_app/hosted/README.md +173 -0
  183. examples/swe/task_app/hosted/__init__.py +5 -0
  184. examples/swe/task_app/hosted/branching.py +143 -0
  185. examples/swe/task_app/hosted/environment_routes.py +1289 -0
  186. examples/swe/task_app/hosted/envs/__init__.py +1 -0
  187. examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
  188. examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
  189. examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
  190. examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
  191. examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
  192. examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
  193. examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
  194. examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
  195. examples/swe/task_app/hosted/envs/mini_swe/environment.py +1191 -0
  196. examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
  197. examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
  198. examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
  199. examples/swe/task_app/hosted/hosted_app.py +204 -0
  200. examples/swe/task_app/hosted/inference/__init__.py +5 -0
  201. examples/swe/task_app/hosted/inference/openai_client.py +584 -0
  202. examples/swe/task_app/hosted/main.py +100 -0
  203. examples/swe/task_app/hosted/policy_routes.py +1094 -0
  204. examples/swe/task_app/hosted/registry.py +195 -0
  205. examples/swe/task_app/hosted/rollout.py +1905 -0
  206. examples/swe/task_app/hosted/storage/__init__.py +5 -0
  207. examples/swe/task_app/hosted/storage/volume.py +211 -0
  208. examples/swe/task_app/hosted/test_agents.py +161 -0
  209. examples/swe/task_app/hosted/test_service.py +136 -0
  210. examples/swe/task_app/hosted/utils.py +62 -0
  211. examples/swe/task_app/morph_backend.py +178 -0
  212. examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
  213. examples/task_apps/TESTING.md +275 -0
  214. examples/task_apps/banking77/__init__.py +6 -0
  215. examples/task_apps/banking77/banking77_task_app.py +912 -0
  216. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  217. examples/task_apps/banking77_pipeline/__init__.py +6 -0
  218. examples/task_apps/banking77_pipeline/banking77_pipeline_task_app.py +489 -0
  219. examples/task_apps/banking77_pipeline/deploy_wrapper.py +50 -0
  220. examples/task_apps/crafter/CREATE_SFT_DATASET.md +286 -0
  221. examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
  222. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +187 -0
  223. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +281 -0
  224. examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
  225. examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
  226. examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
  227. examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
  228. examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
  229. examples/task_apps/crafter/task_app/README.md +42 -0
  230. examples/task_apps/crafter/task_app/__init__.py +5 -0
  231. examples/task_apps/crafter/task_app/grpo_crafter.py +1055 -0
  232. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +146 -0
  233. examples/task_apps/crafter/task_app/synth_envs_hosted/README.md +173 -0
  234. examples/task_apps/crafter/task_app/synth_envs_hosted/__init__.py +5 -0
  235. examples/task_apps/crafter/task_app/synth_envs_hosted/branching.py +143 -0
  236. examples/task_apps/crafter/task_app/synth_envs_hosted/environment_routes.py +1226 -0
  237. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  238. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  239. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  240. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +532 -0
  241. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +583 -0
  242. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +122 -0
  243. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
  244. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  245. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +253 -0
  246. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  247. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +999 -0
  248. examples/task_apps/crafter/task_app/synth_envs_hosted/main.py +100 -0
  249. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +1252 -0
  250. examples/task_apps/crafter/task_app/synth_envs_hosted/registry.py +195 -0
  251. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +2233 -0
  252. examples/task_apps/crafter/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  253. examples/task_apps/crafter/task_app/synth_envs_hosted/storage/volume.py +211 -0
  254. examples/task_apps/crafter/task_app/synth_envs_hosted/test_agents.py +161 -0
  255. examples/task_apps/crafter/task_app/synth_envs_hosted/test_service.py +136 -0
  256. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +411 -0
  257. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  258. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  259. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  260. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  261. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  262. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  263. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  264. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  265. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  266. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  267. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  268. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  269. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  270. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  271. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  272. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  273. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  274. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  275. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  276. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  277. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  278. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  279. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  280. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  281. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  282. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  283. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  284. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  285. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  286. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  287. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  288. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  289. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  290. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  291. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  292. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  293. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  294. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  295. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  296. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  297. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  298. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  299. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  300. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  301. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  302. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  303. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  304. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  305. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  306. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  307. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  308. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  309. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  310. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  311. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  312. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  313. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  314. examples/task_apps/enron/__init__.py +2 -0
  315. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  316. examples/task_apps/enron/filter_sft.toml +5 -0
  317. examples/task_apps/enron/task_app/README.md +14 -0
  318. examples/task_apps/enron/task_app/__init__.py +1 -0
  319. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  320. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  321. examples/task_apps/enron/tests/__init__.py +4 -0
  322. examples/task_apps/enron/tests/conftest.py +115 -0
  323. examples/task_apps/enron/tests/integration/__init__.py +4 -0
  324. examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
  325. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  326. examples/task_apps/enron/tests/unit/__init__.py +4 -0
  327. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  328. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  329. examples/task_apps/gepa_benchmarks/common.py +260 -0
  330. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  331. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  332. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  333. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  334. examples/task_apps/math/README.md +21 -0
  335. examples/task_apps/math/math_single_step.py +1000 -0
  336. examples/task_apps/math/math_task_app.py +115 -0
  337. examples/task_apps/pokemon_battle/__init__.py +2 -0
  338. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  339. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  340. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  341. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  342. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
  343. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
  344. examples/task_apps/pokemon_red/README.md +356 -0
  345. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +428 -0
  346. examples/task_apps/pokemon_red/__init__.py +3 -0
  347. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +30 -0
  348. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +224 -0
  349. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
  350. examples/task_apps/pokemon_red/task_app.py +1048 -0
  351. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
  352. examples/task_apps/sokoban/README.md +306 -0
  353. examples/task_apps/sokoban/__init__.py +3 -0
  354. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  355. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  356. examples/task_apps/sokoban/filter_sft.toml +5 -0
  357. examples/task_apps/sokoban/task_app.py +1058 -0
  358. examples/task_apps/sokoban/tests/__init__.py +4 -0
  359. examples/task_apps/sokoban/tests/conftest.py +113 -0
  360. examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
  361. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  362. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  363. examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
  364. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  365. examples/task_apps/verilog/__init__.py +1 -0
  366. examples/task_apps/verilog/eval_groq_qwen32b.toml +22 -0
  367. examples/task_apps/verilog/filter_sft.toml +5 -0
  368. examples/task_apps/verilog/task_app/README.md +12 -0
  369. examples/task_apps/verilog/task_app/__init__.py +1 -0
  370. examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
  371. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  372. examples/task_apps/verilog/tests/__init__.py +4 -0
  373. examples/task_apps/verilog/tests/conftest.py +115 -0
  374. examples/task_apps/verilog/tests/integration/__init__.py +4 -0
  375. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
  376. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  377. examples/task_apps/verilog/tests/unit/__init__.py +4 -0
  378. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  379. examples/tunnel_gepa_banking77/README.md +106 -0
  380. examples/tunnel_gepa_banking77/banking77_gepa_tunnel.toml +95 -0
  381. examples/tunnel_gepa_banking77/keep_tunnel_running.py +60 -0
  382. examples/tunnel_gepa_banking77/run_gepa_with_tunnel.sh +226 -0
  383. examples/vlm/PROPOSAL.md +53 -0
  384. examples/vlm/README.md +68 -0
  385. examples/vlm/configs/crafter_vlm_gpt4o.toml +49 -0
  386. examples/vlm/crafter_image_only_agent.py +207 -0
  387. examples/vlm/crafter_openai_vlm_agent.py +275 -0
  388. examples/vlm/filter_image_rows.py +63 -0
  389. examples/vlm/run_crafter_vlm_benchmark.py +316 -0
  390. examples/warming_up_to_rl/_utils.py +92 -0
  391. examples/warming_up_to_rl/analyze_trace_db.py +422 -0
  392. examples/warming_up_to_rl/configs/crafter_fft.toml +53 -0
  393. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  394. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +22 -0
  395. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +15 -0
  396. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +24 -0
  397. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
  398. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  399. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  400. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
  401. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +85 -0
  402. examples/warming_up_to_rl/configs/rl_from_ft.toml +58 -0
  403. examples/warming_up_to_rl/export_trace_sft.py +837 -0
  404. examples/warming_up_to_rl/groq_test.py +97 -0
  405. examples/warming_up_to_rl/manage_secrets.py +131 -0
  406. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  407. examples/warming_up_to_rl/old/notes.md +73 -0
  408. examples/warming_up_to_rl/readme.md +110 -0
  409. examples/warming_up_to_rl/run_eval.py +736 -0
  410. examples/warming_up_to_rl/run_fft_and_save.py +380 -0
  411. examples/warming_up_to_rl/run_local_rollout.py +239 -0
  412. examples/warming_up_to_rl/run_local_rollout_modal.py +248 -0
  413. examples/warming_up_to_rl/run_local_rollout_parallel.py +405 -0
  414. examples/warming_up_to_rl/run_local_rollout_traced.py +477 -0
  415. examples/warming_up_to_rl/run_rl_and_save.py +124 -0
  416. examples/warming_up_to_rl/run_rollout_remote.py +156 -0
  417. examples/warming_up_to_rl/task_app/README.md +42 -0
  418. examples/warming_up_to_rl/task_app/grpo_crafter.py +876 -0
  419. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
  420. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  421. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  422. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
  423. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
  424. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  425. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  426. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  427. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
  428. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
  429. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
  430. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
  431. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  432. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +253 -0
  433. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  434. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +729 -0
  435. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
  436. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1114 -0
  437. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
  438. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1891 -0
  439. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  440. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
  441. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
  442. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
  443. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +129 -0
  444. examples/workflows/math_rl/configs/eval_base_qwen.toml +15 -0
  445. examples/workflows/math_rl/configs/eval_rl_qwen.toml +11 -0
  446. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +62 -0
  447. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +80 -0
  448. examples/workflows/math_rl/configs/rl_from_ft_qwen.toml +35 -0
  449. examples/workflows/math_rl/download_dataset.py +80 -0
  450. examples/workflows/math_rl/run_eval.py +436 -0
  451. examples/workflows/math_rl/run_rl_and_save.py +111 -0
  452. synth_ai/__init__.py +47 -23
  453. synth_ai/_utils/__init__.py +47 -0
  454. synth_ai/_utils/base_url.py +10 -0
  455. synth_ai/_utils/http.py +10 -0
  456. synth_ai/_utils/prompts.py +10 -0
  457. synth_ai/_utils/task_app_state.py +12 -0
  458. synth_ai/_utils/user_config.py +10 -0
  459. synth_ai/api/models/supported.py +514 -0
  460. synth_ai/api/train/__init__.py +63 -0
  461. synth_ai/api/train/builders.py +473 -0
  462. synth_ai/api/train/cli.py +1185 -0
  463. synth_ai/api/train/config_finder.py +246 -0
  464. synth_ai/api/train/configs/__init__.py +65 -0
  465. synth_ai/api/train/configs/prompt_learning.py +496 -0
  466. synth_ai/api/train/configs/rl.py +188 -0
  467. synth_ai/api/train/configs/sft.py +99 -0
  468. synth_ai/api/train/configs/shared.py +81 -0
  469. synth_ai/api/train/env_resolver.py +352 -0
  470. synth_ai/api/train/pollers.py +91 -0
  471. synth_ai/api/train/prompt_learning.py +425 -0
  472. synth_ai/api/train/sft.py +390 -0
  473. synth_ai/api/train/supported_algos.py +147 -0
  474. synth_ai/api/train/task_app.py +195 -0
  475. synth_ai/api/train/utils.py +244 -0
  476. synth_ai/api/train/validators.py +1117 -0
  477. synth_ai/api/tunnel.py +49 -0
  478. synth_ai/auth/credentials.py +94 -0
  479. synth_ai/baseline/__init__.py +25 -0
  480. synth_ai/baseline/config.py +209 -0
  481. synth_ai/baseline/discovery.py +214 -0
  482. synth_ai/baseline/execution.py +146 -0
  483. synth_ai/cfgs.py +227 -0
  484. synth_ai/cli/__init__.py +90 -45
  485. synth_ai/cli/_modal_wrapper.py +31 -0
  486. synth_ai/cli/_storage.py +20 -0
  487. synth_ai/cli/_typer_patch.py +47 -0
  488. synth_ai/cli/_validate_task_app.py +29 -0
  489. synth_ai/cli/balance.py +16 -4
  490. synth_ai/cli/calc.py +36 -21
  491. synth_ai/cli/claude.py +70 -0
  492. synth_ai/cli/codex.py +267 -0
  493. synth_ai/cli/commands/__init__.py +18 -0
  494. synth_ai/cli/commands/baseline/__init__.py +12 -0
  495. synth_ai/cli/commands/baseline/core.py +637 -0
  496. synth_ai/cli/commands/baseline/list.py +93 -0
  497. synth_ai/cli/commands/demo/__init__.py +6 -0
  498. synth_ai/cli/commands/demo/core.py +163 -0
  499. synth_ai/cli/commands/eval/__init__.py +19 -0
  500. synth_ai/cli/commands/eval/core.py +1112 -0
  501. synth_ai/cli/commands/eval/errors.py +81 -0
  502. synth_ai/cli/commands/eval/validation.py +133 -0
  503. synth_ai/cli/commands/filter/__init__.py +12 -0
  504. synth_ai/cli/commands/filter/core.py +424 -0
  505. synth_ai/cli/commands/filter/errors.py +55 -0
  506. synth_ai/cli/commands/filter/validation.py +77 -0
  507. synth_ai/cli/commands/help/__init__.py +185 -0
  508. synth_ai/cli/commands/help/core.py +72 -0
  509. synth_ai/cli/commands/smoke/__init__.py +7 -0
  510. synth_ai/cli/commands/smoke/core.py +1437 -0
  511. synth_ai/cli/commands/status/__init__.py +66 -0
  512. synth_ai/cli/commands/status/client.py +192 -0
  513. synth_ai/cli/commands/status/config.py +92 -0
  514. synth_ai/cli/commands/status/errors.py +20 -0
  515. synth_ai/cli/commands/status/formatters.py +164 -0
  516. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  517. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  518. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  519. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  520. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  521. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  522. synth_ai/cli/commands/status/subcommands/session.py +183 -0
  523. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  524. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  525. synth_ai/cli/commands/status/utils.py +114 -0
  526. synth_ai/cli/commands/train/__init__.py +53 -0
  527. synth_ai/cli/commands/train/core.py +21 -0
  528. synth_ai/cli/commands/train/errors.py +117 -0
  529. synth_ai/cli/commands/train/judge_schemas.py +200 -0
  530. synth_ai/cli/commands/train/judge_validation.py +305 -0
  531. synth_ai/cli/commands/train/validation.py +386 -0
  532. synth_ai/cli/demo.py +32 -140
  533. synth_ai/cli/deploy.py +233 -0
  534. synth_ai/cli/eval/__init__.py +36 -0
  535. synth_ai/cli/eval/core.py +5 -0
  536. synth_ai/cli/eval/errors.py +31 -0
  537. synth_ai/cli/eval/validation.py +5 -0
  538. synth_ai/cli/filter/__init__.py +28 -0
  539. synth_ai/cli/filter/core.py +5 -0
  540. synth_ai/cli/filter/errors.py +23 -0
  541. synth_ai/cli/filter/validation.py +5 -0
  542. synth_ai/cli/legacy_root_backup.py +28 -22
  543. synth_ai/cli/lib/__init__.py +10 -0
  544. synth_ai/cli/lib/task_app_discovery.py +7 -0
  545. synth_ai/cli/lib/task_app_env.py +518 -0
  546. synth_ai/cli/mcp.py +34 -0
  547. synth_ai/cli/modal_serve/__init__.py +12 -0
  548. synth_ai/cli/modal_serve/core.py +14 -0
  549. synth_ai/cli/modal_serve/errors.py +8 -0
  550. synth_ai/cli/modal_serve/validation.py +11 -0
  551. synth_ai/cli/opencode.py +256 -0
  552. synth_ai/cli/recent.py +13 -7
  553. synth_ai/cli/rl_demo.py +166 -114
  554. synth_ai/cli/root.py +143 -112
  555. synth_ai/cli/serve/__init__.py +12 -0
  556. synth_ai/cli/serve/core.py +14 -0
  557. synth_ai/cli/serve/errors.py +8 -0
  558. synth_ai/cli/serve/validation.py +11 -0
  559. synth_ai/cli/setup.py +49 -0
  560. synth_ai/cli/status.py +7 -125
  561. synth_ai/cli/task_app_deploy.py +7 -0
  562. synth_ai/cli/task_app_list.py +25 -0
  563. synth_ai/cli/task_app_modal_serve.py +11 -0
  564. synth_ai/cli/task_app_serve.py +11 -0
  565. synth_ai/cli/task_apps.py +3134 -0
  566. synth_ai/cli/traces.py +9 -5
  567. synth_ai/cli/train/__init__.py +12 -0
  568. synth_ai/cli/train/core.py +21 -0
  569. synth_ai/cli/train/errors.py +8 -0
  570. synth_ai/cli/train/validation.py +24 -0
  571. synth_ai/cli/train.py +5 -0
  572. synth_ai/cli/turso.py +73 -0
  573. synth_ai/cli/watch.py +13 -18
  574. synth_ai/demos/__init__.py +10 -0
  575. synth_ai/demos/core/__init__.py +28 -1
  576. synth_ai/demos/core/cli.py +745 -416
  577. synth_ai/demos/crafter/__init__.py +1 -0
  578. synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
  579. synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
  580. synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
  581. synth_ai/demos/demo_registry.py +176 -0
  582. synth_ai/demos/demo_task_apps/__init__.py +7 -1
  583. synth_ai/demos/demo_task_apps/core.py +75 -37
  584. synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
  585. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +53 -0
  586. synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
  587. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +184 -0
  588. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  589. synth_ai/demos/demo_task_apps/math/app.py +2 -1
  590. synth_ai/demos/demo_task_apps/math/config.toml +55 -110
  591. synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
  592. synth_ai/demos/demo_task_apps/math/modal_task_app.py +491 -166
  593. synth_ai/demos/demo_task_apps/math/task_app_entry.py +37 -0
  594. synth_ai/demos/math/__init__.py +1 -0
  595. synth_ai/demos/math/_common.py +16 -0
  596. synth_ai/demos/math/app.py +38 -0
  597. synth_ai/demos/math/config.toml +76 -0
  598. synth_ai/demos/math/deploy_modal.py +54 -0
  599. synth_ai/demos/math/modal_task_app.py +703 -0
  600. synth_ai/demos/math/task_app_entry.py +51 -0
  601. synth_ai/environments/environment/core.py +7 -1
  602. synth_ai/environments/examples/bandit/engine.py +12 -5
  603. synth_ai/environments/examples/bandit/environment.py +0 -1
  604. synth_ai/environments/examples/bandit/taskset.py +4 -4
  605. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  606. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  607. synth_ai/environments/examples/crafter_classic/environment.py +93 -2
  608. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  609. synth_ai/environments/examples/enron/engine.py +7 -2
  610. synth_ai/environments/examples/enron/environment.py +68 -0
  611. synth_ai/environments/examples/red/engine.py +60 -12
  612. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  613. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  614. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  615. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  616. synth_ai/environments/examples/red/environment.py +86 -0
  617. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  618. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  619. synth_ai/environments/examples/verilog/engine.py +104 -12
  620. synth_ai/environments/examples/wordle/environment.py +0 -1
  621. synth_ai/environments/reproducibility/tree.py +5 -6
  622. synth_ai/environments/service/app.py +11 -12
  623. synth_ai/environments/service/core_routes.py +10 -9
  624. synth_ai/environments/stateful/engine.py +1 -1
  625. synth_ai/environments/tasks/core.py +1 -0
  626. synth_ai/environments/tasks/filters.py +5 -6
  627. synth_ai/environments/tasks/utils.py +4 -5
  628. synth_ai/evals/__init__.py +15 -0
  629. synth_ai/evals/base.py +14 -5
  630. synth_ai/evals/client.py +82 -0
  631. synth_ai/evals/types.py +42 -0
  632. synth_ai/http.py +8 -22
  633. synth_ai/http_client.py +45 -12
  634. synth_ai/inference/__init__.py +0 -2
  635. synth_ai/inference/client.py +21 -7
  636. synth_ai/jobs/client.py +129 -80
  637. synth_ai/judge_schemas.py +127 -0
  638. synth_ai/learning/__init__.py +51 -6
  639. synth_ai/learning/algorithms.py +14 -0
  640. synth_ai/learning/client.py +122 -30
  641. synth_ai/learning/config.py +2 -40
  642. synth_ai/learning/constants.py +0 -2
  643. synth_ai/learning/ft_client.py +4 -56
  644. synth_ai/learning/health.py +14 -8
  645. synth_ai/learning/jobs.py +43 -47
  646. synth_ai/learning/prompt_learning_client.py +276 -0
  647. synth_ai/learning/prompt_learning_types.py +185 -0
  648. synth_ai/{rl → learning/rl}/__init__.py +14 -5
  649. synth_ai/learning/rl/client.py +269 -0
  650. synth_ai/learning/rl/config.py +31 -0
  651. synth_ai/{rl → learning/rl}/contracts.py +5 -10
  652. synth_ai/{rl → learning/rl}/env_keys.py +45 -16
  653. synth_ai/learning/rl/secrets.py +13 -0
  654. synth_ai/learning/rl_client.py +2 -253
  655. synth_ai/learning/sft/__init__.py +29 -0
  656. synth_ai/learning/sft/client.py +68 -0
  657. synth_ai/learning/sft/config.py +270 -0
  658. synth_ai/learning/sft/data.py +698 -0
  659. synth_ai/learning/sse.py +25 -26
  660. synth_ai/learning/validators.py +29 -25
  661. synth_ai/mcp/__init__.py +5 -0
  662. synth_ai/mcp/__main__.py +8 -0
  663. synth_ai/mcp/main.py +254 -0
  664. synth_ai/mcp/setup.py +100 -0
  665. synth_ai/modal.py +257 -0
  666. synth_ai/pricing/__init__.py +3 -0
  667. synth_ai/pricing/model_pricing.py +64 -0
  668. synth_ai/session/__init__.py +75 -0
  669. synth_ai/session/client.py +383 -0
  670. synth_ai/session/constants.py +63 -0
  671. synth_ai/session/exceptions.py +105 -0
  672. synth_ai/session/manager.py +139 -0
  673. synth_ai/session/models.py +89 -0
  674. synth_ai/session/query.py +110 -0
  675. synth_ai/spec/__init__.py +46 -0
  676. synth_ai/spec/dataclasses.py +149 -0
  677. synth_ai/spec/loader.py +144 -0
  678. synth_ai/spec/serializer.py +199 -0
  679. synth_ai/spec/validation.py +250 -0
  680. synth_ai/streaming/__init__.py +29 -0
  681. synth_ai/streaming/config.py +94 -0
  682. synth_ai/streaming/handlers.py +589 -0
  683. synth_ai/streaming/streamer.py +320 -0
  684. synth_ai/streaming/types.py +95 -0
  685. synth_ai/task/__init__.py +116 -3
  686. synth_ai/task/apps/__init__.py +132 -0
  687. synth_ai/task/auth.py +165 -0
  688. synth_ai/task/client.py +167 -0
  689. synth_ai/task/config.py +261 -0
  690. synth_ai/task/contracts.py +173 -57
  691. synth_ai/task/datasets.py +108 -0
  692. synth_ai/task/errors.py +50 -0
  693. synth_ai/task/health.py +17 -11
  694. synth_ai/task/inference_api.py +101 -0
  695. synth_ai/task/json.py +111 -0
  696. synth_ai/task/proxy.py +251 -0
  697. synth_ai/task/rubrics/__init__.py +55 -0
  698. synth_ai/task/rubrics/loaders.py +156 -0
  699. synth_ai/task/rubrics/models.py +57 -0
  700. synth_ai/task/rubrics/scoring.py +116 -0
  701. synth_ai/task/rubrics/strict.py +149 -0
  702. synth_ai/task/rubrics.py +219 -0
  703. synth_ai/task/server.py +432 -0
  704. synth_ai/task/trace_correlation_helpers.py +328 -0
  705. synth_ai/task/tracing_utils.py +95 -0
  706. synth_ai/task/validators.py +449 -6
  707. synth_ai/task/vendors.py +59 -0
  708. synth_ai/tracing_v3/__init__.py +4 -0
  709. synth_ai/tracing_v3/abstractions.py +21 -4
  710. synth_ai/tracing_v3/config.py +167 -22
  711. synth_ai/tracing_v3/constants.py +21 -0
  712. synth_ai/tracing_v3/db_config.py +42 -29
  713. synth_ai/tracing_v3/decorators.py +80 -45
  714. synth_ai/tracing_v3/examples/basic_usage.py +15 -9
  715. synth_ai/tracing_v3/hooks.py +6 -4
  716. synth_ai/tracing_v3/llm_call_record_helpers.py +161 -61
  717. synth_ai/tracing_v3/migration_helper.py +1 -2
  718. synth_ai/tracing_v3/replica_sync.py +12 -7
  719. synth_ai/tracing_v3/serialization.py +130 -0
  720. synth_ai/tracing_v3/session_tracer.py +86 -21
  721. synth_ai/tracing_v3/storage/base.py +98 -12
  722. synth_ai/tracing_v3/storage/config.py +63 -16
  723. synth_ai/tracing_v3/storage/factory.py +11 -9
  724. synth_ai/tracing_v3/storage/utils.py +15 -11
  725. synth_ai/tracing_v3/trace_utils.py +317 -0
  726. synth_ai/tracing_v3/turso/__init__.py +8 -21
  727. synth_ai/tracing_v3/turso/daemon.py +123 -15
  728. synth_ai/tracing_v3/turso/models.py +5 -2
  729. synth_ai/tracing_v3/turso/native_manager.py +1293 -0
  730. synth_ai/tracing_v3/utils.py +5 -4
  731. synth_ai/tunnel.py +143 -0
  732. synth_ai/tunnel_deploy.py +278 -0
  733. synth_ai/types.py +8 -0
  734. synth_ai/urls.py +11 -0
  735. synth_ai/utils/__init__.py +166 -0
  736. synth_ai/utils/agents.py +74 -0
  737. synth_ai/utils/apps.py +152 -0
  738. synth_ai/utils/base_url.py +94 -0
  739. synth_ai/utils/bin.py +39 -0
  740. synth_ai/utils/claude.py +36 -0
  741. synth_ai/utils/cli.py +284 -0
  742. synth_ai/utils/config.py +81 -0
  743. synth_ai/utils/env.py +346 -0
  744. synth_ai/utils/errors.py +85 -0
  745. synth_ai/utils/http.py +172 -0
  746. synth_ai/utils/json.py +72 -0
  747. synth_ai/utils/log_filter.py +99 -0
  748. synth_ai/utils/logging.py +198 -0
  749. synth_ai/utils/modal.py +299 -0
  750. synth_ai/utils/paths.py +95 -0
  751. synth_ai/utils/process.py +233 -0
  752. synth_ai/utils/prompts.py +39 -0
  753. synth_ai/utils/sqld.py +122 -0
  754. synth_ai/utils/ssl.py +25 -0
  755. synth_ai/utils/task_app_discovery.py +882 -0
  756. synth_ai/utils/task_app_env.py +186 -0
  757. synth_ai/utils/task_app_state.py +318 -0
  758. synth_ai/utils/tunnel/__init__.py +12 -0
  759. synth_ai/utils/tunnel/config.py +55 -0
  760. synth_ai/utils/user_config.py +137 -0
  761. synth_ai/uvicorn.py +77 -0
  762. synth_ai-0.2.23.dev3.dist-info/METADATA +357 -0
  763. synth_ai-0.2.23.dev3.dist-info/RECORD +983 -0
  764. {synth_ai-0.2.8.dev4.dist-info → synth_ai-0.2.23.dev3.dist-info}/entry_points.txt +0 -1
  765. {synth_ai-0.2.8.dev4.dist-info → synth_ai-0.2.23.dev3.dist-info}/top_level.txt +1 -0
  766. synth_ai/cli/man.py +0 -106
  767. synth_ai/core/experiment.py +0 -15
  768. synth_ai/core/system.py +0 -15
  769. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  770. synth_ai/experimental/synth_oss.py +0 -446
  771. synth_ai/handshake.py +0 -63
  772. synth_ai/install_sqld.sh +0 -40
  773. synth_ai/learning/offline/dpo.py +0 -0
  774. synth_ai/learning/offline/providers.py +0 -7
  775. synth_ai/learning/offline/sft.py +0 -0
  776. synth_ai/learning/offline/shared.py +0 -0
  777. synth_ai/learning/online/grpo.py +0 -0
  778. synth_ai/learning/online/irft.py +0 -0
  779. synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
  780. synth_ai/learning/prompts/gepa.py +0 -0
  781. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
  782. synth_ai/learning/prompts/mipro.py +0 -289
  783. synth_ai/learning/prompts/random_search.py +0 -246
  784. synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
  785. synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
  786. synth_ai/lm/__init__.py +0 -51
  787. synth_ai/lm/caching/constants.py +0 -6
  788. synth_ai/lm/caching/dbs.py +0 -0
  789. synth_ai/lm/caching/ephemeral.py +0 -102
  790. synth_ai/lm/caching/handler.py +0 -137
  791. synth_ai/lm/caching/initialize.py +0 -11
  792. synth_ai/lm/caching/persistent.py +0 -114
  793. synth_ai/lm/config.py +0 -110
  794. synth_ai/lm/constants.py +0 -32
  795. synth_ai/lm/core/__init__.py +0 -8
  796. synth_ai/lm/core/all.py +0 -73
  797. synth_ai/lm/core/exceptions.py +0 -7
  798. synth_ai/lm/core/main.py +0 -319
  799. synth_ai/lm/core/main_v3.py +0 -594
  800. synth_ai/lm/core/synth_models.py +0 -48
  801. synth_ai/lm/core/vendor_clients.py +0 -188
  802. synth_ai/lm/cost/monitor.py +0 -1
  803. synth_ai/lm/cost/statefulness.py +0 -1
  804. synth_ai/lm/injection.py +0 -80
  805. synth_ai/lm/overrides.py +0 -206
  806. synth_ai/lm/provider_support/__init__.py +0 -8
  807. synth_ai/lm/provider_support/anthropic.py +0 -972
  808. synth_ai/lm/provider_support/openai.py +0 -1139
  809. synth_ai/lm/provider_support/suppress_logging.py +0 -31
  810. synth_ai/lm/structured_outputs/handler.py +0 -440
  811. synth_ai/lm/structured_outputs/inject.py +0 -297
  812. synth_ai/lm/structured_outputs/rehabilitate.py +0 -185
  813. synth_ai/lm/tools/__init__.py +0 -3
  814. synth_ai/lm/tools/base.py +0 -172
  815. synth_ai/lm/unified_interface.py +0 -202
  816. synth_ai/lm/vendors/base.py +0 -81
  817. synth_ai/lm/vendors/core/anthropic_api.py +0 -387
  818. synth_ai/lm/vendors/core/gemini_api.py +0 -292
  819. synth_ai/lm/vendors/core/mistral_api.py +0 -322
  820. synth_ai/lm/vendors/core/openai_api.py +0 -225
  821. synth_ai/lm/vendors/core/synth_dev_api.py +0 -0
  822. synth_ai/lm/vendors/local/ollama.py +0 -0
  823. synth_ai/lm/vendors/openai_standard.py +0 -780
  824. synth_ai/lm/vendors/openai_standard_responses.py +0 -256
  825. synth_ai/lm/vendors/retries.py +0 -22
  826. synth_ai/lm/vendors/supported/custom_endpoint.py +0 -417
  827. synth_ai/lm/vendors/supported/deepseek.py +0 -69
  828. synth_ai/lm/vendors/supported/grok.py +0 -75
  829. synth_ai/lm/vendors/supported/groq.py +0 -16
  830. synth_ai/lm/vendors/supported/ollama.py +0 -15
  831. synth_ai/lm/vendors/supported/openrouter.py +0 -74
  832. synth_ai/lm/vendors/supported/together.py +0 -11
  833. synth_ai/lm/vendors/synth_client.py +0 -808
  834. synth_ai/lm/warmup.py +0 -186
  835. synth_ai/rl/secrets.py +0 -19
  836. synth_ai/scripts/verify_rewards.py +0 -100
  837. synth_ai/tracing/__init__.py +0 -30
  838. synth_ai/tracing_v1/__init__.py +0 -33
  839. synth_ai/tracing_v3/turso/manager.py +0 -760
  840. synth_ai/v0/tracing/abstractions.py +0 -224
  841. synth_ai/v0/tracing/base_client.py +0 -91
  842. synth_ai/v0/tracing/client_manager.py +0 -131
  843. synth_ai/v0/tracing/config.py +0 -142
  844. synth_ai/v0/tracing/context.py +0 -146
  845. synth_ai/v0/tracing/decorators.py +0 -682
  846. synth_ai/v0/tracing/events/__init__.py +0 -0
  847. synth_ai/v0/tracing/events/manage.py +0 -147
  848. synth_ai/v0/tracing/events/scope.py +0 -86
  849. synth_ai/v0/tracing/events/store.py +0 -228
  850. synth_ai/v0/tracing/immediate_client.py +0 -151
  851. synth_ai/v0/tracing/local.py +0 -18
  852. synth_ai/v0/tracing/log_client_base.py +0 -73
  853. synth_ai/v0/tracing/retry_queue.py +0 -186
  854. synth_ai/v0/tracing/trackers.py +0 -515
  855. synth_ai/v0/tracing/upload.py +0 -512
  856. synth_ai/v0/tracing/utils.py +0 -9
  857. synth_ai/v0/tracing_v1/__init__.py +0 -16
  858. synth_ai/v0/tracing_v1/abstractions.py +0 -224
  859. synth_ai/v0/tracing_v1/base_client.py +0 -91
  860. synth_ai/v0/tracing_v1/client_manager.py +0 -131
  861. synth_ai/v0/tracing_v1/config.py +0 -142
  862. synth_ai/v0/tracing_v1/context.py +0 -146
  863. synth_ai/v0/tracing_v1/decorators.py +0 -703
  864. synth_ai/v0/tracing_v1/events/__init__.py +0 -0
  865. synth_ai/v0/tracing_v1/events/manage.py +0 -147
  866. synth_ai/v0/tracing_v1/events/scope.py +0 -86
  867. synth_ai/v0/tracing_v1/events/store.py +0 -228
  868. synth_ai/v0/tracing_v1/immediate_client.py +0 -151
  869. synth_ai/v0/tracing_v1/local.py +0 -18
  870. synth_ai/v0/tracing_v1/log_client_base.py +0 -73
  871. synth_ai/v0/tracing_v1/retry_queue.py +0 -186
  872. synth_ai/v0/tracing_v1/trackers.py +0 -515
  873. synth_ai/v0/tracing_v1/upload.py +0 -527
  874. synth_ai/v0/tracing_v1/utils.py +0 -9
  875. synth_ai/zyk/__init__.py +0 -30
  876. synth_ai-0.2.8.dev4.dist-info/METADATA +0 -129
  877. synth_ai-0.2.8.dev4.dist-info/RECORD +0 -420
  878. {synth_ai/lm/caching → examples/task_apps}/__init__.py +0 -0
  879. {synth_ai/lm/cost → examples/task_apps/crafter}/__init__.py +0 -0
  880. {synth_ai/lm/structured_outputs → examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server}/__init__.py +0 -0
  881. {synth_ai/lm/vendors → examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests}/__init__.py +0 -0
  882. {synth_ai/lm/vendors/core → examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils}/__init__.py +0 -0
  883. {synth_ai/lm/vendors/local → examples/task_apps/math}/__init__.py +0 -0
  884. {synth_ai/lm/vendors/supported → examples/workflows}/__init__.py +0 -0
  885. {synth_ai/v0/tracing → examples/workflows/math_rl}/__init__.py +0 -0
  886. /synth_ai/{compound/cais.py → cli/__main__.py} +0 -0
  887. /synth_ai/{learning/filtering.py → py.typed} +0 -0
  888. {synth_ai-0.2.8.dev4.dist-info → synth_ai-0.2.23.dev3.dist-info}/WHEEL +0 -0
  889. {synth_ai-0.2.8.dev4.dist-info → synth_ai-0.2.23.dev3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1891 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time as _time
8
+ from datetime import datetime
9
+ from typing import Any
10
+
11
+ from fastapi import APIRouter, HTTPException, Request, status
12
+ from pydantic import BaseModel
13
+ from synth_ai.tracing_v3 import BaseLMResponse
14
+ from synth_ai.task.tracing_utils import unique_sft_path
15
+ from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
16
+ from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
17
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
18
+
19
+ from .registry import registry
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ # --- Seeding utilities (robust, optional deps) ---
25
+ def _set_global_seed(seed_value: int) -> dict[str, Any]:
26
+ """Set global RNG seeds across common libraries; return details for logging/restoration.
27
+
28
+ Returns a dict containing which libraries were seeded and prior states if obtainable.
29
+ """
30
+ seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
31
+ with contextlib.suppress(Exception):
32
+ import random as _random # type: ignore
33
+
34
+ _random.seed(seed_value)
35
+ seeded["libs"].append("random")
36
+ with contextlib.suppress(Exception):
37
+ import numpy as _np # type: ignore
38
+
39
+ _np.random.seed(seed_value)
40
+ seeded["libs"].append("numpy")
41
+ with contextlib.suppress(Exception):
42
+ import torch as _torch # type: ignore
43
+
44
+ if hasattr(_torch, "manual_seed"):
45
+ _torch.manual_seed(seed_value)
46
+ seeded["libs"].append("torch")
47
+ # Make CUDA deterministic if present (best-effort)
48
+ with contextlib.suppress(Exception):
49
+ if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
50
+ _torch.cuda.manual_seed_all(seed_value)
51
+ seeded.setdefault("cuda", True)
52
+ # CUDNN deterministic flags (optional)
53
+ with contextlib.suppress(Exception):
54
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
55
+ _torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
56
+ _torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
57
+ return seeded
58
+
59
+
60
+ def _clear_seed_side_effects() -> None:
61
+ """Best-effort cleanup to avoid global deterministic side-effects between requests."""
62
+ # We cannot truly restore prior RNG states without capturing them; we just avoid
63
+ # leaving aggressive deterministic flags enabled where it matters.
64
+ with contextlib.suppress(Exception):
65
+ import torch as _torch # type: ignore
66
+
67
+ with contextlib.suppress(Exception):
68
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
69
+ # Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
70
+ # We'll keep deterministic False to avoid global impact; benchmark left False for stability.
71
+ _torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
72
+
73
+
74
+ router = APIRouter()
75
+
76
+
77
+ class RolloutEnvSpec(BaseModel):
78
+ env_id: str | None = None
79
+ env_name: str | None = None
80
+ config: dict[str, Any] = {}
81
+ seed: int | None = None
82
+
83
+
84
+ class RolloutPolicySpec(BaseModel):
85
+ policy_id: str | None = None
86
+ policy_name: str | None = None
87
+ config: dict[str, Any] = {}
88
+
89
+
90
+ class RolloutBranchConfig(BaseModel):
91
+ branch_every_n_steps: int = 0
92
+ branch_on_condition: str | None = None
93
+ max_branches: int = 0
94
+ branch_policy: bool = False
95
+ branch_env: bool = False
96
+
97
+
98
+ class RolloutRecordConfig(BaseModel):
99
+ trajectories: bool = True
100
+ logprobs: bool = False
101
+ value: bool = False
102
+ return_trace: bool = False
103
+ trace_format: str = "compact"
104
+
105
+
106
+ class RolloutSafetyConfig(BaseModel):
107
+ max_ops: int = 100000
108
+ max_time_s: float = 3600.0
109
+
110
+
111
+ class RolloutRequest(BaseModel):
112
+ run_id: str
113
+ env: RolloutEnvSpec
114
+ policy: RolloutPolicySpec
115
+ ops: list[str] # ["agent", "env", ...]
116
+ record: RolloutRecordConfig = RolloutRecordConfig()
117
+ on_done: str = "reset" # "reset" | "terminate"
118
+ branch: RolloutBranchConfig | None = None
119
+ safety: RolloutSafetyConfig = RolloutSafetyConfig()
120
+ # Optional run/session context
121
+ training_session_id: str | None = None
122
+ synth_base_url: str | None = None
123
+
124
+
125
+ class RolloutStep(BaseModel):
126
+ obs: dict[str, Any]
127
+ tool_calls: list[dict[str, Any]]
128
+ reward: float | None = None
129
+ done: bool = False
130
+ truncated: bool | None = None
131
+ logprob: float | None = None
132
+ value: float | None = None
133
+ info: dict[str, Any] | None = None
134
+
135
+
136
+ class RolloutTrajectory(BaseModel):
137
+ env_id: str
138
+ policy_id: str
139
+ steps: list[RolloutStep]
140
+ final: dict[str, Any] | None = None
141
+ length: int
142
+ decision_samples: list[dict[str, Any]] | None = None
143
+
144
+
145
+ def compute_stepwise_reward(
146
+ prev_achievements: dict[str, bool],
147
+ new_achievements: dict[str, bool],
148
+ decision_index: int,
149
+ actions_summary: list[dict[str, Any]],
150
+ indicator_lambda: float,
151
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
152
+ """Compute stepwise reward metadata given achievement states before/after a decision."""
153
+
154
+ prev_map = prev_achievements or {}
155
+ next_map = new_achievements or {}
156
+
157
+ unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
158
+ indicator = 1 if unlocked else 0
159
+ reward_value = float(indicator_lambda) * indicator
160
+
161
+ stepwise_info = {
162
+ "decision_index": decision_index,
163
+ "indicator": indicator,
164
+ "new_achievements": unlocked,
165
+ "reward": reward_value,
166
+ }
167
+ decision_sample = {
168
+ "decision_index": decision_index,
169
+ "indicator": indicator,
170
+ "r_i": reward_value,
171
+ "actions": actions_summary,
172
+ }
173
+ stats = {
174
+ "indicator": float(indicator),
175
+ "reward": reward_value,
176
+ "new_achievements_count": float(len(unlocked)),
177
+ }
178
+ return stepwise_info, decision_sample, stats
179
+
180
+
181
+ class RolloutMetrics(BaseModel):
182
+ episode_returns: list[float]
183
+ mean_return: float
184
+ num_steps: int
185
+ num_episodes: int = 0
186
+
187
+
188
+ class RolloutResponse(BaseModel):
189
+ run_id: str
190
+ trajectories: list[RolloutTrajectory]
191
+ branches: dict[str, list[str]] = {}
192
+ metrics: RolloutMetrics
193
+ aborted: bool = False
194
+ ops_executed: int = 0
195
+ trace: dict[str, Any] | None = None
196
+
197
+
198
+ class RolloutTracingContext:
199
+ """Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
200
+
201
+ def __init__(
202
+ self,
203
+ tracer: SessionTracer | None,
204
+ request: RolloutRequest,
205
+ fastapi_request: Request,
206
+ ) -> None:
207
+ self.tracer = tracer
208
+ self.enabled = tracer is not None
209
+ self.request = request
210
+ self.fastapi_request = fastapi_request
211
+ self.run_id = request.run_id
212
+ self.current_step_id: str | None = None
213
+ self.current_turn: int | None = None
214
+ self.lm_calls_summary: list[dict[str, Any]] = []
215
+ self.decision_rewards: list[dict[str, Any]] = []
216
+ self.sft_records: list[dict[str, Any]] = []
217
+ self.latest_system_messages: list[str] = []
218
+ self.latest_user_messages: list[str] = []
219
+ self.latest_system_prompt_content: list[Any] = []
220
+ self.latest_user_prompt_content: list[Any] = []
221
+ self.trace_format = (
222
+ getattr(request.record, "trace_format", "compact") or "compact"
223
+ ).lower()
224
+ self.return_trace = bool(getattr(request.record, "return_trace", False))
225
+ self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
226
+ print(f"[TRACE_DEBUG] RolloutTracingContext init: trace_format={self.trace_format} return_trace={self.return_trace}", flush=True)
227
+ self.session_trace = None
228
+ self.metadata_updates: dict[str, Any] = {}
229
+ self.policy_name = request.policy.policy_name or ""
230
+ self.env_name = request.env.env_name or ""
231
+ self.metadata_base: dict[str, Any] = {
232
+ "run_id": self.run_id,
233
+ "policy_name": self.policy_name,
234
+ "policy_id": request.policy.policy_id,
235
+ "env_name": self.env_name,
236
+ "env_id": request.env.env_id,
237
+ "seed": request.env.seed,
238
+ "training_session_id": request.training_session_id,
239
+ "synth_base_url": request.synth_base_url,
240
+ }
241
+
242
+ # Expose context for downstream calls inside this request lifecycle
243
+ fastapi_request.state.rollout_tracing = self
244
+ fastapi_request.state.rollout_run_id = self.run_id
245
+
246
+ async def start_session(self) -> None:
247
+ if not self.enabled or self.tracer is None:
248
+ print("[TRACE_DEBUG] start_session skipped: tracer disabled", flush=True)
249
+ return
250
+ try:
251
+ await self.tracer.initialize()
252
+ print("[TRACE_DEBUG] tracer initialized", flush=True)
253
+ except Exception as exc:
254
+ logger.debug("TRACING_INIT_FAIL: %s", exc)
255
+ # Hard fail: tracing requested but cannot initialize
256
+ raise
257
+ try:
258
+ await self.tracer.start_session(
259
+ session_id=self.run_id, metadata=dict(self.metadata_base)
260
+ )
261
+ print(f"[TRACE_DEBUG] start_session succeeded for run_id={self.run_id}", flush=True)
262
+ except Exception as exc:
263
+ logger.warning("TRACING_START_FAIL: %s", exc)
264
+ # Hard fail: tracing requested but cannot start session
265
+ raise
266
+
267
+ async def start_decision(self, turn_number: int) -> None:
268
+ self.current_turn = turn_number
269
+ self.current_step_id = f"decision_{turn_number}"
270
+ if not self.enabled or self.tracer is None:
271
+ return
272
+ try:
273
+ await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
274
+ except Exception as exc:
275
+ logger.debug("TRACING_STEP_START_FAIL: %s", exc)
276
+
277
+ async def end_decision(self) -> None:
278
+ if not self.enabled or self.tracer is None:
279
+ return
280
+ try:
281
+ await self.tracer.end_timestep(step_id=self.current_step_id)
282
+ except Exception as exc:
283
+ logger.debug("TRACING_STEP_END_FAIL: %s", exc)
284
+ finally:
285
+ self.current_step_id = None
286
+
287
+ def _message_metadata(self) -> dict[str, Any]:
288
+ return {
289
+ "turn": self.current_turn,
290
+ "step_id": self.current_step_id,
291
+ }
292
+
293
+ async def record_policy_prompts(
294
+ self,
295
+ system_messages: list[Any],
296
+ user_messages: list[Any],
297
+ ) -> None:
298
+ self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
299
+ self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
300
+ self.latest_system_prompt_content = [
301
+ self._prompt_content(entry, role="system") for entry in system_messages
302
+ ]
303
+ self.latest_user_prompt_content = [
304
+ self._prompt_content(entry, role="user") for entry in user_messages
305
+ ]
306
+ if not self.enabled or self.tracer is None:
307
+ return
308
+ for entry in system_messages:
309
+ try:
310
+ await self.tracer.record_message(
311
+ content=self._prompt_payload(entry, role="system"),
312
+ message_type="policy_system_prompt",
313
+ metadata=self._message_metadata(),
314
+ )
315
+ except Exception as exc:
316
+ logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
317
+ for entry in user_messages:
318
+ try:
319
+ await self.tracer.record_message(
320
+ content=self._prompt_payload(entry, role="user"),
321
+ message_type="policy_user_prompt",
322
+ metadata=self._message_metadata(),
323
+ )
324
+ except Exception as exc:
325
+ logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
326
+ if self.tracer and self.tracer._current_trace:
327
+ msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
328
+ print(f"[TRACE_DEBUG] After record_policy_prompts: {msg_count} messages", flush=True)
329
+
330
+ def _content_to_text(self, content: Any) -> str:
331
+ if isinstance(content, str):
332
+ return content
333
+ if isinstance(content, list):
334
+ parts: list[str] = []
335
+ for seg in content:
336
+ if isinstance(seg, dict):
337
+ text_val = seg.get("text") or seg.get("content")
338
+ if isinstance(text_val, str):
339
+ parts.append(text_val)
340
+ return "".join(parts)
341
+ if content is None:
342
+ return ""
343
+ return str(content)
344
+
345
+ def _prompt_text(self, entry: Any) -> str:
346
+ if isinstance(entry, dict):
347
+ text = entry.get("text")
348
+ if isinstance(text, str):
349
+ return text
350
+ content = entry.get("content")
351
+ return self._content_to_text(content)
352
+ return self._content_to_text(entry)
353
+
354
+ def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
355
+ if isinstance(entry, dict):
356
+ payload = dict(entry)
357
+ payload.setdefault("role", role)
358
+ return payload
359
+ return {
360
+ "role": role,
361
+ "text": self._prompt_text(entry),
362
+ "content": entry,
363
+ }
364
+
365
+ def _prompt_content(self, entry: Any, *, role: str) -> Any:
366
+ payload = self._prompt_payload(entry, role=role)
367
+ return payload.get("content", payload.get("text"))
368
+
369
+ def _content_has_image(self, content: Any) -> bool:
370
+ if isinstance(content, list):
371
+ return any(
372
+ isinstance(seg, dict)
373
+ and seg.get("type") in {"image", "image_url"}
374
+ for seg in content
375
+ )
376
+ if isinstance(content, dict):
377
+ if content.get("type") in {"image", "image_url"}:
378
+ return True
379
+ inner = content.get("content")
380
+ if isinstance(inner, list):
381
+ return any(
382
+ isinstance(seg, dict)
383
+ and seg.get("type") in {"image", "image_url"}
384
+ for seg in inner
385
+ )
386
+ return False
387
+
388
+ def _safe_json(self, payload: Any, limit: int = 4000) -> str:
389
+ try:
390
+ text = json.dumps(payload, ensure_ascii=False)
391
+ except Exception:
392
+ text = str(payload)
393
+ if len(text) > limit:
394
+ return text[:limit] + "…"
395
+ return text
396
+
397
+ async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
398
+ if tool_calls is None:
399
+ return
400
+ if self.enabled and self.tracer is not None:
401
+ try:
402
+ await self.tracer.record_message(
403
+ content=self._safe_json(tool_calls),
404
+ message_type="policy_tool_call",
405
+ metadata=self._message_metadata(),
406
+ )
407
+ if self.tracer._current_trace:
408
+ print(
409
+ f"[TRACE_DEBUG] After tool invocation: messages={len(self.tracer._current_trace.markov_blanket_message_history)}",
410
+ flush=True,
411
+ )
412
+ except Exception as exc:
413
+ logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
414
+
415
+ async def _record_event(self, event: Any) -> int | None:
416
+ if not self.enabled or self.tracer is None:
417
+ return None
418
+ try:
419
+ return await self.tracer.record_event(event)
420
+ except Exception as exc:
421
+ logger.debug("TRACING_EVENT_FAIL: %s", exc)
422
+ return None
423
+
424
+ async def record_llm_call(
425
+ self,
426
+ *,
427
+ inference_request: dict[str, Any],
428
+ inference_response: dict[str, Any],
429
+ tool_calls: list[dict[str, Any]] | None,
430
+ provider: str,
431
+ model_name: str,
432
+ started_at: datetime,
433
+ completed_at: datetime,
434
+ latency_ms: int | None,
435
+ ) -> None:
436
+ usage = inference_response.get("usage") or {}
437
+ input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
438
+ output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
439
+ total_tokens = usage.get("total_tokens")
440
+ cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
441
+
442
+ assistant_message = None
443
+ choices = inference_response.get("choices") or []
444
+ if choices:
445
+ assistant_message = choices[0].get("message") or {}
446
+ assistant_content = (
447
+ assistant_message.get("content") if isinstance(assistant_message, dict) else None
448
+ )
449
+
450
+ raw_response = self._content_to_text(assistant_content)
451
+ if not raw_response:
452
+ raw_response = self._safe_json(inference_response, limit=2000)
453
+
454
+ base_response = BaseLMResponse(
455
+ raw_response=raw_response,
456
+ tool_calls=assistant_message.get("tool_calls")
457
+ if isinstance(assistant_message, dict)
458
+ else None,
459
+ usage=usage or None,
460
+ api_type="chat_completions",
461
+ )
462
+
463
+ request_messages = inference_request.get("messages") or []
464
+ try:
465
+ temperature = float(inference_request.get("temperature"))
466
+ except Exception:
467
+ temperature = 0.0
468
+
469
+ call_record = create_llm_call_record_from_response(
470
+ response=base_response,
471
+ model_name=model_name,
472
+ provider=provider,
473
+ messages=request_messages,
474
+ temperature=temperature,
475
+ request_params=inference_request,
476
+ tools=inference_request.get("tools"),
477
+ started_at=started_at,
478
+ completed_at=completed_at,
479
+ latency_ms=latency_ms,
480
+ )
481
+
482
+ event_metadata = {
483
+ "policy_id": self.request.policy.policy_id,
484
+ "turn": self.current_turn,
485
+ "run_id": self.run_id,
486
+ }
487
+
488
+ event = LMCAISEvent(
489
+ system_instance_id=f"policy:{self.policy_name or 'unknown'}",
490
+ time_record=TimeRecord(event_time=completed_at.timestamp()),
491
+ model_name=model_name,
492
+ provider=provider,
493
+ input_tokens=input_tokens,
494
+ output_tokens=output_tokens,
495
+ total_tokens=total_tokens,
496
+ cost_usd=cost_usd,
497
+ latency_ms=latency_ms,
498
+ call_records=[call_record],
499
+ metadata=event_metadata,
500
+ )
501
+
502
+ await self._record_event(event)
503
+
504
+ self.lm_calls_summary.append(
505
+ {
506
+ "turn": self.current_turn,
507
+ "model": model_name,
508
+ "provider": provider,
509
+ "total_tokens": total_tokens,
510
+ "input_tokens": input_tokens,
511
+ "output_tokens": output_tokens,
512
+ "latency_ms": latency_ms,
513
+ "tool_calls": len(tool_calls or []),
514
+ }
515
+ )
516
+
517
+ if self.sft_output_dir is not None:
518
+ assistant_structured = assistant_content if assistant_content is not None else ""
519
+ assistant_text = self._content_to_text(assistant_content)
520
+ dialogue_structured: list[dict[str, Any]] = []
521
+ for content in self.latest_system_prompt_content:
522
+ if content is None:
523
+ continue
524
+ dialogue_structured.append({"role": "system", "content": content})
525
+ for content in self.latest_user_prompt_content:
526
+ if content is None:
527
+ continue
528
+ dialogue_structured.append({"role": "user", "content": content})
529
+ dialogue_text = (
530
+ [{"role": "system", "content": s} for s in self.latest_system_messages]
531
+ + [{"role": "user", "content": u} for u in self.latest_user_messages]
532
+ )
533
+ user_has_image = any(
534
+ self._content_has_image(content) for content in self.latest_user_prompt_content
535
+ )
536
+ assistant_has_image = self._content_has_image(assistant_structured)
537
+ record = {
538
+ "run_id": self.run_id,
539
+ "turn": self.current_turn,
540
+ "model": model_name,
541
+ "provider": provider,
542
+ "dialogue": dialogue_structured,
543
+ "dialogue_text": dialogue_text,
544
+ "assistant": {
545
+ "content": assistant_structured,
546
+ "content_text": assistant_text,
547
+ "tool_calls": assistant_message.get("tool_calls")
548
+ if isinstance(assistant_message, dict)
549
+ else [],
550
+ "has_image": assistant_has_image,
551
+ },
552
+ "metadata": {
553
+ "user_has_image": user_has_image,
554
+ "assistant_has_image": assistant_has_image,
555
+ "has_image": user_has_image or assistant_has_image,
556
+ },
557
+ "timestamp": datetime.utcnow().isoformat(),
558
+ }
559
+ self.sft_records.append(record)
560
+
561
+ async def record_environment_event(
562
+ self,
563
+ *,
564
+ env_handle: Any,
565
+ prev_obs: dict[str, Any] | None,
566
+ env_response: Any,
567
+ next_obs: dict[str, Any] | None,
568
+ metadata: dict[str, Any] | None = None,
569
+ ) -> int | None:
570
+ if not self.enabled or self.tracer is None:
571
+ return None
572
+
573
+ try:
574
+ prev_summary = (
575
+ _summarize_observation_for_storage(env_handle, prev_obs or {})
576
+ if prev_obs is not None
577
+ else None
578
+ )
579
+ except Exception:
580
+ prev_summary = None
581
+ try:
582
+ next_summary = (
583
+ _summarize_observation_for_storage(env_handle, next_obs or {})
584
+ if next_obs is not None
585
+ else None
586
+ )
587
+ except Exception:
588
+ next_summary = None
589
+
590
+ reward_val = getattr(env_response, "reward", None)
591
+ try:
592
+ reward_float = float(reward_val) if reward_val is not None else 0.0
593
+ except Exception:
594
+ reward_float = 0.0
595
+
596
+ event = EnvironmentEvent(
597
+ system_instance_id=f"environment:{self.env_name or 'unknown'}",
598
+ time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
599
+ reward=reward_float,
600
+ terminated=bool(getattr(env_response, "done", False)),
601
+ truncated=bool(getattr(env_response, "truncated", False)),
602
+ system_state_before=prev_summary,
603
+ system_state_after=next_summary,
604
+ metadata={
605
+ "turn": self.current_turn,
606
+ "run_id": self.run_id,
607
+ **(metadata or {}),
608
+ },
609
+ )
610
+
611
+ return await self._record_event(event)
612
+
613
+ async def record_decision_reward(
614
+ self,
615
+ *,
616
+ event_id: int | None,
617
+ decision_meta: dict[str, Any] | None,
618
+ ) -> None:
619
+ decision_meta = decision_meta or {}
620
+ ach_delta = int(decision_meta.get("ach_delta", 0))
621
+ unique_delta = int(decision_meta.get("unique_delta", 0))
622
+ all_ach = list(decision_meta.get("all") or [])
623
+ unique_ach = list(decision_meta.get("unique") or [])
624
+
625
+ self.decision_rewards.append(
626
+ {
627
+ "turn": self.current_turn,
628
+ "ach_delta": ach_delta,
629
+ "unique_delta": unique_delta,
630
+ "achievements": all_ach,
631
+ "unique_achievements": unique_ach,
632
+ }
633
+ )
634
+
635
+ if not self.enabled or self.tracer is None or event_id is None:
636
+ return
637
+ try:
638
+ await self.tracer.record_event_reward(
639
+ event_id=event_id,
640
+ turn_number=self.current_turn,
641
+ reward_value=float(ach_delta),
642
+ reward_type="achievement_delta",
643
+ annotation={"achievements": all_ach},
644
+ source="environment",
645
+ )
646
+ if unique_delta:
647
+ await self.tracer.record_event_reward(
648
+ event_id=event_id,
649
+ turn_number=self.current_turn,
650
+ reward_value=float(unique_delta),
651
+ reward_type="unique_achievement_delta",
652
+ annotation={"achievements": unique_ach},
653
+ source="environment",
654
+ )
655
+ except Exception as exc:
656
+ logger.debug("TRACING_REWARD_FAIL: %s", exc)
657
+
658
+ def update_metadata(self, **kwargs: Any) -> None:
659
+ self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
660
+
661
+ async def finalize(
662
+ self,
663
+ *,
664
+ total_reward: float,
665
+ achievement_state: dict[str, bool] | None,
666
+ total_steps: int,
667
+ ) -> Any:
668
+ final_achievements = [key for key, val in (achievement_state or {}).items() if val]
669
+ self.metadata_updates.setdefault("final_achievements", final_achievements)
670
+ if self.enabled and self.tracer is not None:
671
+ try:
672
+ await self.tracer.record_outcome_reward(
673
+ total_reward=int(total_reward),
674
+ achievements_count=len(final_achievements),
675
+ total_steps=int(total_steps),
676
+ reward_metadata=dict(self.metadata_updates),
677
+ )
678
+ except Exception as exc:
679
+ logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
680
+ try:
681
+ if self.tracer._current_trace:
682
+ msg_count = len(self.tracer._current_trace.markov_blanket_message_history)
683
+ print(f"[TRACE_DEBUG] Before end_session: {msg_count} messages in trace", flush=True)
684
+ self.session_trace = await self.tracer.end_session()
685
+ if self.session_trace is not None:
686
+ self.session_trace.metadata.update(self.metadata_updates)
687
+ print(
688
+ f"[TRACE_DEBUG] Session ended successfully, session_id={self.session_trace.session_id}",
689
+ flush=True,
690
+ )
691
+ print(
692
+ f"[TRACE_DEBUG] session_trace.metadata keys: {list(self.session_trace.metadata.keys())}",
693
+ flush=True,
694
+ )
695
+ except Exception as exc:
696
+ logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
697
+ self.session_trace = None
698
+ print(f"[TRACE_DEBUG] end_session failed for run_id={self.run_id}: {exc}", flush=True)
699
+ with contextlib.suppress(Exception):
700
+ await self.tracer.close()
701
+
702
+ if self.sft_records and self.sft_output_dir:
703
+ self.write_sft_records()
704
+
705
+ # Clear context from request state to avoid leaks
706
+ self.fastapi_request.state.rollout_tracing = None
707
+
708
+ return self.session_trace
709
+
710
+ def write_sft_records(self) -> None:
711
+ if not self.sft_output_dir or not self.sft_records:
712
+ return
713
+ try:
714
+ path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
715
+ path.parent.mkdir(parents=True, exist_ok=True)
716
+ with path.open("w", encoding="utf-8") as fh:
717
+ for record in self.sft_records:
718
+ json.dump(record, fh, ensure_ascii=False)
719
+ fh.write("\n")
720
+ logger.info(f"SFT_WRITTEN: {path}")
721
+ except Exception as exc:
722
+ logger.warning(f"SFT_WRITE_FAIL: {exc}")
723
+ finally:
724
+ self.sft_records.clear()
725
+
726
+ def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
727
+ if not self.return_trace or session_trace is None:
728
+ return None
729
+ if self.trace_format in ("full", "structured"):
730
+ payload = session_trace.to_dict()
731
+ payload.setdefault("metadata", {}).update(self.metadata_updates)
732
+ print(
733
+ f"[TRACE_DEBUG] build_trace_payload returning structured trace with messages={len(payload.get('markov_blanket_message_history') or [])}",
734
+ flush=True,
735
+ )
736
+ return payload
737
+ metadata = dict(session_trace.metadata)
738
+ metadata.update(self.metadata_updates)
739
+ return {
740
+ "session_id": session_trace.session_id,
741
+ "created_at": session_trace.created_at.isoformat(),
742
+ "metadata": metadata,
743
+ "events_count": len(session_trace.event_history),
744
+ "messages_count": len(session_trace.markov_blanket_message_history),
745
+ "lm_calls": self.lm_calls_summary,
746
+ "decision_rewards": self.decision_rewards,
747
+ }
748
+
749
+
750
+ def _summarize_observation_for_storage(
751
+ env_handle: Any, observation: dict[str, Any]
752
+ ) -> dict[str, Any]:
753
+ """Return a compact dict for trajectory storage instead of the raw observation.
754
+
755
+ - For Crafter, use the same summary used for the policy user prompt
756
+ - For others, keep a minimal subset or plain text preview
757
+ """
758
+ # Try Crafter-specific formatter
759
+ crafter_wrapper = None
760
+ with contextlib.suppress(Exception):
761
+ from .envs.crafter.environment import (
762
+ CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
763
+ )
764
+
765
+ crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
766
+
767
+ if crafter_wrapper is not None and isinstance(
768
+ getattr(env_handle, "env", None), crafter_wrapper
769
+ ):
770
+ with contextlib.suppress(Exception):
771
+ from .envs.crafter.shared import format_observation as _fmt # type: ignore
772
+
773
+ text = _fmt(observation or {})
774
+ return {"text": text}
775
+
776
+ # Generic fallback: extract a few small fields if present; avoid huge arrays
777
+ with contextlib.suppress(Exception):
778
+ inv = observation.get("inventory") if isinstance(observation, dict) else None
779
+ ach = observation.get("achievements_status") if isinstance(observation, dict) else None
780
+ pos = observation.get("player_position") if isinstance(observation, dict) else None
781
+ health = None
782
+ if isinstance(inv, dict):
783
+ health = inv.get("health")
784
+ summary = {
785
+ "position": pos,
786
+ "health": health,
787
+ "inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
788
+ if isinstance(inv, dict)
789
+ else None,
790
+ "achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
791
+ if isinstance(ach, dict)
792
+ else None,
793
+ }
794
+ return {"text": json.dumps(summary, ensure_ascii=False)}
795
+
796
+ # Last resort: plain string preview
797
+ try:
798
+ return {"text": str(observation)[:10000]}
799
+ except Exception:
800
+ return {"text": ""}
801
+
802
+
803
+ class RunAbortRequest(BaseModel):
804
+ run_id: str
805
+
806
+
807
+ class RunAbortResponse(BaseModel):
808
+ ok: bool
809
+ run_id: str
810
+
811
+
812
+ class RunStatusResponse(BaseModel):
813
+ run_id: str
814
+ status: str
815
+ started_at: datetime
816
+ finished_at: datetime | None = None
817
+
818
+
819
+ @router.post("/rollout", response_model=RolloutResponse)
820
+ async def execute_rollout(
821
+ request: RolloutRequest,
822
+ req: Request,
823
+ ) -> RolloutResponse:
824
+ """Execute a rollout with coordinated environment and policy steps."""
825
+ # Emit rollout identifier early for correlation
826
+ with contextlib.suppress(Exception):
827
+ _rid = getattr(request, "run_id", None)
828
+ _pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
829
+ _env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
830
+ logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
831
+ print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
832
+ # Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
833
+ try:
834
+ _env_params = {}
835
+ if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
836
+ _env_params = dict(request.env.config.get("env_params") or {})
837
+ max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
838
+ assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
839
+ except Exception as _mse:
840
+ raise HTTPException(
841
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
842
+ detail={
843
+ "error": "invalid_env_params",
844
+ "message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
845
+ },
846
+ ) from _mse
847
+ # Truncate incoming ops to the enforced cap (each step is [agent, env])
848
+ ops_seq: list[str] = list(request.ops or [])
849
+ allowed_ops = max(0, int(max_steps_per_episode) * 2)
850
+ if len(ops_seq) > allowed_ops:
851
+ with contextlib.suppress(Exception):
852
+ logger.info(
853
+ "ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
854
+ str(len(ops_seq)),
855
+ str(allowed_ops),
856
+ )
857
+ ops_seq = ops_seq[:allowed_ops]
858
+ # Simple API key auth for inbound rollout
859
+ header_key = req.headers.get("x-api-key")
860
+ env_key = os.getenv("ENVIRONMENT_API_KEY")
861
+ dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
862
+ # Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
863
+ expected_keys = [k for k in (env_key, dev_key) if k]
864
+ if not expected_keys:
865
+ missing = []
866
+ if not env_key:
867
+ missing.append("ENVIRONMENT_API_KEY")
868
+ if not dev_key:
869
+ missing.append("DEV_ENVIRONMENT_API_KEY")
870
+ msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
871
+ logger.error(msg)
872
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
873
+ if not header_key:
874
+ raise HTTPException(
875
+ status_code=status.HTTP_401_UNAUTHORIZED,
876
+ detail="Invalid or missing API key: X-API-Key header not provided",
877
+ )
878
+ if header_key not in expected_keys:
879
+ # Do not leak secrets; include short prefix for diagnostics
880
+ exp_src = env_key if env_key else (dev_key or "")
881
+ exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
882
+ got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
883
+ raise HTTPException(
884
+ status_code=status.HTTP_401_UNAUTHORIZED,
885
+ detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
886
+ )
887
+
888
+ # Log contextual fields for traceability
889
+ if request.training_session_id:
890
+ logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
891
+ if request.synth_base_url:
892
+ logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
893
+
894
+ # Log masked OpenAI API key presence for diagnostics
895
+ with contextlib.suppress(Exception):
896
+ _oa = os.getenv("OPENAI_API_KEY")
897
+ if _oa:
898
+ _pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
899
+ logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
900
+ else:
901
+ logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
902
+
903
+ # Make synth_base_url available for outbound calls in this app
904
+ with contextlib.suppress(Exception):
905
+ task_app = req.app.state.task_app
906
+ if request.synth_base_url:
907
+ task_app.synth_base_url = request.synth_base_url
908
+
909
+ tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
910
+ tracer_instance: SessionTracer | None = None
911
+ if callable(tracer_factory):
912
+ try:
913
+ inst = tracer_factory()
914
+ tracer_instance = inst if isinstance(inst, SessionTracer) else None
915
+ except Exception as exc:
916
+ logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
917
+ tracing_context = RolloutTracingContext(tracer_instance, request, req)
918
+ await tracing_context.start_session()
919
+
920
+ # Register run
921
+ registry.register_run(request.run_id)
922
+
923
+ # Track resources created during this rollout so we can guarantee cleanup
924
+ created_env_id: str | None = None
925
+ created_policy_id: str | None = None
926
+ env_seed_used: int | None = None
927
+ trajectory_steps: list[RolloutStep] = []
928
+ decision_samples: list[dict[str, Any]] = []
929
+ pending_tool_calls: Any = None
930
+ current_obs: Any = {}
931
+ total_reward: float = 0.0
932
+ ops_executed = 0
933
+ last_agent_response_ts: float | None = None
934
+ last_policy_meta: dict[str, Any] | None = None
935
+ last_env_step_ms: float | None = None
936
+ last_env_step_completed_ts: float | None = None
937
+ decision_open = False
938
+ finalized = False
939
+ prev_achievements: dict[str, bool] = {}
940
+ session_trace = None
941
+ step_rewards_active = False
942
+
943
+ try:
944
+ # Initialize deterministic seed early for the entire rollout
945
+ seed_value: int | None = None
946
+ try:
947
+ if request.env and request.env.seed is not None:
948
+ seed_value = int(request.env.seed)
949
+ else:
950
+ # Derive a stable seed from run_id
951
+ import hashlib as _hashlib # local import to avoid global deps
952
+
953
+ _digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
954
+ # Use lower 32 bits to fit common RNG ranges
955
+ seed_value = int(_digest[:8], 16)
956
+ except Exception:
957
+ # Fallback to time-based seed if anything goes wrong
958
+ try:
959
+ seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
960
+ except Exception:
961
+ seed_value = 42
962
+
963
+ _seed_info = _set_global_seed(int(seed_value))
964
+ with contextlib.suppress(Exception):
965
+ logger.info(
966
+ "ROLL_OUT: RNG seeded seed=%s libs=%s",
967
+ str(_seed_info.get("seed")),
968
+ ",".join(_seed_info.get("libs", [])),
969
+ )
970
+ # Resolve or create environment
971
+ if request.env.env_id:
972
+ env_handle = registry.get_env(request.env.env_id)
973
+ if not env_handle:
974
+ raise HTTPException(
975
+ status_code=404,
976
+ detail=f"Environment {request.env.env_id} not found",
977
+ )
978
+ env_id = request.env.env_id
979
+ else:
980
+ # Create new environment
981
+ from .environment_routes import EnvCreateRequest, create_environment
982
+
983
+ if not request.env.env_name:
984
+ raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
985
+
986
+ # Propagate training_session_id via env config for downstream usage
987
+ _env_config = dict(request.env.config or {})
988
+ if request.training_session_id is not None:
989
+ _env_config.setdefault("training_session_id", request.training_session_id)
990
+ env_response = await create_environment(
991
+ EnvCreateRequest(
992
+ env_name=request.env.env_name,
993
+ config=_env_config,
994
+ seed=request.env.seed,
995
+ rl_run_id=request.run_id,
996
+ )
997
+ )
998
+ env_id = env_response.env_id
999
+ env_handle = registry.get_env(env_id)
1000
+ created_env_id = env_id
1001
+
1002
+ tracing_context.update_metadata(env_id=env_id)
1003
+
1004
+ # Resolve or create policy
1005
+ if request.policy.policy_id:
1006
+ policy_handle = registry.get_policy(request.policy.policy_id)
1007
+ if not policy_handle:
1008
+ raise HTTPException(
1009
+ status_code=404,
1010
+ detail=f"Policy {request.policy.policy_id} not found",
1011
+ )
1012
+ policy_id = request.policy.policy_id
1013
+ else:
1014
+ # Create new policy
1015
+ from .policy_routes import PolicyCreateRequest, create_policy
1016
+
1017
+ if not request.policy.policy_name:
1018
+ raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
1019
+
1020
+ # Propagate training_session_id and synth_base_url via policy config
1021
+ _policy_config = dict(request.policy.config or {})
1022
+ if request.training_session_id is not None:
1023
+ _policy_config.setdefault("training_session_id", request.training_session_id)
1024
+ if request.synth_base_url is not None:
1025
+ _policy_config.setdefault("synth_base_url", request.synth_base_url)
1026
+ policy_response = await create_policy(
1027
+ PolicyCreateRequest(
1028
+ policy_name=request.policy.policy_name,
1029
+ config=_policy_config,
1030
+ rl_run_id=request.run_id,
1031
+ bound_env_id=env_id,
1032
+ ),
1033
+ req,
1034
+ )
1035
+ policy_id = policy_response.policy_id
1036
+ policy_handle = registry.get_policy(policy_id)
1037
+ created_policy_id = policy_id
1038
+
1039
+ tracing_context.update_metadata(policy_id=policy_id)
1040
+
1041
+ # Bind policy to environment if not already bound
1042
+ if policy_handle and not policy_handle.bound_env_id:
1043
+ policy_handle.bound_env_id = env_id
1044
+
1045
+ # Record seed bound to environment for end-of-rollout verification/logging
1046
+ try:
1047
+ env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
1048
+ except Exception:
1049
+ env_seed_used = None
1050
+ tracing_context.update_metadata(env_seed=env_seed_used)
1051
+ # Initialize trajectory
1052
+ trajectory_steps = []
1053
+ pending_tool_calls = None
1054
+ current_obs = env_handle.last_observation
1055
+ total_reward = 0.0
1056
+ ops_executed = 0
1057
+ last_agent_response_ts = None
1058
+ last_policy_meta = None
1059
+ last_env_step_ms = None
1060
+ last_env_step_completed_ts = None
1061
+
1062
+ # Stepwise reward configuration (Crafter shaping; gate on explicit enable)
1063
+ step_rewards_cfg_raw: dict[str, Any] = {}
1064
+ try:
1065
+ if isinstance(request.policy.config, dict):
1066
+ step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
1067
+ except Exception:
1068
+ step_rewards_cfg_raw = {}
1069
+ if not step_rewards_cfg_raw:
1070
+ try:
1071
+ if isinstance(request.env.config, dict):
1072
+ step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
1073
+ except Exception:
1074
+ step_rewards_cfg_raw = {}
1075
+
1076
+ step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
1077
+ step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
1078
+ try:
1079
+ step_rewards_indicator_lambda = float(
1080
+ step_rewards_cfg_raw.get("indicator_lambda") or 0.0
1081
+ )
1082
+ except Exception:
1083
+ step_rewards_indicator_lambda = 0.0
1084
+ try:
1085
+ step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
1086
+ except Exception:
1087
+ step_rewards_beta = 0.0
1088
+ step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
1089
+
1090
+ def _extract_achievements(obs: Any) -> dict[str, bool]:
1091
+ if not isinstance(obs, dict):
1092
+ return {}
1093
+ ach = obs.get("achievements_status")
1094
+ if isinstance(ach, dict):
1095
+ return {str(k): bool(v) for k, v in ach.items()}
1096
+ return {}
1097
+
1098
+ def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
1099
+ if not tool_calls:
1100
+ return []
1101
+ try:
1102
+ items = (
1103
+ tool_calls
1104
+ if isinstance(tool_calls, list)
1105
+ else list(tool_calls) # tolerates tuples or pydantic lists
1106
+ )
1107
+ except Exception:
1108
+ return []
1109
+ summary: list[dict[str, Any]] = []
1110
+ for tc in items:
1111
+ tool_name = None
1112
+ args: Any = {}
1113
+ if isinstance(tc, dict):
1114
+ tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
1115
+ raw_args = tc.get("arguments") or tc.get("args") or {}
1116
+ else:
1117
+ tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
1118
+ raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
1119
+ args = raw_args
1120
+ if isinstance(raw_args, str):
1121
+ try:
1122
+ args = json.loads(raw_args)
1123
+ except Exception:
1124
+ args = raw_args
1125
+ summary.append({"tool": tool_name, "args": args})
1126
+ return summary
1127
+
1128
+ decision_samples: list[dict[str, Any]] = []
1129
+ decision_index = 0
1130
+ decision_open = False
1131
+ session_trace = None
1132
+ finalized = False
1133
+ prev_achievements = _extract_achievements(current_obs)
1134
+ # Track episode-level achievements that have been seen as true at any point so far
1135
+ episode_seen_achievements: set[str] = {
1136
+ k for k, v in (prev_achievements or {}).items() if bool(v)
1137
+ }
1138
+ stepwise_indicator_sum = 0.0
1139
+ stepwise_reward_sum = 0.0
1140
+ stepwise_new_achievements_total = 0
1141
+ final_achievement_count = sum(1 for v in prev_achievements.values() if v)
1142
+
1143
+ # Execute ops sequence (capped by env_params.max_steps_per_episode)
1144
+ for op_idx, op in enumerate(ops_seq):
1145
+ # Check for abort
1146
+ if registry.is_run_aborted(request.run_id):
1147
+ logger.info(f"Run {request.run_id} aborted at op {op_idx}")
1148
+ break
1149
+
1150
+ # Check safety limits
1151
+ if ops_executed >= request.safety.max_ops:
1152
+ logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
1153
+ break
1154
+
1155
+ if op == "agent":
1156
+ # Policy step
1157
+ from .policy_routes import PolicyStepRequest, step_policy
1158
+
1159
+ if not decision_open:
1160
+ await tracing_context.start_decision(decision_index)
1161
+ decision_open = True
1162
+
1163
+ agent_request_start = _time.perf_counter()
1164
+ if last_agent_response_ts is not None and last_policy_meta is not None:
1165
+ with contextlib.suppress(Exception):
1166
+ timing_prev = last_policy_meta.setdefault("timing", {})
1167
+ decision_ms = max(
1168
+ 0.0,
1169
+ (agent_request_start - float(last_agent_response_ts)) * 1000.0,
1170
+ )
1171
+ # Update timing on prior policy meta (kept by previous env step)
1172
+ timing_prev["decision_ms"] = decision_ms
1173
+ if last_env_step_ms is not None:
1174
+ timing_prev["env_step_ms"] = float(last_env_step_ms)
1175
+ timing_prev["overhead_ms"] = max(
1176
+ 0.0, decision_ms - float(last_env_step_ms)
1177
+ )
1178
+ else:
1179
+ timing_prev.setdefault("overhead_ms", 0.0)
1180
+ timing_prev["decision_ready_s"] = agent_request_start
1181
+ # Also backfill the last appended trajectory step so the trainer
1182
+ # can always see decision_ms without relying on shared dict refs.
1183
+ if trajectory_steps:
1184
+ with contextlib.suppress(Exception):
1185
+ _last = trajectory_steps[-1]
1186
+ _info = dict(_last.info or {})
1187
+ _meta = dict(_info.get("meta") or {})
1188
+ _timing = dict(_meta.get("timing") or {})
1189
+ _timing["decision_ms"] = decision_ms
1190
+ if last_env_step_ms is not None:
1191
+ _timing.setdefault("env_step_ms", float(last_env_step_ms))
1192
+ _timing.setdefault(
1193
+ "overhead_ms",
1194
+ max(0.0, decision_ms - float(last_env_step_ms)),
1195
+ )
1196
+ else:
1197
+ _timing.setdefault("overhead_ms", 0.0)
1198
+ _meta["timing"] = _timing
1199
+ _info["meta"] = _meta
1200
+ _last.info = _info
1201
+ last_env_step_ms = None
1202
+ last_env_step_completed_ts = None
1203
+
1204
+ # Build metadata for policy (carry previous tool_calls and env result)
1205
+ metadata = {}
1206
+ if pending_tool_calls:
1207
+ metadata["prev_tool_calls"] = pending_tool_calls
1208
+ if len(trajectory_steps) > 0:
1209
+ last_step = trajectory_steps[-1]
1210
+ # Prefer the last executed tool calls to seed history
1211
+ if last_step.tool_calls:
1212
+ metadata["prev_tool_calls"] = last_step.tool_calls
1213
+ # Provide a compact env result snapshot
1214
+ metadata["prev_env_result"] = {
1215
+ "observation": last_step.obs,
1216
+ "reward": last_step.reward,
1217
+ "done": last_step.done,
1218
+ "truncated": last_step.truncated,
1219
+ "info": last_step.info,
1220
+ }
1221
+
1222
+ # Log compact metadata summary to confirm history threading
1223
+ with contextlib.suppress(Exception):
1224
+ _prev_calls = metadata.get("prev_tool_calls")
1225
+ _count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
1226
+ _first_guess = None
1227
+ if _count > 0 and isinstance(_prev_calls[0], dict):
1228
+ _args = _prev_calls[0].get("arguments", None)
1229
+ if isinstance(_args, str):
1230
+ import json as _json
1231
+ with contextlib.suppress(Exception):
1232
+ _args = _json.loads(_args)
1233
+ if not isinstance(_args, dict):
1234
+ _args = {}
1235
+ _first_guess = _args.get("guess") or _args.get("word")
1236
+ logger.info(
1237
+ "POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
1238
+ _count,
1239
+ _first_guess,
1240
+ str("prev_env_result" in metadata),
1241
+ )
1242
+
1243
+ try:
1244
+ policy_response = await step_policy(
1245
+ PolicyStepRequest(
1246
+ policy_id=policy_id,
1247
+ observation=current_obs,
1248
+ metadata=metadata,
1249
+ ),
1250
+ req,
1251
+ )
1252
+ except Exception as _pe:
1253
+ # Do not 500 the rollout; finalize with partial trajectory
1254
+ with contextlib.suppress(Exception):
1255
+ logger.warning(
1256
+ "POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1257
+ request.run_id,
1258
+ str(op_idx),
1259
+ str(_pe),
1260
+ )
1261
+
1262
+ # Build partial trajectory and return HTTP 200
1263
+ trajectory = RolloutTrajectory(
1264
+ env_id=env_id,
1265
+ policy_id=policy_id,
1266
+ steps=trajectory_steps,
1267
+ final={
1268
+ "observation": current_obs,
1269
+ "rollout_status": "partial_policy_error",
1270
+ "error": str(_pe),
1271
+ "at_op": op,
1272
+ },
1273
+ length=len(trajectory_steps),
1274
+ decision_samples=decision_samples if step_rewards_active else None,
1275
+ )
1276
+ metrics = RolloutMetrics(
1277
+ episode_returns=[total_reward],
1278
+ mean_return=total_reward,
1279
+ num_steps=len(trajectory_steps),
1280
+ num_episodes=1,
1281
+ )
1282
+ aborted = registry.is_run_aborted(request.run_id)
1283
+ if not aborted:
1284
+ registry.complete_run(request.run_id)
1285
+ if decision_open:
1286
+ await tracing_context.end_decision()
1287
+ decision_open = False
1288
+ if not finalized:
1289
+ session_trace = await tracing_context.finalize(
1290
+ total_reward=total_reward,
1291
+ achievement_state=prev_achievements,
1292
+ total_steps=len(trajectory_steps),
1293
+ )
1294
+ finalized = True
1295
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1296
+ return RolloutResponse(
1297
+ run_id=request.run_id,
1298
+ trajectories=[trajectory],
1299
+ branches={},
1300
+ metrics=metrics,
1301
+ aborted=aborted,
1302
+ ops_executed=ops_executed,
1303
+ trace=trace_payload,
1304
+ )
1305
+
1306
+ agent_response_ts = _time.perf_counter()
1307
+ if isinstance(policy_response.meta, dict):
1308
+ with contextlib.suppress(Exception):
1309
+ timing_cur = policy_response.meta.setdefault("timing", {})
1310
+ timing_cur["agent_request_start_s"] = agent_request_start
1311
+ timing_cur["agent_response_s"] = agent_response_ts
1312
+ if "inference_ms" in policy_response.meta:
1313
+ with contextlib.suppress(Exception):
1314
+ timing_cur.setdefault(
1315
+ "inference_ms",
1316
+ float(policy_response.meta["inference_ms"]),
1317
+ )
1318
+ timing_cur.setdefault(
1319
+ "inference_s",
1320
+ float(policy_response.meta["inference_ms"]) / 1000.0,
1321
+ )
1322
+ last_policy_meta = policy_response.meta
1323
+ else:
1324
+ last_policy_meta = None
1325
+ last_agent_response_ts = agent_response_ts
1326
+
1327
+ # Diagnostic: summarize policy step target and tool calls
1328
+ try:
1329
+ model_name = None
1330
+ target_url = None
1331
+ if isinstance(policy_response.meta, dict):
1332
+ req_body = policy_response.meta.get("inference_request") or {}
1333
+ model_name = req_body.get("model")
1334
+ target_url = policy_response.meta.get("inference_url")
1335
+ _tc = policy_response.tool_calls or []
1336
+ print(
1337
+ {
1338
+ "rollout.policy_step": True,
1339
+ "run_id": request.run_id,
1340
+ "model": model_name,
1341
+ "inference_url": target_url,
1342
+ "tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
1343
+ },
1344
+ flush=True,
1345
+ )
1346
+ except Exception:
1347
+ pass
1348
+
1349
+ pending_tool_calls = policy_response.tool_calls
1350
+ # Log summarized agent tool calls
1351
+ with contextlib.suppress(Exception):
1352
+ _tc = pending_tool_calls or []
1353
+ _summary = []
1354
+ for _item in (_tc if isinstance(_tc, list) else []):
1355
+ try:
1356
+ if isinstance(_item, dict):
1357
+ _tool = _item.get("tool")
1358
+ _args = _item.get("args")
1359
+ _keys = list(_args.keys()) if isinstance(_args, dict) else []
1360
+ _summary.append({"tool": _tool, "args_keys": _keys})
1361
+ except Exception:
1362
+ continue
1363
+ _rid = getattr(request, "run_id", None)
1364
+ logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
1365
+ print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
1366
+ await tracing_context.record_tool_invocation(pending_tool_calls)
1367
+ ops_executed += 1
1368
+
1369
+ elif op == "env":
1370
+ if not pending_tool_calls:
1371
+ # Treat absence of tool calls as a soft terminal condition; yield partial trajectory
1372
+ with contextlib.suppress(Exception):
1373
+ logger.warning(
1374
+ "NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
1375
+ request.run_id,
1376
+ str(op_idx),
1377
+ )
1378
+ print(
1379
+ f"[rollout] no tool_calls; terminating early run_id={request.run_id} op_idx={op_idx}",
1380
+ flush=True,
1381
+ )
1382
+ term_step = RolloutStep(
1383
+ obs=current_obs,
1384
+ tool_calls=[],
1385
+ reward=None,
1386
+ done=True,
1387
+ truncated=False,
1388
+ info={
1389
+ "terminated": True,
1390
+ "reason": "no_tool_calls",
1391
+ },
1392
+ )
1393
+ trajectory_steps.append(term_step)
1394
+ trajectory = RolloutTrajectory(
1395
+ env_id=env_id,
1396
+ policy_id=policy_id,
1397
+ steps=trajectory_steps,
1398
+ final={
1399
+ "observation": current_obs,
1400
+ "rollout_status": "partial_no_tool_calls",
1401
+ "at_op": op,
1402
+ },
1403
+ length=len(trajectory_steps),
1404
+ decision_samples=decision_samples if step_rewards_active else None,
1405
+ )
1406
+ metrics = RolloutMetrics(
1407
+ episode_returns=[total_reward],
1408
+ mean_return=total_reward,
1409
+ num_steps=len(trajectory_steps),
1410
+ num_episodes=1,
1411
+ )
1412
+ aborted = registry.is_run_aborted(request.run_id)
1413
+ if not aborted:
1414
+ registry.complete_run(request.run_id)
1415
+ if decision_open:
1416
+ await tracing_context.end_decision()
1417
+ decision_open = False
1418
+ if not finalized:
1419
+ session_trace = await tracing_context.finalize(
1420
+ total_reward=total_reward,
1421
+ achievement_state=prev_achievements,
1422
+ total_steps=len(trajectory_steps),
1423
+ )
1424
+ finalized = True
1425
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1426
+ return RolloutResponse(
1427
+ run_id=request.run_id,
1428
+ trajectories=[trajectory],
1429
+ branches={},
1430
+ metrics=metrics,
1431
+ aborted=aborted,
1432
+ ops_executed=ops_executed,
1433
+ trace=trace_payload,
1434
+ )
1435
+
1436
+ # Environment step
1437
+ from .environment_routes import EnvStepRequest, step_environment
1438
+
1439
+ env_step_error: Exception | None = None
1440
+ env_response = None
1441
+ env_step_start = _time.perf_counter()
1442
+ try:
1443
+ env_response = await step_environment(
1444
+ EnvStepRequest(
1445
+ env_id=env_id,
1446
+ tool_calls=pending_tool_calls,
1447
+ )
1448
+ )
1449
+ except Exception as _ee:
1450
+ env_step_error = _ee
1451
+ env_step_end = _time.perf_counter()
1452
+ env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
1453
+ last_env_step_ms = env_step_duration_ms
1454
+ last_env_step_completed_ts = env_step_end
1455
+ if last_policy_meta is not None:
1456
+ with contextlib.suppress(Exception):
1457
+ timing_env = last_policy_meta.setdefault("timing", {})
1458
+ timing_env["env_step_ms"] = env_step_duration_ms
1459
+ timing_env["env_step_end_s"] = env_step_end
1460
+
1461
+ if env_step_error is not None:
1462
+ # Invalid action or environment rejection — terminate episode early with partial trajectory
1463
+ with contextlib.suppress(Exception):
1464
+ logger.warning(
1465
+ "ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1466
+ request.run_id,
1467
+ str(op_idx),
1468
+ str(env_step_error),
1469
+ )
1470
+
1471
+ term_step = RolloutStep(
1472
+ obs=current_obs,
1473
+ tool_calls=pending_tool_calls,
1474
+ reward=None,
1475
+ done=True,
1476
+ truncated=False,
1477
+ info={
1478
+ "terminated": True,
1479
+ "reason": "invalid_action",
1480
+ "error": str(env_step_error),
1481
+ },
1482
+ )
1483
+ trajectory_steps.append(term_step)
1484
+ # Build partial response
1485
+ trajectory = RolloutTrajectory(
1486
+ env_id=env_id,
1487
+ policy_id=policy_id,
1488
+ steps=trajectory_steps,
1489
+ final={
1490
+ "observation": current_obs,
1491
+ "rollout_status": "partial_invalid_action",
1492
+ "error": str(env_step_error),
1493
+ "at_op": op,
1494
+ },
1495
+ length=len(trajectory_steps),
1496
+ decision_samples=decision_samples if step_rewards_active else None,
1497
+ )
1498
+ metrics = RolloutMetrics(
1499
+ episode_returns=[total_reward],
1500
+ mean_return=total_reward,
1501
+ num_steps=len(trajectory_steps),
1502
+ num_episodes=1,
1503
+ )
1504
+ aborted = registry.is_run_aborted(request.run_id)
1505
+ if not aborted:
1506
+ registry.complete_run(request.run_id)
1507
+ if (
1508
+ last_policy_meta is not None
1509
+ and last_agent_response_ts is not None
1510
+ and "decision_ms" not in last_policy_meta.get("timing", {})
1511
+ ):
1512
+ with contextlib.suppress(Exception):
1513
+ timing_last = last_policy_meta.setdefault("timing", {})
1514
+ decision_ms = max(
1515
+ 0.0,
1516
+ (env_step_end - float(last_agent_response_ts)) * 1000.0,
1517
+ )
1518
+ timing_last["decision_ms"] = decision_ms
1519
+ timing_last.setdefault(
1520
+ "overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
1521
+ )
1522
+ if decision_open:
1523
+ await tracing_context.end_decision()
1524
+ decision_open = False
1525
+ if not finalized:
1526
+ session_trace = await tracing_context.finalize(
1527
+ total_reward=total_reward,
1528
+ achievement_state=prev_achievements,
1529
+ total_steps=len(trajectory_steps),
1530
+ )
1531
+ finalized = True
1532
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1533
+ return RolloutResponse(
1534
+ run_id=request.run_id,
1535
+ trajectories=[trajectory],
1536
+ branches={},
1537
+ metrics=metrics,
1538
+ aborted=aborted,
1539
+ ops_executed=ops_executed,
1540
+ trace=trace_payload,
1541
+ )
1542
+
1543
+ # Reaching here means env step succeeded
1544
+ assert env_response is not None
1545
+
1546
+ # Record step, including policy meta if present for timing/tokens observability
1547
+ _info = env_response.info if isinstance(env_response.info, dict) else {}
1548
+ # Attach policy meta from the immediately preceding agent step
1549
+ with contextlib.suppress(Exception):
1550
+ prev_meta = {}
1551
+ if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
1552
+ prev_meta = policy_response.meta
1553
+ if prev_meta:
1554
+ _info = dict(_info)
1555
+ _info["meta"] = prev_meta
1556
+
1557
+ event_metadata = {
1558
+ "op_index": op_idx,
1559
+ }
1560
+ event_id = await tracing_context.record_environment_event(
1561
+ env_handle=env_handle,
1562
+ prev_obs=current_obs,
1563
+ env_response=env_response,
1564
+ next_obs=getattr(env_response, "observation", None),
1565
+ metadata=event_metadata,
1566
+ )
1567
+
1568
+ decision_index += 1
1569
+ next_obs = env_response.observation
1570
+ new_achievement_state = _extract_achievements(next_obs)
1571
+ final_achievement_count = sum(
1572
+ 1 for _, unlocked in new_achievement_state.items() if unlocked
1573
+ )
1574
+ indicator_val = 0
1575
+ reward_stepwise = 0.0
1576
+ decision_rewards_meta: dict[str, Any] | None = None
1577
+ if step_rewards_active:
1578
+ decision_actions = _summarize_tool_calls(pending_tool_calls)
1579
+ stepwise_info, decision_record, stats = compute_stepwise_reward(
1580
+ prev_achievements or {},
1581
+ new_achievement_state,
1582
+ decision_index,
1583
+ decision_actions,
1584
+ step_rewards_indicator_lambda,
1585
+ )
1586
+ indicator_val = int(stats.get("indicator", 0.0))
1587
+ reward_stepwise = float(stats.get("reward", 0.0))
1588
+ stepwise_indicator_sum += float(stats.get("indicator", 0.0))
1589
+ stepwise_reward_sum += reward_stepwise
1590
+ stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
1591
+ _info = {} if not isinstance(_info, dict) else dict(_info)
1592
+ _info["stepwise"] = stepwise_info
1593
+ # Compute decision-level rewards (absolute vs unique) and attach to metadata
1594
+ with contextlib.suppress(Exception):
1595
+ turned_true = set(stepwise_info.get("new_achievements") or [])
1596
+ seen_before = set(episode_seen_achievements)
1597
+ new_unique = sorted(turned_true - seen_before)
1598
+ ach_delta = int(len(turned_true))
1599
+ unique_delta = int(len(new_unique))
1600
+ # Prepare stable lists for logging/metadata
1601
+ all_list = sorted(turned_true)
1602
+ # Ensure nested meta exists
1603
+ meta_block = (
1604
+ _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
1605
+ )
1606
+ decision_rewards = {
1607
+ "turn": int(decision_index),
1608
+ "ach_delta": ach_delta,
1609
+ "unique_delta": unique_delta,
1610
+ "all": all_list,
1611
+ "unique": new_unique,
1612
+ }
1613
+ decision_rewards_meta = decision_rewards
1614
+ meta_block["decision_rewards"] = decision_rewards
1615
+ _info["meta"] = meta_block
1616
+ # Update episode-level seen set after attributing uniqueness to this decision
1617
+ episode_seen_achievements.update(turned_true)
1618
+ decision_samples.append(decision_record)
1619
+ prev_achievements = new_achievement_state
1620
+
1621
+ await tracing_context.record_decision_reward(
1622
+ event_id=event_id,
1623
+ decision_meta=decision_rewards_meta,
1624
+ )
1625
+
1626
+ step = RolloutStep(
1627
+ obs=_summarize_observation_for_storage(env_handle, current_obs),
1628
+ tool_calls=pending_tool_calls,
1629
+ reward=env_response.reward,
1630
+ done=env_response.done,
1631
+ truncated=env_response.truncated,
1632
+ info=_info,
1633
+ )
1634
+ # Log summarized env application of tool calls and immediate reward/done
1635
+ with contextlib.suppress(Exception):
1636
+ _tc = pending_tool_calls or []
1637
+ _summary = []
1638
+ for _item in (_tc if isinstance(_tc, list) else []):
1639
+ try:
1640
+ if isinstance(_item, dict):
1641
+ _tool = _item.get("tool")
1642
+ _args = _item.get("args")
1643
+ _keys = list(_args.keys()) if isinstance(_args, dict) else []
1644
+ _summary.append({"tool": _tool, "args_keys": _keys})
1645
+ except Exception:
1646
+ continue
1647
+ _rid = getattr(request, "run_id", None)
1648
+ logger.info(
1649
+ "ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
1650
+ _rid,
1651
+ len(_tc),
1652
+ str(env_response.reward),
1653
+ str(env_response.done),
1654
+ _summary,
1655
+ )
1656
+ print(
1657
+ f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
1658
+ flush=True,
1659
+ )
1660
+ trajectory_steps.append(step)
1661
+
1662
+ if env_response.reward is not None:
1663
+ total_reward += env_response.reward
1664
+
1665
+ # Update state
1666
+ current_obs = next_obs
1667
+ pending_tool_calls = None
1668
+ ops_executed += 1
1669
+
1670
+ # Handle episode end
1671
+ if env_response.done:
1672
+ if request.on_done == "reset":
1673
+ # Reset environment
1674
+ from .environment_routes import (
1675
+ EnvResetRequest,
1676
+ reset_environment,
1677
+ )
1678
+
1679
+ reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
1680
+ current_obs = reset_response.observation
1681
+ elif request.on_done == "terminate":
1682
+ break
1683
+
1684
+ if decision_open:
1685
+ await tracing_context.end_decision()
1686
+ decision_open = False
1687
+
1688
+ else:
1689
+ logger.warning(f"Unknown op: {op}")
1690
+
1691
+ if (
1692
+ last_policy_meta is not None
1693
+ and last_agent_response_ts is not None
1694
+ and "timing" in last_policy_meta
1695
+ and isinstance(last_policy_meta["timing"], dict)
1696
+ and "decision_ms" not in last_policy_meta["timing"]
1697
+ ):
1698
+ with contextlib.suppress(Exception):
1699
+ final_now = last_env_step_completed_ts or _time.perf_counter()
1700
+ final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
1701
+ timing_final = last_policy_meta.setdefault("timing", {})
1702
+ timing_final["decision_ms"] = final_decision_ms
1703
+ if last_env_step_ms is not None:
1704
+ timing_final.setdefault("env_step_ms", float(last_env_step_ms))
1705
+ timing_final.setdefault(
1706
+ "overhead_ms",
1707
+ max(0.0, final_decision_ms - float(last_env_step_ms)),
1708
+ )
1709
+ else:
1710
+ timing_final.setdefault("overhead_ms", 0.0)
1711
+
1712
+ # Build trajectory
1713
+ trajectory = RolloutTrajectory(
1714
+ env_id=env_id,
1715
+ policy_id=policy_id,
1716
+ steps=trajectory_steps,
1717
+ final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
1718
+ length=len(trajectory_steps),
1719
+ decision_samples=decision_samples if step_rewards_active else None,
1720
+ )
1721
+
1722
+ # Build metrics
1723
+ metrics = RolloutMetrics(
1724
+ episode_returns=[total_reward],
1725
+ mean_return=total_reward,
1726
+ num_steps=len(trajectory_steps),
1727
+ num_episodes=1,
1728
+ )
1729
+
1730
+ # Environment-specific: Log summary if available
1731
+ try:
1732
+ # Check if this is a Wordle environment and use Wordle helpers (lazy import)
1733
+ wordle_wrapper_cls = None
1734
+ try:
1735
+ from .envs.wordle.environment import WordleEnvironmentWrapper
1736
+ from .envs.wordle.helpers import (
1737
+ get_wordle_rollout_summary,
1738
+ log_wordle_rollout_summary,
1739
+ )
1740
+
1741
+ wordle_wrapper_cls = WordleEnvironmentWrapper
1742
+ except Exception:
1743
+ wordle_wrapper_cls = None # type: ignore[assignment]
1744
+ get_wordle_rollout_summary = None # type: ignore
1745
+ log_wordle_rollout_summary = None # type: ignore
1746
+
1747
+ is_wordle = wordle_wrapper_cls is not None and isinstance(
1748
+ env_handle.env,
1749
+ wordle_wrapper_cls, # type: ignore[arg-type]
1750
+ )
1751
+ if is_wordle:
1752
+ # Convert trajectory steps to expected format
1753
+ formatted_steps = []
1754
+ for step in trajectory_steps:
1755
+ formatted_steps.append({"tool_calls": step.tool_calls or []})
1756
+
1757
+ if (
1758
+ get_wordle_rollout_summary is not None
1759
+ and log_wordle_rollout_summary is not None
1760
+ ):
1761
+ summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
1762
+ log_wordle_rollout_summary(request.run_id, summary)
1763
+ except ImportError:
1764
+ # Wordle helpers not available, skip Wordle-specific logging
1765
+ pass
1766
+ except Exception as e:
1767
+ logger.warning(f"Failed to generate environment-specific summary: {e}")
1768
+
1769
+ # Mark run as completed
1770
+ aborted = registry.is_run_aborted(request.run_id)
1771
+ if not aborted:
1772
+ registry.complete_run(request.run_id)
1773
+ if decision_open:
1774
+ await tracing_context.end_decision()
1775
+ decision_open = False
1776
+ if not finalized:
1777
+ session_trace = await tracing_context.finalize(
1778
+ total_reward=total_reward,
1779
+ achievement_state=prev_achievements,
1780
+ total_steps=len(trajectory_steps),
1781
+ )
1782
+ finalized = True
1783
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1784
+
1785
+ return RolloutResponse(
1786
+ run_id=request.run_id,
1787
+ trajectories=[trajectory],
1788
+ branches={},
1789
+ metrics=metrics,
1790
+ aborted=aborted,
1791
+ ops_executed=ops_executed,
1792
+ trace=trace_payload,
1793
+ )
1794
+
1795
+ except Exception as e:
1796
+ logger.error(f"Rollout failed for run {request.run_id}: {e}")
1797
+ registry.abort_run(request.run_id)
1798
+ if decision_open:
1799
+ with contextlib.suppress(Exception):
1800
+ await tracing_context.end_decision()
1801
+ decision_open = False
1802
+ if not finalized:
1803
+ session_trace = None
1804
+ with contextlib.suppress(Exception):
1805
+ session_trace = await tracing_context.finalize(
1806
+ total_reward=total_reward,
1807
+ achievement_state=prev_achievements,
1808
+ total_steps=len(trajectory_steps),
1809
+ )
1810
+ finalized = True
1811
+ raise HTTPException(status_code=500, detail=str(e)) from e
1812
+ finally:
1813
+ # Ensure any environment created for this rollout is terminated (no reuse across rollouts)
1814
+ try:
1815
+ if created_env_id:
1816
+ from .environment_routes import EnvTerminateRequest, terminate_environment
1817
+
1818
+ await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
1819
+ logger.info(
1820
+ "ROLL_OUT: terminated environment env_id=%s seed=%s",
1821
+ str(created_env_id),
1822
+ str(env_seed_used) if env_seed_used is not None else "unknown",
1823
+ )
1824
+ # Verify removal from registry
1825
+ with contextlib.suppress(Exception):
1826
+ _post = registry.get_env(created_env_id)
1827
+ logger.info(
1828
+ "ROLL_OUT: env_killed=%s (post_lookup=%s)",
1829
+ str(_post is None),
1830
+ str(_post),
1831
+ )
1832
+ except Exception as _te:
1833
+ logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
1834
+
1835
+ # Best-effort policy cleanup if we created one (avoid reuse across rollouts)
1836
+ with contextlib.suppress(Exception):
1837
+ if created_policy_id:
1838
+ from .policy_routes import PolicyTerminateRequest, terminate_policy
1839
+
1840
+ await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
1841
+ logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
1842
+
1843
+ if not finalized:
1844
+ session_trace = None
1845
+ with contextlib.suppress(Exception):
1846
+ session_trace = await tracing_context.finalize(
1847
+ total_reward=total_reward,
1848
+ achievement_state=prev_achievements,
1849
+ total_steps=len(trajectory_steps),
1850
+ )
1851
+ finalized = True
1852
+
1853
+ with contextlib.suppress(Exception):
1854
+ _clear_seed_side_effects()
1855
+ logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
1856
+
1857
+
1858
+ @router.post("/run/abort", response_model=RunAbortResponse)
1859
+ async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
1860
+ """Abort a running rollout."""
1861
+ success = registry.abort_run(request.run_id)
1862
+
1863
+ if not success:
1864
+ raise HTTPException(
1865
+ status_code=404,
1866
+ detail=f"Run {request.run_id} not found",
1867
+ )
1868
+
1869
+ return RunAbortResponse(
1870
+ ok=True,
1871
+ run_id=request.run_id,
1872
+ )
1873
+
1874
+
1875
+ @router.get("/run/status/{run_id}", response_model=RunStatusResponse)
1876
+ async def get_run_status(run_id: str) -> RunStatusResponse:
1877
+ """Get the status of a run."""
1878
+ run_handle = registry.get_run(run_id)
1879
+
1880
+ if not run_handle:
1881
+ raise HTTPException(
1882
+ status_code=404,
1883
+ detail=f"Run {run_id} not found",
1884
+ )
1885
+
1886
+ return RunStatusResponse(
1887
+ run_id=run_id,
1888
+ status=run_handle.status,
1889
+ started_at=run_handle.started_at,
1890
+ finished_at=run_handle.finished_at,
1891
+ )