synth-ai 0.2.14__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (354) hide show
  1. examples/README.md +1 -0
  2. examples/analyze_semantic_words.sh +2 -2
  3. examples/blog_posts/pokemon_vl/README.md +98 -0
  4. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
  5. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  6. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  7. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
  8. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  9. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  10. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  11. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  12. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  13. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  14. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
  15. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  16. examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
  17. examples/multi_step/SFT_README.md +147 -0
  18. examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
  19. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +73 -115
  20. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
  21. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
  22. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  23. examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
  24. examples/multi_step/configs/verilog_rl_lora.toml +80 -123
  25. examples/multi_step/convert_traces_to_sft.py +84 -0
  26. examples/multi_step/run_sft_qwen30b.sh +45 -0
  27. examples/qwen_coder/configs/coder_lora_30b.toml +1 -2
  28. examples/qwen_coder/configs/coder_lora_4b.toml +5 -1
  29. examples/qwen_coder/configs/coder_lora_small.toml +1 -2
  30. examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
  31. examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
  32. examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
  33. examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
  34. examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
  35. examples/qwen_vl/QUICKSTART.md +327 -0
  36. examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
  37. examples/qwen_vl/README.md +152 -0
  38. examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
  39. examples/qwen_vl/RL_VISION_TESTING.md +333 -0
  40. examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
  41. examples/qwen_vl/SETUP_COMPLETE.md +274 -0
  42. examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
  43. examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
  44. examples/qwen_vl/__init__.py +2 -0
  45. examples/qwen_vl/collect_data_via_cli.md +415 -0
  46. examples/qwen_vl/collect_vision_traces.py +368 -0
  47. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
  48. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
  49. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
  50. examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
  51. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
  52. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  53. examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
  54. examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
  55. examples/qwen_vl/configs/filter_vision_test.toml +8 -0
  56. examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
  57. examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
  58. examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
  59. examples/qwen_vl/run_vision_comparison.sh +61 -0
  60. examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
  61. examples/qwen_vl/test_image_validation.py +201 -0
  62. examples/qwen_vl/test_sft_vision_data.py +110 -0
  63. examples/rl/README.md +6 -6
  64. examples/rl/configs/eval_base_qwen.toml +17 -0
  65. examples/rl/configs/eval_rl_qwen.toml +13 -0
  66. examples/rl/configs/rl_from_base_qwen.toml +62 -0
  67. examples/rl/configs/rl_from_base_qwen17.toml +79 -0
  68. examples/rl/configs/rl_from_ft_qwen.toml +37 -0
  69. examples/rl/run_eval.py +436 -0
  70. examples/rl/run_rl_and_save.py +111 -0
  71. examples/rl/task_app/README.md +21 -0
  72. examples/rl/task_app/math_single_step.py +990 -0
  73. examples/rl/task_app/math_task_app.py +111 -0
  74. examples/run_crafter_demo.sh +2 -2
  75. examples/sft/README.md +6 -6
  76. examples/sft/configs/crafter_fft_qwen0p6b.toml +7 -2
  77. examples/sft/configs/crafter_lora_qwen0p6b.toml +7 -3
  78. examples/sft/evaluate.py +2 -4
  79. examples/sft/export_dataset.py +7 -4
  80. examples/swe/task_app/README.md +33 -3
  81. examples/swe/task_app/grpo_swe_mini.py +4 -1
  82. examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
  83. examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
  84. examples/swe/task_app/hosted/envs/mini_swe/environment.py +50 -23
  85. examples/swe/task_app/hosted/inference/openai_client.py +4 -4
  86. examples/swe/task_app/hosted/policy_routes.py +0 -2
  87. examples/swe/task_app/hosted/rollout.py +0 -8
  88. examples/swe/task_app/morph_backend.py +178 -0
  89. examples/task_apps/crafter/task_app/README.md +1 -1
  90. examples/task_apps/crafter/task_app/grpo_crafter.py +70 -10
  91. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
  92. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +63 -27
  93. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
  94. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +48 -50
  95. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +75 -36
  96. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +31 -15
  97. examples/task_apps/enron/__init__.py +1 -0
  98. examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
  99. examples/task_apps/math/README.md +1 -2
  100. examples/task_apps/pokemon_red/README.md +3 -4
  101. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
  102. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
  103. examples/task_apps/pokemon_red/task_app.py +36 -5
  104. examples/task_apps/sokoban/README.md +2 -3
  105. examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
  106. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
  107. examples/vlm/README.md +3 -3
  108. examples/vlm/configs/crafter_vlm_gpt4o.toml +5 -0
  109. examples/vlm/crafter_openai_vlm_agent.py +3 -5
  110. examples/vlm/filter_image_rows.py +1 -1
  111. examples/vlm/run_crafter_vlm_benchmark.py +2 -2
  112. examples/warming_up_to_rl/_utils.py +92 -0
  113. examples/warming_up_to_rl/analyze_trace_db.py +1 -1
  114. examples/warming_up_to_rl/configs/crafter_fft.toml +5 -0
  115. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
  116. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
  117. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
  118. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
  119. examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
  120. examples/warming_up_to_rl/export_trace_sft.py +174 -60
  121. examples/warming_up_to_rl/readme.md +63 -132
  122. examples/warming_up_to_rl/run_fft_and_save.py +1 -1
  123. examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
  124. examples/warming_up_to_rl/run_rl_and_save.py +1 -1
  125. examples/warming_up_to_rl/task_app/README.md +42 -0
  126. examples/warming_up_to_rl/task_app/grpo_crafter.py +827 -0
  127. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
  128. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  129. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  130. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
  131. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
  132. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  133. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  134. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  135. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
  136. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
  137. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
  138. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
  139. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  140. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
  141. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  142. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
  143. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
  144. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1084 -0
  145. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
  146. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
  147. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  148. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
  149. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
  150. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
  151. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
  152. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
  153. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
  154. synth_ai/__init__.py +44 -30
  155. synth_ai/_utils/__init__.py +47 -0
  156. synth_ai/_utils/base_url.py +10 -0
  157. synth_ai/_utils/http.py +10 -0
  158. synth_ai/_utils/prompts.py +10 -0
  159. synth_ai/_utils/task_app_state.py +12 -0
  160. synth_ai/_utils/user_config.py +10 -0
  161. synth_ai/api/models/supported.py +144 -7
  162. synth_ai/api/train/__init__.py +13 -1
  163. synth_ai/api/train/builders.py +9 -3
  164. synth_ai/api/train/cli.py +155 -17
  165. synth_ai/api/train/config_finder.py +18 -11
  166. synth_ai/api/train/configs/__init__.py +8 -1
  167. synth_ai/api/train/configs/rl.py +32 -7
  168. synth_ai/api/train/configs/sft.py +6 -2
  169. synth_ai/api/train/configs/shared.py +59 -2
  170. synth_ai/api/train/env_resolver.py +13 -10
  171. synth_ai/auth/credentials.py +119 -0
  172. synth_ai/cli/__init__.py +61 -69
  173. synth_ai/cli/_modal_wrapper.py +7 -5
  174. synth_ai/cli/_typer_patch.py +0 -2
  175. synth_ai/cli/_validate_task_app.py +22 -4
  176. synth_ai/cli/commands/__init__.py +17 -0
  177. synth_ai/cli/commands/demo/__init__.py +6 -0
  178. synth_ai/cli/commands/demo/core.py +163 -0
  179. synth_ai/cli/commands/deploy/__init__.py +23 -0
  180. synth_ai/cli/commands/deploy/core.py +614 -0
  181. synth_ai/cli/commands/deploy/errors.py +72 -0
  182. synth_ai/cli/commands/deploy/validation.py +11 -0
  183. synth_ai/cli/commands/eval/__init__.py +19 -0
  184. synth_ai/cli/commands/eval/core.py +1109 -0
  185. synth_ai/cli/commands/eval/errors.py +81 -0
  186. synth_ai/cli/commands/eval/validation.py +133 -0
  187. synth_ai/cli/commands/filter/__init__.py +12 -0
  188. synth_ai/cli/commands/filter/core.py +388 -0
  189. synth_ai/cli/commands/filter/errors.py +55 -0
  190. synth_ai/cli/commands/filter/validation.py +77 -0
  191. synth_ai/cli/commands/help/__init__.py +177 -0
  192. synth_ai/cli/commands/help/core.py +73 -0
  193. synth_ai/cli/commands/status/__init__.py +64 -0
  194. synth_ai/cli/commands/status/client.py +192 -0
  195. synth_ai/cli/commands/status/config.py +92 -0
  196. synth_ai/cli/commands/status/errors.py +20 -0
  197. synth_ai/cli/commands/status/formatters.py +164 -0
  198. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  199. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  200. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  201. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  202. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  203. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  204. synth_ai/cli/commands/status/utils.py +114 -0
  205. synth_ai/cli/commands/train/__init__.py +53 -0
  206. synth_ai/cli/commands/train/core.py +21 -0
  207. synth_ai/cli/commands/train/errors.py +117 -0
  208. synth_ai/cli/commands/train/judge_schemas.py +199 -0
  209. synth_ai/cli/commands/train/judge_validation.py +304 -0
  210. synth_ai/cli/commands/train/validation.py +443 -0
  211. synth_ai/cli/demo.py +2 -162
  212. synth_ai/cli/deploy/__init__.py +28 -0
  213. synth_ai/cli/deploy/core.py +5 -0
  214. synth_ai/cli/deploy/errors.py +23 -0
  215. synth_ai/cli/deploy/validation.py +5 -0
  216. synth_ai/cli/eval/__init__.py +36 -0
  217. synth_ai/cli/eval/core.py +5 -0
  218. synth_ai/cli/eval/errors.py +31 -0
  219. synth_ai/cli/eval/validation.py +5 -0
  220. synth_ai/cli/filter/__init__.py +28 -0
  221. synth_ai/cli/filter/core.py +5 -0
  222. synth_ai/cli/filter/errors.py +23 -0
  223. synth_ai/cli/filter/validation.py +5 -0
  224. synth_ai/cli/legacy_root_backup.py +3 -1
  225. synth_ai/cli/lib/__init__.py +10 -0
  226. synth_ai/cli/lib/task_app_discovery.py +7 -0
  227. synth_ai/cli/lib/task_app_env.py +518 -0
  228. synth_ai/cli/modal_serve/__init__.py +12 -0
  229. synth_ai/cli/modal_serve/core.py +14 -0
  230. synth_ai/cli/modal_serve/errors.py +8 -0
  231. synth_ai/cli/modal_serve/validation.py +11 -0
  232. synth_ai/cli/recent.py +2 -1
  233. synth_ai/cli/serve/__init__.py +12 -0
  234. synth_ai/cli/serve/core.py +14 -0
  235. synth_ai/cli/serve/errors.py +8 -0
  236. synth_ai/cli/serve/validation.py +11 -0
  237. synth_ai/cli/setup.py +21 -0
  238. synth_ai/cli/status.py +7 -126
  239. synth_ai/cli/task_app_deploy.py +7 -0
  240. synth_ai/cli/task_app_list.py +25 -0
  241. synth_ai/cli/task_app_modal_serve.py +11 -0
  242. synth_ai/cli/task_app_serve.py +11 -0
  243. synth_ai/cli/task_apps.py +110 -1499
  244. synth_ai/cli/traces.py +1 -1
  245. synth_ai/cli/train/__init__.py +12 -0
  246. synth_ai/cli/train/core.py +21 -0
  247. synth_ai/cli/train/errors.py +8 -0
  248. synth_ai/cli/train/validation.py +24 -0
  249. synth_ai/cli/train.py +5 -0
  250. synth_ai/cli/turso.py +1 -1
  251. synth_ai/cli/watch.py +1 -1
  252. synth_ai/demos/__init__.py +10 -0
  253. synth_ai/demos/core/__init__.py +28 -1
  254. synth_ai/demos/crafter/__init__.py +1 -0
  255. synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
  256. synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
  257. synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
  258. synth_ai/demos/demo_registry.py +176 -0
  259. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  260. synth_ai/demos/math/__init__.py +1 -0
  261. synth_ai/demos/math/_common.py +16 -0
  262. synth_ai/demos/math/app.py +38 -0
  263. synth_ai/demos/math/config.toml +76 -0
  264. synth_ai/demos/math/deploy_modal.py +54 -0
  265. synth_ai/demos/math/modal_task_app.py +702 -0
  266. synth_ai/demos/math/task_app_entry.py +51 -0
  267. synth_ai/environments/environment/core.py +7 -1
  268. synth_ai/environments/examples/bandit/engine.py +0 -1
  269. synth_ai/environments/examples/bandit/environment.py +0 -1
  270. synth_ai/environments/examples/red/engine.py +33 -12
  271. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  272. synth_ai/environments/examples/red/environment.py +26 -0
  273. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  274. synth_ai/environments/examples/wordle/environment.py +0 -1
  275. synth_ai/evals/base.py +16 -5
  276. synth_ai/evals/client.py +1 -1
  277. synth_ai/http.py +8 -22
  278. synth_ai/inference/client.py +1 -1
  279. synth_ai/judge_schemas.py +4 -5
  280. synth_ai/learning/client.py +1 -1
  281. synth_ai/learning/health.py +1 -1
  282. synth_ai/learning/jobs.py +1 -1
  283. synth_ai/learning/rl/client.py +4 -2
  284. synth_ai/learning/rl/env_keys.py +1 -1
  285. synth_ai/learning/rl/secrets.py +1 -1
  286. synth_ai/learning/sft/client.py +1 -1
  287. synth_ai/learning/sft/data.py +407 -4
  288. synth_ai/learning/validators.py +4 -1
  289. synth_ai/streaming/__init__.py +29 -0
  290. synth_ai/streaming/config.py +94 -0
  291. synth_ai/streaming/handlers.py +469 -0
  292. synth_ai/streaming/streamer.py +301 -0
  293. synth_ai/streaming/types.py +95 -0
  294. synth_ai/task/apps/__init__.py +4 -2
  295. synth_ai/task/config.py +6 -4
  296. synth_ai/task/rubrics/__init__.py +1 -2
  297. synth_ai/task/rubrics/loaders.py +14 -10
  298. synth_ai/task/rubrics.py +219 -0
  299. synth_ai/task/trace_correlation_helpers.py +24 -11
  300. synth_ai/task/tracing_utils.py +14 -3
  301. synth_ai/task/validators.py +0 -1
  302. synth_ai/tracing_v3/abstractions.py +3 -3
  303. synth_ai/tracing_v3/config.py +15 -13
  304. synth_ai/tracing_v3/constants.py +21 -0
  305. synth_ai/tracing_v3/db_config.py +3 -1
  306. synth_ai/tracing_v3/decorators.py +10 -7
  307. synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
  308. synth_ai/tracing_v3/migration_helper.py +1 -2
  309. synth_ai/tracing_v3/session_tracer.py +7 -7
  310. synth_ai/tracing_v3/storage/base.py +29 -29
  311. synth_ai/tracing_v3/storage/config.py +3 -3
  312. synth_ai/tracing_v3/turso/daemon.py +8 -9
  313. synth_ai/tracing_v3/turso/native_manager.py +80 -72
  314. synth_ai/tracing_v3/utils.py +2 -2
  315. synth_ai/utils/__init__.py +101 -0
  316. synth_ai/utils/base_url.py +94 -0
  317. synth_ai/utils/cli.py +131 -0
  318. synth_ai/utils/env.py +294 -0
  319. synth_ai/utils/http.py +172 -0
  320. synth_ai/utils/modal.py +308 -0
  321. synth_ai/utils/process.py +212 -0
  322. synth_ai/utils/prompts.py +39 -0
  323. synth_ai/utils/sqld.py +122 -0
  324. synth_ai/utils/task_app_discovery.py +882 -0
  325. synth_ai/utils/task_app_env.py +186 -0
  326. synth_ai/utils/task_app_state.py +318 -0
  327. synth_ai/utils/user_config.py +137 -0
  328. synth_ai/v0/config/__init__.py +1 -5
  329. synth_ai/v0/config/base_url.py +1 -7
  330. synth_ai/v0/tracing/config.py +1 -1
  331. synth_ai/v0/tracing/decorators.py +1 -1
  332. synth_ai/v0/tracing/upload.py +1 -1
  333. synth_ai/v0/tracing_v1/config.py +1 -1
  334. synth_ai/v0/tracing_v1/decorators.py +1 -1
  335. synth_ai/v0/tracing_v1/upload.py +1 -1
  336. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/METADATA +91 -32
  337. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/RECORD +341 -154
  338. synth_ai/cli/man.py +0 -106
  339. synth_ai/cli/tui.py +0 -57
  340. synth_ai/compound/cais.py +0 -0
  341. synth_ai/core/experiment.py +0 -13
  342. synth_ai/core/system.py +0 -15
  343. synth_ai/demo_registry.py +0 -295
  344. synth_ai/handshake.py +0 -109
  345. synth_ai/tui/__init__.py +0 -5
  346. synth_ai/tui/__main__.py +0 -13
  347. synth_ai/tui/cli/__init__.py +0 -1
  348. synth_ai/tui/cli/query_experiments.py +0 -164
  349. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  350. synth_ai/tui/dashboard.py +0 -906
  351. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
  352. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
  353. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
  354. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1861 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time as _time
8
+ from datetime import datetime
9
+ from typing import Any
10
+
11
+ from fastapi import APIRouter, HTTPException, Request, status
12
+ from pydantic import BaseModel
13
+ from synth_ai.lm.vendors.base import BaseLMResponse
14
+ from synth_ai.task.tracing_utils import unique_sft_path
15
+ from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
16
+ from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
17
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
18
+
19
+ from .registry import registry
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ # --- Seeding utilities (robust, optional deps) ---
25
+ def _set_global_seed(seed_value: int) -> dict[str, Any]:
26
+ """Set global RNG seeds across common libraries; return details for logging/restoration.
27
+
28
+ Returns a dict containing which libraries were seeded and prior states if obtainable.
29
+ """
30
+ seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
31
+ with contextlib.suppress(Exception):
32
+ import random as _random # type: ignore
33
+
34
+ _random.seed(seed_value)
35
+ seeded["libs"].append("random")
36
+ with contextlib.suppress(Exception):
37
+ import numpy as _np # type: ignore
38
+
39
+ _np.random.seed(seed_value)
40
+ seeded["libs"].append("numpy")
41
+ with contextlib.suppress(Exception):
42
+ import torch as _torch # type: ignore
43
+
44
+ if hasattr(_torch, "manual_seed"):
45
+ _torch.manual_seed(seed_value)
46
+ seeded["libs"].append("torch")
47
+ # Make CUDA deterministic if present (best-effort)
48
+ with contextlib.suppress(Exception):
49
+ if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
50
+ _torch.cuda.manual_seed_all(seed_value)
51
+ seeded.setdefault("cuda", True)
52
+ # CUDNN deterministic flags (optional)
53
+ with contextlib.suppress(Exception):
54
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
55
+ _torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
56
+ _torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
57
+ return seeded
58
+
59
+
60
+ def _clear_seed_side_effects() -> None:
61
+ """Best-effort cleanup to avoid global deterministic side-effects between requests."""
62
+ # We cannot truly restore prior RNG states without capturing them; we just avoid
63
+ # leaving aggressive deterministic flags enabled where it matters.
64
+ with contextlib.suppress(Exception):
65
+ import torch as _torch # type: ignore
66
+
67
+ with contextlib.suppress(Exception):
68
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
69
+ # Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
70
+ # We'll keep deterministic False to avoid global impact; benchmark left False for stability.
71
+ _torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
72
+
73
+
74
+ router = APIRouter()
75
+
76
+
77
+ class RolloutEnvSpec(BaseModel):
78
+ env_id: str | None = None
79
+ env_name: str | None = None
80
+ config: dict[str, Any] = {}
81
+ seed: int | None = None
82
+
83
+
84
+ class RolloutPolicySpec(BaseModel):
85
+ policy_id: str | None = None
86
+ policy_name: str | None = None
87
+ config: dict[str, Any] = {}
88
+
89
+
90
+ class RolloutBranchConfig(BaseModel):
91
+ branch_every_n_steps: int = 0
92
+ branch_on_condition: str | None = None
93
+ max_branches: int = 0
94
+ branch_policy: bool = False
95
+ branch_env: bool = False
96
+
97
+
98
+ class RolloutRecordConfig(BaseModel):
99
+ trajectories: bool = True
100
+ logprobs: bool = False
101
+ value: bool = False
102
+ return_trace: bool = False
103
+ trace_format: str = "compact"
104
+
105
+
106
+ class RolloutSafetyConfig(BaseModel):
107
+ max_ops: int = 100000
108
+ max_time_s: float = 3600.0
109
+
110
+
111
+ class RolloutRequest(BaseModel):
112
+ run_id: str
113
+ env: RolloutEnvSpec
114
+ policy: RolloutPolicySpec
115
+ ops: list[str] # ["agent", "env", ...]
116
+ record: RolloutRecordConfig = RolloutRecordConfig()
117
+ on_done: str = "reset" # "reset" | "terminate"
118
+ branch: RolloutBranchConfig | None = None
119
+ safety: RolloutSafetyConfig = RolloutSafetyConfig()
120
+ # Optional run/session context
121
+ training_session_id: str | None = None
122
+ synth_base_url: str | None = None
123
+
124
+
125
+ class RolloutStep(BaseModel):
126
+ obs: dict[str, Any]
127
+ tool_calls: list[dict[str, Any]]
128
+ reward: float | None = None
129
+ done: bool = False
130
+ truncated: bool | None = None
131
+ logprob: float | None = None
132
+ value: float | None = None
133
+ info: dict[str, Any] | None = None
134
+
135
+
136
+ class RolloutTrajectory(BaseModel):
137
+ env_id: str
138
+ policy_id: str
139
+ steps: list[RolloutStep]
140
+ final: dict[str, Any] | None = None
141
+ length: int
142
+ decision_samples: list[dict[str, Any]] | None = None
143
+
144
+
145
+ def compute_stepwise_reward(
146
+ prev_achievements: dict[str, bool],
147
+ new_achievements: dict[str, bool],
148
+ decision_index: int,
149
+ actions_summary: list[dict[str, Any]],
150
+ indicator_lambda: float,
151
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
152
+ """Compute stepwise reward metadata given achievement states before/after a decision."""
153
+
154
+ prev_map = prev_achievements or {}
155
+ next_map = new_achievements or {}
156
+
157
+ unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
158
+ indicator = 1 if unlocked else 0
159
+ reward_value = float(indicator_lambda) * indicator
160
+
161
+ stepwise_info = {
162
+ "decision_index": decision_index,
163
+ "indicator": indicator,
164
+ "new_achievements": unlocked,
165
+ "reward": reward_value,
166
+ }
167
+ decision_sample = {
168
+ "decision_index": decision_index,
169
+ "indicator": indicator,
170
+ "r_i": reward_value,
171
+ "actions": actions_summary,
172
+ }
173
+ stats = {
174
+ "indicator": float(indicator),
175
+ "reward": reward_value,
176
+ "new_achievements_count": float(len(unlocked)),
177
+ }
178
+ return stepwise_info, decision_sample, stats
179
+
180
+
181
+ class RolloutMetrics(BaseModel):
182
+ episode_returns: list[float]
183
+ mean_return: float
184
+ num_steps: int
185
+ num_episodes: int = 0
186
+
187
+
188
+ class RolloutResponse(BaseModel):
189
+ run_id: str
190
+ trajectories: list[RolloutTrajectory]
191
+ branches: dict[str, list[str]] = {}
192
+ metrics: RolloutMetrics
193
+ aborted: bool = False
194
+ ops_executed: int = 0
195
+ trace: dict[str, Any] | None = None
196
+
197
+
198
+ class RolloutTracingContext:
199
+ """Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
200
+
201
+ def __init__(
202
+ self,
203
+ tracer: SessionTracer | None,
204
+ request: RolloutRequest,
205
+ fastapi_request: Request,
206
+ ) -> None:
207
+ self.tracer = tracer
208
+ self.enabled = tracer is not None
209
+ self.request = request
210
+ self.fastapi_request = fastapi_request
211
+ self.run_id = request.run_id
212
+ self.current_step_id: str | None = None
213
+ self.current_turn: int | None = None
214
+ self.lm_calls_summary: list[dict[str, Any]] = []
215
+ self.decision_rewards: list[dict[str, Any]] = []
216
+ self.sft_records: list[dict[str, Any]] = []
217
+ self.latest_system_messages: list[str] = []
218
+ self.latest_user_messages: list[str] = []
219
+ self.latest_system_prompt_content: list[Any] = []
220
+ self.latest_user_prompt_content: list[Any] = []
221
+ self.trace_format = (
222
+ getattr(request.record, "trace_format", "compact") or "compact"
223
+ ).lower()
224
+ self.return_trace = bool(getattr(request.record, "return_trace", False))
225
+ self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
226
+ self.session_trace = None
227
+ self.metadata_updates: dict[str, Any] = {}
228
+ self.policy_name = request.policy.policy_name or ""
229
+ self.env_name = request.env.env_name or ""
230
+ self.metadata_base: dict[str, Any] = {
231
+ "run_id": self.run_id,
232
+ "policy_name": self.policy_name,
233
+ "policy_id": request.policy.policy_id,
234
+ "env_name": self.env_name,
235
+ "env_id": request.env.env_id,
236
+ "seed": request.env.seed,
237
+ "training_session_id": request.training_session_id,
238
+ "synth_base_url": request.synth_base_url,
239
+ }
240
+
241
+ # Expose context for downstream calls inside this request lifecycle
242
+ fastapi_request.state.rollout_tracing = self
243
+ fastapi_request.state.rollout_run_id = self.run_id
244
+
245
+ async def start_session(self) -> None:
246
+ if not self.enabled or self.tracer is None:
247
+ return
248
+ try:
249
+ await self.tracer.initialize()
250
+ except Exception as exc:
251
+ logger.debug("TRACING_INIT_FAIL: %s", exc)
252
+ try:
253
+ await self.tracer.start_session(
254
+ session_id=self.run_id, metadata=dict(self.metadata_base)
255
+ )
256
+ except Exception as exc:
257
+ logger.warning("TRACING_START_FAIL: %s", exc)
258
+ self.enabled = False
259
+ self.tracer = None
260
+
261
+ async def start_decision(self, turn_number: int) -> None:
262
+ self.current_turn = turn_number
263
+ self.current_step_id = f"decision_{turn_number}"
264
+ if not self.enabled or self.tracer is None:
265
+ return
266
+ try:
267
+ await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
268
+ except Exception as exc:
269
+ logger.debug("TRACING_STEP_START_FAIL: %s", exc)
270
+
271
+ async def end_decision(self) -> None:
272
+ if not self.enabled or self.tracer is None:
273
+ return
274
+ try:
275
+ await self.tracer.end_timestep(step_id=self.current_step_id)
276
+ except Exception as exc:
277
+ logger.debug("TRACING_STEP_END_FAIL: %s", exc)
278
+ finally:
279
+ self.current_step_id = None
280
+
281
+ def _message_metadata(self) -> dict[str, Any]:
282
+ return {
283
+ "turn": self.current_turn,
284
+ "step_id": self.current_step_id,
285
+ }
286
+
287
+ async def record_policy_prompts(
288
+ self,
289
+ system_messages: list[Any],
290
+ user_messages: list[Any],
291
+ ) -> None:
292
+ self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
293
+ self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
294
+ self.latest_system_prompt_content = [
295
+ self._prompt_content(entry, role="system") for entry in system_messages
296
+ ]
297
+ self.latest_user_prompt_content = [
298
+ self._prompt_content(entry, role="user") for entry in user_messages
299
+ ]
300
+ if not self.enabled or self.tracer is None:
301
+ return
302
+ for entry in system_messages:
303
+ try:
304
+ await self.tracer.record_message(
305
+ content=self._prompt_payload(entry, role="system"),
306
+ message_type="policy_system_prompt",
307
+ metadata=self._message_metadata(),
308
+ )
309
+ except Exception as exc:
310
+ logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
311
+ for entry in user_messages:
312
+ try:
313
+ await self.tracer.record_message(
314
+ content=self._prompt_payload(entry, role="user"),
315
+ message_type="policy_user_prompt",
316
+ metadata=self._message_metadata(),
317
+ )
318
+ except Exception as exc:
319
+ logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
320
+
321
+ def _content_to_text(self, content: Any) -> str:
322
+ if isinstance(content, str):
323
+ return content
324
+ if isinstance(content, list):
325
+ parts: list[str] = []
326
+ for seg in content:
327
+ if isinstance(seg, dict):
328
+ text_val = seg.get("text") or seg.get("content")
329
+ if isinstance(text_val, str):
330
+ parts.append(text_val)
331
+ return "".join(parts)
332
+ if content is None:
333
+ return ""
334
+ return str(content)
335
+
336
+ def _prompt_text(self, entry: Any) -> str:
337
+ if isinstance(entry, dict):
338
+ text = entry.get("text")
339
+ if isinstance(text, str):
340
+ return text
341
+ content = entry.get("content")
342
+ return self._content_to_text(content)
343
+ return self._content_to_text(entry)
344
+
345
+ def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
346
+ if isinstance(entry, dict):
347
+ payload = dict(entry)
348
+ payload.setdefault("role", role)
349
+ return payload
350
+ return {
351
+ "role": role,
352
+ "text": self._prompt_text(entry),
353
+ "content": entry,
354
+ }
355
+
356
+ def _prompt_content(self, entry: Any, *, role: str) -> Any:
357
+ payload = self._prompt_payload(entry, role=role)
358
+ return payload.get("content", payload.get("text"))
359
+
360
+ def _content_has_image(self, content: Any) -> bool:
361
+ if isinstance(content, list):
362
+ return any(
363
+ isinstance(seg, dict)
364
+ and seg.get("type") in {"image", "image_url"}
365
+ for seg in content
366
+ )
367
+ if isinstance(content, dict):
368
+ if content.get("type") in {"image", "image_url"}:
369
+ return True
370
+ inner = content.get("content")
371
+ if isinstance(inner, list):
372
+ return any(
373
+ isinstance(seg, dict)
374
+ and seg.get("type") in {"image", "image_url"}
375
+ for seg in inner
376
+ )
377
+ return False
378
+
379
+ def _safe_json(self, payload: Any, limit: int = 4000) -> str:
380
+ try:
381
+ text = json.dumps(payload, ensure_ascii=False)
382
+ except Exception:
383
+ text = str(payload)
384
+ if len(text) > limit:
385
+ return text[:limit] + "…"
386
+ return text
387
+
388
+ async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
389
+ if tool_calls is None:
390
+ return
391
+ if self.enabled and self.tracer is not None:
392
+ try:
393
+ await self.tracer.record_message(
394
+ content=self._safe_json(tool_calls),
395
+ message_type="policy_tool_call",
396
+ metadata=self._message_metadata(),
397
+ )
398
+ except Exception as exc:
399
+ logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
400
+
401
+ async def _record_event(self, event: Any) -> int | None:
402
+ if not self.enabled or self.tracer is None:
403
+ return None
404
+ try:
405
+ return await self.tracer.record_event(event)
406
+ except Exception as exc:
407
+ logger.debug("TRACING_EVENT_FAIL: %s", exc)
408
+ return None
409
+
410
+ async def record_llm_call(
411
+ self,
412
+ *,
413
+ inference_request: dict[str, Any],
414
+ inference_response: dict[str, Any],
415
+ tool_calls: list[dict[str, Any]] | None,
416
+ provider: str,
417
+ model_name: str,
418
+ started_at: datetime,
419
+ completed_at: datetime,
420
+ latency_ms: int | None,
421
+ ) -> None:
422
+ usage = inference_response.get("usage") or {}
423
+ input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
424
+ output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
425
+ total_tokens = usage.get("total_tokens")
426
+ cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
427
+
428
+ assistant_message = None
429
+ choices = inference_response.get("choices") or []
430
+ if choices:
431
+ assistant_message = choices[0].get("message") or {}
432
+ assistant_content = (
433
+ assistant_message.get("content") if isinstance(assistant_message, dict) else None
434
+ )
435
+
436
+ raw_response = self._content_to_text(assistant_content)
437
+ if not raw_response:
438
+ raw_response = self._safe_json(inference_response, limit=2000)
439
+
440
+ base_response = BaseLMResponse(
441
+ raw_response=raw_response,
442
+ tool_calls=assistant_message.get("tool_calls")
443
+ if isinstance(assistant_message, dict)
444
+ else None,
445
+ usage=usage or None,
446
+ api_type="chat_completions",
447
+ )
448
+
449
+ request_messages = inference_request.get("messages") or []
450
+ try:
451
+ temperature = float(inference_request.get("temperature"))
452
+ except Exception:
453
+ temperature = 0.0
454
+
455
+ call_record = create_llm_call_record_from_response(
456
+ response=base_response,
457
+ model_name=model_name,
458
+ provider=provider,
459
+ messages=request_messages,
460
+ temperature=temperature,
461
+ request_params=inference_request,
462
+ tools=inference_request.get("tools"),
463
+ started_at=started_at,
464
+ completed_at=completed_at,
465
+ latency_ms=latency_ms,
466
+ )
467
+
468
+ event_metadata = {
469
+ "policy_id": self.request.policy.policy_id,
470
+ "turn": self.current_turn,
471
+ "run_id": self.run_id,
472
+ }
473
+
474
+ event = LMCAISEvent(
475
+ system_instance_id=f"policy:{self.policy_name or 'unknown'}",
476
+ time_record=TimeRecord(event_time=completed_at.timestamp()),
477
+ model_name=model_name,
478
+ provider=provider,
479
+ input_tokens=input_tokens,
480
+ output_tokens=output_tokens,
481
+ total_tokens=total_tokens,
482
+ cost_usd=cost_usd,
483
+ latency_ms=latency_ms,
484
+ call_records=[call_record],
485
+ metadata=event_metadata,
486
+ )
487
+
488
+ await self._record_event(event)
489
+
490
+ self.lm_calls_summary.append(
491
+ {
492
+ "turn": self.current_turn,
493
+ "model": model_name,
494
+ "provider": provider,
495
+ "total_tokens": total_tokens,
496
+ "input_tokens": input_tokens,
497
+ "output_tokens": output_tokens,
498
+ "latency_ms": latency_ms,
499
+ "tool_calls": len(tool_calls or []),
500
+ }
501
+ )
502
+
503
+ if self.sft_output_dir is not None:
504
+ assistant_structured = assistant_content if assistant_content is not None else ""
505
+ assistant_text = self._content_to_text(assistant_content)
506
+ dialogue_structured: list[dict[str, Any]] = []
507
+ for content in self.latest_system_prompt_content:
508
+ if content is None:
509
+ continue
510
+ dialogue_structured.append({"role": "system", "content": content})
511
+ for content in self.latest_user_prompt_content:
512
+ if content is None:
513
+ continue
514
+ dialogue_structured.append({"role": "user", "content": content})
515
+ dialogue_text = (
516
+ [{"role": "system", "content": s} for s in self.latest_system_messages]
517
+ + [{"role": "user", "content": u} for u in self.latest_user_messages]
518
+ )
519
+ user_has_image = any(
520
+ self._content_has_image(content) for content in self.latest_user_prompt_content
521
+ )
522
+ assistant_has_image = self._content_has_image(assistant_structured)
523
+ record = {
524
+ "run_id": self.run_id,
525
+ "turn": self.current_turn,
526
+ "model": model_name,
527
+ "provider": provider,
528
+ "dialogue": dialogue_structured,
529
+ "dialogue_text": dialogue_text,
530
+ "assistant": {
531
+ "content": assistant_structured,
532
+ "content_text": assistant_text,
533
+ "tool_calls": assistant_message.get("tool_calls")
534
+ if isinstance(assistant_message, dict)
535
+ else [],
536
+ "has_image": assistant_has_image,
537
+ },
538
+ "metadata": {
539
+ "user_has_image": user_has_image,
540
+ "assistant_has_image": assistant_has_image,
541
+ "has_image": user_has_image or assistant_has_image,
542
+ },
543
+ "timestamp": datetime.utcnow().isoformat(),
544
+ }
545
+ self.sft_records.append(record)
546
+
547
+ async def record_environment_event(
548
+ self,
549
+ *,
550
+ env_handle: Any,
551
+ prev_obs: dict[str, Any] | None,
552
+ env_response: Any,
553
+ next_obs: dict[str, Any] | None,
554
+ metadata: dict[str, Any] | None = None,
555
+ ) -> int | None:
556
+ if not self.enabled or self.tracer is None:
557
+ return None
558
+
559
+ try:
560
+ prev_summary = (
561
+ _summarize_observation_for_storage(env_handle, prev_obs or {})
562
+ if prev_obs is not None
563
+ else None
564
+ )
565
+ except Exception:
566
+ prev_summary = None
567
+ try:
568
+ next_summary = (
569
+ _summarize_observation_for_storage(env_handle, next_obs or {})
570
+ if next_obs is not None
571
+ else None
572
+ )
573
+ except Exception:
574
+ next_summary = None
575
+
576
+ reward_val = getattr(env_response, "reward", None)
577
+ try:
578
+ reward_float = float(reward_val) if reward_val is not None else 0.0
579
+ except Exception:
580
+ reward_float = 0.0
581
+
582
+ event = EnvironmentEvent(
583
+ system_instance_id=f"environment:{self.env_name or 'unknown'}",
584
+ time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
585
+ reward=reward_float,
586
+ terminated=bool(getattr(env_response, "done", False)),
587
+ truncated=bool(getattr(env_response, "truncated", False)),
588
+ system_state_before=prev_summary,
589
+ system_state_after=next_summary,
590
+ metadata={
591
+ "turn": self.current_turn,
592
+ "run_id": self.run_id,
593
+ **(metadata or {}),
594
+ },
595
+ )
596
+
597
+ return await self._record_event(event)
598
+
599
+ async def record_decision_reward(
600
+ self,
601
+ *,
602
+ event_id: int | None,
603
+ decision_meta: dict[str, Any] | None,
604
+ ) -> None:
605
+ decision_meta = decision_meta or {}
606
+ ach_delta = int(decision_meta.get("ach_delta", 0))
607
+ unique_delta = int(decision_meta.get("unique_delta", 0))
608
+ all_ach = list(decision_meta.get("all") or [])
609
+ unique_ach = list(decision_meta.get("unique") or [])
610
+
611
+ self.decision_rewards.append(
612
+ {
613
+ "turn": self.current_turn,
614
+ "ach_delta": ach_delta,
615
+ "unique_delta": unique_delta,
616
+ "achievements": all_ach,
617
+ "unique_achievements": unique_ach,
618
+ }
619
+ )
620
+
621
+ if not self.enabled or self.tracer is None or event_id is None:
622
+ return
623
+ try:
624
+ await self.tracer.record_event_reward(
625
+ event_id=event_id,
626
+ turn_number=self.current_turn,
627
+ reward_value=float(ach_delta),
628
+ reward_type="achievement_delta",
629
+ annotation={"achievements": all_ach},
630
+ source="environment",
631
+ )
632
+ if unique_delta:
633
+ await self.tracer.record_event_reward(
634
+ event_id=event_id,
635
+ turn_number=self.current_turn,
636
+ reward_value=float(unique_delta),
637
+ reward_type="unique_achievement_delta",
638
+ annotation={"achievements": unique_ach},
639
+ source="environment",
640
+ )
641
+ except Exception as exc:
642
+ logger.debug("TRACING_REWARD_FAIL: %s", exc)
643
+
644
+ def update_metadata(self, **kwargs: Any) -> None:
645
+ self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
646
+
647
+ async def finalize(
648
+ self,
649
+ *,
650
+ total_reward: float,
651
+ achievement_state: dict[str, bool] | None,
652
+ total_steps: int,
653
+ ) -> Any:
654
+ final_achievements = [key for key, val in (achievement_state or {}).items() if val]
655
+ self.metadata_updates.setdefault("final_achievements", final_achievements)
656
+ if self.enabled and self.tracer is not None:
657
+ try:
658
+ await self.tracer.record_outcome_reward(
659
+ total_reward=int(total_reward),
660
+ achievements_count=len(final_achievements),
661
+ total_steps=int(total_steps),
662
+ reward_metadata=dict(self.metadata_updates),
663
+ )
664
+ except Exception as exc:
665
+ logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
666
+ try:
667
+ self.session_trace = await self.tracer.end_session()
668
+ if self.session_trace is not None:
669
+ self.session_trace.metadata.update(self.metadata_updates)
670
+ except Exception as exc:
671
+ logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
672
+ self.session_trace = None
673
+ with contextlib.suppress(Exception):
674
+ await self.tracer.close()
675
+
676
+ if self.sft_records and self.sft_output_dir:
677
+ self.write_sft_records()
678
+
679
+ # Clear context from request state to avoid leaks
680
+ self.fastapi_request.state.rollout_tracing = None
681
+
682
+ return self.session_trace
683
+
684
+ def write_sft_records(self) -> None:
685
+ if not self.sft_output_dir or not self.sft_records:
686
+ return
687
+ try:
688
+ path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
689
+ path.parent.mkdir(parents=True, exist_ok=True)
690
+ with path.open("w", encoding="utf-8") as fh:
691
+ for record in self.sft_records:
692
+ json.dump(record, fh, ensure_ascii=False)
693
+ fh.write("\n")
694
+ logger.info(f"SFT_WRITTEN: {path}")
695
+ except Exception as exc:
696
+ logger.warning(f"SFT_WRITE_FAIL: {exc}")
697
+ finally:
698
+ self.sft_records.clear()
699
+
700
+ def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
701
+ if not self.return_trace or session_trace is None:
702
+ return None
703
+ if self.trace_format == "full":
704
+ payload = session_trace.to_dict()
705
+ payload.setdefault("metadata", {}).update(self.metadata_updates)
706
+ return payload
707
+ metadata = dict(session_trace.metadata)
708
+ metadata.update(self.metadata_updates)
709
+ return {
710
+ "session_id": session_trace.session_id,
711
+ "created_at": session_trace.created_at.isoformat(),
712
+ "metadata": metadata,
713
+ "events_count": len(session_trace.event_history),
714
+ "messages_count": len(session_trace.markov_blanket_message_history),
715
+ "lm_calls": self.lm_calls_summary,
716
+ "decision_rewards": self.decision_rewards,
717
+ }
718
+
719
+
720
+ def _summarize_observation_for_storage(
721
+ env_handle: Any, observation: dict[str, Any]
722
+ ) -> dict[str, Any]:
723
+ """Return a compact dict for trajectory storage instead of the raw observation.
724
+
725
+ - For Crafter, use the same summary used for the policy user prompt
726
+ - For others, keep a minimal subset or plain text preview
727
+ """
728
+ # Try Crafter-specific formatter
729
+ crafter_wrapper = None
730
+ with contextlib.suppress(Exception):
731
+ from .envs.crafter.environment import (
732
+ CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
733
+ )
734
+
735
+ crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
736
+
737
+ if crafter_wrapper is not None and isinstance(
738
+ getattr(env_handle, "env", None), crafter_wrapper
739
+ ):
740
+ with contextlib.suppress(Exception):
741
+ from .envs.crafter.shared import format_observation as _fmt # type: ignore
742
+
743
+ text = _fmt(observation or {})
744
+ return {"text": text}
745
+
746
+ # Generic fallback: extract a few small fields if present; avoid huge arrays
747
+ with contextlib.suppress(Exception):
748
+ inv = observation.get("inventory") if isinstance(observation, dict) else None
749
+ ach = observation.get("achievements_status") if isinstance(observation, dict) else None
750
+ pos = observation.get("player_position") if isinstance(observation, dict) else None
751
+ health = None
752
+ if isinstance(inv, dict):
753
+ health = inv.get("health")
754
+ summary = {
755
+ "position": pos,
756
+ "health": health,
757
+ "inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
758
+ if isinstance(inv, dict)
759
+ else None,
760
+ "achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
761
+ if isinstance(ach, dict)
762
+ else None,
763
+ }
764
+ return {"text": json.dumps(summary, ensure_ascii=False)}
765
+
766
+ # Last resort: plain string preview
767
+ try:
768
+ return {"text": str(observation)[:10000]}
769
+ except Exception:
770
+ return {"text": ""}
771
+
772
+
773
+ class RunAbortRequest(BaseModel):
774
+ run_id: str
775
+
776
+
777
+ class RunAbortResponse(BaseModel):
778
+ ok: bool
779
+ run_id: str
780
+
781
+
782
+ class RunStatusResponse(BaseModel):
783
+ run_id: str
784
+ status: str
785
+ started_at: datetime
786
+ finished_at: datetime | None = None
787
+
788
+
789
+ @router.post("/rollout", response_model=RolloutResponse)
790
+ async def execute_rollout(
791
+ request: RolloutRequest,
792
+ req: Request,
793
+ ) -> RolloutResponse:
794
+ """Execute a rollout with coordinated environment and policy steps."""
795
+ # Emit rollout identifier early for correlation
796
+ with contextlib.suppress(Exception):
797
+ _rid = getattr(request, "run_id", None)
798
+ _pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
799
+ _env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
800
+ logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
801
+ print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
802
+ # Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
803
+ try:
804
+ _env_params = {}
805
+ if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
806
+ _env_params = dict(request.env.config.get("env_params") or {})
807
+ max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
808
+ assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
809
+ except Exception as _mse:
810
+ raise HTTPException(
811
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
812
+ detail={
813
+ "error": "invalid_env_params",
814
+ "message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
815
+ },
816
+ ) from _mse
817
+ # Truncate incoming ops to the enforced cap (each step is [agent, env])
818
+ ops_seq: list[str] = list(request.ops or [])
819
+ allowed_ops = max(0, int(max_steps_per_episode) * 2)
820
+ if len(ops_seq) > allowed_ops:
821
+ with contextlib.suppress(Exception):
822
+ logger.info(
823
+ "ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
824
+ str(len(ops_seq)),
825
+ str(allowed_ops),
826
+ )
827
+ ops_seq = ops_seq[:allowed_ops]
828
+ # Simple API key auth for inbound rollout
829
+ header_key = req.headers.get("x-api-key")
830
+ env_key = os.getenv("ENVIRONMENT_API_KEY")
831
+ dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
832
+ # Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
833
+ expected_keys = [k for k in (env_key, dev_key) if k]
834
+ if not expected_keys:
835
+ missing = []
836
+ if not env_key:
837
+ missing.append("ENVIRONMENT_API_KEY")
838
+ if not dev_key:
839
+ missing.append("DEV_ENVIRONMENT_API_KEY")
840
+ msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
841
+ logger.error(msg)
842
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
843
+ if not header_key:
844
+ raise HTTPException(
845
+ status_code=status.HTTP_401_UNAUTHORIZED,
846
+ detail="Invalid or missing API key: X-API-Key header not provided",
847
+ )
848
+ if header_key not in expected_keys:
849
+ # Do not leak secrets; include short prefix for diagnostics
850
+ exp_src = env_key if env_key else (dev_key or "")
851
+ exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
852
+ got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
853
+ raise HTTPException(
854
+ status_code=status.HTTP_401_UNAUTHORIZED,
855
+ detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
856
+ )
857
+
858
+ # Log contextual fields for traceability
859
+ if request.training_session_id:
860
+ logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
861
+ if request.synth_base_url:
862
+ logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
863
+
864
+ # Log masked OpenAI API key presence for diagnostics
865
+ with contextlib.suppress(Exception):
866
+ _oa = os.getenv("OPENAI_API_KEY")
867
+ if _oa:
868
+ _pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
869
+ logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
870
+ else:
871
+ logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
872
+
873
+ # Make synth_base_url available for outbound calls in this app
874
+ with contextlib.suppress(Exception):
875
+ task_app = req.app.state.task_app
876
+ if request.synth_base_url:
877
+ task_app.synth_base_url = request.synth_base_url
878
+
879
+ tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
880
+ tracer_instance: SessionTracer | None = None
881
+ if callable(tracer_factory):
882
+ try:
883
+ inst = tracer_factory()
884
+ tracer_instance = inst if isinstance(inst, SessionTracer) else None
885
+ except Exception as exc:
886
+ logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
887
+ tracing_context = RolloutTracingContext(tracer_instance, request, req)
888
+ await tracing_context.start_session()
889
+
890
+ # Register run
891
+ registry.register_run(request.run_id)
892
+
893
+ # Track resources created during this rollout so we can guarantee cleanup
894
+ created_env_id: str | None = None
895
+ created_policy_id: str | None = None
896
+ env_seed_used: int | None = None
897
+ trajectory_steps: list[RolloutStep] = []
898
+ decision_samples: list[dict[str, Any]] = []
899
+ pending_tool_calls: Any = None
900
+ current_obs: Any = {}
901
+ total_reward: float = 0.0
902
+ ops_executed = 0
903
+ last_agent_response_ts: float | None = None
904
+ last_policy_meta: dict[str, Any] | None = None
905
+ last_env_step_ms: float | None = None
906
+ last_env_step_completed_ts: float | None = None
907
+ decision_open = False
908
+ finalized = False
909
+ prev_achievements: dict[str, bool] = {}
910
+ session_trace = None
911
+ step_rewards_active = False
912
+
913
+ try:
914
+ # Initialize deterministic seed early for the entire rollout
915
+ seed_value: int | None = None
916
+ try:
917
+ if request.env and request.env.seed is not None:
918
+ seed_value = int(request.env.seed)
919
+ else:
920
+ # Derive a stable seed from run_id
921
+ import hashlib as _hashlib # local import to avoid global deps
922
+
923
+ _digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
924
+ # Use lower 32 bits to fit common RNG ranges
925
+ seed_value = int(_digest[:8], 16)
926
+ except Exception:
927
+ # Fallback to time-based seed if anything goes wrong
928
+ try:
929
+ seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
930
+ except Exception:
931
+ seed_value = 42
932
+
933
+ _seed_info = _set_global_seed(int(seed_value))
934
+ with contextlib.suppress(Exception):
935
+ logger.info(
936
+ "ROLL_OUT: RNG seeded seed=%s libs=%s",
937
+ str(_seed_info.get("seed")),
938
+ ",".join(_seed_info.get("libs", [])),
939
+ )
940
+ # Resolve or create environment
941
+ if request.env.env_id:
942
+ env_handle = registry.get_env(request.env.env_id)
943
+ if not env_handle:
944
+ raise HTTPException(
945
+ status_code=404,
946
+ detail=f"Environment {request.env.env_id} not found",
947
+ )
948
+ env_id = request.env.env_id
949
+ else:
950
+ # Create new environment
951
+ from .environment_routes import EnvCreateRequest, create_environment
952
+
953
+ if not request.env.env_name:
954
+ raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
955
+
956
+ # Propagate training_session_id via env config for downstream usage
957
+ _env_config = dict(request.env.config or {})
958
+ if request.training_session_id is not None:
959
+ _env_config.setdefault("training_session_id", request.training_session_id)
960
+ env_response = await create_environment(
961
+ EnvCreateRequest(
962
+ env_name=request.env.env_name,
963
+ config=_env_config,
964
+ seed=request.env.seed,
965
+ rl_run_id=request.run_id,
966
+ )
967
+ )
968
+ env_id = env_response.env_id
969
+ env_handle = registry.get_env(env_id)
970
+ created_env_id = env_id
971
+
972
+ tracing_context.update_metadata(env_id=env_id)
973
+
974
+ # Resolve or create policy
975
+ if request.policy.policy_id:
976
+ policy_handle = registry.get_policy(request.policy.policy_id)
977
+ if not policy_handle:
978
+ raise HTTPException(
979
+ status_code=404,
980
+ detail=f"Policy {request.policy.policy_id} not found",
981
+ )
982
+ policy_id = request.policy.policy_id
983
+ else:
984
+ # Create new policy
985
+ from .policy_routes import PolicyCreateRequest, create_policy
986
+
987
+ if not request.policy.policy_name:
988
+ raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
989
+
990
+ # Propagate training_session_id and synth_base_url via policy config
991
+ _policy_config = dict(request.policy.config or {})
992
+ if request.training_session_id is not None:
993
+ _policy_config.setdefault("training_session_id", request.training_session_id)
994
+ if request.synth_base_url is not None:
995
+ _policy_config.setdefault("synth_base_url", request.synth_base_url)
996
+ policy_response = await create_policy(
997
+ PolicyCreateRequest(
998
+ policy_name=request.policy.policy_name,
999
+ config=_policy_config,
1000
+ rl_run_id=request.run_id,
1001
+ bound_env_id=env_id,
1002
+ ),
1003
+ req,
1004
+ )
1005
+ policy_id = policy_response.policy_id
1006
+ policy_handle = registry.get_policy(policy_id)
1007
+ created_policy_id = policy_id
1008
+
1009
+ tracing_context.update_metadata(policy_id=policy_id)
1010
+
1011
+ # Bind policy to environment if not already bound
1012
+ if policy_handle and not policy_handle.bound_env_id:
1013
+ policy_handle.bound_env_id = env_id
1014
+
1015
+ # Record seed bound to environment for end-of-rollout verification/logging
1016
+ try:
1017
+ env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
1018
+ except Exception:
1019
+ env_seed_used = None
1020
+ tracing_context.update_metadata(env_seed=env_seed_used)
1021
+ # Initialize trajectory
1022
+ trajectory_steps = []
1023
+ pending_tool_calls = None
1024
+ current_obs = env_handle.last_observation
1025
+ total_reward = 0.0
1026
+ ops_executed = 0
1027
+ last_agent_response_ts = None
1028
+ last_policy_meta = None
1029
+ last_env_step_ms = None
1030
+ last_env_step_completed_ts = None
1031
+
1032
+ # Stepwise reward configuration (Crafter shaping; gate on explicit enable)
1033
+ step_rewards_cfg_raw: dict[str, Any] = {}
1034
+ try:
1035
+ if isinstance(request.policy.config, dict):
1036
+ step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
1037
+ except Exception:
1038
+ step_rewards_cfg_raw = {}
1039
+ if not step_rewards_cfg_raw:
1040
+ try:
1041
+ if isinstance(request.env.config, dict):
1042
+ step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
1043
+ except Exception:
1044
+ step_rewards_cfg_raw = {}
1045
+
1046
+ step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
1047
+ step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
1048
+ try:
1049
+ step_rewards_indicator_lambda = float(
1050
+ step_rewards_cfg_raw.get("indicator_lambda") or 0.0
1051
+ )
1052
+ except Exception:
1053
+ step_rewards_indicator_lambda = 0.0
1054
+ try:
1055
+ step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
1056
+ except Exception:
1057
+ step_rewards_beta = 0.0
1058
+ step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
1059
+
1060
+ def _extract_achievements(obs: Any) -> dict[str, bool]:
1061
+ if not isinstance(obs, dict):
1062
+ return {}
1063
+ ach = obs.get("achievements_status")
1064
+ if isinstance(ach, dict):
1065
+ return {str(k): bool(v) for k, v in ach.items()}
1066
+ return {}
1067
+
1068
+ def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
1069
+ if not tool_calls:
1070
+ return []
1071
+ try:
1072
+ items = (
1073
+ tool_calls
1074
+ if isinstance(tool_calls, list)
1075
+ else list(tool_calls) # tolerates tuples or pydantic lists
1076
+ )
1077
+ except Exception:
1078
+ return []
1079
+ summary: list[dict[str, Any]] = []
1080
+ for tc in items:
1081
+ tool_name = None
1082
+ args: Any = {}
1083
+ if isinstance(tc, dict):
1084
+ tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
1085
+ raw_args = tc.get("arguments") or tc.get("args") or {}
1086
+ else:
1087
+ tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
1088
+ raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
1089
+ args = raw_args
1090
+ if isinstance(raw_args, str):
1091
+ try:
1092
+ args = json.loads(raw_args)
1093
+ except Exception:
1094
+ args = raw_args
1095
+ summary.append({"tool": tool_name, "args": args})
1096
+ return summary
1097
+
1098
+ decision_samples: list[dict[str, Any]] = []
1099
+ decision_index = 0
1100
+ decision_open = False
1101
+ session_trace = None
1102
+ finalized = False
1103
+ prev_achievements = _extract_achievements(current_obs)
1104
+ # Track episode-level achievements that have been seen as true at any point so far
1105
+ episode_seen_achievements: set[str] = {
1106
+ k for k, v in (prev_achievements or {}).items() if bool(v)
1107
+ }
1108
+ stepwise_indicator_sum = 0.0
1109
+ stepwise_reward_sum = 0.0
1110
+ stepwise_new_achievements_total = 0
1111
+ final_achievement_count = sum(1 for v in prev_achievements.values() if v)
1112
+
1113
+ # Execute ops sequence (capped by env_params.max_steps_per_episode)
1114
+ for op_idx, op in enumerate(ops_seq):
1115
+ # Check for abort
1116
+ if registry.is_run_aborted(request.run_id):
1117
+ logger.info(f"Run {request.run_id} aborted at op {op_idx}")
1118
+ break
1119
+
1120
+ # Check safety limits
1121
+ if ops_executed >= request.safety.max_ops:
1122
+ logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
1123
+ break
1124
+
1125
+ if op == "agent":
1126
+ # Policy step
1127
+ from .policy_routes import PolicyStepRequest, step_policy
1128
+
1129
+ if not decision_open:
1130
+ await tracing_context.start_decision(decision_index)
1131
+ decision_open = True
1132
+
1133
+ agent_request_start = _time.perf_counter()
1134
+ if last_agent_response_ts is not None and last_policy_meta is not None:
1135
+ with contextlib.suppress(Exception):
1136
+ timing_prev = last_policy_meta.setdefault("timing", {})
1137
+ decision_ms = max(
1138
+ 0.0,
1139
+ (agent_request_start - float(last_agent_response_ts)) * 1000.0,
1140
+ )
1141
+ # Update timing on prior policy meta (kept by previous env step)
1142
+ timing_prev["decision_ms"] = decision_ms
1143
+ if last_env_step_ms is not None:
1144
+ timing_prev["env_step_ms"] = float(last_env_step_ms)
1145
+ timing_prev["overhead_ms"] = max(
1146
+ 0.0, decision_ms - float(last_env_step_ms)
1147
+ )
1148
+ else:
1149
+ timing_prev.setdefault("overhead_ms", 0.0)
1150
+ timing_prev["decision_ready_s"] = agent_request_start
1151
+ # Also backfill the last appended trajectory step so the trainer
1152
+ # can always see decision_ms without relying on shared dict refs.
1153
+ if trajectory_steps:
1154
+ with contextlib.suppress(Exception):
1155
+ _last = trajectory_steps[-1]
1156
+ _info = dict(_last.info or {})
1157
+ _meta = dict(_info.get("meta") or {})
1158
+ _timing = dict(_meta.get("timing") or {})
1159
+ _timing["decision_ms"] = decision_ms
1160
+ if last_env_step_ms is not None:
1161
+ _timing.setdefault("env_step_ms", float(last_env_step_ms))
1162
+ _timing.setdefault(
1163
+ "overhead_ms",
1164
+ max(0.0, decision_ms - float(last_env_step_ms)),
1165
+ )
1166
+ else:
1167
+ _timing.setdefault("overhead_ms", 0.0)
1168
+ _meta["timing"] = _timing
1169
+ _info["meta"] = _meta
1170
+ _last.info = _info
1171
+ last_env_step_ms = None
1172
+ last_env_step_completed_ts = None
1173
+
1174
+ # Build metadata for policy (carry previous tool_calls and env result)
1175
+ metadata = {}
1176
+ if pending_tool_calls:
1177
+ metadata["prev_tool_calls"] = pending_tool_calls
1178
+ if len(trajectory_steps) > 0:
1179
+ last_step = trajectory_steps[-1]
1180
+ # Prefer the last executed tool calls to seed history
1181
+ if last_step.tool_calls:
1182
+ metadata["prev_tool_calls"] = last_step.tool_calls
1183
+ # Provide a compact env result snapshot
1184
+ metadata["prev_env_result"] = {
1185
+ "observation": last_step.obs,
1186
+ "reward": last_step.reward,
1187
+ "done": last_step.done,
1188
+ "truncated": last_step.truncated,
1189
+ "info": last_step.info,
1190
+ }
1191
+
1192
+ # Log compact metadata summary to confirm history threading
1193
+ with contextlib.suppress(Exception):
1194
+ _prev_calls = metadata.get("prev_tool_calls")
1195
+ _count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
1196
+ _first_guess = None
1197
+ if _count > 0 and isinstance(_prev_calls[0], dict):
1198
+ _args = _prev_calls[0].get("arguments", None)
1199
+ if isinstance(_args, str):
1200
+ import json as _json
1201
+ with contextlib.suppress(Exception):
1202
+ _args = _json.loads(_args)
1203
+ if not isinstance(_args, dict):
1204
+ _args = {}
1205
+ _first_guess = _args.get("guess") or _args.get("word")
1206
+ logger.info(
1207
+ "POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
1208
+ _count,
1209
+ _first_guess,
1210
+ str("prev_env_result" in metadata),
1211
+ )
1212
+
1213
+ try:
1214
+ policy_response = await step_policy(
1215
+ PolicyStepRequest(
1216
+ policy_id=policy_id,
1217
+ observation=current_obs,
1218
+ metadata=metadata,
1219
+ ),
1220
+ req,
1221
+ )
1222
+ except Exception as _pe:
1223
+ # Do not 500 the rollout; finalize with partial trajectory
1224
+ with contextlib.suppress(Exception):
1225
+ logger.warning(
1226
+ "POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1227
+ request.run_id,
1228
+ str(op_idx),
1229
+ str(_pe),
1230
+ )
1231
+
1232
+ # Build partial trajectory and return HTTP 200
1233
+ trajectory = RolloutTrajectory(
1234
+ env_id=env_id,
1235
+ policy_id=policy_id,
1236
+ steps=trajectory_steps,
1237
+ final={
1238
+ "observation": current_obs,
1239
+ "rollout_status": "partial_policy_error",
1240
+ "error": str(_pe),
1241
+ "at_op": op,
1242
+ },
1243
+ length=len(trajectory_steps),
1244
+ decision_samples=decision_samples if step_rewards_active else None,
1245
+ )
1246
+ metrics = RolloutMetrics(
1247
+ episode_returns=[total_reward],
1248
+ mean_return=total_reward,
1249
+ num_steps=len(trajectory_steps),
1250
+ num_episodes=1,
1251
+ )
1252
+ aborted = registry.is_run_aborted(request.run_id)
1253
+ if not aborted:
1254
+ registry.complete_run(request.run_id)
1255
+ if decision_open:
1256
+ await tracing_context.end_decision()
1257
+ decision_open = False
1258
+ if not finalized:
1259
+ session_trace = await tracing_context.finalize(
1260
+ total_reward=total_reward,
1261
+ achievement_state=prev_achievements,
1262
+ total_steps=len(trajectory_steps),
1263
+ )
1264
+ finalized = True
1265
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1266
+ return RolloutResponse(
1267
+ run_id=request.run_id,
1268
+ trajectories=[trajectory],
1269
+ branches={},
1270
+ metrics=metrics,
1271
+ aborted=aborted,
1272
+ ops_executed=ops_executed,
1273
+ trace=trace_payload,
1274
+ )
1275
+
1276
+ agent_response_ts = _time.perf_counter()
1277
+ if isinstance(policy_response.meta, dict):
1278
+ with contextlib.suppress(Exception):
1279
+ timing_cur = policy_response.meta.setdefault("timing", {})
1280
+ timing_cur["agent_request_start_s"] = agent_request_start
1281
+ timing_cur["agent_response_s"] = agent_response_ts
1282
+ if "inference_ms" in policy_response.meta:
1283
+ with contextlib.suppress(Exception):
1284
+ timing_cur.setdefault(
1285
+ "inference_ms",
1286
+ float(policy_response.meta["inference_ms"]),
1287
+ )
1288
+ timing_cur.setdefault(
1289
+ "inference_s",
1290
+ float(policy_response.meta["inference_ms"]) / 1000.0,
1291
+ )
1292
+ last_policy_meta = policy_response.meta
1293
+ else:
1294
+ last_policy_meta = None
1295
+ last_agent_response_ts = agent_response_ts
1296
+
1297
+ # Diagnostic: summarize policy step target and tool calls
1298
+ try:
1299
+ model_name = None
1300
+ target_url = None
1301
+ if isinstance(policy_response.meta, dict):
1302
+ req_body = policy_response.meta.get("inference_request") or {}
1303
+ model_name = req_body.get("model")
1304
+ target_url = policy_response.meta.get("inference_url")
1305
+ _tc = policy_response.tool_calls or []
1306
+ print(
1307
+ {
1308
+ "rollout.policy_step": True,
1309
+ "run_id": request.run_id,
1310
+ "model": model_name,
1311
+ "inference_url": target_url,
1312
+ "tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
1313
+ },
1314
+ flush=True,
1315
+ )
1316
+ except Exception:
1317
+ pass
1318
+
1319
+ pending_tool_calls = policy_response.tool_calls
1320
+ # Log summarized agent tool calls
1321
+ with contextlib.suppress(Exception):
1322
+ _tc = pending_tool_calls or []
1323
+ _summary = []
1324
+ for _item in (_tc if isinstance(_tc, list) else []):
1325
+ try:
1326
+ if isinstance(_item, dict):
1327
+ _tool = _item.get("tool")
1328
+ _args = _item.get("args")
1329
+ _keys = list(_args.keys()) if isinstance(_args, dict) else []
1330
+ _summary.append({"tool": _tool, "args_keys": _keys})
1331
+ except Exception:
1332
+ continue
1333
+ _rid = getattr(request, "run_id", None)
1334
+ logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
1335
+ print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
1336
+ await tracing_context.record_tool_invocation(pending_tool_calls)
1337
+ ops_executed += 1
1338
+
1339
+ elif op == "env":
1340
+ if not pending_tool_calls:
1341
+ # Treat absence of tool calls as a soft terminal condition; yield partial trajectory
1342
+ with contextlib.suppress(Exception):
1343
+ logger.warning(
1344
+ "NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
1345
+ request.run_id,
1346
+ str(op_idx),
1347
+ )
1348
+ print(
1349
+ f"[rollout] no tool_calls; terminating early run_id={request.run_id} op_idx={op_idx}",
1350
+ flush=True,
1351
+ )
1352
+ term_step = RolloutStep(
1353
+ obs=current_obs,
1354
+ tool_calls=[],
1355
+ reward=None,
1356
+ done=True,
1357
+ truncated=False,
1358
+ info={
1359
+ "terminated": True,
1360
+ "reason": "no_tool_calls",
1361
+ },
1362
+ )
1363
+ trajectory_steps.append(term_step)
1364
+ trajectory = RolloutTrajectory(
1365
+ env_id=env_id,
1366
+ policy_id=policy_id,
1367
+ steps=trajectory_steps,
1368
+ final={
1369
+ "observation": current_obs,
1370
+ "rollout_status": "partial_no_tool_calls",
1371
+ "at_op": op,
1372
+ },
1373
+ length=len(trajectory_steps),
1374
+ decision_samples=decision_samples if step_rewards_active else None,
1375
+ )
1376
+ metrics = RolloutMetrics(
1377
+ episode_returns=[total_reward],
1378
+ mean_return=total_reward,
1379
+ num_steps=len(trajectory_steps),
1380
+ num_episodes=1,
1381
+ )
1382
+ aborted = registry.is_run_aborted(request.run_id)
1383
+ if not aborted:
1384
+ registry.complete_run(request.run_id)
1385
+ if decision_open:
1386
+ await tracing_context.end_decision()
1387
+ decision_open = False
1388
+ if not finalized:
1389
+ session_trace = await tracing_context.finalize(
1390
+ total_reward=total_reward,
1391
+ achievement_state=prev_achievements,
1392
+ total_steps=len(trajectory_steps),
1393
+ )
1394
+ finalized = True
1395
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1396
+ return RolloutResponse(
1397
+ run_id=request.run_id,
1398
+ trajectories=[trajectory],
1399
+ branches={},
1400
+ metrics=metrics,
1401
+ aborted=aborted,
1402
+ ops_executed=ops_executed,
1403
+ trace=trace_payload,
1404
+ )
1405
+
1406
+ # Environment step
1407
+ from .environment_routes import EnvStepRequest, step_environment
1408
+
1409
+ env_step_error: Exception | None = None
1410
+ env_response = None
1411
+ env_step_start = _time.perf_counter()
1412
+ try:
1413
+ env_response = await step_environment(
1414
+ EnvStepRequest(
1415
+ env_id=env_id,
1416
+ tool_calls=pending_tool_calls,
1417
+ )
1418
+ )
1419
+ except Exception as _ee:
1420
+ env_step_error = _ee
1421
+ env_step_end = _time.perf_counter()
1422
+ env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
1423
+ last_env_step_ms = env_step_duration_ms
1424
+ last_env_step_completed_ts = env_step_end
1425
+ if last_policy_meta is not None:
1426
+ with contextlib.suppress(Exception):
1427
+ timing_env = last_policy_meta.setdefault("timing", {})
1428
+ timing_env["env_step_ms"] = env_step_duration_ms
1429
+ timing_env["env_step_end_s"] = env_step_end
1430
+
1431
+ if env_step_error is not None:
1432
+ # Invalid action or environment rejection — terminate episode early with partial trajectory
1433
+ with contextlib.suppress(Exception):
1434
+ logger.warning(
1435
+ "ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1436
+ request.run_id,
1437
+ str(op_idx),
1438
+ str(env_step_error),
1439
+ )
1440
+
1441
+ term_step = RolloutStep(
1442
+ obs=current_obs,
1443
+ tool_calls=pending_tool_calls,
1444
+ reward=None,
1445
+ done=True,
1446
+ truncated=False,
1447
+ info={
1448
+ "terminated": True,
1449
+ "reason": "invalid_action",
1450
+ "error": str(env_step_error),
1451
+ },
1452
+ )
1453
+ trajectory_steps.append(term_step)
1454
+ # Build partial response
1455
+ trajectory = RolloutTrajectory(
1456
+ env_id=env_id,
1457
+ policy_id=policy_id,
1458
+ steps=trajectory_steps,
1459
+ final={
1460
+ "observation": current_obs,
1461
+ "rollout_status": "partial_invalid_action",
1462
+ "error": str(env_step_error),
1463
+ "at_op": op,
1464
+ },
1465
+ length=len(trajectory_steps),
1466
+ decision_samples=decision_samples if step_rewards_active else None,
1467
+ )
1468
+ metrics = RolloutMetrics(
1469
+ episode_returns=[total_reward],
1470
+ mean_return=total_reward,
1471
+ num_steps=len(trajectory_steps),
1472
+ num_episodes=1,
1473
+ )
1474
+ aborted = registry.is_run_aborted(request.run_id)
1475
+ if not aborted:
1476
+ registry.complete_run(request.run_id)
1477
+ if (
1478
+ last_policy_meta is not None
1479
+ and last_agent_response_ts is not None
1480
+ and "decision_ms" not in last_policy_meta.get("timing", {})
1481
+ ):
1482
+ with contextlib.suppress(Exception):
1483
+ timing_last = last_policy_meta.setdefault("timing", {})
1484
+ decision_ms = max(
1485
+ 0.0,
1486
+ (env_step_end - float(last_agent_response_ts)) * 1000.0,
1487
+ )
1488
+ timing_last["decision_ms"] = decision_ms
1489
+ timing_last.setdefault(
1490
+ "overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
1491
+ )
1492
+ if decision_open:
1493
+ await tracing_context.end_decision()
1494
+ decision_open = False
1495
+ if not finalized:
1496
+ session_trace = await tracing_context.finalize(
1497
+ total_reward=total_reward,
1498
+ achievement_state=prev_achievements,
1499
+ total_steps=len(trajectory_steps),
1500
+ )
1501
+ finalized = True
1502
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1503
+ return RolloutResponse(
1504
+ run_id=request.run_id,
1505
+ trajectories=[trajectory],
1506
+ branches={},
1507
+ metrics=metrics,
1508
+ aborted=aborted,
1509
+ ops_executed=ops_executed,
1510
+ trace=trace_payload,
1511
+ )
1512
+
1513
+ # Reaching here means env step succeeded
1514
+ assert env_response is not None
1515
+
1516
+ # Record step, including policy meta if present for timing/tokens observability
1517
+ _info = env_response.info if isinstance(env_response.info, dict) else {}
1518
+ # Attach policy meta from the immediately preceding agent step
1519
+ with contextlib.suppress(Exception):
1520
+ prev_meta = {}
1521
+ if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
1522
+ prev_meta = policy_response.meta
1523
+ if prev_meta:
1524
+ _info = dict(_info)
1525
+ _info["meta"] = prev_meta
1526
+
1527
+ event_metadata = {
1528
+ "op_index": op_idx,
1529
+ }
1530
+ event_id = await tracing_context.record_environment_event(
1531
+ env_handle=env_handle,
1532
+ prev_obs=current_obs,
1533
+ env_response=env_response,
1534
+ next_obs=getattr(env_response, "observation", None),
1535
+ metadata=event_metadata,
1536
+ )
1537
+
1538
+ decision_index += 1
1539
+ next_obs = env_response.observation
1540
+ new_achievement_state = _extract_achievements(next_obs)
1541
+ final_achievement_count = sum(
1542
+ 1 for _, unlocked in new_achievement_state.items() if unlocked
1543
+ )
1544
+ indicator_val = 0
1545
+ reward_stepwise = 0.0
1546
+ decision_rewards_meta: dict[str, Any] | None = None
1547
+ if step_rewards_active:
1548
+ decision_actions = _summarize_tool_calls(pending_tool_calls)
1549
+ stepwise_info, decision_record, stats = compute_stepwise_reward(
1550
+ prev_achievements or {},
1551
+ new_achievement_state,
1552
+ decision_index,
1553
+ decision_actions,
1554
+ step_rewards_indicator_lambda,
1555
+ )
1556
+ indicator_val = int(stats.get("indicator", 0.0))
1557
+ reward_stepwise = float(stats.get("reward", 0.0))
1558
+ stepwise_indicator_sum += float(stats.get("indicator", 0.0))
1559
+ stepwise_reward_sum += reward_stepwise
1560
+ stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
1561
+ _info = {} if not isinstance(_info, dict) else dict(_info)
1562
+ _info["stepwise"] = stepwise_info
1563
+ # Compute decision-level rewards (absolute vs unique) and attach to metadata
1564
+ with contextlib.suppress(Exception):
1565
+ turned_true = set(stepwise_info.get("new_achievements") or [])
1566
+ seen_before = set(episode_seen_achievements)
1567
+ new_unique = sorted(turned_true - seen_before)
1568
+ ach_delta = int(len(turned_true))
1569
+ unique_delta = int(len(new_unique))
1570
+ # Prepare stable lists for logging/metadata
1571
+ all_list = sorted(turned_true)
1572
+ # Ensure nested meta exists
1573
+ meta_block = (
1574
+ _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
1575
+ )
1576
+ decision_rewards = {
1577
+ "turn": int(decision_index),
1578
+ "ach_delta": ach_delta,
1579
+ "unique_delta": unique_delta,
1580
+ "all": all_list,
1581
+ "unique": new_unique,
1582
+ }
1583
+ decision_rewards_meta = decision_rewards
1584
+ meta_block["decision_rewards"] = decision_rewards
1585
+ _info["meta"] = meta_block
1586
+ # Update episode-level seen set after attributing uniqueness to this decision
1587
+ episode_seen_achievements.update(turned_true)
1588
+ decision_samples.append(decision_record)
1589
+ prev_achievements = new_achievement_state
1590
+
1591
+ await tracing_context.record_decision_reward(
1592
+ event_id=event_id,
1593
+ decision_meta=decision_rewards_meta,
1594
+ )
1595
+
1596
+ step = RolloutStep(
1597
+ obs=_summarize_observation_for_storage(env_handle, current_obs),
1598
+ tool_calls=pending_tool_calls,
1599
+ reward=env_response.reward,
1600
+ done=env_response.done,
1601
+ truncated=env_response.truncated,
1602
+ info=_info,
1603
+ )
1604
+ # Log summarized env application of tool calls and immediate reward/done
1605
+ with contextlib.suppress(Exception):
1606
+ _tc = pending_tool_calls or []
1607
+ _summary = []
1608
+ for _item in (_tc if isinstance(_tc, list) else []):
1609
+ try:
1610
+ if isinstance(_item, dict):
1611
+ _tool = _item.get("tool")
1612
+ _args = _item.get("args")
1613
+ _keys = list(_args.keys()) if isinstance(_args, dict) else []
1614
+ _summary.append({"tool": _tool, "args_keys": _keys})
1615
+ except Exception:
1616
+ continue
1617
+ _rid = getattr(request, "run_id", None)
1618
+ logger.info(
1619
+ "ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
1620
+ _rid,
1621
+ len(_tc),
1622
+ str(env_response.reward),
1623
+ str(env_response.done),
1624
+ _summary,
1625
+ )
1626
+ print(
1627
+ f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
1628
+ flush=True,
1629
+ )
1630
+ trajectory_steps.append(step)
1631
+
1632
+ if env_response.reward is not None:
1633
+ total_reward += env_response.reward
1634
+
1635
+ # Update state
1636
+ current_obs = next_obs
1637
+ pending_tool_calls = None
1638
+ ops_executed += 1
1639
+
1640
+ # Handle episode end
1641
+ if env_response.done:
1642
+ if request.on_done == "reset":
1643
+ # Reset environment
1644
+ from .environment_routes import (
1645
+ EnvResetRequest,
1646
+ reset_environment,
1647
+ )
1648
+
1649
+ reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
1650
+ current_obs = reset_response.observation
1651
+ elif request.on_done == "terminate":
1652
+ break
1653
+
1654
+ if decision_open:
1655
+ await tracing_context.end_decision()
1656
+ decision_open = False
1657
+
1658
+ else:
1659
+ logger.warning(f"Unknown op: {op}")
1660
+
1661
+ if (
1662
+ last_policy_meta is not None
1663
+ and last_agent_response_ts is not None
1664
+ and "timing" in last_policy_meta
1665
+ and isinstance(last_policy_meta["timing"], dict)
1666
+ and "decision_ms" not in last_policy_meta["timing"]
1667
+ ):
1668
+ with contextlib.suppress(Exception):
1669
+ final_now = last_env_step_completed_ts or _time.perf_counter()
1670
+ final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
1671
+ timing_final = last_policy_meta.setdefault("timing", {})
1672
+ timing_final["decision_ms"] = final_decision_ms
1673
+ if last_env_step_ms is not None:
1674
+ timing_final.setdefault("env_step_ms", float(last_env_step_ms))
1675
+ timing_final.setdefault(
1676
+ "overhead_ms",
1677
+ max(0.0, final_decision_ms - float(last_env_step_ms)),
1678
+ )
1679
+ else:
1680
+ timing_final.setdefault("overhead_ms", 0.0)
1681
+
1682
+ # Build trajectory
1683
+ trajectory = RolloutTrajectory(
1684
+ env_id=env_id,
1685
+ policy_id=policy_id,
1686
+ steps=trajectory_steps,
1687
+ final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
1688
+ length=len(trajectory_steps),
1689
+ decision_samples=decision_samples if step_rewards_active else None,
1690
+ )
1691
+
1692
+ # Build metrics
1693
+ metrics = RolloutMetrics(
1694
+ episode_returns=[total_reward],
1695
+ mean_return=total_reward,
1696
+ num_steps=len(trajectory_steps),
1697
+ num_episodes=1,
1698
+ )
1699
+
1700
+ # Environment-specific: Log summary if available
1701
+ try:
1702
+ # Check if this is a Wordle environment and use Wordle helpers (lazy import)
1703
+ wordle_wrapper_cls = None
1704
+ try:
1705
+ from .envs.wordle.environment import WordleEnvironmentWrapper
1706
+ from .envs.wordle.helpers import (
1707
+ get_wordle_rollout_summary,
1708
+ log_wordle_rollout_summary,
1709
+ )
1710
+
1711
+ wordle_wrapper_cls = WordleEnvironmentWrapper
1712
+ except Exception:
1713
+ wordle_wrapper_cls = None # type: ignore[assignment]
1714
+ get_wordle_rollout_summary = None # type: ignore
1715
+ log_wordle_rollout_summary = None # type: ignore
1716
+
1717
+ is_wordle = wordle_wrapper_cls is not None and isinstance(
1718
+ env_handle.env,
1719
+ wordle_wrapper_cls, # type: ignore[arg-type]
1720
+ )
1721
+ if is_wordle:
1722
+ # Convert trajectory steps to expected format
1723
+ formatted_steps = []
1724
+ for step in trajectory_steps:
1725
+ formatted_steps.append({"tool_calls": step.tool_calls or []})
1726
+
1727
+ if (
1728
+ get_wordle_rollout_summary is not None
1729
+ and log_wordle_rollout_summary is not None
1730
+ ):
1731
+ summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
1732
+ log_wordle_rollout_summary(request.run_id, summary)
1733
+ except ImportError:
1734
+ # Wordle helpers not available, skip Wordle-specific logging
1735
+ pass
1736
+ except Exception as e:
1737
+ logger.warning(f"Failed to generate environment-specific summary: {e}")
1738
+
1739
+ # Mark run as completed
1740
+ aborted = registry.is_run_aborted(request.run_id)
1741
+ if not aborted:
1742
+ registry.complete_run(request.run_id)
1743
+ if decision_open:
1744
+ await tracing_context.end_decision()
1745
+ decision_open = False
1746
+ if not finalized:
1747
+ session_trace = await tracing_context.finalize(
1748
+ total_reward=total_reward,
1749
+ achievement_state=prev_achievements,
1750
+ total_steps=len(trajectory_steps),
1751
+ )
1752
+ finalized = True
1753
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1754
+
1755
+ return RolloutResponse(
1756
+ run_id=request.run_id,
1757
+ trajectories=[trajectory],
1758
+ branches={},
1759
+ metrics=metrics,
1760
+ aborted=aborted,
1761
+ ops_executed=ops_executed,
1762
+ trace=trace_payload,
1763
+ )
1764
+
1765
+ except Exception as e:
1766
+ logger.error(f"Rollout failed for run {request.run_id}: {e}")
1767
+ registry.abort_run(request.run_id)
1768
+ if decision_open:
1769
+ with contextlib.suppress(Exception):
1770
+ await tracing_context.end_decision()
1771
+ decision_open = False
1772
+ if not finalized:
1773
+ session_trace = None
1774
+ with contextlib.suppress(Exception):
1775
+ session_trace = await tracing_context.finalize(
1776
+ total_reward=total_reward,
1777
+ achievement_state=prev_achievements,
1778
+ total_steps=len(trajectory_steps),
1779
+ )
1780
+ finalized = True
1781
+ raise HTTPException(status_code=500, detail=str(e)) from e
1782
+ finally:
1783
+ # Ensure any environment created for this rollout is terminated (no reuse across rollouts)
1784
+ try:
1785
+ if created_env_id:
1786
+ from .environment_routes import EnvTerminateRequest, terminate_environment
1787
+
1788
+ await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
1789
+ logger.info(
1790
+ "ROLL_OUT: terminated environment env_id=%s seed=%s",
1791
+ str(created_env_id),
1792
+ str(env_seed_used) if env_seed_used is not None else "unknown",
1793
+ )
1794
+ # Verify removal from registry
1795
+ with contextlib.suppress(Exception):
1796
+ _post = registry.get_env(created_env_id)
1797
+ logger.info(
1798
+ "ROLL_OUT: env_killed=%s (post_lookup=%s)",
1799
+ str(_post is None),
1800
+ str(_post),
1801
+ )
1802
+ except Exception as _te:
1803
+ logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
1804
+
1805
+ # Best-effort policy cleanup if we created one (avoid reuse across rollouts)
1806
+ with contextlib.suppress(Exception):
1807
+ if created_policy_id:
1808
+ from .policy_routes import PolicyTerminateRequest, terminate_policy
1809
+
1810
+ await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
1811
+ logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
1812
+
1813
+ if not finalized:
1814
+ session_trace = None
1815
+ with contextlib.suppress(Exception):
1816
+ session_trace = await tracing_context.finalize(
1817
+ total_reward=total_reward,
1818
+ achievement_state=prev_achievements,
1819
+ total_steps=len(trajectory_steps),
1820
+ )
1821
+ finalized = True
1822
+
1823
+ with contextlib.suppress(Exception):
1824
+ _clear_seed_side_effects()
1825
+ logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
1826
+
1827
+
1828
+ @router.post("/run/abort", response_model=RunAbortResponse)
1829
+ async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
1830
+ """Abort a running rollout."""
1831
+ success = registry.abort_run(request.run_id)
1832
+
1833
+ if not success:
1834
+ raise HTTPException(
1835
+ status_code=404,
1836
+ detail=f"Run {request.run_id} not found",
1837
+ )
1838
+
1839
+ return RunAbortResponse(
1840
+ ok=True,
1841
+ run_id=request.run_id,
1842
+ )
1843
+
1844
+
1845
+ @router.get("/run/status/{run_id}", response_model=RunStatusResponse)
1846
+ async def get_run_status(run_id: str) -> RunStatusResponse:
1847
+ """Get the status of a run."""
1848
+ run_handle = registry.get_run(run_id)
1849
+
1850
+ if not run_handle:
1851
+ raise HTTPException(
1852
+ status_code=404,
1853
+ detail=f"Run {run_id} not found",
1854
+ )
1855
+
1856
+ return RunStatusResponse(
1857
+ run_id=run_id,
1858
+ status=run_handle.status,
1859
+ started_at=run_handle.started_at,
1860
+ finished_at=run_handle.finished_at,
1861
+ )