synth-ai 0.2.14__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (354) hide show
  1. examples/README.md +1 -0
  2. examples/analyze_semantic_words.sh +2 -2
  3. examples/blog_posts/pokemon_vl/README.md +98 -0
  4. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
  5. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  6. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  7. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
  8. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  9. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  10. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  11. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  12. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  13. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  14. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
  15. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  16. examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
  17. examples/multi_step/SFT_README.md +147 -0
  18. examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
  19. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +73 -115
  20. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
  21. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
  22. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  23. examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
  24. examples/multi_step/configs/verilog_rl_lora.toml +80 -123
  25. examples/multi_step/convert_traces_to_sft.py +84 -0
  26. examples/multi_step/run_sft_qwen30b.sh +45 -0
  27. examples/qwen_coder/configs/coder_lora_30b.toml +1 -2
  28. examples/qwen_coder/configs/coder_lora_4b.toml +5 -1
  29. examples/qwen_coder/configs/coder_lora_small.toml +1 -2
  30. examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
  31. examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
  32. examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
  33. examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
  34. examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
  35. examples/qwen_vl/QUICKSTART.md +327 -0
  36. examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
  37. examples/qwen_vl/README.md +152 -0
  38. examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
  39. examples/qwen_vl/RL_VISION_TESTING.md +333 -0
  40. examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
  41. examples/qwen_vl/SETUP_COMPLETE.md +274 -0
  42. examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
  43. examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
  44. examples/qwen_vl/__init__.py +2 -0
  45. examples/qwen_vl/collect_data_via_cli.md +415 -0
  46. examples/qwen_vl/collect_vision_traces.py +368 -0
  47. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
  48. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
  49. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
  50. examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
  51. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
  52. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  53. examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
  54. examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
  55. examples/qwen_vl/configs/filter_vision_test.toml +8 -0
  56. examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
  57. examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
  58. examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
  59. examples/qwen_vl/run_vision_comparison.sh +61 -0
  60. examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
  61. examples/qwen_vl/test_image_validation.py +201 -0
  62. examples/qwen_vl/test_sft_vision_data.py +110 -0
  63. examples/rl/README.md +6 -6
  64. examples/rl/configs/eval_base_qwen.toml +17 -0
  65. examples/rl/configs/eval_rl_qwen.toml +13 -0
  66. examples/rl/configs/rl_from_base_qwen.toml +62 -0
  67. examples/rl/configs/rl_from_base_qwen17.toml +79 -0
  68. examples/rl/configs/rl_from_ft_qwen.toml +37 -0
  69. examples/rl/run_eval.py +436 -0
  70. examples/rl/run_rl_and_save.py +111 -0
  71. examples/rl/task_app/README.md +21 -0
  72. examples/rl/task_app/math_single_step.py +990 -0
  73. examples/rl/task_app/math_task_app.py +111 -0
  74. examples/run_crafter_demo.sh +2 -2
  75. examples/sft/README.md +6 -6
  76. examples/sft/configs/crafter_fft_qwen0p6b.toml +7 -2
  77. examples/sft/configs/crafter_lora_qwen0p6b.toml +7 -3
  78. examples/sft/evaluate.py +2 -4
  79. examples/sft/export_dataset.py +7 -4
  80. examples/swe/task_app/README.md +33 -3
  81. examples/swe/task_app/grpo_swe_mini.py +4 -1
  82. examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
  83. examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
  84. examples/swe/task_app/hosted/envs/mini_swe/environment.py +50 -23
  85. examples/swe/task_app/hosted/inference/openai_client.py +4 -4
  86. examples/swe/task_app/hosted/policy_routes.py +0 -2
  87. examples/swe/task_app/hosted/rollout.py +0 -8
  88. examples/swe/task_app/morph_backend.py +178 -0
  89. examples/task_apps/crafter/task_app/README.md +1 -1
  90. examples/task_apps/crafter/task_app/grpo_crafter.py +70 -10
  91. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
  92. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +63 -27
  93. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
  94. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +48 -50
  95. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +75 -36
  96. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +31 -15
  97. examples/task_apps/enron/__init__.py +1 -0
  98. examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
  99. examples/task_apps/math/README.md +1 -2
  100. examples/task_apps/pokemon_red/README.md +3 -4
  101. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
  102. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
  103. examples/task_apps/pokemon_red/task_app.py +36 -5
  104. examples/task_apps/sokoban/README.md +2 -3
  105. examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
  106. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
  107. examples/vlm/README.md +3 -3
  108. examples/vlm/configs/crafter_vlm_gpt4o.toml +5 -0
  109. examples/vlm/crafter_openai_vlm_agent.py +3 -5
  110. examples/vlm/filter_image_rows.py +1 -1
  111. examples/vlm/run_crafter_vlm_benchmark.py +2 -2
  112. examples/warming_up_to_rl/_utils.py +92 -0
  113. examples/warming_up_to_rl/analyze_trace_db.py +1 -1
  114. examples/warming_up_to_rl/configs/crafter_fft.toml +5 -0
  115. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
  116. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
  117. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
  118. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
  119. examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
  120. examples/warming_up_to_rl/export_trace_sft.py +174 -60
  121. examples/warming_up_to_rl/readme.md +63 -132
  122. examples/warming_up_to_rl/run_fft_and_save.py +1 -1
  123. examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
  124. examples/warming_up_to_rl/run_rl_and_save.py +1 -1
  125. examples/warming_up_to_rl/task_app/README.md +42 -0
  126. examples/warming_up_to_rl/task_app/grpo_crafter.py +827 -0
  127. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
  128. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  129. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  130. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
  131. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
  132. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  133. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  134. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  135. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
  136. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
  137. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
  138. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
  139. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  140. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
  141. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  142. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
  143. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
  144. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1084 -0
  145. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
  146. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
  147. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  148. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
  149. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
  150. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
  151. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
  152. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
  153. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
  154. synth_ai/__init__.py +44 -30
  155. synth_ai/_utils/__init__.py +47 -0
  156. synth_ai/_utils/base_url.py +10 -0
  157. synth_ai/_utils/http.py +10 -0
  158. synth_ai/_utils/prompts.py +10 -0
  159. synth_ai/_utils/task_app_state.py +12 -0
  160. synth_ai/_utils/user_config.py +10 -0
  161. synth_ai/api/models/supported.py +144 -7
  162. synth_ai/api/train/__init__.py +13 -1
  163. synth_ai/api/train/builders.py +9 -3
  164. synth_ai/api/train/cli.py +155 -17
  165. synth_ai/api/train/config_finder.py +18 -11
  166. synth_ai/api/train/configs/__init__.py +8 -1
  167. synth_ai/api/train/configs/rl.py +32 -7
  168. synth_ai/api/train/configs/sft.py +6 -2
  169. synth_ai/api/train/configs/shared.py +59 -2
  170. synth_ai/api/train/env_resolver.py +13 -10
  171. synth_ai/auth/credentials.py +119 -0
  172. synth_ai/cli/__init__.py +61 -69
  173. synth_ai/cli/_modal_wrapper.py +7 -5
  174. synth_ai/cli/_typer_patch.py +0 -2
  175. synth_ai/cli/_validate_task_app.py +22 -4
  176. synth_ai/cli/commands/__init__.py +17 -0
  177. synth_ai/cli/commands/demo/__init__.py +6 -0
  178. synth_ai/cli/commands/demo/core.py +163 -0
  179. synth_ai/cli/commands/deploy/__init__.py +23 -0
  180. synth_ai/cli/commands/deploy/core.py +614 -0
  181. synth_ai/cli/commands/deploy/errors.py +72 -0
  182. synth_ai/cli/commands/deploy/validation.py +11 -0
  183. synth_ai/cli/commands/eval/__init__.py +19 -0
  184. synth_ai/cli/commands/eval/core.py +1109 -0
  185. synth_ai/cli/commands/eval/errors.py +81 -0
  186. synth_ai/cli/commands/eval/validation.py +133 -0
  187. synth_ai/cli/commands/filter/__init__.py +12 -0
  188. synth_ai/cli/commands/filter/core.py +388 -0
  189. synth_ai/cli/commands/filter/errors.py +55 -0
  190. synth_ai/cli/commands/filter/validation.py +77 -0
  191. synth_ai/cli/commands/help/__init__.py +177 -0
  192. synth_ai/cli/commands/help/core.py +73 -0
  193. synth_ai/cli/commands/status/__init__.py +64 -0
  194. synth_ai/cli/commands/status/client.py +192 -0
  195. synth_ai/cli/commands/status/config.py +92 -0
  196. synth_ai/cli/commands/status/errors.py +20 -0
  197. synth_ai/cli/commands/status/formatters.py +164 -0
  198. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  199. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  200. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  201. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  202. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  203. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  204. synth_ai/cli/commands/status/utils.py +114 -0
  205. synth_ai/cli/commands/train/__init__.py +53 -0
  206. synth_ai/cli/commands/train/core.py +21 -0
  207. synth_ai/cli/commands/train/errors.py +117 -0
  208. synth_ai/cli/commands/train/judge_schemas.py +199 -0
  209. synth_ai/cli/commands/train/judge_validation.py +304 -0
  210. synth_ai/cli/commands/train/validation.py +443 -0
  211. synth_ai/cli/demo.py +2 -162
  212. synth_ai/cli/deploy/__init__.py +28 -0
  213. synth_ai/cli/deploy/core.py +5 -0
  214. synth_ai/cli/deploy/errors.py +23 -0
  215. synth_ai/cli/deploy/validation.py +5 -0
  216. synth_ai/cli/eval/__init__.py +36 -0
  217. synth_ai/cli/eval/core.py +5 -0
  218. synth_ai/cli/eval/errors.py +31 -0
  219. synth_ai/cli/eval/validation.py +5 -0
  220. synth_ai/cli/filter/__init__.py +28 -0
  221. synth_ai/cli/filter/core.py +5 -0
  222. synth_ai/cli/filter/errors.py +23 -0
  223. synth_ai/cli/filter/validation.py +5 -0
  224. synth_ai/cli/legacy_root_backup.py +3 -1
  225. synth_ai/cli/lib/__init__.py +10 -0
  226. synth_ai/cli/lib/task_app_discovery.py +7 -0
  227. synth_ai/cli/lib/task_app_env.py +518 -0
  228. synth_ai/cli/modal_serve/__init__.py +12 -0
  229. synth_ai/cli/modal_serve/core.py +14 -0
  230. synth_ai/cli/modal_serve/errors.py +8 -0
  231. synth_ai/cli/modal_serve/validation.py +11 -0
  232. synth_ai/cli/recent.py +2 -1
  233. synth_ai/cli/serve/__init__.py +12 -0
  234. synth_ai/cli/serve/core.py +14 -0
  235. synth_ai/cli/serve/errors.py +8 -0
  236. synth_ai/cli/serve/validation.py +11 -0
  237. synth_ai/cli/setup.py +21 -0
  238. synth_ai/cli/status.py +7 -126
  239. synth_ai/cli/task_app_deploy.py +7 -0
  240. synth_ai/cli/task_app_list.py +25 -0
  241. synth_ai/cli/task_app_modal_serve.py +11 -0
  242. synth_ai/cli/task_app_serve.py +11 -0
  243. synth_ai/cli/task_apps.py +110 -1499
  244. synth_ai/cli/traces.py +1 -1
  245. synth_ai/cli/train/__init__.py +12 -0
  246. synth_ai/cli/train/core.py +21 -0
  247. synth_ai/cli/train/errors.py +8 -0
  248. synth_ai/cli/train/validation.py +24 -0
  249. synth_ai/cli/train.py +5 -0
  250. synth_ai/cli/turso.py +1 -1
  251. synth_ai/cli/watch.py +1 -1
  252. synth_ai/demos/__init__.py +10 -0
  253. synth_ai/demos/core/__init__.py +28 -1
  254. synth_ai/demos/crafter/__init__.py +1 -0
  255. synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
  256. synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
  257. synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
  258. synth_ai/demos/demo_registry.py +176 -0
  259. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  260. synth_ai/demos/math/__init__.py +1 -0
  261. synth_ai/demos/math/_common.py +16 -0
  262. synth_ai/demos/math/app.py +38 -0
  263. synth_ai/demos/math/config.toml +76 -0
  264. synth_ai/demos/math/deploy_modal.py +54 -0
  265. synth_ai/demos/math/modal_task_app.py +702 -0
  266. synth_ai/demos/math/task_app_entry.py +51 -0
  267. synth_ai/environments/environment/core.py +7 -1
  268. synth_ai/environments/examples/bandit/engine.py +0 -1
  269. synth_ai/environments/examples/bandit/environment.py +0 -1
  270. synth_ai/environments/examples/red/engine.py +33 -12
  271. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  272. synth_ai/environments/examples/red/environment.py +26 -0
  273. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  274. synth_ai/environments/examples/wordle/environment.py +0 -1
  275. synth_ai/evals/base.py +16 -5
  276. synth_ai/evals/client.py +1 -1
  277. synth_ai/http.py +8 -22
  278. synth_ai/inference/client.py +1 -1
  279. synth_ai/judge_schemas.py +4 -5
  280. synth_ai/learning/client.py +1 -1
  281. synth_ai/learning/health.py +1 -1
  282. synth_ai/learning/jobs.py +1 -1
  283. synth_ai/learning/rl/client.py +4 -2
  284. synth_ai/learning/rl/env_keys.py +1 -1
  285. synth_ai/learning/rl/secrets.py +1 -1
  286. synth_ai/learning/sft/client.py +1 -1
  287. synth_ai/learning/sft/data.py +407 -4
  288. synth_ai/learning/validators.py +4 -1
  289. synth_ai/streaming/__init__.py +29 -0
  290. synth_ai/streaming/config.py +94 -0
  291. synth_ai/streaming/handlers.py +469 -0
  292. synth_ai/streaming/streamer.py +301 -0
  293. synth_ai/streaming/types.py +95 -0
  294. synth_ai/task/apps/__init__.py +4 -2
  295. synth_ai/task/config.py +6 -4
  296. synth_ai/task/rubrics/__init__.py +1 -2
  297. synth_ai/task/rubrics/loaders.py +14 -10
  298. synth_ai/task/rubrics.py +219 -0
  299. synth_ai/task/trace_correlation_helpers.py +24 -11
  300. synth_ai/task/tracing_utils.py +14 -3
  301. synth_ai/task/validators.py +0 -1
  302. synth_ai/tracing_v3/abstractions.py +3 -3
  303. synth_ai/tracing_v3/config.py +15 -13
  304. synth_ai/tracing_v3/constants.py +21 -0
  305. synth_ai/tracing_v3/db_config.py +3 -1
  306. synth_ai/tracing_v3/decorators.py +10 -7
  307. synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
  308. synth_ai/tracing_v3/migration_helper.py +1 -2
  309. synth_ai/tracing_v3/session_tracer.py +7 -7
  310. synth_ai/tracing_v3/storage/base.py +29 -29
  311. synth_ai/tracing_v3/storage/config.py +3 -3
  312. synth_ai/tracing_v3/turso/daemon.py +8 -9
  313. synth_ai/tracing_v3/turso/native_manager.py +80 -72
  314. synth_ai/tracing_v3/utils.py +2 -2
  315. synth_ai/utils/__init__.py +101 -0
  316. synth_ai/utils/base_url.py +94 -0
  317. synth_ai/utils/cli.py +131 -0
  318. synth_ai/utils/env.py +294 -0
  319. synth_ai/utils/http.py +172 -0
  320. synth_ai/utils/modal.py +308 -0
  321. synth_ai/utils/process.py +212 -0
  322. synth_ai/utils/prompts.py +39 -0
  323. synth_ai/utils/sqld.py +122 -0
  324. synth_ai/utils/task_app_discovery.py +882 -0
  325. synth_ai/utils/task_app_env.py +186 -0
  326. synth_ai/utils/task_app_state.py +318 -0
  327. synth_ai/utils/user_config.py +137 -0
  328. synth_ai/v0/config/__init__.py +1 -5
  329. synth_ai/v0/config/base_url.py +1 -7
  330. synth_ai/v0/tracing/config.py +1 -1
  331. synth_ai/v0/tracing/decorators.py +1 -1
  332. synth_ai/v0/tracing/upload.py +1 -1
  333. synth_ai/v0/tracing_v1/config.py +1 -1
  334. synth_ai/v0/tracing_v1/decorators.py +1 -1
  335. synth_ai/v0/tracing_v1/upload.py +1 -1
  336. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/METADATA +91 -32
  337. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/RECORD +341 -154
  338. synth_ai/cli/man.py +0 -106
  339. synth_ai/cli/tui.py +0 -57
  340. synth_ai/compound/cais.py +0 -0
  341. synth_ai/core/experiment.py +0 -13
  342. synth_ai/core/system.py +0 -15
  343. synth_ai/demo_registry.py +0 -295
  344. synth_ai/handshake.py +0 -109
  345. synth_ai/tui/__init__.py +0 -5
  346. synth_ai/tui/__main__.py +0 -13
  347. synth_ai/tui/cli/__init__.py +0 -1
  348. synth_ai/tui/cli/query_experiments.py +0 -164
  349. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  350. synth_ai/tui/dashboard.py +0 -906
  351. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
  352. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
  353. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
  354. {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py CHANGED
@@ -9,24 +9,21 @@ import hashlib
9
9
  import importlib
10
10
  import importlib.util
11
11
  import inspect
12
- import json
13
12
  import os
14
13
  import shlex
15
14
  import shutil
16
15
  import signal
17
- import sqlite3
18
16
  import subprocess
19
17
  import sys
20
18
  import tempfile
21
19
  import textwrap
22
20
  import time
23
21
  import types
24
- import uuid
25
22
  from collections.abc import Callable, Iterable, Iterator, Sequence
26
23
  from dataclasses import dataclass
27
- from datetime import datetime, timezone
24
+ from datetime import UTC, datetime
28
25
  from pathlib import Path
29
- from typing import Any, Optional, cast
26
+ from typing import Any, cast
30
27
 
31
28
  try: # Python 3.11+
32
29
  import tomllib as _toml
@@ -35,6 +32,9 @@ except Exception: # pragma: no cover - fallback
35
32
 
36
33
  import click
37
34
  from click.exceptions import Abort
35
+ from synth_ai.cli.commands import deploy as _deploy_commands
36
+ from synth_ai.cli.commands.eval import core as eval_core
37
+ from synth_ai.cli.commands.filter import core as filter_core
38
38
 
39
39
  # Tracing imports - make conditional for optional dependencies
40
40
  try:
@@ -92,14 +92,14 @@ except Exception as exc: # pragma: no cover - critical dependency
92
92
  raise RuntimeError("Unable to load task app server utilities") from exc
93
93
 
94
94
 
95
- def _load_demo_directory() -> Optional[Path]:
95
+ def _load_demo_directory() -> Path | None:
96
96
  """Return the demo task apps directory if available."""
97
97
 
98
98
  try:
99
99
  module = cast(
100
100
  Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
101
101
  )
102
- loader = cast(Callable[[], Optional[str | Path]], module.load_demo_dir)
102
+ loader = cast(Callable[[], str | Path | None], module.load_demo_dir)
103
103
  demo_dir = loader()
104
104
  if isinstance(demo_dir, str | Path):
105
105
  demo_path = Path(demo_dir)
@@ -139,7 +139,7 @@ DEFAULT_SEARCH_RELATIVE = (
139
139
  )
140
140
 
141
141
 
142
- def _pearson(xs: Sequence[float], ys: Sequence[float]) -> Optional[float]:
142
+ def _pearson(xs: Sequence[float], ys: Sequence[float]) -> float | None:
143
143
  if len(xs) != len(ys) or len(xs) < 2:
144
144
  return None
145
145
  mean_x = sum(xs) / len(xs)
@@ -164,7 +164,7 @@ class AppChoice:
164
164
  label: str
165
165
  path: Path
166
166
  source: str
167
- description: Optional[str] = None
167
+ description: str | None = None
168
168
  aliases: tuple[str, ...] = ()
169
169
  entry: TaskAppEntryType | None = None
170
170
  entry_loader: Callable[[], TaskAppEntryType] | None = None
@@ -188,21 +188,21 @@ class JudgeSpec:
188
188
  kwargs: dict[str, Any]
189
189
 
190
190
 
191
- def _parse_datetime_for_trace(value: Any) -> Optional[datetime]:
191
+ def _parse_datetime_for_trace(value: Any) -> datetime | None:
192
192
  if isinstance(value, datetime):
193
- return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
193
+ return value if value.tzinfo else value.replace(tzinfo=UTC)
194
194
  if isinstance(value, str):
195
195
  value = value.replace("Z", "+00:00")
196
196
  try:
197
197
  dt = datetime.fromisoformat(value)
198
198
  except ValueError:
199
199
  try:
200
- dt = datetime.fromtimestamp(float(value), tz=timezone.utc)
200
+ dt = datetime.fromtimestamp(float(value), tz=UTC)
201
201
  except Exception:
202
202
  return None
203
- return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
203
+ return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
204
204
  if isinstance(value, int | float):
205
- return datetime.fromtimestamp(float(value), tz=timezone.utc)
205
+ return datetime.fromtimestamp(float(value), tz=UTC)
206
206
  return None
207
207
 
208
208
 
@@ -241,6 +241,24 @@ def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
241
241
  system_state_after=payload.get("system_state_after"),
242
242
  **base_kwargs,
243
243
  )
244
+ # Check for LM CAIS event fields
245
+ if any(key in payload for key in ("model_name", "provider", "call_records")):
246
+ from synth_ai.tracing_v3.abstractions import LMCAISEvent
247
+ # Note: call_records are left as dicts - the storage layer will handle serialization
248
+ call_records = payload.get("call_records") or []
249
+ return LMCAISEvent(
250
+ model_name=payload.get("model_name", ""),
251
+ provider=payload.get("provider", ""),
252
+ input_tokens=payload.get("input_tokens"),
253
+ output_tokens=payload.get("output_tokens"),
254
+ total_tokens=payload.get("total_tokens"),
255
+ cost_usd=payload.get("cost_usd"),
256
+ latency_ms=payload.get("latency_ms"),
257
+ span_id=payload.get("span_id"),
258
+ trace_id=payload.get("trace_id"),
259
+ call_records=call_records,
260
+ **base_kwargs,
261
+ )
244
262
  return BaseEvent(**base_kwargs)
245
263
 
246
264
 
@@ -279,7 +297,7 @@ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
279
297
  for msg in payload.get("markov_blanket_messages", [])
280
298
  if isinstance(msg, dict)
281
299
  ]
282
- timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(timezone.utc)
300
+ timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(UTC)
283
301
  completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
284
302
  return SessionTimeStep(
285
303
  step_id=payload.get("step_id", ""),
@@ -293,7 +311,7 @@ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
293
311
  )
294
312
 
295
313
 
296
- def _session_trace_from_dict(payload: dict[str, Any]) -> Optional[V3SessionTrace]:
314
+ def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace | None:
297
315
  if not isinstance(payload, dict):
298
316
  return None
299
317
  steps = [
@@ -311,7 +329,7 @@ def _session_trace_from_dict(payload: dict[str, Any]) -> Optional[V3SessionTrace
311
329
  for msg in payload.get("markov_blanket_message_history", [])
312
330
  if isinstance(msg, dict)
313
331
  ]
314
- created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(timezone.utc)
332
+ created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(UTC)
315
333
  metadata = payload.get("metadata") or {}
316
334
  session_metadata = payload.get("session_metadata")
317
335
  return V3SessionTrace(
@@ -336,15 +354,32 @@ async def _store_trace(
336
354
  _logger.info(f"[STORE_TRACE_DEBUG] Called with tracer={tracer is not None}, trace_namespace={trace_namespace is not None}")
337
355
 
338
356
  if tracer is None or not isinstance(trace_namespace, dict):
339
- _logger.warning(f"[STORE_TRACE_DEBUG] Early return: tracer={tracer is not None}, trace_namespace type={type(trace_namespace)}")
340
- return
357
+ message = (
358
+ f"Trace storage requires a tracer instance and dict payload. "
359
+ f"Got tracer_present={tracer is not None}, payload_type={type(trace_namespace)}"
360
+ )
361
+ _logger.error("[STORE_TRACE_DEBUG] %s", message)
362
+ raise ValueError(message)
341
363
 
342
364
  _logger.info(f"[STORE_TRACE_DEBUG] trace_namespace keys: {list(trace_namespace.keys())}")
343
365
 
366
+ # Handle both formats:
367
+ # - With session_trace key: {"session_trace": {...}}
368
+ # - Without session_trace key (trace itself is the session): {"session_id": ..., "markov_blanket_message_history": ...}
344
369
  session_payload = trace_namespace.get("session_trace")
345
370
  if not isinstance(session_payload, dict):
346
- _logger.warning(f"[STORE_TRACE_DEBUG] No session_trace found or wrong type: {type(session_payload)}")
347
- return
371
+ # If no session_trace key, assume "full" format where trace itself is the session_trace
372
+ if "session_id" in trace_namespace:
373
+ session_payload = trace_namespace
374
+ _logger.info("[STORE_TRACE_DEBUG] Using trace_namespace directly as session_payload (no session_trace key)")
375
+ else:
376
+ message = (
377
+ "Trace payload did not contain a 'session_trace' dict and lacked top-level "
378
+ "session fields (session_id, markov_blanket_message_history). "
379
+ f"Payload keys: {list(trace_namespace.keys())}"
380
+ )
381
+ _logger.error("[STORE_TRACE_DEBUG] %s", message)
382
+ raise ValueError(message)
348
383
 
349
384
  _logger.info(f"[STORE_TRACE_DEBUG] session_payload keys: {list(session_payload.keys())}")
350
385
  msg_count = len(session_payload.get("markov_blanket_message_history", []))
@@ -352,8 +387,26 @@ async def _store_trace(
352
387
 
353
388
  trace_obj = _session_trace_from_dict(session_payload)
354
389
  if trace_obj is None:
355
- _logger.warning(f"[STORE_TRACE_DEBUG] _session_trace_from_dict returned None")
356
- return
390
+ message = "Session trace payload could not be parsed into a SessionTrace object."
391
+ _logger.error("[STORE_TRACE_DEBUG] %s", message)
392
+ raise ValueError(message)
393
+
394
+ if not trace_obj.markov_blanket_message_history:
395
+ message = (
396
+ "Session trace is missing markov_blanket_message_history; "
397
+ "eval output must include all prompts/tool calls. "
398
+ f"session_id={trace_obj.session_id}"
399
+ )
400
+ _logger.error("[STORE_TRACE_DEBUG] %s", message)
401
+ raise ValueError(message)
402
+
403
+ if not trace_obj.event_history:
404
+ message = (
405
+ "Session trace is missing event_history; rollout should emit environment/LLM events. "
406
+ f"session_id={trace_obj.session_id}"
407
+ )
408
+ _logger.error("[STORE_TRACE_DEBUG] %s", message)
409
+ raise ValueError(message)
357
410
 
358
411
  _logger.info(f"[STORE_TRACE_DEBUG] Created SessionTrace object with {len(trace_obj.markov_blanket_message_history)} messages")
359
412
 
@@ -366,7 +419,7 @@ async def _store_trace(
366
419
 
367
420
  _logger.info(f"[STORE_TRACE_DEBUG] Calling insert_session_trace for session_id={trace_obj.session_id}")
368
421
  await tracer.db.insert_session_trace(trace_obj)
369
- _logger.info(f"[STORE_TRACE_DEBUG] Successfully inserted trace")
422
+ _logger.info("[STORE_TRACE_DEBUG] Successfully inserted trace")
370
423
 
371
424
  def _temporary_sys_path(paths: Sequence[Path]):
372
425
  """Context manager to prepend entries to sys.path temporarily."""
@@ -480,49 +533,6 @@ def _candidate_search_roots() -> list[Path]:
480
533
  return ordered
481
534
 
482
535
 
483
- def _eval_config_sort_key(path: Path) -> tuple[int, int, int, str]:
484
- name = path.name.lower()
485
- parent_names = {p.name.lower() for p in path.parents}
486
- in_configs = 0 if "configs" in parent_names else 1
487
- in_examples = 0 if "examples" in parent_names else 1
488
- starts_eval = 0 if name.startswith("eval") else 1
489
- return (in_configs, in_examples, starts_eval, str(path))
490
-
491
-
492
- def _discover_eval_config_paths() -> list[Path]:
493
- """Find candidate eval TOML files near the current working directory."""
494
-
495
- candidates: list[Path] = []
496
- seen: set[Path] = set()
497
- search_roots = _candidate_search_roots()
498
- for root in search_roots:
499
- if not root.exists() or not root.is_dir():
500
- continue
501
- try:
502
- root = root.resolve()
503
- except Exception:
504
- continue
505
- for path in root.rglob("*.toml"):
506
- if not path.is_file():
507
- continue
508
- if _should_ignore_path(path):
509
- continue
510
- name_lower = path.name.lower()
511
- if "eval" not in name_lower and "evaluation" not in name_lower:
512
- continue
513
- try:
514
- resolved = path.resolve()
515
- except Exception:
516
- continue
517
- if resolved in seen:
518
- continue
519
- seen.add(resolved)
520
- candidates.append(resolved)
521
-
522
- candidates.sort(key=_eval_config_sort_key)
523
- return candidates
524
-
525
-
526
536
  class _TaskAppConfigVisitor(ast.NodeVisitor):
527
537
  def __init__(self) -> None:
528
538
  self.matches: list[tuple[str, int]] = []
@@ -913,43 +923,43 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
913
923
  for kw in modal_call.keywords:
914
924
  if kw.arg and isinstance(kw.value, ast.Constant):
915
925
  kwargs[kw.arg] = kw.value.value
916
- elif kw.arg == "pip_packages" and isinstance(kw.value, (ast.List, ast.Tuple)):
926
+ elif kw.arg == "pip_packages" and isinstance(kw.value, ast.List | ast.Tuple):
917
927
  # Handle pip_packages list/tuple
918
928
  packages: list[str] = []
919
929
  value_node = kw.value
920
- if isinstance(value_node, (ast.List, ast.Tuple)):
930
+ if isinstance(value_node, ast.List | ast.Tuple):
921
931
  for elt in value_node.elts:
922
932
  if isinstance(elt, ast.Constant):
923
933
  packages.append(elt.value)
924
934
  kwargs[kw.arg] = tuple(packages)
925
- elif kw.arg == "extra_local_dirs" and isinstance(kw.value, (ast.List, ast.Tuple)):
935
+ elif kw.arg == "extra_local_dirs" and isinstance(kw.value, ast.List | ast.Tuple):
926
936
  # Handle extra_local_dirs list/tuple of tuples
927
937
  dirs = []
928
938
  value_node = kw.value
929
- if isinstance(value_node, (ast.List, ast.Tuple)):
939
+ if isinstance(value_node, ast.List | ast.Tuple):
930
940
  for elt in value_node.elts:
931
- if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
941
+ if isinstance(elt, ast.List | ast.Tuple) and len(elt.elts) == 2:
932
942
  src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
933
943
  dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
934
944
  if src and dst:
935
945
  dirs.append((src, dst))
936
946
  kwargs[kw.arg] = tuple(dirs)
937
- elif kw.arg == "secret_names" and isinstance(kw.value, (ast.List, ast.Tuple)):
947
+ elif kw.arg == "secret_names" and isinstance(kw.value, ast.List | ast.Tuple):
938
948
  # Handle secret_names list/tuple
939
949
  secrets = []
940
950
  value_node = kw.value
941
- if isinstance(value_node, (ast.List, ast.Tuple)):
951
+ if isinstance(value_node, ast.List | ast.Tuple):
942
952
  for elt in value_node.elts:
943
953
  if isinstance(elt, ast.Constant):
944
954
  secrets.append(elt.value)
945
955
  kwargs[kw.arg] = tuple(secrets)
946
- elif kw.arg == "volume_mounts" and isinstance(kw.value, (ast.List, ast.Tuple)):
956
+ elif kw.arg == "volume_mounts" and isinstance(kw.value, ast.List | ast.Tuple):
947
957
  # Handle volume_mounts list/tuple of tuples
948
958
  mounts = []
949
959
  value_node = kw.value
950
- if isinstance(value_node, (ast.List, ast.Tuple)):
960
+ if isinstance(value_node, ast.List | ast.Tuple):
951
961
  for elt in value_node.elts:
952
- if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
962
+ if isinstance(elt, ast.List | ast.Tuple) and len(elt.elts) == 2:
953
963
  name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
954
964
  mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
955
965
  if name and mount:
@@ -2238,14 +2248,13 @@ def validate_task_app_cmd(
2238
2248
  • Debug failing deployments: Use --verbose to see detailed endpoint responses
2239
2249
  • Test API key configuration: Verify authentication is set up correctly
2240
2250
  """
2241
- import asyncio
2242
2251
  import socket
2243
2252
  import subprocess
2244
2253
  import tempfile
2245
2254
  import time
2246
2255
 
2247
2256
  # Import the validate_task_app function defined in this module
2248
- from synth_ai.cli._validate_task_app import validate_task_app # type: ignore[attr-defined]
2257
+ from ._validate_task_app import validate_task_app # type: ignore[attr-defined]
2249
2258
 
2250
2259
  proc = None
2251
2260
  task_app_url = url
@@ -2445,48 +2454,15 @@ def serve_command(
2445
2454
  trace_dir: str | None,
2446
2455
  trace_db: str | None,
2447
2456
  ) -> None:
2448
- demo_dir_path = _load_demo_directory()
2449
- if demo_dir_path:
2450
- if not demo_dir_path.is_dir():
2451
- raise click.ClickException(
2452
- f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai setup' to create a demo."
2453
- )
2454
- os.chdir(demo_dir_path)
2455
- click.echo(f"Using demo directory: {demo_dir_path}\n")
2456
- os.environ["SYNTH_DEMO_DIR"] = str(demo_dir_path.resolve())
2457
-
2458
- # Prompt for port if not provided
2459
- if port is None:
2460
- port = click.prompt("Port to serve on", type=int, default=8001)
2461
-
2462
- # Prompt for trace directory if not provided
2463
- if trace_dir is None:
2464
- click.echo(
2465
- "\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
2466
- )
2467
- click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
2468
- enable_tracing = click.confirm("Enable tracing?", default=True)
2469
- if enable_tracing:
2470
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2471
- default_trace_dir = str((demo_base / "traces/v3").resolve())
2472
- trace_dir = click.prompt(
2473
- "Trace directory", type=str, default=default_trace_dir, show_default=True
2474
- )
2475
- else:
2476
- trace_dir = None
2477
-
2478
- # Prompt for trace DB if not provided and tracing is enabled
2479
- if trace_dir and trace_db is None:
2480
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2481
- default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
2482
- trace_db = click.prompt(
2483
- "Trace DB path", type=str, default=default_trace_db, show_default=True
2484
- )
2485
-
2486
- choice = _select_app_choice(app_id, purpose="serve")
2487
- entry = choice.ensure_entry()
2488
- _serve_entry(
2489
- entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
2457
+ _deploy_commands.run_uvicorn_runtime(
2458
+ app_id,
2459
+ host,
2460
+ port,
2461
+ env_file,
2462
+ reload_flag,
2463
+ force,
2464
+ trace_dir,
2465
+ trace_db,
2490
2466
  )
2491
2467
 
2492
2468
 
@@ -2599,49 +2575,19 @@ def serve_task_group(
2599
2575
  trace_dir: str | None,
2600
2576
  trace_db: str | None,
2601
2577
  ) -> None:
2602
- demo_dir_path = _load_demo_directory()
2603
- if demo_dir_path:
2604
- if not demo_dir_path.is_dir():
2605
- raise click.ClickException(
2606
- f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai setup' to create a demo."
2607
- )
2608
- os.chdir(demo_dir_path)
2609
- click.echo(f"Using demo directory: {demo_dir_path}\n")
2610
- os.environ["SYNTH_DEMO_DIR"] = str(demo_dir_path.resolve())
2611
-
2612
- # Prompt for port if not provided
2613
- if port is None:
2614
- port = click.prompt("Port to serve on", type=int, default=8001)
2615
-
2616
- # Prompt for trace directory if not provided
2617
- if trace_dir is None:
2618
- click.echo(
2619
- "\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
2620
- )
2621
- click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
2622
- enable_tracing = click.confirm("Enable tracing?", default=True)
2623
- if enable_tracing:
2624
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2625
- default_trace_dir = str((demo_base / "traces/v3").resolve())
2626
- trace_dir = click.prompt(
2627
- "Trace directory", type=str, default=default_trace_dir, show_default=True
2628
- )
2629
- else:
2630
- trace_dir = None
2578
+ _deploy_commands.run_uvicorn_runtime(
2579
+ app_id,
2580
+ host,
2581
+ port,
2582
+ env_file,
2583
+ reload_flag,
2584
+ force,
2585
+ trace_dir,
2586
+ trace_db,
2587
+ )
2631
2588
 
2632
- # Prompt for trace DB if not provided and tracing is enabled
2633
- if trace_dir and trace_db is None:
2634
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2635
- default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
2636
- trace_db = click.prompt(
2637
- "Trace DB path", type=str, default=default_trace_db, show_default=True
2638
- )
2639
2589
 
2640
- choice = _select_app_choice(app_id, purpose="serve")
2641
- entry = choice.ensure_entry()
2642
- _serve_entry(
2643
- entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
2644
- )
2590
+ _deploy_commands.register_task_app_commands(task_app_group)
2645
2591
 
2646
2592
 
2647
2593
  def _determine_env_files(
@@ -2936,87 +2882,6 @@ def _serve_entry(
2936
2882
  )
2937
2883
 
2938
2884
 
2939
- @task_app_group.command("deploy")
2940
- @click.argument("app_id", type=str, required=False)
2941
- @click.option("--name", "modal_name", default=None, help="Override Modal app name")
2942
- @click.option("--dry-run", is_flag=True, help="Print modal deploy command without executing")
2943
- @click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
2944
- @click.option(
2945
- "--env-file",
2946
- multiple=True,
2947
- type=click.Path(),
2948
- help="Env file to load into the container (can be repeated)",
2949
- )
2950
- def deploy_app(
2951
- app_id: str | None,
2952
- modal_name: str | None,
2953
- dry_run: bool,
2954
- modal_cli: str,
2955
- env_file: Sequence[str],
2956
- ) -> None:
2957
- """Deploy a task app to Modal."""
2958
-
2959
- demo_dir_path = _load_demo_directory()
2960
- if demo_dir_path:
2961
- if not demo_dir_path.is_dir():
2962
- raise click.ClickException(
2963
- f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai demo' to create a demo."
2964
- )
2965
- os.chdir(demo_dir_path)
2966
- click.echo(f"Using demo directory: {demo_dir_path}\n")
2967
-
2968
- choice = _select_app_choice(app_id, purpose="deploy")
2969
-
2970
- if choice.modal_script:
2971
- env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
2972
- click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
2973
- _run_modal_script(
2974
- choice.modal_script,
2975
- modal_cli,
2976
- "deploy",
2977
- env_paths,
2978
- modal_name=modal_name,
2979
- dry_run=dry_run,
2980
- )
2981
- return
2982
-
2983
- entry = choice.ensure_entry()
2984
- _deploy_entry(entry, modal_name, dry_run, modal_cli, env_file, original_path=choice.path)
2985
-
2986
-
2987
- @task_app_group.command("modal-serve")
2988
- @click.argument("app_id", type=str, required=False)
2989
- @click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
2990
- @click.option("--name", "modal_name", default=None, help="Override Modal app name (optional)")
2991
- @click.option(
2992
- "--env-file",
2993
- multiple=True,
2994
- type=click.Path(),
2995
- help="Env file to load into the container (can be repeated)",
2996
- )
2997
- def modal_serve_app(
2998
- app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
2999
- ) -> None:
3000
- click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
3001
- try:
3002
- choice = _select_app_choice(app_id, purpose="modal-serve")
3003
- except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
3004
- raise click.ClickException(
3005
- f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
3006
- "Make sure you're running the Click CLI (synth_ai.cli:cli)."
3007
- ) from exc
3008
-
3009
- if choice.modal_script:
3010
- env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
3011
- click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
3012
- _run_modal_script(choice.modal_script, modal_cli, "serve", env_paths, modal_name=modal_name)
3013
- return
3014
-
3015
- entry = choice.ensure_entry()
3016
- click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
3017
- _modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
3018
-
3019
-
3020
2885
  def _write_modal_entrypoint(
3021
2886
  entry: TaskAppEntryType,
3022
2887
  modal_cfg: ModalDeploymentConfigType,
@@ -3260,1263 +3125,9 @@ def register(cli: click.Group) -> None:
3260
3125
  cli.add_command(filter_command)
3261
3126
 
3262
3127
 
3263
- @click.command(
3264
- "eval",
3265
- help="Run one-off rollouts against a task app and print judge/eval summaries.",
3266
- )
3267
- @click.argument("app_id", type=str, required=False)
3268
- @click.option(
3269
- "--config",
3270
- type=click.Path(),
3271
- default=None,
3272
- help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
3273
- )
3274
- @click.option(
3275
- "--url",
3276
- "task_app_url",
3277
- type=str,
3278
- default=None,
3279
- help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
3280
- )
3281
- @click.option(
3282
- "--seeds",
3283
- default="0,1,2,3,4",
3284
- help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
3285
- )
3286
- @click.option("--split", default="train", show_default=True, help="Dataset split to use")
3287
- @click.option(
3288
- "--model",
3289
- default=None,
3290
- help="Model identifier. When omitted the CLI will prompt based on task metadata.",
3291
- )
3292
- @click.option(
3293
- "--env-file",
3294
- multiple=True,
3295
- type=click.Path(),
3296
- help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
3297
- )
3298
- @click.option(
3299
- "--trace-db",
3300
- default="traces/v3/synth_ai.db",
3301
- show_default=True,
3302
- help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
3303
- )
3304
- @click.option(
3305
- "--metadata",
3306
- multiple=True,
3307
- help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
3308
- )
3309
- @click.option(
3310
- "--metadata-sql",
3311
- default=None,
3312
- help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
3313
- )
3314
- def eval_command(
3315
- app_id: str | None,
3316
- config: str | None,
3317
- task_app_url: str | None,
3318
- seeds: str,
3319
- split: str,
3320
- model: str | None,
3321
- env_file: Sequence[str],
3322
- trace_db: str,
3323
- metadata: Sequence[str],
3324
- metadata_sql: str | None,
3325
- ) -> None:
3326
- """Run rollouts against a task app and report judge statistics.
3327
-
3328
- By default the command spins up the selected task app in-process, executes the
3329
- requested seeds, and prints aggregate scores (official and custom judges). When
3330
- pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
3331
- forward authentication headers to the running service.
3332
- """
3333
- # Parse and validate TOML config
3334
- from synth_ai.task.config import EvalConfig
3335
-
3336
- cfg: dict[str, Any] = {}
3337
- eval_cfg: EvalConfig | None = None
3338
- config_path: Path | None = None
3339
-
3340
- if config:
3341
- config_path = Path(config)
3342
- else:
3343
- auto_configs = _discover_eval_config_paths()
3344
- if auto_configs:
3345
- config_path = auto_configs[0]
3346
- click.echo(f"Using eval config: {config_path}")
3347
-
3348
- if config_path:
3349
- if _toml is None:
3350
- raise click.ClickException(
3351
- "TOML parser not available; use Python 3.11+ or install tomli"
3352
- )
3353
- if not config_path.exists():
3354
- raise click.ClickException(f"Eval config not found: {config_path}")
3355
- try:
3356
- data = config_path.read_bytes()
3357
- parsed = _toml.loads(data.decode("utf-8"))
3358
- if isinstance(parsed, dict):
3359
- section = parsed.get("eval")
3360
- cfg = dict(section) if isinstance(section, dict) else dict(parsed)
3361
-
3362
- # Validate config with dataclass
3363
- try:
3364
- eval_cfg = EvalConfig.from_dict(cfg)
3365
- click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
3366
- except (ValueError, TypeError) as validation_error:
3367
- raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
3368
- except click.ClickException:
3369
- raise
3370
- except Exception as exc:
3371
- raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
3372
-
3373
- # CLI args override config
3374
- if eval_cfg:
3375
- app_id = app_id or eval_cfg.app_id
3376
- else:
3377
- app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
3378
-
3379
- metadata_filters: dict[str, str] = {}
3380
- if eval_cfg:
3381
- metadata_filters.update(eval_cfg.metadata)
3382
- else:
3383
- cfg_metadata = cfg.get("metadata")
3384
- if isinstance(cfg_metadata, dict):
3385
- for key, value in cfg_metadata.items():
3386
- metadata_filters[str(key)] = str(value)
3387
- elif isinstance(cfg_metadata, list):
3388
- for item in cfg_metadata:
3389
- if isinstance(item, str) and "=" in item:
3390
- key, value = item.split("=", 1)
3391
- metadata_filters[key.strip()] = value.strip()
3392
-
3393
- for item in metadata or ():
3394
- if "=" not in item:
3395
- raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
3396
- key, value = item.split("=", 1)
3397
- key = key.strip()
3398
- value = value.strip()
3399
- if not key or not value:
3400
- raise click.ClickException(f"Invalid metadata filter: {item}")
3401
- metadata_filters[key] = value
3402
-
3403
- metadata_sql_query: str | None = None
3404
- if eval_cfg and eval_cfg.metadata_sql:
3405
- metadata_sql_query = eval_cfg.metadata_sql
3406
- else:
3407
- cfg_metadata_sql = cfg.get("metadata_sql")
3408
- if isinstance(cfg_metadata_sql, dict):
3409
- metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
3410
- elif isinstance(cfg_metadata_sql, str):
3411
- metadata_sql_query = cfg_metadata_sql
3412
-
3413
- if metadata_sql:
3414
- metadata_sql_query = metadata_sql
3415
- if metadata_sql_query is not None:
3416
- metadata_sql_query = str(metadata_sql_query)
3417
-
3418
- trace_db_url: str | None = None
3419
- trace_db = (trace_db or "").strip()
3420
- if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
3421
- if "://" in trace_db:
3422
- trace_db_url = trace_db
3423
- else:
3424
- trace_path = Path(trace_db).expanduser()
3425
- trace_path.parent.mkdir(parents=True, exist_ok=True)
3426
- trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
3427
- trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
3428
-
3429
- # Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
3430
- if cfg.get("model") and not model:
3431
- model = str(cfg["model"]) # type: ignore[index]
3432
- if cfg.get("seeds") and seeds == "0,1,2,3,4":
3433
- val = cfg["seeds"]
3434
- if isinstance(val, list):
3435
- with contextlib.suppress(Exception):
3436
- seeds = ",".join(str(int(x)) for x in val)
3437
- elif isinstance(val, str):
3438
- seeds = val
3439
- elif isinstance(val, int):
3440
- seeds = str(val)
3441
- if cfg.get("env_file") and not env_file:
3442
- ef = cfg["env_file"]
3443
- if isinstance(ef, str):
3444
- env_file = (ef,) # type: ignore[assignment]
3445
- elif isinstance(ef, list):
3446
- env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
3447
-
3448
- choice_for_env: AppChoice | None = None
3449
- entry: TaskAppEntryType | None = None
3450
- if task_app_url is None:
3451
- choice_for_env = _select_app_choice(app_id, purpose="eval")
3452
- entry = choice_for_env.ensure_entry()
3453
-
3454
- env_paths: list[Path] = []
3455
- if entry is not None:
3456
- original_env_path = choice_for_env.path if choice_for_env is not None else None
3457
- env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
3458
- else:
3459
- if not env_file:
3460
- raise click.ClickException("--env-file is required when using --url")
3461
- for candidate in env_file:
3462
- p = Path(candidate).expanduser()
3463
- if not p.exists():
3464
- raise click.ClickException(f"Env file not found: {p}")
3465
- env_paths.append(p)
3466
-
3467
- click.echo("Using env file(s): " + ", ".join(str(p) for p in env_paths))
3468
- _load_env_files_into_process([str(Path(p)) for p in env_paths])
3469
-
3470
- if task_app_url is None:
3471
- config = entry.config_factory() # type: ignore[union-attr]
3472
- # Help the type checker; runtime check also enforced in server.run_task_app
3473
- if not isinstance(config, TaskAppConfig):
3474
- raise click.ClickException(
3475
- "Invalid task app: config_factory did not return TaskAppConfig"
3476
- )
3477
- app = create_task_app(config)
3478
-
3479
- # Determine supported models
3480
- inference_meta: dict[str, Any] = {}
3481
- supported: list[str] = []
3482
- seen_models: set[str] = set()
3483
-
3484
- def _add_supported_model(candidate: Any) -> None:
3485
- if not candidate:
3486
- return
3487
- text = str(candidate).strip()
3488
- if not text or text in seen_models:
3489
- return
3490
- supported.append(text)
3491
- seen_models.add(text)
3492
-
3493
- if task_app_url is None:
3494
- try:
3495
- if hasattr(config, "base_task_info") and config.base_task_info:
3496
- inf_obj = getattr(config.base_task_info, "inference", None)
3497
- if inf_obj is not None:
3498
- if hasattr(inf_obj, "model_dump"):
3499
- inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
3500
- elif isinstance(inf_obj, dict):
3501
- inference_meta = dict(inf_obj)
3502
- except Exception:
3503
- inference_meta = {}
3504
- else:
3505
- try:
3506
- import httpx as _hx
3507
-
3508
- headers = {}
3509
- api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
3510
- if api_key:
3511
- headers["X-API-Key"] = api_key
3512
- with _hx.Client(base_url=task_app_url, headers=headers, timeout=15.0) as c:
3513
- info = c.get("/info").json()
3514
- inf = info.get("inference") if isinstance(info, dict) else None
3515
- if isinstance(inf, dict):
3516
- inference_meta = dict(inf)
3517
- except Exception:
3518
- inference_meta = {}
3519
-
3520
- default_model = inference_meta.get("model")
3521
- if isinstance(default_model, str):
3522
- _add_supported_model(default_model)
3523
-
3524
- models_field = inference_meta.get("models")
3525
- if isinstance(models_field, list):
3526
- for candidate in models_field:
3527
- _add_supported_model(candidate)
3528
-
3529
- supported_models = inference_meta.get("supported_models")
3530
- if isinstance(supported_models, list):
3531
- for candidate in supported_models:
3532
- _add_supported_model(candidate)
3533
-
3534
- providers = inference_meta.get("providers")
3535
- if isinstance(providers, list):
3536
- if "openai" in providers:
3537
- _add_supported_model("gpt-5")
3538
- if "groq" in providers:
3539
- _add_supported_model("groq:llama-3.1-70b-versatile")
3540
-
3541
- _add_supported_model("synth:qwen-0.6b")
3542
-
3543
- selected_model = model
3544
- if not selected_model:
3545
- if not supported:
3546
- raise click.ClickException(
3547
- "No supported models; supply --model or add base_task_info.inference.model"
3548
- )
3549
- click.echo("Select model to evaluate:")
3550
- for idx, m in enumerate(supported, start=1):
3551
- click.echo(f" {idx}) {m}")
3552
- choice_idx = click.prompt("Enter choice", type=click.IntRange(1, len(supported)))
3553
- selected_model = supported[choice_idx - 1]
3554
-
3555
- try:
3556
- seed_values = [int(s.strip()) for s in seeds.split(",") if s.strip()]
3557
- except Exception as exc:
3558
- raise click.ClickException("Invalid --seeds; expected comma-separated integers") from exc
3559
-
3560
- import httpx
3561
-
3562
- headers = {}
3563
- api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
3564
- if api_key:
3565
- headers["X-API-Key"] = api_key
3566
-
3567
- # Precompute optional policy overrides from TOML
3568
- policy_overrides: dict[str, Any] = {}
3569
- try:
3570
- # Accept [eval.policy] table or top-level keys for convenience
3571
- if isinstance(cfg.get("policy"), dict):
3572
- policy_overrides.update(dict(cfg["policy"]))
3573
- # Back-compat: allow temperature/max_tokens at top level
3574
- for k in (
3575
- "temperature",
3576
- "max_tokens",
3577
- "reasoning_effort",
3578
- "system_hint",
3579
- "tool_choice",
3580
- "inference_url",
3581
- ):
3582
- if k in cfg and k not in policy_overrides:
3583
- policy_overrides[k] = cfg.get(k)
3584
- except Exception:
3585
- policy_overrides = {}
3586
-
3587
- raw_concurrency = cfg.get("concurrency")
3588
- try:
3589
- concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
3590
- except Exception:
3591
- concurrency_limit = 1
3592
- if concurrency_limit <= 0:
3593
- concurrency_limit = 1
3594
- concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
3595
-
3596
- judge_specs: list[JudgeSpec] = []
3597
-
3598
- def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
3599
- if not judge_cfg:
3600
- return
3601
- judge_module = judge_cfg.get("module")
3602
- judge_path = judge_cfg.get("path")
3603
- judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
3604
- if judge_module and judge_path:
3605
- raise click.ClickException("Judge config cannot set both 'module' and 'path'")
3606
- if not judge_module and not judge_path:
3607
- raise click.ClickException("Judge config requires 'module' or 'path'")
3608
- try:
3609
- if judge_module:
3610
- module = importlib.import_module(str(judge_module))
3611
- else:
3612
- path = Path(str(judge_path)).expanduser()
3613
- if not path.exists():
3614
- raise click.ClickException(f"Judge module path not found: {path}")
3615
- spec = importlib.util.spec_from_file_location(
3616
- f"_eval_judge_{path.stem}", path
3617
- )
3618
- if not spec or not spec.loader:
3619
- raise click.ClickException(f"Failed to load judge module from {path}")
3620
- module = importlib.util.module_from_spec(spec)
3621
- sys.modules[spec.name] = module
3622
- spec.loader.exec_module(module)
3623
- except click.ClickException:
3624
- raise
3625
- except Exception as exc:
3626
- raise click.ClickException(f"Unable to load judge module: {exc}") from exc
3627
-
3628
- if judge_callable_name:
3629
- try:
3630
- judge_fn = getattr(module, str(judge_callable_name))
3631
- except AttributeError as exc:
3632
- raise click.ClickException(
3633
- f"Judge callable '{judge_callable_name}' not found in module"
3634
- ) from exc
3635
- else:
3636
- if hasattr(module, "judge"):
3637
- judge_fn = module.judge
3638
- else:
3639
- raise click.ClickException("Judge module must expose 'judge' callable")
3640
-
3641
- if not callable(judge_fn):
3642
- raise click.ClickException("Judge callable is not callable")
3643
-
3644
- judge_kwargs = {
3645
- k: v
3646
- for k, v in judge_cfg.items()
3647
- if k not in {"module", "path", "callable", "function", "name"}
3648
- }
3649
- display_name = str(
3650
- judge_cfg.get("name")
3651
- or name_hint
3652
- or f"judge{len(judge_specs) + 1}"
3653
- )
3654
- judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
3655
-
3656
- raw_judge_cfg = cfg.get("judge")
3657
- if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
3658
- direct_keys = {"module", "path", "callable", "function", "name"}
3659
- has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
3660
- nested_candidates = [
3661
- (key, value)
3662
- for key, value in raw_judge_cfg.items()
3663
- if isinstance(value, dict)
3664
- ]
3665
- if has_direct_keys and not nested_candidates:
3666
- _register_judge(None, raw_judge_cfg)
3667
- else:
3668
- for sub_name, sub_cfg in nested_candidates:
3669
- _register_judge(sub_name, sub_cfg)
3670
-
3671
- raw_judges_list = cfg.get("judges")
3672
- if isinstance(raw_judges_list, list):
3673
- for _index, entry in enumerate(raw_judges_list, start=1):
3674
- if isinstance(entry, dict):
3675
- _register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
3676
-
3677
- records: list[dict[str, Any]] = []
3678
-
3679
- successes = 0
3680
- failures = 0
3681
- # Aggregate outcome stats across successful seeds
3682
- outcome_sum: float = 0.0
3683
- outcome_count: int = 0
3684
- outcome_correct: int = 0
3685
-
3686
- def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
3687
- rows: dict[int, dict[str, Any]] = {}
3688
- if not isinstance(taskset, dict):
3689
- return rows
3690
-
3691
- scenario_ids = taskset.get("scenario_ids") or []
3692
- loop_ids = taskset.get("loop_ids") or []
3693
- thread_ids = taskset.get("thread_ids") or []
3694
- difficulty_map = taskset.get("difficulty_map") or {}
3695
-
3696
- max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
3697
- for seed in range(max_len):
3698
- scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
3699
- loop_id = loop_ids[seed] if seed < len(loop_ids) else None
3700
- thread_id = thread_ids[seed] if seed < len(thread_ids) else None
3701
- difficulty = None
3702
- if isinstance(difficulty_map, dict):
3703
- if scenario_id and scenario_id in difficulty_map:
3704
- difficulty = difficulty_map.get(scenario_id)
3705
- elif str(seed) in difficulty_map:
3706
- difficulty = difficulty_map.get(str(seed))
3707
-
3708
- rows[seed] = {
3709
- "seed": seed,
3710
- "scenario_id": scenario_id,
3711
- "loop_id": loop_id,
3712
- "thread_id": thread_id,
3713
- "difficulty": difficulty,
3714
- }
3715
- return rows
3716
-
3717
- def _apply_metadata_filters(
3718
- rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
3719
- ) -> list[int]:
3720
- if not filters:
3721
- return seeds_list
3722
- filtered: list[int] = []
3723
- for seed in seeds_list:
3724
- row = rows.get(seed)
3725
- if not row:
3726
- continue
3727
- include = True
3728
- for key, expected in filters.items():
3729
- actual = row.get(key)
3730
- if actual is None:
3731
- include = False
3732
- break
3733
- if str(actual).lower() != expected.lower():
3734
- include = False
3735
- break
3736
- if include:
3737
- filtered.append(seed)
3738
- return filtered
3739
-
3740
- def _apply_metadata_sql(
3741
- rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
3742
- ) -> list[int]:
3743
- """Return seeds that satisfy an arbitrary SQL query.
3744
-
3745
- The query is executed against an in-memory SQLite table named `tasks`
3746
- with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
3747
- Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
3748
- """
3749
- if not query:
3750
- return seeds_list
3751
- conn = sqlite3.connect(":memory:")
3752
- try:
3753
- cur = conn.cursor()
3754
- cur.execute(
3755
- "CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
3756
- )
3757
- insert_stmt = (
3758
- "INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
3759
- )
3760
- for seed in seeds_list:
3761
- row = rows.get(seed, {})
3762
- cur.execute(
3763
- insert_stmt,
3764
- [
3765
- seed,
3766
- row.get("scenario_id"),
3767
- row.get("loop_id"),
3768
- row.get("thread_id"),
3769
- row.get("difficulty"),
3770
- ],
3771
- )
3772
-
3773
- result = cur.execute(query)
3774
- fetched = result.fetchall()
3775
- if not fetched:
3776
- return []
3777
- description = result.description or []
3778
- col_names = [col[0] for col in description]
3779
- seeds_out: list[int] = []
3780
- for entry in fetched:
3781
- value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
3782
- try:
3783
- seeds_out.append(int(value))
3784
- except Exception as exc:
3785
- raise click.ClickException(
3786
- "metadata SQL query must return seed integers"
3787
- ) from exc
3788
- seeds_set = set(seeds_out)
3789
- return [seed for seed in seeds_list if seed in seeds_set]
3790
- except sqlite3.Error as exc:
3791
- raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
3792
- finally:
3793
- conn.close()
3794
-
3795
- async def _run_eval() -> None:
3796
- nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
3797
-
3798
- if trace_tracer is not None and trace_tracer.db is None:
3799
- await trace_tracer.initialize()
3800
-
3801
- if task_app_url is None:
3802
- transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
3803
- async_client = httpx.AsyncClient(
3804
- transport=cast(Any, transport),
3805
- base_url="http://eval.local",
3806
- timeout=300.0,
3807
- follow_redirects=True,
3808
- headers=headers,
3809
- )
3810
- else:
3811
- async_client = httpx.AsyncClient(
3812
- base_url=task_app_url,
3813
- timeout=300.0,
3814
- follow_redirects=True,
3815
- headers=headers,
3816
- )
3817
-
3818
- try:
3819
- taskset_payload: dict[str, Any] | None = None
3820
- try:
3821
- task_info_response = await async_client.get("/task_info")
3822
- except Exception:
3823
- task_info_response = None
3824
- if task_info_response is not None and task_info_response.status_code == 200:
3825
- with contextlib.suppress(Exception):
3826
- payload_json = task_info_response.json()
3827
- if isinstance(payload_json, dict) and "taskset" in payload_json:
3828
- taskset_payload = payload_json.get("taskset")
3829
- if not isinstance(taskset_payload, dict):
3830
- taskset_payload = None
3831
- elif isinstance(payload_json, dict):
3832
- taskset_payload = payload_json
3833
-
3834
- available_seeds = list(seed_values)
3835
- if metadata_sql_query or metadata_filters:
3836
- if not taskset_payload:
3837
- raise click.ClickException(
3838
- "Task metadata filters require the task app to expose /task_info metadata"
3839
- )
3840
- rows = _build_task_rows(taskset_payload)
3841
- if metadata_sql_query:
3842
- available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
3843
- if metadata_filters:
3844
- available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
3845
- if not available_seeds:
3846
- raise click.ClickException("No seeds match the provided metadata filters")
3847
- seed_values = available_seeds
3848
-
3849
- semaphore = asyncio.Semaphore(concurrency_limit)
3850
-
3851
- async def _run_seed(seed_val: int) -> None:
3852
- nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
3853
- # Read env_name and policy_name from config if available
3854
- env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
3855
- policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
3856
- env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
3857
- policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
3858
-
3859
- # Debug: print config parsing
3860
- if seed_val == 0:
3861
- click.echo(f"[DEBUG] env_name from config: {env_name}")
3862
- click.echo(f"[DEBUG] policy_name from config: {policy_name}")
3863
-
3864
- # Generate default ops sequence if not provided
3865
- max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
3866
- ops_list = cfg.get("ops", [])
3867
- if not ops_list:
3868
- # Generate default "agent, env" pairs for max_llm_calls
3869
- ops_list = ["agent", "env"] * int(max_llm_calls)
3870
-
3871
- body = {
3872
- "run_id": str(uuid.uuid4()),
3873
- "env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
3874
- "policy": {
3875
- "policy_name": policy_name or selected_model,
3876
- "config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
3877
- },
3878
- "ops": ops_list,
3879
- "record": {
3880
- "return_trace": cfg.get("return_trace", True),
3881
- "trace_format": cfg.get("trace_format", "structured"),
3882
- },
3883
- "mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
3884
- }
3885
- if env_name:
3886
- body["env"]["env_name"] = env_name
3887
-
3888
- # Debug: print the body being sent
3889
- if seed_val == 0:
3890
- click.echo(f"[DEBUG] rollout body env: {body['env']}")
3891
- click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
3892
- click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
3893
- rollout_elapsed: float | None = None
3894
- rollout_start = time.perf_counter()
3895
- try:
3896
- import logging
3897
- _log = logging.getLogger(__name__)
3898
- _log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
3899
- async with semaphore:
3900
- response = await async_client.post("/rollout", json=body)
3901
- rollout_elapsed = time.perf_counter() - rollout_start
3902
- except Exception as exc:
3903
- failures += 1
3904
- click.echo(f"seed={seed_val} error={exc}")
3905
- return
3906
-
3907
- ok = 200 <= response.status_code < 300
3908
- if ok:
3909
- successes += 1
3910
- else:
3911
- failures += 1
3912
-
3913
- summary = [f"seed={seed_val}", f"status={response.status_code}"]
3914
- data: Any
3915
- try:
3916
- data = response.json()
3917
- except Exception:
3918
- data = None
3919
-
3920
- # Debug: print validation errors
3921
- if response.status_code == 422 and data:
3922
- click.echo(f"[DEBUG] 422 Validation Error: {data}")
3923
-
3924
- metrics: dict[str, Any] | None = None
3925
- completion: str | None = None
3926
- prompt_index: int | None = None
3927
- prompt_text: str | None = None
3928
- task_id: str | None = None
3929
- task_split: str | None = None
3930
- task_rubric_id: str | None = None
3931
-
3932
- trace_namespace: dict[str, Any] | None = None
3933
- session_trace_dict: dict[str, Any] | None = None
3934
-
3935
- if isinstance(data, dict):
3936
- import logging
3937
- _logger = logging.getLogger(__name__)
3938
- _logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
3939
- if "detail" in data:
3940
- _logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
3941
- trace_namespace = data.get("trace")
3942
- _logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
3943
- if not isinstance(trace_namespace, dict):
3944
- raise RuntimeError(
3945
- "The 'synth-ai eval' command requires trace payloads in rollout responses. "
3946
- "Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
3947
- "and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
3948
- "Note: This is specific to the eval command - general rollout endpoints don't require traces."
3949
- )
3950
- # Handle both "compact" and "full" trace formats:
3951
- # - compact: trace_namespace contains {session_id, metadata, ...}
3952
- # - full: trace_namespace IS the full session_trace dict
3953
- session_trace_dict = trace_namespace.get("session_trace")
3954
- if not isinstance(session_trace_dict, dict):
3955
- # If no session_trace key, assume "full" format where trace itself is the session_trace
3956
- if "session_id" in trace_namespace:
3957
- session_trace_dict = trace_namespace
3958
- else:
3959
- raise RuntimeError(
3960
- "The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
3961
- "Ensure the task app is using tracing_v3 and returning structured trace data."
3962
- )
3963
- metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
3964
- if metrics:
3965
- mean_return = metrics.get("mean_return") or metrics.get("total_reward")
3966
- outcome = metrics.get("outcome_score")
3967
- if mean_return is not None:
3968
- summary.append(f"mean_return={mean_return}")
3969
- if outcome is not None:
3970
- summary.append(f"outcome={outcome}")
3971
- try:
3972
- val = float(outcome)
3973
- outcome_sum += val
3974
- outcome_count += 1
3975
- if val >= 0.5:
3976
- outcome_correct += 1
3977
- except Exception:
3978
- pass
3979
- trajs = (
3980
- data.get("trajectories")
3981
- if isinstance(data.get("trajectories"), list)
3982
- else None
3983
- )
3984
- if trajs:
3985
- first = trajs[0] if trajs else None
3986
- steps = first.get("steps") if isinstance(first, dict) else None
3987
- if isinstance(steps, list) and steps:
3988
- step0 = steps[0]
3989
- tool_calls = step0.get("tool_calls") or step0.get("tools") or []
3990
- if isinstance(tool_calls, list):
3991
- summary.append(f"tool_calls={len(tool_calls)}")
3992
- obs = step0.get("obs") if isinstance(step0, dict) else None
3993
- if isinstance(obs, dict):
3994
- idx_val = obs.get("prompt_index")
3995
- if isinstance(idx_val, int):
3996
- prompt_index = idx_val
3997
- prompt_raw = obs.get("prompt")
3998
- if isinstance(prompt_raw, str):
3999
- prompt_text = prompt_raw
4000
- if task_id is None:
4001
- candidate_id = obs.get("task_id")
4002
- if isinstance(candidate_id, str) and candidate_id:
4003
- task_id = candidate_id
4004
- if task_split is None:
4005
- candidate_split = obs.get("task_split")
4006
- if isinstance(candidate_split, str) and candidate_split:
4007
- task_split = candidate_split
4008
- if task_rubric_id is None:
4009
- candidate_rid = obs.get("task_rubric_id")
4010
- if isinstance(candidate_rid, str) and candidate_rid:
4011
- task_rubric_id = candidate_rid
4012
- final = first.get("final") if isinstance(first, dict) else None
4013
- if isinstance(final, dict):
4014
- final_obs = final.get("observation")
4015
- if isinstance(final_obs, dict):
4016
- comp_val = final_obs.get("completion")
4017
- if isinstance(comp_val, str):
4018
- completion = comp_val
4019
- if task_id is None:
4020
- candidate_id = final_obs.get("task_id")
4021
- if isinstance(candidate_id, str) and candidate_id:
4022
- task_id = candidate_id
4023
- if task_split is None:
4024
- candidate_split = final_obs.get("task_split")
4025
- if isinstance(candidate_split, str) and candidate_split:
4026
- task_split = candidate_split
4027
- if task_rubric_id is None:
4028
- candidate_rid = final_obs.get("task_rubric_id")
4029
- if isinstance(candidate_rid, str) and candidate_rid:
4030
- task_rubric_id = candidate_rid
4031
- final_info = final.get("info")
4032
- if isinstance(final_info, dict):
4033
- if task_id is None:
4034
- candidate_id = final_info.get("task_id")
4035
- if isinstance(candidate_id, str) and candidate_id:
4036
- task_id = candidate_id
4037
- if task_split is None:
4038
- candidate_split = final_info.get("task_split")
4039
- if isinstance(candidate_split, str) and candidate_split:
4040
- task_split = candidate_split
4041
- if task_rubric_id is None:
4042
- candidate_rid = final_info.get("task_rubric_id")
4043
- if isinstance(candidate_rid, str) and candidate_rid:
4044
- task_rubric_id = candidate_rid
4045
- if task_id:
4046
- summary.append(f"task_id={task_id}")
4047
- click.echo(" ".join(summary))
4048
- with contextlib.suppress(Exception):
4049
- click.echo(json.dumps(data, indent=2))
4050
- else:
4051
- click.echo(" ".join(summary))
4052
-
4053
- official_score = None
4054
- if isinstance(metrics, dict):
4055
- for key in ("mean_return", "total_reward", "outcome_score"):
4056
- val = metrics.get(key)
4057
- if isinstance(val, int | float):
4058
- official_score = float(val)
4059
- break
4060
- if official_score is None and isinstance(data, dict):
4061
- try:
4062
- reward_val = data["trajectories"][0]["steps"][0].get("reward")
4063
- if isinstance(reward_val, int | float):
4064
- official_score = float(reward_val)
4065
- except Exception:
4066
- pass
4067
-
4068
- if official_score is not None:
4069
- if official_score < 0.0:
4070
- official_score = 0.0
4071
- elif official_score > 1.0:
4072
- official_score = min(1.0, official_score)
4073
-
4074
- judge_scores: dict[str, float | None] = {}
4075
- judges_timings: dict[str, float | None] = {}
4076
- timings: dict[str, Any] = {
4077
- "rollout_s": rollout_elapsed,
4078
- "judges": judges_timings,
4079
- }
4080
- if judge_specs:
4081
- for spec in judge_specs:
4082
- score_value: float | None = None
4083
- judge_elapsed: float | None = None
4084
- # Run judges for all tasks (text-based and trajectory-based)
4085
- # Text-based tasks have completion, trajectory-based tasks use response
4086
- judge_payload = {
4087
- "seed": seed_val,
4088
- "prompt_index": prompt_index,
4089
- "prompt": prompt_text,
4090
- "completion": completion,
4091
- "metrics": metrics,
4092
- "response": data,
4093
- "trace": trace_namespace,
4094
- }
4095
- try:
4096
- judge_start = time.perf_counter()
4097
- result = spec.fn(judge_payload, **spec.kwargs)
4098
- judge_elapsed = time.perf_counter() - judge_start
4099
- if isinstance(result, int | float):
4100
- score_value = float(result)
4101
- except Exception as exc:
4102
- if judge_elapsed is None:
4103
- judge_elapsed = time.perf_counter() - judge_start
4104
- click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
4105
- judges_timings[spec.name] = judge_elapsed
4106
- judge_scores[spec.name] = score_value
4107
-
4108
- if trace_tracer is not None and trace_namespace:
4109
- storage_metadata = {
4110
- "eval_seed": seed_val,
4111
- "prompt_index": prompt_index,
4112
- "task_id": task_id,
4113
- "task_split": task_split,
4114
- "task_rubric_id": task_rubric_id,
4115
- "official_score": official_score,
4116
- "judge_scores": judge_scores,
4117
- "model": selected_model,
4118
- "prompt": prompt_text,
4119
- "completion": completion,
4120
- }
4121
- await _store_trace(trace_tracer, trace_namespace, storage_metadata)
4122
-
4123
- records.append(
4124
- {
4125
- "seed": seed_val,
4126
- "prompt_index": prompt_index,
4127
- "task_id": task_id,
4128
- "task_split": task_split,
4129
- "task_rubric_id": task_rubric_id,
4130
- "official_score": official_score,
4131
- "judge_scores": judge_scores,
4132
- "timings": timings,
4133
- }
4134
- )
4135
-
4136
- await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
4137
- finally:
4138
- await async_client.aclose()
4139
-
4140
- try:
4141
- asyncio.run(_run_eval())
4142
- finally:
4143
- if trace_tracer is not None and trace_tracer.db is not None:
4144
- asyncio.run(trace_tracer.db.close())
4145
-
4146
- click.echo(
4147
- f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
4148
- )
4149
-
4150
- if outcome_count > 0:
4151
- mean_outcome = outcome_sum / float(outcome_count)
4152
- frac_right = outcome_correct / float(outcome_count)
4153
- click.echo(
4154
- f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
4155
- )
4156
-
4157
- if records:
4158
- judge_specs = judge_specs or [] # ensure iterable
4159
- official_scores = [
4160
- r["official_score"] for r in records if r["official_score"] is not None
4161
- ]
4162
- if official_scores:
4163
- click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
4164
- else:
4165
- click.echo(" Official mean: n/a")
4166
-
4167
- for spec in judge_specs:
4168
- spec_scores = [
4169
- record["judge_scores"].get(spec.name)
4170
- for record in records
4171
- if record["judge_scores"].get(spec.name) is not None
4172
- ]
4173
- if spec_scores:
4174
- mean_spec = sum(spec_scores) / len(spec_scores)
4175
- click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
4176
- else:
4177
- click.echo(f" [{spec.name}] mean: n/a")
4178
-
4179
- paired = [
4180
- (
4181
- record["official_score"],
4182
- record["judge_scores"].get(spec.name),
4183
- )
4184
- for record in records
4185
- if record["official_score"] is not None
4186
- and record["judge_scores"].get(spec.name) is not None
4187
- ]
4188
- if len(paired) >= 2:
4189
- corr = _pearson(
4190
- [p[0] for p in paired if p[0] is not None],
4191
- [p[1] for p in paired if p[1] is not None],
4192
- )
4193
- if corr is not None:
4194
- click.echo(f" Pearson r: {corr:.3f}")
4195
- else:
4196
- click.echo(" Pearson r: undefined (zero variance)")
4197
- else:
4198
- click.echo(" Pearson r: n/a (need ≥2 paired scores)")
4199
-
4200
- header = ["Seed", "Prompt", "Official"]
4201
- header.extend(spec.name for spec in judge_specs)
4202
- rows: list[list[str]] = []
4203
- for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
4204
- seed_val = str(record["seed"])
4205
- prompt_idx = (
4206
- str(record["prompt_index"])
4207
- if record["prompt_index"] is not None
4208
- else "-"
4209
- )
4210
- official_val = (
4211
- f"{record['official_score']:.3f}"
4212
- if record["official_score"] is not None
4213
- else "-"
4214
- )
4215
- row = [seed_val, prompt_idx, official_val]
4216
- for spec in judge_specs:
4217
- score_val = record["judge_scores"].get(spec.name)
4218
- row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
4219
- rows.append(row)
4220
-
4221
- widths = [len(col) for col in header]
4222
- for row in rows:
4223
- for idx, cell in enumerate(row):
4224
- widths[idx] = max(widths[idx], len(cell))
4225
-
4226
- click.echo("")
4227
- click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
4228
- click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
4229
- for row in rows:
4230
- click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
4231
-
4232
-
4233
-
4234
- @click.command(
4235
- "filter",
4236
- help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
4237
- )
4238
- @click.option(
4239
- "--config",
4240
- "config_path",
4241
- type=click.Path(),
4242
- required=True,
4243
- help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
4244
- )
4245
- def filter_command(config_path: str) -> None:
4246
- """Render tracing sessions that match filter rules into SFT JSONL.
4247
-
4248
- The TOML file should contain a `[filter]` table with at least:
4249
-
4250
- db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
4251
- output = \"ft_data/out.jsonl\" # destination JSONL
4252
-
4253
- Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
4254
- `min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
4255
- high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
4256
- for a working example.
4257
- """
4258
- # Parse and validate TOML config
4259
- from synth_ai.task.config import FilterConfig
4260
-
4261
- if _toml is None:
4262
- raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
4263
-
4264
- cfg_path = Path(config_path)
4265
- if not cfg_path.exists():
4266
- raise click.ClickException(f"Filter config not found: {cfg_path}")
4267
-
4268
- try:
4269
- config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
4270
- except Exception as exc:
4271
- raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
4272
-
4273
- filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
4274
- if not isinstance(filter_cfg_dict, dict):
4275
- raise click.ClickException("Config must contain a [filter] table")
4276
-
4277
- # Validate config with dataclass
4278
- try:
4279
- filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
4280
- click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
4281
- if filter_cfg.min_official_score is not None:
4282
- click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
4283
- if filter_cfg.limit:
4284
- click.echo(f" → Limiting to {filter_cfg.limit} examples")
4285
- except (ValueError, TypeError) as validation_error:
4286
- raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
4287
-
4288
- # Use validated config
4289
- db_url = filter_cfg.get_db_url()
4290
- output_path = filter_cfg.get_output_path()
4291
-
4292
- # Extract validated fields from dataclass
4293
- splits = set(filter_cfg.splits)
4294
- task_ids = set(filter_cfg.task_ids)
4295
- models = set(filter_cfg.models)
4296
- min_official = filter_cfg.min_official_score
4297
- max_official = filter_cfg.max_official_score
4298
- min_judge_scores = filter_cfg.min_judge_scores
4299
- max_judge_scores = filter_cfg.max_judge_scores
4300
- # Note: min_created_at and max_created_at not yet in FilterConfig dataclass
4301
- min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
4302
- max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
4303
- limit = filter_cfg.limit
4304
-
4305
- def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
4306
- try:
4307
- if value is None:
4308
- return min_val is None
4309
- value = float(value)
4310
- except Exception:
4311
- return False
4312
- if min_val is not None and value < float(min_val):
4313
- return False
4314
- return not (max_val is not None and value > float(max_val))
4315
-
4316
- async def _run_filter() -> None:
4317
- tracer = SessionTracer(db_url=db_url, auto_save=False)
4318
- await tracer.initialize()
4319
-
4320
- df = await tracer.db.query_traces(
4321
- "SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
4322
- )
4323
- if getattr(df, "empty", True):
4324
- raise click.ClickException("No traces found in database")
4325
-
4326
- sessions = df.to_dict("records")
4327
- accepted: list[dict[str, Any]] = []
4328
-
4329
- for row in sessions:
4330
- metadata_raw = row.get("metadata")
4331
- if isinstance(metadata_raw, str):
4332
- try:
4333
- metadata = json.loads(metadata_raw)
4334
- except Exception:
4335
- metadata = {}
4336
- elif isinstance(metadata_raw, dict):
4337
- metadata = dict(metadata_raw)
4338
- else:
4339
- metadata = {}
4340
-
4341
- created_at_raw = row.get("created_at")
4342
- created_at_dt = _parse_datetime_for_trace(created_at_raw)
4343
-
4344
- session_id = row.get("session_id")
4345
-
4346
- if splits and metadata.get("task_split") not in splits:
4347
- continue
4348
- if task_ids and metadata.get("task_id") not in task_ids:
4349
- continue
4350
- if models and metadata.get("model") not in models:
4351
- continue
4352
-
4353
- if min_created and (created_at_dt is None or created_at_dt < min_created):
4354
- continue
4355
- if max_created and (created_at_dt is None or created_at_dt > max_created):
4356
- continue
4357
-
4358
- # Check against outcome_rewards if score filter is set
4359
- total_reward = None
4360
- achievements_count = None
4361
- if min_official is not None or max_official is not None:
4362
- reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
4363
- reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
4364
- reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
4365
- if reward_records:
4366
- total_reward = reward_records[0].get("total_reward")
4367
- achievements_count = reward_records[0].get("achievements_count")
4368
- if not _score_ok(total_reward, min_official, max_official):
4369
- continue
4370
- elif min_official is not None:
4371
- # No reward found, but score filter requires it
4372
- continue
4373
-
4374
- judge_scores = metadata.get("judge_scores") or {}
4375
- include = True
4376
- for judge_name, threshold in (min_judge_scores or {}).items():
4377
- if not _score_ok(judge_scores.get(judge_name), threshold, None):
4378
- include = False
4379
- break
4380
- if not include:
4381
- continue
4382
- for judge_name, threshold in (max_judge_scores or {}).items():
4383
- if not _score_ok(judge_scores.get(judge_name), None, threshold):
4384
- include = False
4385
- break
4386
- if not include:
4387
- continue
4388
-
4389
- # Query messages for this session
4390
- messages_query = """
4391
- SELECT message_type, content, timestamp
4392
- FROM messages
4393
- WHERE session_id = :session_id
4394
- ORDER BY timestamp ASC, id ASC
4395
- """
4396
- msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
4397
- message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
4398
-
4399
- if not message_rows:
4400
- # Fallback: check if prompt/completion in metadata (old format)
4401
- prompt = metadata.get("prompt") or ""
4402
- completion = metadata.get("completion") or ""
4403
- if prompt and completion:
4404
- record = {
4405
- "messages": [
4406
- {"role": "user", "content": str(prompt)},
4407
- {"role": "assistant", "content": str(completion)},
4408
- ],
4409
- "metadata": {
4410
- "session_id": session_id,
4411
- "env_name": metadata.get("env_name"),
4412
- "policy_name": metadata.get("policy_name"),
4413
- "seed": metadata.get("seed"),
4414
- "total_reward": total_reward,
4415
- "achievements_count": achievements_count,
4416
- "model": metadata.get("model"),
4417
- "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4418
- },
4419
- }
4420
- accepted.append(record)
4421
- continue
4422
-
4423
- # Extract user/assistant pairs from messages
4424
- for i, msg_row in enumerate(message_rows):
4425
- msg_type = msg_row.get("message_type")
4426
- content_raw = msg_row.get("content")
4427
-
4428
- # Look for user message
4429
- if msg_type in ("user", "policy_user_prompt"):
4430
- # Find next policy_system_prompt or assistant
4431
- assistant_msg = None
4432
- for j in range(i + 1, len(message_rows)):
4433
- next_type = message_rows[j].get("message_type")
4434
- if next_type in ("assistant", "policy_system_prompt"):
4435
- if next_type == "assistant":
4436
- assistant_msg = message_rows[j]
4437
- break
4438
-
4439
- # Parse content
4440
- try:
4441
- user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
4442
- except Exception:
4443
- user_content = content_raw
4444
-
4445
- # Extract text from structured content
4446
- def extract_text(content: Any) -> str:
4447
- if isinstance(content, str):
4448
- return content
4449
- if isinstance(content, dict):
4450
- # Try payload.content for user prompts
4451
- if "payload" in content and isinstance(content["payload"], dict):
4452
- payload = content["payload"]
4453
- if "content" in payload:
4454
- return extract_text(payload["content"])
4455
- # Try common keys
4456
- for key in ["text", "content", "content_text"]:
4457
- if key in content:
4458
- val = content[key]
4459
- if isinstance(val, str):
4460
- return val
4461
- return json.dumps(content)
4462
- if isinstance(content, list):
4463
- # Multimodal content - concatenate text parts
4464
- parts = []
4465
- for item in content:
4466
- if isinstance(item, dict) and item.get("type") == "text":
4467
- parts.append(item.get("text", ""))
4468
- return " ".join(parts) if parts else str(content)
4469
- return str(content)
4470
-
4471
- user_text = extract_text(user_content)
4472
-
4473
- # For assistant, we might not have it recorded, so use tool calls as completion
4474
- assistant_text = ""
4475
- if assistant_msg:
4476
- assistant_content_raw = assistant_msg.get("content")
4477
- try:
4478
- assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
4479
- except Exception:
4480
- assistant_content = assistant_content_raw
4481
- assistant_text = extract_text(assistant_content)
4482
-
4483
- if not user_text:
4484
- continue
4485
-
4486
- record = {
4487
- "messages": [
4488
- {"role": "user", "content": user_text},
4489
- {"role": "assistant", "content": assistant_text if assistant_text else "[no response recorded]"},
4490
- ],
4491
- "metadata": {
4492
- "session_id": session_id,
4493
- "env_name": metadata.get("env_name"),
4494
- "policy_name": metadata.get("policy_name"),
4495
- "seed": metadata.get("seed"),
4496
- "total_reward": total_reward,
4497
- "achievements_count": achievements_count,
4498
- "model": metadata.get("model"),
4499
- "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4500
- },
4501
- }
4502
- accepted.append(record)
4503
-
4504
- if not accepted:
4505
- raise click.ClickException("No sessions matched the provided filters")
4506
-
4507
- if limit is not None and limit > 0:
4508
- accepted = accepted[:limit]
4509
-
4510
- output_path.parent.mkdir(parents=True, exist_ok=True)
4511
- with output_path.open("w", encoding="utf-8") as handle:
4512
- for item in accepted:
4513
- handle.write(json.dumps(item, ensure_ascii=False))
4514
- handle.write("\n")
4515
-
4516
- click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
4517
- await tracer.db.close()
3128
+ eval_command = eval_core.command
4518
3129
 
4519
- asyncio.run(_run_filter())
3130
+ filter_command = filter_core.command
4520
3131
 
4521
3132
 
4522
3133
  def register_eval(cli: click.Group) -> None: