synth-ai 0.2.16__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (299) hide show
  1. examples/analyze_semantic_words.sh +2 -2
  2. examples/baseline/banking77_baseline.py +204 -0
  3. examples/baseline/crafter_baseline.py +407 -0
  4. examples/baseline/pokemon_red_baseline.py +326 -0
  5. examples/baseline/simple_baseline.py +56 -0
  6. examples/baseline/warming_up_to_rl_baseline.py +239 -0
  7. examples/blog_posts/gepa/README.md +355 -0
  8. examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
  9. examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
  10. examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
  11. examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
  12. examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
  13. examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
  14. examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
  15. examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
  16. examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
  17. examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
  18. examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
  19. examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
  20. examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
  21. examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
  22. examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
  23. examples/blog_posts/gepa/gepa_baseline.py +204 -0
  24. examples/blog_posts/gepa/query_prompts_example.py +97 -0
  25. examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
  26. examples/blog_posts/gepa/task_apps.py +105 -0
  27. examples/blog_posts/gepa/test_gepa_local.sh +67 -0
  28. examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
  29. examples/blog_posts/pokemon_vl/README.md +98 -0
  30. examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
  31. examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
  32. examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
  33. examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
  34. examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
  35. examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
  36. examples/blog_posts/pokemon_vl/extract_images.py +239 -0
  37. examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
  38. examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
  39. examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
  40. examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
  41. examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
  42. examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
  43. examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
  44. examples/blog_posts/warming_up_to_rl/README.md +158 -0
  45. examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
  46. examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
  47. examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
  48. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
  49. examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
  50. examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
  51. examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
  52. examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
  53. examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
  54. examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
  55. examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
  56. examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
  57. examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
  58. examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
  59. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
  60. examples/multi_step/configs/crafter_rl_outcome.toml +2 -1
  61. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
  62. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +2 -1
  63. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +2 -1
  64. examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
  65. examples/multi_step/configs/verilog_rl_lora.toml +80 -123
  66. examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
  67. examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
  68. examples/qwen_coder/configs/coder_lora_small.toml +1 -3
  69. examples/qwen_vl/README.md +10 -12
  70. examples/qwen_vl/SETUP_COMPLETE.md +7 -8
  71. examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
  72. examples/qwen_vl/collect_data_via_cli.md +76 -84
  73. examples/qwen_vl/collect_vision_traces.py +4 -4
  74. examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
  75. examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
  76. examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
  77. examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
  78. examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
  79. examples/qwen_vl/configs/{filter_qwen2vl_sft.toml → filter_qwen3vl_sft.toml} +4 -5
  80. examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
  81. examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
  82. examples/qwen_vl/run_vision_comparison.sh +6 -7
  83. examples/rl/README.md +5 -5
  84. examples/rl/configs/rl_from_base_qwen.toml +26 -1
  85. examples/rl/configs/rl_from_base_qwen17.toml +6 -2
  86. examples/rl/task_app/README.md +1 -2
  87. examples/rl/task_app/math_single_step.py +2 -2
  88. examples/run_crafter_demo.sh +2 -2
  89. examples/sft/README.md +1 -1
  90. examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
  91. examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
  92. examples/swe/task_app/README.md +32 -2
  93. examples/swe/task_app/grpo_swe_mini.py +4 -0
  94. examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
  95. examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
  96. examples/swe/task_app/hosted/inference/openai_client.py +4 -38
  97. examples/swe/task_app/hosted/policy_routes.py +17 -0
  98. examples/swe/task_app/hosted/rollout.py +4 -2
  99. examples/swe/task_app/morph_backend.py +178 -0
  100. examples/task_apps/banking77/__init__.py +6 -0
  101. examples/task_apps/banking77/banking77_task_app.py +841 -0
  102. examples/task_apps/banking77/deploy_wrapper.py +46 -0
  103. examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
  104. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
  105. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
  106. examples/task_apps/crafter/task_app/README.md +1 -1
  107. examples/task_apps/crafter/task_app/grpo_crafter.py +90 -5
  108. examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
  109. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
  110. examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
  111. examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
  112. examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +372 -107
  113. examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +81 -12
  114. examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +82 -11
  115. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
  116. examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
  117. examples/task_apps/gepa_benchmarks/__init__.py +7 -0
  118. examples/task_apps/gepa_benchmarks/common.py +260 -0
  119. examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
  120. examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
  121. examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
  122. examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
  123. examples/task_apps/math/README.md +1 -2
  124. examples/task_apps/pokemon_red/README.md +3 -4
  125. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
  126. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
  127. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
  128. examples/task_apps/pokemon_red/task_app.py +288 -39
  129. examples/task_apps/sokoban/README.md +2 -3
  130. examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
  131. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
  132. examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
  133. examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
  134. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
  135. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +3 -2
  136. examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
  137. examples/warming_up_to_rl/task_app/README.md +1 -1
  138. examples/warming_up_to_rl/task_app/grpo_crafter.py +185 -5
  139. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
  140. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
  141. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
  142. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
  143. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +156 -45
  144. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +37 -4
  145. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
  146. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
  147. examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
  148. examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +6 -0
  149. synth_ai/api/train/builders.py +99 -4
  150. synth_ai/api/train/cli.py +516 -26
  151. synth_ai/api/train/config_finder.py +13 -2
  152. synth_ai/api/train/configs/__init__.py +23 -2
  153. synth_ai/api/train/configs/prompt_learning.py +442 -0
  154. synth_ai/api/train/configs/rl.py +61 -7
  155. synth_ai/api/train/configs/sft.py +6 -2
  156. synth_ai/api/train/configs/shared.py +59 -2
  157. synth_ai/api/train/task_app.py +1 -1
  158. synth_ai/api/train/validators.py +277 -0
  159. synth_ai/auth/credentials.py +119 -0
  160. synth_ai/baseline/__init__.py +25 -0
  161. synth_ai/baseline/config.py +209 -0
  162. synth_ai/baseline/discovery.py +214 -0
  163. synth_ai/baseline/execution.py +146 -0
  164. synth_ai/cli/__init__.py +94 -18
  165. synth_ai/cli/__main__.py +0 -0
  166. synth_ai/cli/claude.py +70 -0
  167. synth_ai/cli/codex.py +84 -0
  168. synth_ai/cli/commands/__init__.py +18 -0
  169. synth_ai/cli/commands/baseline/__init__.py +12 -0
  170. synth_ai/cli/commands/baseline/core.py +637 -0
  171. synth_ai/cli/commands/baseline/list.py +93 -0
  172. synth_ai/cli/commands/demo/__init__.py +6 -0
  173. synth_ai/cli/commands/demo/core.py +163 -0
  174. synth_ai/cli/commands/eval/__init__.py +19 -0
  175. synth_ai/cli/commands/eval/core.py +1112 -0
  176. synth_ai/cli/commands/eval/errors.py +81 -0
  177. synth_ai/cli/commands/eval/validation.py +133 -0
  178. synth_ai/cli/commands/filter/__init__.py +12 -0
  179. synth_ai/cli/commands/filter/core.py +424 -0
  180. synth_ai/cli/commands/filter/errors.py +55 -0
  181. synth_ai/cli/commands/filter/validation.py +77 -0
  182. synth_ai/cli/commands/help/__init__.py +177 -0
  183. synth_ai/cli/commands/help/core.py +72 -0
  184. synth_ai/cli/commands/smoke/__init__.py +7 -0
  185. synth_ai/cli/commands/smoke/core.py +1436 -0
  186. synth_ai/cli/commands/status/__init__.py +64 -0
  187. synth_ai/cli/commands/status/client.py +192 -0
  188. synth_ai/cli/commands/status/config.py +92 -0
  189. synth_ai/cli/commands/status/errors.py +20 -0
  190. synth_ai/cli/commands/status/formatters.py +164 -0
  191. synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
  192. synth_ai/cli/commands/status/subcommands/files.py +79 -0
  193. synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
  194. synth_ai/cli/commands/status/subcommands/models.py +79 -0
  195. synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
  196. synth_ai/cli/commands/status/subcommands/runs.py +81 -0
  197. synth_ai/cli/commands/status/subcommands/summary.py +47 -0
  198. synth_ai/cli/commands/status/subcommands/usage.py +203 -0
  199. synth_ai/cli/commands/status/utils.py +114 -0
  200. synth_ai/cli/commands/train/__init__.py +53 -0
  201. synth_ai/cli/commands/train/core.py +21 -0
  202. synth_ai/cli/commands/train/errors.py +117 -0
  203. synth_ai/cli/commands/train/judge_schemas.py +200 -0
  204. synth_ai/cli/commands/train/judge_validation.py +305 -0
  205. synth_ai/cli/commands/train/validation.py +386 -0
  206. synth_ai/cli/demo.py +30 -158
  207. synth_ai/cli/deploy/__init__.py +43 -0
  208. synth_ai/cli/deploy.py +162 -0
  209. synth_ai/cli/eval/__init__.py +36 -0
  210. synth_ai/cli/eval/core.py +5 -0
  211. synth_ai/cli/eval/errors.py +31 -0
  212. synth_ai/cli/eval/validation.py +5 -0
  213. synth_ai/cli/filter/__init__.py +28 -0
  214. synth_ai/cli/filter/core.py +5 -0
  215. synth_ai/cli/filter/errors.py +23 -0
  216. synth_ai/cli/filter/validation.py +5 -0
  217. synth_ai/cli/legacy_root_backup.py +14 -8
  218. synth_ai/cli/modal_serve/__init__.py +12 -0
  219. synth_ai/cli/modal_serve/core.py +14 -0
  220. synth_ai/cli/modal_serve/errors.py +8 -0
  221. synth_ai/cli/modal_serve/validation.py +11 -0
  222. synth_ai/cli/opencode.py +107 -0
  223. synth_ai/cli/root.py +9 -5
  224. synth_ai/cli/serve/__init__.py +12 -0
  225. synth_ai/cli/serve/core.py +14 -0
  226. synth_ai/cli/serve/errors.py +8 -0
  227. synth_ai/cli/serve/validation.py +11 -0
  228. synth_ai/cli/setup.py +20 -265
  229. synth_ai/cli/status.py +7 -126
  230. synth_ai/cli/task_app_deploy.py +1 -10
  231. synth_ai/cli/task_app_modal_serve.py +4 -9
  232. synth_ai/cli/task_app_serve.py +4 -11
  233. synth_ai/cli/task_apps.py +51 -1480
  234. synth_ai/cli/train/__init__.py +12 -0
  235. synth_ai/cli/train/core.py +21 -0
  236. synth_ai/cli/train/errors.py +8 -0
  237. synth_ai/cli/train/validation.py +24 -0
  238. synth_ai/cli/train.py +1 -14
  239. synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
  240. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
  241. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
  242. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
  243. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
  244. synth_ai/environments/examples/red/engine.py +33 -12
  245. synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
  246. synth_ai/environments/examples/red/environment.py +26 -0
  247. synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
  248. synth_ai/http.py +12 -0
  249. synth_ai/judge_schemas.py +10 -10
  250. synth_ai/learning/__init__.py +10 -0
  251. synth_ai/learning/prompt_learning_client.py +276 -0
  252. synth_ai/learning/prompt_learning_types.py +184 -0
  253. synth_ai/learning/rl/client.py +3 -1
  254. synth_ai/pricing/__init__.py +2 -0
  255. synth_ai/pricing/model_pricing.py +57 -0
  256. synth_ai/streaming/__init__.py +29 -0
  257. synth_ai/streaming/config.py +94 -0
  258. synth_ai/streaming/handlers.py +518 -0
  259. synth_ai/streaming/streamer.py +320 -0
  260. synth_ai/streaming/types.py +95 -0
  261. synth_ai/task/apps/__init__.py +1 -0
  262. synth_ai/task/config.py +2 -0
  263. synth_ai/task/tracing_utils.py +25 -25
  264. synth_ai/task/validators.py +45 -9
  265. synth_ai/task_app_cfgs.py +21 -0
  266. synth_ai/tracing_v3/config.py +162 -19
  267. synth_ai/tracing_v3/constants.py +1 -1
  268. synth_ai/tracing_v3/db_config.py +24 -38
  269. synth_ai/tracing_v3/migration_helper.py +1 -2
  270. synth_ai/tracing_v3/storage/config.py +47 -13
  271. synth_ai/tracing_v3/storage/factory.py +3 -3
  272. synth_ai/tracing_v3/turso/daemon.py +113 -11
  273. synth_ai/tracing_v3/turso/native_manager.py +92 -16
  274. synth_ai/types.py +8 -0
  275. synth_ai/urls.py +11 -0
  276. synth_ai/utils/__init__.py +30 -1
  277. synth_ai/utils/agents.py +74 -0
  278. synth_ai/utils/bin.py +39 -0
  279. synth_ai/utils/cli.py +149 -5
  280. synth_ai/utils/env.py +40 -33
  281. synth_ai/utils/http.py +4 -1
  282. synth_ai/utils/json.py +72 -0
  283. synth_ai/utils/modal.py +285 -3
  284. synth_ai/utils/paths.py +48 -0
  285. synth_ai/utils/uvicorn.py +113 -0
  286. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/METADATA +109 -6
  287. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/RECORD +291 -142
  288. examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
  289. synth_ai/cli/tui.py +0 -62
  290. synth_ai/tui/__init__.py +0 -5
  291. synth_ai/tui/__main__.py +0 -13
  292. synth_ai/tui/cli/__init__.py +0 -1
  293. synth_ai/tui/cli/query_experiments.py +0 -164
  294. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  295. synth_ai/tui/dashboard.py +0 -911
  296. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
  297. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
  298. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
  299. {synth_ai-0.2.16.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py CHANGED
@@ -9,19 +9,16 @@ import hashlib
9
9
  import importlib
10
10
  import importlib.util
11
11
  import inspect
12
- import json
13
12
  import os
14
13
  import shlex
15
14
  import shutil
16
15
  import signal
17
- import sqlite3
18
16
  import subprocess
19
17
  import sys
20
18
  import tempfile
21
19
  import textwrap
22
20
  import time
23
21
  import types
24
- import uuid
25
22
  from collections.abc import Callable, Iterable, Iterator, Sequence
26
23
  from dataclasses import dataclass
27
24
  from datetime import UTC, datetime
@@ -35,6 +32,8 @@ except Exception: # pragma: no cover - fallback
35
32
 
36
33
  import click
37
34
  from click.exceptions import Abort
35
+ from synth_ai.cli.commands.eval import core as eval_core
36
+ from synth_ai.cli.commands.filter import core as filter_core
38
37
 
39
38
  # Tracing imports - make conditional for optional dependencies
40
39
  try:
@@ -269,20 +268,25 @@ def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlan
269
268
  json_payload=content_payload.get("json_payload"),
270
269
  )
271
270
  raw_type = (payload.get("message_type") or "").lower()
272
- if raw_type == "observation":
271
+ original_type = payload.get("message_type") or raw_type
272
+
273
+ if raw_type in ("observation", "policy_system_prompt"):
273
274
  normalized_type = "system"
274
- elif raw_type == "action":
275
+ elif raw_type in ("action", "policy_tool_call"):
275
276
  normalized_type = "assistant"
276
277
  elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
277
278
  normalized_type = raw_type
278
279
  else:
279
280
  normalized_type = "system"
280
281
 
282
+ metadata = dict(payload.get("metadata") or {})
283
+ metadata["original_message_type"] = original_type
284
+
281
285
  return SessionEventMarkovBlanketMessage(
282
286
  content=content,
283
287
  message_type=normalized_type,
284
288
  time_record=_time_record_from_dict(payload.get("time_record")),
285
- metadata=payload.get("metadata") or {},
289
+ metadata=metadata,
286
290
  )
287
291
 
288
292
 
@@ -506,49 +510,6 @@ def _candidate_search_roots() -> list[Path]:
506
510
  return ordered
507
511
 
508
512
 
509
- def _eval_config_sort_key(path: Path) -> tuple[int, int, int, str]:
510
- name = path.name.lower()
511
- parent_names = {p.name.lower() for p in path.parents}
512
- in_configs = 0 if "configs" in parent_names else 1
513
- in_examples = 0 if "examples" in parent_names else 1
514
- starts_eval = 0 if name.startswith("eval") else 1
515
- return (in_configs, in_examples, starts_eval, str(path))
516
-
517
-
518
- def _discover_eval_config_paths() -> list[Path]:
519
- """Find candidate eval TOML files near the current working directory."""
520
-
521
- candidates: list[Path] = []
522
- seen: set[Path] = set()
523
- search_roots = _candidate_search_roots()
524
- for root in search_roots:
525
- if not root.exists() or not root.is_dir():
526
- continue
527
- try:
528
- root = root.resolve()
529
- except Exception:
530
- continue
531
- for path in root.rglob("*.toml"):
532
- if not path.is_file():
533
- continue
534
- if _should_ignore_path(path):
535
- continue
536
- name_lower = path.name.lower()
537
- if "eval" not in name_lower and "evaluation" not in name_lower:
538
- continue
539
- try:
540
- resolved = path.resolve()
541
- except Exception:
542
- continue
543
- if resolved in seen:
544
- continue
545
- seen.add(resolved)
546
- candidates.append(resolved)
547
-
548
- candidates.sort(key=_eval_config_sort_key)
549
- return candidates
550
-
551
-
552
513
  class _TaskAppConfigVisitor(ast.NodeVisitor):
553
514
  def __init__(self) -> None:
554
515
  self.matches: list[tuple[str, int]] = []
@@ -2264,7 +2225,6 @@ def validate_task_app_cmd(
2264
2225
  • Debug failing deployments: Use --verbose to see detailed endpoint responses
2265
2226
  • Test API key configuration: Verify authentication is set up correctly
2266
2227
  """
2267
- import asyncio
2268
2228
  import socket
2269
2229
  import subprocess
2270
2230
  import tempfile
@@ -2471,49 +2431,7 @@ def serve_command(
2471
2431
  trace_dir: str | None,
2472
2432
  trace_db: str | None,
2473
2433
  ) -> None:
2474
- demo_dir_path = _load_demo_directory()
2475
- if demo_dir_path:
2476
- if not demo_dir_path.is_dir():
2477
- raise click.ClickException(
2478
- f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai setup' to create a demo."
2479
- )
2480
- os.chdir(demo_dir_path)
2481
- click.echo(f"Using demo directory: {demo_dir_path}\n")
2482
- os.environ["SYNTH_DEMO_DIR"] = str(demo_dir_path.resolve())
2483
-
2484
- # Prompt for port if not provided
2485
- if port is None:
2486
- port = click.prompt("Port to serve on", type=int, default=8001)
2487
-
2488
- # Prompt for trace directory if not provided
2489
- if trace_dir is None:
2490
- click.echo(
2491
- "\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
2492
- )
2493
- click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
2494
- enable_tracing = click.confirm("Enable tracing?", default=True)
2495
- if enable_tracing:
2496
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2497
- default_trace_dir = str((demo_base / "traces/v3").resolve())
2498
- trace_dir = click.prompt(
2499
- "Trace directory", type=str, default=default_trace_dir, show_default=True
2500
- )
2501
- else:
2502
- trace_dir = None
2503
-
2504
- # Prompt for trace DB if not provided and tracing is enabled
2505
- if trace_dir and trace_db is None:
2506
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2507
- default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
2508
- trace_db = click.prompt(
2509
- "Trace DB path", type=str, default=default_trace_db, show_default=True
2510
- )
2511
-
2512
- choice = _select_app_choice(app_id, purpose="serve")
2513
- entry = choice.ensure_entry()
2514
- _serve_entry(
2515
- entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
2516
- )
2434
+ return None
2517
2435
 
2518
2436
 
2519
2437
  @task_app_group.command("info")
@@ -2625,51 +2543,53 @@ def serve_task_group(
2625
2543
  trace_dir: str | None,
2626
2544
  trace_db: str | None,
2627
2545
  ) -> None:
2628
- demo_dir_path = _load_demo_directory()
2629
- if demo_dir_path:
2630
- if not demo_dir_path.is_dir():
2631
- raise click.ClickException(
2632
- f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai setup' to create a demo."
2633
- )
2634
- os.chdir(demo_dir_path)
2635
- click.echo(f"Using demo directory: {demo_dir_path}\n")
2636
- os.environ["SYNTH_DEMO_DIR"] = str(demo_dir_path.resolve())
2637
-
2638
- # Prompt for port if not provided
2546
+ """Serve a TaskAppConfig-based task app using uvicorn."""
2547
+ import contextlib
2548
+
2549
+ if not host:
2550
+ host = "0.0.0.0"
2551
+
2639
2552
  if port is None:
2640
- port = click.prompt("Port to serve on", type=int, default=8001)
2641
-
2642
- # Prompt for trace directory if not provided
2643
- if trace_dir is None:
2644
- click.echo(
2645
- "\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
2646
- )
2647
- click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
2648
- enable_tracing = click.confirm("Enable tracing?", default=True)
2649
- if enable_tracing:
2650
- demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2651
- default_trace_dir = str((demo_base / "traces/v3").resolve())
2652
- trace_dir = click.prompt(
2653
- "Trace directory", type=str, default=default_trace_dir, show_default=True
2654
- )
2655
- else:
2656
- trace_dir = None
2553
+ port = 8001
2554
+
2555
+ # Auto-enable tracing by default
2556
+ try:
2557
+ auto_trace = os.getenv("SYNTH_AUTO_TRACE", "1")
2558
+ auto_trace_enabled = auto_trace not in {"0", "false", "False", ""}
2559
+ except Exception:
2560
+ auto_trace_enabled = True
2657
2561
 
2658
- # Prompt for trace DB if not provided and tracing is enabled
2659
- if trace_dir and trace_db is None:
2562
+ if auto_trace_enabled:
2660
2563
  demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
2661
- default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
2662
- trace_db = click.prompt(
2663
- "Trace DB path", type=str, default=default_trace_db, show_default=True
2664
- )
2665
-
2564
+ if trace_dir is None:
2565
+ default_trace_dir = (demo_base / "traces" / "v3").resolve()
2566
+ with contextlib.suppress(Exception):
2567
+ default_trace_dir.mkdir(parents=True, exist_ok=True)
2568
+ trace_dir = str(default_trace_dir)
2569
+ click.echo(f"[trace] Using trace directory: {trace_dir}")
2570
+ if trace_dir and trace_db is None:
2571
+ default_trace_db = (Path(trace_dir) / "synth_ai.db").resolve()
2572
+ with contextlib.suppress(Exception):
2573
+ default_trace_db.parent.mkdir(parents=True, exist_ok=True)
2574
+ trace_db = str(default_trace_db)
2575
+ click.echo(f"[trace] Using trace DB: {trace_db}")
2576
+
2577
+ # Select and serve the app
2666
2578
  choice = _select_app_choice(app_id, purpose="serve")
2667
2579
  entry = choice.ensure_entry()
2668
2580
  _serve_entry(
2669
- entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
2581
+ entry,
2582
+ host,
2583
+ port,
2584
+ env_file,
2585
+ reload_flag,
2586
+ force,
2587
+ trace_dir=trace_dir,
2588
+ trace_db=trace_db,
2670
2589
  )
2671
2590
 
2672
2591
 
2592
+
2673
2593
  def _determine_env_files(
2674
2594
  entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
2675
2595
  ) -> list[Path]:
@@ -2962,87 +2882,6 @@ def _serve_entry(
2962
2882
  )
2963
2883
 
2964
2884
 
2965
- @task_app_group.command("deploy")
2966
- @click.argument("app_id", type=str, required=False)
2967
- @click.option("--name", "modal_name", default=None, help="Override Modal app name")
2968
- @click.option("--dry-run", is_flag=True, help="Print modal deploy command without executing")
2969
- @click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
2970
- @click.option(
2971
- "--env-file",
2972
- multiple=True,
2973
- type=click.Path(),
2974
- help="Env file to load into the container (can be repeated)",
2975
- )
2976
- def deploy_app(
2977
- app_id: str | None,
2978
- modal_name: str | None,
2979
- dry_run: bool,
2980
- modal_cli: str,
2981
- env_file: Sequence[str],
2982
- ) -> None:
2983
- """Deploy a task app to Modal."""
2984
-
2985
- demo_dir_path = _load_demo_directory()
2986
- if demo_dir_path:
2987
- if not demo_dir_path.is_dir():
2988
- raise click.ClickException(
2989
- f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai demo' to create a demo."
2990
- )
2991
- os.chdir(demo_dir_path)
2992
- click.echo(f"Using demo directory: {demo_dir_path}\n")
2993
-
2994
- choice = _select_app_choice(app_id, purpose="deploy")
2995
-
2996
- if choice.modal_script:
2997
- env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
2998
- click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
2999
- _run_modal_script(
3000
- choice.modal_script,
3001
- modal_cli,
3002
- "deploy",
3003
- env_paths,
3004
- modal_name=modal_name,
3005
- dry_run=dry_run,
3006
- )
3007
- return
3008
-
3009
- entry = choice.ensure_entry()
3010
- _deploy_entry(entry, modal_name, dry_run, modal_cli, env_file, original_path=choice.path)
3011
-
3012
-
3013
- @task_app_group.command("modal-serve")
3014
- @click.argument("app_id", type=str, required=False)
3015
- @click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
3016
- @click.option("--name", "modal_name", default=None, help="Override Modal app name (optional)")
3017
- @click.option(
3018
- "--env-file",
3019
- multiple=True,
3020
- type=click.Path(),
3021
- help="Env file to load into the container (can be repeated)",
3022
- )
3023
- def modal_serve_app(
3024
- app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
3025
- ) -> None:
3026
- click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
3027
- try:
3028
- choice = _select_app_choice(app_id, purpose="modal-serve")
3029
- except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
3030
- raise click.ClickException(
3031
- f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
3032
- "Make sure you're running the Click CLI (synth_ai.cli:cli)."
3033
- ) from exc
3034
-
3035
- if choice.modal_script:
3036
- env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
3037
- click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
3038
- _run_modal_script(choice.modal_script, modal_cli, "serve", env_paths, modal_name=modal_name)
3039
- return
3040
-
3041
- entry = choice.ensure_entry()
3042
- click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
3043
- _modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
3044
-
3045
-
3046
2885
  def _write_modal_entrypoint(
3047
2886
  entry: TaskAppEntryType,
3048
2887
  modal_cfg: ModalDeploymentConfigType,
@@ -3286,1277 +3125,9 @@ def register(cli: click.Group) -> None:
3286
3125
  cli.add_command(filter_command)
3287
3126
 
3288
3127
 
3289
- @click.command(
3290
- "eval",
3291
- help="Run one-off rollouts against a task app and print judge/eval summaries.",
3292
- )
3293
- @click.argument("app_id", type=str, required=False)
3294
- @click.option(
3295
- "--config",
3296
- type=click.Path(),
3297
- default=None,
3298
- help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
3299
- )
3300
- @click.option(
3301
- "--url",
3302
- "task_app_url",
3303
- type=str,
3304
- default=None,
3305
- help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
3306
- )
3307
- @click.option(
3308
- "--seeds",
3309
- default="0,1,2,3,4",
3310
- help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
3311
- )
3312
- @click.option("--split", default="train", show_default=True, help="Dataset split to use")
3313
- @click.option(
3314
- "--model",
3315
- default=None,
3316
- help="Model identifier. When omitted the CLI will prompt based on task metadata.",
3317
- )
3318
- @click.option(
3319
- "--env-file",
3320
- multiple=True,
3321
- type=click.Path(),
3322
- help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
3323
- )
3324
- @click.option(
3325
- "--trace-db",
3326
- default="traces/v3/synth_ai.db",
3327
- show_default=True,
3328
- help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
3329
- )
3330
- @click.option(
3331
- "--metadata",
3332
- multiple=True,
3333
- help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
3334
- )
3335
- @click.option(
3336
- "--metadata-sql",
3337
- default=None,
3338
- help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
3339
- )
3340
- def eval_command(
3341
- app_id: str | None,
3342
- config: str | None,
3343
- task_app_url: str | None,
3344
- seeds: str,
3345
- split: str,
3346
- model: str | None,
3347
- env_file: Sequence[str],
3348
- trace_db: str,
3349
- metadata: Sequence[str],
3350
- metadata_sql: str | None,
3351
- ) -> None:
3352
- """Run rollouts against a task app and report judge statistics.
3353
-
3354
- By default the command spins up the selected task app in-process, executes the
3355
- requested seeds, and prints aggregate scores (official and custom judges). When
3356
- pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
3357
- forward authentication headers to the running service.
3358
- """
3359
- # Parse and validate TOML config
3360
- from synth_ai.task.config import EvalConfig
3361
-
3362
- cfg: dict[str, Any] = {}
3363
- eval_cfg: EvalConfig | None = None
3364
- config_path: Path | None = None
3365
-
3366
- if config:
3367
- config_path = Path(config)
3368
- else:
3369
- auto_configs = _discover_eval_config_paths()
3370
- if auto_configs:
3371
- config_path = auto_configs[0]
3372
- click.echo(f"Using eval config: {config_path}")
3373
-
3374
- if config_path:
3375
- if _toml is None:
3376
- raise click.ClickException(
3377
- "TOML parser not available; use Python 3.11+ or install tomli"
3378
- )
3379
- if not config_path.exists():
3380
- raise click.ClickException(f"Eval config not found: {config_path}")
3381
- try:
3382
- data = config_path.read_bytes()
3383
- parsed = _toml.loads(data.decode("utf-8"))
3384
- if isinstance(parsed, dict):
3385
- section = parsed.get("eval")
3386
- cfg = dict(section) if isinstance(section, dict) else dict(parsed)
3387
-
3388
- # Validate config with dataclass
3389
- try:
3390
- eval_cfg = EvalConfig.from_dict(cfg)
3391
- click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
3392
- except (ValueError, TypeError) as validation_error:
3393
- raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
3394
- except click.ClickException:
3395
- raise
3396
- except Exception as exc:
3397
- raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
3398
-
3399
- # CLI args override config
3400
- if eval_cfg:
3401
- app_id = app_id or eval_cfg.app_id
3402
- else:
3403
- app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
3404
-
3405
- metadata_filters: dict[str, str] = {}
3406
- if eval_cfg:
3407
- metadata_filters.update(eval_cfg.metadata)
3408
- else:
3409
- cfg_metadata = cfg.get("metadata")
3410
- if isinstance(cfg_metadata, dict):
3411
- for key, value in cfg_metadata.items():
3412
- metadata_filters[str(key)] = str(value)
3413
- elif isinstance(cfg_metadata, list):
3414
- for item in cfg_metadata:
3415
- if isinstance(item, str) and "=" in item:
3416
- key, value = item.split("=", 1)
3417
- metadata_filters[key.strip()] = value.strip()
3418
-
3419
- for item in metadata or ():
3420
- if "=" not in item:
3421
- raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
3422
- key, value = item.split("=", 1)
3423
- key = key.strip()
3424
- value = value.strip()
3425
- if not key or not value:
3426
- raise click.ClickException(f"Invalid metadata filter: {item}")
3427
- metadata_filters[key] = value
3428
-
3429
- metadata_sql_query: str | None = None
3430
- if eval_cfg and eval_cfg.metadata_sql:
3431
- metadata_sql_query = eval_cfg.metadata_sql
3432
- else:
3433
- cfg_metadata_sql = cfg.get("metadata_sql")
3434
- if isinstance(cfg_metadata_sql, dict):
3435
- metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
3436
- elif isinstance(cfg_metadata_sql, str):
3437
- metadata_sql_query = cfg_metadata_sql
3438
-
3439
- if metadata_sql:
3440
- metadata_sql_query = metadata_sql
3441
- if metadata_sql_query is not None:
3442
- metadata_sql_query = str(metadata_sql_query)
3443
-
3444
- trace_db_url: str | None = None
3445
- trace_db = (trace_db or "").strip()
3446
- if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
3447
- if "://" in trace_db:
3448
- trace_db_url = trace_db
3449
- else:
3450
- trace_path = Path(trace_db).expanduser()
3451
- trace_path.parent.mkdir(parents=True, exist_ok=True)
3452
- trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
3453
- trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
3454
-
3455
- # Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
3456
- if cfg.get("model") and not model:
3457
- model = str(cfg["model"]) # type: ignore[index]
3458
- if cfg.get("seeds") and seeds == "0,1,2,3,4":
3459
- val = cfg["seeds"]
3460
- if isinstance(val, list):
3461
- with contextlib.suppress(Exception):
3462
- seeds = ",".join(str(int(x)) for x in val)
3463
- elif isinstance(val, str):
3464
- seeds = val
3465
- elif isinstance(val, int):
3466
- seeds = str(val)
3467
- if cfg.get("env_file") and not env_file:
3468
- ef = cfg["env_file"]
3469
- if isinstance(ef, str):
3470
- env_file = (ef,) # type: ignore[assignment]
3471
- elif isinstance(ef, list):
3472
- env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
3473
-
3474
- choice_for_env: AppChoice | None = None
3475
- entry: TaskAppEntryType | None = None
3476
- if task_app_url is None:
3477
- choice_for_env = _select_app_choice(app_id, purpose="eval")
3478
- entry = choice_for_env.ensure_entry()
3479
-
3480
- env_paths: list[Path] = []
3481
- if entry is not None:
3482
- original_env_path = choice_for_env.path if choice_for_env is not None else None
3483
- env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
3484
- else:
3485
- if not env_file:
3486
- raise click.ClickException("--env-file is required when using --url")
3487
- for candidate in env_file:
3488
- p = Path(candidate).expanduser()
3489
- if not p.exists():
3490
- raise click.ClickException(f"Env file not found: {p}")
3491
- env_paths.append(p)
3492
-
3493
- click.echo("Using env file(s): " + ", ".join(str(p) for p in env_paths))
3494
- _load_env_files_into_process([str(Path(p)) for p in env_paths])
3495
-
3496
- if task_app_url is None:
3497
- config = entry.config_factory() # type: ignore[union-attr]
3498
- # Help the type checker; runtime check also enforced in server.run_task_app
3499
- if not isinstance(config, TaskAppConfig):
3500
- raise click.ClickException(
3501
- "Invalid task app: config_factory did not return TaskAppConfig"
3502
- )
3503
- app = create_task_app(config)
3504
-
3505
- # Determine supported models
3506
- inference_meta: dict[str, Any] = {}
3507
- supported: list[str] = []
3508
- seen_models: set[str] = set()
3509
-
3510
- def _add_supported_model(candidate: Any) -> None:
3511
- if not candidate:
3512
- return
3513
- text = str(candidate).strip()
3514
- if not text or text in seen_models:
3515
- return
3516
- supported.append(text)
3517
- seen_models.add(text)
3518
-
3519
- if task_app_url is None:
3520
- try:
3521
- if hasattr(config, "base_task_info") and config.base_task_info:
3522
- inf_obj = getattr(config.base_task_info, "inference", None)
3523
- if inf_obj is not None:
3524
- if hasattr(inf_obj, "model_dump"):
3525
- inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
3526
- elif isinstance(inf_obj, dict):
3527
- inference_meta = dict(inf_obj)
3528
- except Exception:
3529
- inference_meta = {}
3530
- else:
3531
- try:
3532
- import httpx as _hx
3533
-
3534
- headers = {}
3535
- api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
3536
- if api_key:
3537
- headers["X-API-Key"] = api_key
3538
- with _hx.Client(base_url=task_app_url, headers=headers, timeout=15.0) as c:
3539
- info = c.get("/info").json()
3540
- inf = info.get("inference") if isinstance(info, dict) else None
3541
- if isinstance(inf, dict):
3542
- inference_meta = dict(inf)
3543
- except Exception:
3544
- inference_meta = {}
3545
-
3546
- default_model = inference_meta.get("model")
3547
- if isinstance(default_model, str):
3548
- _add_supported_model(default_model)
3549
-
3550
- models_field = inference_meta.get("models")
3551
- if isinstance(models_field, list):
3552
- for candidate in models_field:
3553
- _add_supported_model(candidate)
3554
-
3555
- supported_models = inference_meta.get("supported_models")
3556
- if isinstance(supported_models, list):
3557
- for candidate in supported_models:
3558
- _add_supported_model(candidate)
3559
-
3560
- providers = inference_meta.get("providers")
3561
- if isinstance(providers, list):
3562
- if "openai" in providers:
3563
- _add_supported_model("gpt-5")
3564
- if "groq" in providers:
3565
- _add_supported_model("groq:llama-3.1-70b-versatile")
3566
-
3567
- _add_supported_model("synth:qwen-0.6b")
3568
-
3569
- selected_model = model
3570
- if not selected_model:
3571
- if not supported:
3572
- raise click.ClickException(
3573
- "No supported models; supply --model or add base_task_info.inference.model"
3574
- )
3575
- click.echo("Select model to evaluate:")
3576
- for idx, m in enumerate(supported, start=1):
3577
- click.echo(f" {idx}) {m}")
3578
- choice_idx = click.prompt("Enter choice", type=click.IntRange(1, len(supported)))
3579
- selected_model = supported[choice_idx - 1]
3580
-
3581
- try:
3582
- seed_values = [int(s.strip()) for s in seeds.split(",") if s.strip()]
3583
- except Exception as exc:
3584
- raise click.ClickException("Invalid --seeds; expected comma-separated integers") from exc
3585
-
3586
- import httpx
3587
-
3588
- headers = {}
3589
- api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
3590
- if api_key:
3591
- headers["X-API-Key"] = api_key
3592
-
3593
- # Precompute optional policy overrides from TOML
3594
- policy_overrides: dict[str, Any] = {}
3595
- try:
3596
- # Accept [eval.policy] table or top-level keys for convenience
3597
- if isinstance(cfg.get("policy"), dict):
3598
- policy_overrides.update(dict(cfg["policy"]))
3599
- # Back-compat: allow temperature/max_tokens at top level
3600
- for k in (
3601
- "temperature",
3602
- "max_tokens",
3603
- "reasoning_effort",
3604
- "system_hint",
3605
- "tool_choice",
3606
- "inference_url",
3607
- ):
3608
- if k in cfg and k not in policy_overrides:
3609
- policy_overrides[k] = cfg.get(k)
3610
- except Exception:
3611
- policy_overrides = {}
3612
-
3613
- raw_concurrency = cfg.get("concurrency")
3614
- try:
3615
- concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
3616
- except Exception:
3617
- concurrency_limit = 1
3618
- if concurrency_limit <= 0:
3619
- concurrency_limit = 1
3620
- concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
3621
-
3622
- judge_specs: list[JudgeSpec] = []
3623
-
3624
- def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
3625
- if not judge_cfg:
3626
- return
3627
- judge_module = judge_cfg.get("module")
3628
- judge_path = judge_cfg.get("path")
3629
- judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
3630
- if judge_module and judge_path:
3631
- raise click.ClickException("Judge config cannot set both 'module' and 'path'")
3632
- if not judge_module and not judge_path:
3633
- raise click.ClickException("Judge config requires 'module' or 'path'")
3634
- try:
3635
- if judge_module:
3636
- module = importlib.import_module(str(judge_module))
3637
- else:
3638
- path = Path(str(judge_path)).expanduser()
3639
- if not path.exists():
3640
- raise click.ClickException(f"Judge module path not found: {path}")
3641
- spec = importlib.util.spec_from_file_location(
3642
- f"_eval_judge_{path.stem}", path
3643
- )
3644
- if not spec or not spec.loader:
3645
- raise click.ClickException(f"Failed to load judge module from {path}")
3646
- module = importlib.util.module_from_spec(spec)
3647
- sys.modules[spec.name] = module
3648
- spec.loader.exec_module(module)
3649
- except click.ClickException:
3650
- raise
3651
- except Exception as exc:
3652
- raise click.ClickException(f"Unable to load judge module: {exc}") from exc
3653
-
3654
- if judge_callable_name:
3655
- try:
3656
- judge_fn = getattr(module, str(judge_callable_name))
3657
- except AttributeError as exc:
3658
- raise click.ClickException(
3659
- f"Judge callable '{judge_callable_name}' not found in module"
3660
- ) from exc
3661
- else:
3662
- if hasattr(module, "judge"):
3663
- judge_fn = module.judge
3664
- else:
3665
- raise click.ClickException("Judge module must expose 'judge' callable")
3666
-
3667
- if not callable(judge_fn):
3668
- raise click.ClickException("Judge callable is not callable")
3669
-
3670
- judge_kwargs = {
3671
- k: v
3672
- for k, v in judge_cfg.items()
3673
- if k not in {"module", "path", "callable", "function", "name"}
3674
- }
3675
- display_name = str(
3676
- judge_cfg.get("name")
3677
- or name_hint
3678
- or f"judge{len(judge_specs) + 1}"
3679
- )
3680
- judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
3681
-
3682
- raw_judge_cfg = cfg.get("judge")
3683
- if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
3684
- direct_keys = {"module", "path", "callable", "function", "name"}
3685
- has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
3686
- nested_candidates = [
3687
- (key, value)
3688
- for key, value in raw_judge_cfg.items()
3689
- if isinstance(value, dict)
3690
- ]
3691
- if has_direct_keys and not nested_candidates:
3692
- _register_judge(None, raw_judge_cfg)
3693
- else:
3694
- for sub_name, sub_cfg in nested_candidates:
3695
- _register_judge(sub_name, sub_cfg)
3696
-
3697
- raw_judges_list = cfg.get("judges")
3698
- if isinstance(raw_judges_list, list):
3699
- for _index, entry in enumerate(raw_judges_list, start=1):
3700
- if isinstance(entry, dict):
3701
- _register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
3702
-
3703
- records: list[dict[str, Any]] = []
3704
-
3705
- successes = 0
3706
- failures = 0
3707
- # Aggregate outcome stats across successful seeds
3708
- outcome_sum: float = 0.0
3709
- outcome_count: int = 0
3710
- outcome_correct: int = 0
3711
-
3712
- def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
3713
- rows: dict[int, dict[str, Any]] = {}
3714
- if not isinstance(taskset, dict):
3715
- return rows
3716
-
3717
- scenario_ids = taskset.get("scenario_ids") or []
3718
- loop_ids = taskset.get("loop_ids") or []
3719
- thread_ids = taskset.get("thread_ids") or []
3720
- difficulty_map = taskset.get("difficulty_map") or {}
3721
-
3722
- max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
3723
- for seed in range(max_len):
3724
- scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
3725
- loop_id = loop_ids[seed] if seed < len(loop_ids) else None
3726
- thread_id = thread_ids[seed] if seed < len(thread_ids) else None
3727
- difficulty = None
3728
- if isinstance(difficulty_map, dict):
3729
- if scenario_id and scenario_id in difficulty_map:
3730
- difficulty = difficulty_map.get(scenario_id)
3731
- elif str(seed) in difficulty_map:
3732
- difficulty = difficulty_map.get(str(seed))
3733
-
3734
- rows[seed] = {
3735
- "seed": seed,
3736
- "scenario_id": scenario_id,
3737
- "loop_id": loop_id,
3738
- "thread_id": thread_id,
3739
- "difficulty": difficulty,
3740
- }
3741
- return rows
3742
-
3743
- def _apply_metadata_filters(
3744
- rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
3745
- ) -> list[int]:
3746
- if not filters:
3747
- return seeds_list
3748
- filtered: list[int] = []
3749
- for seed in seeds_list:
3750
- row = rows.get(seed)
3751
- if not row:
3752
- continue
3753
- include = True
3754
- for key, expected in filters.items():
3755
- actual = row.get(key)
3756
- if actual is None:
3757
- include = False
3758
- break
3759
- if str(actual).lower() != expected.lower():
3760
- include = False
3761
- break
3762
- if include:
3763
- filtered.append(seed)
3764
- return filtered
3765
-
3766
- def _apply_metadata_sql(
3767
- rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
3768
- ) -> list[int]:
3769
- """Return seeds that satisfy an arbitrary SQL query.
3770
-
3771
- The query is executed against an in-memory SQLite table named `tasks`
3772
- with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
3773
- Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
3774
- """
3775
- if not query:
3776
- return seeds_list
3777
- conn = sqlite3.connect(":memory:")
3778
- try:
3779
- cur = conn.cursor()
3780
- cur.execute(
3781
- "CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
3782
- )
3783
- insert_stmt = (
3784
- "INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
3785
- )
3786
- for seed in seeds_list:
3787
- row = rows.get(seed, {})
3788
- cur.execute(
3789
- insert_stmt,
3790
- [
3791
- seed,
3792
- row.get("scenario_id"),
3793
- row.get("loop_id"),
3794
- row.get("thread_id"),
3795
- row.get("difficulty"),
3796
- ],
3797
- )
3798
-
3799
- result = cur.execute(query)
3800
- fetched = result.fetchall()
3801
- if not fetched:
3802
- return []
3803
- description = result.description or []
3804
- col_names = [col[0] for col in description]
3805
- seeds_out: list[int] = []
3806
- for entry in fetched:
3807
- value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
3808
- try:
3809
- seeds_out.append(int(value))
3810
- except Exception as exc:
3811
- raise click.ClickException(
3812
- "metadata SQL query must return seed integers"
3813
- ) from exc
3814
- seeds_set = set(seeds_out)
3815
- return [seed for seed in seeds_list if seed in seeds_set]
3816
- except sqlite3.Error as exc:
3817
- raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
3818
- finally:
3819
- conn.close()
3820
-
3821
- async def _run_eval() -> None:
3822
- nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
3823
-
3824
- if trace_tracer is not None and trace_tracer.db is None:
3825
- await trace_tracer.initialize()
3826
-
3827
- if task_app_url is None:
3828
- transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
3829
- async_client = httpx.AsyncClient(
3830
- transport=cast(Any, transport),
3831
- base_url="http://eval.local",
3832
- timeout=300.0,
3833
- follow_redirects=True,
3834
- headers=headers,
3835
- )
3836
- else:
3837
- async_client = httpx.AsyncClient(
3838
- base_url=task_app_url,
3839
- timeout=300.0,
3840
- follow_redirects=True,
3841
- headers=headers,
3842
- )
3843
-
3844
- try:
3845
- taskset_payload: dict[str, Any] | None = None
3846
- try:
3847
- task_info_response = await async_client.get("/task_info")
3848
- except Exception:
3849
- task_info_response = None
3850
- if task_info_response is not None and task_info_response.status_code == 200:
3851
- with contextlib.suppress(Exception):
3852
- payload_json = task_info_response.json()
3853
- if isinstance(payload_json, dict) and "taskset" in payload_json:
3854
- taskset_payload = payload_json.get("taskset")
3855
- if not isinstance(taskset_payload, dict):
3856
- taskset_payload = None
3857
- elif isinstance(payload_json, dict):
3858
- taskset_payload = payload_json
3859
-
3860
- available_seeds = list(seed_values)
3861
- if metadata_sql_query or metadata_filters:
3862
- if not taskset_payload:
3863
- raise click.ClickException(
3864
- "Task metadata filters require the task app to expose /task_info metadata"
3865
- )
3866
- rows = _build_task_rows(taskset_payload)
3867
- if metadata_sql_query:
3868
- available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
3869
- if metadata_filters:
3870
- available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
3871
- if not available_seeds:
3872
- raise click.ClickException("No seeds match the provided metadata filters")
3873
- seed_values = available_seeds
3874
-
3875
- semaphore = asyncio.Semaphore(concurrency_limit)
3876
-
3877
- async def _run_seed(seed_val: int) -> None:
3878
- nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
3879
- # Read env_name and policy_name from config if available
3880
- env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
3881
- policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
3882
- env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
3883
- policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
3884
-
3885
- # Debug: print config parsing
3886
- if seed_val == 0:
3887
- click.echo(f"[DEBUG] env_name from config: {env_name}")
3888
- click.echo(f"[DEBUG] policy_name from config: {policy_name}")
3889
-
3890
- # Generate default ops sequence if not provided
3891
- max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
3892
- ops_list = cfg.get("ops", [])
3893
- if not ops_list:
3894
- # Generate default "agent, env" pairs for max_llm_calls
3895
- ops_list = ["agent", "env"] * int(max_llm_calls)
3896
-
3897
- body = {
3898
- "run_id": str(uuid.uuid4()),
3899
- "env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
3900
- "policy": {
3901
- "policy_name": policy_name or selected_model,
3902
- "config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
3903
- },
3904
- "ops": ops_list,
3905
- "record": {
3906
- "return_trace": cfg.get("return_trace", True),
3907
- "trace_format": cfg.get("trace_format", "structured"),
3908
- },
3909
- "mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
3910
- }
3911
- if env_name:
3912
- body["env"]["env_name"] = env_name
3913
-
3914
- # Debug: print the body being sent
3915
- if seed_val == 0:
3916
- click.echo(f"[DEBUG] rollout body env: {body['env']}")
3917
- click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
3918
- click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
3919
- rollout_elapsed: float | None = None
3920
- rollout_start = time.perf_counter()
3921
- try:
3922
- import logging
3923
- _log = logging.getLogger(__name__)
3924
- _log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
3925
- async with semaphore:
3926
- response = await async_client.post("/rollout", json=body)
3927
- rollout_elapsed = time.perf_counter() - rollout_start
3928
- except Exception as exc:
3929
- failures += 1
3930
- click.echo(f"seed={seed_val} error={exc}")
3931
- return
3932
-
3933
- ok = 200 <= response.status_code < 300
3934
- if ok:
3935
- successes += 1
3936
- else:
3937
- failures += 1
3938
-
3939
- summary = [f"seed={seed_val}", f"status={response.status_code}"]
3940
- data: Any
3941
- try:
3942
- data = response.json()
3943
- except Exception:
3944
- data = None
3945
-
3946
- # Debug: print validation errors
3947
- if response.status_code == 422 and data:
3948
- click.echo(f"[DEBUG] 422 Validation Error: {data}")
3949
-
3950
- metrics: dict[str, Any] | None = None
3951
- completion: str | None = None
3952
- prompt_index: int | None = None
3953
- prompt_text: str | None = None
3954
- task_id: str | None = None
3955
- task_split: str | None = None
3956
- task_rubric_id: str | None = None
3957
-
3958
- trace_namespace: dict[str, Any] | None = None
3959
- session_trace_dict: dict[str, Any] | None = None
3960
-
3961
- if isinstance(data, dict):
3962
- import logging
3963
- _logger = logging.getLogger(__name__)
3964
- _logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
3965
- if "detail" in data:
3966
- _logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
3967
- trace_namespace = data.get("trace")
3968
- _logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
3969
- if not isinstance(trace_namespace, dict):
3970
- raise RuntimeError(
3971
- "The 'synth-ai eval' command requires trace payloads in rollout responses. "
3972
- "Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
3973
- "and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
3974
- "Note: This is specific to the eval command - general rollout endpoints don't require traces."
3975
- )
3976
- # Handle both "compact" and "full" trace formats:
3977
- # - compact: trace_namespace contains {session_id, metadata, ...}
3978
- # - full: trace_namespace IS the full session_trace dict
3979
- session_trace_dict = trace_namespace.get("session_trace")
3980
- if not isinstance(session_trace_dict, dict):
3981
- # If no session_trace key, assume "full" format where trace itself is the session_trace
3982
- if "session_id" in trace_namespace:
3983
- session_trace_dict = trace_namespace
3984
- else:
3985
- raise RuntimeError(
3986
- "The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
3987
- "Ensure the task app is using tracing_v3 and returning structured trace data."
3988
- )
3989
- metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
3990
- if metrics:
3991
- mean_return = metrics.get("mean_return") or metrics.get("total_reward")
3992
- outcome = metrics.get("outcome_score")
3993
- if mean_return is not None:
3994
- summary.append(f"mean_return={mean_return}")
3995
- if outcome is not None:
3996
- summary.append(f"outcome={outcome}")
3997
- try:
3998
- val = float(outcome)
3999
- outcome_sum += val
4000
- outcome_count += 1
4001
- if val >= 0.5:
4002
- outcome_correct += 1
4003
- except Exception:
4004
- pass
4005
- trajs = (
4006
- data.get("trajectories")
4007
- if isinstance(data.get("trajectories"), list)
4008
- else None
4009
- )
4010
- if trajs:
4011
- first = trajs[0] if trajs else None
4012
- steps = first.get("steps") if isinstance(first, dict) else None
4013
- if isinstance(steps, list) and steps:
4014
- step0 = steps[0]
4015
- tool_calls = step0.get("tool_calls") or step0.get("tools") or []
4016
- if isinstance(tool_calls, list):
4017
- summary.append(f"tool_calls={len(tool_calls)}")
4018
- obs = step0.get("obs") if isinstance(step0, dict) else None
4019
- if isinstance(obs, dict):
4020
- idx_val = obs.get("prompt_index")
4021
- if isinstance(idx_val, int):
4022
- prompt_index = idx_val
4023
- prompt_raw = obs.get("prompt")
4024
- if isinstance(prompt_raw, str):
4025
- prompt_text = prompt_raw
4026
- if task_id is None:
4027
- candidate_id = obs.get("task_id")
4028
- if isinstance(candidate_id, str) and candidate_id:
4029
- task_id = candidate_id
4030
- if task_split is None:
4031
- candidate_split = obs.get("task_split")
4032
- if isinstance(candidate_split, str) and candidate_split:
4033
- task_split = candidate_split
4034
- if task_rubric_id is None:
4035
- candidate_rid = obs.get("task_rubric_id")
4036
- if isinstance(candidate_rid, str) and candidate_rid:
4037
- task_rubric_id = candidate_rid
4038
- final = first.get("final") if isinstance(first, dict) else None
4039
- if isinstance(final, dict):
4040
- final_obs = final.get("observation")
4041
- if isinstance(final_obs, dict):
4042
- comp_val = final_obs.get("completion")
4043
- if isinstance(comp_val, str):
4044
- completion = comp_val
4045
- if task_id is None:
4046
- candidate_id = final_obs.get("task_id")
4047
- if isinstance(candidate_id, str) and candidate_id:
4048
- task_id = candidate_id
4049
- if task_split is None:
4050
- candidate_split = final_obs.get("task_split")
4051
- if isinstance(candidate_split, str) and candidate_split:
4052
- task_split = candidate_split
4053
- if task_rubric_id is None:
4054
- candidate_rid = final_obs.get("task_rubric_id")
4055
- if isinstance(candidate_rid, str) and candidate_rid:
4056
- task_rubric_id = candidate_rid
4057
- final_info = final.get("info")
4058
- if isinstance(final_info, dict):
4059
- if task_id is None:
4060
- candidate_id = final_info.get("task_id")
4061
- if isinstance(candidate_id, str) and candidate_id:
4062
- task_id = candidate_id
4063
- if task_split is None:
4064
- candidate_split = final_info.get("task_split")
4065
- if isinstance(candidate_split, str) and candidate_split:
4066
- task_split = candidate_split
4067
- if task_rubric_id is None:
4068
- candidate_rid = final_info.get("task_rubric_id")
4069
- if isinstance(candidate_rid, str) and candidate_rid:
4070
- task_rubric_id = candidate_rid
4071
- if task_id:
4072
- summary.append(f"task_id={task_id}")
4073
- click.echo(" ".join(summary))
4074
- with contextlib.suppress(Exception):
4075
- click.echo(json.dumps(data, indent=2))
4076
- else:
4077
- click.echo(" ".join(summary))
4078
-
4079
- official_score = None
4080
- if isinstance(metrics, dict):
4081
- for key in ("mean_return", "total_reward", "outcome_score"):
4082
- val = metrics.get(key)
4083
- if isinstance(val, int | float):
4084
- official_score = float(val)
4085
- break
4086
- if official_score is None and isinstance(data, dict):
4087
- try:
4088
- reward_val = data["trajectories"][0]["steps"][0].get("reward")
4089
- if isinstance(reward_val, int | float):
4090
- official_score = float(reward_val)
4091
- except Exception:
4092
- pass
4093
-
4094
- if official_score is not None:
4095
- if official_score < 0.0:
4096
- official_score = 0.0
4097
- elif official_score > 1.0:
4098
- official_score = min(1.0, official_score)
4099
-
4100
- judge_scores: dict[str, float | None] = {}
4101
- judges_timings: dict[str, float | None] = {}
4102
- timings: dict[str, Any] = {
4103
- "rollout_s": rollout_elapsed,
4104
- "judges": judges_timings,
4105
- }
4106
- if judge_specs:
4107
- for spec in judge_specs:
4108
- score_value: float | None = None
4109
- judge_elapsed: float | None = None
4110
- # Run judges for all tasks (text-based and trajectory-based)
4111
- # Text-based tasks have completion, trajectory-based tasks use response
4112
- judge_payload = {
4113
- "seed": seed_val,
4114
- "prompt_index": prompt_index,
4115
- "prompt": prompt_text,
4116
- "completion": completion,
4117
- "metrics": metrics,
4118
- "response": data,
4119
- "trace": trace_namespace,
4120
- }
4121
- try:
4122
- judge_start = time.perf_counter()
4123
- result = spec.fn(judge_payload, **spec.kwargs)
4124
- judge_elapsed = time.perf_counter() - judge_start
4125
- if isinstance(result, int | float):
4126
- score_value = float(result)
4127
- except Exception as exc:
4128
- if judge_elapsed is None:
4129
- judge_elapsed = time.perf_counter() - judge_start
4130
- click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
4131
- judges_timings[spec.name] = judge_elapsed
4132
- judge_scores[spec.name] = score_value
4133
-
4134
- if trace_tracer is not None and trace_namespace:
4135
- storage_metadata = {
4136
- "eval_seed": seed_val,
4137
- "prompt_index": prompt_index,
4138
- "task_id": task_id,
4139
- "task_split": task_split,
4140
- "task_rubric_id": task_rubric_id,
4141
- "official_score": official_score,
4142
- "judge_scores": judge_scores,
4143
- "model": selected_model,
4144
- "prompt": prompt_text,
4145
- "completion": completion,
4146
- }
4147
- await _store_trace(trace_tracer, trace_namespace, storage_metadata)
4148
-
4149
- records.append(
4150
- {
4151
- "seed": seed_val,
4152
- "prompt_index": prompt_index,
4153
- "task_id": task_id,
4154
- "task_split": task_split,
4155
- "task_rubric_id": task_rubric_id,
4156
- "official_score": official_score,
4157
- "judge_scores": judge_scores,
4158
- "timings": timings,
4159
- }
4160
- )
4161
-
4162
- await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
4163
- finally:
4164
- await async_client.aclose()
4165
-
4166
- try:
4167
- asyncio.run(_run_eval())
4168
- finally:
4169
- if trace_tracer is not None and trace_tracer.db is not None:
4170
- asyncio.run(trace_tracer.db.close())
4171
-
4172
- click.echo(
4173
- f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
4174
- )
4175
-
4176
- if outcome_count > 0:
4177
- mean_outcome = outcome_sum / float(outcome_count)
4178
- frac_right = outcome_correct / float(outcome_count)
4179
- click.echo(
4180
- f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
4181
- )
4182
-
4183
- if records:
4184
- judge_specs = judge_specs or [] # ensure iterable
4185
- official_scores = [
4186
- r["official_score"] for r in records if r["official_score"] is not None
4187
- ]
4188
- if official_scores:
4189
- click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
4190
- else:
4191
- click.echo(" Official mean: n/a")
4192
-
4193
- for spec in judge_specs:
4194
- spec_scores = [
4195
- record["judge_scores"].get(spec.name)
4196
- for record in records
4197
- if record["judge_scores"].get(spec.name) is not None
4198
- ]
4199
- if spec_scores:
4200
- mean_spec = sum(spec_scores) / len(spec_scores)
4201
- click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
4202
- else:
4203
- click.echo(f" [{spec.name}] mean: n/a")
4204
-
4205
- paired = [
4206
- (
4207
- record["official_score"],
4208
- record["judge_scores"].get(spec.name),
4209
- )
4210
- for record in records
4211
- if record["official_score"] is not None
4212
- and record["judge_scores"].get(spec.name) is not None
4213
- ]
4214
- if len(paired) >= 2:
4215
- corr = _pearson(
4216
- [p[0] for p in paired if p[0] is not None],
4217
- [p[1] for p in paired if p[1] is not None],
4218
- )
4219
- if corr is not None:
4220
- click.echo(f" Pearson r: {corr:.3f}")
4221
- else:
4222
- click.echo(" Pearson r: undefined (zero variance)")
4223
- else:
4224
- click.echo(" Pearson r: n/a (need ≥2 paired scores)")
4225
-
4226
- header = ["Seed", "Prompt", "Official"]
4227
- header.extend(spec.name for spec in judge_specs)
4228
- rows: list[list[str]] = []
4229
- for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
4230
- seed_val = str(record["seed"])
4231
- prompt_idx = (
4232
- str(record["prompt_index"])
4233
- if record["prompt_index"] is not None
4234
- else "-"
4235
- )
4236
- official_val = (
4237
- f"{record['official_score']:.3f}"
4238
- if record["official_score"] is not None
4239
- else "-"
4240
- )
4241
- row = [seed_val, prompt_idx, official_val]
4242
- for spec in judge_specs:
4243
- score_val = record["judge_scores"].get(spec.name)
4244
- row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
4245
- rows.append(row)
4246
-
4247
- widths = [len(col) for col in header]
4248
- for row in rows:
4249
- for idx, cell in enumerate(row):
4250
- widths[idx] = max(widths[idx], len(cell))
4251
-
4252
- click.echo("")
4253
- click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
4254
- click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
4255
- for row in rows:
4256
- click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
4257
-
4258
-
4259
-
4260
- @click.command(
4261
- "filter",
4262
- help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
4263
- )
4264
- @click.option(
4265
- "--config",
4266
- "config_path",
4267
- type=click.Path(),
4268
- required=True,
4269
- help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
4270
- )
4271
- def filter_command(config_path: str) -> None:
4272
- """Render tracing sessions that match filter rules into SFT JSONL.
4273
-
4274
- The TOML file should contain a `[filter]` table with at least:
4275
-
4276
- db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
4277
- output = \"ft_data/out.jsonl\" # destination JSONL
4278
-
4279
- Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
4280
- `min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
4281
- high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
4282
- for a working example.
4283
- """
4284
- # Parse and validate TOML config
4285
- from synth_ai.task.config import FilterConfig
4286
-
4287
- if _toml is None:
4288
- raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
4289
-
4290
- cfg_path = Path(config_path)
4291
- if not cfg_path.exists():
4292
- raise click.ClickException(f"Filter config not found: {cfg_path}")
4293
-
4294
- try:
4295
- config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
4296
- except Exception as exc:
4297
- raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
4298
-
4299
- filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
4300
- if not isinstance(filter_cfg_dict, dict):
4301
- raise click.ClickException("Config must contain a [filter] table")
4302
-
4303
- # Validate config with dataclass
4304
- try:
4305
- filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
4306
- click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
4307
- if filter_cfg.min_official_score is not None:
4308
- click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
4309
- if filter_cfg.limit:
4310
- click.echo(f" → Limiting to {filter_cfg.limit} examples")
4311
- except (ValueError, TypeError) as validation_error:
4312
- raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
4313
-
4314
- # Use validated config
4315
- db_url = filter_cfg.get_db_url()
4316
- output_path = filter_cfg.get_output_path()
4317
-
4318
- # Extract validated fields from dataclass
4319
- splits = set(filter_cfg.splits)
4320
- task_ids = set(filter_cfg.task_ids)
4321
- models = set(filter_cfg.models)
4322
- min_official = filter_cfg.min_official_score
4323
- max_official = filter_cfg.max_official_score
4324
- min_judge_scores = filter_cfg.min_judge_scores
4325
- max_judge_scores = filter_cfg.max_judge_scores
4326
- # Note: min_created_at and max_created_at not yet in FilterConfig dataclass
4327
- min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
4328
- max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
4329
- limit = filter_cfg.limit
4330
-
4331
- def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
4332
- try:
4333
- if value is None:
4334
- return min_val is None
4335
- value = float(value)
4336
- except Exception:
4337
- return False
4338
- if min_val is not None and value < float(min_val):
4339
- return False
4340
- return not (max_val is not None and value > float(max_val))
4341
-
4342
- async def _run_filter() -> None:
4343
- tracer = SessionTracer(db_url=db_url, auto_save=False)
4344
- await tracer.initialize()
4345
-
4346
- df = await tracer.db.query_traces(
4347
- "SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
4348
- )
4349
- if getattr(df, "empty", True):
4350
- raise click.ClickException("No traces found in database")
4351
-
4352
- sessions = df.to_dict("records")
4353
- accepted: list[dict[str, Any]] = []
4354
-
4355
- for row in sessions:
4356
- metadata_raw = row.get("metadata")
4357
- if isinstance(metadata_raw, str):
4358
- try:
4359
- metadata = json.loads(metadata_raw)
4360
- except Exception:
4361
- metadata = {}
4362
- elif isinstance(metadata_raw, dict):
4363
- metadata = dict(metadata_raw)
4364
- else:
4365
- metadata = {}
4366
-
4367
- created_at_raw = row.get("created_at")
4368
- created_at_dt = _parse_datetime_for_trace(created_at_raw)
4369
-
4370
- session_id = row.get("session_id")
4371
-
4372
- if splits and metadata.get("task_split") not in splits:
4373
- continue
4374
- if task_ids and metadata.get("task_id") not in task_ids:
4375
- continue
4376
- if models and metadata.get("model") not in models:
4377
- continue
4378
-
4379
- if min_created and (created_at_dt is None or created_at_dt < min_created):
4380
- continue
4381
- if max_created and (created_at_dt is None or created_at_dt > max_created):
4382
- continue
4383
-
4384
- # Check against outcome_rewards if score filter is set
4385
- total_reward = None
4386
- achievements_count = None
4387
- if min_official is not None or max_official is not None:
4388
- reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
4389
- reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
4390
- reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
4391
- if reward_records:
4392
- total_reward = reward_records[0].get("total_reward")
4393
- achievements_count = reward_records[0].get("achievements_count")
4394
- if not _score_ok(total_reward, min_official, max_official):
4395
- continue
4396
- elif min_official is not None:
4397
- # No reward found, but score filter requires it
4398
- continue
4399
-
4400
- judge_scores = metadata.get("judge_scores") or {}
4401
- include = True
4402
- for judge_name, threshold in (min_judge_scores or {}).items():
4403
- if not _score_ok(judge_scores.get(judge_name), threshold, None):
4404
- include = False
4405
- break
4406
- if not include:
4407
- continue
4408
- for judge_name, threshold in (max_judge_scores or {}).items():
4409
- if not _score_ok(judge_scores.get(judge_name), None, threshold):
4410
- include = False
4411
- break
4412
- if not include:
4413
- continue
4414
-
4415
- # Query messages for this session
4416
- messages_query = """
4417
- SELECT message_type, content, timestamp
4418
- FROM messages
4419
- WHERE session_id = :session_id
4420
- ORDER BY timestamp ASC, id ASC
4421
- """
4422
- msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
4423
- message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
4424
-
4425
- if not message_rows:
4426
- # Fallback: check if prompt/completion in metadata (old format)
4427
- prompt = metadata.get("prompt") or ""
4428
- completion = metadata.get("completion") or ""
4429
- if prompt and completion:
4430
- record = {
4431
- "messages": [
4432
- {"role": "user", "content": str(prompt)},
4433
- {"role": "assistant", "content": str(completion)},
4434
- ],
4435
- "metadata": {
4436
- "session_id": session_id,
4437
- "env_name": metadata.get("env_name"),
4438
- "policy_name": metadata.get("policy_name"),
4439
- "seed": metadata.get("seed"),
4440
- "total_reward": total_reward,
4441
- "achievements_count": achievements_count,
4442
- "model": metadata.get("model"),
4443
- "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4444
- },
4445
- }
4446
- accepted.append(record)
4447
- continue
4448
-
4449
- # Extract user/assistant pairs from messages
4450
- for i, msg_row in enumerate(message_rows):
4451
- msg_type = msg_row.get("message_type")
4452
- content_raw = msg_row.get("content")
4453
-
4454
- # Look for user message
4455
- if msg_type in ("user", "policy_user_prompt"):
4456
- # Find next policy_system_prompt or assistant
4457
- assistant_msg = None
4458
- for j in range(i + 1, len(message_rows)):
4459
- next_type = message_rows[j].get("message_type")
4460
- if next_type in ("assistant", "policy_system_prompt"):
4461
- if next_type == "assistant":
4462
- assistant_msg = message_rows[j]
4463
- break
4464
-
4465
- # Parse content
4466
- try:
4467
- user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
4468
- except Exception:
4469
- user_content = content_raw
4470
-
4471
- # If user_content is a message dict with a 'content' key, extract it
4472
- if isinstance(user_content, dict) and "content" in user_content:
4473
- user_content = user_content["content"]
4474
-
4475
- # Extract text from structured content
4476
- def extract_text(content: Any) -> str:
4477
- if isinstance(content, str):
4478
- return content
4479
- if isinstance(content, dict):
4480
- # Try payload.content for user prompts
4481
- if "payload" in content and isinstance(content["payload"], dict):
4482
- payload = content["payload"]
4483
- if "content" in payload:
4484
- return extract_text(payload["content"])
4485
- # Try common keys
4486
- for key in ["text", "content", "content_text"]:
4487
- if key in content:
4488
- val = content[key]
4489
- if isinstance(val, str):
4490
- return val
4491
- return json.dumps(content)
4492
- if isinstance(content, list):
4493
- # Multimodal content - concatenate text parts
4494
- parts = []
4495
- for item in content:
4496
- if isinstance(item, dict) and item.get("type") == "text":
4497
- parts.append(item.get("text", ""))
4498
- return " ".join(parts) if parts else str(content)
4499
- return str(content)
4500
-
4501
- user_text = extract_text(user_content)
4502
-
4503
- # For assistant, we might not have it recorded, so use tool calls as completion
4504
- assistant_text = ""
4505
- assistant_content = None
4506
- if assistant_msg:
4507
- assistant_content_raw = assistant_msg.get("content")
4508
- try:
4509
- assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
4510
- except Exception:
4511
- assistant_content = assistant_content_raw
4512
-
4513
- # If assistant_content is a message dict with a 'content' key, extract it
4514
- if isinstance(assistant_content, dict) and "content" in assistant_content:
4515
- assistant_content = assistant_content["content"]
4516
-
4517
- assistant_text = extract_text(assistant_content)
4518
-
4519
- if not user_text:
4520
- continue
4521
-
4522
- # Use full multimodal content if it's a list (contains images), otherwise use text
4523
- user_content_for_message = user_content if isinstance(user_content, list) else user_text
4524
- assistant_content_for_message = assistant_content if isinstance(assistant_content, list) else (assistant_text if assistant_text else "[no response recorded]")
4525
-
4526
- record = {
4527
- "messages": [
4528
- {"role": "user", "content": user_content_for_message},
4529
- {"role": "assistant", "content": assistant_content_for_message},
4530
- ],
4531
- "metadata": {
4532
- "session_id": session_id,
4533
- "env_name": metadata.get("env_name"),
4534
- "policy_name": metadata.get("policy_name"),
4535
- "seed": metadata.get("seed"),
4536
- "total_reward": total_reward,
4537
- "achievements_count": achievements_count,
4538
- "model": metadata.get("model"),
4539
- "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4540
- },
4541
- }
4542
- accepted.append(record)
4543
-
4544
- if not accepted:
4545
- raise click.ClickException("No sessions matched the provided filters")
4546
-
4547
- if limit is not None and limit > 0:
4548
- accepted = accepted[:limit]
4549
-
4550
- output_path.parent.mkdir(parents=True, exist_ok=True)
4551
- with output_path.open("w", encoding="utf-8") as handle:
4552
- for item in accepted:
4553
- handle.write(json.dumps(item, ensure_ascii=False))
4554
- handle.write("\n")
4555
-
4556
- click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
4557
- await tracer.db.close()
3128
+ eval_command = eval_core.command
4558
3129
 
4559
- asyncio.run(_run_filter())
3130
+ filter_command = filter_core.command
4560
3131
 
4561
3132
 
4562
3133
  def register_eval(cli: click.Group) -> None: