synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (353) hide show
  1. examples/__init__.py +16 -0
  2. examples/crafter_debug_render.py +23 -17
  3. examples/qwen_coder/README.md +102 -0
  4. examples/qwen_coder/_shared.py +113 -0
  5. examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
  6. examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
  7. examples/qwen_coder/configs/coder_lora_small.toml +58 -0
  8. examples/qwen_coder/generate_dataset.py +98 -0
  9. examples/qwen_coder/infer_ft_smoke.py +64 -0
  10. examples/qwen_coder/infer_prod_proxy.py +73 -0
  11. examples/qwen_coder/infer_via_synth.py +87 -0
  12. examples/qwen_coder/scripts/infer_coder.sh +18 -0
  13. examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
  14. examples/qwen_coder/sft_full_17b.py +103 -0
  15. examples/qwen_coder/sft_lora_30b.py +110 -0
  16. examples/qwen_coder/subset_jsonl.py +38 -0
  17. examples/qwen_coder/validate_jsonl.py +59 -0
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +53 -52
  22. examples/rl/run_rl_and_save.py +29 -12
  23. examples/rl/task_app/math_single_step.py +180 -41
  24. examples/rl/task_app/math_task_app.py +14 -6
  25. examples/sft/README.md +139 -0
  26. examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
  27. examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
  28. examples/sft/evaluate.py +117 -0
  29. examples/sft/export_dataset.py +117 -0
  30. examples/sft/generate_traces.py +162 -0
  31. examples/swe/__init__.py +12 -0
  32. examples/swe/task_app/README.md +105 -0
  33. examples/swe/task_app/__init__.py +2 -0
  34. examples/swe/task_app/grpo_swe_mini.py +571 -0
  35. examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
  36. examples/swe/task_app/hosted/README.md +173 -0
  37. examples/swe/task_app/hosted/__init__.py +5 -0
  38. examples/swe/task_app/hosted/branching.py +143 -0
  39. examples/swe/task_app/hosted/environment_routes.py +1289 -0
  40. examples/swe/task_app/hosted/envs/__init__.py +1 -0
  41. examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
  42. examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
  43. examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
  44. examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
  45. examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
  46. examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
  47. examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
  48. examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
  49. examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
  50. examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
  51. examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
  52. examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
  53. examples/swe/task_app/hosted/hosted_app.py +204 -0
  54. examples/swe/task_app/hosted/inference/__init__.py +5 -0
  55. examples/swe/task_app/hosted/inference/openai_client.py +618 -0
  56. examples/swe/task_app/hosted/main.py +100 -0
  57. examples/swe/task_app/hosted/policy_routes.py +1079 -0
  58. examples/swe/task_app/hosted/registry.py +195 -0
  59. examples/swe/task_app/hosted/rollout.py +1869 -0
  60. examples/swe/task_app/hosted/storage/__init__.py +5 -0
  61. examples/swe/task_app/hosted/storage/volume.py +211 -0
  62. examples/swe/task_app/hosted/test_agents.py +161 -0
  63. examples/swe/task_app/hosted/test_service.py +137 -0
  64. examples/swe/task_app/hosted/utils.py +62 -0
  65. examples/vlm/README.md +68 -0
  66. examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
  67. examples/vlm/crafter_image_only_agent.py +207 -0
  68. examples/vlm/crafter_openai_vlm_agent.py +277 -0
  69. examples/vlm/filter_image_rows.py +63 -0
  70. examples/vlm/run_crafter_vlm_benchmark.py +316 -0
  71. examples/warming_up_to_rl/analyze_trace_db.py +12 -10
  72. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
  73. examples/warming_up_to_rl/export_trace_sft.py +218 -36
  74. examples/warming_up_to_rl/groq_test.py +15 -8
  75. examples/warming_up_to_rl/manage_secrets.py +29 -25
  76. examples/warming_up_to_rl/readme.md +9 -2
  77. examples/warming_up_to_rl/run_eval.py +137 -61
  78. examples/warming_up_to_rl/run_fft_and_save.py +131 -60
  79. examples/warming_up_to_rl/run_local_rollout.py +88 -39
  80. examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
  81. examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
  82. examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
  83. examples/warming_up_to_rl/run_rl_and_save.py +35 -12
  84. examples/warming_up_to_rl/run_rollout_remote.py +44 -19
  85. examples/warming_up_to_rl/task_app/README.md +6 -2
  86. examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
  87. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
  98. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  99. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
  100. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  101. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
  102. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
  103. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
  104. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
  105. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
  106. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  107. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
  108. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
  109. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
  110. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
  111. synth/__init__.py +14 -0
  112. synth_ai/__init__.py +20 -4
  113. synth_ai/api/models/supported.py +376 -0
  114. synth_ai/api/train/builders.py +157 -26
  115. synth_ai/api/train/cli.py +213 -57
  116. synth_ai/api/train/config_finder.py +65 -5
  117. synth_ai/api/train/env_resolver.py +33 -15
  118. synth_ai/api/train/pollers.py +13 -4
  119. synth_ai/api/train/supported_algos.py +139 -0
  120. synth_ai/api/train/task_app.py +5 -3
  121. synth_ai/api/train/utils.py +33 -48
  122. synth_ai/cli/__init__.py +19 -4
  123. synth_ai/cli/_modal_wrapper.py +28 -0
  124. synth_ai/cli/_typer_patch.py +49 -0
  125. synth_ai/cli/balance.py +2 -3
  126. synth_ai/cli/calc.py +1 -1
  127. synth_ai/cli/demo.py +21 -6
  128. synth_ai/cli/recent.py +2 -2
  129. synth_ai/cli/rl_demo.py +77 -17
  130. synth_ai/cli/root.py +116 -39
  131. synth_ai/cli/status.py +2 -2
  132. synth_ai/cli/task_apps.py +1709 -243
  133. synth_ai/cli/traces.py +7 -4
  134. synth_ai/cli/turso.py +73 -0
  135. synth_ai/cli/watch.py +12 -18
  136. synth_ai/core/experiment.py +0 -2
  137. synth_ai/demo_registry.py +68 -31
  138. synth_ai/demos/core/cli.py +516 -194
  139. synth_ai/demos/demo_task_apps/__init__.py +3 -3
  140. synth_ai/demos/demo_task_apps/core.py +64 -28
  141. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  142. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
  143. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  144. synth_ai/demos/demo_task_apps/math/app.py +2 -1
  145. synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
  146. synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
  147. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
  148. synth_ai/environments/examples/bandit/engine.py +12 -4
  149. synth_ai/environments/examples/bandit/taskset.py +4 -4
  150. synth_ai/environments/examples/crafter_classic/environment.py +76 -1
  151. synth_ai/environments/reproducibility/tree.py +5 -6
  152. synth_ai/environments/service/app.py +11 -12
  153. synth_ai/environments/service/core_routes.py +10 -9
  154. synth_ai/environments/stateful/engine.py +1 -1
  155. synth_ai/environments/tasks/core.py +1 -0
  156. synth_ai/environments/tasks/filters.py +5 -6
  157. synth_ai/environments/tasks/utils.py +4 -5
  158. synth_ai/evals/base.py +0 -2
  159. synth_ai/handshake.py +11 -9
  160. synth_ai/http.py +1 -1
  161. synth_ai/http_client.py +43 -11
  162. synth_ai/inference/__init__.py +0 -2
  163. synth_ai/inference/client.py +20 -6
  164. synth_ai/jobs/client.py +103 -78
  165. synth_ai/learning/__init__.py +41 -6
  166. synth_ai/learning/algorithms.py +14 -0
  167. synth_ai/learning/client.py +121 -29
  168. synth_ai/learning/config.py +2 -40
  169. synth_ai/learning/constants.py +0 -2
  170. synth_ai/learning/ft_client.py +4 -56
  171. synth_ai/learning/health.py +13 -7
  172. synth_ai/learning/jobs.py +43 -47
  173. synth_ai/{rl → learning/rl}/__init__.py +14 -5
  174. synth_ai/learning/rl/client.py +267 -0
  175. synth_ai/learning/rl/config.py +31 -0
  176. synth_ai/{rl → learning/rl}/contracts.py +5 -10
  177. synth_ai/{rl → learning/rl}/env_keys.py +45 -16
  178. synth_ai/learning/rl/secrets.py +13 -0
  179. synth_ai/learning/rl_client.py +2 -253
  180. synth_ai/learning/sft/__init__.py +29 -0
  181. synth_ai/learning/sft/client.py +68 -0
  182. synth_ai/learning/sft/config.py +270 -0
  183. synth_ai/learning/sft/data.py +295 -0
  184. synth_ai/learning/sse.py +25 -26
  185. synth_ai/learning/validators.py +25 -24
  186. synth_ai/lm/__init__.py +21 -47
  187. synth_ai/task/__init__.py +26 -27
  188. synth_ai/task/apps/__init__.py +18 -19
  189. synth_ai/task/auth.py +35 -23
  190. synth_ai/task/client.py +15 -13
  191. synth_ai/task/contracts.py +37 -35
  192. synth_ai/task/datasets.py +9 -6
  193. synth_ai/task/errors.py +11 -10
  194. synth_ai/task/health.py +17 -11
  195. synth_ai/task/json.py +58 -24
  196. synth_ai/task/proxy.py +15 -14
  197. synth_ai/task/rubrics.py +22 -15
  198. synth_ai/task/server.py +43 -17
  199. synth_ai/task/tracing_utils.py +12 -7
  200. synth_ai/task/validators.py +0 -1
  201. synth_ai/task/vendors.py +5 -7
  202. synth_ai/tracing_v3/__init__.py +2 -0
  203. synth_ai/tracing_v3/abstractions.py +21 -4
  204. synth_ai/tracing_v3/db_config.py +26 -1
  205. synth_ai/tracing_v3/decorators.py +18 -15
  206. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  207. synth_ai/tracing_v3/hooks.py +6 -4
  208. synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
  209. synth_ai/tracing_v3/replica_sync.py +1 -0
  210. synth_ai/tracing_v3/session_tracer.py +63 -16
  211. synth_ai/tracing_v3/storage/base.py +89 -1
  212. synth_ai/tracing_v3/storage/config.py +21 -8
  213. synth_ai/tracing_v3/storage/factory.py +10 -8
  214. synth_ai/tracing_v3/storage/utils.py +4 -2
  215. synth_ai/tracing_v3/turso/daemon.py +7 -2
  216. synth_ai/tracing_v3/turso/models.py +5 -2
  217. synth_ai/tracing_v3/turso/native_manager.py +1173 -0
  218. synth_ai/tracing_v3/utils.py +4 -3
  219. synth_ai/v0/api/__init__.py +8 -0
  220. synth_ai/v0/api/models/__init__.py +8 -0
  221. synth_ai/v0/api/models/supported.py +8 -0
  222. synth_ai/v0/config/__init__.py +15 -0
  223. synth_ai/v0/config/base_url.py +12 -0
  224. synth_ai/v0/lm/__init__.py +51 -0
  225. synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
  226. synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
  227. synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
  228. synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
  229. synth_ai/{lm → v0/lm}/config.py +6 -1
  230. synth_ai/{lm → v0/lm}/core/all.py +9 -9
  231. synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
  232. synth_ai/{lm → v0/lm}/core/main.py +19 -7
  233. synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
  234. synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
  235. synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
  236. synth_ai/{lm → v0/lm}/overrides.py +4 -4
  237. synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
  238. synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
  239. synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
  240. synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
  241. synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
  242. synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
  243. synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
  244. synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
  245. synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
  246. synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
  247. synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
  248. synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
  249. synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
  250. synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
  251. synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
  252. synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
  253. synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
  254. synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
  255. synth_ai/v0/tracing/upload.py +32 -135
  256. synth_ai/v0/tracing_v3/__init__.py +10 -0
  257. synth_ai/v0/tracing_v3/abstractions.py +3 -0
  258. synth_ai/v0/tracing_v3/decorators.py +3 -0
  259. synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
  260. synth_ai/v0/tracing_v3/session_tracer.py +3 -0
  261. synth_ai-0.2.9.dev6.dist-info/METADATA +191 -0
  262. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/RECORD +291 -264
  263. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/top_level.txt +1 -0
  264. examples/common_old/backend.py +0 -21
  265. examples/evals_old/README.md +0 -98
  266. examples/evals_old/__init__.py +0 -6
  267. examples/evals_old/compare_models.py +0 -1037
  268. examples/evals_old/example_log.md +0 -145
  269. examples/evals_old/run_demo.sh +0 -126
  270. examples/evals_old/trace_analysis.py +0 -270
  271. examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
  272. examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
  273. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
  274. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
  275. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
  276. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
  277. examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
  278. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
  279. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
  280. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
  281. examples/finetuning_old/synth_qwen_v1/README.md +0 -68
  282. examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
  283. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
  284. examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
  285. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
  286. examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
  287. examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
  288. examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
  289. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
  290. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
  291. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
  292. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
  293. examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
  294. examples/finetuning_old/synth_qwen_v1/util.py +0 -147
  295. examples/rl_old/task_app.py +0 -962
  296. examples/warming_up_to_rl/old/event_rewards.md +0 -234
  297. examples/warming_up_to_rl/old/notes.md +0 -73
  298. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
  299. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
  300. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
  301. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  302. synth_ai/experimental/synth_oss.py +0 -446
  303. synth_ai/install_sqld.sh +0 -40
  304. synth_ai/learning/filtering.py +0 -0
  305. synth_ai/learning/offline/dpo.py +0 -0
  306. synth_ai/learning/offline/providers.py +0 -7
  307. synth_ai/learning/offline/sft.py +0 -0
  308. synth_ai/learning/offline/shared.py +0 -0
  309. synth_ai/learning/online/grpo.py +0 -0
  310. synth_ai/learning/online/irft.py +0 -0
  311. synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
  312. synth_ai/learning/prompts/gepa.py +0 -0
  313. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
  314. synth_ai/learning/prompts/mipro.py +0 -289
  315. synth_ai/learning/prompts/random_search.py +0 -246
  316. synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
  317. synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
  318. synth_ai/rl/secrets.py +0 -19
  319. synth_ai/scripts/verify_rewards.py +0 -100
  320. synth_ai/tracing/__init__.py +0 -30
  321. synth_ai/tracing_v1/__init__.py +0 -33
  322. synth_ai/tracing_v3/turso/__init__.py +0 -25
  323. synth_ai/tracing_v3/turso/manager.py +0 -774
  324. synth_ai/zyk/__init__.py +0 -30
  325. synth_ai-0.2.9.dev4.dist-info/METADATA +0 -131
  326. /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
  327. /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
  328. /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
  329. /synth_ai/{lm → v0/lm}/constants.py +0 -0
  330. /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
  331. /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
  332. /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
  333. /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
  334. /synth_ai/{lm → v0/lm}/injection.py +0 -0
  335. /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
  336. /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
  337. /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
  338. /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
  339. /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
  340. /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
  341. /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
  342. /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
  343. /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
  344. /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
  345. /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
  346. /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
  347. /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
  348. /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
  349. /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
  350. /synth_ai/{lm → v0/lm}/warmup.py +0 -0
  351. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/WHEEL +0 -0
  352. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/entry_points.txt +0 -0
  353. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1869 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import time as _time
8
+ from datetime import datetime
9
+ from typing import Any
10
+
11
+ from fastapi import APIRouter, HTTPException, Request, status
12
+ from pydantic import BaseModel
13
+ from synth_ai.lm.vendors.base import BaseLMResponse
14
+ from synth_ai.task.tracing_utils import unique_sft_path
15
+ from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, TimeRecord
16
+ from synth_ai.tracing_v3.llm_call_record_helpers import create_llm_call_record_from_response
17
+ from synth_ai.tracing_v3.session_tracer import SessionTracer
18
+
19
+ from .registry import registry
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ # --- Seeding utilities (robust, optional deps) ---
25
+ def _set_global_seed(seed_value: int) -> dict[str, Any]:
26
+ """Set global RNG seeds across common libraries; return details for logging/restoration.
27
+
28
+ Returns a dict containing which libraries were seeded and prior states if obtainable.
29
+ """
30
+ seeded: dict[str, Any] = {"seed": int(seed_value), "libs": []}
31
+ with contextlib.suppress(Exception):
32
+ import random as _random # type: ignore
33
+
34
+ _random.seed(seed_value)
35
+ seeded["libs"].append("random")
36
+ with contextlib.suppress(Exception):
37
+ import numpy as _np # type: ignore
38
+
39
+ _np.random.seed(seed_value)
40
+ seeded["libs"].append("numpy")
41
+ with contextlib.suppress(Exception):
42
+ import torch as _torch # type: ignore
43
+
44
+ if hasattr(_torch, "manual_seed"):
45
+ _torch.manual_seed(seed_value)
46
+ seeded["libs"].append("torch")
47
+ # Make CUDA deterministic if present (best-effort)
48
+ with contextlib.suppress(Exception):
49
+ if getattr(_torch, "cuda", None) and _torch.cuda.is_available():
50
+ _torch.cuda.manual_seed_all(seed_value)
51
+ seeded.setdefault("cuda", True)
52
+ # CUDNN deterministic flags (optional)
53
+ with contextlib.suppress(Exception):
54
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
55
+ _torch.backends.cudnn.deterministic = True # type: ignore[attr-defined]
56
+ _torch.backends.cudnn.benchmark = False # type: ignore[attr-defined]
57
+ return seeded
58
+
59
+
60
+ def _clear_seed_side_effects() -> None:
61
+ """Best-effort cleanup to avoid global deterministic side-effects between requests."""
62
+ # We cannot truly restore prior RNG states without capturing them; we just avoid
63
+ # leaving aggressive deterministic flags enabled where it matters.
64
+ with contextlib.suppress(Exception):
65
+ import torch as _torch # type: ignore
66
+
67
+ with contextlib.suppress(Exception):
68
+ if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
69
+ # Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
70
+ # We'll keep deterministic False to avoid global impact; benchmark left False for stability.
71
+ _torch.backends.cudnn.deterministic = False # type: ignore[attr-defined]
72
+
73
+
74
+ router = APIRouter()
75
+
76
+
77
+ class RolloutEnvSpec(BaseModel):
78
+ env_id: str | None = None
79
+ env_name: str | None = None
80
+ config: dict[str, Any] = {}
81
+ seed: int | None = None
82
+
83
+
84
+ class RolloutPolicySpec(BaseModel):
85
+ policy_id: str | None = None
86
+ policy_name: str | None = None
87
+ config: dict[str, Any] = {}
88
+
89
+
90
+ class RolloutBranchConfig(BaseModel):
91
+ branch_every_n_steps: int = 0
92
+ branch_on_condition: str | None = None
93
+ max_branches: int = 0
94
+ branch_policy: bool = False
95
+ branch_env: bool = False
96
+
97
+
98
+ class RolloutRecordConfig(BaseModel):
99
+ trajectories: bool = True
100
+ logprobs: bool = False
101
+ value: bool = False
102
+ return_trace: bool = False
103
+ trace_format: str = "compact"
104
+
105
+
106
+ class RolloutSafetyConfig(BaseModel):
107
+ max_ops: int = 100000
108
+ max_time_s: float = 3600.0
109
+
110
+
111
+ class RolloutRequest(BaseModel):
112
+ run_id: str
113
+ env: RolloutEnvSpec
114
+ policy: RolloutPolicySpec
115
+ ops: list[str] # ["agent", "env", ...]
116
+ record: RolloutRecordConfig = RolloutRecordConfig()
117
+ on_done: str = "reset" # "reset" | "terminate"
118
+ branch: RolloutBranchConfig | None = None
119
+ safety: RolloutSafetyConfig = RolloutSafetyConfig()
120
+ # Optional run/session context
121
+ training_session_id: str | None = None
122
+ synth_base_url: str | None = None
123
+
124
+
125
+ class RolloutStep(BaseModel):
126
+ obs: dict[str, Any]
127
+ tool_calls: list[dict[str, Any]]
128
+ reward: float | None = None
129
+ done: bool = False
130
+ truncated: bool | None = None
131
+ logprob: float | None = None
132
+ value: float | None = None
133
+ info: dict[str, Any] | None = None
134
+
135
+
136
+ class RolloutTrajectory(BaseModel):
137
+ env_id: str
138
+ policy_id: str
139
+ steps: list[RolloutStep]
140
+ final: dict[str, Any] | None = None
141
+ length: int
142
+ decision_samples: list[dict[str, Any]] | None = None
143
+
144
+
145
+ def compute_stepwise_reward(
146
+ prev_achievements: dict[str, bool],
147
+ new_achievements: dict[str, bool],
148
+ decision_index: int,
149
+ actions_summary: list[dict[str, Any]],
150
+ indicator_lambda: float,
151
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, float]]:
152
+ """Compute stepwise reward metadata given achievement states before/after a decision."""
153
+
154
+ prev_map = prev_achievements or {}
155
+ next_map = new_achievements or {}
156
+
157
+ unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
158
+ indicator = 1 if unlocked else 0
159
+ reward_value = float(indicator_lambda) * indicator
160
+
161
+ stepwise_info = {
162
+ "decision_index": decision_index,
163
+ "indicator": indicator,
164
+ "new_achievements": unlocked,
165
+ "reward": reward_value,
166
+ }
167
+ decision_sample = {
168
+ "decision_index": decision_index,
169
+ "indicator": indicator,
170
+ "r_i": reward_value,
171
+ "actions": actions_summary,
172
+ }
173
+ stats = {
174
+ "indicator": float(indicator),
175
+ "reward": reward_value,
176
+ "new_achievements_count": float(len(unlocked)),
177
+ }
178
+ return stepwise_info, decision_sample, stats
179
+
180
+
181
+ class RolloutMetrics(BaseModel):
182
+ episode_returns: list[float]
183
+ mean_return: float
184
+ num_steps: int
185
+ num_episodes: int = 0
186
+
187
+
188
+ class RolloutResponse(BaseModel):
189
+ run_id: str
190
+ trajectories: list[RolloutTrajectory]
191
+ branches: dict[str, list[str]] = {}
192
+ metrics: RolloutMetrics
193
+ aborted: bool = False
194
+ ops_executed: int = 0
195
+ trace: dict[str, Any] | None = None
196
+
197
+
198
+ class RolloutTracingContext:
199
+ """Helper managing tracing_v3 recording and optional SFT dumps for a rollout."""
200
+
201
+ def __init__(
202
+ self,
203
+ tracer: SessionTracer | None,
204
+ request: RolloutRequest,
205
+ fastapi_request: Request,
206
+ ) -> None:
207
+ self.tracer = tracer
208
+ self.enabled = tracer is not None
209
+ self.request = request
210
+ self.fastapi_request = fastapi_request
211
+ self.run_id = request.run_id
212
+ self.current_step_id: str | None = None
213
+ self.current_turn: int | None = None
214
+ self.lm_calls_summary: list[dict[str, Any]] = []
215
+ self.decision_rewards: list[dict[str, Any]] = []
216
+ self.sft_records: list[dict[str, Any]] = []
217
+ self.latest_system_messages: list[str] = []
218
+ self.latest_user_messages: list[str] = []
219
+ self.latest_system_prompt_content: list[Any] = []
220
+ self.latest_user_prompt_content: list[Any] = []
221
+ self.trace_format = (
222
+ getattr(request.record, "trace_format", "compact") or "compact"
223
+ ).lower()
224
+ self.return_trace = bool(getattr(request.record, "return_trace", False))
225
+ self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
226
+ self.session_trace = None
227
+ self.metadata_updates: dict[str, Any] = {}
228
+ self.policy_name = request.policy.policy_name or ""
229
+ self.env_name = request.env.env_name or ""
230
+ self.metadata_base: dict[str, Any] = {
231
+ "run_id": self.run_id,
232
+ "policy_name": self.policy_name,
233
+ "policy_id": request.policy.policy_id,
234
+ "env_name": self.env_name,
235
+ "env_id": request.env.env_id,
236
+ "seed": request.env.seed,
237
+ "training_session_id": request.training_session_id,
238
+ "synth_base_url": request.synth_base_url,
239
+ }
240
+
241
+ # Expose context for downstream calls inside this request lifecycle
242
+ fastapi_request.state.rollout_tracing = self
243
+ fastapi_request.state.rollout_run_id = self.run_id
244
+
245
+ async def start_session(self) -> None:
246
+ if not self.enabled or self.tracer is None:
247
+ return
248
+ try:
249
+ await self.tracer.initialize()
250
+ except Exception as exc:
251
+ logger.debug("TRACING_INIT_FAIL: %s", exc)
252
+ try:
253
+ await self.tracer.start_session(
254
+ session_id=self.run_id, metadata=dict(self.metadata_base)
255
+ )
256
+ except Exception as exc:
257
+ logger.warning("TRACING_START_FAIL: %s", exc)
258
+ self.enabled = False
259
+ self.tracer = None
260
+
261
+ async def start_decision(self, turn_number: int) -> None:
262
+ self.current_turn = turn_number
263
+ self.current_step_id = f"decision_{turn_number}"
264
+ if not self.enabled or self.tracer is None:
265
+ return
266
+ try:
267
+ await self.tracer.start_timestep(step_id=self.current_step_id, turn_number=turn_number)
268
+ except Exception as exc:
269
+ logger.debug("TRACING_STEP_START_FAIL: %s", exc)
270
+
271
+ async def end_decision(self) -> None:
272
+ if not self.enabled or self.tracer is None:
273
+ return
274
+ try:
275
+ await self.tracer.end_timestep(step_id=self.current_step_id)
276
+ except Exception as exc:
277
+ logger.debug("TRACING_STEP_END_FAIL: %s", exc)
278
+ finally:
279
+ self.current_step_id = None
280
+
281
+ def _message_metadata(self) -> dict[str, Any]:
282
+ return {
283
+ "turn": self.current_turn,
284
+ "step_id": self.current_step_id,
285
+ }
286
+
287
+ async def record_policy_prompts(
288
+ self,
289
+ system_messages: list[Any],
290
+ user_messages: list[Any],
291
+ ) -> None:
292
+ self.latest_system_messages = [self._prompt_text(entry) for entry in system_messages]
293
+ self.latest_user_messages = [self._prompt_text(entry) for entry in user_messages]
294
+ self.latest_system_prompt_content = [
295
+ self._prompt_content(entry, role="system") for entry in system_messages
296
+ ]
297
+ self.latest_user_prompt_content = [
298
+ self._prompt_content(entry, role="user") for entry in user_messages
299
+ ]
300
+ if not self.enabled or self.tracer is None:
301
+ return
302
+ for entry in system_messages:
303
+ try:
304
+ await self.tracer.record_message(
305
+ content=self._prompt_payload(entry, role="system"),
306
+ message_type="policy_system_prompt",
307
+ metadata=self._message_metadata(),
308
+ )
309
+ except Exception as exc:
310
+ logger.debug("TRACING_SYSTEM_MSG_FAIL: %s", exc)
311
+ for entry in user_messages:
312
+ try:
313
+ await self.tracer.record_message(
314
+ content=self._prompt_payload(entry, role="user"),
315
+ message_type="policy_user_prompt",
316
+ metadata=self._message_metadata(),
317
+ )
318
+ except Exception as exc:
319
+ logger.debug("TRACING_USER_MSG_FAIL: %s", exc)
320
+
321
+ def _content_to_text(self, content: Any) -> str:
322
+ if isinstance(content, str):
323
+ return content
324
+ if isinstance(content, list):
325
+ parts: list[str] = []
326
+ for seg in content:
327
+ if isinstance(seg, dict):
328
+ text_val = seg.get("text") or seg.get("content")
329
+ if isinstance(text_val, str):
330
+ parts.append(text_val)
331
+ return "".join(parts)
332
+ if content is None:
333
+ return ""
334
+ return str(content)
335
+
336
+ def _prompt_text(self, entry: Any) -> str:
337
+ if isinstance(entry, dict):
338
+ text = entry.get("text")
339
+ if isinstance(text, str):
340
+ return text
341
+ content = entry.get("content")
342
+ return self._content_to_text(content)
343
+ return self._content_to_text(entry)
344
+
345
+ def _prompt_payload(self, entry: Any, *, role: str) -> dict[str, Any]:
346
+ if isinstance(entry, dict):
347
+ payload = dict(entry)
348
+ payload.setdefault("role", role)
349
+ return payload
350
+ return {
351
+ "role": role,
352
+ "text": self._prompt_text(entry),
353
+ "content": entry,
354
+ }
355
+
356
+ def _prompt_content(self, entry: Any, *, role: str) -> Any:
357
+ payload = self._prompt_payload(entry, role=role)
358
+ return payload.get("content", payload.get("text"))
359
+
360
+ def _content_has_image(self, content: Any) -> bool:
361
+ if isinstance(content, list):
362
+ return any(
363
+ isinstance(seg, dict)
364
+ and seg.get("type") in {"image", "image_url"}
365
+ for seg in content
366
+ )
367
+ if isinstance(content, dict):
368
+ if content.get("type") in {"image", "image_url"}:
369
+ return True
370
+ inner = content.get("content")
371
+ if isinstance(inner, list):
372
+ return any(
373
+ isinstance(seg, dict)
374
+ and seg.get("type") in {"image", "image_url"}
375
+ for seg in inner
376
+ )
377
+ return False
378
+
379
+ def _safe_json(self, payload: Any, limit: int = 4000) -> str:
380
+ try:
381
+ text = json.dumps(payload, ensure_ascii=False)
382
+ except Exception:
383
+ text = str(payload)
384
+ if len(text) > limit:
385
+ return text[:limit] + "…"
386
+ return text
387
+
388
+ async def record_tool_invocation(self, tool_calls: list[dict[str, Any]] | None) -> None:
389
+ if tool_calls is None:
390
+ return
391
+ if self.enabled and self.tracer is not None:
392
+ try:
393
+ await self.tracer.record_message(
394
+ content=self._safe_json(tool_calls),
395
+ message_type="policy_tool_call",
396
+ metadata=self._message_metadata(),
397
+ )
398
+ except Exception as exc:
399
+ logger.debug("TRACING_TOOL_MSG_FAIL: %s", exc)
400
+
401
+ async def _record_event(self, event: Any) -> int | None:
402
+ if not self.enabled or self.tracer is None:
403
+ return None
404
+ try:
405
+ return await self.tracer.record_event(event)
406
+ except Exception as exc:
407
+ logger.debug("TRACING_EVENT_FAIL: %s", exc)
408
+ return None
409
+
410
+ async def record_llm_call(
411
+ self,
412
+ *,
413
+ inference_request: dict[str, Any],
414
+ inference_response: dict[str, Any],
415
+ tool_calls: list[dict[str, Any]] | None,
416
+ provider: str,
417
+ model_name: str,
418
+ started_at: datetime,
419
+ completed_at: datetime,
420
+ latency_ms: int | None,
421
+ ) -> None:
422
+ usage = inference_response.get("usage") or {}
423
+ input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
424
+ output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
425
+ total_tokens = usage.get("total_tokens")
426
+ cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
427
+
428
+ assistant_message = None
429
+ choices = inference_response.get("choices") or []
430
+ if choices:
431
+ assistant_message = choices[0].get("message") or {}
432
+ assistant_content = (
433
+ assistant_message.get("content") if isinstance(assistant_message, dict) else None
434
+ )
435
+
436
+ raw_response = self._content_to_text(assistant_content)
437
+ if not raw_response:
438
+ raw_response = self._safe_json(inference_response, limit=2000)
439
+
440
+ base_response = BaseLMResponse(
441
+ raw_response=raw_response,
442
+ tool_calls=assistant_message.get("tool_calls")
443
+ if isinstance(assistant_message, dict)
444
+ else None,
445
+ usage=usage or None,
446
+ api_type="chat_completions",
447
+ )
448
+
449
+ request_messages = inference_request.get("messages") or []
450
+ try:
451
+ temperature = float(inference_request.get("temperature"))
452
+ except Exception:
453
+ temperature = 0.0
454
+
455
+ call_record = create_llm_call_record_from_response(
456
+ response=base_response,
457
+ model_name=model_name,
458
+ provider=provider,
459
+ messages=request_messages,
460
+ temperature=temperature,
461
+ request_params=inference_request,
462
+ tools=inference_request.get("tools"),
463
+ started_at=started_at,
464
+ completed_at=completed_at,
465
+ latency_ms=latency_ms,
466
+ )
467
+
468
+ event_metadata = {
469
+ "policy_id": self.request.policy.policy_id,
470
+ "turn": self.current_turn,
471
+ "run_id": self.run_id,
472
+ }
473
+
474
+ event = LMCAISEvent(
475
+ system_instance_id=f"policy:{self.policy_name or 'unknown'}",
476
+ time_record=TimeRecord(event_time=completed_at.timestamp()),
477
+ model_name=model_name,
478
+ provider=provider,
479
+ input_tokens=input_tokens,
480
+ output_tokens=output_tokens,
481
+ total_tokens=total_tokens,
482
+ cost_usd=cost_usd,
483
+ latency_ms=latency_ms,
484
+ call_records=[call_record],
485
+ metadata=event_metadata,
486
+ )
487
+
488
+ await self._record_event(event)
489
+
490
+ self.lm_calls_summary.append(
491
+ {
492
+ "turn": self.current_turn,
493
+ "model": model_name,
494
+ "provider": provider,
495
+ "total_tokens": total_tokens,
496
+ "input_tokens": input_tokens,
497
+ "output_tokens": output_tokens,
498
+ "latency_ms": latency_ms,
499
+ "tool_calls": len(tool_calls or []),
500
+ }
501
+ )
502
+
503
+ if self.sft_output_dir is not None:
504
+ assistant_structured = assistant_content if assistant_content is not None else ""
505
+ assistant_text = self._content_to_text(assistant_content)
506
+ dialogue_structured: list[dict[str, Any]] = []
507
+ for content in self.latest_system_prompt_content:
508
+ if content is None:
509
+ continue
510
+ dialogue_structured.append({"role": "system", "content": content})
511
+ for content in self.latest_user_prompt_content:
512
+ if content is None:
513
+ continue
514
+ dialogue_structured.append({"role": "user", "content": content})
515
+ dialogue_text = (
516
+ [{"role": "system", "content": s} for s in self.latest_system_messages]
517
+ + [{"role": "user", "content": u} for u in self.latest_user_messages]
518
+ )
519
+ user_has_image = any(
520
+ self._content_has_image(content) for content in self.latest_user_prompt_content
521
+ )
522
+ assistant_has_image = self._content_has_image(assistant_structured)
523
+ record = {
524
+ "run_id": self.run_id,
525
+ "turn": self.current_turn,
526
+ "model": model_name,
527
+ "provider": provider,
528
+ "dialogue": dialogue_structured,
529
+ "dialogue_text": dialogue_text,
530
+ "assistant": {
531
+ "content": assistant_structured,
532
+ "content_text": assistant_text,
533
+ "tool_calls": assistant_message.get("tool_calls")
534
+ if isinstance(assistant_message, dict)
535
+ else [],
536
+ "has_image": assistant_has_image,
537
+ },
538
+ "metadata": {
539
+ "user_has_image": user_has_image,
540
+ "assistant_has_image": assistant_has_image,
541
+ "has_image": user_has_image or assistant_has_image,
542
+ },
543
+ "timestamp": datetime.utcnow().isoformat(),
544
+ }
545
+ self.sft_records.append(record)
546
+
547
+ async def record_environment_event(
548
+ self,
549
+ *,
550
+ env_handle: Any,
551
+ prev_obs: dict[str, Any] | None,
552
+ env_response: Any,
553
+ next_obs: dict[str, Any] | None,
554
+ metadata: dict[str, Any] | None = None,
555
+ ) -> int | None:
556
+ if not self.enabled or self.tracer is None:
557
+ return None
558
+
559
+ try:
560
+ prev_summary = (
561
+ _summarize_observation_for_storage(env_handle, prev_obs or {})
562
+ if prev_obs is not None
563
+ else None
564
+ )
565
+ except Exception:
566
+ prev_summary = None
567
+ try:
568
+ next_summary = (
569
+ _summarize_observation_for_storage(env_handle, next_obs or {})
570
+ if next_obs is not None
571
+ else None
572
+ )
573
+ except Exception:
574
+ next_summary = None
575
+
576
+ reward_val = getattr(env_response, "reward", None)
577
+ try:
578
+ reward_float = float(reward_val) if reward_val is not None else 0.0
579
+ except Exception:
580
+ reward_float = 0.0
581
+
582
+ event = EnvironmentEvent(
583
+ system_instance_id=f"environment:{self.env_name or 'unknown'}",
584
+ time_record=TimeRecord(event_time=datetime.utcnow().timestamp()),
585
+ reward=reward_float,
586
+ terminated=bool(getattr(env_response, "done", False)),
587
+ truncated=bool(getattr(env_response, "truncated", False)),
588
+ system_state_before=prev_summary,
589
+ system_state_after=next_summary,
590
+ metadata={
591
+ "turn": self.current_turn,
592
+ "run_id": self.run_id,
593
+ **(metadata or {}),
594
+ },
595
+ )
596
+
597
+ return await self._record_event(event)
598
+
599
+ async def record_decision_reward(
600
+ self,
601
+ *,
602
+ event_id: int | None,
603
+ decision_meta: dict[str, Any] | None,
604
+ ) -> None:
605
+ decision_meta = decision_meta or {}
606
+ ach_delta = int(decision_meta.get("ach_delta", 0))
607
+ unique_delta = int(decision_meta.get("unique_delta", 0))
608
+ all_ach = list(decision_meta.get("all") or [])
609
+ unique_ach = list(decision_meta.get("unique") or [])
610
+
611
+ self.decision_rewards.append(
612
+ {
613
+ "turn": self.current_turn,
614
+ "ach_delta": ach_delta,
615
+ "unique_delta": unique_delta,
616
+ "achievements": all_ach,
617
+ "unique_achievements": unique_ach,
618
+ }
619
+ )
620
+
621
+ if not self.enabled or self.tracer is None or event_id is None:
622
+ return
623
+ try:
624
+ await self.tracer.record_event_reward(
625
+ event_id=event_id,
626
+ turn_number=self.current_turn,
627
+ reward_value=float(ach_delta),
628
+ reward_type="achievement_delta",
629
+ annotation={"achievements": all_ach},
630
+ source="environment",
631
+ )
632
+ if unique_delta:
633
+ await self.tracer.record_event_reward(
634
+ event_id=event_id,
635
+ turn_number=self.current_turn,
636
+ reward_value=float(unique_delta),
637
+ reward_type="unique_achievement_delta",
638
+ annotation={"achievements": unique_ach},
639
+ source="environment",
640
+ )
641
+ except Exception as exc:
642
+ logger.debug("TRACING_REWARD_FAIL: %s", exc)
643
+
644
+ def update_metadata(self, **kwargs: Any) -> None:
645
+ self.metadata_updates.update({k: v for k, v in kwargs.items() if v is not None})
646
+
647
+ async def finalize(
648
+ self,
649
+ *,
650
+ total_reward: float,
651
+ achievement_state: dict[str, bool] | None,
652
+ total_steps: int,
653
+ ) -> Any:
654
+ final_achievements = [key for key, val in (achievement_state or {}).items() if val]
655
+ self.metadata_updates.setdefault("final_achievements", final_achievements)
656
+ if self.enabled and self.tracer is not None:
657
+ try:
658
+ await self.tracer.record_outcome_reward(
659
+ total_reward=int(total_reward),
660
+ achievements_count=len(final_achievements),
661
+ total_steps=int(total_steps),
662
+ reward_metadata=dict(self.metadata_updates),
663
+ )
664
+ except Exception as exc:
665
+ logger.debug("TRACING_OUTCOME_FAIL: %s", exc)
666
+ try:
667
+ self.session_trace = await self.tracer.end_session()
668
+ if self.session_trace is not None:
669
+ self.session_trace.metadata.update(self.metadata_updates)
670
+ except Exception as exc:
671
+ logger.debug("TRACING_END_SESSION_FAIL: %s", exc)
672
+ self.session_trace = None
673
+ with contextlib.suppress(Exception):
674
+ await self.tracer.close()
675
+
676
+ if self.sft_records and self.sft_output_dir:
677
+ self.write_sft_records()
678
+
679
+ # Clear context from request state to avoid leaks
680
+ self.fastapi_request.state.rollout_tracing = None
681
+
682
+ return self.session_trace
683
+
684
+ def write_sft_records(self) -> None:
685
+ if not self.sft_output_dir or not self.sft_records:
686
+ return
687
+ try:
688
+ path = unique_sft_path(self.sft_output_dir, run_id=self.run_id)
689
+ path.parent.mkdir(parents=True, exist_ok=True)
690
+ with path.open("w", encoding="utf-8") as fh:
691
+ for record in self.sft_records:
692
+ json.dump(record, fh, ensure_ascii=False)
693
+ fh.write("\n")
694
+ logger.info(f"SFT_WRITTEN: {path}")
695
+ except Exception as exc:
696
+ logger.warning(f"SFT_WRITE_FAIL: {exc}")
697
+ finally:
698
+ self.sft_records.clear()
699
+
700
+ def build_trace_payload(self, session_trace: Any) -> dict[str, Any] | None:
701
+ if not self.return_trace or session_trace is None:
702
+ return None
703
+ if self.trace_format == "full":
704
+ payload = session_trace.to_dict()
705
+ payload.setdefault("metadata", {}).update(self.metadata_updates)
706
+ return payload
707
+ metadata = dict(session_trace.metadata)
708
+ metadata.update(self.metadata_updates)
709
+ return {
710
+ "session_id": session_trace.session_id,
711
+ "created_at": session_trace.created_at.isoformat(),
712
+ "metadata": metadata,
713
+ "events_count": len(session_trace.event_history),
714
+ "messages_count": len(session_trace.markov_blanket_message_history),
715
+ "lm_calls": self.lm_calls_summary,
716
+ "decision_rewards": self.decision_rewards,
717
+ }
718
+
719
+
720
+ def _summarize_observation_for_storage(
721
+ env_handle: Any, observation: dict[str, Any]
722
+ ) -> dict[str, Any]:
723
+ """Return a compact dict for trajectory storage instead of the raw observation.
724
+
725
+ - For Crafter, use the same summary used for the policy user prompt
726
+ - For others, keep a minimal subset or plain text preview
727
+ """
728
+ # Try Crafter-specific formatter
729
+ crafter_wrapper = None
730
+ with contextlib.suppress(Exception):
731
+ from .envs.crafter.environment import (
732
+ CrafterEnvironmentWrapper as _CrafterWrapper, # type: ignore
733
+ )
734
+
735
+ crafter_wrapper = _CrafterWrapper # type: ignore[assignment]
736
+
737
+ if crafter_wrapper is not None and isinstance(
738
+ getattr(env_handle, "env", None), crafter_wrapper
739
+ ):
740
+ with contextlib.suppress(Exception):
741
+ from .envs.crafter.shared import format_observation as _fmt # type: ignore
742
+
743
+ text = _fmt(observation or {})
744
+ return {"text": text}
745
+
746
+ # Generic fallback: extract a few small fields if present; avoid huge arrays
747
+ with contextlib.suppress(Exception):
748
+ inv = observation.get("inventory") if isinstance(observation, dict) else None
749
+ ach = observation.get("achievements_status") if isinstance(observation, dict) else None
750
+ pos = observation.get("player_position") if isinstance(observation, dict) else None
751
+ health = None
752
+ if isinstance(inv, dict):
753
+ health = inv.get("health")
754
+ summary = {
755
+ "position": pos,
756
+ "health": health,
757
+ "inventory_keys": sorted(k for k, v in (inv or {}).items() if v)[:10]
758
+ if isinstance(inv, dict)
759
+ else None,
760
+ "achievements_unlocked": sorted(k for k, v in (ach or {}).items() if v)[:10]
761
+ if isinstance(ach, dict)
762
+ else None,
763
+ }
764
+ return {"text": json.dumps(summary, ensure_ascii=False)}
765
+
766
+ # Last resort: plain string preview
767
+ try:
768
+ return {"text": str(observation)[:10000]}
769
+ except Exception:
770
+ return {"text": ""}
771
+
772
+
773
+ class RunAbortRequest(BaseModel):
774
+ run_id: str
775
+
776
+
777
+ class RunAbortResponse(BaseModel):
778
+ ok: bool
779
+ run_id: str
780
+
781
+
782
+ class RunStatusResponse(BaseModel):
783
+ run_id: str
784
+ status: str
785
+ started_at: datetime
786
+ finished_at: datetime | None = None
787
+
788
+
789
+ @router.post("/rollout", response_model=RolloutResponse)
790
+ async def execute_rollout(
791
+ request: RolloutRequest,
792
+ req: Request,
793
+ ) -> RolloutResponse:
794
+ """Execute a rollout with coordinated environment and policy steps."""
795
+ # Emit rollout identifier early for correlation
796
+ with contextlib.suppress(Exception):
797
+ _rid = getattr(request, "run_id", None)
798
+ _pol = getattr(request.policy, "policy_name", None) or getattr(request.policy, "policy_id", None)
799
+ _env = getattr(request.env, "env_name", None) or getattr(request.env, "env_id", None)
800
+ logger.info("ROLLOUT_BEGIN: run_id=%s policy=%s env=%s", _rid, _pol, _env)
801
+ print(f"[rollout] begin run_id={_rid} policy={_pol} env={_env}", flush=True)
802
+ # Enforce per-episode step cap via env-specific parameters; default to 20 if omitted
803
+ try:
804
+ _env_params = {}
805
+ if isinstance(request.env, RolloutEnvSpec) and isinstance(request.env.config, dict):
806
+ _env_params = dict(request.env.config.get("env_params") or {})
807
+ max_steps_per_episode = int(_env_params.get("max_steps_per_episode") or 20)
808
+ assert max_steps_per_episode > 0, "max_steps_per_episode must be a positive integer"
809
+ except Exception as _mse:
810
+ raise HTTPException(
811
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
812
+ detail={
813
+ "error": "invalid_env_params",
814
+ "message": f"Invalid or missing env_params.max_steps_per_episode: {_mse}",
815
+ },
816
+ ) from _mse
817
+ # Truncate incoming ops to the enforced cap (each step is [agent, env])
818
+ ops_seq: list[str] = list(request.ops or [])
819
+ allowed_ops = max(0, int(max_steps_per_episode) * 2)
820
+ if len(ops_seq) > allowed_ops:
821
+ with contextlib.suppress(Exception):
822
+ logger.info(
823
+ "ROLL_OUT: truncating ops to cap: requested_ops=%s allowed_ops=%s",
824
+ str(len(ops_seq)),
825
+ str(allowed_ops),
826
+ )
827
+ ops_seq = ops_seq[:allowed_ops]
828
+ # Simple API key auth for inbound rollout
829
+ header_key = req.headers.get("x-api-key")
830
+ env_key = os.getenv("ENVIRONMENT_API_KEY")
831
+ dev_key = os.getenv("DEV_ENVIRONMENT_API_KEY")
832
+ # Accept either ENVIRONMENT_API_KEY or DEV_ENVIRONMENT_API_KEY
833
+ expected_keys = [k for k in (env_key, dev_key) if k]
834
+ if not expected_keys:
835
+ missing = []
836
+ if not env_key:
837
+ missing.append("ENVIRONMENT_API_KEY")
838
+ if not dev_key:
839
+ missing.append("DEV_ENVIRONMENT_API_KEY")
840
+ msg = f"Auth not configured: missing {', '.join(missing)} in task service environment"
841
+ logger.error(msg)
842
+ raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=msg)
843
+ if not header_key:
844
+ raise HTTPException(
845
+ status_code=status.HTTP_401_UNAUTHORIZED,
846
+ detail="Invalid or missing API key: X-API-Key header not provided",
847
+ )
848
+ if header_key not in expected_keys:
849
+ # Do not leak secrets; include short prefix for diagnostics
850
+ exp_src = env_key if env_key else (dev_key or "")
851
+ exp_prefix = (exp_src[:7] + "…") if len(exp_src) >= 7 else "set"
852
+ got_prefix = (header_key[:7] + "…") if len(header_key) >= 7 else "set"
853
+ raise HTTPException(
854
+ status_code=status.HTTP_401_UNAUTHORIZED,
855
+ detail=f"Invalid API key: header does not match expected (got={got_prefix}, expected_prefix={exp_prefix})",
856
+ )
857
+
858
+ # Log contextual fields for traceability
859
+ if request.training_session_id:
860
+ logger.info(f"ROLL_OUT: training_session_id={request.training_session_id}")
861
+ if request.synth_base_url:
862
+ logger.info(f"ROLL_OUT: synth_base_url={request.synth_base_url}")
863
+
864
+ # Log masked OpenAI API key presence for diagnostics
865
+ with contextlib.suppress(Exception):
866
+ _oa = os.getenv("OPENAI_API_KEY")
867
+ if _oa:
868
+ _pref = (_oa[:6] + "…") if len(_oa) >= 6 else "set"
869
+ logger.info(f"ROLL_OUT: OPENAI_API_KEY present (prefix={_pref})")
870
+ else:
871
+ logger.warning("ROLL_OUT: OPENAI_API_KEY missing")
872
+
873
+ # Make synth_base_url available for outbound calls in this app
874
+ with contextlib.suppress(Exception):
875
+ task_app = req.app.state.task_app
876
+ if request.synth_base_url:
877
+ task_app.synth_base_url = request.synth_base_url
878
+
879
+ tracer_factory = getattr(req.app.state, "session_tracer_factory", None)
880
+ tracer_instance: SessionTracer | None = None
881
+ if callable(tracer_factory):
882
+ try:
883
+ inst = tracer_factory()
884
+ tracer_instance = inst if isinstance(inst, SessionTracer) else None
885
+ except Exception as exc:
886
+ logger.debug(f"TRACER_FACTORY_FAIL: {exc}")
887
+ tracing_context = RolloutTracingContext(tracer_instance, request, req)
888
+ await tracing_context.start_session()
889
+ # Print whether tracing is active for this rollout
890
+ try:
891
+ print(
892
+ f"[rollout] tracing enabled={bool(tracing_context.enabled)} run_id={request.run_id}",
893
+ flush=True,
894
+ )
895
+ except Exception:
896
+ pass
897
+
898
+ # Register run
899
+ registry.register_run(request.run_id)
900
+
901
+ # Track resources created during this rollout so we can guarantee cleanup
902
+ created_env_id: str | None = None
903
+ created_policy_id: str | None = None
904
+ env_seed_used: int | None = None
905
+ trajectory_steps: list[RolloutStep] = []
906
+ decision_samples: list[dict[str, Any]] = []
907
+ pending_tool_calls: Any = None
908
+ current_obs: Any = {}
909
+ total_reward: float = 0.0
910
+ ops_executed = 0
911
+ last_agent_response_ts: float | None = None
912
+ last_policy_meta: dict[str, Any] | None = None
913
+ last_env_step_ms: float | None = None
914
+ last_env_step_completed_ts: float | None = None
915
+ decision_open = False
916
+ finalized = False
917
+ prev_achievements: dict[str, bool] = {}
918
+ session_trace = None
919
+ step_rewards_active = False
920
+
921
+ try:
922
+ # Initialize deterministic seed early for the entire rollout
923
+ seed_value: int | None = None
924
+ try:
925
+ if request.env and request.env.seed is not None:
926
+ seed_value = int(request.env.seed)
927
+ else:
928
+ # Derive a stable seed from run_id
929
+ import hashlib as _hashlib # local import to avoid global deps
930
+
931
+ _digest = _hashlib.sha256(request.run_id.encode("utf-8")).hexdigest()
932
+ # Use lower 32 bits to fit common RNG ranges
933
+ seed_value = int(_digest[:8], 16)
934
+ except Exception:
935
+ # Fallback to time-based seed if anything goes wrong
936
+ try:
937
+ seed_value = int((_time.time_ns() // 1_000_000) % (2**31 - 1))
938
+ except Exception:
939
+ seed_value = 42
940
+
941
+ _seed_info = _set_global_seed(int(seed_value))
942
+ with contextlib.suppress(Exception):
943
+ logger.info(
944
+ "ROLL_OUT: RNG seeded seed=%s libs=%s",
945
+ str(_seed_info.get("seed")),
946
+ ",".join(_seed_info.get("libs", [])),
947
+ )
948
+ # Resolve or create environment
949
+ if request.env.env_id:
950
+ env_handle = registry.get_env(request.env.env_id)
951
+ if not env_handle:
952
+ raise HTTPException(
953
+ status_code=404,
954
+ detail=f"Environment {request.env.env_id} not found",
955
+ )
956
+ env_id = request.env.env_id
957
+ else:
958
+ # Create new environment
959
+ from .environment_routes import EnvCreateRequest, create_environment
960
+
961
+ if not request.env.env_name:
962
+ raise ValueError("FATAL: env_name is required - NO FALLBACKS!")
963
+
964
+ # Propagate training_session_id via env config for downstream usage
965
+ _env_config = dict(request.env.config or {})
966
+ if request.training_session_id is not None:
967
+ _env_config.setdefault("training_session_id", request.training_session_id)
968
+ env_response = await create_environment(
969
+ EnvCreateRequest(
970
+ env_name=request.env.env_name,
971
+ config=_env_config,
972
+ seed=request.env.seed,
973
+ rl_run_id=request.run_id,
974
+ )
975
+ )
976
+ env_id = env_response.env_id
977
+ env_handle = registry.get_env(env_id)
978
+ created_env_id = env_id
979
+
980
+ tracing_context.update_metadata(env_id=env_id)
981
+
982
+ # Resolve or create policy
983
+ if request.policy.policy_id:
984
+ policy_handle = registry.get_policy(request.policy.policy_id)
985
+ if not policy_handle:
986
+ raise HTTPException(
987
+ status_code=404,
988
+ detail=f"Policy {request.policy.policy_id} not found",
989
+ )
990
+ policy_id = request.policy.policy_id
991
+ else:
992
+ # Create new policy
993
+ from .policy_routes import PolicyCreateRequest, create_policy
994
+
995
+ if not request.policy.policy_name:
996
+ raise ValueError("FATAL: policy_name is required - NO FALLBACKS!")
997
+
998
+ # Propagate training_session_id and synth_base_url via policy config
999
+ _policy_config = dict(request.policy.config or {})
1000
+ if request.training_session_id is not None:
1001
+ _policy_config.setdefault("training_session_id", request.training_session_id)
1002
+ if request.synth_base_url is not None:
1003
+ _policy_config.setdefault("synth_base_url", request.synth_base_url)
1004
+ policy_response = await create_policy(
1005
+ PolicyCreateRequest(
1006
+ policy_name=request.policy.policy_name,
1007
+ config=_policy_config,
1008
+ rl_run_id=request.run_id,
1009
+ bound_env_id=env_id,
1010
+ ),
1011
+ req,
1012
+ )
1013
+ policy_id = policy_response.policy_id
1014
+ policy_handle = registry.get_policy(policy_id)
1015
+ created_policy_id = policy_id
1016
+
1017
+ tracing_context.update_metadata(policy_id=policy_id)
1018
+
1019
+ # Bind policy to environment if not already bound
1020
+ if policy_handle and not policy_handle.bound_env_id:
1021
+ policy_handle.bound_env_id = env_id
1022
+
1023
+ # Record seed bound to environment for end-of-rollout verification/logging
1024
+ try:
1025
+ env_seed_used = int(getattr(env_handle, "seed", 0) or 0)
1026
+ except Exception:
1027
+ env_seed_used = None
1028
+ tracing_context.update_metadata(env_seed=env_seed_used)
1029
+ # Initialize trajectory
1030
+ trajectory_steps = []
1031
+ pending_tool_calls = None
1032
+ current_obs = env_handle.last_observation
1033
+ total_reward = 0.0
1034
+ ops_executed = 0
1035
+ last_agent_response_ts = None
1036
+ last_policy_meta = None
1037
+ last_env_step_ms = None
1038
+ last_env_step_completed_ts = None
1039
+
1040
+ # Stepwise reward configuration (Crafter shaping; gate on explicit enable)
1041
+ step_rewards_cfg_raw: dict[str, Any] = {}
1042
+ try:
1043
+ if isinstance(request.policy.config, dict):
1044
+ step_rewards_cfg_raw = dict(request.policy.config.get("step_rewards") or {})
1045
+ except Exception:
1046
+ step_rewards_cfg_raw = {}
1047
+ if not step_rewards_cfg_raw:
1048
+ try:
1049
+ if isinstance(request.env.config, dict):
1050
+ step_rewards_cfg_raw = dict(request.env.config.get("step_rewards") or {})
1051
+ except Exception:
1052
+ step_rewards_cfg_raw = {}
1053
+
1054
+ step_rewards_enabled = bool(step_rewards_cfg_raw.get("enabled", False))
1055
+ step_rewards_mode = str(step_rewards_cfg_raw.get("mode") or "off").lower()
1056
+ try:
1057
+ step_rewards_indicator_lambda = float(
1058
+ step_rewards_cfg_raw.get("indicator_lambda") or 0.0
1059
+ )
1060
+ except Exception:
1061
+ step_rewards_indicator_lambda = 0.0
1062
+ try:
1063
+ step_rewards_beta = float(step_rewards_cfg_raw.get("step_beta") or 0.0)
1064
+ except Exception:
1065
+ step_rewards_beta = 0.0
1066
+ step_rewards_active = step_rewards_enabled and step_rewards_mode == "decision_stepwise"
1067
+
1068
+ def _extract_achievements(obs: Any) -> dict[str, bool]:
1069
+ if not isinstance(obs, dict):
1070
+ return {}
1071
+ ach = obs.get("achievements_status")
1072
+ if isinstance(ach, dict):
1073
+ return {str(k): bool(v) for k, v in ach.items()}
1074
+ return {}
1075
+
1076
+ def _summarize_tool_calls(tool_calls: Any) -> list[dict[str, Any]]:
1077
+ if not tool_calls:
1078
+ return []
1079
+ try:
1080
+ items = (
1081
+ tool_calls
1082
+ if isinstance(tool_calls, list)
1083
+ else list(tool_calls) # tolerates tuples or pydantic lists
1084
+ )
1085
+ except Exception:
1086
+ return []
1087
+ summary: list[dict[str, Any]] = []
1088
+ for tc in items:
1089
+ tool_name = None
1090
+ args: Any = {}
1091
+ if isinstance(tc, dict):
1092
+ tool_name = tc.get("tool") or tc.get("tool_name") or tc.get("name")
1093
+ raw_args = tc.get("arguments") or tc.get("args") or {}
1094
+ else:
1095
+ tool_name = getattr(tc, "tool", None) or getattr(tc, "tool_name", None)
1096
+ raw_args = getattr(tc, "arguments", None) or getattr(tc, "args", None) or {}
1097
+ args = raw_args
1098
+ if isinstance(raw_args, str):
1099
+ try:
1100
+ args = json.loads(raw_args)
1101
+ except Exception:
1102
+ args = raw_args
1103
+ summary.append({"tool": tool_name, "args": args})
1104
+ return summary
1105
+
1106
+ decision_samples: list[dict[str, Any]] = []
1107
+ decision_index = 0
1108
+ decision_open = False
1109
+ session_trace = None
1110
+ finalized = False
1111
+ prev_achievements = _extract_achievements(current_obs)
1112
+ # Track episode-level achievements that have been seen as true at any point so far
1113
+ episode_seen_achievements: set[str] = {
1114
+ k for k, v in (prev_achievements or {}).items() if bool(v)
1115
+ }
1116
+ stepwise_indicator_sum = 0.0
1117
+ stepwise_reward_sum = 0.0
1118
+ stepwise_new_achievements_total = 0
1119
+ final_achievement_count = sum(1 for v in prev_achievements.values() if v)
1120
+
1121
+ # Execute ops sequence (capped by env_params.max_steps_per_episode)
1122
+ for op_idx, op in enumerate(ops_seq):
1123
+ # Check for abort
1124
+ if registry.is_run_aborted(request.run_id):
1125
+ logger.info(f"Run {request.run_id} aborted at op {op_idx}")
1126
+ break
1127
+
1128
+ # Check safety limits
1129
+ if ops_executed >= request.safety.max_ops:
1130
+ logger.warning(f"Reached max_ops limit ({request.safety.max_ops})")
1131
+ break
1132
+
1133
+ if op == "agent":
1134
+ # Policy step
1135
+ from .policy_routes import PolicyStepRequest, step_policy
1136
+
1137
+ if not decision_open:
1138
+ await tracing_context.start_decision(decision_index)
1139
+ decision_open = True
1140
+
1141
+ agent_request_start = _time.perf_counter()
1142
+ if last_agent_response_ts is not None and last_policy_meta is not None:
1143
+ with contextlib.suppress(Exception):
1144
+ timing_prev = last_policy_meta.setdefault("timing", {})
1145
+ decision_ms = max(
1146
+ 0.0,
1147
+ (agent_request_start - float(last_agent_response_ts)) * 1000.0,
1148
+ )
1149
+ # Update timing on prior policy meta (kept by previous env step)
1150
+ timing_prev["decision_ms"] = decision_ms
1151
+ if last_env_step_ms is not None:
1152
+ timing_prev["env_step_ms"] = float(last_env_step_ms)
1153
+ timing_prev["overhead_ms"] = max(
1154
+ 0.0, decision_ms - float(last_env_step_ms)
1155
+ )
1156
+ else:
1157
+ timing_prev.setdefault("overhead_ms", 0.0)
1158
+ timing_prev["decision_ready_s"] = agent_request_start
1159
+ # Also backfill the last appended trajectory step so the trainer
1160
+ # can always see decision_ms without relying on shared dict refs.
1161
+ if trajectory_steps:
1162
+ with contextlib.suppress(Exception):
1163
+ _last = trajectory_steps[-1]
1164
+ _info = dict(_last.info or {})
1165
+ _meta = dict(_info.get("meta") or {})
1166
+ _timing = dict(_meta.get("timing") or {})
1167
+ _timing["decision_ms"] = decision_ms
1168
+ if last_env_step_ms is not None:
1169
+ _timing.setdefault("env_step_ms", float(last_env_step_ms))
1170
+ _timing.setdefault(
1171
+ "overhead_ms",
1172
+ max(0.0, decision_ms - float(last_env_step_ms)),
1173
+ )
1174
+ else:
1175
+ _timing.setdefault("overhead_ms", 0.0)
1176
+ _meta["timing"] = _timing
1177
+ _info["meta"] = _meta
1178
+ _last.info = _info
1179
+ last_env_step_ms = None
1180
+ last_env_step_completed_ts = None
1181
+
1182
+ # Build metadata for policy (carry previous tool_calls and env result)
1183
+ metadata = {}
1184
+ if pending_tool_calls:
1185
+ metadata["prev_tool_calls"] = pending_tool_calls
1186
+ if len(trajectory_steps) > 0:
1187
+ last_step = trajectory_steps[-1]
1188
+ # Prefer the last executed tool calls to seed history
1189
+ if last_step.tool_calls:
1190
+ metadata["prev_tool_calls"] = last_step.tool_calls
1191
+ # Provide a compact env result snapshot
1192
+ metadata["prev_env_result"] = {
1193
+ "observation": last_step.obs,
1194
+ "reward": last_step.reward,
1195
+ "done": last_step.done,
1196
+ "truncated": last_step.truncated,
1197
+ "info": last_step.info,
1198
+ }
1199
+
1200
+ # Log compact metadata summary to confirm history threading
1201
+ with contextlib.suppress(Exception):
1202
+ _prev_calls = metadata.get("prev_tool_calls")
1203
+ _count = len(_prev_calls) if isinstance(_prev_calls, list) else 0
1204
+ _first_guess = None
1205
+ if _count > 0 and isinstance(_prev_calls[0], dict):
1206
+ _args = _prev_calls[0].get("arguments", None)
1207
+ if isinstance(_args, str):
1208
+ import json as _json
1209
+ with contextlib.suppress(Exception):
1210
+ _args = _json.loads(_args)
1211
+ if not isinstance(_args, dict):
1212
+ _args = {}
1213
+ _first_guess = _args.get("guess") or _args.get("word")
1214
+ logger.info(
1215
+ "POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
1216
+ _count,
1217
+ _first_guess,
1218
+ str("prev_env_result" in metadata),
1219
+ )
1220
+
1221
+ try:
1222
+ policy_response = await step_policy(
1223
+ PolicyStepRequest(
1224
+ policy_id=policy_id,
1225
+ observation=current_obs,
1226
+ metadata=metadata,
1227
+ ),
1228
+ req,
1229
+ )
1230
+ except Exception as _pe:
1231
+ # Do not 500 the rollout; finalize with partial trajectory
1232
+ with contextlib.suppress(Exception):
1233
+ logger.warning(
1234
+ "POLICY_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1235
+ request.run_id,
1236
+ str(op_idx),
1237
+ str(_pe),
1238
+ )
1239
+
1240
+ # Build partial trajectory and return HTTP 200
1241
+ trajectory = RolloutTrajectory(
1242
+ env_id=env_id,
1243
+ policy_id=policy_id,
1244
+ steps=trajectory_steps,
1245
+ final={
1246
+ "observation": current_obs,
1247
+ "rollout_status": "partial_policy_error",
1248
+ "error": str(_pe),
1249
+ "at_op": op,
1250
+ },
1251
+ length=len(trajectory_steps),
1252
+ decision_samples=decision_samples if step_rewards_active else None,
1253
+ )
1254
+ metrics = RolloutMetrics(
1255
+ episode_returns=[total_reward],
1256
+ mean_return=total_reward,
1257
+ num_steps=len(trajectory_steps),
1258
+ num_episodes=1,
1259
+ )
1260
+ aborted = registry.is_run_aborted(request.run_id)
1261
+ if not aborted:
1262
+ registry.complete_run(request.run_id)
1263
+ if decision_open:
1264
+ await tracing_context.end_decision()
1265
+ decision_open = False
1266
+ if not finalized:
1267
+ session_trace = await tracing_context.finalize(
1268
+ total_reward=total_reward,
1269
+ achievement_state=prev_achievements,
1270
+ total_steps=len(trajectory_steps),
1271
+ )
1272
+ finalized = True
1273
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1274
+ return RolloutResponse(
1275
+ run_id=request.run_id,
1276
+ trajectories=[trajectory],
1277
+ branches={},
1278
+ metrics=metrics,
1279
+ aborted=aborted,
1280
+ ops_executed=ops_executed,
1281
+ trace=trace_payload,
1282
+ )
1283
+
1284
+ agent_response_ts = _time.perf_counter()
1285
+ if isinstance(policy_response.meta, dict):
1286
+ with contextlib.suppress(Exception):
1287
+ timing_cur = policy_response.meta.setdefault("timing", {})
1288
+ timing_cur["agent_request_start_s"] = agent_request_start
1289
+ timing_cur["agent_response_s"] = agent_response_ts
1290
+ if "inference_ms" in policy_response.meta:
1291
+ with contextlib.suppress(Exception):
1292
+ timing_cur.setdefault(
1293
+ "inference_ms",
1294
+ float(policy_response.meta["inference_ms"]),
1295
+ )
1296
+ timing_cur.setdefault(
1297
+ "inference_s",
1298
+ float(policy_response.meta["inference_ms"]) / 1000.0,
1299
+ )
1300
+ last_policy_meta = policy_response.meta
1301
+ else:
1302
+ last_policy_meta = None
1303
+ last_agent_response_ts = agent_response_ts
1304
+
1305
+ # Diagnostic: summarize policy step target and tool calls
1306
+ try:
1307
+ model_name = None
1308
+ target_url = None
1309
+ if isinstance(policy_response.meta, dict):
1310
+ req_body = policy_response.meta.get("inference_request") or {}
1311
+ model_name = req_body.get("model")
1312
+ target_url = policy_response.meta.get("inference_url")
1313
+ _tc = policy_response.tool_calls or []
1314
+ print(
1315
+ {
1316
+ "rollout.policy_step": True,
1317
+ "run_id": request.run_id,
1318
+ "model": model_name,
1319
+ "inference_url": target_url,
1320
+ "tool_calls_count": len(_tc) if isinstance(_tc, list) else 0,
1321
+ },
1322
+ flush=True,
1323
+ )
1324
+ except Exception:
1325
+ pass
1326
+
1327
+ pending_tool_calls = policy_response.tool_calls
1328
+ # Log summarized agent tool calls
1329
+ with contextlib.suppress(Exception):
1330
+ _tc = pending_tool_calls or []
1331
+ _summary = []
1332
+ for _item in (_tc if isinstance(_tc, list) else []):
1333
+ try:
1334
+ if isinstance(_item, dict):
1335
+ _tool = _item.get("tool")
1336
+ _args = _item.get("args")
1337
+ _keys = list(_args.keys()) if isinstance(_args, dict) else []
1338
+ _summary.append({"tool": _tool, "args_keys": _keys})
1339
+ except Exception:
1340
+ continue
1341
+ _rid = getattr(request, "run_id", None)
1342
+ logger.info("AGENT_TOOL_CALLS: run_id=%s count=%d summary=%s", _rid, len(_tc), _summary)
1343
+ print(f"[rollout] agent tool_calls run_id={_rid} count={len(_tc)} summary={_summary}", flush=True)
1344
+ await tracing_context.record_tool_invocation(pending_tool_calls)
1345
+ ops_executed += 1
1346
+
1347
+ elif op == "env":
1348
+ if not pending_tool_calls:
1349
+ # Treat absence of tool calls as a soft terminal condition; yield partial trajectory
1350
+ with contextlib.suppress(Exception):
1351
+ logger.warning(
1352
+ "NO_TOOL_CALLS: terminating episode early run_id=%s op_idx=%s",
1353
+ request.run_id,
1354
+ str(op_idx),
1355
+ )
1356
+ print(
1357
+ f"[rollout] no tool_calls; terminating early run_id={request.run_id} op_idx={op_idx}",
1358
+ flush=True,
1359
+ )
1360
+ term_step = RolloutStep(
1361
+ obs=current_obs,
1362
+ tool_calls=[],
1363
+ reward=None,
1364
+ done=True,
1365
+ truncated=False,
1366
+ info={
1367
+ "terminated": True,
1368
+ "reason": "no_tool_calls",
1369
+ },
1370
+ )
1371
+ trajectory_steps.append(term_step)
1372
+ trajectory = RolloutTrajectory(
1373
+ env_id=env_id,
1374
+ policy_id=policy_id,
1375
+ steps=trajectory_steps,
1376
+ final={
1377
+ "observation": current_obs,
1378
+ "rollout_status": "partial_no_tool_calls",
1379
+ "at_op": op,
1380
+ },
1381
+ length=len(trajectory_steps),
1382
+ decision_samples=decision_samples if step_rewards_active else None,
1383
+ )
1384
+ metrics = RolloutMetrics(
1385
+ episode_returns=[total_reward],
1386
+ mean_return=total_reward,
1387
+ num_steps=len(trajectory_steps),
1388
+ num_episodes=1,
1389
+ )
1390
+ aborted = registry.is_run_aborted(request.run_id)
1391
+ if not aborted:
1392
+ registry.complete_run(request.run_id)
1393
+ if decision_open:
1394
+ await tracing_context.end_decision()
1395
+ decision_open = False
1396
+ if not finalized:
1397
+ session_trace = await tracing_context.finalize(
1398
+ total_reward=total_reward,
1399
+ achievement_state=prev_achievements,
1400
+ total_steps=len(trajectory_steps),
1401
+ )
1402
+ finalized = True
1403
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1404
+ return RolloutResponse(
1405
+ run_id=request.run_id,
1406
+ trajectories=[trajectory],
1407
+ branches={},
1408
+ metrics=metrics,
1409
+ aborted=aborted,
1410
+ ops_executed=ops_executed,
1411
+ trace=trace_payload,
1412
+ )
1413
+
1414
+ # Environment step
1415
+ from .environment_routes import EnvStepRequest, step_environment
1416
+
1417
+ env_step_error: Exception | None = None
1418
+ env_response = None
1419
+ env_step_start = _time.perf_counter()
1420
+ try:
1421
+ env_response = await step_environment(
1422
+ EnvStepRequest(
1423
+ env_id=env_id,
1424
+ tool_calls=pending_tool_calls,
1425
+ )
1426
+ )
1427
+ except Exception as _ee:
1428
+ env_step_error = _ee
1429
+ env_step_end = _time.perf_counter()
1430
+ env_step_duration_ms = (env_step_end - env_step_start) * 1000.0
1431
+ last_env_step_ms = env_step_duration_ms
1432
+ last_env_step_completed_ts = env_step_end
1433
+ if last_policy_meta is not None:
1434
+ with contextlib.suppress(Exception):
1435
+ timing_env = last_policy_meta.setdefault("timing", {})
1436
+ timing_env["env_step_ms"] = env_step_duration_ms
1437
+ timing_env["env_step_end_s"] = env_step_end
1438
+
1439
+ if env_step_error is not None:
1440
+ # Invalid action or environment rejection — terminate episode early with partial trajectory
1441
+ with contextlib.suppress(Exception):
1442
+ logger.warning(
1443
+ "ENV_STEP_FAIL: terminating episode early run_id=%s op_idx=%s err=%s",
1444
+ request.run_id,
1445
+ str(op_idx),
1446
+ str(env_step_error),
1447
+ )
1448
+
1449
+ term_step = RolloutStep(
1450
+ obs=current_obs,
1451
+ tool_calls=pending_tool_calls,
1452
+ reward=None,
1453
+ done=True,
1454
+ truncated=False,
1455
+ info={
1456
+ "terminated": True,
1457
+ "reason": "invalid_action",
1458
+ "error": str(env_step_error),
1459
+ },
1460
+ )
1461
+ trajectory_steps.append(term_step)
1462
+ # Build partial response
1463
+ trajectory = RolloutTrajectory(
1464
+ env_id=env_id,
1465
+ policy_id=policy_id,
1466
+ steps=trajectory_steps,
1467
+ final={
1468
+ "observation": current_obs,
1469
+ "rollout_status": "partial_invalid_action",
1470
+ "error": str(env_step_error),
1471
+ "at_op": op,
1472
+ },
1473
+ length=len(trajectory_steps),
1474
+ decision_samples=decision_samples if step_rewards_active else None,
1475
+ )
1476
+ metrics = RolloutMetrics(
1477
+ episode_returns=[total_reward],
1478
+ mean_return=total_reward,
1479
+ num_steps=len(trajectory_steps),
1480
+ num_episodes=1,
1481
+ )
1482
+ aborted = registry.is_run_aborted(request.run_id)
1483
+ if not aborted:
1484
+ registry.complete_run(request.run_id)
1485
+ if (
1486
+ last_policy_meta is not None
1487
+ and last_agent_response_ts is not None
1488
+ and "decision_ms" not in last_policy_meta.get("timing", {})
1489
+ ):
1490
+ with contextlib.suppress(Exception):
1491
+ timing_last = last_policy_meta.setdefault("timing", {})
1492
+ decision_ms = max(
1493
+ 0.0,
1494
+ (env_step_end - float(last_agent_response_ts)) * 1000.0,
1495
+ )
1496
+ timing_last["decision_ms"] = decision_ms
1497
+ timing_last.setdefault(
1498
+ "overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
1499
+ )
1500
+ if decision_open:
1501
+ await tracing_context.end_decision()
1502
+ decision_open = False
1503
+ if not finalized:
1504
+ session_trace = await tracing_context.finalize(
1505
+ total_reward=total_reward,
1506
+ achievement_state=prev_achievements,
1507
+ total_steps=len(trajectory_steps),
1508
+ )
1509
+ finalized = True
1510
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1511
+ return RolloutResponse(
1512
+ run_id=request.run_id,
1513
+ trajectories=[trajectory],
1514
+ branches={},
1515
+ metrics=metrics,
1516
+ aborted=aborted,
1517
+ ops_executed=ops_executed,
1518
+ trace=trace_payload,
1519
+ )
1520
+
1521
+ # Reaching here means env step succeeded
1522
+ assert env_response is not None
1523
+
1524
+ # Record step, including policy meta if present for timing/tokens observability
1525
+ _info = env_response.info if isinstance(env_response.info, dict) else {}
1526
+ # Attach policy meta from the immediately preceding agent step
1527
+ with contextlib.suppress(Exception):
1528
+ prev_meta = {}
1529
+ if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
1530
+ prev_meta = policy_response.meta
1531
+ if prev_meta:
1532
+ _info = dict(_info)
1533
+ _info["meta"] = prev_meta
1534
+
1535
+ event_metadata = {
1536
+ "op_index": op_idx,
1537
+ }
1538
+ event_id = await tracing_context.record_environment_event(
1539
+ env_handle=env_handle,
1540
+ prev_obs=current_obs,
1541
+ env_response=env_response,
1542
+ next_obs=getattr(env_response, "observation", None),
1543
+ metadata=event_metadata,
1544
+ )
1545
+
1546
+ decision_index += 1
1547
+ next_obs = env_response.observation
1548
+ new_achievement_state = _extract_achievements(next_obs)
1549
+ final_achievement_count = sum(
1550
+ 1 for _, unlocked in new_achievement_state.items() if unlocked
1551
+ )
1552
+ indicator_val = 0
1553
+ reward_stepwise = 0.0
1554
+ decision_rewards_meta: dict[str, Any] | None = None
1555
+ if step_rewards_active:
1556
+ decision_actions = _summarize_tool_calls(pending_tool_calls)
1557
+ stepwise_info, decision_record, stats = compute_stepwise_reward(
1558
+ prev_achievements or {},
1559
+ new_achievement_state,
1560
+ decision_index,
1561
+ decision_actions,
1562
+ step_rewards_indicator_lambda,
1563
+ )
1564
+ indicator_val = int(stats.get("indicator", 0.0))
1565
+ reward_stepwise = float(stats.get("reward", 0.0))
1566
+ stepwise_indicator_sum += float(stats.get("indicator", 0.0))
1567
+ stepwise_reward_sum += reward_stepwise
1568
+ stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
1569
+ _info = {} if not isinstance(_info, dict) else dict(_info)
1570
+ _info["stepwise"] = stepwise_info
1571
+ # Compute decision-level rewards (absolute vs unique) and attach to metadata
1572
+ with contextlib.suppress(Exception):
1573
+ turned_true = set(stepwise_info.get("new_achievements") or [])
1574
+ seen_before = set(episode_seen_achievements)
1575
+ new_unique = sorted(turned_true - seen_before)
1576
+ ach_delta = int(len(turned_true))
1577
+ unique_delta = int(len(new_unique))
1578
+ # Prepare stable lists for logging/metadata
1579
+ all_list = sorted(turned_true)
1580
+ # Ensure nested meta exists
1581
+ meta_block = (
1582
+ _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
1583
+ )
1584
+ decision_rewards = {
1585
+ "turn": int(decision_index),
1586
+ "ach_delta": ach_delta,
1587
+ "unique_delta": unique_delta,
1588
+ "all": all_list,
1589
+ "unique": new_unique,
1590
+ }
1591
+ decision_rewards_meta = decision_rewards
1592
+ meta_block["decision_rewards"] = decision_rewards
1593
+ _info["meta"] = meta_block
1594
+ # Update episode-level seen set after attributing uniqueness to this decision
1595
+ episode_seen_achievements.update(turned_true)
1596
+ decision_samples.append(decision_record)
1597
+ prev_achievements = new_achievement_state
1598
+
1599
+ await tracing_context.record_decision_reward(
1600
+ event_id=event_id,
1601
+ decision_meta=decision_rewards_meta,
1602
+ )
1603
+
1604
+ step = RolloutStep(
1605
+ obs=_summarize_observation_for_storage(env_handle, current_obs),
1606
+ tool_calls=pending_tool_calls,
1607
+ reward=env_response.reward,
1608
+ done=env_response.done,
1609
+ truncated=env_response.truncated,
1610
+ info=_info,
1611
+ )
1612
+ # Log summarized env application of tool calls and immediate reward/done
1613
+ with contextlib.suppress(Exception):
1614
+ _tc = pending_tool_calls or []
1615
+ _summary = []
1616
+ for _item in (_tc if isinstance(_tc, list) else []):
1617
+ try:
1618
+ if isinstance(_item, dict):
1619
+ _tool = _item.get("tool")
1620
+ _args = _item.get("args")
1621
+ _keys = list(_args.keys()) if isinstance(_args, dict) else []
1622
+ _summary.append({"tool": _tool, "args_keys": _keys})
1623
+ except Exception:
1624
+ continue
1625
+ _rid = getattr(request, "run_id", None)
1626
+ logger.info(
1627
+ "ENV_APPLY: run_id=%s tool_calls=%d reward=%s done=%s summary=%s",
1628
+ _rid,
1629
+ len(_tc),
1630
+ str(env_response.reward),
1631
+ str(env_response.done),
1632
+ _summary,
1633
+ )
1634
+ print(
1635
+ f"[rollout] env apply run_id={_rid} tool_calls={len(_tc)} reward={env_response.reward} done={env_response.done} summary={_summary}",
1636
+ flush=True,
1637
+ )
1638
+ trajectory_steps.append(step)
1639
+
1640
+ if env_response.reward is not None:
1641
+ total_reward += env_response.reward
1642
+
1643
+ # Update state
1644
+ current_obs = next_obs
1645
+ pending_tool_calls = None
1646
+ ops_executed += 1
1647
+
1648
+ # Handle episode end
1649
+ if env_response.done:
1650
+ if request.on_done == "reset":
1651
+ # Reset environment
1652
+ from .environment_routes import (
1653
+ EnvResetRequest,
1654
+ reset_environment,
1655
+ )
1656
+
1657
+ reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
1658
+ current_obs = reset_response.observation
1659
+ elif request.on_done == "terminate":
1660
+ break
1661
+
1662
+ if decision_open:
1663
+ await tracing_context.end_decision()
1664
+ decision_open = False
1665
+
1666
+ else:
1667
+ logger.warning(f"Unknown op: {op}")
1668
+
1669
+ if (
1670
+ last_policy_meta is not None
1671
+ and last_agent_response_ts is not None
1672
+ and "timing" in last_policy_meta
1673
+ and isinstance(last_policy_meta["timing"], dict)
1674
+ and "decision_ms" not in last_policy_meta["timing"]
1675
+ ):
1676
+ with contextlib.suppress(Exception):
1677
+ final_now = last_env_step_completed_ts or _time.perf_counter()
1678
+ final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
1679
+ timing_final = last_policy_meta.setdefault("timing", {})
1680
+ timing_final["decision_ms"] = final_decision_ms
1681
+ if last_env_step_ms is not None:
1682
+ timing_final.setdefault("env_step_ms", float(last_env_step_ms))
1683
+ timing_final.setdefault(
1684
+ "overhead_ms",
1685
+ max(0.0, final_decision_ms - float(last_env_step_ms)),
1686
+ )
1687
+ else:
1688
+ timing_final.setdefault("overhead_ms", 0.0)
1689
+
1690
+ # Build trajectory
1691
+ trajectory = RolloutTrajectory(
1692
+ env_id=env_id,
1693
+ policy_id=policy_id,
1694
+ steps=trajectory_steps,
1695
+ final={"observation": _summarize_observation_for_storage(env_handle, current_obs)},
1696
+ length=len(trajectory_steps),
1697
+ decision_samples=decision_samples if step_rewards_active else None,
1698
+ )
1699
+
1700
+ # Build metrics
1701
+ metrics = RolloutMetrics(
1702
+ episode_returns=[total_reward],
1703
+ mean_return=total_reward,
1704
+ num_steps=len(trajectory_steps),
1705
+ num_episodes=1,
1706
+ )
1707
+
1708
+ # Environment-specific: Log summary if available
1709
+ try:
1710
+ # Check if this is a Wordle environment and use Wordle helpers (lazy import)
1711
+ wordle_wrapper_cls = None
1712
+ try:
1713
+ from .envs.wordle.environment import WordleEnvironmentWrapper
1714
+ from .envs.wordle.helpers import (
1715
+ get_wordle_rollout_summary,
1716
+ log_wordle_rollout_summary,
1717
+ )
1718
+
1719
+ wordle_wrapper_cls = WordleEnvironmentWrapper
1720
+ except Exception:
1721
+ wordle_wrapper_cls = None # type: ignore[assignment]
1722
+ get_wordle_rollout_summary = None # type: ignore
1723
+ log_wordle_rollout_summary = None # type: ignore
1724
+
1725
+ is_wordle = wordle_wrapper_cls is not None and isinstance(
1726
+ env_handle.env,
1727
+ wordle_wrapper_cls, # type: ignore[arg-type]
1728
+ )
1729
+ if is_wordle:
1730
+ # Convert trajectory steps to expected format
1731
+ formatted_steps = []
1732
+ for step in trajectory_steps:
1733
+ formatted_steps.append({"tool_calls": step.tool_calls or []})
1734
+
1735
+ if (
1736
+ get_wordle_rollout_summary is not None
1737
+ and log_wordle_rollout_summary is not None
1738
+ ):
1739
+ summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
1740
+ log_wordle_rollout_summary(request.run_id, summary)
1741
+ except ImportError:
1742
+ # Wordle helpers not available, skip Wordle-specific logging
1743
+ pass
1744
+ except Exception as e:
1745
+ logger.warning(f"Failed to generate environment-specific summary: {e}")
1746
+
1747
+ # Mark run as completed
1748
+ aborted = registry.is_run_aborted(request.run_id)
1749
+ if not aborted:
1750
+ registry.complete_run(request.run_id)
1751
+ if decision_open:
1752
+ await tracing_context.end_decision()
1753
+ decision_open = False
1754
+ if not finalized:
1755
+ session_trace = await tracing_context.finalize(
1756
+ total_reward=total_reward,
1757
+ achievement_state=prev_achievements,
1758
+ total_steps=len(trajectory_steps),
1759
+ )
1760
+ finalized = True
1761
+ trace_payload = tracing_context.build_trace_payload(session_trace)
1762
+
1763
+ return RolloutResponse(
1764
+ run_id=request.run_id,
1765
+ trajectories=[trajectory],
1766
+ branches={},
1767
+ metrics=metrics,
1768
+ aborted=aborted,
1769
+ ops_executed=ops_executed,
1770
+ trace=trace_payload,
1771
+ )
1772
+
1773
+ except Exception as e:
1774
+ logger.error(f"Rollout failed for run {request.run_id}: {e}")
1775
+ registry.abort_run(request.run_id)
1776
+ if decision_open:
1777
+ with contextlib.suppress(Exception):
1778
+ await tracing_context.end_decision()
1779
+ decision_open = False
1780
+ if not finalized:
1781
+ session_trace = None
1782
+ with contextlib.suppress(Exception):
1783
+ session_trace = await tracing_context.finalize(
1784
+ total_reward=total_reward,
1785
+ achievement_state=prev_achievements,
1786
+ total_steps=len(trajectory_steps),
1787
+ )
1788
+ finalized = True
1789
+ raise HTTPException(status_code=500, detail=str(e)) from e
1790
+ finally:
1791
+ # Ensure any environment created for this rollout is terminated (no reuse across rollouts)
1792
+ try:
1793
+ if created_env_id:
1794
+ from .environment_routes import EnvTerminateRequest, terminate_environment
1795
+
1796
+ await terminate_environment(EnvTerminateRequest(env_id=created_env_id))
1797
+ logger.info(
1798
+ "ROLL_OUT: terminated environment env_id=%s seed=%s",
1799
+ str(created_env_id),
1800
+ str(env_seed_used) if env_seed_used is not None else "unknown",
1801
+ )
1802
+ # Verify removal from registry
1803
+ with contextlib.suppress(Exception):
1804
+ _post = registry.get_env(created_env_id)
1805
+ logger.info(
1806
+ "ROLL_OUT: env_killed=%s (post_lookup=%s)",
1807
+ str(_post is None),
1808
+ str(_post),
1809
+ )
1810
+ except Exception as _te:
1811
+ logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
1812
+
1813
+ # Best-effort policy cleanup if we created one (avoid reuse across rollouts)
1814
+ with contextlib.suppress(Exception):
1815
+ if created_policy_id:
1816
+ from .policy_routes import PolicyTerminateRequest, terminate_policy
1817
+
1818
+ await terminate_policy(PolicyTerminateRequest(policy_id=created_policy_id))
1819
+ logger.info("ROLL_OUT: terminated policy policy_id=%s", str(created_policy_id))
1820
+
1821
+ if not finalized:
1822
+ session_trace = None
1823
+ with contextlib.suppress(Exception):
1824
+ session_trace = await tracing_context.finalize(
1825
+ total_reward=total_reward,
1826
+ achievement_state=prev_achievements,
1827
+ total_steps=len(trajectory_steps),
1828
+ )
1829
+ finalized = True
1830
+
1831
+ with contextlib.suppress(Exception):
1832
+ _clear_seed_side_effects()
1833
+ logger.info("ROLL_OUT: RNG seed terminated/cleared before conclusion")
1834
+
1835
+
1836
+ @router.post("/run/abort", response_model=RunAbortResponse)
1837
+ async def abort_run(request: RunAbortRequest) -> RunAbortResponse:
1838
+ """Abort a running rollout."""
1839
+ success = registry.abort_run(request.run_id)
1840
+
1841
+ if not success:
1842
+ raise HTTPException(
1843
+ status_code=404,
1844
+ detail=f"Run {request.run_id} not found",
1845
+ )
1846
+
1847
+ return RunAbortResponse(
1848
+ ok=True,
1849
+ run_id=request.run_id,
1850
+ )
1851
+
1852
+
1853
+ @router.get("/run/status/{run_id}", response_model=RunStatusResponse)
1854
+ async def get_run_status(run_id: str) -> RunStatusResponse:
1855
+ """Get the status of a run."""
1856
+ run_handle = registry.get_run(run_id)
1857
+
1858
+ if not run_handle:
1859
+ raise HTTPException(
1860
+ status_code=404,
1861
+ detail=f"Run {run_id} not found",
1862
+ )
1863
+
1864
+ return RunStatusResponse(
1865
+ run_id=run_id,
1866
+ status=run_handle.status,
1867
+ started_at=run_handle.started_at,
1868
+ finished_at=run_handle.finished_at,
1869
+ )