synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (291) hide show
  1. examples/multi_step/configs/README_verilog_rl.md +77 -0
  2. examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
  3. examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
  4. examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
  5. examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
  6. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +17 -5
  7. examples/multi_step/configs/crafter_synth_backend.md +40 -0
  8. examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
  9. examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
  10. examples/multi_step/configs/verilog_rl_lora.toml +190 -0
  11. examples/multi_step/judges/crafter_backend_judge.py +220 -0
  12. examples/multi_step/judges/verilog_backend_judge.py +234 -0
  13. examples/multi_step/readme.md +48 -0
  14. examples/multi_step/verilog_rl_lora.md +218 -0
  15. examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
  16. examples/sft/evaluate.py +2 -0
  17. examples/sft/generate_traces.py +2 -0
  18. examples/swe/task_app/grpo_swe_mini.py +56 -26
  19. examples/swe/task_app/hosted/rollout.py +42 -0
  20. examples/swe/task_app/hosted/test_service.py +5 -6
  21. examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
  22. examples/task_apps/TESTING.md +275 -0
  23. examples/task_apps/__init__.py +0 -0
  24. examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
  25. examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
  26. examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
  27. examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
  28. examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
  29. examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
  30. examples/task_apps/crafter/__init__.py +0 -0
  31. examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
  32. examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
  33. examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
  34. examples/task_apps/crafter/task_app/__init__.py +5 -0
  35. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +324 -21
  36. examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
  37. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
  38. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +76 -7
  39. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
  40. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
  41. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +77 -4
  42. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +117 -9
  43. examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
  44. examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +218 -0
  45. examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
  46. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
  47. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
  48. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
  49. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
  50. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
  51. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
  52. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
  53. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
  54. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
  55. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
  56. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
  57. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
  58. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
  59. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
  60. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
  61. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
  62. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
  63. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
  64. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
  65. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
  66. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
  67. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
  68. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
  69. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
  70. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
  71. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
  72. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
  73. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
  74. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
  75. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
  76. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
  77. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
  78. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
  79. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
  80. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
  81. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
  82. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
  83. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
  84. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
  85. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
  86. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
  87. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
  88. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
  89. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
  90. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
  91. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
  92. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
  93. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
  94. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
  95. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
  96. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
  97. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
  98. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
  99. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
  100. examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
  101. examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
  102. examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
  103. examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
  104. examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
  105. examples/task_apps/enron/__init__.py +1 -0
  106. examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
  107. examples/task_apps/enron/filter_sft.toml +5 -0
  108. examples/task_apps/enron/task_app/README.md +14 -0
  109. examples/task_apps/enron/task_app/__init__.py +1 -0
  110. examples/task_apps/enron/task_app/grpo_enron.py +906 -0
  111. examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
  112. examples/task_apps/enron/tests/__init__.py +4 -0
  113. examples/task_apps/enron/tests/conftest.py +115 -0
  114. examples/task_apps/enron/tests/integration/__init__.py +4 -0
  115. examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
  116. examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
  117. examples/task_apps/enron/tests/unit/__init__.py +4 -0
  118. examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
  119. examples/task_apps/math/__init__.py +0 -0
  120. examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
  121. examples/task_apps/pokemon_battle/__init__.py +2 -0
  122. examples/task_apps/pokemon_battle/modal_app.py +104 -0
  123. examples/task_apps/pokemon_battle/task_app/README.md +68 -0
  124. examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
  125. examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
  126. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
  127. examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
  128. examples/task_apps/pokemon_red/README.md +357 -0
  129. examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
  130. examples/task_apps/pokemon_red/__init__.py +3 -0
  131. examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
  132. examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
  133. examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
  134. examples/task_apps/pokemon_red/task_app.py +799 -0
  135. examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
  136. examples/task_apps/sokoban/README.md +307 -0
  137. examples/task_apps/sokoban/__init__.py +3 -0
  138. examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
  139. examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
  140. examples/task_apps/sokoban/filter_sft.toml +5 -0
  141. examples/task_apps/sokoban/task_app.py +1058 -0
  142. examples/task_apps/sokoban/tests/__init__.py +4 -0
  143. examples/task_apps/sokoban/tests/conftest.py +113 -0
  144. examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
  145. examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
  146. examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
  147. examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
  148. examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
  149. examples/task_apps/verilog/__init__.py +1 -0
  150. examples/task_apps/verilog/eval_groq_qwen32b.toml +24 -0
  151. examples/task_apps/verilog/filter_sft.toml +5 -0
  152. examples/task_apps/verilog/task_app/README.md +12 -0
  153. examples/task_apps/verilog/task_app/__init__.py +1 -0
  154. examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
  155. examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
  156. examples/task_apps/verilog/tests/__init__.py +4 -0
  157. examples/task_apps/verilog/tests/conftest.py +115 -0
  158. examples/task_apps/verilog/tests/integration/__init__.py +4 -0
  159. examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
  160. examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
  161. examples/task_apps/verilog/tests/unit/__init__.py +4 -0
  162. examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
  163. examples/vlm/crafter_openai_vlm_agent.py +4 -4
  164. examples/vlm/run_crafter_vlm_benchmark.py +4 -4
  165. examples/warming_up_to_rl/groq_test.py +2 -0
  166. examples/warming_up_to_rl/run_local_rollout.py +2 -0
  167. examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
  168. examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
  169. examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
  170. examples/warming_up_to_rl/run_rollout_remote.py +2 -0
  171. examples/workflows/__init__.py +0 -0
  172. examples/workflows/math_rl/__init__.py +0 -0
  173. examples/workflows/math_rl/download_dataset.py +80 -0
  174. synth_ai/__init__.py +2 -2
  175. synth_ai/api/models/supported.py +1 -0
  176. synth_ai/api/train/builders.py +25 -11
  177. synth_ai/api/train/cli.py +12 -6
  178. synth_ai/api/train/configs/__init__.py +10 -10
  179. synth_ai/api/train/configs/rl.py +5 -4
  180. synth_ai/api/train/configs/sft.py +4 -3
  181. synth_ai/api/train/env_resolver.py +5 -2
  182. synth_ai/api/train/supported_algos.py +10 -5
  183. synth_ai/api/train/utils.py +7 -4
  184. synth_ai/cli/__init__.py +48 -59
  185. synth_ai/cli/_modal_wrapper.py +3 -2
  186. synth_ai/cli/_storage.py +4 -3
  187. synth_ai/cli/_validate_task_app.py +11 -0
  188. synth_ai/cli/balance.py +4 -3
  189. synth_ai/cli/calc.py +2 -2
  190. synth_ai/cli/demo.py +14 -7
  191. synth_ai/cli/legacy_root_backup.py +1 -1
  192. synth_ai/cli/recent.py +1 -1
  193. synth_ai/cli/rl_demo.py +8 -7
  194. synth_ai/cli/root.py +0 -97
  195. synth_ai/cli/status.py +1 -1
  196. synth_ai/cli/task_apps.py +1922 -190
  197. synth_ai/cli/traces.py +1 -1
  198. synth_ai/cli/tui.py +57 -0
  199. synth_ai/cli/turso.py +1 -1
  200. synth_ai/cli/watch.py +1 -1
  201. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +29 -17
  202. synth_ai/environments/examples/crafter_classic/environment.py +1 -1
  203. synth_ai/environments/examples/enron/engine.py +7 -2
  204. synth_ai/environments/examples/enron/environment.py +68 -0
  205. synth_ai/environments/examples/red/engine.py +27 -0
  206. synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
  207. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
  208. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
  209. synth_ai/environments/examples/red/environment.py +60 -0
  210. synth_ai/environments/examples/sokoban/taskset.py +116 -0
  211. synth_ai/environments/examples/verilog/engine.py +104 -12
  212. synth_ai/evals/client.py +58 -61
  213. synth_ai/jobs/client.py +16 -4
  214. synth_ai/judge_schemas.py +9 -9
  215. synth_ai/py.typed +0 -0
  216. synth_ai/task/__init__.py +24 -5
  217. synth_ai/task/apps/__init__.py +1 -0
  218. synth_ai/task/config.py +257 -0
  219. synth_ai/task/contracts.py +138 -39
  220. synth_ai/task/proxy.py +48 -56
  221. synth_ai/task/rubrics/__init__.py +56 -0
  222. synth_ai/task/rubrics/loaders.py +152 -0
  223. synth_ai/task/rubrics/models.py +57 -0
  224. synth_ai/task/rubrics/scoring.py +116 -0
  225. synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
  226. synth_ai/task/server.py +8 -7
  227. synth_ai/task/trace_correlation_helpers.py +315 -0
  228. synth_ai/task/validators.py +413 -6
  229. synth_ai/tracing_v3/abstractions.py +3 -3
  230. synth_ai/tracing_v3/decorators.py +7 -3
  231. synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
  232. synth_ai/tracing_v3/replica_sync.py +4 -4
  233. synth_ai/tracing_v3/serialization.py +5 -5
  234. synth_ai/tracing_v3/session_tracer.py +16 -6
  235. synth_ai/tracing_v3/storage/base.py +29 -29
  236. synth_ai/tracing_v3/storage/config.py +3 -3
  237. synth_ai/tracing_v3/trace_utils.py +317 -0
  238. synth_ai/tracing_v3/turso/daemon.py +8 -7
  239. synth_ai/tracing_v3/turso/native_manager.py +66 -43
  240. synth_ai/tracing_v3/utils.py +3 -3
  241. synth_ai/tui/__init__.py +5 -0
  242. synth_ai/tui/__main__.py +13 -0
  243. synth_ai/tui/cli/__init__.py +1 -0
  244. synth_ai/tui/cli/query_experiments.py +164 -0
  245. synth_ai/tui/cli/query_experiments_v3.py +164 -0
  246. synth_ai/tui/dashboard.py +906 -0
  247. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/METADATA +4 -1
  248. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/RECORD +278 -126
  249. examples/agora_ex/README_MoE.md +0 -224
  250. examples/agora_ex/__init__.py +0 -7
  251. examples/agora_ex/agora_ex.py +0 -65
  252. examples/agora_ex/agora_ex_task_app.py +0 -590
  253. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
  254. examples/agora_ex/reward_fn_grpo-human.py +0 -129
  255. examples/agora_ex/system_prompt_CURRENT.md +0 -63
  256. examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
  257. examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
  258. examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
  259. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +0 -62
  260. synth_ai/rubrics/__init__.py +0 -22
  261. synth_ai/task/rubrics.py +0 -219
  262. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
  263. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
  264. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
  265. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
  266. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
  267. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
  268. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
  269. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
  270. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
  271. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
  272. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
  273. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
  274. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
  275. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
  276. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
  277. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
  278. /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
  279. /examples/{rl/task_app → task_apps/math}/README.md +0 -0
  280. /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
  281. /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
  282. /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
  283. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
  284. /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
  285. /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
  286. /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
  287. /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
  288. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
  289. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
  290. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
  291. {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py CHANGED
@@ -1,48 +1,80 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import argparse
3
4
  import ast
4
5
  import asyncio
5
6
  import contextlib
7
+ import functools
6
8
  import hashlib
7
9
  import importlib
8
10
  import importlib.util
9
11
  import inspect
10
12
  import json
11
13
  import os
14
+ import shlex
12
15
  import shutil
13
16
  import signal
17
+ import sqlite3
14
18
  import subprocess
15
19
  import sys
16
20
  import tempfile
17
21
  import textwrap
22
+ import time
18
23
  import types
24
+ import uuid
19
25
  from collections.abc import Callable, Iterable, Iterator, Sequence
20
26
  from dataclasses import dataclass
27
+ from datetime import datetime, timezone
21
28
  from pathlib import Path
22
- from typing import Any, cast
29
+ from typing import Any, Optional, cast
23
30
 
24
31
  try: # Python 3.11+
25
32
  import tomllib as _toml
26
33
  except Exception: # pragma: no cover - fallback
27
34
  _toml = None # type: ignore
28
- import uuid
29
35
 
30
36
  import click
31
37
  from click.exceptions import Abort
32
38
 
39
+ # Tracing imports - make conditional for optional dependencies
40
+ try:
41
+ from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
42
+ BaseEvent,
43
+ EnvironmentEvent,
44
+ RuntimeEvent,
45
+ SessionEventMarkovBlanketMessage,
46
+ SessionMessageContent,
47
+ SessionTimeStep,
48
+ SessionTracer,
49
+ TimeRecord,
50
+ )
51
+ from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
52
+ SessionTrace as V3SessionTrace,
53
+ )
54
+ _TRACING_AVAILABLE = True
55
+ except (ImportError, ModuleNotFoundError, TypeError):
56
+ # Tracing system not available (missing optional dependencies)
57
+ BaseEvent = EnvironmentEvent = RuntimeEvent = None # type: ignore
58
+ SessionEventMarkovBlanketMessage = SessionMessageContent = None # type: ignore
59
+ SessionTimeStep = SessionTracer = TimeRecord = None # type: ignore
60
+ V3SessionTrace = None # type: ignore
61
+ _TRACING_AVAILABLE = False
62
+
33
63
  # ---------------------------------------------------------------------------
34
64
  # Dynamic imports to avoid hard dependencies during type checking.
35
65
  # ---------------------------------------------------------------------------
36
66
  ModalDeploymentConfigType = TaskAppConfigType = TaskAppEntryType = Any
37
67
 
38
68
  try: # Resolve base URL defaults lazily
39
- _config_module = importlib.import_module("synth_ai.config.base_url")
69
+ _config_module = cast(
70
+ Any, importlib.import_module("synth_ai.config.base_url")
71
+ )
40
72
  PROD_BASE_URL_DEFAULT = cast(str, _config_module.PROD_BASE_URL_DEFAULT)
41
73
  except Exception: # pragma: no cover - fallback
42
74
  PROD_BASE_URL_DEFAULT = "https://agent-learning.onrender.com"
43
75
 
44
76
  try:
45
- _task_apps_module = importlib.import_module("synth_ai.task.apps")
77
+ _task_apps_module = cast(Any, importlib.import_module("synth_ai.task.apps"))
46
78
  ModalDeploymentConfig = cast(
47
79
  type[ModalDeploymentConfigType], _task_apps_module.ModalDeploymentConfig
48
80
  )
@@ -53,21 +85,23 @@ except Exception as exc: # pragma: no cover - critical dependency
53
85
  raise RuntimeError("Unable to load task app registry") from exc
54
86
 
55
87
  try:
56
- _task_server_module = importlib.import_module("synth_ai.task.server")
57
- create_task_app = _task_server_module.create_task_app
58
- run_task_app = _task_server_module.run_task_app
88
+ _task_server_module = cast(Any, importlib.import_module("synth_ai.task.server"))
89
+ create_task_app = cast(Callable[..., Any], _task_server_module.create_task_app)
90
+ run_task_app = cast(Callable[..., Any], _task_server_module.run_task_app)
59
91
  except Exception as exc: # pragma: no cover - critical dependency
60
92
  raise RuntimeError("Unable to load task app server utilities") from exc
61
93
 
62
94
 
63
- def _load_demo_directory() -> Path | None:
95
+ def _load_demo_directory() -> Optional[Path]:
64
96
  """Return the demo task apps directory if available."""
65
97
 
66
98
  try:
67
- module = importlib.import_module("synth_ai.demos.demo_task_apps.core")
68
- loader = module.load_demo_dir
99
+ module = cast(
100
+ Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
101
+ )
102
+ loader = cast(Callable[[], Optional[str | Path]], module.load_demo_dir)
69
103
  demo_dir = loader()
70
- if isinstance(demo_dir, (str, Path)):
104
+ if isinstance(demo_dir, str | Path):
71
105
  demo_path = Path(demo_dir)
72
106
  if demo_path.exists():
73
107
  return demo_path.resolve()
@@ -105,13 +139,32 @@ DEFAULT_SEARCH_RELATIVE = (
105
139
  )
106
140
 
107
141
 
142
+ def _pearson(xs: Sequence[float], ys: Sequence[float]) -> Optional[float]:
143
+ if len(xs) != len(ys) or len(xs) < 2:
144
+ return None
145
+ mean_x = sum(xs) / len(xs)
146
+ mean_y = sum(ys) / len(ys)
147
+ num = 0.0
148
+ denom_x = 0.0
149
+ denom_y = 0.0
150
+ for x, y in zip(xs, ys, strict=False):
151
+ dx = x - mean_x
152
+ dy = y - mean_y
153
+ num += dx * dy
154
+ denom_x += dx * dx
155
+ denom_y += dy * dy
156
+ if denom_x <= 0 or denom_y <= 0:
157
+ return None
158
+ return num / (denom_x ** 0.5 * denom_y ** 0.5)
159
+
160
+
108
161
  @dataclass
109
162
  class AppChoice:
110
163
  app_id: str
111
164
  label: str
112
165
  path: Path
113
166
  source: str
114
- description: str | None = None
167
+ description: Optional[str] = None
115
168
  aliases: tuple[str, ...] = ()
116
169
  entry: TaskAppEntryType | None = None
117
170
  entry_loader: Callable[[], TaskAppEntryType] | None = None
@@ -128,6 +181,193 @@ class AppChoice:
128
181
  return entry
129
182
 
130
183
 
184
+ @dataclass
185
+ class JudgeSpec:
186
+ name: str
187
+ fn: Callable[..., Any]
188
+ kwargs: dict[str, Any]
189
+
190
+
191
+ def _parse_datetime_for_trace(value: Any) -> Optional[datetime]:
192
+ if isinstance(value, datetime):
193
+ return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
194
+ if isinstance(value, str):
195
+ value = value.replace("Z", "+00:00")
196
+ try:
197
+ dt = datetime.fromisoformat(value)
198
+ except ValueError:
199
+ try:
200
+ dt = datetime.fromtimestamp(float(value), tz=timezone.utc)
201
+ except Exception:
202
+ return None
203
+ return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
204
+ if isinstance(value, int | float):
205
+ return datetime.fromtimestamp(float(value), tz=timezone.utc)
206
+ return None
207
+
208
+
209
+ def _time_record_from_dict(payload: dict[str, Any] | None) -> TimeRecord:
210
+ payload = payload or {}
211
+ event_time = payload.get("event_time")
212
+ if not isinstance(event_time, int | float):
213
+ try:
214
+ event_time = float(event_time)
215
+ except Exception:
216
+ event_time = float(time.time())
217
+ message_time = payload.get("message_time")
218
+ if message_time is not None:
219
+ try:
220
+ message_time = int(message_time)
221
+ except Exception:
222
+ message_time = None
223
+ return TimeRecord(event_time=event_time, message_time=message_time)
224
+
225
+
226
+ def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
227
+ base_kwargs = {
228
+ "system_instance_id": payload.get("system_instance_id", ""),
229
+ "time_record": _time_record_from_dict(payload.get("time_record")),
230
+ "metadata": payload.get("metadata") or {},
231
+ "event_metadata": payload.get("event_metadata"),
232
+ }
233
+ if "actions" in payload:
234
+ return RuntimeEvent(actions=payload.get("actions") or [], **base_kwargs)
235
+ if any(key in payload for key in ("reward", "terminated", "truncated")):
236
+ return EnvironmentEvent(
237
+ reward=float(payload.get("reward", 0.0) or 0.0),
238
+ terminated=bool(payload.get("terminated", False)),
239
+ truncated=bool(payload.get("truncated", False)),
240
+ system_state_before=payload.get("system_state_before"),
241
+ system_state_after=payload.get("system_state_after"),
242
+ **base_kwargs,
243
+ )
244
+ return BaseEvent(**base_kwargs)
245
+
246
+
247
+ def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlanketMessage:
248
+ content_payload = payload.get("content") or {}
249
+ content = SessionMessageContent(
250
+ text=content_payload.get("text"),
251
+ json_payload=content_payload.get("json_payload"),
252
+ )
253
+ raw_type = (payload.get("message_type") or "").lower()
254
+ if raw_type == "observation":
255
+ normalized_type = "system"
256
+ elif raw_type == "action":
257
+ normalized_type = "assistant"
258
+ elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
259
+ normalized_type = raw_type
260
+ else:
261
+ normalized_type = "system"
262
+
263
+ return SessionEventMarkovBlanketMessage(
264
+ content=content,
265
+ message_type=normalized_type,
266
+ time_record=_time_record_from_dict(payload.get("time_record")),
267
+ metadata=payload.get("metadata") or {},
268
+ )
269
+
270
+
271
+ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
272
+ events = [
273
+ _event_from_dict(event)
274
+ for event in payload.get("events", [])
275
+ if isinstance(event, dict)
276
+ ]
277
+ messages = [
278
+ _markov_message_from_dict(msg)
279
+ for msg in payload.get("markov_blanket_messages", [])
280
+ if isinstance(msg, dict)
281
+ ]
282
+ timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(timezone.utc)
283
+ completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
284
+ return SessionTimeStep(
285
+ step_id=payload.get("step_id", ""),
286
+ step_index=int(payload.get("step_index", 0) or 0),
287
+ timestamp=timestamp,
288
+ turn_number=payload.get("turn_number"),
289
+ events=events,
290
+ markov_blanket_messages=messages,
291
+ step_metadata=payload.get("step_metadata") or {},
292
+ completed_at=completed_at,
293
+ )
294
+
295
+
296
+ def _session_trace_from_dict(payload: dict[str, Any]) -> Optional[V3SessionTrace]:
297
+ if not isinstance(payload, dict):
298
+ return None
299
+ steps = [
300
+ _step_from_dict(step)
301
+ for step in payload.get("session_time_steps", [])
302
+ if isinstance(step, dict)
303
+ ]
304
+ events = [
305
+ _event_from_dict(event)
306
+ for event in payload.get("event_history", [])
307
+ if isinstance(event, dict)
308
+ ]
309
+ markov_history = [
310
+ _markov_message_from_dict(msg)
311
+ for msg in payload.get("markov_blanket_message_history", [])
312
+ if isinstance(msg, dict)
313
+ ]
314
+ created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(timezone.utc)
315
+ metadata = payload.get("metadata") or {}
316
+ session_metadata = payload.get("session_metadata")
317
+ return V3SessionTrace(
318
+ session_id=payload.get("session_id", ""),
319
+ created_at=created_at,
320
+ session_time_steps=steps,
321
+ event_history=events,
322
+ markov_blanket_message_history=markov_history,
323
+ metadata=metadata,
324
+ session_metadata=session_metadata,
325
+ )
326
+
327
+
328
+ async def _store_trace(
329
+ tracer: SessionTracer | None,
330
+ trace_namespace: dict[str, Any] | None,
331
+ extra_metadata: dict[str, Any] | None = None,
332
+ ):
333
+ import logging
334
+ _logger = logging.getLogger(__name__)
335
+
336
+ _logger.info(f"[STORE_TRACE_DEBUG] Called with tracer={tracer is not None}, trace_namespace={trace_namespace is not None}")
337
+
338
+ if tracer is None or not isinstance(trace_namespace, dict):
339
+ _logger.warning(f"[STORE_TRACE_DEBUG] Early return: tracer={tracer is not None}, trace_namespace type={type(trace_namespace)}")
340
+ return
341
+
342
+ _logger.info(f"[STORE_TRACE_DEBUG] trace_namespace keys: {list(trace_namespace.keys())}")
343
+
344
+ session_payload = trace_namespace.get("session_trace")
345
+ if not isinstance(session_payload, dict):
346
+ _logger.warning(f"[STORE_TRACE_DEBUG] No session_trace found or wrong type: {type(session_payload)}")
347
+ return
348
+
349
+ _logger.info(f"[STORE_TRACE_DEBUG] session_payload keys: {list(session_payload.keys())}")
350
+ msg_count = len(session_payload.get("markov_blanket_message_history", []))
351
+ _logger.info(f"[STORE_TRACE_DEBUG] Found {msg_count} messages in session_payload")
352
+
353
+ trace_obj = _session_trace_from_dict(session_payload)
354
+ if trace_obj is None:
355
+ _logger.warning(f"[STORE_TRACE_DEBUG] _session_trace_from_dict returned None")
356
+ return
357
+
358
+ _logger.info(f"[STORE_TRACE_DEBUG] Created SessionTrace object with {len(trace_obj.markov_blanket_message_history)} messages")
359
+
360
+ if tracer.db is None:
361
+ await tracer.initialize()
362
+ meta = dict(trace_obj.metadata or {})
363
+ if extra_metadata:
364
+ meta.update(extra_metadata)
365
+ trace_obj.metadata = meta
366
+
367
+ _logger.info(f"[STORE_TRACE_DEBUG] Calling insert_session_trace for session_id={trace_obj.session_id}")
368
+ await tracer.db.insert_session_trace(trace_obj)
369
+ _logger.info(f"[STORE_TRACE_DEBUG] Successfully inserted trace")
370
+
131
371
  def _temporary_sys_path(paths: Sequence[Path]):
132
372
  """Context manager to prepend entries to sys.path temporarily."""
133
373
 
@@ -676,36 +916,44 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
676
916
  elif kw.arg == "pip_packages" and isinstance(kw.value, (ast.List, ast.Tuple)):
677
917
  # Handle pip_packages list/tuple
678
918
  packages: list[str] = []
679
- for elt in kw.value.elts:
680
- if isinstance(elt, ast.Constant):
681
- packages.append(elt.value)
919
+ value_node = kw.value
920
+ if isinstance(value_node, (ast.List, ast.Tuple)):
921
+ for elt in value_node.elts:
922
+ if isinstance(elt, ast.Constant):
923
+ packages.append(elt.value)
682
924
  kwargs[kw.arg] = tuple(packages)
683
925
  elif kw.arg == "extra_local_dirs" and isinstance(kw.value, (ast.List, ast.Tuple)):
684
926
  # Handle extra_local_dirs list/tuple of tuples
685
927
  dirs = []
686
- for elt in kw.value.elts:
687
- if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
688
- src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
689
- dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
690
- if src and dst:
691
- dirs.append((src, dst))
928
+ value_node = kw.value
929
+ if isinstance(value_node, (ast.List, ast.Tuple)):
930
+ for elt in value_node.elts:
931
+ if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
932
+ src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
933
+ dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
934
+ if src and dst:
935
+ dirs.append((src, dst))
692
936
  kwargs[kw.arg] = tuple(dirs)
693
937
  elif kw.arg == "secret_names" and isinstance(kw.value, (ast.List, ast.Tuple)):
694
938
  # Handle secret_names list/tuple
695
939
  secrets = []
696
- for elt in kw.value.elts:
697
- if isinstance(elt, ast.Constant):
698
- secrets.append(elt.value)
940
+ value_node = kw.value
941
+ if isinstance(value_node, (ast.List, ast.Tuple)):
942
+ for elt in value_node.elts:
943
+ if isinstance(elt, ast.Constant):
944
+ secrets.append(elt.value)
699
945
  kwargs[kw.arg] = tuple(secrets)
700
946
  elif kw.arg == "volume_mounts" and isinstance(kw.value, (ast.List, ast.Tuple)):
701
947
  # Handle volume_mounts list/tuple of tuples
702
948
  mounts = []
703
- for elt in kw.value.elts:
704
- if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
705
- name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
706
- mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
707
- if name and mount:
708
- mounts.append((name, mount))
949
+ value_node = kw.value
950
+ if isinstance(value_node, (ast.List, ast.Tuple)):
951
+ for elt in value_node.elts:
952
+ if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
953
+ name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
954
+ mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
955
+ if name and mount:
956
+ mounts.append((name, mount))
709
957
  kwargs[kw.arg] = tuple(mounts)
710
958
 
711
959
  return ModalDeploymentConfig(**kwargs)
@@ -832,6 +1080,71 @@ def _import_task_app_module(
832
1080
  return module
833
1081
 
834
1082
 
1083
+ @contextlib.contextmanager
1084
+ def _safe_import_context() -> Iterator[None]:
1085
+ """Guard module imports against argparse/uvicorn side effects."""
1086
+
1087
+ original_argv = sys.argv[:]
1088
+ sys.argv = [original_argv[0]] if original_argv else ["python"]
1089
+
1090
+ parser_cls = argparse.ArgumentParser
1091
+ old_parse_args = parser_cls.parse_args
1092
+
1093
+ def _parse_noargs(self, args=None, namespace=None): # type: ignore[override]
1094
+ if args is None:
1095
+ args = []
1096
+ if namespace is None:
1097
+ namespace = argparse.Namespace()
1098
+ try:
1099
+ return old_parse_args(self, args, namespace)
1100
+ except SystemExit:
1101
+ return namespace
1102
+
1103
+ parser_cls.parse_args = _parse_noargs # type: ignore[assignment]
1104
+
1105
+ uvicorn_run = None
1106
+ run_task_app_orig = None
1107
+ try:
1108
+ import uvicorn # type: ignore
1109
+
1110
+ uvicorn_run = uvicorn.run
1111
+ uvicorn.run = lambda *args, **kwargs: None # type: ignore[assignment]
1112
+ except Exception:
1113
+ uvicorn_run = None
1114
+
1115
+ try:
1116
+ _task_server_patch = cast(
1117
+ Any, importlib.import_module("synth_ai.task.server")
1118
+ )
1119
+ run_task_app_orig = cast(Callable[..., Any], _task_server_patch.run_task_app)
1120
+ _task_server_patch.run_task_app = ( # type: ignore[assignment]
1121
+ lambda *args, **kwargs: None
1122
+ )
1123
+ except Exception:
1124
+ run_task_app_orig = None
1125
+
1126
+ try:
1127
+ yield
1128
+ finally:
1129
+ sys.argv = original_argv
1130
+ parser_cls.parse_args = old_parse_args # type: ignore[assignment]
1131
+ if uvicorn_run is not None:
1132
+ try:
1133
+ import uvicorn # type: ignore
1134
+
1135
+ uvicorn.run = uvicorn_run # type: ignore[assignment]
1136
+ except Exception:
1137
+ pass
1138
+ if run_task_app_orig is not None:
1139
+ try:
1140
+ _task_server_patch = cast(
1141
+ Any, importlib.import_module("synth_ai.task.server")
1142
+ )
1143
+ _task_server_patch.run_task_app = run_task_app_orig # type: ignore[assignment]
1144
+ except Exception:
1145
+ pass
1146
+
1147
+
835
1148
  def _load_entry_from_path(
836
1149
  path: Path, app_id: str, module_search_roots: Sequence[Path] | None = None
837
1150
  ) -> TaskAppEntryType:
@@ -859,13 +1172,14 @@ def _load_entry_from_path(
859
1172
 
860
1173
  for module_name, namespace_root in _possible_module_names(resolved, search_roots):
861
1174
  try:
862
- module = _import_task_app_module(
863
- resolved,
864
- module_name,
865
- namespace_root=namespace_root,
866
- sys_path_roots=search_roots,
867
- ensure_namespace=True,
868
- )
1175
+ with _safe_import_context():
1176
+ module = _import_task_app_module(
1177
+ resolved,
1178
+ module_name,
1179
+ namespace_root=namespace_root,
1180
+ sys_path_roots=search_roots,
1181
+ ensure_namespace=True,
1182
+ )
869
1183
  break
870
1184
  except Exception as exc: # pragma: no cover - best-effort fallbacks
871
1185
  last_error = exc
@@ -874,13 +1188,14 @@ def _load_entry_from_path(
874
1188
  if module is None:
875
1189
  hashed_name = f"_synth_task_app_{hashlib.md5(str(resolved).encode(), usedforsecurity=False).hexdigest()}"
876
1190
  try:
877
- module = _import_task_app_module(
878
- resolved,
879
- hashed_name,
880
- namespace_root=None,
881
- sys_path_roots=search_roots,
882
- ensure_namespace=False,
883
- )
1191
+ with _safe_import_context():
1192
+ module = _import_task_app_module(
1193
+ resolved,
1194
+ hashed_name,
1195
+ namespace_root=None,
1196
+ sys_path_roots=search_roots,
1197
+ ensure_namespace=False,
1198
+ )
884
1199
  except Exception as exc: # pragma: no cover - propagate meaningful error
885
1200
  detail = last_error or exc
886
1201
  raise click.ClickException(f"Failed to import {resolved}: {detail}") from detail
@@ -928,7 +1243,10 @@ def _load_entry_from_path(
928
1243
  if has_required:
929
1244
  continue
930
1245
  try:
931
- result = attr()
1246
+ with _safe_import_context():
1247
+ result = attr()
1248
+ except SystemExit:
1249
+ continue
932
1250
  except Exception:
933
1251
  continue
934
1252
  if isinstance(result, TaskAppConfig) and result.app_id == app_id:
@@ -1024,21 +1342,173 @@ def _resolve_env_paths_for_script(script_path: Path, explicit: Sequence[str]) ->
1024
1342
  return [env_candidates[choice - 1]]
1025
1343
 
1026
1344
 
1345
+ def _path_is_within(child: Path, parent: Path) -> bool:
1346
+ try:
1347
+ child.resolve().relative_to(parent.resolve())
1348
+ return True
1349
+ except Exception:
1350
+ return False
1351
+
1352
+
1353
+ @functools.lru_cache(maxsize=16)
1354
+ def _is_modal_shim(path_str: str) -> bool:
1355
+ """Return True if the candidate CLI path refers to the synth-ai shim."""
1356
+
1357
+ path = Path(path_str)
1358
+ try:
1359
+ resolved = path.resolve(strict=True)
1360
+ except Exception:
1361
+ resolved = path
1362
+
1363
+ if not resolved.exists() or resolved.is_dir():
1364
+ return False
1365
+
1366
+ snippet = ""
1367
+ try:
1368
+ snippet = resolved.read_bytes()[:4096].decode("utf-8", errors="ignore")
1369
+ except Exception:
1370
+ snippet = ""
1371
+
1372
+ shim_markers = (
1373
+ "synth_ai.cli._modal_wrapper",
1374
+ "from modal.__main__ import main",
1375
+ "import modal.__main__",
1376
+ "run_module('modal.__main__'",
1377
+ )
1378
+ if snippet and any(marker in snippet for marker in shim_markers):
1379
+ return True
1380
+
1381
+ try:
1382
+ size = resolved.stat().st_size
1383
+ except Exception:
1384
+ size = None
1385
+
1386
+ if (
1387
+ size is not None
1388
+ and size < 2048
1389
+ and "python" in (snippet.splitlines() or [""])[0]
1390
+ and (
1391
+ "modal.__main__" in snippet
1392
+ or "modal.__main__" in snippet.replace(" ", "")
1393
+ )
1394
+ ):
1395
+ return True
1396
+
1397
+ virtual_env = os.environ.get("VIRTUAL_ENV")
1398
+ if virtual_env and _path_is_within(resolved, Path(virtual_env)):
1399
+ return True
1400
+
1401
+ if _path_is_within(resolved, REPO_ROOT):
1402
+ return True
1403
+
1404
+ uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools"
1405
+ return uv_tools_dir.exists() and _path_is_within(resolved, uv_tools_dir)
1406
+
1407
+
1408
+ def _find_modal_executable(modal_cli: str) -> tuple[str | None, str | None]:
1409
+ """Return the first non-shim executable and the first shim discovered on PATH."""
1410
+
1411
+ if not modal_cli:
1412
+ modal_cli = "modal"
1413
+
1414
+ candidate_path = Path(modal_cli).expanduser()
1415
+ if candidate_path.is_absolute() or len(candidate_path.parts) > 1:
1416
+ resolved_candidate = candidate_path
1417
+ if not resolved_candidate.is_absolute():
1418
+ resolved_candidate = (Path.cwd() / resolved_candidate).resolve()
1419
+ else:
1420
+ resolved_candidate = resolved_candidate.resolve()
1421
+ if not resolved_candidate.exists():
1422
+ raise click.ClickException(f"--modal-cli path does not exist: {resolved_candidate}")
1423
+ if not os.access(resolved_candidate, os.X_OK):
1424
+ raise click.ClickException(f"--modal-cli is not executable: {resolved_candidate}")
1425
+ return str(resolved_candidate), None
1426
+
1427
+ path_env = os.environ.get("PATH", "")
1428
+ if not path_env:
1429
+ return None, None
1430
+
1431
+ seen_dirs: set[str] = set()
1432
+ seen_candidates: set[str] = set()
1433
+ shim_path: str | None = None
1434
+
1435
+ for raw_entry in path_env.split(os.pathsep):
1436
+ if not raw_entry:
1437
+ continue
1438
+ try:
1439
+ resolved_entry = str(Path(raw_entry).resolve())
1440
+ except Exception:
1441
+ resolved_entry = os.path.normpath(raw_entry)
1442
+ if resolved_entry in seen_dirs:
1443
+ continue
1444
+ seen_dirs.add(resolved_entry)
1445
+
1446
+ candidate = shutil.which(modal_cli, path=raw_entry)
1447
+ if candidate is None:
1448
+ continue
1449
+ if candidate in seen_candidates:
1450
+ continue
1451
+ seen_candidates.add(candidate)
1452
+
1453
+ if _is_modal_shim(candidate):
1454
+ if shim_path is None:
1455
+ shim_path = candidate
1456
+ continue
1457
+ return candidate, shim_path
1458
+
1459
+ return None, shim_path
1460
+
1461
+
1027
1462
  def _modal_command_prefix(modal_cli: str) -> list[str]:
1028
1463
  """Resolve a command prefix for invoking the Modal CLI within the active environment."""
1029
- if modal_cli == "modal" and importlib.util.find_spec("modal") is not None:
1464
+
1465
+ force_wrapper_env = os.environ.get("SYNTH_FORCE_MODAL_WRAPPER", "").strip().lower()
1466
+ if force_wrapper_env in {"1", "true", "yes"}:
1467
+ click.secho(
1468
+ "[modal-prefix] SYNTH_FORCE_MODAL_WRAPPER=1 -> using in-process wrapper",
1469
+ fg="yellow",
1470
+ )
1030
1471
  return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
1031
1472
 
1032
- modal_path = shutil.which(modal_cli)
1033
- if modal_path is not None:
1034
- return [modal_path]
1473
+ lookup = modal_cli or "modal"
1474
+ spec = importlib.util.find_spec("modal") if lookup == "modal" else None
1475
+
1476
+ preferred, shim_candidate = _find_modal_executable(lookup)
1477
+ if preferred is not None:
1478
+ detail = f"[modal-prefix] modal_cli={lookup} selected={preferred}"
1479
+ if lookup == "modal":
1480
+ detail += f" spec={'yes' if spec else 'no'}"
1481
+ click.secho(detail, fg="cyan")
1482
+ return [preferred]
1483
+
1484
+ if lookup != "modal":
1485
+ raise click.ClickException(f"Modal CLI not found (looked for '{lookup}')")
1486
+
1487
+ if spec is not None:
1488
+ warning = "[modal-prefix] Using synth-ai modal shim; pass --modal-cli /path/to/modal to override."
1489
+ if shim_candidate is not None:
1490
+ warning = (
1491
+ f"[modal-prefix] Using synth-ai modal shim at {shim_candidate}; "
1492
+ "pass --modal-cli /path/to/modal to override."
1493
+ )
1494
+ click.secho(warning, fg="yellow")
1495
+ click.secho(
1496
+ "[modal-prefix] modal_cli=modal selected=module-wrapper spec=yes",
1497
+ fg="yellow",
1498
+ )
1499
+ return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
1035
1500
 
1036
- if modal_cli == "modal":
1501
+ if shim_candidate is not None:
1037
1502
  raise click.ClickException(
1038
- "Modal CLI not found. Install the 'modal' package in this environment or pass "
1039
- "--modal-cli with an explicit path."
1503
+ "Modal CLI resolution found the synth-ai shim but the 'modal' package "
1504
+ "is not importable in this environment. Install the official Modal CLI "
1505
+ "or pass --modal-cli with its path."
1040
1506
  )
1041
- raise click.ClickException(f"Modal CLI not found (looked for '{modal_cli}')")
1507
+
1508
+ raise click.ClickException(
1509
+ "Modal CLI not found. Install the 'modal' package in this environment or pass "
1510
+ "--modal-cli with an explicit path."
1511
+ )
1042
1512
 
1043
1513
 
1044
1514
  def _build_modal_app_wrapper(original_script: Path) -> tuple[Path, Path]:
@@ -1173,8 +1643,15 @@ def _run_modal_script(
1173
1643
  if modal_name and command == "deploy":
1174
1644
  cmd.extend(["--name", modal_name])
1175
1645
  if dry_run:
1176
- click.echo("Dry run: " + " ".join(cmd))
1646
+ click.echo(
1647
+ "Dry run: " + " ".join(shlex.quote(component) for component in cmd),
1648
+ err=False,
1649
+ )
1177
1650
  return
1651
+ click.secho(
1652
+ "[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
1653
+ fg="cyan",
1654
+ )
1178
1655
  try:
1179
1656
  # Stream output live for better diagnostics
1180
1657
  proc = subprocess.Popen(
@@ -1429,7 +1906,6 @@ def _run_modal_with_entry(
1429
1906
  inline_secret_values=inline_secret_values,
1430
1907
  )
1431
1908
  cmd = [*_modal_command_prefix(modal_cli), command, str(script_path)]
1432
-
1433
1909
  if modal_name and command == "deploy":
1434
1910
  cmd.extend(["--name", modal_name])
1435
1911
 
@@ -1444,9 +1920,13 @@ def _run_modal_with_entry(
1444
1920
  proc_env["PYTHONPATH"] = os.pathsep.join(list(dict.fromkeys(pythonpath_entries)))
1445
1921
 
1446
1922
  if dry_run:
1447
- click.echo("Dry run: " + " ".join(cmd))
1923
+ click.echo("Dry run: " + " ".join(shlex.quote(component) for component in cmd))
1448
1924
  script_path.unlink(missing_ok=True)
1449
1925
  return
1926
+ click.secho(
1927
+ "[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
1928
+ fg="cyan",
1929
+ )
1450
1930
 
1451
1931
  try:
1452
1932
  # Stream output live for better diagnostics
@@ -1531,6 +2011,10 @@ def _parse_env_file(path: Path) -> dict[str, str]:
1531
2011
 
1532
2012
 
1533
2013
  def _interactive_fill_env(env_path: Path) -> Path | None:
2014
+ if not sys.stdin.isatty():
2015
+ raise click.ClickException(
2016
+ "ENVIRONMENT_API_KEY missing. Provide --env-file or run `synth-ai setup` in an interactive shell to create one."
2017
+ )
1534
2018
  existing = _parse_env_file(env_path) if env_path.exists() else {}
1535
2019
 
1536
2020
  def _prompt(label: str, *, default: str = "", required: bool) -> str | None:
@@ -1570,6 +2054,10 @@ def _ensure_env_values(env_paths: list[Path], fallback_dir: Path) -> None:
1570
2054
  if (os.environ.get("ENVIRONMENT_API_KEY") or "").strip():
1571
2055
  return
1572
2056
  target = env_paths[0] if env_paths else (fallback_dir / ".env").resolve()
2057
+ click.echo(
2058
+ "⚠️ ENVIRONMENT_API_KEY not set. Run `uvx synth-ai setup`, "
2059
+ "or pass --env-file pointing at a .env with ENVIRONMENT_API_KEY."
2060
+ )
1573
2061
  result = _interactive_fill_env(target)
1574
2062
  if result is None:
1575
2063
  raise click.ClickException("ENVIRONMENT_API_KEY required to continue")
@@ -1593,7 +2081,7 @@ def _deploy_entry(
1593
2081
  f"Task app '{entry.app_id}' does not define Modal deployment settings"
1594
2082
  )
1595
2083
 
1596
- env_paths = _determine_env_files(entry, env_file)
2084
+ env_paths = _determine_env_files(entry, env_file, original_path=original_path)
1597
2085
  click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
1598
2086
  _run_modal_with_entry(
1599
2087
  entry,
@@ -1620,7 +2108,7 @@ def _modal_serve_entry(
1620
2108
  f"Task app '{entry.app_id}' does not define Modal deployment settings"
1621
2109
  )
1622
2110
 
1623
- env_paths = _determine_env_files(entry, env_file)
2111
+ env_paths = _determine_env_files(entry, env_file, original_path=original_path)
1624
2112
  click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
1625
2113
  _run_modal_with_entry(
1626
2114
  entry,
@@ -1651,6 +2139,255 @@ def list_apps() -> None:
1651
2139
  click.echo(f"- {entry.app_id}{aliases}: {entry.description}")
1652
2140
 
1653
2141
 
2142
+ @task_app_group.command("validate")
2143
+ @click.argument("app_id", type=str, required=True)
2144
+ @click.option(
2145
+ "--url",
2146
+ type=str,
2147
+ default=None,
2148
+ help="Task app URL to validate (if not provided, starts a local server)",
2149
+ )
2150
+ @click.option(
2151
+ "--port",
2152
+ type=int,
2153
+ default=8765,
2154
+ help="Port to use for temporary server (default: 8765)",
2155
+ )
2156
+ @click.option(
2157
+ "--api-key",
2158
+ type=str,
2159
+ default=None,
2160
+ envvar="ENVIRONMENT_API_KEY",
2161
+ help="API key for authentication (default: $ENVIRONMENT_API_KEY)",
2162
+ )
2163
+ @click.option(
2164
+ "--min-instances",
2165
+ type=int,
2166
+ default=10,
2167
+ help="Minimum number of task instances required (default: 10)",
2168
+ )
2169
+ @click.option(
2170
+ "--verbose",
2171
+ "-v",
2172
+ is_flag=True,
2173
+ help="Show detailed information about the task app",
2174
+ )
2175
+ @click.option(
2176
+ "--json",
2177
+ "output_json",
2178
+ is_flag=True,
2179
+ help="Output results as JSON",
2180
+ )
2181
+ def validate_task_app_cmd(
2182
+ app_id: str,
2183
+ url: str | None,
2184
+ port: int,
2185
+ api_key: str | None,
2186
+ min_instances: int,
2187
+ verbose: bool,
2188
+ output_json: bool,
2189
+ ) -> None:
2190
+ """Validate a task app deployment readiness.
2191
+
2192
+ This command verifies that a task app is properly configured and ready to run
2193
+ by checking all required HTTP endpoints, authentication, and task availability.
2194
+
2195
+ By default, it starts a temporary local server for validation. You can also
2196
+ validate a remote deployment by passing --url.
2197
+
2198
+ \b
2199
+ What gets validated:
2200
+ • Root endpoint (/) responds correctly
2201
+ • Health endpoint (/health) is accessible with proper authentication
2202
+ • Info endpoint (/info) returns valid task metadata
2203
+ • Task info endpoint (/task_info) provides task instances
2204
+ • Rollout endpoint (/rollout) is registered
2205
+ • At least N task instances are available (default: 10)
2206
+
2207
+ \b
2208
+ Examples:
2209
+
2210
+ \b
2211
+ Validate grpo-crafter (starts local server automatically):
2212
+ $ synth-ai task-app validate grpo-crafter
2213
+
2214
+ \b
2215
+ Validate sokoban with verbose output:
2216
+ $ synth-ai task-app validate sokoban --verbose
2217
+
2218
+ \b
2219
+ Validate with custom port:
2220
+ $ synth-ai task-app validate sokoban --port 9000
2221
+
2222
+ \b
2223
+ Validate a remote deployment:
2224
+ $ synth-ai task-app validate grpo-crafter --url https://my-crafter.modal.run
2225
+
2226
+ \b
2227
+ Require at least 20 task instances:
2228
+ $ synth-ai task-app validate grpo-crafter --min-instances 20
2229
+
2230
+ \b
2231
+ Get JSON output for automation:
2232
+ $ synth-ai task-app validate sokoban --json
2233
+
2234
+ \b
2235
+ Common use cases:
2236
+ • Pre-deployment verification: Check task app works before deploying to Modal
2237
+ • CI/CD integration: Use --json flag for automated validation in pipelines
2238
+ • Debug failing deployments: Use --verbose to see detailed endpoint responses
2239
+ • Test API key configuration: Verify authentication is set up correctly
2240
+ """
2241
+ import asyncio
2242
+ import socket
2243
+ import subprocess
2244
+ import tempfile
2245
+ import time
2246
+
2247
+ # Import the validate_task_app function defined in this module
2248
+ from synth_ai.cli._validate_task_app import validate_task_app # type: ignore[attr-defined]
2249
+
2250
+ proc = None
2251
+ task_app_url = url
2252
+
2253
+ try:
2254
+ # If no URL provided, start a temporary server
2255
+ if not task_app_url:
2256
+ # Find an available port
2257
+ def is_port_available(port: int) -> bool:
2258
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
2259
+ try:
2260
+ s.bind(("", port))
2261
+ return True
2262
+ except OSError:
2263
+ return False
2264
+
2265
+ while not is_port_available(port):
2266
+ port += 1
2267
+
2268
+ task_app_url = f"http://localhost:{port}"
2269
+
2270
+ if not output_json:
2271
+ click.echo(f"Starting temporary {app_id} server on port {port}...")
2272
+
2273
+ # Start the server in background
2274
+ env = os.environ.copy()
2275
+ if api_key:
2276
+ env["ENVIRONMENT_API_KEY"] = api_key
2277
+
2278
+ # Create a temporary trace DB and trace dir to avoid prompts
2279
+ import tempfile
2280
+ temp_dir = tempfile.mkdtemp()
2281
+ temp_trace_db = os.path.join(temp_dir, "validate_trace.db")
2282
+ temp_trace_dir = os.path.join(temp_dir, "traces")
2283
+ os.makedirs(temp_trace_dir, exist_ok=True)
2284
+
2285
+ proc = subprocess.Popen(
2286
+ [
2287
+ "uv",
2288
+ "run",
2289
+ "synth-ai",
2290
+ "task-app",
2291
+ "serve",
2292
+ app_id,
2293
+ "--port",
2294
+ str(port),
2295
+ "--no-reload",
2296
+ "--trace",
2297
+ temp_trace_dir,
2298
+ "--trace-db",
2299
+ temp_trace_db,
2300
+ ],
2301
+ env=env,
2302
+ stdin=subprocess.PIPE, # Add stdin to handle any prompts
2303
+ stdout=subprocess.DEVNULL if output_json else subprocess.PIPE,
2304
+ stderr=subprocess.DEVNULL if output_json else subprocess.PIPE,
2305
+ text=True,
2306
+ )
2307
+
2308
+ # Write empty input to stdin to skip any prompts
2309
+ if proc.stdin:
2310
+ try:
2311
+ proc.stdin.write("\n")
2312
+ proc.stdin.flush()
2313
+ proc.stdin.close()
2314
+ except Exception:
2315
+ pass
2316
+
2317
+ # Wait for server to be ready
2318
+ if not output_json:
2319
+ click.echo("Waiting for server to start...")
2320
+
2321
+ import httpx
2322
+ for _attempt in range(60): # 30 seconds timeout
2323
+ try:
2324
+ async def check_health():
2325
+ async with httpx.AsyncClient(timeout=2.0) as client:
2326
+ resp = await client.get(f"{task_app_url}/")
2327
+ return resp.status_code == 200
2328
+
2329
+ if asyncio.run(check_health()):
2330
+ break
2331
+ except Exception:
2332
+ pass
2333
+
2334
+ # Check if process died
2335
+ if proc.poll() is not None:
2336
+ stderr_output = ""
2337
+ if proc.stderr and not output_json:
2338
+ stderr_output = proc.stderr.read()
2339
+ click.echo(click.style("✗ Server process exited unexpectedly", fg="red"), err=True)
2340
+ if stderr_output and not output_json:
2341
+ click.echo(f"Error output:\n{stderr_output}", err=True)
2342
+ sys.exit(1)
2343
+
2344
+ time.sleep(0.5)
2345
+ else:
2346
+ click.echo(click.style("✗ Server failed to start within 30 seconds", fg="red"), err=True)
2347
+ sys.exit(1)
2348
+
2349
+ if not output_json:
2350
+ click.echo(click.style("✓ Server started", fg="green"))
2351
+ click.echo()
2352
+
2353
+ # Ensure URL doesn't have trailing slash
2354
+ task_app_url = task_app_url.rstrip("/")
2355
+
2356
+ async def _run() -> tuple[bool, dict[str, Any]]:
2357
+ return await validate_task_app(
2358
+ url=task_app_url,
2359
+ api_key=api_key,
2360
+ min_instances=min_instances,
2361
+ verbose=verbose,
2362
+ )
2363
+
2364
+ success, results = asyncio.run(_run())
2365
+
2366
+ if output_json:
2367
+ import json as _json
2368
+ click.echo(_json.dumps(results, indent=2))
2369
+
2370
+ sys.exit(0 if success else 1)
2371
+
2372
+ finally:
2373
+ # Cleanup: stop the temporary server
2374
+ if proc is not None:
2375
+ if not output_json:
2376
+ click.echo("\nStopping temporary server...")
2377
+ try:
2378
+ proc.terminate()
2379
+ proc.wait(timeout=5)
2380
+ except Exception:
2381
+ proc.kill()
2382
+
2383
+ # Cleanup temp trace DB
2384
+ if not url and 'temp_dir' in locals():
2385
+ import contextlib
2386
+ import shutil
2387
+ with contextlib.suppress(Exception):
2388
+ shutil.rmtree(temp_dir, ignore_errors=True)
2389
+
2390
+
1654
2391
  def _load_env_files_into_process(paths: Sequence[str]) -> None:
1655
2392
  for p in paths:
1656
2393
  try:
@@ -1907,7 +2644,9 @@ def serve_task_group(
1907
2644
  )
1908
2645
 
1909
2646
 
1910
- def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str]) -> list[Path]:
2647
+ def _determine_env_files(
2648
+ entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
2649
+ ) -> list[Path]:
1911
2650
  resolved: list[Path] = []
1912
2651
  for candidate in user_env_files:
1913
2652
  p = Path(candidate).expanduser()
@@ -1917,30 +2656,46 @@ def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str])
1917
2656
  if resolved:
1918
2657
  return resolved
1919
2658
 
1920
- # Always prompt for env file selection instead of auto-loading defaults
1921
- # Look for env files in current working directory first, then repo root
1922
- cwd = Path.cwd()
1923
- env_candidates = []
1924
-
1925
- # Add CWD env files first (prioritized)
1926
- cwd_env_files = sorted(cwd.glob("**/*.env"))
1927
- env_candidates.extend(cwd_env_files)
2659
+ declared: list[Path] = []
2660
+ for candidate in getattr(entry, "env_files", ()) or ():
2661
+ try:
2662
+ p = Path(candidate).expanduser()
2663
+ except Exception:
2664
+ continue
2665
+ if p.exists() and p.is_file():
2666
+ declared.append(p)
2667
+ if declared:
2668
+ return declared
1928
2669
 
1929
- # Add repo root env files
1930
- repo_env_files = sorted(REPO_ROOT.glob("**/*.env"))
1931
- # Avoid duplicates
1932
- for repo_file in repo_env_files:
1933
- if repo_file not in env_candidates:
1934
- env_candidates.append(repo_file)
2670
+ def _append_candidate(collection: list[Path], candidate: Path) -> None:
2671
+ if candidate.exists() and candidate.is_file() and candidate not in collection:
2672
+ collection.append(candidate)
1935
2673
 
1936
- if not env_candidates:
1937
- raise click.ClickException("No env file found. Pass --env-file explicitly.")
2674
+ auto_candidates: list[Path] = []
1938
2675
 
1939
- click.echo("Select env file to load:")
1940
- for idx, path in enumerate(env_candidates, start=1):
1941
- click.echo(f" {idx}) {path.resolve()}")
1942
- choice = click.prompt("Enter choice", type=click.IntRange(1, len(env_candidates)), default=1)
1943
- return [env_candidates[choice - 1]]
2676
+ search_dirs: list[Path] = []
2677
+ if original_path is not None:
2678
+ search_dirs.append(original_path.parent.resolve())
2679
+ for parent in original_path.parent.resolve().parents:
2680
+ search_dirs.append(parent)
2681
+ cwd = Path.cwd().resolve()
2682
+ if cwd not in search_dirs:
2683
+ search_dirs.append(cwd)
2684
+ repo_root = REPO_ROOT.resolve()
2685
+ if repo_root not in search_dirs:
2686
+ search_dirs.append(repo_root)
2687
+
2688
+ for directory in search_dirs:
2689
+ _append_candidate(auto_candidates, directory / ".env")
2690
+ for candidate in sorted(directory.glob("*.env")):
2691
+ _append_candidate(auto_candidates, candidate)
2692
+
2693
+ if auto_candidates:
2694
+ return [auto_candidates[0]]
2695
+
2696
+ raise click.ClickException(
2697
+ "No .env file discovered automatically. Pass --env-file /path/to/.env or generate one with `uvx synth-ai setup`."
2698
+ )
1944
2699
 
1945
2700
 
1946
2701
  def _ensure_port_free(port: int, host: str, *, force: bool) -> None:
@@ -2242,7 +2997,14 @@ def deploy_app(
2242
2997
  def modal_serve_app(
2243
2998
  app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
2244
2999
  ) -> None:
2245
- choice = _select_app_choice(app_id, purpose="modal-serve")
3000
+ click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
3001
+ try:
3002
+ choice = _select_app_choice(app_id, purpose="modal-serve")
3003
+ except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
3004
+ raise click.ClickException(
3005
+ f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
3006
+ "Make sure you're running the Click CLI (synth_ai.cli:cli)."
3007
+ ) from exc
2246
3008
 
2247
3009
  if choice.modal_script:
2248
3010
  env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
@@ -2251,6 +3013,7 @@ def modal_serve_app(
2251
3013
  return
2252
3014
 
2253
3015
  entry = choice.ensure_entry()
3016
+ click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
2254
3017
  _modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
2255
3018
 
2256
3019
 
@@ -2313,6 +3076,11 @@ def _write_modal_entrypoint(
2313
3076
  if not any(str(p).startswith("synth-ai") for p in pip_packages):
2314
3077
  pip_packages.insert(0, synth_pkg)
2315
3078
 
3079
+ apt_packages = list(modal_cfg.apt_packages)
3080
+ click.echo(f"[DEBUG] modal_cfg.apt_packages type: {type(modal_cfg.apt_packages)}")
3081
+ click.echo(f"[DEBUG] modal_cfg.apt_packages value: {modal_cfg.apt_packages}")
3082
+ click.echo(f"[DEBUG] apt_packages after list(): {apt_packages}")
3083
+
2316
3084
  local_dirs = [(str(Path(src)), dst) for src, dst in modal_cfg.extra_local_dirs]
2317
3085
  # Also mount the host synth_ai source if available to ensure latest code is used
2318
3086
  if host_synth is not None:
@@ -2359,6 +3127,15 @@ INLINE_SECRET_VALUES = {inline_secret_values!r}
2359
3127
 
2360
3128
  image = Image.debian_slim(python_version={modal_cfg.python_version!r})
2361
3129
 
3130
+ # CRITICAL: Install iverilog for Verilog task app (hardcoded to prevent config issues)
3131
+ if {entry.app_id!r} == "grpo-verilog":
3132
+ image = image.apt_install("iverilog")
3133
+
3134
+ # Install apt packages first (before pip)
3135
+ apt_packages = {apt_packages!r}
3136
+ if apt_packages:
3137
+ image = image.apt_install(*apt_packages)
3138
+
2362
3139
  pip_packages = {pip_packages!r}
2363
3140
  if pip_packages:
2364
3141
  image = image.pip_install(*pip_packages)
@@ -2480,22 +3257,60 @@ def register(cli: click.Group) -> None:
2480
3257
  cli.add_command(serve_command)
2481
3258
  cli.add_command(task_app_group)
2482
3259
  cli.add_command(eval_command)
3260
+ cli.add_command(filter_command)
2483
3261
 
2484
3262
 
2485
- @click.command("eval")
3263
+ @click.command(
3264
+ "eval",
3265
+ help="Run one-off rollouts against a task app and print judge/eval summaries.",
3266
+ )
2486
3267
  @click.argument("app_id", type=str, required=False)
2487
- @click.option("--config", type=click.Path(), default=None, help="Path to eval TOML (short schema)")
3268
+ @click.option(
3269
+ "--config",
3270
+ type=click.Path(),
3271
+ default=None,
3272
+ help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
3273
+ )
2488
3274
  @click.option(
2489
3275
  "--url",
2490
3276
  "task_app_url",
2491
3277
  type=str,
2492
3278
  default=None,
2493
- help="Base URL of a running task app (skip in-process server)",
3279
+ help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
3280
+ )
3281
+ @click.option(
3282
+ "--seeds",
3283
+ default="0,1,2,3,4",
3284
+ help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
2494
3285
  )
2495
- @click.option("--seeds", default="0,1,2,3,4", help="Comma-separated seeds/indices to evaluate")
2496
3286
  @click.option("--split", default="train", show_default=True, help="Dataset split to use")
2497
- @click.option("--model", default=None, help="Model identifier (prompted if omitted)")
2498
- @click.option("--env-file", multiple=True, type=click.Path(), help="Env file(s) for keys")
3287
+ @click.option(
3288
+ "--model",
3289
+ default=None,
3290
+ help="Model identifier. When omitted the CLI will prompt based on task metadata.",
3291
+ )
3292
+ @click.option(
3293
+ "--env-file",
3294
+ multiple=True,
3295
+ type=click.Path(),
3296
+ help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
3297
+ )
3298
+ @click.option(
3299
+ "--trace-db",
3300
+ default="traces/v3/synth_ai.db",
3301
+ show_default=True,
3302
+ help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
3303
+ )
3304
+ @click.option(
3305
+ "--metadata",
3306
+ multiple=True,
3307
+ help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
3308
+ )
3309
+ @click.option(
3310
+ "--metadata-sql",
3311
+ default=None,
3312
+ help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
3313
+ )
2499
3314
  def eval_command(
2500
3315
  app_id: str | None,
2501
3316
  config: str | None,
@@ -2504,10 +3319,24 @@ def eval_command(
2504
3319
  split: str,
2505
3320
  model: str | None,
2506
3321
  env_file: Sequence[str],
3322
+ trace_db: str,
3323
+ metadata: Sequence[str],
3324
+ metadata_sql: str | None,
2507
3325
  ) -> None:
2508
- """Run local rollouts against a task app using in-process ASGI and summarize results."""
3326
+ """Run rollouts against a task app and report judge statistics.
3327
+
3328
+ By default the command spins up the selected task app in-process, executes the
3329
+ requested seeds, and prints aggregate scores (official and custom judges). When
3330
+ pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
3331
+ forward authentication headers to the running service.
3332
+ """
3333
+ # Parse and validate TOML config
3334
+ from synth_ai.task.config import EvalConfig
3335
+
2509
3336
  cfg: dict[str, Any] = {}
3337
+ eval_cfg: EvalConfig | None = None
2510
3338
  config_path: Path | None = None
3339
+
2511
3340
  if config:
2512
3341
  config_path = Path(config)
2513
3342
  else:
@@ -2529,10 +3358,73 @@ def eval_command(
2529
3358
  if isinstance(parsed, dict):
2530
3359
  section = parsed.get("eval")
2531
3360
  cfg = dict(section) if isinstance(section, dict) else dict(parsed)
3361
+
3362
+ # Validate config with dataclass
3363
+ try:
3364
+ eval_cfg = EvalConfig.from_dict(cfg)
3365
+ click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
3366
+ except (ValueError, TypeError) as validation_error:
3367
+ raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
3368
+ except click.ClickException:
3369
+ raise
2532
3370
  except Exception as exc:
2533
3371
  raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
2534
3372
 
2535
- app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
3373
+ # CLI args override config
3374
+ if eval_cfg:
3375
+ app_id = app_id or eval_cfg.app_id
3376
+ else:
3377
+ app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
3378
+
3379
+ metadata_filters: dict[str, str] = {}
3380
+ if eval_cfg:
3381
+ metadata_filters.update(eval_cfg.metadata)
3382
+ else:
3383
+ cfg_metadata = cfg.get("metadata")
3384
+ if isinstance(cfg_metadata, dict):
3385
+ for key, value in cfg_metadata.items():
3386
+ metadata_filters[str(key)] = str(value)
3387
+ elif isinstance(cfg_metadata, list):
3388
+ for item in cfg_metadata:
3389
+ if isinstance(item, str) and "=" in item:
3390
+ key, value = item.split("=", 1)
3391
+ metadata_filters[key.strip()] = value.strip()
3392
+
3393
+ for item in metadata or ():
3394
+ if "=" not in item:
3395
+ raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
3396
+ key, value = item.split("=", 1)
3397
+ key = key.strip()
3398
+ value = value.strip()
3399
+ if not key or not value:
3400
+ raise click.ClickException(f"Invalid metadata filter: {item}")
3401
+ metadata_filters[key] = value
3402
+
3403
+ metadata_sql_query: str | None = None
3404
+ if eval_cfg and eval_cfg.metadata_sql:
3405
+ metadata_sql_query = eval_cfg.metadata_sql
3406
+ else:
3407
+ cfg_metadata_sql = cfg.get("metadata_sql")
3408
+ if isinstance(cfg_metadata_sql, dict):
3409
+ metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
3410
+ elif isinstance(cfg_metadata_sql, str):
3411
+ metadata_sql_query = cfg_metadata_sql
3412
+
3413
+ if metadata_sql:
3414
+ metadata_sql_query = metadata_sql
3415
+ if metadata_sql_query is not None:
3416
+ metadata_sql_query = str(metadata_sql_query)
3417
+
3418
+ trace_db_url: str | None = None
3419
+ trace_db = (trace_db or "").strip()
3420
+ if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
3421
+ if "://" in trace_db:
3422
+ trace_db_url = trace_db
3423
+ else:
3424
+ trace_path = Path(trace_db).expanduser()
3425
+ trace_path.parent.mkdir(parents=True, exist_ok=True)
3426
+ trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
3427
+ trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
2536
3428
 
2537
3429
  # Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
2538
3430
  if cfg.get("model") and not model:
@@ -2553,14 +3445,16 @@ def eval_command(
2553
3445
  elif isinstance(ef, list):
2554
3446
  env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
2555
3447
 
3448
+ choice_for_env: AppChoice | None = None
2556
3449
  entry: TaskAppEntryType | None = None
2557
3450
  if task_app_url is None:
2558
- choice = _select_app_choice(app_id, purpose="eval")
2559
- entry = choice.ensure_entry()
3451
+ choice_for_env = _select_app_choice(app_id, purpose="eval")
3452
+ entry = choice_for_env.ensure_entry()
2560
3453
 
2561
3454
  env_paths: list[Path] = []
2562
3455
  if entry is not None:
2563
- env_paths = _determine_env_files(entry, env_file)
3456
+ original_env_path = choice_for_env.path if choice_for_env is not None else None
3457
+ env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
2564
3458
  else:
2565
3459
  if not env_file:
2566
3460
  raise click.ClickException("--env-file is required when using --url")
@@ -2583,12 +3477,30 @@ def eval_command(
2583
3477
  app = create_task_app(config)
2584
3478
 
2585
3479
  # Determine supported models
3480
+ inference_meta: dict[str, Any] = {}
2586
3481
  supported: list[str] = []
3482
+ seen_models: set[str] = set()
3483
+
3484
+ def _add_supported_model(candidate: Any) -> None:
3485
+ if not candidate:
3486
+ return
3487
+ text = str(candidate).strip()
3488
+ if not text or text in seen_models:
3489
+ return
3490
+ supported.append(text)
3491
+ seen_models.add(text)
3492
+
2587
3493
  if task_app_url is None:
2588
3494
  try:
2589
- supported = list((config.base_task_info.inference or {}).get("models") or []) # type: ignore[union-attr]
3495
+ if hasattr(config, "base_task_info") and config.base_task_info:
3496
+ inf_obj = getattr(config.base_task_info, "inference", None)
3497
+ if inf_obj is not None:
3498
+ if hasattr(inf_obj, "model_dump"):
3499
+ inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
3500
+ elif isinstance(inf_obj, dict):
3501
+ inference_meta = dict(inf_obj)
2590
3502
  except Exception:
2591
- supported = []
3503
+ inference_meta = {}
2592
3504
  else:
2593
3505
  try:
2594
3506
  import httpx as _hx
@@ -2601,38 +3513,38 @@ def eval_command(
2601
3513
  info = c.get("/info").json()
2602
3514
  inf = info.get("inference") if isinstance(info, dict) else None
2603
3515
  if isinstance(inf, dict):
2604
- m = inf.get("models")
2605
- if isinstance(m, list):
2606
- supported = [str(x) for x in m]
2607
- if not supported:
2608
- providers = inf.get("providers")
2609
- if isinstance(providers, list):
2610
- if "openai" in providers:
2611
- supported.append("gpt-5")
2612
- if "groq" in providers:
2613
- supported.append("groq:llama-3.1-70b-versatile")
2614
- supported.append("synth:qwen-0.6b")
3516
+ inference_meta = dict(inf)
2615
3517
  except Exception:
2616
- supported = []
2617
- if not supported:
2618
- # Only fall back to local config-derived providers when running in-process
2619
- if task_app_url is None:
2620
- try:
2621
- providers = list((config.base_task_info.inference or {}).get("providers") or []) # type: ignore[union-attr]
2622
- except Exception:
2623
- providers = []
2624
- if "openai" in providers:
2625
- supported.append("gpt-5")
2626
- if "groq" in providers:
2627
- supported.append("groq:llama-3.1-70b-versatile")
2628
- # Always include a local synth model option for smoke tests
2629
- supported.append("synth:qwen-0.6b")
3518
+ inference_meta = {}
3519
+
3520
+ default_model = inference_meta.get("model")
3521
+ if isinstance(default_model, str):
3522
+ _add_supported_model(default_model)
3523
+
3524
+ models_field = inference_meta.get("models")
3525
+ if isinstance(models_field, list):
3526
+ for candidate in models_field:
3527
+ _add_supported_model(candidate)
3528
+
3529
+ supported_models = inference_meta.get("supported_models")
3530
+ if isinstance(supported_models, list):
3531
+ for candidate in supported_models:
3532
+ _add_supported_model(candidate)
3533
+
3534
+ providers = inference_meta.get("providers")
3535
+ if isinstance(providers, list):
3536
+ if "openai" in providers:
3537
+ _add_supported_model("gpt-5")
3538
+ if "groq" in providers:
3539
+ _add_supported_model("groq:llama-3.1-70b-versatile")
3540
+
3541
+ _add_supported_model("synth:qwen-0.6b")
2630
3542
 
2631
3543
  selected_model = model
2632
3544
  if not selected_model:
2633
3545
  if not supported:
2634
3546
  raise click.ClickException(
2635
- "No supported models; supply --model or add base_task_info.inference.models"
3547
+ "No supported models; supply --model or add base_task_info.inference.model"
2636
3548
  )
2637
3549
  click.echo("Select model to evaluate:")
2638
3550
  for idx, m in enumerate(supported, start=1):
@@ -2652,70 +3564,402 @@ def eval_command(
2652
3564
  if api_key:
2653
3565
  headers["X-API-Key"] = api_key
2654
3566
 
3567
+ # Precompute optional policy overrides from TOML
3568
+ policy_overrides: dict[str, Any] = {}
3569
+ try:
3570
+ # Accept [eval.policy] table or top-level keys for convenience
3571
+ if isinstance(cfg.get("policy"), dict):
3572
+ policy_overrides.update(dict(cfg["policy"]))
3573
+ # Back-compat: allow temperature/max_tokens at top level
3574
+ for k in (
3575
+ "temperature",
3576
+ "max_tokens",
3577
+ "reasoning_effort",
3578
+ "system_hint",
3579
+ "tool_choice",
3580
+ "inference_url",
3581
+ ):
3582
+ if k in cfg and k not in policy_overrides:
3583
+ policy_overrides[k] = cfg.get(k)
3584
+ except Exception:
3585
+ policy_overrides = {}
3586
+
3587
+ raw_concurrency = cfg.get("concurrency")
3588
+ try:
3589
+ concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
3590
+ except Exception:
3591
+ concurrency_limit = 1
3592
+ if concurrency_limit <= 0:
3593
+ concurrency_limit = 1
3594
+ concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
3595
+
3596
+ judge_specs: list[JudgeSpec] = []
3597
+
3598
+ def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
3599
+ if not judge_cfg:
3600
+ return
3601
+ judge_module = judge_cfg.get("module")
3602
+ judge_path = judge_cfg.get("path")
3603
+ judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
3604
+ if judge_module and judge_path:
3605
+ raise click.ClickException("Judge config cannot set both 'module' and 'path'")
3606
+ if not judge_module and not judge_path:
3607
+ raise click.ClickException("Judge config requires 'module' or 'path'")
3608
+ try:
3609
+ if judge_module:
3610
+ module = importlib.import_module(str(judge_module))
3611
+ else:
3612
+ path = Path(str(judge_path)).expanduser()
3613
+ if not path.exists():
3614
+ raise click.ClickException(f"Judge module path not found: {path}")
3615
+ spec = importlib.util.spec_from_file_location(
3616
+ f"_eval_judge_{path.stem}", path
3617
+ )
3618
+ if not spec or not spec.loader:
3619
+ raise click.ClickException(f"Failed to load judge module from {path}")
3620
+ module = importlib.util.module_from_spec(spec)
3621
+ sys.modules[spec.name] = module
3622
+ spec.loader.exec_module(module)
3623
+ except click.ClickException:
3624
+ raise
3625
+ except Exception as exc:
3626
+ raise click.ClickException(f"Unable to load judge module: {exc}") from exc
3627
+
3628
+ if judge_callable_name:
3629
+ try:
3630
+ judge_fn = getattr(module, str(judge_callable_name))
3631
+ except AttributeError as exc:
3632
+ raise click.ClickException(
3633
+ f"Judge callable '{judge_callable_name}' not found in module"
3634
+ ) from exc
3635
+ else:
3636
+ if hasattr(module, "judge"):
3637
+ judge_fn = module.judge
3638
+ else:
3639
+ raise click.ClickException("Judge module must expose 'judge' callable")
3640
+
3641
+ if not callable(judge_fn):
3642
+ raise click.ClickException("Judge callable is not callable")
3643
+
3644
+ judge_kwargs = {
3645
+ k: v
3646
+ for k, v in judge_cfg.items()
3647
+ if k not in {"module", "path", "callable", "function", "name"}
3648
+ }
3649
+ display_name = str(
3650
+ judge_cfg.get("name")
3651
+ or name_hint
3652
+ or f"judge{len(judge_specs) + 1}"
3653
+ )
3654
+ judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
3655
+
3656
+ raw_judge_cfg = cfg.get("judge")
3657
+ if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
3658
+ direct_keys = {"module", "path", "callable", "function", "name"}
3659
+ has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
3660
+ nested_candidates = [
3661
+ (key, value)
3662
+ for key, value in raw_judge_cfg.items()
3663
+ if isinstance(value, dict)
3664
+ ]
3665
+ if has_direct_keys and not nested_candidates:
3666
+ _register_judge(None, raw_judge_cfg)
3667
+ else:
3668
+ for sub_name, sub_cfg in nested_candidates:
3669
+ _register_judge(sub_name, sub_cfg)
3670
+
3671
+ raw_judges_list = cfg.get("judges")
3672
+ if isinstance(raw_judges_list, list):
3673
+ for _index, entry in enumerate(raw_judges_list, start=1):
3674
+ if isinstance(entry, dict):
3675
+ _register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
3676
+
3677
+ records: list[dict[str, Any]] = []
3678
+
2655
3679
  successes = 0
2656
3680
  failures = 0
2657
3681
  # Aggregate outcome stats across successful seeds
2658
3682
  outcome_sum: float = 0.0
2659
3683
  outcome_count: int = 0
2660
3684
  outcome_correct: int = 0
2661
- if task_app_url is None:
2662
- transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
2663
- # Newer httpx types consider ASGITransport under httpx._transports; cast to satisfy type checker
2664
- client = httpx.Client(
2665
- transport=cast(Any, transport),
2666
- base_url="http://eval.local",
2667
- timeout=60.0,
2668
- headers=headers,
2669
- )
2670
- else:
2671
- client = httpx.Client(base_url=task_app_url, timeout=60.0, headers=headers)
2672
- try:
2673
- with contextlib.suppress(Exception):
2674
- client.get("/task_info")
2675
- # Precompute optional policy overrides from TOML
2676
- policy_overrides: dict[str, Any] = {}
2677
- try:
2678
- # Accept [eval.policy] table or top-level keys for convenience
2679
- if isinstance(cfg.get("policy"), dict):
2680
- policy_overrides.update(dict(cfg["policy"]))
2681
- # Back-compat: allow temperature/max_tokens at top level
2682
- for k in (
2683
- "temperature",
2684
- "max_tokens",
2685
- "reasoning_effort",
2686
- "system_hint",
2687
- "tool_choice",
2688
- ):
2689
- if k in cfg and k not in policy_overrides:
2690
- policy_overrides[k] = cfg.get(k)
2691
- except Exception:
2692
- policy_overrides = {}
2693
-
2694
- for seed_val in seed_values:
2695
- body = {
2696
- "run_id": str(uuid.uuid4()),
2697
- "env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
2698
- "policy": {
2699
- "policy_name": selected_model,
2700
- "config": {"model": selected_model, **policy_overrides},
2701
- },
2702
- "ops": [],
3685
+
3686
+ def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
3687
+ rows: dict[int, dict[str, Any]] = {}
3688
+ if not isinstance(taskset, dict):
3689
+ return rows
3690
+
3691
+ scenario_ids = taskset.get("scenario_ids") or []
3692
+ loop_ids = taskset.get("loop_ids") or []
3693
+ thread_ids = taskset.get("thread_ids") or []
3694
+ difficulty_map = taskset.get("difficulty_map") or {}
3695
+
3696
+ max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
3697
+ for seed in range(max_len):
3698
+ scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
3699
+ loop_id = loop_ids[seed] if seed < len(loop_ids) else None
3700
+ thread_id = thread_ids[seed] if seed < len(thread_ids) else None
3701
+ difficulty = None
3702
+ if isinstance(difficulty_map, dict):
3703
+ if scenario_id and scenario_id in difficulty_map:
3704
+ difficulty = difficulty_map.get(scenario_id)
3705
+ elif str(seed) in difficulty_map:
3706
+ difficulty = difficulty_map.get(str(seed))
3707
+
3708
+ rows[seed] = {
3709
+ "seed": seed,
3710
+ "scenario_id": scenario_id,
3711
+ "loop_id": loop_id,
3712
+ "thread_id": thread_id,
3713
+ "difficulty": difficulty,
2703
3714
  }
3715
+ return rows
3716
+
3717
+ def _apply_metadata_filters(
3718
+ rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
3719
+ ) -> list[int]:
3720
+ if not filters:
3721
+ return seeds_list
3722
+ filtered: list[int] = []
3723
+ for seed in seeds_list:
3724
+ row = rows.get(seed)
3725
+ if not row:
3726
+ continue
3727
+ include = True
3728
+ for key, expected in filters.items():
3729
+ actual = row.get(key)
3730
+ if actual is None:
3731
+ include = False
3732
+ break
3733
+ if str(actual).lower() != expected.lower():
3734
+ include = False
3735
+ break
3736
+ if include:
3737
+ filtered.append(seed)
3738
+ return filtered
3739
+
3740
+ def _apply_metadata_sql(
3741
+ rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
3742
+ ) -> list[int]:
3743
+ """Return seeds that satisfy an arbitrary SQL query.
3744
+
3745
+ The query is executed against an in-memory SQLite table named `tasks`
3746
+ with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
3747
+ Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
3748
+ """
3749
+ if not query:
3750
+ return seeds_list
3751
+ conn = sqlite3.connect(":memory:")
3752
+ try:
3753
+ cur = conn.cursor()
3754
+ cur.execute(
3755
+ "CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
3756
+ )
3757
+ insert_stmt = (
3758
+ "INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
3759
+ )
3760
+ for seed in seeds_list:
3761
+ row = rows.get(seed, {})
3762
+ cur.execute(
3763
+ insert_stmt,
3764
+ [
3765
+ seed,
3766
+ row.get("scenario_id"),
3767
+ row.get("loop_id"),
3768
+ row.get("thread_id"),
3769
+ row.get("difficulty"),
3770
+ ],
3771
+ )
3772
+
3773
+ result = cur.execute(query)
3774
+ fetched = result.fetchall()
3775
+ if not fetched:
3776
+ return []
3777
+ description = result.description or []
3778
+ col_names = [col[0] for col in description]
3779
+ seeds_out: list[int] = []
3780
+ for entry in fetched:
3781
+ value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
3782
+ try:
3783
+ seeds_out.append(int(value))
3784
+ except Exception as exc:
3785
+ raise click.ClickException(
3786
+ "metadata SQL query must return seed integers"
3787
+ ) from exc
3788
+ seeds_set = set(seeds_out)
3789
+ return [seed for seed in seeds_list if seed in seeds_set]
3790
+ except sqlite3.Error as exc:
3791
+ raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
3792
+ finally:
3793
+ conn.close()
3794
+
3795
+ async def _run_eval() -> None:
3796
+ nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
3797
+
3798
+ if trace_tracer is not None and trace_tracer.db is None:
3799
+ await trace_tracer.initialize()
3800
+
3801
+ if task_app_url is None:
3802
+ transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
3803
+ async_client = httpx.AsyncClient(
3804
+ transport=cast(Any, transport),
3805
+ base_url="http://eval.local",
3806
+ timeout=300.0,
3807
+ follow_redirects=True,
3808
+ headers=headers,
3809
+ )
3810
+ else:
3811
+ async_client = httpx.AsyncClient(
3812
+ base_url=task_app_url,
3813
+ timeout=300.0,
3814
+ follow_redirects=True,
3815
+ headers=headers,
3816
+ )
3817
+
3818
+ try:
3819
+ taskset_payload: dict[str, Any] | None = None
2704
3820
  try:
2705
- resp = client.post("/rollout", json=body)
2706
- ok = 200 <= resp.status_code < 300
3821
+ task_info_response = await async_client.get("/task_info")
3822
+ except Exception:
3823
+ task_info_response = None
3824
+ if task_info_response is not None and task_info_response.status_code == 200:
3825
+ with contextlib.suppress(Exception):
3826
+ payload_json = task_info_response.json()
3827
+ if isinstance(payload_json, dict) and "taskset" in payload_json:
3828
+ taskset_payload = payload_json.get("taskset")
3829
+ if not isinstance(taskset_payload, dict):
3830
+ taskset_payload = None
3831
+ elif isinstance(payload_json, dict):
3832
+ taskset_payload = payload_json
3833
+
3834
+ available_seeds = list(seed_values)
3835
+ if metadata_sql_query or metadata_filters:
3836
+ if not taskset_payload:
3837
+ raise click.ClickException(
3838
+ "Task metadata filters require the task app to expose /task_info metadata"
3839
+ )
3840
+ rows = _build_task_rows(taskset_payload)
3841
+ if metadata_sql_query:
3842
+ available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
3843
+ if metadata_filters:
3844
+ available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
3845
+ if not available_seeds:
3846
+ raise click.ClickException("No seeds match the provided metadata filters")
3847
+ seed_values = available_seeds
3848
+
3849
+ semaphore = asyncio.Semaphore(concurrency_limit)
3850
+
3851
+ async def _run_seed(seed_val: int) -> None:
3852
+ nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
3853
+ # Read env_name and policy_name from config if available
3854
+ env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
3855
+ policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
3856
+ env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
3857
+ policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
3858
+
3859
+ # Debug: print config parsing
3860
+ if seed_val == 0:
3861
+ click.echo(f"[DEBUG] env_name from config: {env_name}")
3862
+ click.echo(f"[DEBUG] policy_name from config: {policy_name}")
3863
+
3864
+ # Generate default ops sequence if not provided
3865
+ max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
3866
+ ops_list = cfg.get("ops", [])
3867
+ if not ops_list:
3868
+ # Generate default "agent, env" pairs for max_llm_calls
3869
+ ops_list = ["agent", "env"] * int(max_llm_calls)
3870
+
3871
+ body = {
3872
+ "run_id": str(uuid.uuid4()),
3873
+ "env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
3874
+ "policy": {
3875
+ "policy_name": policy_name or selected_model,
3876
+ "config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
3877
+ },
3878
+ "ops": ops_list,
3879
+ "record": {
3880
+ "return_trace": cfg.get("return_trace", True),
3881
+ "trace_format": cfg.get("trace_format", "structured"),
3882
+ },
3883
+ "mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
3884
+ }
3885
+ if env_name:
3886
+ body["env"]["env_name"] = env_name
3887
+
3888
+ # Debug: print the body being sent
3889
+ if seed_val == 0:
3890
+ click.echo(f"[DEBUG] rollout body env: {body['env']}")
3891
+ click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
3892
+ click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
3893
+ rollout_elapsed: float | None = None
3894
+ rollout_start = time.perf_counter()
3895
+ try:
3896
+ import logging
3897
+ _log = logging.getLogger(__name__)
3898
+ _log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
3899
+ async with semaphore:
3900
+ response = await async_client.post("/rollout", json=body)
3901
+ rollout_elapsed = time.perf_counter() - rollout_start
3902
+ except Exception as exc:
3903
+ failures += 1
3904
+ click.echo(f"seed={seed_val} error={exc}")
3905
+ return
3906
+
3907
+ ok = 200 <= response.status_code < 300
2707
3908
  if ok:
2708
3909
  successes += 1
2709
3910
  else:
2710
3911
  failures += 1
2711
3912
 
2712
- # Print summary with any available metrics/tool calls
2713
- summary = [f"seed={seed_val}", f"status={resp.status_code}"]
3913
+ summary = [f"seed={seed_val}", f"status={response.status_code}"]
3914
+ data: Any
2714
3915
  try:
2715
- data = resp.json()
3916
+ data = response.json()
2716
3917
  except Exception:
2717
3918
  data = None
3919
+
3920
+ # Debug: print validation errors
3921
+ if response.status_code == 422 and data:
3922
+ click.echo(f"[DEBUG] 422 Validation Error: {data}")
3923
+
3924
+ metrics: dict[str, Any] | None = None
3925
+ completion: str | None = None
3926
+ prompt_index: int | None = None
3927
+ prompt_text: str | None = None
3928
+ task_id: str | None = None
3929
+ task_split: str | None = None
3930
+ task_rubric_id: str | None = None
3931
+
3932
+ trace_namespace: dict[str, Any] | None = None
3933
+ session_trace_dict: dict[str, Any] | None = None
3934
+
2718
3935
  if isinstance(data, dict):
3936
+ import logging
3937
+ _logger = logging.getLogger(__name__)
3938
+ _logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
3939
+ if "detail" in data:
3940
+ _logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
3941
+ trace_namespace = data.get("trace")
3942
+ _logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
3943
+ if not isinstance(trace_namespace, dict):
3944
+ raise RuntimeError(
3945
+ "The 'synth-ai eval' command requires trace payloads in rollout responses. "
3946
+ "Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
3947
+ "and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
3948
+ "Note: This is specific to the eval command - general rollout endpoints don't require traces."
3949
+ )
3950
+ # Handle both "compact" and "full" trace formats:
3951
+ # - compact: trace_namespace contains {session_id, metadata, ...}
3952
+ # - full: trace_namespace IS the full session_trace dict
3953
+ session_trace_dict = trace_namespace.get("session_trace")
3954
+ if not isinstance(session_trace_dict, dict):
3955
+ # If no session_trace key, assume "full" format where trace itself is the session_trace
3956
+ if "session_id" in trace_namespace:
3957
+ session_trace_dict = trace_namespace
3958
+ else:
3959
+ raise RuntimeError(
3960
+ "The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
3961
+ "Ensure the task app is using tracing_v3 and returning structured trace data."
3962
+ )
2719
3963
  metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
2720
3964
  if metrics:
2721
3965
  mean_return = metrics.get("mean_return") or metrics.get("total_reward")
@@ -2724,7 +3968,6 @@ def eval_command(
2724
3968
  summary.append(f"mean_return={mean_return}")
2725
3969
  if outcome is not None:
2726
3970
  summary.append(f"outcome={outcome}")
2727
- # Aggregate outcome stats
2728
3971
  try:
2729
3972
  val = float(outcome)
2730
3973
  outcome_sum += val
@@ -2733,7 +3976,6 @@ def eval_command(
2733
3976
  outcome_correct += 1
2734
3977
  except Exception:
2735
3978
  pass
2736
- # Try to infer tool call count from first trajectory step
2737
3979
  trajs = (
2738
3980
  data.get("trajectories")
2739
3981
  if isinstance(data.get("trajectories"), list)
@@ -2747,38 +3989,164 @@ def eval_command(
2747
3989
  tool_calls = step0.get("tool_calls") or step0.get("tools") or []
2748
3990
  if isinstance(tool_calls, list):
2749
3991
  summary.append(f"tool_calls={len(tool_calls)}")
3992
+ obs = step0.get("obs") if isinstance(step0, dict) else None
3993
+ if isinstance(obs, dict):
3994
+ idx_val = obs.get("prompt_index")
3995
+ if isinstance(idx_val, int):
3996
+ prompt_index = idx_val
3997
+ prompt_raw = obs.get("prompt")
3998
+ if isinstance(prompt_raw, str):
3999
+ prompt_text = prompt_raw
4000
+ if task_id is None:
4001
+ candidate_id = obs.get("task_id")
4002
+ if isinstance(candidate_id, str) and candidate_id:
4003
+ task_id = candidate_id
4004
+ if task_split is None:
4005
+ candidate_split = obs.get("task_split")
4006
+ if isinstance(candidate_split, str) and candidate_split:
4007
+ task_split = candidate_split
4008
+ if task_rubric_id is None:
4009
+ candidate_rid = obs.get("task_rubric_id")
4010
+ if isinstance(candidate_rid, str) and candidate_rid:
4011
+ task_rubric_id = candidate_rid
4012
+ final = first.get("final") if isinstance(first, dict) else None
4013
+ if isinstance(final, dict):
4014
+ final_obs = final.get("observation")
4015
+ if isinstance(final_obs, dict):
4016
+ comp_val = final_obs.get("completion")
4017
+ if isinstance(comp_val, str):
4018
+ completion = comp_val
4019
+ if task_id is None:
4020
+ candidate_id = final_obs.get("task_id")
4021
+ if isinstance(candidate_id, str) and candidate_id:
4022
+ task_id = candidate_id
4023
+ if task_split is None:
4024
+ candidate_split = final_obs.get("task_split")
4025
+ if isinstance(candidate_split, str) and candidate_split:
4026
+ task_split = candidate_split
4027
+ if task_rubric_id is None:
4028
+ candidate_rid = final_obs.get("task_rubric_id")
4029
+ if isinstance(candidate_rid, str) and candidate_rid:
4030
+ task_rubric_id = candidate_rid
4031
+ final_info = final.get("info")
4032
+ if isinstance(final_info, dict):
4033
+ if task_id is None:
4034
+ candidate_id = final_info.get("task_id")
4035
+ if isinstance(candidate_id, str) and candidate_id:
4036
+ task_id = candidate_id
4037
+ if task_split is None:
4038
+ candidate_split = final_info.get("task_split")
4039
+ if isinstance(candidate_split, str) and candidate_split:
4040
+ task_split = candidate_split
4041
+ if task_rubric_id is None:
4042
+ candidate_rid = final_info.get("task_rubric_id")
4043
+ if isinstance(candidate_rid, str) and candidate_rid:
4044
+ task_rubric_id = candidate_rid
4045
+ if task_id:
4046
+ summary.append(f"task_id={task_id}")
2750
4047
  click.echo(" ".join(summary))
2751
- # Print the full response JSON (trace, trajectories, metrics)
2752
4048
  with contextlib.suppress(Exception):
2753
4049
  click.echo(json.dumps(data, indent=2))
2754
4050
  else:
2755
4051
  click.echo(" ".join(summary))
2756
- except Exception as exc:
2757
- failures += 1
2758
- click.echo(f"seed={seed_val} error={exc}")
2759
4052
 
2760
- finally:
2761
- try:
2762
- client.close()
2763
- except AttributeError:
2764
- transport_obj = getattr(client, "_transport", None)
2765
- if transport_obj and hasattr(transport_obj, "aclose"):
2766
- try:
2767
- asyncio.run(transport_obj.aclose())
2768
- except RuntimeError:
2769
- # Fallback when already inside a running loop (rare for CLI).
2770
- new_loop = asyncio.new_event_loop()
4053
+ official_score = None
4054
+ if isinstance(metrics, dict):
4055
+ for key in ("mean_return", "total_reward", "outcome_score"):
4056
+ val = metrics.get(key)
4057
+ if isinstance(val, int | float):
4058
+ official_score = float(val)
4059
+ break
4060
+ if official_score is None and isinstance(data, dict):
2771
4061
  try:
2772
- new_loop.run_until_complete(transport_obj.aclose())
2773
- finally:
2774
- new_loop.close()
2775
- except Exception:
2776
- pass
4062
+ reward_val = data["trajectories"][0]["steps"][0].get("reward")
4063
+ if isinstance(reward_val, int | float):
4064
+ official_score = float(reward_val)
4065
+ except Exception:
4066
+ pass
4067
+
4068
+ if official_score is not None:
4069
+ if official_score < 0.0:
4070
+ official_score = 0.0
4071
+ elif official_score > 1.0:
4072
+ official_score = min(1.0, official_score)
4073
+
4074
+ judge_scores: dict[str, float | None] = {}
4075
+ judges_timings: dict[str, float | None] = {}
4076
+ timings: dict[str, Any] = {
4077
+ "rollout_s": rollout_elapsed,
4078
+ "judges": judges_timings,
4079
+ }
4080
+ if judge_specs:
4081
+ for spec in judge_specs:
4082
+ score_value: float | None = None
4083
+ judge_elapsed: float | None = None
4084
+ # Run judges for all tasks (text-based and trajectory-based)
4085
+ # Text-based tasks have completion, trajectory-based tasks use response
4086
+ judge_payload = {
4087
+ "seed": seed_val,
4088
+ "prompt_index": prompt_index,
4089
+ "prompt": prompt_text,
4090
+ "completion": completion,
4091
+ "metrics": metrics,
4092
+ "response": data,
4093
+ "trace": trace_namespace,
4094
+ }
4095
+ try:
4096
+ judge_start = time.perf_counter()
4097
+ result = spec.fn(judge_payload, **spec.kwargs)
4098
+ judge_elapsed = time.perf_counter() - judge_start
4099
+ if isinstance(result, int | float):
4100
+ score_value = float(result)
4101
+ except Exception as exc:
4102
+ if judge_elapsed is None:
4103
+ judge_elapsed = time.perf_counter() - judge_start
4104
+ click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
4105
+ judges_timings[spec.name] = judge_elapsed
4106
+ judge_scores[spec.name] = score_value
4107
+
4108
+ if trace_tracer is not None and trace_namespace:
4109
+ storage_metadata = {
4110
+ "eval_seed": seed_val,
4111
+ "prompt_index": prompt_index,
4112
+ "task_id": task_id,
4113
+ "task_split": task_split,
4114
+ "task_rubric_id": task_rubric_id,
4115
+ "official_score": official_score,
4116
+ "judge_scores": judge_scores,
4117
+ "model": selected_model,
4118
+ "prompt": prompt_text,
4119
+ "completion": completion,
4120
+ }
4121
+ await _store_trace(trace_tracer, trace_namespace, storage_metadata)
4122
+
4123
+ records.append(
4124
+ {
4125
+ "seed": seed_val,
4126
+ "prompt_index": prompt_index,
4127
+ "task_id": task_id,
4128
+ "task_split": task_split,
4129
+ "task_rubric_id": task_rubric_id,
4130
+ "official_score": official_score,
4131
+ "judge_scores": judge_scores,
4132
+ "timings": timings,
4133
+ }
4134
+ )
4135
+
4136
+ await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
4137
+ finally:
4138
+ await async_client.aclose()
4139
+
4140
+ try:
4141
+ asyncio.run(_run_eval())
4142
+ finally:
4143
+ if trace_tracer is not None and trace_tracer.db is not None:
4144
+ asyncio.run(trace_tracer.db.close())
2777
4145
 
2778
4146
  click.echo(
2779
4147
  f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
2780
4148
  )
2781
- # Print outcome summary if any successes
4149
+
2782
4150
  if outcome_count > 0:
2783
4151
  mean_outcome = outcome_sum / float(outcome_count)
2784
4152
  frac_right = outcome_correct / float(outcome_count)
@@ -2786,6 +4154,370 @@ def eval_command(
2786
4154
  f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
2787
4155
  )
2788
4156
 
4157
+ if records:
4158
+ judge_specs = judge_specs or [] # ensure iterable
4159
+ official_scores = [
4160
+ r["official_score"] for r in records if r["official_score"] is not None
4161
+ ]
4162
+ if official_scores:
4163
+ click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
4164
+ else:
4165
+ click.echo(" Official mean: n/a")
4166
+
4167
+ for spec in judge_specs:
4168
+ spec_scores = [
4169
+ record["judge_scores"].get(spec.name)
4170
+ for record in records
4171
+ if record["judge_scores"].get(spec.name) is not None
4172
+ ]
4173
+ if spec_scores:
4174
+ mean_spec = sum(spec_scores) / len(spec_scores)
4175
+ click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
4176
+ else:
4177
+ click.echo(f" [{spec.name}] mean: n/a")
4178
+
4179
+ paired = [
4180
+ (
4181
+ record["official_score"],
4182
+ record["judge_scores"].get(spec.name),
4183
+ )
4184
+ for record in records
4185
+ if record["official_score"] is not None
4186
+ and record["judge_scores"].get(spec.name) is not None
4187
+ ]
4188
+ if len(paired) >= 2:
4189
+ corr = _pearson(
4190
+ [p[0] for p in paired if p[0] is not None],
4191
+ [p[1] for p in paired if p[1] is not None],
4192
+ )
4193
+ if corr is not None:
4194
+ click.echo(f" Pearson r: {corr:.3f}")
4195
+ else:
4196
+ click.echo(" Pearson r: undefined (zero variance)")
4197
+ else:
4198
+ click.echo(" Pearson r: n/a (need ≥2 paired scores)")
4199
+
4200
+ header = ["Seed", "Prompt", "Official"]
4201
+ header.extend(spec.name for spec in judge_specs)
4202
+ rows: list[list[str]] = []
4203
+ for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
4204
+ seed_val = str(record["seed"])
4205
+ prompt_idx = (
4206
+ str(record["prompt_index"])
4207
+ if record["prompt_index"] is not None
4208
+ else "-"
4209
+ )
4210
+ official_val = (
4211
+ f"{record['official_score']:.3f}"
4212
+ if record["official_score"] is not None
4213
+ else "-"
4214
+ )
4215
+ row = [seed_val, prompt_idx, official_val]
4216
+ for spec in judge_specs:
4217
+ score_val = record["judge_scores"].get(spec.name)
4218
+ row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
4219
+ rows.append(row)
4220
+
4221
+ widths = [len(col) for col in header]
4222
+ for row in rows:
4223
+ for idx, cell in enumerate(row):
4224
+ widths[idx] = max(widths[idx], len(cell))
4225
+
4226
+ click.echo("")
4227
+ click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
4228
+ click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
4229
+ for row in rows:
4230
+ click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
4231
+
4232
+
4233
+
4234
+ @click.command(
4235
+ "filter",
4236
+ help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
4237
+ )
4238
+ @click.option(
4239
+ "--config",
4240
+ "config_path",
4241
+ type=click.Path(),
4242
+ required=True,
4243
+ help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
4244
+ )
4245
+ def filter_command(config_path: str) -> None:
4246
+ """Render tracing sessions that match filter rules into SFT JSONL.
4247
+
4248
+ The TOML file should contain a `[filter]` table with at least:
4249
+
4250
+ db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
4251
+ output = \"ft_data/out.jsonl\" # destination JSONL
4252
+
4253
+ Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
4254
+ `min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
4255
+ high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
4256
+ for a working example.
4257
+ """
4258
+ # Parse and validate TOML config
4259
+ from synth_ai.task.config import FilterConfig
4260
+
4261
+ if _toml is None:
4262
+ raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
4263
+
4264
+ cfg_path = Path(config_path)
4265
+ if not cfg_path.exists():
4266
+ raise click.ClickException(f"Filter config not found: {cfg_path}")
4267
+
4268
+ try:
4269
+ config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
4270
+ except Exception as exc:
4271
+ raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
4272
+
4273
+ filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
4274
+ if not isinstance(filter_cfg_dict, dict):
4275
+ raise click.ClickException("Config must contain a [filter] table")
4276
+
4277
+ # Validate config with dataclass
4278
+ try:
4279
+ filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
4280
+ click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
4281
+ if filter_cfg.min_official_score is not None:
4282
+ click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
4283
+ if filter_cfg.limit:
4284
+ click.echo(f" → Limiting to {filter_cfg.limit} examples")
4285
+ except (ValueError, TypeError) as validation_error:
4286
+ raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
4287
+
4288
+ # Use validated config
4289
+ db_url = filter_cfg.get_db_url()
4290
+ output_path = filter_cfg.get_output_path()
4291
+
4292
+ # Extract validated fields from dataclass
4293
+ splits = set(filter_cfg.splits)
4294
+ task_ids = set(filter_cfg.task_ids)
4295
+ models = set(filter_cfg.models)
4296
+ min_official = filter_cfg.min_official_score
4297
+ max_official = filter_cfg.max_official_score
4298
+ min_judge_scores = filter_cfg.min_judge_scores
4299
+ max_judge_scores = filter_cfg.max_judge_scores
4300
+ # Note: min_created_at and max_created_at not yet in FilterConfig dataclass
4301
+ min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
4302
+ max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
4303
+ limit = filter_cfg.limit
4304
+
4305
+ def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
4306
+ try:
4307
+ if value is None:
4308
+ return min_val is None
4309
+ value = float(value)
4310
+ except Exception:
4311
+ return False
4312
+ if min_val is not None and value < float(min_val):
4313
+ return False
4314
+ return not (max_val is not None and value > float(max_val))
4315
+
4316
+ async def _run_filter() -> None:
4317
+ tracer = SessionTracer(db_url=db_url, auto_save=False)
4318
+ await tracer.initialize()
4319
+
4320
+ df = await tracer.db.query_traces(
4321
+ "SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
4322
+ )
4323
+ if getattr(df, "empty", True):
4324
+ raise click.ClickException("No traces found in database")
4325
+
4326
+ sessions = df.to_dict("records")
4327
+ accepted: list[dict[str, Any]] = []
4328
+
4329
+ for row in sessions:
4330
+ metadata_raw = row.get("metadata")
4331
+ if isinstance(metadata_raw, str):
4332
+ try:
4333
+ metadata = json.loads(metadata_raw)
4334
+ except Exception:
4335
+ metadata = {}
4336
+ elif isinstance(metadata_raw, dict):
4337
+ metadata = dict(metadata_raw)
4338
+ else:
4339
+ metadata = {}
4340
+
4341
+ created_at_raw = row.get("created_at")
4342
+ created_at_dt = _parse_datetime_for_trace(created_at_raw)
4343
+
4344
+ session_id = row.get("session_id")
4345
+
4346
+ if splits and metadata.get("task_split") not in splits:
4347
+ continue
4348
+ if task_ids and metadata.get("task_id") not in task_ids:
4349
+ continue
4350
+ if models and metadata.get("model") not in models:
4351
+ continue
4352
+
4353
+ if min_created and (created_at_dt is None or created_at_dt < min_created):
4354
+ continue
4355
+ if max_created and (created_at_dt is None or created_at_dt > max_created):
4356
+ continue
4357
+
4358
+ # Check against outcome_rewards if score filter is set
4359
+ total_reward = None
4360
+ achievements_count = None
4361
+ if min_official is not None or max_official is not None:
4362
+ reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
4363
+ reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
4364
+ reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
4365
+ if reward_records:
4366
+ total_reward = reward_records[0].get("total_reward")
4367
+ achievements_count = reward_records[0].get("achievements_count")
4368
+ if not _score_ok(total_reward, min_official, max_official):
4369
+ continue
4370
+ elif min_official is not None:
4371
+ # No reward found, but score filter requires it
4372
+ continue
4373
+
4374
+ judge_scores = metadata.get("judge_scores") or {}
4375
+ include = True
4376
+ for judge_name, threshold in (min_judge_scores or {}).items():
4377
+ if not _score_ok(judge_scores.get(judge_name), threshold, None):
4378
+ include = False
4379
+ break
4380
+ if not include:
4381
+ continue
4382
+ for judge_name, threshold in (max_judge_scores or {}).items():
4383
+ if not _score_ok(judge_scores.get(judge_name), None, threshold):
4384
+ include = False
4385
+ break
4386
+ if not include:
4387
+ continue
4388
+
4389
+ # Query messages for this session
4390
+ messages_query = """
4391
+ SELECT message_type, content, timestamp
4392
+ FROM messages
4393
+ WHERE session_id = :session_id
4394
+ ORDER BY timestamp ASC, id ASC
4395
+ """
4396
+ msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
4397
+ message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
4398
+
4399
+ if not message_rows:
4400
+ # Fallback: check if prompt/completion in metadata (old format)
4401
+ prompt = metadata.get("prompt") or ""
4402
+ completion = metadata.get("completion") or ""
4403
+ if prompt and completion:
4404
+ record = {
4405
+ "messages": [
4406
+ {"role": "user", "content": str(prompt)},
4407
+ {"role": "assistant", "content": str(completion)},
4408
+ ],
4409
+ "metadata": {
4410
+ "session_id": session_id,
4411
+ "env_name": metadata.get("env_name"),
4412
+ "policy_name": metadata.get("policy_name"),
4413
+ "seed": metadata.get("seed"),
4414
+ "total_reward": total_reward,
4415
+ "achievements_count": achievements_count,
4416
+ "model": metadata.get("model"),
4417
+ "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4418
+ },
4419
+ }
4420
+ accepted.append(record)
4421
+ continue
4422
+
4423
+ # Extract user/assistant pairs from messages
4424
+ for i, msg_row in enumerate(message_rows):
4425
+ msg_type = msg_row.get("message_type")
4426
+ content_raw = msg_row.get("content")
4427
+
4428
+ # Look for user message
4429
+ if msg_type in ("user", "policy_user_prompt"):
4430
+ # Find next policy_system_prompt or assistant
4431
+ assistant_msg = None
4432
+ for j in range(i + 1, len(message_rows)):
4433
+ next_type = message_rows[j].get("message_type")
4434
+ if next_type in ("assistant", "policy_system_prompt"):
4435
+ if next_type == "assistant":
4436
+ assistant_msg = message_rows[j]
4437
+ break
4438
+
4439
+ # Parse content
4440
+ try:
4441
+ user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
4442
+ except Exception:
4443
+ user_content = content_raw
4444
+
4445
+ # Extract text from structured content
4446
+ def extract_text(content: Any) -> str:
4447
+ if isinstance(content, str):
4448
+ return content
4449
+ if isinstance(content, dict):
4450
+ # Try payload.content for user prompts
4451
+ if "payload" in content and isinstance(content["payload"], dict):
4452
+ payload = content["payload"]
4453
+ if "content" in payload:
4454
+ return extract_text(payload["content"])
4455
+ # Try common keys
4456
+ for key in ["text", "content", "content_text"]:
4457
+ if key in content:
4458
+ val = content[key]
4459
+ if isinstance(val, str):
4460
+ return val
4461
+ return json.dumps(content)
4462
+ if isinstance(content, list):
4463
+ # Multimodal content - concatenate text parts
4464
+ parts = []
4465
+ for item in content:
4466
+ if isinstance(item, dict) and item.get("type") == "text":
4467
+ parts.append(item.get("text", ""))
4468
+ return " ".join(parts) if parts else str(content)
4469
+ return str(content)
4470
+
4471
+ user_text = extract_text(user_content)
4472
+
4473
+ # For assistant, we might not have it recorded, so use tool calls as completion
4474
+ assistant_text = ""
4475
+ if assistant_msg:
4476
+ assistant_content_raw = assistant_msg.get("content")
4477
+ try:
4478
+ assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
4479
+ except Exception:
4480
+ assistant_content = assistant_content_raw
4481
+ assistant_text = extract_text(assistant_content)
4482
+
4483
+ if not user_text:
4484
+ continue
4485
+
4486
+ record = {
4487
+ "messages": [
4488
+ {"role": "user", "content": user_text},
4489
+ {"role": "assistant", "content": assistant_text if assistant_text else "[no response recorded]"},
4490
+ ],
4491
+ "metadata": {
4492
+ "session_id": session_id,
4493
+ "env_name": metadata.get("env_name"),
4494
+ "policy_name": metadata.get("policy_name"),
4495
+ "seed": metadata.get("seed"),
4496
+ "total_reward": total_reward,
4497
+ "achievements_count": achievements_count,
4498
+ "model": metadata.get("model"),
4499
+ "created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
4500
+ },
4501
+ }
4502
+ accepted.append(record)
4503
+
4504
+ if not accepted:
4505
+ raise click.ClickException("No sessions matched the provided filters")
4506
+
4507
+ if limit is not None and limit > 0:
4508
+ accepted = accepted[:limit]
4509
+
4510
+ output_path.parent.mkdir(parents=True, exist_ok=True)
4511
+ with output_path.open("w", encoding="utf-8") as handle:
4512
+ for item in accepted:
4513
+ handle.write(json.dumps(item, ensure_ascii=False))
4514
+ handle.write("\n")
4515
+
4516
+ click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
4517
+ await tracer.db.close()
4518
+
4519
+ asyncio.run(_run_filter())
4520
+
2789
4521
 
2790
4522
  def register_eval(cli: click.Group) -> None:
2791
4523
  cli.add_command(eval_command)