synth-ai 0.2.14__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +73 -115
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +5 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -2
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +152 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +274 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +415 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +61 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +6 -6
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +62 -0
- examples/rl/configs/rl_from_base_qwen17.toml +79 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +21 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +6 -6
- examples/sft/configs/crafter_fft_qwen0p6b.toml +7 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +7 -3
- examples/sft/evaluate.py +2 -4
- examples/sft/export_dataset.py +7 -4
- examples/swe/task_app/README.md +33 -3
- examples/swe/task_app/grpo_swe_mini.py +4 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +50 -23
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +0 -8
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +70 -10
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +63 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +48 -50
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +75 -36
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +31 -15
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +5 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +5 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +827 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1084 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +144 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +155 -17
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +61 -69
- synth_ai/cli/_modal_wrapper.py +7 -5
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/recent.py +2 -1
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +21 -0
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +7 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +11 -0
- synth_ai/cli/task_app_serve.py +11 -0
- synth_ai/cli/task_apps.py +110 -1499
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +5 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/http.py +8 -22
- synth_ai/inference/client.py +1 -1
- synth_ai/judge_schemas.py +4 -5
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +4 -2
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/apps/__init__.py +4 -2
- synth_ai/task/config.py +6 -4
- synth_ai/task/rubrics/__init__.py +1 -2
- synth_ai/task/rubrics/loaders.py +14 -10
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +24 -11
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/session_tracer.py +7 -7
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -9
- synth_ai/tracing_v3/turso/native_manager.py +80 -72
- synth_ai/tracing_v3/utils.py +2 -2
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +294 -0
- synth_ai/utils/http.py +172 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/METADATA +91 -32
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/RECORD +341 -154
- synth_ai/cli/man.py +0 -106
- synth_ai/cli/tui.py +0 -57
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -906
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -9,24 +9,21 @@ import hashlib
|
|
|
9
9
|
import importlib
|
|
10
10
|
import importlib.util
|
|
11
11
|
import inspect
|
|
12
|
-
import json
|
|
13
12
|
import os
|
|
14
13
|
import shlex
|
|
15
14
|
import shutil
|
|
16
15
|
import signal
|
|
17
|
-
import sqlite3
|
|
18
16
|
import subprocess
|
|
19
17
|
import sys
|
|
20
18
|
import tempfile
|
|
21
19
|
import textwrap
|
|
22
20
|
import time
|
|
23
21
|
import types
|
|
24
|
-
import uuid
|
|
25
22
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
26
23
|
from dataclasses import dataclass
|
|
27
|
-
from datetime import
|
|
24
|
+
from datetime import UTC, datetime
|
|
28
25
|
from pathlib import Path
|
|
29
|
-
from typing import Any,
|
|
26
|
+
from typing import Any, cast
|
|
30
27
|
|
|
31
28
|
try: # Python 3.11+
|
|
32
29
|
import tomllib as _toml
|
|
@@ -35,6 +32,9 @@ except Exception: # pragma: no cover - fallback
|
|
|
35
32
|
|
|
36
33
|
import click
|
|
37
34
|
from click.exceptions import Abort
|
|
35
|
+
from synth_ai.cli.commands import deploy as _deploy_commands
|
|
36
|
+
from synth_ai.cli.commands.eval import core as eval_core
|
|
37
|
+
from synth_ai.cli.commands.filter import core as filter_core
|
|
38
38
|
|
|
39
39
|
# Tracing imports - make conditional for optional dependencies
|
|
40
40
|
try:
|
|
@@ -92,14 +92,14 @@ except Exception as exc: # pragma: no cover - critical dependency
|
|
|
92
92
|
raise RuntimeError("Unable to load task app server utilities") from exc
|
|
93
93
|
|
|
94
94
|
|
|
95
|
-
def _load_demo_directory() ->
|
|
95
|
+
def _load_demo_directory() -> Path | None:
|
|
96
96
|
"""Return the demo task apps directory if available."""
|
|
97
97
|
|
|
98
98
|
try:
|
|
99
99
|
module = cast(
|
|
100
100
|
Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
101
101
|
)
|
|
102
|
-
loader = cast(Callable[[],
|
|
102
|
+
loader = cast(Callable[[], str | Path | None], module.load_demo_dir)
|
|
103
103
|
demo_dir = loader()
|
|
104
104
|
if isinstance(demo_dir, str | Path):
|
|
105
105
|
demo_path = Path(demo_dir)
|
|
@@ -139,7 +139,7 @@ DEFAULT_SEARCH_RELATIVE = (
|
|
|
139
139
|
)
|
|
140
140
|
|
|
141
141
|
|
|
142
|
-
def _pearson(xs: Sequence[float], ys: Sequence[float]) ->
|
|
142
|
+
def _pearson(xs: Sequence[float], ys: Sequence[float]) -> float | None:
|
|
143
143
|
if len(xs) != len(ys) or len(xs) < 2:
|
|
144
144
|
return None
|
|
145
145
|
mean_x = sum(xs) / len(xs)
|
|
@@ -164,7 +164,7 @@ class AppChoice:
|
|
|
164
164
|
label: str
|
|
165
165
|
path: Path
|
|
166
166
|
source: str
|
|
167
|
-
description:
|
|
167
|
+
description: str | None = None
|
|
168
168
|
aliases: tuple[str, ...] = ()
|
|
169
169
|
entry: TaskAppEntryType | None = None
|
|
170
170
|
entry_loader: Callable[[], TaskAppEntryType] | None = None
|
|
@@ -188,21 +188,21 @@ class JudgeSpec:
|
|
|
188
188
|
kwargs: dict[str, Any]
|
|
189
189
|
|
|
190
190
|
|
|
191
|
-
def _parse_datetime_for_trace(value: Any) ->
|
|
191
|
+
def _parse_datetime_for_trace(value: Any) -> datetime | None:
|
|
192
192
|
if isinstance(value, datetime):
|
|
193
|
-
return value if value.tzinfo else value.replace(tzinfo=
|
|
193
|
+
return value if value.tzinfo else value.replace(tzinfo=UTC)
|
|
194
194
|
if isinstance(value, str):
|
|
195
195
|
value = value.replace("Z", "+00:00")
|
|
196
196
|
try:
|
|
197
197
|
dt = datetime.fromisoformat(value)
|
|
198
198
|
except ValueError:
|
|
199
199
|
try:
|
|
200
|
-
dt = datetime.fromtimestamp(float(value), tz=
|
|
200
|
+
dt = datetime.fromtimestamp(float(value), tz=UTC)
|
|
201
201
|
except Exception:
|
|
202
202
|
return None
|
|
203
|
-
return dt if dt.tzinfo else dt.replace(tzinfo=
|
|
203
|
+
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
|
204
204
|
if isinstance(value, int | float):
|
|
205
|
-
return datetime.fromtimestamp(float(value), tz=
|
|
205
|
+
return datetime.fromtimestamp(float(value), tz=UTC)
|
|
206
206
|
return None
|
|
207
207
|
|
|
208
208
|
|
|
@@ -241,6 +241,24 @@ def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
|
|
|
241
241
|
system_state_after=payload.get("system_state_after"),
|
|
242
242
|
**base_kwargs,
|
|
243
243
|
)
|
|
244
|
+
# Check for LM CAIS event fields
|
|
245
|
+
if any(key in payload for key in ("model_name", "provider", "call_records")):
|
|
246
|
+
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
|
247
|
+
# Note: call_records are left as dicts - the storage layer will handle serialization
|
|
248
|
+
call_records = payload.get("call_records") or []
|
|
249
|
+
return LMCAISEvent(
|
|
250
|
+
model_name=payload.get("model_name", ""),
|
|
251
|
+
provider=payload.get("provider", ""),
|
|
252
|
+
input_tokens=payload.get("input_tokens"),
|
|
253
|
+
output_tokens=payload.get("output_tokens"),
|
|
254
|
+
total_tokens=payload.get("total_tokens"),
|
|
255
|
+
cost_usd=payload.get("cost_usd"),
|
|
256
|
+
latency_ms=payload.get("latency_ms"),
|
|
257
|
+
span_id=payload.get("span_id"),
|
|
258
|
+
trace_id=payload.get("trace_id"),
|
|
259
|
+
call_records=call_records,
|
|
260
|
+
**base_kwargs,
|
|
261
|
+
)
|
|
244
262
|
return BaseEvent(**base_kwargs)
|
|
245
263
|
|
|
246
264
|
|
|
@@ -279,7 +297,7 @@ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
|
279
297
|
for msg in payload.get("markov_blanket_messages", [])
|
|
280
298
|
if isinstance(msg, dict)
|
|
281
299
|
]
|
|
282
|
-
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(
|
|
300
|
+
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(UTC)
|
|
283
301
|
completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
|
|
284
302
|
return SessionTimeStep(
|
|
285
303
|
step_id=payload.get("step_id", ""),
|
|
@@ -293,7 +311,7 @@ def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
|
293
311
|
)
|
|
294
312
|
|
|
295
313
|
|
|
296
|
-
def _session_trace_from_dict(payload: dict[str, Any]) ->
|
|
314
|
+
def _session_trace_from_dict(payload: dict[str, Any]) -> V3SessionTrace | None:
|
|
297
315
|
if not isinstance(payload, dict):
|
|
298
316
|
return None
|
|
299
317
|
steps = [
|
|
@@ -311,7 +329,7 @@ def _session_trace_from_dict(payload: dict[str, Any]) -> Optional[V3SessionTrace
|
|
|
311
329
|
for msg in payload.get("markov_blanket_message_history", [])
|
|
312
330
|
if isinstance(msg, dict)
|
|
313
331
|
]
|
|
314
|
-
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(
|
|
332
|
+
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(UTC)
|
|
315
333
|
metadata = payload.get("metadata") or {}
|
|
316
334
|
session_metadata = payload.get("session_metadata")
|
|
317
335
|
return V3SessionTrace(
|
|
@@ -336,15 +354,32 @@ async def _store_trace(
|
|
|
336
354
|
_logger.info(f"[STORE_TRACE_DEBUG] Called with tracer={tracer is not None}, trace_namespace={trace_namespace is not None}")
|
|
337
355
|
|
|
338
356
|
if tracer is None or not isinstance(trace_namespace, dict):
|
|
339
|
-
|
|
340
|
-
|
|
357
|
+
message = (
|
|
358
|
+
f"Trace storage requires a tracer instance and dict payload. "
|
|
359
|
+
f"Got tracer_present={tracer is not None}, payload_type={type(trace_namespace)}"
|
|
360
|
+
)
|
|
361
|
+
_logger.error("[STORE_TRACE_DEBUG] %s", message)
|
|
362
|
+
raise ValueError(message)
|
|
341
363
|
|
|
342
364
|
_logger.info(f"[STORE_TRACE_DEBUG] trace_namespace keys: {list(trace_namespace.keys())}")
|
|
343
365
|
|
|
366
|
+
# Handle both formats:
|
|
367
|
+
# - With session_trace key: {"session_trace": {...}}
|
|
368
|
+
# - Without session_trace key (trace itself is the session): {"session_id": ..., "markov_blanket_message_history": ...}
|
|
344
369
|
session_payload = trace_namespace.get("session_trace")
|
|
345
370
|
if not isinstance(session_payload, dict):
|
|
346
|
-
|
|
347
|
-
|
|
371
|
+
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
372
|
+
if "session_id" in trace_namespace:
|
|
373
|
+
session_payload = trace_namespace
|
|
374
|
+
_logger.info("[STORE_TRACE_DEBUG] Using trace_namespace directly as session_payload (no session_trace key)")
|
|
375
|
+
else:
|
|
376
|
+
message = (
|
|
377
|
+
"Trace payload did not contain a 'session_trace' dict and lacked top-level "
|
|
378
|
+
"session fields (session_id, markov_blanket_message_history). "
|
|
379
|
+
f"Payload keys: {list(trace_namespace.keys())}"
|
|
380
|
+
)
|
|
381
|
+
_logger.error("[STORE_TRACE_DEBUG] %s", message)
|
|
382
|
+
raise ValueError(message)
|
|
348
383
|
|
|
349
384
|
_logger.info(f"[STORE_TRACE_DEBUG] session_payload keys: {list(session_payload.keys())}")
|
|
350
385
|
msg_count = len(session_payload.get("markov_blanket_message_history", []))
|
|
@@ -352,8 +387,26 @@ async def _store_trace(
|
|
|
352
387
|
|
|
353
388
|
trace_obj = _session_trace_from_dict(session_payload)
|
|
354
389
|
if trace_obj is None:
|
|
355
|
-
|
|
356
|
-
|
|
390
|
+
message = "Session trace payload could not be parsed into a SessionTrace object."
|
|
391
|
+
_logger.error("[STORE_TRACE_DEBUG] %s", message)
|
|
392
|
+
raise ValueError(message)
|
|
393
|
+
|
|
394
|
+
if not trace_obj.markov_blanket_message_history:
|
|
395
|
+
message = (
|
|
396
|
+
"Session trace is missing markov_blanket_message_history; "
|
|
397
|
+
"eval output must include all prompts/tool calls. "
|
|
398
|
+
f"session_id={trace_obj.session_id}"
|
|
399
|
+
)
|
|
400
|
+
_logger.error("[STORE_TRACE_DEBUG] %s", message)
|
|
401
|
+
raise ValueError(message)
|
|
402
|
+
|
|
403
|
+
if not trace_obj.event_history:
|
|
404
|
+
message = (
|
|
405
|
+
"Session trace is missing event_history; rollout should emit environment/LLM events. "
|
|
406
|
+
f"session_id={trace_obj.session_id}"
|
|
407
|
+
)
|
|
408
|
+
_logger.error("[STORE_TRACE_DEBUG] %s", message)
|
|
409
|
+
raise ValueError(message)
|
|
357
410
|
|
|
358
411
|
_logger.info(f"[STORE_TRACE_DEBUG] Created SessionTrace object with {len(trace_obj.markov_blanket_message_history)} messages")
|
|
359
412
|
|
|
@@ -366,7 +419,7 @@ async def _store_trace(
|
|
|
366
419
|
|
|
367
420
|
_logger.info(f"[STORE_TRACE_DEBUG] Calling insert_session_trace for session_id={trace_obj.session_id}")
|
|
368
421
|
await tracer.db.insert_session_trace(trace_obj)
|
|
369
|
-
_logger.info(
|
|
422
|
+
_logger.info("[STORE_TRACE_DEBUG] Successfully inserted trace")
|
|
370
423
|
|
|
371
424
|
def _temporary_sys_path(paths: Sequence[Path]):
|
|
372
425
|
"""Context manager to prepend entries to sys.path temporarily."""
|
|
@@ -480,49 +533,6 @@ def _candidate_search_roots() -> list[Path]:
|
|
|
480
533
|
return ordered
|
|
481
534
|
|
|
482
535
|
|
|
483
|
-
def _eval_config_sort_key(path: Path) -> tuple[int, int, int, str]:
|
|
484
|
-
name = path.name.lower()
|
|
485
|
-
parent_names = {p.name.lower() for p in path.parents}
|
|
486
|
-
in_configs = 0 if "configs" in parent_names else 1
|
|
487
|
-
in_examples = 0 if "examples" in parent_names else 1
|
|
488
|
-
starts_eval = 0 if name.startswith("eval") else 1
|
|
489
|
-
return (in_configs, in_examples, starts_eval, str(path))
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def _discover_eval_config_paths() -> list[Path]:
|
|
493
|
-
"""Find candidate eval TOML files near the current working directory."""
|
|
494
|
-
|
|
495
|
-
candidates: list[Path] = []
|
|
496
|
-
seen: set[Path] = set()
|
|
497
|
-
search_roots = _candidate_search_roots()
|
|
498
|
-
for root in search_roots:
|
|
499
|
-
if not root.exists() or not root.is_dir():
|
|
500
|
-
continue
|
|
501
|
-
try:
|
|
502
|
-
root = root.resolve()
|
|
503
|
-
except Exception:
|
|
504
|
-
continue
|
|
505
|
-
for path in root.rglob("*.toml"):
|
|
506
|
-
if not path.is_file():
|
|
507
|
-
continue
|
|
508
|
-
if _should_ignore_path(path):
|
|
509
|
-
continue
|
|
510
|
-
name_lower = path.name.lower()
|
|
511
|
-
if "eval" not in name_lower and "evaluation" not in name_lower:
|
|
512
|
-
continue
|
|
513
|
-
try:
|
|
514
|
-
resolved = path.resolve()
|
|
515
|
-
except Exception:
|
|
516
|
-
continue
|
|
517
|
-
if resolved in seen:
|
|
518
|
-
continue
|
|
519
|
-
seen.add(resolved)
|
|
520
|
-
candidates.append(resolved)
|
|
521
|
-
|
|
522
|
-
candidates.sort(key=_eval_config_sort_key)
|
|
523
|
-
return candidates
|
|
524
|
-
|
|
525
|
-
|
|
526
536
|
class _TaskAppConfigVisitor(ast.NodeVisitor):
|
|
527
537
|
def __init__(self) -> None:
|
|
528
538
|
self.matches: list[tuple[str, int]] = []
|
|
@@ -913,43 +923,43 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
|
|
|
913
923
|
for kw in modal_call.keywords:
|
|
914
924
|
if kw.arg and isinstance(kw.value, ast.Constant):
|
|
915
925
|
kwargs[kw.arg] = kw.value.value
|
|
916
|
-
elif kw.arg == "pip_packages" and isinstance(kw.value,
|
|
926
|
+
elif kw.arg == "pip_packages" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
917
927
|
# Handle pip_packages list/tuple
|
|
918
928
|
packages: list[str] = []
|
|
919
929
|
value_node = kw.value
|
|
920
|
-
if isinstance(value_node,
|
|
930
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
921
931
|
for elt in value_node.elts:
|
|
922
932
|
if isinstance(elt, ast.Constant):
|
|
923
933
|
packages.append(elt.value)
|
|
924
934
|
kwargs[kw.arg] = tuple(packages)
|
|
925
|
-
elif kw.arg == "extra_local_dirs" and isinstance(kw.value,
|
|
935
|
+
elif kw.arg == "extra_local_dirs" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
926
936
|
# Handle extra_local_dirs list/tuple of tuples
|
|
927
937
|
dirs = []
|
|
928
938
|
value_node = kw.value
|
|
929
|
-
if isinstance(value_node,
|
|
939
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
930
940
|
for elt in value_node.elts:
|
|
931
|
-
if isinstance(elt,
|
|
941
|
+
if isinstance(elt, ast.List | ast.Tuple) and len(elt.elts) == 2:
|
|
932
942
|
src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
933
943
|
dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
934
944
|
if src and dst:
|
|
935
945
|
dirs.append((src, dst))
|
|
936
946
|
kwargs[kw.arg] = tuple(dirs)
|
|
937
|
-
elif kw.arg == "secret_names" and isinstance(kw.value,
|
|
947
|
+
elif kw.arg == "secret_names" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
938
948
|
# Handle secret_names list/tuple
|
|
939
949
|
secrets = []
|
|
940
950
|
value_node = kw.value
|
|
941
|
-
if isinstance(value_node,
|
|
951
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
942
952
|
for elt in value_node.elts:
|
|
943
953
|
if isinstance(elt, ast.Constant):
|
|
944
954
|
secrets.append(elt.value)
|
|
945
955
|
kwargs[kw.arg] = tuple(secrets)
|
|
946
|
-
elif kw.arg == "volume_mounts" and isinstance(kw.value,
|
|
956
|
+
elif kw.arg == "volume_mounts" and isinstance(kw.value, ast.List | ast.Tuple):
|
|
947
957
|
# Handle volume_mounts list/tuple of tuples
|
|
948
958
|
mounts = []
|
|
949
959
|
value_node = kw.value
|
|
950
|
-
if isinstance(value_node,
|
|
960
|
+
if isinstance(value_node, ast.List | ast.Tuple):
|
|
951
961
|
for elt in value_node.elts:
|
|
952
|
-
if isinstance(elt,
|
|
962
|
+
if isinstance(elt, ast.List | ast.Tuple) and len(elt.elts) == 2:
|
|
953
963
|
name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
954
964
|
mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
955
965
|
if name and mount:
|
|
@@ -2238,14 +2248,13 @@ def validate_task_app_cmd(
|
|
|
2238
2248
|
• Debug failing deployments: Use --verbose to see detailed endpoint responses
|
|
2239
2249
|
• Test API key configuration: Verify authentication is set up correctly
|
|
2240
2250
|
"""
|
|
2241
|
-
import asyncio
|
|
2242
2251
|
import socket
|
|
2243
2252
|
import subprocess
|
|
2244
2253
|
import tempfile
|
|
2245
2254
|
import time
|
|
2246
2255
|
|
|
2247
2256
|
# Import the validate_task_app function defined in this module
|
|
2248
|
-
from
|
|
2257
|
+
from ._validate_task_app import validate_task_app # type: ignore[attr-defined]
|
|
2249
2258
|
|
|
2250
2259
|
proc = None
|
|
2251
2260
|
task_app_url = url
|
|
@@ -2445,48 +2454,15 @@ def serve_command(
|
|
|
2445
2454
|
trace_dir: str | None,
|
|
2446
2455
|
trace_db: str | None,
|
|
2447
2456
|
) -> None:
|
|
2448
|
-
|
|
2449
|
-
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
2458
|
-
# Prompt for port if not provided
|
|
2459
|
-
if port is None:
|
|
2460
|
-
port = click.prompt("Port to serve on", type=int, default=8001)
|
|
2461
|
-
|
|
2462
|
-
# Prompt for trace directory if not provided
|
|
2463
|
-
if trace_dir is None:
|
|
2464
|
-
click.echo(
|
|
2465
|
-
"\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
|
|
2466
|
-
)
|
|
2467
|
-
click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
|
|
2468
|
-
enable_tracing = click.confirm("Enable tracing?", default=True)
|
|
2469
|
-
if enable_tracing:
|
|
2470
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2471
|
-
default_trace_dir = str((demo_base / "traces/v3").resolve())
|
|
2472
|
-
trace_dir = click.prompt(
|
|
2473
|
-
"Trace directory", type=str, default=default_trace_dir, show_default=True
|
|
2474
|
-
)
|
|
2475
|
-
else:
|
|
2476
|
-
trace_dir = None
|
|
2477
|
-
|
|
2478
|
-
# Prompt for trace DB if not provided and tracing is enabled
|
|
2479
|
-
if trace_dir and trace_db is None:
|
|
2480
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2481
|
-
default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
|
|
2482
|
-
trace_db = click.prompt(
|
|
2483
|
-
"Trace DB path", type=str, default=default_trace_db, show_default=True
|
|
2484
|
-
)
|
|
2485
|
-
|
|
2486
|
-
choice = _select_app_choice(app_id, purpose="serve")
|
|
2487
|
-
entry = choice.ensure_entry()
|
|
2488
|
-
_serve_entry(
|
|
2489
|
-
entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
|
|
2457
|
+
_deploy_commands.run_uvicorn_runtime(
|
|
2458
|
+
app_id,
|
|
2459
|
+
host,
|
|
2460
|
+
port,
|
|
2461
|
+
env_file,
|
|
2462
|
+
reload_flag,
|
|
2463
|
+
force,
|
|
2464
|
+
trace_dir,
|
|
2465
|
+
trace_db,
|
|
2490
2466
|
)
|
|
2491
2467
|
|
|
2492
2468
|
|
|
@@ -2599,49 +2575,19 @@ def serve_task_group(
|
|
|
2599
2575
|
trace_dir: str | None,
|
|
2600
2576
|
trace_db: str | None,
|
|
2601
2577
|
) -> None:
|
|
2602
|
-
|
|
2603
|
-
|
|
2604
|
-
|
|
2605
|
-
|
|
2606
|
-
|
|
2607
|
-
|
|
2608
|
-
|
|
2609
|
-
|
|
2610
|
-
|
|
2611
|
-
|
|
2612
|
-
# Prompt for port if not provided
|
|
2613
|
-
if port is None:
|
|
2614
|
-
port = click.prompt("Port to serve on", type=int, default=8001)
|
|
2615
|
-
|
|
2616
|
-
# Prompt for trace directory if not provided
|
|
2617
|
-
if trace_dir is None:
|
|
2618
|
-
click.echo(
|
|
2619
|
-
"\nTracing captures rollout data (actions, rewards, model outputs) to a local SQLite DB."
|
|
2620
|
-
)
|
|
2621
|
-
click.echo("This data can be exported to JSONL for supervised fine-tuning (SFT).")
|
|
2622
|
-
enable_tracing = click.confirm("Enable tracing?", default=True)
|
|
2623
|
-
if enable_tracing:
|
|
2624
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2625
|
-
default_trace_dir = str((demo_base / "traces/v3").resolve())
|
|
2626
|
-
trace_dir = click.prompt(
|
|
2627
|
-
"Trace directory", type=str, default=default_trace_dir, show_default=True
|
|
2628
|
-
)
|
|
2629
|
-
else:
|
|
2630
|
-
trace_dir = None
|
|
2578
|
+
_deploy_commands.run_uvicorn_runtime(
|
|
2579
|
+
app_id,
|
|
2580
|
+
host,
|
|
2581
|
+
port,
|
|
2582
|
+
env_file,
|
|
2583
|
+
reload_flag,
|
|
2584
|
+
force,
|
|
2585
|
+
trace_dir,
|
|
2586
|
+
trace_db,
|
|
2587
|
+
)
|
|
2631
2588
|
|
|
2632
|
-
# Prompt for trace DB if not provided and tracing is enabled
|
|
2633
|
-
if trace_dir and trace_db is None:
|
|
2634
|
-
demo_base = Path(os.environ.get("SYNTH_DEMO_DIR") or Path.cwd())
|
|
2635
|
-
default_trace_db = str((demo_base / "traces/v3/synth_ai.db").resolve())
|
|
2636
|
-
trace_db = click.prompt(
|
|
2637
|
-
"Trace DB path", type=str, default=default_trace_db, show_default=True
|
|
2638
|
-
)
|
|
2639
2589
|
|
|
2640
|
-
|
|
2641
|
-
entry = choice.ensure_entry()
|
|
2642
|
-
_serve_entry(
|
|
2643
|
-
entry, host, port, env_file, reload_flag, force, trace_dir=trace_dir, trace_db=trace_db
|
|
2644
|
-
)
|
|
2590
|
+
_deploy_commands.register_task_app_commands(task_app_group)
|
|
2645
2591
|
|
|
2646
2592
|
|
|
2647
2593
|
def _determine_env_files(
|
|
@@ -2936,87 +2882,6 @@ def _serve_entry(
|
|
|
2936
2882
|
)
|
|
2937
2883
|
|
|
2938
2884
|
|
|
2939
|
-
@task_app_group.command("deploy")
|
|
2940
|
-
@click.argument("app_id", type=str, required=False)
|
|
2941
|
-
@click.option("--name", "modal_name", default=None, help="Override Modal app name")
|
|
2942
|
-
@click.option("--dry-run", is_flag=True, help="Print modal deploy command without executing")
|
|
2943
|
-
@click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
|
|
2944
|
-
@click.option(
|
|
2945
|
-
"--env-file",
|
|
2946
|
-
multiple=True,
|
|
2947
|
-
type=click.Path(),
|
|
2948
|
-
help="Env file to load into the container (can be repeated)",
|
|
2949
|
-
)
|
|
2950
|
-
def deploy_app(
|
|
2951
|
-
app_id: str | None,
|
|
2952
|
-
modal_name: str | None,
|
|
2953
|
-
dry_run: bool,
|
|
2954
|
-
modal_cli: str,
|
|
2955
|
-
env_file: Sequence[str],
|
|
2956
|
-
) -> None:
|
|
2957
|
-
"""Deploy a task app to Modal."""
|
|
2958
|
-
|
|
2959
|
-
demo_dir_path = _load_demo_directory()
|
|
2960
|
-
if demo_dir_path:
|
|
2961
|
-
if not demo_dir_path.is_dir():
|
|
2962
|
-
raise click.ClickException(
|
|
2963
|
-
f"Demo directory not found: {demo_dir_path}\nRun 'synth-ai demo' to create a demo."
|
|
2964
|
-
)
|
|
2965
|
-
os.chdir(demo_dir_path)
|
|
2966
|
-
click.echo(f"Using demo directory: {demo_dir_path}\n")
|
|
2967
|
-
|
|
2968
|
-
choice = _select_app_choice(app_id, purpose="deploy")
|
|
2969
|
-
|
|
2970
|
-
if choice.modal_script:
|
|
2971
|
-
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
2972
|
-
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
2973
|
-
_run_modal_script(
|
|
2974
|
-
choice.modal_script,
|
|
2975
|
-
modal_cli,
|
|
2976
|
-
"deploy",
|
|
2977
|
-
env_paths,
|
|
2978
|
-
modal_name=modal_name,
|
|
2979
|
-
dry_run=dry_run,
|
|
2980
|
-
)
|
|
2981
|
-
return
|
|
2982
|
-
|
|
2983
|
-
entry = choice.ensure_entry()
|
|
2984
|
-
_deploy_entry(entry, modal_name, dry_run, modal_cli, env_file, original_path=choice.path)
|
|
2985
|
-
|
|
2986
|
-
|
|
2987
|
-
@task_app_group.command("modal-serve")
|
|
2988
|
-
@click.argument("app_id", type=str, required=False)
|
|
2989
|
-
@click.option("--modal-cli", default="modal", help="Path to modal CLI executable")
|
|
2990
|
-
@click.option("--name", "modal_name", default=None, help="Override Modal app name (optional)")
|
|
2991
|
-
@click.option(
|
|
2992
|
-
"--env-file",
|
|
2993
|
-
multiple=True,
|
|
2994
|
-
type=click.Path(),
|
|
2995
|
-
help="Env file to load into the container (can be repeated)",
|
|
2996
|
-
)
|
|
2997
|
-
def modal_serve_app(
|
|
2998
|
-
app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
|
|
2999
|
-
) -> None:
|
|
3000
|
-
click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
|
|
3001
|
-
try:
|
|
3002
|
-
choice = _select_app_choice(app_id, purpose="modal-serve")
|
|
3003
|
-
except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
|
|
3004
|
-
raise click.ClickException(
|
|
3005
|
-
f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
|
|
3006
|
-
"Make sure you're running the Click CLI (synth_ai.cli:cli)."
|
|
3007
|
-
) from exc
|
|
3008
|
-
|
|
3009
|
-
if choice.modal_script:
|
|
3010
|
-
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
3011
|
-
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
3012
|
-
_run_modal_script(choice.modal_script, modal_cli, "serve", env_paths, modal_name=modal_name)
|
|
3013
|
-
return
|
|
3014
|
-
|
|
3015
|
-
entry = choice.ensure_entry()
|
|
3016
|
-
click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
|
|
3017
|
-
_modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
|
|
3018
|
-
|
|
3019
|
-
|
|
3020
2885
|
def _write_modal_entrypoint(
|
|
3021
2886
|
entry: TaskAppEntryType,
|
|
3022
2887
|
modal_cfg: ModalDeploymentConfigType,
|
|
@@ -3260,1263 +3125,9 @@ def register(cli: click.Group) -> None:
|
|
|
3260
3125
|
cli.add_command(filter_command)
|
|
3261
3126
|
|
|
3262
3127
|
|
|
3263
|
-
|
|
3264
|
-
"eval",
|
|
3265
|
-
help="Run one-off rollouts against a task app and print judge/eval summaries.",
|
|
3266
|
-
)
|
|
3267
|
-
@click.argument("app_id", type=str, required=False)
|
|
3268
|
-
@click.option(
|
|
3269
|
-
"--config",
|
|
3270
|
-
type=click.Path(),
|
|
3271
|
-
default=None,
|
|
3272
|
-
help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
|
|
3273
|
-
)
|
|
3274
|
-
@click.option(
|
|
3275
|
-
"--url",
|
|
3276
|
-
"task_app_url",
|
|
3277
|
-
type=str,
|
|
3278
|
-
default=None,
|
|
3279
|
-
help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
|
|
3280
|
-
)
|
|
3281
|
-
@click.option(
|
|
3282
|
-
"--seeds",
|
|
3283
|
-
default="0,1,2,3,4",
|
|
3284
|
-
help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
|
|
3285
|
-
)
|
|
3286
|
-
@click.option("--split", default="train", show_default=True, help="Dataset split to use")
|
|
3287
|
-
@click.option(
|
|
3288
|
-
"--model",
|
|
3289
|
-
default=None,
|
|
3290
|
-
help="Model identifier. When omitted the CLI will prompt based on task metadata.",
|
|
3291
|
-
)
|
|
3292
|
-
@click.option(
|
|
3293
|
-
"--env-file",
|
|
3294
|
-
multiple=True,
|
|
3295
|
-
type=click.Path(),
|
|
3296
|
-
help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
|
|
3297
|
-
)
|
|
3298
|
-
@click.option(
|
|
3299
|
-
"--trace-db",
|
|
3300
|
-
default="traces/v3/synth_ai.db",
|
|
3301
|
-
show_default=True,
|
|
3302
|
-
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3303
|
-
)
|
|
3304
|
-
@click.option(
|
|
3305
|
-
"--metadata",
|
|
3306
|
-
multiple=True,
|
|
3307
|
-
help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
|
|
3308
|
-
)
|
|
3309
|
-
@click.option(
|
|
3310
|
-
"--metadata-sql",
|
|
3311
|
-
default=None,
|
|
3312
|
-
help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
|
|
3313
|
-
)
|
|
3314
|
-
def eval_command(
|
|
3315
|
-
app_id: str | None,
|
|
3316
|
-
config: str | None,
|
|
3317
|
-
task_app_url: str | None,
|
|
3318
|
-
seeds: str,
|
|
3319
|
-
split: str,
|
|
3320
|
-
model: str | None,
|
|
3321
|
-
env_file: Sequence[str],
|
|
3322
|
-
trace_db: str,
|
|
3323
|
-
metadata: Sequence[str],
|
|
3324
|
-
metadata_sql: str | None,
|
|
3325
|
-
) -> None:
|
|
3326
|
-
"""Run rollouts against a task app and report judge statistics.
|
|
3327
|
-
|
|
3328
|
-
By default the command spins up the selected task app in-process, executes the
|
|
3329
|
-
requested seeds, and prints aggregate scores (official and custom judges). When
|
|
3330
|
-
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3331
|
-
forward authentication headers to the running service.
|
|
3332
|
-
"""
|
|
3333
|
-
# Parse and validate TOML config
|
|
3334
|
-
from synth_ai.task.config import EvalConfig
|
|
3335
|
-
|
|
3336
|
-
cfg: dict[str, Any] = {}
|
|
3337
|
-
eval_cfg: EvalConfig | None = None
|
|
3338
|
-
config_path: Path | None = None
|
|
3339
|
-
|
|
3340
|
-
if config:
|
|
3341
|
-
config_path = Path(config)
|
|
3342
|
-
else:
|
|
3343
|
-
auto_configs = _discover_eval_config_paths()
|
|
3344
|
-
if auto_configs:
|
|
3345
|
-
config_path = auto_configs[0]
|
|
3346
|
-
click.echo(f"Using eval config: {config_path}")
|
|
3347
|
-
|
|
3348
|
-
if config_path:
|
|
3349
|
-
if _toml is None:
|
|
3350
|
-
raise click.ClickException(
|
|
3351
|
-
"TOML parser not available; use Python 3.11+ or install tomli"
|
|
3352
|
-
)
|
|
3353
|
-
if not config_path.exists():
|
|
3354
|
-
raise click.ClickException(f"Eval config not found: {config_path}")
|
|
3355
|
-
try:
|
|
3356
|
-
data = config_path.read_bytes()
|
|
3357
|
-
parsed = _toml.loads(data.decode("utf-8"))
|
|
3358
|
-
if isinstance(parsed, dict):
|
|
3359
|
-
section = parsed.get("eval")
|
|
3360
|
-
cfg = dict(section) if isinstance(section, dict) else dict(parsed)
|
|
3361
|
-
|
|
3362
|
-
# Validate config with dataclass
|
|
3363
|
-
try:
|
|
3364
|
-
eval_cfg = EvalConfig.from_dict(cfg)
|
|
3365
|
-
click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
|
|
3366
|
-
except (ValueError, TypeError) as validation_error:
|
|
3367
|
-
raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
|
|
3368
|
-
except click.ClickException:
|
|
3369
|
-
raise
|
|
3370
|
-
except Exception as exc:
|
|
3371
|
-
raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
|
|
3372
|
-
|
|
3373
|
-
# CLI args override config
|
|
3374
|
-
if eval_cfg:
|
|
3375
|
-
app_id = app_id or eval_cfg.app_id
|
|
3376
|
-
else:
|
|
3377
|
-
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
3378
|
-
|
|
3379
|
-
metadata_filters: dict[str, str] = {}
|
|
3380
|
-
if eval_cfg:
|
|
3381
|
-
metadata_filters.update(eval_cfg.metadata)
|
|
3382
|
-
else:
|
|
3383
|
-
cfg_metadata = cfg.get("metadata")
|
|
3384
|
-
if isinstance(cfg_metadata, dict):
|
|
3385
|
-
for key, value in cfg_metadata.items():
|
|
3386
|
-
metadata_filters[str(key)] = str(value)
|
|
3387
|
-
elif isinstance(cfg_metadata, list):
|
|
3388
|
-
for item in cfg_metadata:
|
|
3389
|
-
if isinstance(item, str) and "=" in item:
|
|
3390
|
-
key, value = item.split("=", 1)
|
|
3391
|
-
metadata_filters[key.strip()] = value.strip()
|
|
3392
|
-
|
|
3393
|
-
for item in metadata or ():
|
|
3394
|
-
if "=" not in item:
|
|
3395
|
-
raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
|
|
3396
|
-
key, value = item.split("=", 1)
|
|
3397
|
-
key = key.strip()
|
|
3398
|
-
value = value.strip()
|
|
3399
|
-
if not key or not value:
|
|
3400
|
-
raise click.ClickException(f"Invalid metadata filter: {item}")
|
|
3401
|
-
metadata_filters[key] = value
|
|
3402
|
-
|
|
3403
|
-
metadata_sql_query: str | None = None
|
|
3404
|
-
if eval_cfg and eval_cfg.metadata_sql:
|
|
3405
|
-
metadata_sql_query = eval_cfg.metadata_sql
|
|
3406
|
-
else:
|
|
3407
|
-
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3408
|
-
if isinstance(cfg_metadata_sql, dict):
|
|
3409
|
-
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3410
|
-
elif isinstance(cfg_metadata_sql, str):
|
|
3411
|
-
metadata_sql_query = cfg_metadata_sql
|
|
3412
|
-
|
|
3413
|
-
if metadata_sql:
|
|
3414
|
-
metadata_sql_query = metadata_sql
|
|
3415
|
-
if metadata_sql_query is not None:
|
|
3416
|
-
metadata_sql_query = str(metadata_sql_query)
|
|
3417
|
-
|
|
3418
|
-
trace_db_url: str | None = None
|
|
3419
|
-
trace_db = (trace_db or "").strip()
|
|
3420
|
-
if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
|
|
3421
|
-
if "://" in trace_db:
|
|
3422
|
-
trace_db_url = trace_db
|
|
3423
|
-
else:
|
|
3424
|
-
trace_path = Path(trace_db).expanduser()
|
|
3425
|
-
trace_path.parent.mkdir(parents=True, exist_ok=True)
|
|
3426
|
-
trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
|
|
3427
|
-
trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
|
|
3428
|
-
|
|
3429
|
-
# Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
|
|
3430
|
-
if cfg.get("model") and not model:
|
|
3431
|
-
model = str(cfg["model"]) # type: ignore[index]
|
|
3432
|
-
if cfg.get("seeds") and seeds == "0,1,2,3,4":
|
|
3433
|
-
val = cfg["seeds"]
|
|
3434
|
-
if isinstance(val, list):
|
|
3435
|
-
with contextlib.suppress(Exception):
|
|
3436
|
-
seeds = ",".join(str(int(x)) for x in val)
|
|
3437
|
-
elif isinstance(val, str):
|
|
3438
|
-
seeds = val
|
|
3439
|
-
elif isinstance(val, int):
|
|
3440
|
-
seeds = str(val)
|
|
3441
|
-
if cfg.get("env_file") and not env_file:
|
|
3442
|
-
ef = cfg["env_file"]
|
|
3443
|
-
if isinstance(ef, str):
|
|
3444
|
-
env_file = (ef,) # type: ignore[assignment]
|
|
3445
|
-
elif isinstance(ef, list):
|
|
3446
|
-
env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
|
|
3447
|
-
|
|
3448
|
-
choice_for_env: AppChoice | None = None
|
|
3449
|
-
entry: TaskAppEntryType | None = None
|
|
3450
|
-
if task_app_url is None:
|
|
3451
|
-
choice_for_env = _select_app_choice(app_id, purpose="eval")
|
|
3452
|
-
entry = choice_for_env.ensure_entry()
|
|
3453
|
-
|
|
3454
|
-
env_paths: list[Path] = []
|
|
3455
|
-
if entry is not None:
|
|
3456
|
-
original_env_path = choice_for_env.path if choice_for_env is not None else None
|
|
3457
|
-
env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
|
|
3458
|
-
else:
|
|
3459
|
-
if not env_file:
|
|
3460
|
-
raise click.ClickException("--env-file is required when using --url")
|
|
3461
|
-
for candidate in env_file:
|
|
3462
|
-
p = Path(candidate).expanduser()
|
|
3463
|
-
if not p.exists():
|
|
3464
|
-
raise click.ClickException(f"Env file not found: {p}")
|
|
3465
|
-
env_paths.append(p)
|
|
3466
|
-
|
|
3467
|
-
click.echo("Using env file(s): " + ", ".join(str(p) for p in env_paths))
|
|
3468
|
-
_load_env_files_into_process([str(Path(p)) for p in env_paths])
|
|
3469
|
-
|
|
3470
|
-
if task_app_url is None:
|
|
3471
|
-
config = entry.config_factory() # type: ignore[union-attr]
|
|
3472
|
-
# Help the type checker; runtime check also enforced in server.run_task_app
|
|
3473
|
-
if not isinstance(config, TaskAppConfig):
|
|
3474
|
-
raise click.ClickException(
|
|
3475
|
-
"Invalid task app: config_factory did not return TaskAppConfig"
|
|
3476
|
-
)
|
|
3477
|
-
app = create_task_app(config)
|
|
3478
|
-
|
|
3479
|
-
# Determine supported models
|
|
3480
|
-
inference_meta: dict[str, Any] = {}
|
|
3481
|
-
supported: list[str] = []
|
|
3482
|
-
seen_models: set[str] = set()
|
|
3483
|
-
|
|
3484
|
-
def _add_supported_model(candidate: Any) -> None:
|
|
3485
|
-
if not candidate:
|
|
3486
|
-
return
|
|
3487
|
-
text = str(candidate).strip()
|
|
3488
|
-
if not text or text in seen_models:
|
|
3489
|
-
return
|
|
3490
|
-
supported.append(text)
|
|
3491
|
-
seen_models.add(text)
|
|
3492
|
-
|
|
3493
|
-
if task_app_url is None:
|
|
3494
|
-
try:
|
|
3495
|
-
if hasattr(config, "base_task_info") and config.base_task_info:
|
|
3496
|
-
inf_obj = getattr(config.base_task_info, "inference", None)
|
|
3497
|
-
if inf_obj is not None:
|
|
3498
|
-
if hasattr(inf_obj, "model_dump"):
|
|
3499
|
-
inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
|
|
3500
|
-
elif isinstance(inf_obj, dict):
|
|
3501
|
-
inference_meta = dict(inf_obj)
|
|
3502
|
-
except Exception:
|
|
3503
|
-
inference_meta = {}
|
|
3504
|
-
else:
|
|
3505
|
-
try:
|
|
3506
|
-
import httpx as _hx
|
|
3507
|
-
|
|
3508
|
-
headers = {}
|
|
3509
|
-
api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
|
|
3510
|
-
if api_key:
|
|
3511
|
-
headers["X-API-Key"] = api_key
|
|
3512
|
-
with _hx.Client(base_url=task_app_url, headers=headers, timeout=15.0) as c:
|
|
3513
|
-
info = c.get("/info").json()
|
|
3514
|
-
inf = info.get("inference") if isinstance(info, dict) else None
|
|
3515
|
-
if isinstance(inf, dict):
|
|
3516
|
-
inference_meta = dict(inf)
|
|
3517
|
-
except Exception:
|
|
3518
|
-
inference_meta = {}
|
|
3519
|
-
|
|
3520
|
-
default_model = inference_meta.get("model")
|
|
3521
|
-
if isinstance(default_model, str):
|
|
3522
|
-
_add_supported_model(default_model)
|
|
3523
|
-
|
|
3524
|
-
models_field = inference_meta.get("models")
|
|
3525
|
-
if isinstance(models_field, list):
|
|
3526
|
-
for candidate in models_field:
|
|
3527
|
-
_add_supported_model(candidate)
|
|
3528
|
-
|
|
3529
|
-
supported_models = inference_meta.get("supported_models")
|
|
3530
|
-
if isinstance(supported_models, list):
|
|
3531
|
-
for candidate in supported_models:
|
|
3532
|
-
_add_supported_model(candidate)
|
|
3533
|
-
|
|
3534
|
-
providers = inference_meta.get("providers")
|
|
3535
|
-
if isinstance(providers, list):
|
|
3536
|
-
if "openai" in providers:
|
|
3537
|
-
_add_supported_model("gpt-5")
|
|
3538
|
-
if "groq" in providers:
|
|
3539
|
-
_add_supported_model("groq:llama-3.1-70b-versatile")
|
|
3540
|
-
|
|
3541
|
-
_add_supported_model("synth:qwen-0.6b")
|
|
3542
|
-
|
|
3543
|
-
selected_model = model
|
|
3544
|
-
if not selected_model:
|
|
3545
|
-
if not supported:
|
|
3546
|
-
raise click.ClickException(
|
|
3547
|
-
"No supported models; supply --model or add base_task_info.inference.model"
|
|
3548
|
-
)
|
|
3549
|
-
click.echo("Select model to evaluate:")
|
|
3550
|
-
for idx, m in enumerate(supported, start=1):
|
|
3551
|
-
click.echo(f" {idx}) {m}")
|
|
3552
|
-
choice_idx = click.prompt("Enter choice", type=click.IntRange(1, len(supported)))
|
|
3553
|
-
selected_model = supported[choice_idx - 1]
|
|
3554
|
-
|
|
3555
|
-
try:
|
|
3556
|
-
seed_values = [int(s.strip()) for s in seeds.split(",") if s.strip()]
|
|
3557
|
-
except Exception as exc:
|
|
3558
|
-
raise click.ClickException("Invalid --seeds; expected comma-separated integers") from exc
|
|
3559
|
-
|
|
3560
|
-
import httpx
|
|
3561
|
-
|
|
3562
|
-
headers = {}
|
|
3563
|
-
api_key = (os.environ.get("ENVIRONMENT_API_KEY") or "").strip()
|
|
3564
|
-
if api_key:
|
|
3565
|
-
headers["X-API-Key"] = api_key
|
|
3566
|
-
|
|
3567
|
-
# Precompute optional policy overrides from TOML
|
|
3568
|
-
policy_overrides: dict[str, Any] = {}
|
|
3569
|
-
try:
|
|
3570
|
-
# Accept [eval.policy] table or top-level keys for convenience
|
|
3571
|
-
if isinstance(cfg.get("policy"), dict):
|
|
3572
|
-
policy_overrides.update(dict(cfg["policy"]))
|
|
3573
|
-
# Back-compat: allow temperature/max_tokens at top level
|
|
3574
|
-
for k in (
|
|
3575
|
-
"temperature",
|
|
3576
|
-
"max_tokens",
|
|
3577
|
-
"reasoning_effort",
|
|
3578
|
-
"system_hint",
|
|
3579
|
-
"tool_choice",
|
|
3580
|
-
"inference_url",
|
|
3581
|
-
):
|
|
3582
|
-
if k in cfg and k not in policy_overrides:
|
|
3583
|
-
policy_overrides[k] = cfg.get(k)
|
|
3584
|
-
except Exception:
|
|
3585
|
-
policy_overrides = {}
|
|
3586
|
-
|
|
3587
|
-
raw_concurrency = cfg.get("concurrency")
|
|
3588
|
-
try:
|
|
3589
|
-
concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
|
|
3590
|
-
except Exception:
|
|
3591
|
-
concurrency_limit = 1
|
|
3592
|
-
if concurrency_limit <= 0:
|
|
3593
|
-
concurrency_limit = 1
|
|
3594
|
-
concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
|
|
3595
|
-
|
|
3596
|
-
judge_specs: list[JudgeSpec] = []
|
|
3597
|
-
|
|
3598
|
-
def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
|
|
3599
|
-
if not judge_cfg:
|
|
3600
|
-
return
|
|
3601
|
-
judge_module = judge_cfg.get("module")
|
|
3602
|
-
judge_path = judge_cfg.get("path")
|
|
3603
|
-
judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
|
|
3604
|
-
if judge_module and judge_path:
|
|
3605
|
-
raise click.ClickException("Judge config cannot set both 'module' and 'path'")
|
|
3606
|
-
if not judge_module and not judge_path:
|
|
3607
|
-
raise click.ClickException("Judge config requires 'module' or 'path'")
|
|
3608
|
-
try:
|
|
3609
|
-
if judge_module:
|
|
3610
|
-
module = importlib.import_module(str(judge_module))
|
|
3611
|
-
else:
|
|
3612
|
-
path = Path(str(judge_path)).expanduser()
|
|
3613
|
-
if not path.exists():
|
|
3614
|
-
raise click.ClickException(f"Judge module path not found: {path}")
|
|
3615
|
-
spec = importlib.util.spec_from_file_location(
|
|
3616
|
-
f"_eval_judge_{path.stem}", path
|
|
3617
|
-
)
|
|
3618
|
-
if not spec or not spec.loader:
|
|
3619
|
-
raise click.ClickException(f"Failed to load judge module from {path}")
|
|
3620
|
-
module = importlib.util.module_from_spec(spec)
|
|
3621
|
-
sys.modules[spec.name] = module
|
|
3622
|
-
spec.loader.exec_module(module)
|
|
3623
|
-
except click.ClickException:
|
|
3624
|
-
raise
|
|
3625
|
-
except Exception as exc:
|
|
3626
|
-
raise click.ClickException(f"Unable to load judge module: {exc}") from exc
|
|
3627
|
-
|
|
3628
|
-
if judge_callable_name:
|
|
3629
|
-
try:
|
|
3630
|
-
judge_fn = getattr(module, str(judge_callable_name))
|
|
3631
|
-
except AttributeError as exc:
|
|
3632
|
-
raise click.ClickException(
|
|
3633
|
-
f"Judge callable '{judge_callable_name}' not found in module"
|
|
3634
|
-
) from exc
|
|
3635
|
-
else:
|
|
3636
|
-
if hasattr(module, "judge"):
|
|
3637
|
-
judge_fn = module.judge
|
|
3638
|
-
else:
|
|
3639
|
-
raise click.ClickException("Judge module must expose 'judge' callable")
|
|
3640
|
-
|
|
3641
|
-
if not callable(judge_fn):
|
|
3642
|
-
raise click.ClickException("Judge callable is not callable")
|
|
3643
|
-
|
|
3644
|
-
judge_kwargs = {
|
|
3645
|
-
k: v
|
|
3646
|
-
for k, v in judge_cfg.items()
|
|
3647
|
-
if k not in {"module", "path", "callable", "function", "name"}
|
|
3648
|
-
}
|
|
3649
|
-
display_name = str(
|
|
3650
|
-
judge_cfg.get("name")
|
|
3651
|
-
or name_hint
|
|
3652
|
-
or f"judge{len(judge_specs) + 1}"
|
|
3653
|
-
)
|
|
3654
|
-
judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
|
|
3655
|
-
|
|
3656
|
-
raw_judge_cfg = cfg.get("judge")
|
|
3657
|
-
if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
|
|
3658
|
-
direct_keys = {"module", "path", "callable", "function", "name"}
|
|
3659
|
-
has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
|
|
3660
|
-
nested_candidates = [
|
|
3661
|
-
(key, value)
|
|
3662
|
-
for key, value in raw_judge_cfg.items()
|
|
3663
|
-
if isinstance(value, dict)
|
|
3664
|
-
]
|
|
3665
|
-
if has_direct_keys and not nested_candidates:
|
|
3666
|
-
_register_judge(None, raw_judge_cfg)
|
|
3667
|
-
else:
|
|
3668
|
-
for sub_name, sub_cfg in nested_candidates:
|
|
3669
|
-
_register_judge(sub_name, sub_cfg)
|
|
3670
|
-
|
|
3671
|
-
raw_judges_list = cfg.get("judges")
|
|
3672
|
-
if isinstance(raw_judges_list, list):
|
|
3673
|
-
for _index, entry in enumerate(raw_judges_list, start=1):
|
|
3674
|
-
if isinstance(entry, dict):
|
|
3675
|
-
_register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
|
|
3676
|
-
|
|
3677
|
-
records: list[dict[str, Any]] = []
|
|
3678
|
-
|
|
3679
|
-
successes = 0
|
|
3680
|
-
failures = 0
|
|
3681
|
-
# Aggregate outcome stats across successful seeds
|
|
3682
|
-
outcome_sum: float = 0.0
|
|
3683
|
-
outcome_count: int = 0
|
|
3684
|
-
outcome_correct: int = 0
|
|
3685
|
-
|
|
3686
|
-
def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
|
|
3687
|
-
rows: dict[int, dict[str, Any]] = {}
|
|
3688
|
-
if not isinstance(taskset, dict):
|
|
3689
|
-
return rows
|
|
3690
|
-
|
|
3691
|
-
scenario_ids = taskset.get("scenario_ids") or []
|
|
3692
|
-
loop_ids = taskset.get("loop_ids") or []
|
|
3693
|
-
thread_ids = taskset.get("thread_ids") or []
|
|
3694
|
-
difficulty_map = taskset.get("difficulty_map") or {}
|
|
3695
|
-
|
|
3696
|
-
max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
|
|
3697
|
-
for seed in range(max_len):
|
|
3698
|
-
scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
|
|
3699
|
-
loop_id = loop_ids[seed] if seed < len(loop_ids) else None
|
|
3700
|
-
thread_id = thread_ids[seed] if seed < len(thread_ids) else None
|
|
3701
|
-
difficulty = None
|
|
3702
|
-
if isinstance(difficulty_map, dict):
|
|
3703
|
-
if scenario_id and scenario_id in difficulty_map:
|
|
3704
|
-
difficulty = difficulty_map.get(scenario_id)
|
|
3705
|
-
elif str(seed) in difficulty_map:
|
|
3706
|
-
difficulty = difficulty_map.get(str(seed))
|
|
3707
|
-
|
|
3708
|
-
rows[seed] = {
|
|
3709
|
-
"seed": seed,
|
|
3710
|
-
"scenario_id": scenario_id,
|
|
3711
|
-
"loop_id": loop_id,
|
|
3712
|
-
"thread_id": thread_id,
|
|
3713
|
-
"difficulty": difficulty,
|
|
3714
|
-
}
|
|
3715
|
-
return rows
|
|
3716
|
-
|
|
3717
|
-
def _apply_metadata_filters(
|
|
3718
|
-
rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
|
|
3719
|
-
) -> list[int]:
|
|
3720
|
-
if not filters:
|
|
3721
|
-
return seeds_list
|
|
3722
|
-
filtered: list[int] = []
|
|
3723
|
-
for seed in seeds_list:
|
|
3724
|
-
row = rows.get(seed)
|
|
3725
|
-
if not row:
|
|
3726
|
-
continue
|
|
3727
|
-
include = True
|
|
3728
|
-
for key, expected in filters.items():
|
|
3729
|
-
actual = row.get(key)
|
|
3730
|
-
if actual is None:
|
|
3731
|
-
include = False
|
|
3732
|
-
break
|
|
3733
|
-
if str(actual).lower() != expected.lower():
|
|
3734
|
-
include = False
|
|
3735
|
-
break
|
|
3736
|
-
if include:
|
|
3737
|
-
filtered.append(seed)
|
|
3738
|
-
return filtered
|
|
3739
|
-
|
|
3740
|
-
def _apply_metadata_sql(
|
|
3741
|
-
rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
|
|
3742
|
-
) -> list[int]:
|
|
3743
|
-
"""Return seeds that satisfy an arbitrary SQL query.
|
|
3744
|
-
|
|
3745
|
-
The query is executed against an in-memory SQLite table named `tasks`
|
|
3746
|
-
with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
|
|
3747
|
-
Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
|
|
3748
|
-
"""
|
|
3749
|
-
if not query:
|
|
3750
|
-
return seeds_list
|
|
3751
|
-
conn = sqlite3.connect(":memory:")
|
|
3752
|
-
try:
|
|
3753
|
-
cur = conn.cursor()
|
|
3754
|
-
cur.execute(
|
|
3755
|
-
"CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
|
|
3756
|
-
)
|
|
3757
|
-
insert_stmt = (
|
|
3758
|
-
"INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
|
|
3759
|
-
)
|
|
3760
|
-
for seed in seeds_list:
|
|
3761
|
-
row = rows.get(seed, {})
|
|
3762
|
-
cur.execute(
|
|
3763
|
-
insert_stmt,
|
|
3764
|
-
[
|
|
3765
|
-
seed,
|
|
3766
|
-
row.get("scenario_id"),
|
|
3767
|
-
row.get("loop_id"),
|
|
3768
|
-
row.get("thread_id"),
|
|
3769
|
-
row.get("difficulty"),
|
|
3770
|
-
],
|
|
3771
|
-
)
|
|
3772
|
-
|
|
3773
|
-
result = cur.execute(query)
|
|
3774
|
-
fetched = result.fetchall()
|
|
3775
|
-
if not fetched:
|
|
3776
|
-
return []
|
|
3777
|
-
description = result.description or []
|
|
3778
|
-
col_names = [col[0] for col in description]
|
|
3779
|
-
seeds_out: list[int] = []
|
|
3780
|
-
for entry in fetched:
|
|
3781
|
-
value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
|
|
3782
|
-
try:
|
|
3783
|
-
seeds_out.append(int(value))
|
|
3784
|
-
except Exception as exc:
|
|
3785
|
-
raise click.ClickException(
|
|
3786
|
-
"metadata SQL query must return seed integers"
|
|
3787
|
-
) from exc
|
|
3788
|
-
seeds_set = set(seeds_out)
|
|
3789
|
-
return [seed for seed in seeds_list if seed in seeds_set]
|
|
3790
|
-
except sqlite3.Error as exc:
|
|
3791
|
-
raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
|
|
3792
|
-
finally:
|
|
3793
|
-
conn.close()
|
|
3794
|
-
|
|
3795
|
-
async def _run_eval() -> None:
|
|
3796
|
-
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
|
|
3797
|
-
|
|
3798
|
-
if trace_tracer is not None and trace_tracer.db is None:
|
|
3799
|
-
await trace_tracer.initialize()
|
|
3800
|
-
|
|
3801
|
-
if task_app_url is None:
|
|
3802
|
-
transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
|
|
3803
|
-
async_client = httpx.AsyncClient(
|
|
3804
|
-
transport=cast(Any, transport),
|
|
3805
|
-
base_url="http://eval.local",
|
|
3806
|
-
timeout=300.0,
|
|
3807
|
-
follow_redirects=True,
|
|
3808
|
-
headers=headers,
|
|
3809
|
-
)
|
|
3810
|
-
else:
|
|
3811
|
-
async_client = httpx.AsyncClient(
|
|
3812
|
-
base_url=task_app_url,
|
|
3813
|
-
timeout=300.0,
|
|
3814
|
-
follow_redirects=True,
|
|
3815
|
-
headers=headers,
|
|
3816
|
-
)
|
|
3817
|
-
|
|
3818
|
-
try:
|
|
3819
|
-
taskset_payload: dict[str, Any] | None = None
|
|
3820
|
-
try:
|
|
3821
|
-
task_info_response = await async_client.get("/task_info")
|
|
3822
|
-
except Exception:
|
|
3823
|
-
task_info_response = None
|
|
3824
|
-
if task_info_response is not None and task_info_response.status_code == 200:
|
|
3825
|
-
with contextlib.suppress(Exception):
|
|
3826
|
-
payload_json = task_info_response.json()
|
|
3827
|
-
if isinstance(payload_json, dict) and "taskset" in payload_json:
|
|
3828
|
-
taskset_payload = payload_json.get("taskset")
|
|
3829
|
-
if not isinstance(taskset_payload, dict):
|
|
3830
|
-
taskset_payload = None
|
|
3831
|
-
elif isinstance(payload_json, dict):
|
|
3832
|
-
taskset_payload = payload_json
|
|
3833
|
-
|
|
3834
|
-
available_seeds = list(seed_values)
|
|
3835
|
-
if metadata_sql_query or metadata_filters:
|
|
3836
|
-
if not taskset_payload:
|
|
3837
|
-
raise click.ClickException(
|
|
3838
|
-
"Task metadata filters require the task app to expose /task_info metadata"
|
|
3839
|
-
)
|
|
3840
|
-
rows = _build_task_rows(taskset_payload)
|
|
3841
|
-
if metadata_sql_query:
|
|
3842
|
-
available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
|
|
3843
|
-
if metadata_filters:
|
|
3844
|
-
available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
|
|
3845
|
-
if not available_seeds:
|
|
3846
|
-
raise click.ClickException("No seeds match the provided metadata filters")
|
|
3847
|
-
seed_values = available_seeds
|
|
3848
|
-
|
|
3849
|
-
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
3850
|
-
|
|
3851
|
-
async def _run_seed(seed_val: int) -> None:
|
|
3852
|
-
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3853
|
-
# Read env_name and policy_name from config if available
|
|
3854
|
-
env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
|
|
3855
|
-
policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
|
|
3856
|
-
env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
|
|
3857
|
-
policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
|
|
3858
|
-
|
|
3859
|
-
# Debug: print config parsing
|
|
3860
|
-
if seed_val == 0:
|
|
3861
|
-
click.echo(f"[DEBUG] env_name from config: {env_name}")
|
|
3862
|
-
click.echo(f"[DEBUG] policy_name from config: {policy_name}")
|
|
3863
|
-
|
|
3864
|
-
# Generate default ops sequence if not provided
|
|
3865
|
-
max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
|
|
3866
|
-
ops_list = cfg.get("ops", [])
|
|
3867
|
-
if not ops_list:
|
|
3868
|
-
# Generate default "agent, env" pairs for max_llm_calls
|
|
3869
|
-
ops_list = ["agent", "env"] * int(max_llm_calls)
|
|
3870
|
-
|
|
3871
|
-
body = {
|
|
3872
|
-
"run_id": str(uuid.uuid4()),
|
|
3873
|
-
"env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
|
|
3874
|
-
"policy": {
|
|
3875
|
-
"policy_name": policy_name or selected_model,
|
|
3876
|
-
"config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
|
|
3877
|
-
},
|
|
3878
|
-
"ops": ops_list,
|
|
3879
|
-
"record": {
|
|
3880
|
-
"return_trace": cfg.get("return_trace", True),
|
|
3881
|
-
"trace_format": cfg.get("trace_format", "structured"),
|
|
3882
|
-
},
|
|
3883
|
-
"mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
|
|
3884
|
-
}
|
|
3885
|
-
if env_name:
|
|
3886
|
-
body["env"]["env_name"] = env_name
|
|
3887
|
-
|
|
3888
|
-
# Debug: print the body being sent
|
|
3889
|
-
if seed_val == 0:
|
|
3890
|
-
click.echo(f"[DEBUG] rollout body env: {body['env']}")
|
|
3891
|
-
click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
|
|
3892
|
-
click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
|
|
3893
|
-
rollout_elapsed: float | None = None
|
|
3894
|
-
rollout_start = time.perf_counter()
|
|
3895
|
-
try:
|
|
3896
|
-
import logging
|
|
3897
|
-
_log = logging.getLogger(__name__)
|
|
3898
|
-
_log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
|
|
3899
|
-
async with semaphore:
|
|
3900
|
-
response = await async_client.post("/rollout", json=body)
|
|
3901
|
-
rollout_elapsed = time.perf_counter() - rollout_start
|
|
3902
|
-
except Exception as exc:
|
|
3903
|
-
failures += 1
|
|
3904
|
-
click.echo(f"seed={seed_val} error={exc}")
|
|
3905
|
-
return
|
|
3906
|
-
|
|
3907
|
-
ok = 200 <= response.status_code < 300
|
|
3908
|
-
if ok:
|
|
3909
|
-
successes += 1
|
|
3910
|
-
else:
|
|
3911
|
-
failures += 1
|
|
3912
|
-
|
|
3913
|
-
summary = [f"seed={seed_val}", f"status={response.status_code}"]
|
|
3914
|
-
data: Any
|
|
3915
|
-
try:
|
|
3916
|
-
data = response.json()
|
|
3917
|
-
except Exception:
|
|
3918
|
-
data = None
|
|
3919
|
-
|
|
3920
|
-
# Debug: print validation errors
|
|
3921
|
-
if response.status_code == 422 and data:
|
|
3922
|
-
click.echo(f"[DEBUG] 422 Validation Error: {data}")
|
|
3923
|
-
|
|
3924
|
-
metrics: dict[str, Any] | None = None
|
|
3925
|
-
completion: str | None = None
|
|
3926
|
-
prompt_index: int | None = None
|
|
3927
|
-
prompt_text: str | None = None
|
|
3928
|
-
task_id: str | None = None
|
|
3929
|
-
task_split: str | None = None
|
|
3930
|
-
task_rubric_id: str | None = None
|
|
3931
|
-
|
|
3932
|
-
trace_namespace: dict[str, Any] | None = None
|
|
3933
|
-
session_trace_dict: dict[str, Any] | None = None
|
|
3934
|
-
|
|
3935
|
-
if isinstance(data, dict):
|
|
3936
|
-
import logging
|
|
3937
|
-
_logger = logging.getLogger(__name__)
|
|
3938
|
-
_logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
|
|
3939
|
-
if "detail" in data:
|
|
3940
|
-
_logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
|
|
3941
|
-
trace_namespace = data.get("trace")
|
|
3942
|
-
_logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
|
|
3943
|
-
if not isinstance(trace_namespace, dict):
|
|
3944
|
-
raise RuntimeError(
|
|
3945
|
-
"The 'synth-ai eval' command requires trace payloads in rollout responses. "
|
|
3946
|
-
"Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
|
|
3947
|
-
"and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
|
|
3948
|
-
"Note: This is specific to the eval command - general rollout endpoints don't require traces."
|
|
3949
|
-
)
|
|
3950
|
-
# Handle both "compact" and "full" trace formats:
|
|
3951
|
-
# - compact: trace_namespace contains {session_id, metadata, ...}
|
|
3952
|
-
# - full: trace_namespace IS the full session_trace dict
|
|
3953
|
-
session_trace_dict = trace_namespace.get("session_trace")
|
|
3954
|
-
if not isinstance(session_trace_dict, dict):
|
|
3955
|
-
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
3956
|
-
if "session_id" in trace_namespace:
|
|
3957
|
-
session_trace_dict = trace_namespace
|
|
3958
|
-
else:
|
|
3959
|
-
raise RuntimeError(
|
|
3960
|
-
"The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
|
|
3961
|
-
"Ensure the task app is using tracing_v3 and returning structured trace data."
|
|
3962
|
-
)
|
|
3963
|
-
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
3964
|
-
if metrics:
|
|
3965
|
-
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
3966
|
-
outcome = metrics.get("outcome_score")
|
|
3967
|
-
if mean_return is not None:
|
|
3968
|
-
summary.append(f"mean_return={mean_return}")
|
|
3969
|
-
if outcome is not None:
|
|
3970
|
-
summary.append(f"outcome={outcome}")
|
|
3971
|
-
try:
|
|
3972
|
-
val = float(outcome)
|
|
3973
|
-
outcome_sum += val
|
|
3974
|
-
outcome_count += 1
|
|
3975
|
-
if val >= 0.5:
|
|
3976
|
-
outcome_correct += 1
|
|
3977
|
-
except Exception:
|
|
3978
|
-
pass
|
|
3979
|
-
trajs = (
|
|
3980
|
-
data.get("trajectories")
|
|
3981
|
-
if isinstance(data.get("trajectories"), list)
|
|
3982
|
-
else None
|
|
3983
|
-
)
|
|
3984
|
-
if trajs:
|
|
3985
|
-
first = trajs[0] if trajs else None
|
|
3986
|
-
steps = first.get("steps") if isinstance(first, dict) else None
|
|
3987
|
-
if isinstance(steps, list) and steps:
|
|
3988
|
-
step0 = steps[0]
|
|
3989
|
-
tool_calls = step0.get("tool_calls") or step0.get("tools") or []
|
|
3990
|
-
if isinstance(tool_calls, list):
|
|
3991
|
-
summary.append(f"tool_calls={len(tool_calls)}")
|
|
3992
|
-
obs = step0.get("obs") if isinstance(step0, dict) else None
|
|
3993
|
-
if isinstance(obs, dict):
|
|
3994
|
-
idx_val = obs.get("prompt_index")
|
|
3995
|
-
if isinstance(idx_val, int):
|
|
3996
|
-
prompt_index = idx_val
|
|
3997
|
-
prompt_raw = obs.get("prompt")
|
|
3998
|
-
if isinstance(prompt_raw, str):
|
|
3999
|
-
prompt_text = prompt_raw
|
|
4000
|
-
if task_id is None:
|
|
4001
|
-
candidate_id = obs.get("task_id")
|
|
4002
|
-
if isinstance(candidate_id, str) and candidate_id:
|
|
4003
|
-
task_id = candidate_id
|
|
4004
|
-
if task_split is None:
|
|
4005
|
-
candidate_split = obs.get("task_split")
|
|
4006
|
-
if isinstance(candidate_split, str) and candidate_split:
|
|
4007
|
-
task_split = candidate_split
|
|
4008
|
-
if task_rubric_id is None:
|
|
4009
|
-
candidate_rid = obs.get("task_rubric_id")
|
|
4010
|
-
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4011
|
-
task_rubric_id = candidate_rid
|
|
4012
|
-
final = first.get("final") if isinstance(first, dict) else None
|
|
4013
|
-
if isinstance(final, dict):
|
|
4014
|
-
final_obs = final.get("observation")
|
|
4015
|
-
if isinstance(final_obs, dict):
|
|
4016
|
-
comp_val = final_obs.get("completion")
|
|
4017
|
-
if isinstance(comp_val, str):
|
|
4018
|
-
completion = comp_val
|
|
4019
|
-
if task_id is None:
|
|
4020
|
-
candidate_id = final_obs.get("task_id")
|
|
4021
|
-
if isinstance(candidate_id, str) and candidate_id:
|
|
4022
|
-
task_id = candidate_id
|
|
4023
|
-
if task_split is None:
|
|
4024
|
-
candidate_split = final_obs.get("task_split")
|
|
4025
|
-
if isinstance(candidate_split, str) and candidate_split:
|
|
4026
|
-
task_split = candidate_split
|
|
4027
|
-
if task_rubric_id is None:
|
|
4028
|
-
candidate_rid = final_obs.get("task_rubric_id")
|
|
4029
|
-
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4030
|
-
task_rubric_id = candidate_rid
|
|
4031
|
-
final_info = final.get("info")
|
|
4032
|
-
if isinstance(final_info, dict):
|
|
4033
|
-
if task_id is None:
|
|
4034
|
-
candidate_id = final_info.get("task_id")
|
|
4035
|
-
if isinstance(candidate_id, str) and candidate_id:
|
|
4036
|
-
task_id = candidate_id
|
|
4037
|
-
if task_split is None:
|
|
4038
|
-
candidate_split = final_info.get("task_split")
|
|
4039
|
-
if isinstance(candidate_split, str) and candidate_split:
|
|
4040
|
-
task_split = candidate_split
|
|
4041
|
-
if task_rubric_id is None:
|
|
4042
|
-
candidate_rid = final_info.get("task_rubric_id")
|
|
4043
|
-
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4044
|
-
task_rubric_id = candidate_rid
|
|
4045
|
-
if task_id:
|
|
4046
|
-
summary.append(f"task_id={task_id}")
|
|
4047
|
-
click.echo(" ".join(summary))
|
|
4048
|
-
with contextlib.suppress(Exception):
|
|
4049
|
-
click.echo(json.dumps(data, indent=2))
|
|
4050
|
-
else:
|
|
4051
|
-
click.echo(" ".join(summary))
|
|
4052
|
-
|
|
4053
|
-
official_score = None
|
|
4054
|
-
if isinstance(metrics, dict):
|
|
4055
|
-
for key in ("mean_return", "total_reward", "outcome_score"):
|
|
4056
|
-
val = metrics.get(key)
|
|
4057
|
-
if isinstance(val, int | float):
|
|
4058
|
-
official_score = float(val)
|
|
4059
|
-
break
|
|
4060
|
-
if official_score is None and isinstance(data, dict):
|
|
4061
|
-
try:
|
|
4062
|
-
reward_val = data["trajectories"][0]["steps"][0].get("reward")
|
|
4063
|
-
if isinstance(reward_val, int | float):
|
|
4064
|
-
official_score = float(reward_val)
|
|
4065
|
-
except Exception:
|
|
4066
|
-
pass
|
|
4067
|
-
|
|
4068
|
-
if official_score is not None:
|
|
4069
|
-
if official_score < 0.0:
|
|
4070
|
-
official_score = 0.0
|
|
4071
|
-
elif official_score > 1.0:
|
|
4072
|
-
official_score = min(1.0, official_score)
|
|
4073
|
-
|
|
4074
|
-
judge_scores: dict[str, float | None] = {}
|
|
4075
|
-
judges_timings: dict[str, float | None] = {}
|
|
4076
|
-
timings: dict[str, Any] = {
|
|
4077
|
-
"rollout_s": rollout_elapsed,
|
|
4078
|
-
"judges": judges_timings,
|
|
4079
|
-
}
|
|
4080
|
-
if judge_specs:
|
|
4081
|
-
for spec in judge_specs:
|
|
4082
|
-
score_value: float | None = None
|
|
4083
|
-
judge_elapsed: float | None = None
|
|
4084
|
-
# Run judges for all tasks (text-based and trajectory-based)
|
|
4085
|
-
# Text-based tasks have completion, trajectory-based tasks use response
|
|
4086
|
-
judge_payload = {
|
|
4087
|
-
"seed": seed_val,
|
|
4088
|
-
"prompt_index": prompt_index,
|
|
4089
|
-
"prompt": prompt_text,
|
|
4090
|
-
"completion": completion,
|
|
4091
|
-
"metrics": metrics,
|
|
4092
|
-
"response": data,
|
|
4093
|
-
"trace": trace_namespace,
|
|
4094
|
-
}
|
|
4095
|
-
try:
|
|
4096
|
-
judge_start = time.perf_counter()
|
|
4097
|
-
result = spec.fn(judge_payload, **spec.kwargs)
|
|
4098
|
-
judge_elapsed = time.perf_counter() - judge_start
|
|
4099
|
-
if isinstance(result, int | float):
|
|
4100
|
-
score_value = float(result)
|
|
4101
|
-
except Exception as exc:
|
|
4102
|
-
if judge_elapsed is None:
|
|
4103
|
-
judge_elapsed = time.perf_counter() - judge_start
|
|
4104
|
-
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
4105
|
-
judges_timings[spec.name] = judge_elapsed
|
|
4106
|
-
judge_scores[spec.name] = score_value
|
|
4107
|
-
|
|
4108
|
-
if trace_tracer is not None and trace_namespace:
|
|
4109
|
-
storage_metadata = {
|
|
4110
|
-
"eval_seed": seed_val,
|
|
4111
|
-
"prompt_index": prompt_index,
|
|
4112
|
-
"task_id": task_id,
|
|
4113
|
-
"task_split": task_split,
|
|
4114
|
-
"task_rubric_id": task_rubric_id,
|
|
4115
|
-
"official_score": official_score,
|
|
4116
|
-
"judge_scores": judge_scores,
|
|
4117
|
-
"model": selected_model,
|
|
4118
|
-
"prompt": prompt_text,
|
|
4119
|
-
"completion": completion,
|
|
4120
|
-
}
|
|
4121
|
-
await _store_trace(trace_tracer, trace_namespace, storage_metadata)
|
|
4122
|
-
|
|
4123
|
-
records.append(
|
|
4124
|
-
{
|
|
4125
|
-
"seed": seed_val,
|
|
4126
|
-
"prompt_index": prompt_index,
|
|
4127
|
-
"task_id": task_id,
|
|
4128
|
-
"task_split": task_split,
|
|
4129
|
-
"task_rubric_id": task_rubric_id,
|
|
4130
|
-
"official_score": official_score,
|
|
4131
|
-
"judge_scores": judge_scores,
|
|
4132
|
-
"timings": timings,
|
|
4133
|
-
}
|
|
4134
|
-
)
|
|
4135
|
-
|
|
4136
|
-
await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
|
|
4137
|
-
finally:
|
|
4138
|
-
await async_client.aclose()
|
|
4139
|
-
|
|
4140
|
-
try:
|
|
4141
|
-
asyncio.run(_run_eval())
|
|
4142
|
-
finally:
|
|
4143
|
-
if trace_tracer is not None and trace_tracer.db is not None:
|
|
4144
|
-
asyncio.run(trace_tracer.db.close())
|
|
4145
|
-
|
|
4146
|
-
click.echo(
|
|
4147
|
-
f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
|
|
4148
|
-
)
|
|
4149
|
-
|
|
4150
|
-
if outcome_count > 0:
|
|
4151
|
-
mean_outcome = outcome_sum / float(outcome_count)
|
|
4152
|
-
frac_right = outcome_correct / float(outcome_count)
|
|
4153
|
-
click.echo(
|
|
4154
|
-
f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
|
|
4155
|
-
)
|
|
4156
|
-
|
|
4157
|
-
if records:
|
|
4158
|
-
judge_specs = judge_specs or [] # ensure iterable
|
|
4159
|
-
official_scores = [
|
|
4160
|
-
r["official_score"] for r in records if r["official_score"] is not None
|
|
4161
|
-
]
|
|
4162
|
-
if official_scores:
|
|
4163
|
-
click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
|
|
4164
|
-
else:
|
|
4165
|
-
click.echo(" Official mean: n/a")
|
|
4166
|
-
|
|
4167
|
-
for spec in judge_specs:
|
|
4168
|
-
spec_scores = [
|
|
4169
|
-
record["judge_scores"].get(spec.name)
|
|
4170
|
-
for record in records
|
|
4171
|
-
if record["judge_scores"].get(spec.name) is not None
|
|
4172
|
-
]
|
|
4173
|
-
if spec_scores:
|
|
4174
|
-
mean_spec = sum(spec_scores) / len(spec_scores)
|
|
4175
|
-
click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
|
|
4176
|
-
else:
|
|
4177
|
-
click.echo(f" [{spec.name}] mean: n/a")
|
|
4178
|
-
|
|
4179
|
-
paired = [
|
|
4180
|
-
(
|
|
4181
|
-
record["official_score"],
|
|
4182
|
-
record["judge_scores"].get(spec.name),
|
|
4183
|
-
)
|
|
4184
|
-
for record in records
|
|
4185
|
-
if record["official_score"] is not None
|
|
4186
|
-
and record["judge_scores"].get(spec.name) is not None
|
|
4187
|
-
]
|
|
4188
|
-
if len(paired) >= 2:
|
|
4189
|
-
corr = _pearson(
|
|
4190
|
-
[p[0] for p in paired if p[0] is not None],
|
|
4191
|
-
[p[1] for p in paired if p[1] is not None],
|
|
4192
|
-
)
|
|
4193
|
-
if corr is not None:
|
|
4194
|
-
click.echo(f" Pearson r: {corr:.3f}")
|
|
4195
|
-
else:
|
|
4196
|
-
click.echo(" Pearson r: undefined (zero variance)")
|
|
4197
|
-
else:
|
|
4198
|
-
click.echo(" Pearson r: n/a (need ≥2 paired scores)")
|
|
4199
|
-
|
|
4200
|
-
header = ["Seed", "Prompt", "Official"]
|
|
4201
|
-
header.extend(spec.name for spec in judge_specs)
|
|
4202
|
-
rows: list[list[str]] = []
|
|
4203
|
-
for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
|
|
4204
|
-
seed_val = str(record["seed"])
|
|
4205
|
-
prompt_idx = (
|
|
4206
|
-
str(record["prompt_index"])
|
|
4207
|
-
if record["prompt_index"] is not None
|
|
4208
|
-
else "-"
|
|
4209
|
-
)
|
|
4210
|
-
official_val = (
|
|
4211
|
-
f"{record['official_score']:.3f}"
|
|
4212
|
-
if record["official_score"] is not None
|
|
4213
|
-
else "-"
|
|
4214
|
-
)
|
|
4215
|
-
row = [seed_val, prompt_idx, official_val]
|
|
4216
|
-
for spec in judge_specs:
|
|
4217
|
-
score_val = record["judge_scores"].get(spec.name)
|
|
4218
|
-
row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
|
|
4219
|
-
rows.append(row)
|
|
4220
|
-
|
|
4221
|
-
widths = [len(col) for col in header]
|
|
4222
|
-
for row in rows:
|
|
4223
|
-
for idx, cell in enumerate(row):
|
|
4224
|
-
widths[idx] = max(widths[idx], len(cell))
|
|
4225
|
-
|
|
4226
|
-
click.echo("")
|
|
4227
|
-
click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
|
|
4228
|
-
click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
|
|
4229
|
-
for row in rows:
|
|
4230
|
-
click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
|
|
4231
|
-
|
|
4232
|
-
|
|
4233
|
-
|
|
4234
|
-
@click.command(
|
|
4235
|
-
"filter",
|
|
4236
|
-
help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
|
|
4237
|
-
)
|
|
4238
|
-
@click.option(
|
|
4239
|
-
"--config",
|
|
4240
|
-
"config_path",
|
|
4241
|
-
type=click.Path(),
|
|
4242
|
-
required=True,
|
|
4243
|
-
help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
|
|
4244
|
-
)
|
|
4245
|
-
def filter_command(config_path: str) -> None:
|
|
4246
|
-
"""Render tracing sessions that match filter rules into SFT JSONL.
|
|
4247
|
-
|
|
4248
|
-
The TOML file should contain a `[filter]` table with at least:
|
|
4249
|
-
|
|
4250
|
-
db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
|
|
4251
|
-
output = \"ft_data/out.jsonl\" # destination JSONL
|
|
4252
|
-
|
|
4253
|
-
Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
|
|
4254
|
-
`min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
|
|
4255
|
-
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4256
|
-
for a working example.
|
|
4257
|
-
"""
|
|
4258
|
-
# Parse and validate TOML config
|
|
4259
|
-
from synth_ai.task.config import FilterConfig
|
|
4260
|
-
|
|
4261
|
-
if _toml is None:
|
|
4262
|
-
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4263
|
-
|
|
4264
|
-
cfg_path = Path(config_path)
|
|
4265
|
-
if not cfg_path.exists():
|
|
4266
|
-
raise click.ClickException(f"Filter config not found: {cfg_path}")
|
|
4267
|
-
|
|
4268
|
-
try:
|
|
4269
|
-
config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
|
|
4270
|
-
except Exception as exc:
|
|
4271
|
-
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4272
|
-
|
|
4273
|
-
filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4274
|
-
if not isinstance(filter_cfg_dict, dict):
|
|
4275
|
-
raise click.ClickException("Config must contain a [filter] table")
|
|
4276
|
-
|
|
4277
|
-
# Validate config with dataclass
|
|
4278
|
-
try:
|
|
4279
|
-
filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
|
|
4280
|
-
click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
|
|
4281
|
-
if filter_cfg.min_official_score is not None:
|
|
4282
|
-
click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
|
|
4283
|
-
if filter_cfg.limit:
|
|
4284
|
-
click.echo(f" → Limiting to {filter_cfg.limit} examples")
|
|
4285
|
-
except (ValueError, TypeError) as validation_error:
|
|
4286
|
-
raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
|
|
4287
|
-
|
|
4288
|
-
# Use validated config
|
|
4289
|
-
db_url = filter_cfg.get_db_url()
|
|
4290
|
-
output_path = filter_cfg.get_output_path()
|
|
4291
|
-
|
|
4292
|
-
# Extract validated fields from dataclass
|
|
4293
|
-
splits = set(filter_cfg.splits)
|
|
4294
|
-
task_ids = set(filter_cfg.task_ids)
|
|
4295
|
-
models = set(filter_cfg.models)
|
|
4296
|
-
min_official = filter_cfg.min_official_score
|
|
4297
|
-
max_official = filter_cfg.max_official_score
|
|
4298
|
-
min_judge_scores = filter_cfg.min_judge_scores
|
|
4299
|
-
max_judge_scores = filter_cfg.max_judge_scores
|
|
4300
|
-
# Note: min_created_at and max_created_at not yet in FilterConfig dataclass
|
|
4301
|
-
min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
|
|
4302
|
-
max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
|
|
4303
|
-
limit = filter_cfg.limit
|
|
4304
|
-
|
|
4305
|
-
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4306
|
-
try:
|
|
4307
|
-
if value is None:
|
|
4308
|
-
return min_val is None
|
|
4309
|
-
value = float(value)
|
|
4310
|
-
except Exception:
|
|
4311
|
-
return False
|
|
4312
|
-
if min_val is not None and value < float(min_val):
|
|
4313
|
-
return False
|
|
4314
|
-
return not (max_val is not None and value > float(max_val))
|
|
4315
|
-
|
|
4316
|
-
async def _run_filter() -> None:
|
|
4317
|
-
tracer = SessionTracer(db_url=db_url, auto_save=False)
|
|
4318
|
-
await tracer.initialize()
|
|
4319
|
-
|
|
4320
|
-
df = await tracer.db.query_traces(
|
|
4321
|
-
"SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
|
|
4322
|
-
)
|
|
4323
|
-
if getattr(df, "empty", True):
|
|
4324
|
-
raise click.ClickException("No traces found in database")
|
|
4325
|
-
|
|
4326
|
-
sessions = df.to_dict("records")
|
|
4327
|
-
accepted: list[dict[str, Any]] = []
|
|
4328
|
-
|
|
4329
|
-
for row in sessions:
|
|
4330
|
-
metadata_raw = row.get("metadata")
|
|
4331
|
-
if isinstance(metadata_raw, str):
|
|
4332
|
-
try:
|
|
4333
|
-
metadata = json.loads(metadata_raw)
|
|
4334
|
-
except Exception:
|
|
4335
|
-
metadata = {}
|
|
4336
|
-
elif isinstance(metadata_raw, dict):
|
|
4337
|
-
metadata = dict(metadata_raw)
|
|
4338
|
-
else:
|
|
4339
|
-
metadata = {}
|
|
4340
|
-
|
|
4341
|
-
created_at_raw = row.get("created_at")
|
|
4342
|
-
created_at_dt = _parse_datetime_for_trace(created_at_raw)
|
|
4343
|
-
|
|
4344
|
-
session_id = row.get("session_id")
|
|
4345
|
-
|
|
4346
|
-
if splits and metadata.get("task_split") not in splits:
|
|
4347
|
-
continue
|
|
4348
|
-
if task_ids and metadata.get("task_id") not in task_ids:
|
|
4349
|
-
continue
|
|
4350
|
-
if models and metadata.get("model") not in models:
|
|
4351
|
-
continue
|
|
4352
|
-
|
|
4353
|
-
if min_created and (created_at_dt is None or created_at_dt < min_created):
|
|
4354
|
-
continue
|
|
4355
|
-
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4356
|
-
continue
|
|
4357
|
-
|
|
4358
|
-
# Check against outcome_rewards if score filter is set
|
|
4359
|
-
total_reward = None
|
|
4360
|
-
achievements_count = None
|
|
4361
|
-
if min_official is not None or max_official is not None:
|
|
4362
|
-
reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
|
|
4363
|
-
reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
|
|
4364
|
-
reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
|
|
4365
|
-
if reward_records:
|
|
4366
|
-
total_reward = reward_records[0].get("total_reward")
|
|
4367
|
-
achievements_count = reward_records[0].get("achievements_count")
|
|
4368
|
-
if not _score_ok(total_reward, min_official, max_official):
|
|
4369
|
-
continue
|
|
4370
|
-
elif min_official is not None:
|
|
4371
|
-
# No reward found, but score filter requires it
|
|
4372
|
-
continue
|
|
4373
|
-
|
|
4374
|
-
judge_scores = metadata.get("judge_scores") or {}
|
|
4375
|
-
include = True
|
|
4376
|
-
for judge_name, threshold in (min_judge_scores or {}).items():
|
|
4377
|
-
if not _score_ok(judge_scores.get(judge_name), threshold, None):
|
|
4378
|
-
include = False
|
|
4379
|
-
break
|
|
4380
|
-
if not include:
|
|
4381
|
-
continue
|
|
4382
|
-
for judge_name, threshold in (max_judge_scores or {}).items():
|
|
4383
|
-
if not _score_ok(judge_scores.get(judge_name), None, threshold):
|
|
4384
|
-
include = False
|
|
4385
|
-
break
|
|
4386
|
-
if not include:
|
|
4387
|
-
continue
|
|
4388
|
-
|
|
4389
|
-
# Query messages for this session
|
|
4390
|
-
messages_query = """
|
|
4391
|
-
SELECT message_type, content, timestamp
|
|
4392
|
-
FROM messages
|
|
4393
|
-
WHERE session_id = :session_id
|
|
4394
|
-
ORDER BY timestamp ASC, id ASC
|
|
4395
|
-
"""
|
|
4396
|
-
msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
|
|
4397
|
-
message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
|
|
4398
|
-
|
|
4399
|
-
if not message_rows:
|
|
4400
|
-
# Fallback: check if prompt/completion in metadata (old format)
|
|
4401
|
-
prompt = metadata.get("prompt") or ""
|
|
4402
|
-
completion = metadata.get("completion") or ""
|
|
4403
|
-
if prompt and completion:
|
|
4404
|
-
record = {
|
|
4405
|
-
"messages": [
|
|
4406
|
-
{"role": "user", "content": str(prompt)},
|
|
4407
|
-
{"role": "assistant", "content": str(completion)},
|
|
4408
|
-
],
|
|
4409
|
-
"metadata": {
|
|
4410
|
-
"session_id": session_id,
|
|
4411
|
-
"env_name": metadata.get("env_name"),
|
|
4412
|
-
"policy_name": metadata.get("policy_name"),
|
|
4413
|
-
"seed": metadata.get("seed"),
|
|
4414
|
-
"total_reward": total_reward,
|
|
4415
|
-
"achievements_count": achievements_count,
|
|
4416
|
-
"model": metadata.get("model"),
|
|
4417
|
-
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4418
|
-
},
|
|
4419
|
-
}
|
|
4420
|
-
accepted.append(record)
|
|
4421
|
-
continue
|
|
4422
|
-
|
|
4423
|
-
# Extract user/assistant pairs from messages
|
|
4424
|
-
for i, msg_row in enumerate(message_rows):
|
|
4425
|
-
msg_type = msg_row.get("message_type")
|
|
4426
|
-
content_raw = msg_row.get("content")
|
|
4427
|
-
|
|
4428
|
-
# Look for user message
|
|
4429
|
-
if msg_type in ("user", "policy_user_prompt"):
|
|
4430
|
-
# Find next policy_system_prompt or assistant
|
|
4431
|
-
assistant_msg = None
|
|
4432
|
-
for j in range(i + 1, len(message_rows)):
|
|
4433
|
-
next_type = message_rows[j].get("message_type")
|
|
4434
|
-
if next_type in ("assistant", "policy_system_prompt"):
|
|
4435
|
-
if next_type == "assistant":
|
|
4436
|
-
assistant_msg = message_rows[j]
|
|
4437
|
-
break
|
|
4438
|
-
|
|
4439
|
-
# Parse content
|
|
4440
|
-
try:
|
|
4441
|
-
user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
|
|
4442
|
-
except Exception:
|
|
4443
|
-
user_content = content_raw
|
|
4444
|
-
|
|
4445
|
-
# Extract text from structured content
|
|
4446
|
-
def extract_text(content: Any) -> str:
|
|
4447
|
-
if isinstance(content, str):
|
|
4448
|
-
return content
|
|
4449
|
-
if isinstance(content, dict):
|
|
4450
|
-
# Try payload.content for user prompts
|
|
4451
|
-
if "payload" in content and isinstance(content["payload"], dict):
|
|
4452
|
-
payload = content["payload"]
|
|
4453
|
-
if "content" in payload:
|
|
4454
|
-
return extract_text(payload["content"])
|
|
4455
|
-
# Try common keys
|
|
4456
|
-
for key in ["text", "content", "content_text"]:
|
|
4457
|
-
if key in content:
|
|
4458
|
-
val = content[key]
|
|
4459
|
-
if isinstance(val, str):
|
|
4460
|
-
return val
|
|
4461
|
-
return json.dumps(content)
|
|
4462
|
-
if isinstance(content, list):
|
|
4463
|
-
# Multimodal content - concatenate text parts
|
|
4464
|
-
parts = []
|
|
4465
|
-
for item in content:
|
|
4466
|
-
if isinstance(item, dict) and item.get("type") == "text":
|
|
4467
|
-
parts.append(item.get("text", ""))
|
|
4468
|
-
return " ".join(parts) if parts else str(content)
|
|
4469
|
-
return str(content)
|
|
4470
|
-
|
|
4471
|
-
user_text = extract_text(user_content)
|
|
4472
|
-
|
|
4473
|
-
# For assistant, we might not have it recorded, so use tool calls as completion
|
|
4474
|
-
assistant_text = ""
|
|
4475
|
-
if assistant_msg:
|
|
4476
|
-
assistant_content_raw = assistant_msg.get("content")
|
|
4477
|
-
try:
|
|
4478
|
-
assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
|
|
4479
|
-
except Exception:
|
|
4480
|
-
assistant_content = assistant_content_raw
|
|
4481
|
-
assistant_text = extract_text(assistant_content)
|
|
4482
|
-
|
|
4483
|
-
if not user_text:
|
|
4484
|
-
continue
|
|
4485
|
-
|
|
4486
|
-
record = {
|
|
4487
|
-
"messages": [
|
|
4488
|
-
{"role": "user", "content": user_text},
|
|
4489
|
-
{"role": "assistant", "content": assistant_text if assistant_text else "[no response recorded]"},
|
|
4490
|
-
],
|
|
4491
|
-
"metadata": {
|
|
4492
|
-
"session_id": session_id,
|
|
4493
|
-
"env_name": metadata.get("env_name"),
|
|
4494
|
-
"policy_name": metadata.get("policy_name"),
|
|
4495
|
-
"seed": metadata.get("seed"),
|
|
4496
|
-
"total_reward": total_reward,
|
|
4497
|
-
"achievements_count": achievements_count,
|
|
4498
|
-
"model": metadata.get("model"),
|
|
4499
|
-
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4500
|
-
},
|
|
4501
|
-
}
|
|
4502
|
-
accepted.append(record)
|
|
4503
|
-
|
|
4504
|
-
if not accepted:
|
|
4505
|
-
raise click.ClickException("No sessions matched the provided filters")
|
|
4506
|
-
|
|
4507
|
-
if limit is not None and limit > 0:
|
|
4508
|
-
accepted = accepted[:limit]
|
|
4509
|
-
|
|
4510
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4511
|
-
with output_path.open("w", encoding="utf-8") as handle:
|
|
4512
|
-
for item in accepted:
|
|
4513
|
-
handle.write(json.dumps(item, ensure_ascii=False))
|
|
4514
|
-
handle.write("\n")
|
|
4515
|
-
|
|
4516
|
-
click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
|
|
4517
|
-
await tracer.db.close()
|
|
3128
|
+
eval_command = eval_core.command
|
|
4518
3129
|
|
|
4519
|
-
|
|
3130
|
+
filter_command = filter_core.command
|
|
4520
3131
|
|
|
4521
3132
|
|
|
4522
3133
|
def register_eval(cli: click.Group) -> None:
|