synth-ai 0.2.13.dev1__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +17 -5
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +56 -26
- examples/swe/task_app/hosted/rollout.py +42 -0
- examples/swe/task_app/hosted/test_service.py +5 -6
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/__init__.py +0 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/__init__.py +0 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +5 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter.py +324 -21
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/grpo_crafter_task_app.py +1 -1
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/policy.py +76 -7
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/policy_routes.py +77 -4
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/rollout.py +117 -9
- examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_service.py +5 -6
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +218 -0
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/__init__.py +0 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +4 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +4 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +4 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/math/__init__.py +0 -0
- examples/{rl/task_app → task_apps/math}/math_single_step.py +19 -10
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README.md +357 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +225 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
- examples/task_apps/pokemon_red/task_app.py +799 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
- examples/task_apps/sokoban/README.md +307 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +4 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +24 -0
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +4 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +4 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +4 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/vlm/crafter_openai_vlm_agent.py +4 -4
- examples/vlm/run_crafter_vlm_benchmark.py +4 -4
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/workflows/__init__.py +0 -0
- examples/workflows/math_rl/__init__.py +0 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- synth_ai/__init__.py +2 -2
- synth_ai/api/models/supported.py +1 -0
- synth_ai/api/train/builders.py +25 -11
- synth_ai/api/train/cli.py +12 -6
- synth_ai/api/train/configs/__init__.py +10 -10
- synth_ai/api/train/configs/rl.py +5 -4
- synth_ai/api/train/configs/sft.py +4 -3
- synth_ai/api/train/env_resolver.py +5 -2
- synth_ai/api/train/supported_algos.py +10 -5
- synth_ai/api/train/utils.py +7 -4
- synth_ai/cli/__init__.py +48 -59
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/_storage.py +4 -3
- synth_ai/cli/_validate_task_app.py +11 -0
- synth_ai/cli/balance.py +4 -3
- synth_ai/cli/calc.py +2 -2
- synth_ai/cli/demo.py +14 -7
- synth_ai/cli/legacy_root_backup.py +1 -1
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/rl_demo.py +8 -7
- synth_ai/cli/root.py +0 -97
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +1922 -190
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +29 -17
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +27 -0
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +60 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +104 -12
- synth_ai/evals/client.py +58 -61
- synth_ai/jobs/client.py +16 -4
- synth_ai/judge_schemas.py +9 -9
- synth_ai/py.typed +0 -0
- synth_ai/task/__init__.py +24 -5
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +138 -39
- synth_ai/task/proxy.py +48 -56
- synth_ai/task/rubrics/__init__.py +56 -0
- synth_ai/task/rubrics/loaders.py +152 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +116 -0
- synth_ai/{rubrics/validators.py → task/rubrics/strict.py} +53 -30
- synth_ai/task/server.py +8 -7
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +413 -6
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/decorators.py +7 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/replica_sync.py +4 -4
- synth_ai/tracing_v3/serialization.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +66 -43
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/METADATA +4 -1
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/RECORD +278 -126
- examples/agora_ex/README_MoE.md +0 -224
- examples/agora_ex/__init__.py +0 -7
- examples/agora_ex/agora_ex.py +0 -65
- examples/agora_ex/agora_ex_task_app.py +0 -590
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +0 -121
- examples/agora_ex/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/system_prompt_CURRENT.md +0 -63
- examples/agora_ex/task_app/agora_ex_task_app.py +0 -590
- examples/agora_ex/task_app/reward_fn_grpo-human.py +0 -129
- examples/agora_ex/task_app/system_prompt_CURRENT.md +0 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +0 -62
- synth_ai/rubrics/__init__.py +0 -22
- synth_ai/task/rubrics.py +0 -219
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/README.md +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/branching.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/environment_routes.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/shared.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/envs/crafter/tools.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/hosted_app.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/inference/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/main.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/registry.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/__init__.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/storage/volume.py +0 -0
- /examples/{warming_up_to_rl → task_apps/crafter}/task_app/synth_envs_hosted/test_agents.py +0 -0
- /examples/{rl/task_app → task_apps/math}/README.md +0 -0
- /examples/{rl/task_app → task_apps/math}/math_task_app.py +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/eval_rl_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_base_qwen17.toml +0 -0
- /examples/{rl → workflows/math_rl}/configs/rl_from_ft_qwen.toml +0 -0
- /examples/{rl → workflows/math_rl}/run_eval.py +0 -0
- /examples/{rl → workflows/math_rl}/run_rl_and_save.py +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev1.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
synth_ai/cli/task_apps.py
CHANGED
|
@@ -1,48 +1,80 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import argparse
|
|
3
4
|
import ast
|
|
4
5
|
import asyncio
|
|
5
6
|
import contextlib
|
|
7
|
+
import functools
|
|
6
8
|
import hashlib
|
|
7
9
|
import importlib
|
|
8
10
|
import importlib.util
|
|
9
11
|
import inspect
|
|
10
12
|
import json
|
|
11
13
|
import os
|
|
14
|
+
import shlex
|
|
12
15
|
import shutil
|
|
13
16
|
import signal
|
|
17
|
+
import sqlite3
|
|
14
18
|
import subprocess
|
|
15
19
|
import sys
|
|
16
20
|
import tempfile
|
|
17
21
|
import textwrap
|
|
22
|
+
import time
|
|
18
23
|
import types
|
|
24
|
+
import uuid
|
|
19
25
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
20
26
|
from dataclasses import dataclass
|
|
27
|
+
from datetime import datetime, timezone
|
|
21
28
|
from pathlib import Path
|
|
22
|
-
from typing import Any, cast
|
|
29
|
+
from typing import Any, Optional, cast
|
|
23
30
|
|
|
24
31
|
try: # Python 3.11+
|
|
25
32
|
import tomllib as _toml
|
|
26
33
|
except Exception: # pragma: no cover - fallback
|
|
27
34
|
_toml = None # type: ignore
|
|
28
|
-
import uuid
|
|
29
35
|
|
|
30
36
|
import click
|
|
31
37
|
from click.exceptions import Abort
|
|
32
38
|
|
|
39
|
+
# Tracing imports - make conditional for optional dependencies
|
|
40
|
+
try:
|
|
41
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
42
|
+
BaseEvent,
|
|
43
|
+
EnvironmentEvent,
|
|
44
|
+
RuntimeEvent,
|
|
45
|
+
SessionEventMarkovBlanketMessage,
|
|
46
|
+
SessionMessageContent,
|
|
47
|
+
SessionTimeStep,
|
|
48
|
+
SessionTracer,
|
|
49
|
+
TimeRecord,
|
|
50
|
+
)
|
|
51
|
+
from synth_ai.tracing_v3 import ( # type: ignore[import-untyped]
|
|
52
|
+
SessionTrace as V3SessionTrace,
|
|
53
|
+
)
|
|
54
|
+
_TRACING_AVAILABLE = True
|
|
55
|
+
except (ImportError, ModuleNotFoundError, TypeError):
|
|
56
|
+
# Tracing system not available (missing optional dependencies)
|
|
57
|
+
BaseEvent = EnvironmentEvent = RuntimeEvent = None # type: ignore
|
|
58
|
+
SessionEventMarkovBlanketMessage = SessionMessageContent = None # type: ignore
|
|
59
|
+
SessionTimeStep = SessionTracer = TimeRecord = None # type: ignore
|
|
60
|
+
V3SessionTrace = None # type: ignore
|
|
61
|
+
_TRACING_AVAILABLE = False
|
|
62
|
+
|
|
33
63
|
# ---------------------------------------------------------------------------
|
|
34
64
|
# Dynamic imports to avoid hard dependencies during type checking.
|
|
35
65
|
# ---------------------------------------------------------------------------
|
|
36
66
|
ModalDeploymentConfigType = TaskAppConfigType = TaskAppEntryType = Any
|
|
37
67
|
|
|
38
68
|
try: # Resolve base URL defaults lazily
|
|
39
|
-
_config_module =
|
|
69
|
+
_config_module = cast(
|
|
70
|
+
Any, importlib.import_module("synth_ai.config.base_url")
|
|
71
|
+
)
|
|
40
72
|
PROD_BASE_URL_DEFAULT = cast(str, _config_module.PROD_BASE_URL_DEFAULT)
|
|
41
73
|
except Exception: # pragma: no cover - fallback
|
|
42
74
|
PROD_BASE_URL_DEFAULT = "https://agent-learning.onrender.com"
|
|
43
75
|
|
|
44
76
|
try:
|
|
45
|
-
_task_apps_module = importlib.import_module("synth_ai.task.apps")
|
|
77
|
+
_task_apps_module = cast(Any, importlib.import_module("synth_ai.task.apps"))
|
|
46
78
|
ModalDeploymentConfig = cast(
|
|
47
79
|
type[ModalDeploymentConfigType], _task_apps_module.ModalDeploymentConfig
|
|
48
80
|
)
|
|
@@ -53,21 +85,23 @@ except Exception as exc: # pragma: no cover - critical dependency
|
|
|
53
85
|
raise RuntimeError("Unable to load task app registry") from exc
|
|
54
86
|
|
|
55
87
|
try:
|
|
56
|
-
_task_server_module = importlib.import_module("synth_ai.task.server")
|
|
57
|
-
create_task_app = _task_server_module.create_task_app
|
|
58
|
-
run_task_app = _task_server_module.run_task_app
|
|
88
|
+
_task_server_module = cast(Any, importlib.import_module("synth_ai.task.server"))
|
|
89
|
+
create_task_app = cast(Callable[..., Any], _task_server_module.create_task_app)
|
|
90
|
+
run_task_app = cast(Callable[..., Any], _task_server_module.run_task_app)
|
|
59
91
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
60
92
|
raise RuntimeError("Unable to load task app server utilities") from exc
|
|
61
93
|
|
|
62
94
|
|
|
63
|
-
def _load_demo_directory() -> Path
|
|
95
|
+
def _load_demo_directory() -> Optional[Path]:
|
|
64
96
|
"""Return the demo task apps directory if available."""
|
|
65
97
|
|
|
66
98
|
try:
|
|
67
|
-
module =
|
|
68
|
-
|
|
99
|
+
module = cast(
|
|
100
|
+
Any, importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
101
|
+
)
|
|
102
|
+
loader = cast(Callable[[], Optional[str | Path]], module.load_demo_dir)
|
|
69
103
|
demo_dir = loader()
|
|
70
|
-
if isinstance(demo_dir,
|
|
104
|
+
if isinstance(demo_dir, str | Path):
|
|
71
105
|
demo_path = Path(demo_dir)
|
|
72
106
|
if demo_path.exists():
|
|
73
107
|
return demo_path.resolve()
|
|
@@ -105,13 +139,32 @@ DEFAULT_SEARCH_RELATIVE = (
|
|
|
105
139
|
)
|
|
106
140
|
|
|
107
141
|
|
|
142
|
+
def _pearson(xs: Sequence[float], ys: Sequence[float]) -> Optional[float]:
|
|
143
|
+
if len(xs) != len(ys) or len(xs) < 2:
|
|
144
|
+
return None
|
|
145
|
+
mean_x = sum(xs) / len(xs)
|
|
146
|
+
mean_y = sum(ys) / len(ys)
|
|
147
|
+
num = 0.0
|
|
148
|
+
denom_x = 0.0
|
|
149
|
+
denom_y = 0.0
|
|
150
|
+
for x, y in zip(xs, ys, strict=False):
|
|
151
|
+
dx = x - mean_x
|
|
152
|
+
dy = y - mean_y
|
|
153
|
+
num += dx * dy
|
|
154
|
+
denom_x += dx * dx
|
|
155
|
+
denom_y += dy * dy
|
|
156
|
+
if denom_x <= 0 or denom_y <= 0:
|
|
157
|
+
return None
|
|
158
|
+
return num / (denom_x ** 0.5 * denom_y ** 0.5)
|
|
159
|
+
|
|
160
|
+
|
|
108
161
|
@dataclass
|
|
109
162
|
class AppChoice:
|
|
110
163
|
app_id: str
|
|
111
164
|
label: str
|
|
112
165
|
path: Path
|
|
113
166
|
source: str
|
|
114
|
-
description: str
|
|
167
|
+
description: Optional[str] = None
|
|
115
168
|
aliases: tuple[str, ...] = ()
|
|
116
169
|
entry: TaskAppEntryType | None = None
|
|
117
170
|
entry_loader: Callable[[], TaskAppEntryType] | None = None
|
|
@@ -128,6 +181,193 @@ class AppChoice:
|
|
|
128
181
|
return entry
|
|
129
182
|
|
|
130
183
|
|
|
184
|
+
@dataclass
|
|
185
|
+
class JudgeSpec:
|
|
186
|
+
name: str
|
|
187
|
+
fn: Callable[..., Any]
|
|
188
|
+
kwargs: dict[str, Any]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _parse_datetime_for_trace(value: Any) -> Optional[datetime]:
|
|
192
|
+
if isinstance(value, datetime):
|
|
193
|
+
return value if value.tzinfo else value.replace(tzinfo=timezone.utc)
|
|
194
|
+
if isinstance(value, str):
|
|
195
|
+
value = value.replace("Z", "+00:00")
|
|
196
|
+
try:
|
|
197
|
+
dt = datetime.fromisoformat(value)
|
|
198
|
+
except ValueError:
|
|
199
|
+
try:
|
|
200
|
+
dt = datetime.fromtimestamp(float(value), tz=timezone.utc)
|
|
201
|
+
except Exception:
|
|
202
|
+
return None
|
|
203
|
+
return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)
|
|
204
|
+
if isinstance(value, int | float):
|
|
205
|
+
return datetime.fromtimestamp(float(value), tz=timezone.utc)
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _time_record_from_dict(payload: dict[str, Any] | None) -> TimeRecord:
|
|
210
|
+
payload = payload or {}
|
|
211
|
+
event_time = payload.get("event_time")
|
|
212
|
+
if not isinstance(event_time, int | float):
|
|
213
|
+
try:
|
|
214
|
+
event_time = float(event_time)
|
|
215
|
+
except Exception:
|
|
216
|
+
event_time = float(time.time())
|
|
217
|
+
message_time = payload.get("message_time")
|
|
218
|
+
if message_time is not None:
|
|
219
|
+
try:
|
|
220
|
+
message_time = int(message_time)
|
|
221
|
+
except Exception:
|
|
222
|
+
message_time = None
|
|
223
|
+
return TimeRecord(event_time=event_time, message_time=message_time)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _event_from_dict(payload: dict[str, Any]) -> BaseEvent:
|
|
227
|
+
base_kwargs = {
|
|
228
|
+
"system_instance_id": payload.get("system_instance_id", ""),
|
|
229
|
+
"time_record": _time_record_from_dict(payload.get("time_record")),
|
|
230
|
+
"metadata": payload.get("metadata") or {},
|
|
231
|
+
"event_metadata": payload.get("event_metadata"),
|
|
232
|
+
}
|
|
233
|
+
if "actions" in payload:
|
|
234
|
+
return RuntimeEvent(actions=payload.get("actions") or [], **base_kwargs)
|
|
235
|
+
if any(key in payload for key in ("reward", "terminated", "truncated")):
|
|
236
|
+
return EnvironmentEvent(
|
|
237
|
+
reward=float(payload.get("reward", 0.0) or 0.0),
|
|
238
|
+
terminated=bool(payload.get("terminated", False)),
|
|
239
|
+
truncated=bool(payload.get("truncated", False)),
|
|
240
|
+
system_state_before=payload.get("system_state_before"),
|
|
241
|
+
system_state_after=payload.get("system_state_after"),
|
|
242
|
+
**base_kwargs,
|
|
243
|
+
)
|
|
244
|
+
return BaseEvent(**base_kwargs)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _markov_message_from_dict(payload: dict[str, Any]) -> SessionEventMarkovBlanketMessage:
|
|
248
|
+
content_payload = payload.get("content") or {}
|
|
249
|
+
content = SessionMessageContent(
|
|
250
|
+
text=content_payload.get("text"),
|
|
251
|
+
json_payload=content_payload.get("json_payload"),
|
|
252
|
+
)
|
|
253
|
+
raw_type = (payload.get("message_type") or "").lower()
|
|
254
|
+
if raw_type == "observation":
|
|
255
|
+
normalized_type = "system"
|
|
256
|
+
elif raw_type == "action":
|
|
257
|
+
normalized_type = "assistant"
|
|
258
|
+
elif raw_type in {"user", "assistant", "system", "tool_use", "tool_result"}:
|
|
259
|
+
normalized_type = raw_type
|
|
260
|
+
else:
|
|
261
|
+
normalized_type = "system"
|
|
262
|
+
|
|
263
|
+
return SessionEventMarkovBlanketMessage(
|
|
264
|
+
content=content,
|
|
265
|
+
message_type=normalized_type,
|
|
266
|
+
time_record=_time_record_from_dict(payload.get("time_record")),
|
|
267
|
+
metadata=payload.get("metadata") or {},
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _step_from_dict(payload: dict[str, Any]) -> SessionTimeStep:
|
|
272
|
+
events = [
|
|
273
|
+
_event_from_dict(event)
|
|
274
|
+
for event in payload.get("events", [])
|
|
275
|
+
if isinstance(event, dict)
|
|
276
|
+
]
|
|
277
|
+
messages = [
|
|
278
|
+
_markov_message_from_dict(msg)
|
|
279
|
+
for msg in payload.get("markov_blanket_messages", [])
|
|
280
|
+
if isinstance(msg, dict)
|
|
281
|
+
]
|
|
282
|
+
timestamp = _parse_datetime_for_trace(payload.get("timestamp")) or datetime.now(timezone.utc)
|
|
283
|
+
completed_at = _parse_datetime_for_trace(payload.get("completed_at"))
|
|
284
|
+
return SessionTimeStep(
|
|
285
|
+
step_id=payload.get("step_id", ""),
|
|
286
|
+
step_index=int(payload.get("step_index", 0) or 0),
|
|
287
|
+
timestamp=timestamp,
|
|
288
|
+
turn_number=payload.get("turn_number"),
|
|
289
|
+
events=events,
|
|
290
|
+
markov_blanket_messages=messages,
|
|
291
|
+
step_metadata=payload.get("step_metadata") or {},
|
|
292
|
+
completed_at=completed_at,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _session_trace_from_dict(payload: dict[str, Any]) -> Optional[V3SessionTrace]:
|
|
297
|
+
if not isinstance(payload, dict):
|
|
298
|
+
return None
|
|
299
|
+
steps = [
|
|
300
|
+
_step_from_dict(step)
|
|
301
|
+
for step in payload.get("session_time_steps", [])
|
|
302
|
+
if isinstance(step, dict)
|
|
303
|
+
]
|
|
304
|
+
events = [
|
|
305
|
+
_event_from_dict(event)
|
|
306
|
+
for event in payload.get("event_history", [])
|
|
307
|
+
if isinstance(event, dict)
|
|
308
|
+
]
|
|
309
|
+
markov_history = [
|
|
310
|
+
_markov_message_from_dict(msg)
|
|
311
|
+
for msg in payload.get("markov_blanket_message_history", [])
|
|
312
|
+
if isinstance(msg, dict)
|
|
313
|
+
]
|
|
314
|
+
created_at = _parse_datetime_for_trace(payload.get("created_at")) or datetime.now(timezone.utc)
|
|
315
|
+
metadata = payload.get("metadata") or {}
|
|
316
|
+
session_metadata = payload.get("session_metadata")
|
|
317
|
+
return V3SessionTrace(
|
|
318
|
+
session_id=payload.get("session_id", ""),
|
|
319
|
+
created_at=created_at,
|
|
320
|
+
session_time_steps=steps,
|
|
321
|
+
event_history=events,
|
|
322
|
+
markov_blanket_message_history=markov_history,
|
|
323
|
+
metadata=metadata,
|
|
324
|
+
session_metadata=session_metadata,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
async def _store_trace(
|
|
329
|
+
tracer: SessionTracer | None,
|
|
330
|
+
trace_namespace: dict[str, Any] | None,
|
|
331
|
+
extra_metadata: dict[str, Any] | None = None,
|
|
332
|
+
):
|
|
333
|
+
import logging
|
|
334
|
+
_logger = logging.getLogger(__name__)
|
|
335
|
+
|
|
336
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Called with tracer={tracer is not None}, trace_namespace={trace_namespace is not None}")
|
|
337
|
+
|
|
338
|
+
if tracer is None or not isinstance(trace_namespace, dict):
|
|
339
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] Early return: tracer={tracer is not None}, trace_namespace type={type(trace_namespace)}")
|
|
340
|
+
return
|
|
341
|
+
|
|
342
|
+
_logger.info(f"[STORE_TRACE_DEBUG] trace_namespace keys: {list(trace_namespace.keys())}")
|
|
343
|
+
|
|
344
|
+
session_payload = trace_namespace.get("session_trace")
|
|
345
|
+
if not isinstance(session_payload, dict):
|
|
346
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] No session_trace found or wrong type: {type(session_payload)}")
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
_logger.info(f"[STORE_TRACE_DEBUG] session_payload keys: {list(session_payload.keys())}")
|
|
350
|
+
msg_count = len(session_payload.get("markov_blanket_message_history", []))
|
|
351
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Found {msg_count} messages in session_payload")
|
|
352
|
+
|
|
353
|
+
trace_obj = _session_trace_from_dict(session_payload)
|
|
354
|
+
if trace_obj is None:
|
|
355
|
+
_logger.warning(f"[STORE_TRACE_DEBUG] _session_trace_from_dict returned None")
|
|
356
|
+
return
|
|
357
|
+
|
|
358
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Created SessionTrace object with {len(trace_obj.markov_blanket_message_history)} messages")
|
|
359
|
+
|
|
360
|
+
if tracer.db is None:
|
|
361
|
+
await tracer.initialize()
|
|
362
|
+
meta = dict(trace_obj.metadata or {})
|
|
363
|
+
if extra_metadata:
|
|
364
|
+
meta.update(extra_metadata)
|
|
365
|
+
trace_obj.metadata = meta
|
|
366
|
+
|
|
367
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Calling insert_session_trace for session_id={trace_obj.session_id}")
|
|
368
|
+
await tracer.db.insert_session_trace(trace_obj)
|
|
369
|
+
_logger.info(f"[STORE_TRACE_DEBUG] Successfully inserted trace")
|
|
370
|
+
|
|
131
371
|
def _temporary_sys_path(paths: Sequence[Path]):
|
|
132
372
|
"""Context manager to prepend entries to sys.path temporarily."""
|
|
133
373
|
|
|
@@ -676,36 +916,44 @@ def _build_modal_config_from_ast(modal_call: ast.Call) -> ModalDeploymentConfigT
|
|
|
676
916
|
elif kw.arg == "pip_packages" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
677
917
|
# Handle pip_packages list/tuple
|
|
678
918
|
packages: list[str] = []
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
919
|
+
value_node = kw.value
|
|
920
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
921
|
+
for elt in value_node.elts:
|
|
922
|
+
if isinstance(elt, ast.Constant):
|
|
923
|
+
packages.append(elt.value)
|
|
682
924
|
kwargs[kw.arg] = tuple(packages)
|
|
683
925
|
elif kw.arg == "extra_local_dirs" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
684
926
|
# Handle extra_local_dirs list/tuple of tuples
|
|
685
927
|
dirs = []
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
928
|
+
value_node = kw.value
|
|
929
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
930
|
+
for elt in value_node.elts:
|
|
931
|
+
if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
|
|
932
|
+
src = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
933
|
+
dst = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
934
|
+
if src and dst:
|
|
935
|
+
dirs.append((src, dst))
|
|
692
936
|
kwargs[kw.arg] = tuple(dirs)
|
|
693
937
|
elif kw.arg == "secret_names" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
694
938
|
# Handle secret_names list/tuple
|
|
695
939
|
secrets = []
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
940
|
+
value_node = kw.value
|
|
941
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
942
|
+
for elt in value_node.elts:
|
|
943
|
+
if isinstance(elt, ast.Constant):
|
|
944
|
+
secrets.append(elt.value)
|
|
699
945
|
kwargs[kw.arg] = tuple(secrets)
|
|
700
946
|
elif kw.arg == "volume_mounts" and isinstance(kw.value, (ast.List, ast.Tuple)):
|
|
701
947
|
# Handle volume_mounts list/tuple of tuples
|
|
702
948
|
mounts = []
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
949
|
+
value_node = kw.value
|
|
950
|
+
if isinstance(value_node, (ast.List, ast.Tuple)):
|
|
951
|
+
for elt in value_node.elts:
|
|
952
|
+
if isinstance(elt, (ast.List, ast.Tuple)) and len(elt.elts) == 2:
|
|
953
|
+
name = elt.elts[0].value if isinstance(elt.elts[0], ast.Constant) else None
|
|
954
|
+
mount = elt.elts[1].value if isinstance(elt.elts[1], ast.Constant) else None
|
|
955
|
+
if name and mount:
|
|
956
|
+
mounts.append((name, mount))
|
|
709
957
|
kwargs[kw.arg] = tuple(mounts)
|
|
710
958
|
|
|
711
959
|
return ModalDeploymentConfig(**kwargs)
|
|
@@ -832,6 +1080,71 @@ def _import_task_app_module(
|
|
|
832
1080
|
return module
|
|
833
1081
|
|
|
834
1082
|
|
|
1083
|
+
@contextlib.contextmanager
|
|
1084
|
+
def _safe_import_context() -> Iterator[None]:
|
|
1085
|
+
"""Guard module imports against argparse/uvicorn side effects."""
|
|
1086
|
+
|
|
1087
|
+
original_argv = sys.argv[:]
|
|
1088
|
+
sys.argv = [original_argv[0]] if original_argv else ["python"]
|
|
1089
|
+
|
|
1090
|
+
parser_cls = argparse.ArgumentParser
|
|
1091
|
+
old_parse_args = parser_cls.parse_args
|
|
1092
|
+
|
|
1093
|
+
def _parse_noargs(self, args=None, namespace=None): # type: ignore[override]
|
|
1094
|
+
if args is None:
|
|
1095
|
+
args = []
|
|
1096
|
+
if namespace is None:
|
|
1097
|
+
namespace = argparse.Namespace()
|
|
1098
|
+
try:
|
|
1099
|
+
return old_parse_args(self, args, namespace)
|
|
1100
|
+
except SystemExit:
|
|
1101
|
+
return namespace
|
|
1102
|
+
|
|
1103
|
+
parser_cls.parse_args = _parse_noargs # type: ignore[assignment]
|
|
1104
|
+
|
|
1105
|
+
uvicorn_run = None
|
|
1106
|
+
run_task_app_orig = None
|
|
1107
|
+
try:
|
|
1108
|
+
import uvicorn # type: ignore
|
|
1109
|
+
|
|
1110
|
+
uvicorn_run = uvicorn.run
|
|
1111
|
+
uvicorn.run = lambda *args, **kwargs: None # type: ignore[assignment]
|
|
1112
|
+
except Exception:
|
|
1113
|
+
uvicorn_run = None
|
|
1114
|
+
|
|
1115
|
+
try:
|
|
1116
|
+
_task_server_patch = cast(
|
|
1117
|
+
Any, importlib.import_module("synth_ai.task.server")
|
|
1118
|
+
)
|
|
1119
|
+
run_task_app_orig = cast(Callable[..., Any], _task_server_patch.run_task_app)
|
|
1120
|
+
_task_server_patch.run_task_app = ( # type: ignore[assignment]
|
|
1121
|
+
lambda *args, **kwargs: None
|
|
1122
|
+
)
|
|
1123
|
+
except Exception:
|
|
1124
|
+
run_task_app_orig = None
|
|
1125
|
+
|
|
1126
|
+
try:
|
|
1127
|
+
yield
|
|
1128
|
+
finally:
|
|
1129
|
+
sys.argv = original_argv
|
|
1130
|
+
parser_cls.parse_args = old_parse_args # type: ignore[assignment]
|
|
1131
|
+
if uvicorn_run is not None:
|
|
1132
|
+
try:
|
|
1133
|
+
import uvicorn # type: ignore
|
|
1134
|
+
|
|
1135
|
+
uvicorn.run = uvicorn_run # type: ignore[assignment]
|
|
1136
|
+
except Exception:
|
|
1137
|
+
pass
|
|
1138
|
+
if run_task_app_orig is not None:
|
|
1139
|
+
try:
|
|
1140
|
+
_task_server_patch = cast(
|
|
1141
|
+
Any, importlib.import_module("synth_ai.task.server")
|
|
1142
|
+
)
|
|
1143
|
+
_task_server_patch.run_task_app = run_task_app_orig # type: ignore[assignment]
|
|
1144
|
+
except Exception:
|
|
1145
|
+
pass
|
|
1146
|
+
|
|
1147
|
+
|
|
835
1148
|
def _load_entry_from_path(
|
|
836
1149
|
path: Path, app_id: str, module_search_roots: Sequence[Path] | None = None
|
|
837
1150
|
) -> TaskAppEntryType:
|
|
@@ -859,13 +1172,14 @@ def _load_entry_from_path(
|
|
|
859
1172
|
|
|
860
1173
|
for module_name, namespace_root in _possible_module_names(resolved, search_roots):
|
|
861
1174
|
try:
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
1175
|
+
with _safe_import_context():
|
|
1176
|
+
module = _import_task_app_module(
|
|
1177
|
+
resolved,
|
|
1178
|
+
module_name,
|
|
1179
|
+
namespace_root=namespace_root,
|
|
1180
|
+
sys_path_roots=search_roots,
|
|
1181
|
+
ensure_namespace=True,
|
|
1182
|
+
)
|
|
869
1183
|
break
|
|
870
1184
|
except Exception as exc: # pragma: no cover - best-effort fallbacks
|
|
871
1185
|
last_error = exc
|
|
@@ -874,13 +1188,14 @@ def _load_entry_from_path(
|
|
|
874
1188
|
if module is None:
|
|
875
1189
|
hashed_name = f"_synth_task_app_{hashlib.md5(str(resolved).encode(), usedforsecurity=False).hexdigest()}"
|
|
876
1190
|
try:
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
1191
|
+
with _safe_import_context():
|
|
1192
|
+
module = _import_task_app_module(
|
|
1193
|
+
resolved,
|
|
1194
|
+
hashed_name,
|
|
1195
|
+
namespace_root=None,
|
|
1196
|
+
sys_path_roots=search_roots,
|
|
1197
|
+
ensure_namespace=False,
|
|
1198
|
+
)
|
|
884
1199
|
except Exception as exc: # pragma: no cover - propagate meaningful error
|
|
885
1200
|
detail = last_error or exc
|
|
886
1201
|
raise click.ClickException(f"Failed to import {resolved}: {detail}") from detail
|
|
@@ -928,7 +1243,10 @@ def _load_entry_from_path(
|
|
|
928
1243
|
if has_required:
|
|
929
1244
|
continue
|
|
930
1245
|
try:
|
|
931
|
-
|
|
1246
|
+
with _safe_import_context():
|
|
1247
|
+
result = attr()
|
|
1248
|
+
except SystemExit:
|
|
1249
|
+
continue
|
|
932
1250
|
except Exception:
|
|
933
1251
|
continue
|
|
934
1252
|
if isinstance(result, TaskAppConfig) and result.app_id == app_id:
|
|
@@ -1024,21 +1342,173 @@ def _resolve_env_paths_for_script(script_path: Path, explicit: Sequence[str]) ->
|
|
|
1024
1342
|
return [env_candidates[choice - 1]]
|
|
1025
1343
|
|
|
1026
1344
|
|
|
1345
|
+
def _path_is_within(child: Path, parent: Path) -> bool:
|
|
1346
|
+
try:
|
|
1347
|
+
child.resolve().relative_to(parent.resolve())
|
|
1348
|
+
return True
|
|
1349
|
+
except Exception:
|
|
1350
|
+
return False
|
|
1351
|
+
|
|
1352
|
+
|
|
1353
|
+
@functools.lru_cache(maxsize=16)
|
|
1354
|
+
def _is_modal_shim(path_str: str) -> bool:
|
|
1355
|
+
"""Return True if the candidate CLI path refers to the synth-ai shim."""
|
|
1356
|
+
|
|
1357
|
+
path = Path(path_str)
|
|
1358
|
+
try:
|
|
1359
|
+
resolved = path.resolve(strict=True)
|
|
1360
|
+
except Exception:
|
|
1361
|
+
resolved = path
|
|
1362
|
+
|
|
1363
|
+
if not resolved.exists() or resolved.is_dir():
|
|
1364
|
+
return False
|
|
1365
|
+
|
|
1366
|
+
snippet = ""
|
|
1367
|
+
try:
|
|
1368
|
+
snippet = resolved.read_bytes()[:4096].decode("utf-8", errors="ignore")
|
|
1369
|
+
except Exception:
|
|
1370
|
+
snippet = ""
|
|
1371
|
+
|
|
1372
|
+
shim_markers = (
|
|
1373
|
+
"synth_ai.cli._modal_wrapper",
|
|
1374
|
+
"from modal.__main__ import main",
|
|
1375
|
+
"import modal.__main__",
|
|
1376
|
+
"run_module('modal.__main__'",
|
|
1377
|
+
)
|
|
1378
|
+
if snippet and any(marker in snippet for marker in shim_markers):
|
|
1379
|
+
return True
|
|
1380
|
+
|
|
1381
|
+
try:
|
|
1382
|
+
size = resolved.stat().st_size
|
|
1383
|
+
except Exception:
|
|
1384
|
+
size = None
|
|
1385
|
+
|
|
1386
|
+
if (
|
|
1387
|
+
size is not None
|
|
1388
|
+
and size < 2048
|
|
1389
|
+
and "python" in (snippet.splitlines() or [""])[0]
|
|
1390
|
+
and (
|
|
1391
|
+
"modal.__main__" in snippet
|
|
1392
|
+
or "modal.__main__" in snippet.replace(" ", "")
|
|
1393
|
+
)
|
|
1394
|
+
):
|
|
1395
|
+
return True
|
|
1396
|
+
|
|
1397
|
+
virtual_env = os.environ.get("VIRTUAL_ENV")
|
|
1398
|
+
if virtual_env and _path_is_within(resolved, Path(virtual_env)):
|
|
1399
|
+
return True
|
|
1400
|
+
|
|
1401
|
+
if _path_is_within(resolved, REPO_ROOT):
|
|
1402
|
+
return True
|
|
1403
|
+
|
|
1404
|
+
uv_tools_dir = Path.home() / ".local" / "share" / "uv" / "tools"
|
|
1405
|
+
return uv_tools_dir.exists() and _path_is_within(resolved, uv_tools_dir)
|
|
1406
|
+
|
|
1407
|
+
|
|
1408
|
+
def _find_modal_executable(modal_cli: str) -> tuple[str | None, str | None]:
|
|
1409
|
+
"""Return the first non-shim executable and the first shim discovered on PATH."""
|
|
1410
|
+
|
|
1411
|
+
if not modal_cli:
|
|
1412
|
+
modal_cli = "modal"
|
|
1413
|
+
|
|
1414
|
+
candidate_path = Path(modal_cli).expanduser()
|
|
1415
|
+
if candidate_path.is_absolute() or len(candidate_path.parts) > 1:
|
|
1416
|
+
resolved_candidate = candidate_path
|
|
1417
|
+
if not resolved_candidate.is_absolute():
|
|
1418
|
+
resolved_candidate = (Path.cwd() / resolved_candidate).resolve()
|
|
1419
|
+
else:
|
|
1420
|
+
resolved_candidate = resolved_candidate.resolve()
|
|
1421
|
+
if not resolved_candidate.exists():
|
|
1422
|
+
raise click.ClickException(f"--modal-cli path does not exist: {resolved_candidate}")
|
|
1423
|
+
if not os.access(resolved_candidate, os.X_OK):
|
|
1424
|
+
raise click.ClickException(f"--modal-cli is not executable: {resolved_candidate}")
|
|
1425
|
+
return str(resolved_candidate), None
|
|
1426
|
+
|
|
1427
|
+
path_env = os.environ.get("PATH", "")
|
|
1428
|
+
if not path_env:
|
|
1429
|
+
return None, None
|
|
1430
|
+
|
|
1431
|
+
seen_dirs: set[str] = set()
|
|
1432
|
+
seen_candidates: set[str] = set()
|
|
1433
|
+
shim_path: str | None = None
|
|
1434
|
+
|
|
1435
|
+
for raw_entry in path_env.split(os.pathsep):
|
|
1436
|
+
if not raw_entry:
|
|
1437
|
+
continue
|
|
1438
|
+
try:
|
|
1439
|
+
resolved_entry = str(Path(raw_entry).resolve())
|
|
1440
|
+
except Exception:
|
|
1441
|
+
resolved_entry = os.path.normpath(raw_entry)
|
|
1442
|
+
if resolved_entry in seen_dirs:
|
|
1443
|
+
continue
|
|
1444
|
+
seen_dirs.add(resolved_entry)
|
|
1445
|
+
|
|
1446
|
+
candidate = shutil.which(modal_cli, path=raw_entry)
|
|
1447
|
+
if candidate is None:
|
|
1448
|
+
continue
|
|
1449
|
+
if candidate in seen_candidates:
|
|
1450
|
+
continue
|
|
1451
|
+
seen_candidates.add(candidate)
|
|
1452
|
+
|
|
1453
|
+
if _is_modal_shim(candidate):
|
|
1454
|
+
if shim_path is None:
|
|
1455
|
+
shim_path = candidate
|
|
1456
|
+
continue
|
|
1457
|
+
return candidate, shim_path
|
|
1458
|
+
|
|
1459
|
+
return None, shim_path
|
|
1460
|
+
|
|
1461
|
+
|
|
1027
1462
|
def _modal_command_prefix(modal_cli: str) -> list[str]:
|
|
1028
1463
|
"""Resolve a command prefix for invoking the Modal CLI within the active environment."""
|
|
1029
|
-
|
|
1464
|
+
|
|
1465
|
+
force_wrapper_env = os.environ.get("SYNTH_FORCE_MODAL_WRAPPER", "").strip().lower()
|
|
1466
|
+
if force_wrapper_env in {"1", "true", "yes"}:
|
|
1467
|
+
click.secho(
|
|
1468
|
+
"[modal-prefix] SYNTH_FORCE_MODAL_WRAPPER=1 -> using in-process wrapper",
|
|
1469
|
+
fg="yellow",
|
|
1470
|
+
)
|
|
1030
1471
|
return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
|
|
1031
1472
|
|
|
1032
|
-
|
|
1033
|
-
if
|
|
1034
|
-
|
|
1473
|
+
lookup = modal_cli or "modal"
|
|
1474
|
+
spec = importlib.util.find_spec("modal") if lookup == "modal" else None
|
|
1475
|
+
|
|
1476
|
+
preferred, shim_candidate = _find_modal_executable(lookup)
|
|
1477
|
+
if preferred is not None:
|
|
1478
|
+
detail = f"[modal-prefix] modal_cli={lookup} selected={preferred}"
|
|
1479
|
+
if lookup == "modal":
|
|
1480
|
+
detail += f" spec={'yes' if spec else 'no'}"
|
|
1481
|
+
click.secho(detail, fg="cyan")
|
|
1482
|
+
return [preferred]
|
|
1483
|
+
|
|
1484
|
+
if lookup != "modal":
|
|
1485
|
+
raise click.ClickException(f"Modal CLI not found (looked for '{lookup}')")
|
|
1486
|
+
|
|
1487
|
+
if spec is not None:
|
|
1488
|
+
warning = "[modal-prefix] Using synth-ai modal shim; pass --modal-cli /path/to/modal to override."
|
|
1489
|
+
if shim_candidate is not None:
|
|
1490
|
+
warning = (
|
|
1491
|
+
f"[modal-prefix] Using synth-ai modal shim at {shim_candidate}; "
|
|
1492
|
+
"pass --modal-cli /path/to/modal to override."
|
|
1493
|
+
)
|
|
1494
|
+
click.secho(warning, fg="yellow")
|
|
1495
|
+
click.secho(
|
|
1496
|
+
"[modal-prefix] modal_cli=modal selected=module-wrapper spec=yes",
|
|
1497
|
+
fg="yellow",
|
|
1498
|
+
)
|
|
1499
|
+
return [sys.executable, "-m", "synth_ai.cli._modal_wrapper"]
|
|
1035
1500
|
|
|
1036
|
-
if
|
|
1501
|
+
if shim_candidate is not None:
|
|
1037
1502
|
raise click.ClickException(
|
|
1038
|
-
"Modal CLI
|
|
1039
|
-
"
|
|
1503
|
+
"Modal CLI resolution found the synth-ai shim but the 'modal' package "
|
|
1504
|
+
"is not importable in this environment. Install the official Modal CLI "
|
|
1505
|
+
"or pass --modal-cli with its path."
|
|
1040
1506
|
)
|
|
1041
|
-
|
|
1507
|
+
|
|
1508
|
+
raise click.ClickException(
|
|
1509
|
+
"Modal CLI not found. Install the 'modal' package in this environment or pass "
|
|
1510
|
+
"--modal-cli with an explicit path."
|
|
1511
|
+
)
|
|
1042
1512
|
|
|
1043
1513
|
|
|
1044
1514
|
def _build_modal_app_wrapper(original_script: Path) -> tuple[Path, Path]:
|
|
@@ -1173,8 +1643,15 @@ def _run_modal_script(
|
|
|
1173
1643
|
if modal_name and command == "deploy":
|
|
1174
1644
|
cmd.extend(["--name", modal_name])
|
|
1175
1645
|
if dry_run:
|
|
1176
|
-
click.echo(
|
|
1646
|
+
click.echo(
|
|
1647
|
+
"Dry run: " + " ".join(shlex.quote(component) for component in cmd),
|
|
1648
|
+
err=False,
|
|
1649
|
+
)
|
|
1177
1650
|
return
|
|
1651
|
+
click.secho(
|
|
1652
|
+
"[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
|
|
1653
|
+
fg="cyan",
|
|
1654
|
+
)
|
|
1178
1655
|
try:
|
|
1179
1656
|
# Stream output live for better diagnostics
|
|
1180
1657
|
proc = subprocess.Popen(
|
|
@@ -1429,7 +1906,6 @@ def _run_modal_with_entry(
|
|
|
1429
1906
|
inline_secret_values=inline_secret_values,
|
|
1430
1907
|
)
|
|
1431
1908
|
cmd = [*_modal_command_prefix(modal_cli), command, str(script_path)]
|
|
1432
|
-
|
|
1433
1909
|
if modal_name and command == "deploy":
|
|
1434
1910
|
cmd.extend(["--name", modal_name])
|
|
1435
1911
|
|
|
@@ -1444,9 +1920,13 @@ def _run_modal_with_entry(
|
|
|
1444
1920
|
proc_env["PYTHONPATH"] = os.pathsep.join(list(dict.fromkeys(pythonpath_entries)))
|
|
1445
1921
|
|
|
1446
1922
|
if dry_run:
|
|
1447
|
-
click.echo("Dry run: " + " ".join(cmd))
|
|
1923
|
+
click.echo("Dry run: " + " ".join(shlex.quote(component) for component in cmd))
|
|
1448
1924
|
script_path.unlink(missing_ok=True)
|
|
1449
1925
|
return
|
|
1926
|
+
click.secho(
|
|
1927
|
+
"[modal-exec] " + " ".join(shlex.quote(component) for component in cmd),
|
|
1928
|
+
fg="cyan",
|
|
1929
|
+
)
|
|
1450
1930
|
|
|
1451
1931
|
try:
|
|
1452
1932
|
# Stream output live for better diagnostics
|
|
@@ -1531,6 +2011,10 @@ def _parse_env_file(path: Path) -> dict[str, str]:
|
|
|
1531
2011
|
|
|
1532
2012
|
|
|
1533
2013
|
def _interactive_fill_env(env_path: Path) -> Path | None:
|
|
2014
|
+
if not sys.stdin.isatty():
|
|
2015
|
+
raise click.ClickException(
|
|
2016
|
+
"ENVIRONMENT_API_KEY missing. Provide --env-file or run `synth-ai setup` in an interactive shell to create one."
|
|
2017
|
+
)
|
|
1534
2018
|
existing = _parse_env_file(env_path) if env_path.exists() else {}
|
|
1535
2019
|
|
|
1536
2020
|
def _prompt(label: str, *, default: str = "", required: bool) -> str | None:
|
|
@@ -1570,6 +2054,10 @@ def _ensure_env_values(env_paths: list[Path], fallback_dir: Path) -> None:
|
|
|
1570
2054
|
if (os.environ.get("ENVIRONMENT_API_KEY") or "").strip():
|
|
1571
2055
|
return
|
|
1572
2056
|
target = env_paths[0] if env_paths else (fallback_dir / ".env").resolve()
|
|
2057
|
+
click.echo(
|
|
2058
|
+
"⚠️ ENVIRONMENT_API_KEY not set. Run `uvx synth-ai setup`, "
|
|
2059
|
+
"or pass --env-file pointing at a .env with ENVIRONMENT_API_KEY."
|
|
2060
|
+
)
|
|
1573
2061
|
result = _interactive_fill_env(target)
|
|
1574
2062
|
if result is None:
|
|
1575
2063
|
raise click.ClickException("ENVIRONMENT_API_KEY required to continue")
|
|
@@ -1593,7 +2081,7 @@ def _deploy_entry(
|
|
|
1593
2081
|
f"Task app '{entry.app_id}' does not define Modal deployment settings"
|
|
1594
2082
|
)
|
|
1595
2083
|
|
|
1596
|
-
env_paths = _determine_env_files(entry, env_file)
|
|
2084
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_path)
|
|
1597
2085
|
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
1598
2086
|
_run_modal_with_entry(
|
|
1599
2087
|
entry,
|
|
@@ -1620,7 +2108,7 @@ def _modal_serve_entry(
|
|
|
1620
2108
|
f"Task app '{entry.app_id}' does not define Modal deployment settings"
|
|
1621
2109
|
)
|
|
1622
2110
|
|
|
1623
|
-
env_paths = _determine_env_files(entry, env_file)
|
|
2111
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_path)
|
|
1624
2112
|
click.echo("Using env file(s): " + ", ".join(str(p.resolve()) for p in env_paths))
|
|
1625
2113
|
_run_modal_with_entry(
|
|
1626
2114
|
entry,
|
|
@@ -1651,6 +2139,255 @@ def list_apps() -> None:
|
|
|
1651
2139
|
click.echo(f"- {entry.app_id}{aliases}: {entry.description}")
|
|
1652
2140
|
|
|
1653
2141
|
|
|
2142
|
+
@task_app_group.command("validate")
|
|
2143
|
+
@click.argument("app_id", type=str, required=True)
|
|
2144
|
+
@click.option(
|
|
2145
|
+
"--url",
|
|
2146
|
+
type=str,
|
|
2147
|
+
default=None,
|
|
2148
|
+
help="Task app URL to validate (if not provided, starts a local server)",
|
|
2149
|
+
)
|
|
2150
|
+
@click.option(
|
|
2151
|
+
"--port",
|
|
2152
|
+
type=int,
|
|
2153
|
+
default=8765,
|
|
2154
|
+
help="Port to use for temporary server (default: 8765)",
|
|
2155
|
+
)
|
|
2156
|
+
@click.option(
|
|
2157
|
+
"--api-key",
|
|
2158
|
+
type=str,
|
|
2159
|
+
default=None,
|
|
2160
|
+
envvar="ENVIRONMENT_API_KEY",
|
|
2161
|
+
help="API key for authentication (default: $ENVIRONMENT_API_KEY)",
|
|
2162
|
+
)
|
|
2163
|
+
@click.option(
|
|
2164
|
+
"--min-instances",
|
|
2165
|
+
type=int,
|
|
2166
|
+
default=10,
|
|
2167
|
+
help="Minimum number of task instances required (default: 10)",
|
|
2168
|
+
)
|
|
2169
|
+
@click.option(
|
|
2170
|
+
"--verbose",
|
|
2171
|
+
"-v",
|
|
2172
|
+
is_flag=True,
|
|
2173
|
+
help="Show detailed information about the task app",
|
|
2174
|
+
)
|
|
2175
|
+
@click.option(
|
|
2176
|
+
"--json",
|
|
2177
|
+
"output_json",
|
|
2178
|
+
is_flag=True,
|
|
2179
|
+
help="Output results as JSON",
|
|
2180
|
+
)
|
|
2181
|
+
def validate_task_app_cmd(
|
|
2182
|
+
app_id: str,
|
|
2183
|
+
url: str | None,
|
|
2184
|
+
port: int,
|
|
2185
|
+
api_key: str | None,
|
|
2186
|
+
min_instances: int,
|
|
2187
|
+
verbose: bool,
|
|
2188
|
+
output_json: bool,
|
|
2189
|
+
) -> None:
|
|
2190
|
+
"""Validate a task app deployment readiness.
|
|
2191
|
+
|
|
2192
|
+
This command verifies that a task app is properly configured and ready to run
|
|
2193
|
+
by checking all required HTTP endpoints, authentication, and task availability.
|
|
2194
|
+
|
|
2195
|
+
By default, it starts a temporary local server for validation. You can also
|
|
2196
|
+
validate a remote deployment by passing --url.
|
|
2197
|
+
|
|
2198
|
+
\b
|
|
2199
|
+
What gets validated:
|
|
2200
|
+
• Root endpoint (/) responds correctly
|
|
2201
|
+
• Health endpoint (/health) is accessible with proper authentication
|
|
2202
|
+
• Info endpoint (/info) returns valid task metadata
|
|
2203
|
+
• Task info endpoint (/task_info) provides task instances
|
|
2204
|
+
• Rollout endpoint (/rollout) is registered
|
|
2205
|
+
• At least N task instances are available (default: 10)
|
|
2206
|
+
|
|
2207
|
+
\b
|
|
2208
|
+
Examples:
|
|
2209
|
+
|
|
2210
|
+
\b
|
|
2211
|
+
Validate grpo-crafter (starts local server automatically):
|
|
2212
|
+
$ synth-ai task-app validate grpo-crafter
|
|
2213
|
+
|
|
2214
|
+
\b
|
|
2215
|
+
Validate sokoban with verbose output:
|
|
2216
|
+
$ synth-ai task-app validate sokoban --verbose
|
|
2217
|
+
|
|
2218
|
+
\b
|
|
2219
|
+
Validate with custom port:
|
|
2220
|
+
$ synth-ai task-app validate sokoban --port 9000
|
|
2221
|
+
|
|
2222
|
+
\b
|
|
2223
|
+
Validate a remote deployment:
|
|
2224
|
+
$ synth-ai task-app validate grpo-crafter --url https://my-crafter.modal.run
|
|
2225
|
+
|
|
2226
|
+
\b
|
|
2227
|
+
Require at least 20 task instances:
|
|
2228
|
+
$ synth-ai task-app validate grpo-crafter --min-instances 20
|
|
2229
|
+
|
|
2230
|
+
\b
|
|
2231
|
+
Get JSON output for automation:
|
|
2232
|
+
$ synth-ai task-app validate sokoban --json
|
|
2233
|
+
|
|
2234
|
+
\b
|
|
2235
|
+
Common use cases:
|
|
2236
|
+
• Pre-deployment verification: Check task app works before deploying to Modal
|
|
2237
|
+
• CI/CD integration: Use --json flag for automated validation in pipelines
|
|
2238
|
+
• Debug failing deployments: Use --verbose to see detailed endpoint responses
|
|
2239
|
+
• Test API key configuration: Verify authentication is set up correctly
|
|
2240
|
+
"""
|
|
2241
|
+
import asyncio
|
|
2242
|
+
import socket
|
|
2243
|
+
import subprocess
|
|
2244
|
+
import tempfile
|
|
2245
|
+
import time
|
|
2246
|
+
|
|
2247
|
+
# Import the validate_task_app function defined in this module
|
|
2248
|
+
from synth_ai.cli._validate_task_app import validate_task_app # type: ignore[attr-defined]
|
|
2249
|
+
|
|
2250
|
+
proc = None
|
|
2251
|
+
task_app_url = url
|
|
2252
|
+
|
|
2253
|
+
try:
|
|
2254
|
+
# If no URL provided, start a temporary server
|
|
2255
|
+
if not task_app_url:
|
|
2256
|
+
# Find an available port
|
|
2257
|
+
def is_port_available(port: int) -> bool:
|
|
2258
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
2259
|
+
try:
|
|
2260
|
+
s.bind(("", port))
|
|
2261
|
+
return True
|
|
2262
|
+
except OSError:
|
|
2263
|
+
return False
|
|
2264
|
+
|
|
2265
|
+
while not is_port_available(port):
|
|
2266
|
+
port += 1
|
|
2267
|
+
|
|
2268
|
+
task_app_url = f"http://localhost:{port}"
|
|
2269
|
+
|
|
2270
|
+
if not output_json:
|
|
2271
|
+
click.echo(f"Starting temporary {app_id} server on port {port}...")
|
|
2272
|
+
|
|
2273
|
+
# Start the server in background
|
|
2274
|
+
env = os.environ.copy()
|
|
2275
|
+
if api_key:
|
|
2276
|
+
env["ENVIRONMENT_API_KEY"] = api_key
|
|
2277
|
+
|
|
2278
|
+
# Create a temporary trace DB and trace dir to avoid prompts
|
|
2279
|
+
import tempfile
|
|
2280
|
+
temp_dir = tempfile.mkdtemp()
|
|
2281
|
+
temp_trace_db = os.path.join(temp_dir, "validate_trace.db")
|
|
2282
|
+
temp_trace_dir = os.path.join(temp_dir, "traces")
|
|
2283
|
+
os.makedirs(temp_trace_dir, exist_ok=True)
|
|
2284
|
+
|
|
2285
|
+
proc = subprocess.Popen(
|
|
2286
|
+
[
|
|
2287
|
+
"uv",
|
|
2288
|
+
"run",
|
|
2289
|
+
"synth-ai",
|
|
2290
|
+
"task-app",
|
|
2291
|
+
"serve",
|
|
2292
|
+
app_id,
|
|
2293
|
+
"--port",
|
|
2294
|
+
str(port),
|
|
2295
|
+
"--no-reload",
|
|
2296
|
+
"--trace",
|
|
2297
|
+
temp_trace_dir,
|
|
2298
|
+
"--trace-db",
|
|
2299
|
+
temp_trace_db,
|
|
2300
|
+
],
|
|
2301
|
+
env=env,
|
|
2302
|
+
stdin=subprocess.PIPE, # Add stdin to handle any prompts
|
|
2303
|
+
stdout=subprocess.DEVNULL if output_json else subprocess.PIPE,
|
|
2304
|
+
stderr=subprocess.DEVNULL if output_json else subprocess.PIPE,
|
|
2305
|
+
text=True,
|
|
2306
|
+
)
|
|
2307
|
+
|
|
2308
|
+
# Write empty input to stdin to skip any prompts
|
|
2309
|
+
if proc.stdin:
|
|
2310
|
+
try:
|
|
2311
|
+
proc.stdin.write("\n")
|
|
2312
|
+
proc.stdin.flush()
|
|
2313
|
+
proc.stdin.close()
|
|
2314
|
+
except Exception:
|
|
2315
|
+
pass
|
|
2316
|
+
|
|
2317
|
+
# Wait for server to be ready
|
|
2318
|
+
if not output_json:
|
|
2319
|
+
click.echo("Waiting for server to start...")
|
|
2320
|
+
|
|
2321
|
+
import httpx
|
|
2322
|
+
for _attempt in range(60): # 30 seconds timeout
|
|
2323
|
+
try:
|
|
2324
|
+
async def check_health():
|
|
2325
|
+
async with httpx.AsyncClient(timeout=2.0) as client:
|
|
2326
|
+
resp = await client.get(f"{task_app_url}/")
|
|
2327
|
+
return resp.status_code == 200
|
|
2328
|
+
|
|
2329
|
+
if asyncio.run(check_health()):
|
|
2330
|
+
break
|
|
2331
|
+
except Exception:
|
|
2332
|
+
pass
|
|
2333
|
+
|
|
2334
|
+
# Check if process died
|
|
2335
|
+
if proc.poll() is not None:
|
|
2336
|
+
stderr_output = ""
|
|
2337
|
+
if proc.stderr and not output_json:
|
|
2338
|
+
stderr_output = proc.stderr.read()
|
|
2339
|
+
click.echo(click.style("✗ Server process exited unexpectedly", fg="red"), err=True)
|
|
2340
|
+
if stderr_output and not output_json:
|
|
2341
|
+
click.echo(f"Error output:\n{stderr_output}", err=True)
|
|
2342
|
+
sys.exit(1)
|
|
2343
|
+
|
|
2344
|
+
time.sleep(0.5)
|
|
2345
|
+
else:
|
|
2346
|
+
click.echo(click.style("✗ Server failed to start within 30 seconds", fg="red"), err=True)
|
|
2347
|
+
sys.exit(1)
|
|
2348
|
+
|
|
2349
|
+
if not output_json:
|
|
2350
|
+
click.echo(click.style("✓ Server started", fg="green"))
|
|
2351
|
+
click.echo()
|
|
2352
|
+
|
|
2353
|
+
# Ensure URL doesn't have trailing slash
|
|
2354
|
+
task_app_url = task_app_url.rstrip("/")
|
|
2355
|
+
|
|
2356
|
+
async def _run() -> tuple[bool, dict[str, Any]]:
|
|
2357
|
+
return await validate_task_app(
|
|
2358
|
+
url=task_app_url,
|
|
2359
|
+
api_key=api_key,
|
|
2360
|
+
min_instances=min_instances,
|
|
2361
|
+
verbose=verbose,
|
|
2362
|
+
)
|
|
2363
|
+
|
|
2364
|
+
success, results = asyncio.run(_run())
|
|
2365
|
+
|
|
2366
|
+
if output_json:
|
|
2367
|
+
import json as _json
|
|
2368
|
+
click.echo(_json.dumps(results, indent=2))
|
|
2369
|
+
|
|
2370
|
+
sys.exit(0 if success else 1)
|
|
2371
|
+
|
|
2372
|
+
finally:
|
|
2373
|
+
# Cleanup: stop the temporary server
|
|
2374
|
+
if proc is not None:
|
|
2375
|
+
if not output_json:
|
|
2376
|
+
click.echo("\nStopping temporary server...")
|
|
2377
|
+
try:
|
|
2378
|
+
proc.terminate()
|
|
2379
|
+
proc.wait(timeout=5)
|
|
2380
|
+
except Exception:
|
|
2381
|
+
proc.kill()
|
|
2382
|
+
|
|
2383
|
+
# Cleanup temp trace DB
|
|
2384
|
+
if not url and 'temp_dir' in locals():
|
|
2385
|
+
import contextlib
|
|
2386
|
+
import shutil
|
|
2387
|
+
with contextlib.suppress(Exception):
|
|
2388
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
2389
|
+
|
|
2390
|
+
|
|
1654
2391
|
def _load_env_files_into_process(paths: Sequence[str]) -> None:
|
|
1655
2392
|
for p in paths:
|
|
1656
2393
|
try:
|
|
@@ -1907,7 +2644,9 @@ def serve_task_group(
|
|
|
1907
2644
|
)
|
|
1908
2645
|
|
|
1909
2646
|
|
|
1910
|
-
def _determine_env_files(
|
|
2647
|
+
def _determine_env_files(
|
|
2648
|
+
entry: TaskAppEntryType, user_env_files: Sequence[str], *, original_path: Path | None = None
|
|
2649
|
+
) -> list[Path]:
|
|
1911
2650
|
resolved: list[Path] = []
|
|
1912
2651
|
for candidate in user_env_files:
|
|
1913
2652
|
p = Path(candidate).expanduser()
|
|
@@ -1917,30 +2656,46 @@ def _determine_env_files(entry: TaskAppEntryType, user_env_files: Sequence[str])
|
|
|
1917
2656
|
if resolved:
|
|
1918
2657
|
return resolved
|
|
1919
2658
|
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
2659
|
+
declared: list[Path] = []
|
|
2660
|
+
for candidate in getattr(entry, "env_files", ()) or ():
|
|
2661
|
+
try:
|
|
2662
|
+
p = Path(candidate).expanduser()
|
|
2663
|
+
except Exception:
|
|
2664
|
+
continue
|
|
2665
|
+
if p.exists() and p.is_file():
|
|
2666
|
+
declared.append(p)
|
|
2667
|
+
if declared:
|
|
2668
|
+
return declared
|
|
1928
2669
|
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
for repo_file in repo_env_files:
|
|
1933
|
-
if repo_file not in env_candidates:
|
|
1934
|
-
env_candidates.append(repo_file)
|
|
2670
|
+
def _append_candidate(collection: list[Path], candidate: Path) -> None:
|
|
2671
|
+
if candidate.exists() and candidate.is_file() and candidate not in collection:
|
|
2672
|
+
collection.append(candidate)
|
|
1935
2673
|
|
|
1936
|
-
|
|
1937
|
-
raise click.ClickException("No env file found. Pass --env-file explicitly.")
|
|
2674
|
+
auto_candidates: list[Path] = []
|
|
1938
2675
|
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
2676
|
+
search_dirs: list[Path] = []
|
|
2677
|
+
if original_path is not None:
|
|
2678
|
+
search_dirs.append(original_path.parent.resolve())
|
|
2679
|
+
for parent in original_path.parent.resolve().parents:
|
|
2680
|
+
search_dirs.append(parent)
|
|
2681
|
+
cwd = Path.cwd().resolve()
|
|
2682
|
+
if cwd not in search_dirs:
|
|
2683
|
+
search_dirs.append(cwd)
|
|
2684
|
+
repo_root = REPO_ROOT.resolve()
|
|
2685
|
+
if repo_root not in search_dirs:
|
|
2686
|
+
search_dirs.append(repo_root)
|
|
2687
|
+
|
|
2688
|
+
for directory in search_dirs:
|
|
2689
|
+
_append_candidate(auto_candidates, directory / ".env")
|
|
2690
|
+
for candidate in sorted(directory.glob("*.env")):
|
|
2691
|
+
_append_candidate(auto_candidates, candidate)
|
|
2692
|
+
|
|
2693
|
+
if auto_candidates:
|
|
2694
|
+
return [auto_candidates[0]]
|
|
2695
|
+
|
|
2696
|
+
raise click.ClickException(
|
|
2697
|
+
"No .env file discovered automatically. Pass --env-file /path/to/.env or generate one with `uvx synth-ai setup`."
|
|
2698
|
+
)
|
|
1944
2699
|
|
|
1945
2700
|
|
|
1946
2701
|
def _ensure_port_free(port: int, host: str, *, force: bool) -> None:
|
|
@@ -2242,7 +2997,14 @@ def deploy_app(
|
|
|
2242
2997
|
def modal_serve_app(
|
|
2243
2998
|
app_id: str | None, modal_cli: str, modal_name: str | None, env_file: Sequence[str]
|
|
2244
2999
|
) -> None:
|
|
2245
|
-
|
|
3000
|
+
click.echo(f"[modal-serve] requested app_id={app_id or '(auto)'} modal_cli={modal_cli}")
|
|
3001
|
+
try:
|
|
3002
|
+
choice = _select_app_choice(app_id, purpose="modal-serve")
|
|
3003
|
+
except SystemExit as exc: # bubble up with context (legacy argparse would trigger this)
|
|
3004
|
+
raise click.ClickException(
|
|
3005
|
+
f"Legacy CLI intercepted modal-serve (exit {exc.code}). "
|
|
3006
|
+
"Make sure you're running the Click CLI (synth_ai.cli:cli)."
|
|
3007
|
+
) from exc
|
|
2246
3008
|
|
|
2247
3009
|
if choice.modal_script:
|
|
2248
3010
|
env_paths = _resolve_env_paths_for_script(choice.modal_script, env_file)
|
|
@@ -2251,6 +3013,7 @@ def modal_serve_app(
|
|
|
2251
3013
|
return
|
|
2252
3014
|
|
|
2253
3015
|
entry = choice.ensure_entry()
|
|
3016
|
+
click.echo(f"[modal-serve] serving entry {entry.app_id} from {choice.path}")
|
|
2254
3017
|
_modal_serve_entry(entry, modal_name, modal_cli, env_file, original_path=choice.path)
|
|
2255
3018
|
|
|
2256
3019
|
|
|
@@ -2313,6 +3076,11 @@ def _write_modal_entrypoint(
|
|
|
2313
3076
|
if not any(str(p).startswith("synth-ai") for p in pip_packages):
|
|
2314
3077
|
pip_packages.insert(0, synth_pkg)
|
|
2315
3078
|
|
|
3079
|
+
apt_packages = list(modal_cfg.apt_packages)
|
|
3080
|
+
click.echo(f"[DEBUG] modal_cfg.apt_packages type: {type(modal_cfg.apt_packages)}")
|
|
3081
|
+
click.echo(f"[DEBUG] modal_cfg.apt_packages value: {modal_cfg.apt_packages}")
|
|
3082
|
+
click.echo(f"[DEBUG] apt_packages after list(): {apt_packages}")
|
|
3083
|
+
|
|
2316
3084
|
local_dirs = [(str(Path(src)), dst) for src, dst in modal_cfg.extra_local_dirs]
|
|
2317
3085
|
# Also mount the host synth_ai source if available to ensure latest code is used
|
|
2318
3086
|
if host_synth is not None:
|
|
@@ -2359,6 +3127,15 @@ INLINE_SECRET_VALUES = {inline_secret_values!r}
|
|
|
2359
3127
|
|
|
2360
3128
|
image = Image.debian_slim(python_version={modal_cfg.python_version!r})
|
|
2361
3129
|
|
|
3130
|
+
# CRITICAL: Install iverilog for Verilog task app (hardcoded to prevent config issues)
|
|
3131
|
+
if {entry.app_id!r} == "grpo-verilog":
|
|
3132
|
+
image = image.apt_install("iverilog")
|
|
3133
|
+
|
|
3134
|
+
# Install apt packages first (before pip)
|
|
3135
|
+
apt_packages = {apt_packages!r}
|
|
3136
|
+
if apt_packages:
|
|
3137
|
+
image = image.apt_install(*apt_packages)
|
|
3138
|
+
|
|
2362
3139
|
pip_packages = {pip_packages!r}
|
|
2363
3140
|
if pip_packages:
|
|
2364
3141
|
image = image.pip_install(*pip_packages)
|
|
@@ -2480,22 +3257,60 @@ def register(cli: click.Group) -> None:
|
|
|
2480
3257
|
cli.add_command(serve_command)
|
|
2481
3258
|
cli.add_command(task_app_group)
|
|
2482
3259
|
cli.add_command(eval_command)
|
|
3260
|
+
cli.add_command(filter_command)
|
|
2483
3261
|
|
|
2484
3262
|
|
|
2485
|
-
@click.command(
|
|
3263
|
+
@click.command(
|
|
3264
|
+
"eval",
|
|
3265
|
+
help="Run one-off rollouts against a task app and print judge/eval summaries.",
|
|
3266
|
+
)
|
|
2486
3267
|
@click.argument("app_id", type=str, required=False)
|
|
2487
|
-
@click.option(
|
|
3268
|
+
@click.option(
|
|
3269
|
+
"--config",
|
|
3270
|
+
type=click.Path(),
|
|
3271
|
+
default=None,
|
|
3272
|
+
help="Path to eval TOML (short schema). Auto-discovers the first matching file when omitted.",
|
|
3273
|
+
)
|
|
2488
3274
|
@click.option(
|
|
2489
3275
|
"--url",
|
|
2490
3276
|
"task_app_url",
|
|
2491
3277
|
type=str,
|
|
2492
3278
|
default=None,
|
|
2493
|
-
help="Base URL of a running task app (
|
|
3279
|
+
help="Base URL of a running task app instead of spawning locally (requires --env-file for secrets).",
|
|
3280
|
+
)
|
|
3281
|
+
@click.option(
|
|
3282
|
+
"--seeds",
|
|
3283
|
+
default="0,1,2,3,4",
|
|
3284
|
+
help="Comma-separated seeds/indices to evaluate. Use negative numbers to wrap around the dataset.",
|
|
2494
3285
|
)
|
|
2495
|
-
@click.option("--seeds", default="0,1,2,3,4", help="Comma-separated seeds/indices to evaluate")
|
|
2496
3286
|
@click.option("--split", default="train", show_default=True, help="Dataset split to use")
|
|
2497
|
-
@click.option(
|
|
2498
|
-
|
|
3287
|
+
@click.option(
|
|
3288
|
+
"--model",
|
|
3289
|
+
default=None,
|
|
3290
|
+
help="Model identifier. When omitted the CLI will prompt based on task metadata.",
|
|
3291
|
+
)
|
|
3292
|
+
@click.option(
|
|
3293
|
+
"--env-file",
|
|
3294
|
+
multiple=True,
|
|
3295
|
+
type=click.Path(),
|
|
3296
|
+
help="Env file(s) to load (API keys, etc.). Required when using --url or remote judges.",
|
|
3297
|
+
)
|
|
3298
|
+
@click.option(
|
|
3299
|
+
"--trace-db",
|
|
3300
|
+
default="traces/v3/synth_ai.db",
|
|
3301
|
+
show_default=True,
|
|
3302
|
+
help="SQLite/Turso URL for storing rollout traces set to 'none' to disable persistence.",
|
|
3303
|
+
)
|
|
3304
|
+
@click.option(
|
|
3305
|
+
"--metadata",
|
|
3306
|
+
multiple=True,
|
|
3307
|
+
help="Filter tasks by key=value metadata (e.g., --metadata difficulty=easy)",
|
|
3308
|
+
)
|
|
3309
|
+
@click.option(
|
|
3310
|
+
"--metadata-sql",
|
|
3311
|
+
default=None,
|
|
3312
|
+
help="SQLite query that returns seeds to evaluate (e.g., SELECT seed FROM tasks WHERE difficulty='easy' LIMIT 5)",
|
|
3313
|
+
)
|
|
2499
3314
|
def eval_command(
|
|
2500
3315
|
app_id: str | None,
|
|
2501
3316
|
config: str | None,
|
|
@@ -2504,10 +3319,24 @@ def eval_command(
|
|
|
2504
3319
|
split: str,
|
|
2505
3320
|
model: str | None,
|
|
2506
3321
|
env_file: Sequence[str],
|
|
3322
|
+
trace_db: str,
|
|
3323
|
+
metadata: Sequence[str],
|
|
3324
|
+
metadata_sql: str | None,
|
|
2507
3325
|
) -> None:
|
|
2508
|
-
"""Run
|
|
3326
|
+
"""Run rollouts against a task app and report judge statistics.
|
|
3327
|
+
|
|
3328
|
+
By default the command spins up the selected task app in-process, executes the
|
|
3329
|
+
requested seeds, and prints aggregate scores (official and custom judges). When
|
|
3330
|
+
pointing at a remote `--url`, supply matching `--env-file` values so the CLI can
|
|
3331
|
+
forward authentication headers to the running service.
|
|
3332
|
+
"""
|
|
3333
|
+
# Parse and validate TOML config
|
|
3334
|
+
from synth_ai.task.config import EvalConfig
|
|
3335
|
+
|
|
2509
3336
|
cfg: dict[str, Any] = {}
|
|
3337
|
+
eval_cfg: EvalConfig | None = None
|
|
2510
3338
|
config_path: Path | None = None
|
|
3339
|
+
|
|
2511
3340
|
if config:
|
|
2512
3341
|
config_path = Path(config)
|
|
2513
3342
|
else:
|
|
@@ -2529,10 +3358,73 @@ def eval_command(
|
|
|
2529
3358
|
if isinstance(parsed, dict):
|
|
2530
3359
|
section = parsed.get("eval")
|
|
2531
3360
|
cfg = dict(section) if isinstance(section, dict) else dict(parsed)
|
|
3361
|
+
|
|
3362
|
+
# Validate config with dataclass
|
|
3363
|
+
try:
|
|
3364
|
+
eval_cfg = EvalConfig.from_dict(cfg)
|
|
3365
|
+
click.echo(f"✓ Config validated: {len(eval_cfg.seeds)} seeds, model={eval_cfg.model}")
|
|
3366
|
+
except (ValueError, TypeError) as validation_error:
|
|
3367
|
+
raise click.ClickException(f"Invalid eval config: {validation_error}") from validation_error
|
|
3368
|
+
except click.ClickException:
|
|
3369
|
+
raise
|
|
2532
3370
|
except Exception as exc:
|
|
2533
3371
|
raise click.ClickException(f"Failed to parse TOML '{config_path}': {exc}") from exc
|
|
2534
3372
|
|
|
2535
|
-
|
|
3373
|
+
# CLI args override config
|
|
3374
|
+
if eval_cfg:
|
|
3375
|
+
app_id = app_id or eval_cfg.app_id
|
|
3376
|
+
else:
|
|
3377
|
+
app_id = app_id or (cfg.get("app_id") if isinstance(cfg.get("app_id"), str) else None) # type: ignore
|
|
3378
|
+
|
|
3379
|
+
metadata_filters: dict[str, str] = {}
|
|
3380
|
+
if eval_cfg:
|
|
3381
|
+
metadata_filters.update(eval_cfg.metadata)
|
|
3382
|
+
else:
|
|
3383
|
+
cfg_metadata = cfg.get("metadata")
|
|
3384
|
+
if isinstance(cfg_metadata, dict):
|
|
3385
|
+
for key, value in cfg_metadata.items():
|
|
3386
|
+
metadata_filters[str(key)] = str(value)
|
|
3387
|
+
elif isinstance(cfg_metadata, list):
|
|
3388
|
+
for item in cfg_metadata:
|
|
3389
|
+
if isinstance(item, str) and "=" in item:
|
|
3390
|
+
key, value = item.split("=", 1)
|
|
3391
|
+
metadata_filters[key.strip()] = value.strip()
|
|
3392
|
+
|
|
3393
|
+
for item in metadata or ():
|
|
3394
|
+
if "=" not in item:
|
|
3395
|
+
raise click.ClickException(f"Metadata filters must be key=value (got: {item})")
|
|
3396
|
+
key, value = item.split("=", 1)
|
|
3397
|
+
key = key.strip()
|
|
3398
|
+
value = value.strip()
|
|
3399
|
+
if not key or not value:
|
|
3400
|
+
raise click.ClickException(f"Invalid metadata filter: {item}")
|
|
3401
|
+
metadata_filters[key] = value
|
|
3402
|
+
|
|
3403
|
+
metadata_sql_query: str | None = None
|
|
3404
|
+
if eval_cfg and eval_cfg.metadata_sql:
|
|
3405
|
+
metadata_sql_query = eval_cfg.metadata_sql
|
|
3406
|
+
else:
|
|
3407
|
+
cfg_metadata_sql = cfg.get("metadata_sql")
|
|
3408
|
+
if isinstance(cfg_metadata_sql, dict):
|
|
3409
|
+
metadata_sql_query = cfg_metadata_sql.get("query") or cfg_metadata_sql.get("sql")
|
|
3410
|
+
elif isinstance(cfg_metadata_sql, str):
|
|
3411
|
+
metadata_sql_query = cfg_metadata_sql
|
|
3412
|
+
|
|
3413
|
+
if metadata_sql:
|
|
3414
|
+
metadata_sql_query = metadata_sql
|
|
3415
|
+
if metadata_sql_query is not None:
|
|
3416
|
+
metadata_sql_query = str(metadata_sql_query)
|
|
3417
|
+
|
|
3418
|
+
trace_db_url: str | None = None
|
|
3419
|
+
trace_db = (trace_db or "").strip()
|
|
3420
|
+
if trace_db and trace_db.lower() not in {"none", "off", "disable"}:
|
|
3421
|
+
if "://" in trace_db:
|
|
3422
|
+
trace_db_url = trace_db
|
|
3423
|
+
else:
|
|
3424
|
+
trace_path = Path(trace_db).expanduser()
|
|
3425
|
+
trace_path.parent.mkdir(parents=True, exist_ok=True)
|
|
3426
|
+
trace_db_url = f"sqlite+aiosqlite:///{trace_path}"
|
|
3427
|
+
trace_tracer: SessionTracer | None = SessionTracer(db_url=trace_db_url, auto_save=True) if trace_db_url else None
|
|
2536
3428
|
|
|
2537
3429
|
# Determine selection params (CLI takes precedence; TOML only fills unset model/seeds/env)
|
|
2538
3430
|
if cfg.get("model") and not model:
|
|
@@ -2553,14 +3445,16 @@ def eval_command(
|
|
|
2553
3445
|
elif isinstance(ef, list):
|
|
2554
3446
|
env_file = tuple(str(x) for x in ef) # type: ignore[assignment]
|
|
2555
3447
|
|
|
3448
|
+
choice_for_env: AppChoice | None = None
|
|
2556
3449
|
entry: TaskAppEntryType | None = None
|
|
2557
3450
|
if task_app_url is None:
|
|
2558
|
-
|
|
2559
|
-
entry =
|
|
3451
|
+
choice_for_env = _select_app_choice(app_id, purpose="eval")
|
|
3452
|
+
entry = choice_for_env.ensure_entry()
|
|
2560
3453
|
|
|
2561
3454
|
env_paths: list[Path] = []
|
|
2562
3455
|
if entry is not None:
|
|
2563
|
-
|
|
3456
|
+
original_env_path = choice_for_env.path if choice_for_env is not None else None
|
|
3457
|
+
env_paths = _determine_env_files(entry, env_file, original_path=original_env_path)
|
|
2564
3458
|
else:
|
|
2565
3459
|
if not env_file:
|
|
2566
3460
|
raise click.ClickException("--env-file is required when using --url")
|
|
@@ -2583,12 +3477,30 @@ def eval_command(
|
|
|
2583
3477
|
app = create_task_app(config)
|
|
2584
3478
|
|
|
2585
3479
|
# Determine supported models
|
|
3480
|
+
inference_meta: dict[str, Any] = {}
|
|
2586
3481
|
supported: list[str] = []
|
|
3482
|
+
seen_models: set[str] = set()
|
|
3483
|
+
|
|
3484
|
+
def _add_supported_model(candidate: Any) -> None:
|
|
3485
|
+
if not candidate:
|
|
3486
|
+
return
|
|
3487
|
+
text = str(candidate).strip()
|
|
3488
|
+
if not text or text in seen_models:
|
|
3489
|
+
return
|
|
3490
|
+
supported.append(text)
|
|
3491
|
+
seen_models.add(text)
|
|
3492
|
+
|
|
2587
3493
|
if task_app_url is None:
|
|
2588
3494
|
try:
|
|
2589
|
-
|
|
3495
|
+
if hasattr(config, "base_task_info") and config.base_task_info:
|
|
3496
|
+
inf_obj = getattr(config.base_task_info, "inference", None)
|
|
3497
|
+
if inf_obj is not None:
|
|
3498
|
+
if hasattr(inf_obj, "model_dump"):
|
|
3499
|
+
inference_meta = dict(inf_obj.model_dump(exclude_none=True)) # type: ignore[attr-defined]
|
|
3500
|
+
elif isinstance(inf_obj, dict):
|
|
3501
|
+
inference_meta = dict(inf_obj)
|
|
2590
3502
|
except Exception:
|
|
2591
|
-
|
|
3503
|
+
inference_meta = {}
|
|
2592
3504
|
else:
|
|
2593
3505
|
try:
|
|
2594
3506
|
import httpx as _hx
|
|
@@ -2601,38 +3513,38 @@ def eval_command(
|
|
|
2601
3513
|
info = c.get("/info").json()
|
|
2602
3514
|
inf = info.get("inference") if isinstance(info, dict) else None
|
|
2603
3515
|
if isinstance(inf, dict):
|
|
2604
|
-
|
|
2605
|
-
if isinstance(m, list):
|
|
2606
|
-
supported = [str(x) for x in m]
|
|
2607
|
-
if not supported:
|
|
2608
|
-
providers = inf.get("providers")
|
|
2609
|
-
if isinstance(providers, list):
|
|
2610
|
-
if "openai" in providers:
|
|
2611
|
-
supported.append("gpt-5")
|
|
2612
|
-
if "groq" in providers:
|
|
2613
|
-
supported.append("groq:llama-3.1-70b-versatile")
|
|
2614
|
-
supported.append("synth:qwen-0.6b")
|
|
3516
|
+
inference_meta = dict(inf)
|
|
2615
3517
|
except Exception:
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
|
|
2621
|
-
|
|
2622
|
-
|
|
2623
|
-
|
|
2624
|
-
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
|
|
2628
|
-
|
|
2629
|
-
|
|
3518
|
+
inference_meta = {}
|
|
3519
|
+
|
|
3520
|
+
default_model = inference_meta.get("model")
|
|
3521
|
+
if isinstance(default_model, str):
|
|
3522
|
+
_add_supported_model(default_model)
|
|
3523
|
+
|
|
3524
|
+
models_field = inference_meta.get("models")
|
|
3525
|
+
if isinstance(models_field, list):
|
|
3526
|
+
for candidate in models_field:
|
|
3527
|
+
_add_supported_model(candidate)
|
|
3528
|
+
|
|
3529
|
+
supported_models = inference_meta.get("supported_models")
|
|
3530
|
+
if isinstance(supported_models, list):
|
|
3531
|
+
for candidate in supported_models:
|
|
3532
|
+
_add_supported_model(candidate)
|
|
3533
|
+
|
|
3534
|
+
providers = inference_meta.get("providers")
|
|
3535
|
+
if isinstance(providers, list):
|
|
3536
|
+
if "openai" in providers:
|
|
3537
|
+
_add_supported_model("gpt-5")
|
|
3538
|
+
if "groq" in providers:
|
|
3539
|
+
_add_supported_model("groq:llama-3.1-70b-versatile")
|
|
3540
|
+
|
|
3541
|
+
_add_supported_model("synth:qwen-0.6b")
|
|
2630
3542
|
|
|
2631
3543
|
selected_model = model
|
|
2632
3544
|
if not selected_model:
|
|
2633
3545
|
if not supported:
|
|
2634
3546
|
raise click.ClickException(
|
|
2635
|
-
"No supported models; supply --model or add base_task_info.inference.
|
|
3547
|
+
"No supported models; supply --model or add base_task_info.inference.model"
|
|
2636
3548
|
)
|
|
2637
3549
|
click.echo("Select model to evaluate:")
|
|
2638
3550
|
for idx, m in enumerate(supported, start=1):
|
|
@@ -2652,70 +3564,402 @@ def eval_command(
|
|
|
2652
3564
|
if api_key:
|
|
2653
3565
|
headers["X-API-Key"] = api_key
|
|
2654
3566
|
|
|
3567
|
+
# Precompute optional policy overrides from TOML
|
|
3568
|
+
policy_overrides: dict[str, Any] = {}
|
|
3569
|
+
try:
|
|
3570
|
+
# Accept [eval.policy] table or top-level keys for convenience
|
|
3571
|
+
if isinstance(cfg.get("policy"), dict):
|
|
3572
|
+
policy_overrides.update(dict(cfg["policy"]))
|
|
3573
|
+
# Back-compat: allow temperature/max_tokens at top level
|
|
3574
|
+
for k in (
|
|
3575
|
+
"temperature",
|
|
3576
|
+
"max_tokens",
|
|
3577
|
+
"reasoning_effort",
|
|
3578
|
+
"system_hint",
|
|
3579
|
+
"tool_choice",
|
|
3580
|
+
"inference_url",
|
|
3581
|
+
):
|
|
3582
|
+
if k in cfg and k not in policy_overrides:
|
|
3583
|
+
policy_overrides[k] = cfg.get(k)
|
|
3584
|
+
except Exception:
|
|
3585
|
+
policy_overrides = {}
|
|
3586
|
+
|
|
3587
|
+
raw_concurrency = cfg.get("concurrency")
|
|
3588
|
+
try:
|
|
3589
|
+
concurrency_limit = int(raw_concurrency) if raw_concurrency is not None else 1
|
|
3590
|
+
except Exception:
|
|
3591
|
+
concurrency_limit = 1
|
|
3592
|
+
if concurrency_limit <= 0:
|
|
3593
|
+
concurrency_limit = 1
|
|
3594
|
+
concurrency_limit = min(concurrency_limit, max(1, len(seed_values)))
|
|
3595
|
+
|
|
3596
|
+
judge_specs: list[JudgeSpec] = []
|
|
3597
|
+
|
|
3598
|
+
def _register_judge(name_hint: str | None, judge_cfg: dict[str, Any]) -> None:
|
|
3599
|
+
if not judge_cfg:
|
|
3600
|
+
return
|
|
3601
|
+
judge_module = judge_cfg.get("module")
|
|
3602
|
+
judge_path = judge_cfg.get("path")
|
|
3603
|
+
judge_callable_name = judge_cfg.get("callable") or judge_cfg.get("function")
|
|
3604
|
+
if judge_module and judge_path:
|
|
3605
|
+
raise click.ClickException("Judge config cannot set both 'module' and 'path'")
|
|
3606
|
+
if not judge_module and not judge_path:
|
|
3607
|
+
raise click.ClickException("Judge config requires 'module' or 'path'")
|
|
3608
|
+
try:
|
|
3609
|
+
if judge_module:
|
|
3610
|
+
module = importlib.import_module(str(judge_module))
|
|
3611
|
+
else:
|
|
3612
|
+
path = Path(str(judge_path)).expanduser()
|
|
3613
|
+
if not path.exists():
|
|
3614
|
+
raise click.ClickException(f"Judge module path not found: {path}")
|
|
3615
|
+
spec = importlib.util.spec_from_file_location(
|
|
3616
|
+
f"_eval_judge_{path.stem}", path
|
|
3617
|
+
)
|
|
3618
|
+
if not spec or not spec.loader:
|
|
3619
|
+
raise click.ClickException(f"Failed to load judge module from {path}")
|
|
3620
|
+
module = importlib.util.module_from_spec(spec)
|
|
3621
|
+
sys.modules[spec.name] = module
|
|
3622
|
+
spec.loader.exec_module(module)
|
|
3623
|
+
except click.ClickException:
|
|
3624
|
+
raise
|
|
3625
|
+
except Exception as exc:
|
|
3626
|
+
raise click.ClickException(f"Unable to load judge module: {exc}") from exc
|
|
3627
|
+
|
|
3628
|
+
if judge_callable_name:
|
|
3629
|
+
try:
|
|
3630
|
+
judge_fn = getattr(module, str(judge_callable_name))
|
|
3631
|
+
except AttributeError as exc:
|
|
3632
|
+
raise click.ClickException(
|
|
3633
|
+
f"Judge callable '{judge_callable_name}' not found in module"
|
|
3634
|
+
) from exc
|
|
3635
|
+
else:
|
|
3636
|
+
if hasattr(module, "judge"):
|
|
3637
|
+
judge_fn = module.judge
|
|
3638
|
+
else:
|
|
3639
|
+
raise click.ClickException("Judge module must expose 'judge' callable")
|
|
3640
|
+
|
|
3641
|
+
if not callable(judge_fn):
|
|
3642
|
+
raise click.ClickException("Judge callable is not callable")
|
|
3643
|
+
|
|
3644
|
+
judge_kwargs = {
|
|
3645
|
+
k: v
|
|
3646
|
+
for k, v in judge_cfg.items()
|
|
3647
|
+
if k not in {"module", "path", "callable", "function", "name"}
|
|
3648
|
+
}
|
|
3649
|
+
display_name = str(
|
|
3650
|
+
judge_cfg.get("name")
|
|
3651
|
+
or name_hint
|
|
3652
|
+
or f"judge{len(judge_specs) + 1}"
|
|
3653
|
+
)
|
|
3654
|
+
judge_specs.append(JudgeSpec(display_name, judge_fn, judge_kwargs))
|
|
3655
|
+
|
|
3656
|
+
raw_judge_cfg = cfg.get("judge")
|
|
3657
|
+
if isinstance(raw_judge_cfg, dict) and raw_judge_cfg:
|
|
3658
|
+
direct_keys = {"module", "path", "callable", "function", "name"}
|
|
3659
|
+
has_direct_keys = any(key in raw_judge_cfg for key in direct_keys)
|
|
3660
|
+
nested_candidates = [
|
|
3661
|
+
(key, value)
|
|
3662
|
+
for key, value in raw_judge_cfg.items()
|
|
3663
|
+
if isinstance(value, dict)
|
|
3664
|
+
]
|
|
3665
|
+
if has_direct_keys and not nested_candidates:
|
|
3666
|
+
_register_judge(None, raw_judge_cfg)
|
|
3667
|
+
else:
|
|
3668
|
+
for sub_name, sub_cfg in nested_candidates:
|
|
3669
|
+
_register_judge(sub_name, sub_cfg)
|
|
3670
|
+
|
|
3671
|
+
raw_judges_list = cfg.get("judges")
|
|
3672
|
+
if isinstance(raw_judges_list, list):
|
|
3673
|
+
for _index, entry in enumerate(raw_judges_list, start=1):
|
|
3674
|
+
if isinstance(entry, dict):
|
|
3675
|
+
_register_judge(entry.get("name") or f"judge{len(judge_specs) + 1}", entry)
|
|
3676
|
+
|
|
3677
|
+
records: list[dict[str, Any]] = []
|
|
3678
|
+
|
|
2655
3679
|
successes = 0
|
|
2656
3680
|
failures = 0
|
|
2657
3681
|
# Aggregate outcome stats across successful seeds
|
|
2658
3682
|
outcome_sum: float = 0.0
|
|
2659
3683
|
outcome_count: int = 0
|
|
2660
3684
|
outcome_correct: int = 0
|
|
2661
|
-
|
|
2662
|
-
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
)
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
2676
|
-
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2682
|
-
|
|
2683
|
-
|
|
2684
|
-
|
|
2685
|
-
"
|
|
2686
|
-
"
|
|
2687
|
-
"
|
|
2688
|
-
|
|
2689
|
-
|
|
2690
|
-
policy_overrides[k] = cfg.get(k)
|
|
2691
|
-
except Exception:
|
|
2692
|
-
policy_overrides = {}
|
|
2693
|
-
|
|
2694
|
-
for seed_val in seed_values:
|
|
2695
|
-
body = {
|
|
2696
|
-
"run_id": str(uuid.uuid4()),
|
|
2697
|
-
"env": {"config": {"split": split, "index": seed_val}, "seed": seed_val},
|
|
2698
|
-
"policy": {
|
|
2699
|
-
"policy_name": selected_model,
|
|
2700
|
-
"config": {"model": selected_model, **policy_overrides},
|
|
2701
|
-
},
|
|
2702
|
-
"ops": [],
|
|
3685
|
+
|
|
3686
|
+
def _build_task_rows(taskset: Any) -> dict[int, dict[str, Any]]:
|
|
3687
|
+
rows: dict[int, dict[str, Any]] = {}
|
|
3688
|
+
if not isinstance(taskset, dict):
|
|
3689
|
+
return rows
|
|
3690
|
+
|
|
3691
|
+
scenario_ids = taskset.get("scenario_ids") or []
|
|
3692
|
+
loop_ids = taskset.get("loop_ids") or []
|
|
3693
|
+
thread_ids = taskset.get("thread_ids") or []
|
|
3694
|
+
difficulty_map = taskset.get("difficulty_map") or {}
|
|
3695
|
+
|
|
3696
|
+
max_len = max(len(scenario_ids), len(loop_ids), len(thread_ids))
|
|
3697
|
+
for seed in range(max_len):
|
|
3698
|
+
scenario_id = scenario_ids[seed] if seed < len(scenario_ids) else None
|
|
3699
|
+
loop_id = loop_ids[seed] if seed < len(loop_ids) else None
|
|
3700
|
+
thread_id = thread_ids[seed] if seed < len(thread_ids) else None
|
|
3701
|
+
difficulty = None
|
|
3702
|
+
if isinstance(difficulty_map, dict):
|
|
3703
|
+
if scenario_id and scenario_id in difficulty_map:
|
|
3704
|
+
difficulty = difficulty_map.get(scenario_id)
|
|
3705
|
+
elif str(seed) in difficulty_map:
|
|
3706
|
+
difficulty = difficulty_map.get(str(seed))
|
|
3707
|
+
|
|
3708
|
+
rows[seed] = {
|
|
3709
|
+
"seed": seed,
|
|
3710
|
+
"scenario_id": scenario_id,
|
|
3711
|
+
"loop_id": loop_id,
|
|
3712
|
+
"thread_id": thread_id,
|
|
3713
|
+
"difficulty": difficulty,
|
|
2703
3714
|
}
|
|
3715
|
+
return rows
|
|
3716
|
+
|
|
3717
|
+
def _apply_metadata_filters(
|
|
3718
|
+
rows: dict[int, dict[str, Any]], seeds_list: list[int], filters: dict[str, str]
|
|
3719
|
+
) -> list[int]:
|
|
3720
|
+
if not filters:
|
|
3721
|
+
return seeds_list
|
|
3722
|
+
filtered: list[int] = []
|
|
3723
|
+
for seed in seeds_list:
|
|
3724
|
+
row = rows.get(seed)
|
|
3725
|
+
if not row:
|
|
3726
|
+
continue
|
|
3727
|
+
include = True
|
|
3728
|
+
for key, expected in filters.items():
|
|
3729
|
+
actual = row.get(key)
|
|
3730
|
+
if actual is None:
|
|
3731
|
+
include = False
|
|
3732
|
+
break
|
|
3733
|
+
if str(actual).lower() != expected.lower():
|
|
3734
|
+
include = False
|
|
3735
|
+
break
|
|
3736
|
+
if include:
|
|
3737
|
+
filtered.append(seed)
|
|
3738
|
+
return filtered
|
|
3739
|
+
|
|
3740
|
+
def _apply_metadata_sql(
|
|
3741
|
+
rows: dict[int, dict[str, Any]], seeds_list: list[int], query: str
|
|
3742
|
+
) -> list[int]:
|
|
3743
|
+
"""Return seeds that satisfy an arbitrary SQL query.
|
|
3744
|
+
|
|
3745
|
+
The query is executed against an in-memory SQLite table named `tasks`
|
|
3746
|
+
with columns (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT).
|
|
3747
|
+
Any rows whose `seed` value (or first column if `seed` is absent) appear in the result set are retained.
|
|
3748
|
+
"""
|
|
3749
|
+
if not query:
|
|
3750
|
+
return seeds_list
|
|
3751
|
+
conn = sqlite3.connect(":memory:")
|
|
3752
|
+
try:
|
|
3753
|
+
cur = conn.cursor()
|
|
3754
|
+
cur.execute(
|
|
3755
|
+
"CREATE TABLE tasks (seed INTEGER, scenario_id TEXT, loop_id TEXT, thread_id TEXT, difficulty TEXT)"
|
|
3756
|
+
)
|
|
3757
|
+
insert_stmt = (
|
|
3758
|
+
"INSERT INTO tasks (seed, scenario_id, loop_id, thread_id, difficulty) VALUES (?,?,?,?,?)"
|
|
3759
|
+
)
|
|
3760
|
+
for seed in seeds_list:
|
|
3761
|
+
row = rows.get(seed, {})
|
|
3762
|
+
cur.execute(
|
|
3763
|
+
insert_stmt,
|
|
3764
|
+
[
|
|
3765
|
+
seed,
|
|
3766
|
+
row.get("scenario_id"),
|
|
3767
|
+
row.get("loop_id"),
|
|
3768
|
+
row.get("thread_id"),
|
|
3769
|
+
row.get("difficulty"),
|
|
3770
|
+
],
|
|
3771
|
+
)
|
|
3772
|
+
|
|
3773
|
+
result = cur.execute(query)
|
|
3774
|
+
fetched = result.fetchall()
|
|
3775
|
+
if not fetched:
|
|
3776
|
+
return []
|
|
3777
|
+
description = result.description or []
|
|
3778
|
+
col_names = [col[0] for col in description]
|
|
3779
|
+
seeds_out: list[int] = []
|
|
3780
|
+
for entry in fetched:
|
|
3781
|
+
value = entry[col_names.index("seed")] if "seed" in col_names else entry[0]
|
|
3782
|
+
try:
|
|
3783
|
+
seeds_out.append(int(value))
|
|
3784
|
+
except Exception as exc:
|
|
3785
|
+
raise click.ClickException(
|
|
3786
|
+
"metadata SQL query must return seed integers"
|
|
3787
|
+
) from exc
|
|
3788
|
+
seeds_set = set(seeds_out)
|
|
3789
|
+
return [seed for seed in seeds_list if seed in seeds_set]
|
|
3790
|
+
except sqlite3.Error as exc:
|
|
3791
|
+
raise click.ClickException(f"Failed to execute metadata SQL query: {exc}") from exc
|
|
3792
|
+
finally:
|
|
3793
|
+
conn.close()
|
|
3794
|
+
|
|
3795
|
+
async def _run_eval() -> None:
|
|
3796
|
+
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records, seed_values
|
|
3797
|
+
|
|
3798
|
+
if trace_tracer is not None and trace_tracer.db is None:
|
|
3799
|
+
await trace_tracer.initialize()
|
|
3800
|
+
|
|
3801
|
+
if task_app_url is None:
|
|
3802
|
+
transport = httpx.ASGITransport(app=app) # type: ignore[name-defined]
|
|
3803
|
+
async_client = httpx.AsyncClient(
|
|
3804
|
+
transport=cast(Any, transport),
|
|
3805
|
+
base_url="http://eval.local",
|
|
3806
|
+
timeout=300.0,
|
|
3807
|
+
follow_redirects=True,
|
|
3808
|
+
headers=headers,
|
|
3809
|
+
)
|
|
3810
|
+
else:
|
|
3811
|
+
async_client = httpx.AsyncClient(
|
|
3812
|
+
base_url=task_app_url,
|
|
3813
|
+
timeout=300.0,
|
|
3814
|
+
follow_redirects=True,
|
|
3815
|
+
headers=headers,
|
|
3816
|
+
)
|
|
3817
|
+
|
|
3818
|
+
try:
|
|
3819
|
+
taskset_payload: dict[str, Any] | None = None
|
|
2704
3820
|
try:
|
|
2705
|
-
|
|
2706
|
-
|
|
3821
|
+
task_info_response = await async_client.get("/task_info")
|
|
3822
|
+
except Exception:
|
|
3823
|
+
task_info_response = None
|
|
3824
|
+
if task_info_response is not None and task_info_response.status_code == 200:
|
|
3825
|
+
with contextlib.suppress(Exception):
|
|
3826
|
+
payload_json = task_info_response.json()
|
|
3827
|
+
if isinstance(payload_json, dict) and "taskset" in payload_json:
|
|
3828
|
+
taskset_payload = payload_json.get("taskset")
|
|
3829
|
+
if not isinstance(taskset_payload, dict):
|
|
3830
|
+
taskset_payload = None
|
|
3831
|
+
elif isinstance(payload_json, dict):
|
|
3832
|
+
taskset_payload = payload_json
|
|
3833
|
+
|
|
3834
|
+
available_seeds = list(seed_values)
|
|
3835
|
+
if metadata_sql_query or metadata_filters:
|
|
3836
|
+
if not taskset_payload:
|
|
3837
|
+
raise click.ClickException(
|
|
3838
|
+
"Task metadata filters require the task app to expose /task_info metadata"
|
|
3839
|
+
)
|
|
3840
|
+
rows = _build_task_rows(taskset_payload)
|
|
3841
|
+
if metadata_sql_query:
|
|
3842
|
+
available_seeds = _apply_metadata_sql(rows, available_seeds, metadata_sql_query)
|
|
3843
|
+
if metadata_filters:
|
|
3844
|
+
available_seeds = _apply_metadata_filters(rows, available_seeds, metadata_filters)
|
|
3845
|
+
if not available_seeds:
|
|
3846
|
+
raise click.ClickException("No seeds match the provided metadata filters")
|
|
3847
|
+
seed_values = available_seeds
|
|
3848
|
+
|
|
3849
|
+
semaphore = asyncio.Semaphore(concurrency_limit)
|
|
3850
|
+
|
|
3851
|
+
async def _run_seed(seed_val: int) -> None:
|
|
3852
|
+
nonlocal successes, failures, outcome_sum, outcome_count, outcome_correct, records
|
|
3853
|
+
# Read env_name and policy_name from config if available
|
|
3854
|
+
env_name = cfg.get("env_name") or (cfg.get("env", {}).get("env_name") if isinstance(cfg.get("env"), dict) else None)
|
|
3855
|
+
policy_name = cfg.get("policy_name") or (cfg.get("policy", {}).get("policy_name") if isinstance(cfg.get("policy"), dict) else None)
|
|
3856
|
+
env_config_overrides = cfg.get("env_config", {}) if isinstance(cfg.get("env_config"), dict) else {}
|
|
3857
|
+
policy_config_overrides = cfg.get("policy_config", {}) if isinstance(cfg.get("policy_config"), dict) else {}
|
|
3858
|
+
|
|
3859
|
+
# Debug: print config parsing
|
|
3860
|
+
if seed_val == 0:
|
|
3861
|
+
click.echo(f"[DEBUG] env_name from config: {env_name}")
|
|
3862
|
+
click.echo(f"[DEBUG] policy_name from config: {policy_name}")
|
|
3863
|
+
|
|
3864
|
+
# Generate default ops sequence if not provided
|
|
3865
|
+
max_llm_calls = policy_config_overrides.get("max_llm_calls", 10)
|
|
3866
|
+
ops_list = cfg.get("ops", [])
|
|
3867
|
+
if not ops_list:
|
|
3868
|
+
# Generate default "agent, env" pairs for max_llm_calls
|
|
3869
|
+
ops_list = ["agent", "env"] * int(max_llm_calls)
|
|
3870
|
+
|
|
3871
|
+
body = {
|
|
3872
|
+
"run_id": str(uuid.uuid4()),
|
|
3873
|
+
"env": {"config": {"split": split, "index": seed_val, **env_config_overrides}, "seed": seed_val},
|
|
3874
|
+
"policy": {
|
|
3875
|
+
"policy_name": policy_name or selected_model,
|
|
3876
|
+
"config": {"model": selected_model, **policy_overrides, **policy_config_overrides},
|
|
3877
|
+
},
|
|
3878
|
+
"ops": ops_list,
|
|
3879
|
+
"record": {
|
|
3880
|
+
"return_trace": cfg.get("return_trace", True),
|
|
3881
|
+
"trace_format": cfg.get("trace_format", "structured"),
|
|
3882
|
+
},
|
|
3883
|
+
"mode": "eval", # RolloutMode.EVAL: use inference URLs as-is, no transformations
|
|
3884
|
+
}
|
|
3885
|
+
if env_name:
|
|
3886
|
+
body["env"]["env_name"] = env_name
|
|
3887
|
+
|
|
3888
|
+
# Debug: print the body being sent
|
|
3889
|
+
if seed_val == 0:
|
|
3890
|
+
click.echo(f"[DEBUG] rollout body env: {body['env']}")
|
|
3891
|
+
click.echo(f"[DEBUG] rollout body policy: {body['policy']}")
|
|
3892
|
+
click.echo(f"[DEBUG] rollout body mode: {body.get('mode', 'NOT SET')}")
|
|
3893
|
+
rollout_elapsed: float | None = None
|
|
3894
|
+
rollout_start = time.perf_counter()
|
|
3895
|
+
try:
|
|
3896
|
+
import logging
|
|
3897
|
+
_log = logging.getLogger(__name__)
|
|
3898
|
+
_log.info(f"[EVAL_BODY_DEBUG] Sending body with mode={body.get('mode')}")
|
|
3899
|
+
async with semaphore:
|
|
3900
|
+
response = await async_client.post("/rollout", json=body)
|
|
3901
|
+
rollout_elapsed = time.perf_counter() - rollout_start
|
|
3902
|
+
except Exception as exc:
|
|
3903
|
+
failures += 1
|
|
3904
|
+
click.echo(f"seed={seed_val} error={exc}")
|
|
3905
|
+
return
|
|
3906
|
+
|
|
3907
|
+
ok = 200 <= response.status_code < 300
|
|
2707
3908
|
if ok:
|
|
2708
3909
|
successes += 1
|
|
2709
3910
|
else:
|
|
2710
3911
|
failures += 1
|
|
2711
3912
|
|
|
2712
|
-
|
|
2713
|
-
|
|
3913
|
+
summary = [f"seed={seed_val}", f"status={response.status_code}"]
|
|
3914
|
+
data: Any
|
|
2714
3915
|
try:
|
|
2715
|
-
data =
|
|
3916
|
+
data = response.json()
|
|
2716
3917
|
except Exception:
|
|
2717
3918
|
data = None
|
|
3919
|
+
|
|
3920
|
+
# Debug: print validation errors
|
|
3921
|
+
if response.status_code == 422 and data:
|
|
3922
|
+
click.echo(f"[DEBUG] 422 Validation Error: {data}")
|
|
3923
|
+
|
|
3924
|
+
metrics: dict[str, Any] | None = None
|
|
3925
|
+
completion: str | None = None
|
|
3926
|
+
prompt_index: int | None = None
|
|
3927
|
+
prompt_text: str | None = None
|
|
3928
|
+
task_id: str | None = None
|
|
3929
|
+
task_split: str | None = None
|
|
3930
|
+
task_rubric_id: str | None = None
|
|
3931
|
+
|
|
3932
|
+
trace_namespace: dict[str, Any] | None = None
|
|
3933
|
+
session_trace_dict: dict[str, Any] | None = None
|
|
3934
|
+
|
|
2718
3935
|
if isinstance(data, dict):
|
|
3936
|
+
import logging
|
|
3937
|
+
_logger = logging.getLogger(__name__)
|
|
3938
|
+
_logger.info(f"[EVAL_DEBUG] Response data keys: {list(data.keys())}")
|
|
3939
|
+
if "detail" in data:
|
|
3940
|
+
_logger.error(f"[EVAL_DEBUG] Task app returned error: {data['detail']}")
|
|
3941
|
+
trace_namespace = data.get("trace")
|
|
3942
|
+
_logger.info(f"[EVAL_DEBUG] trace_namespace type: {type(trace_namespace)}, value: {trace_namespace if not isinstance(trace_namespace, dict) else 'dict with keys: ' + str(list(trace_namespace.keys()) if trace_namespace else 'None')}")
|
|
3943
|
+
if not isinstance(trace_namespace, dict):
|
|
3944
|
+
raise RuntimeError(
|
|
3945
|
+
"The 'synth-ai eval' command requires trace payloads in rollout responses. "
|
|
3946
|
+
"Ensure the rollout request includes 'trace_format': 'structured' and 'return_trace': true, "
|
|
3947
|
+
"and that task app tracing is enabled (TASKAPP_TRACING_ENABLED=1). "
|
|
3948
|
+
"Note: This is specific to the eval command - general rollout endpoints don't require traces."
|
|
3949
|
+
)
|
|
3950
|
+
# Handle both "compact" and "full" trace formats:
|
|
3951
|
+
# - compact: trace_namespace contains {session_id, metadata, ...}
|
|
3952
|
+
# - full: trace_namespace IS the full session_trace dict
|
|
3953
|
+
session_trace_dict = trace_namespace.get("session_trace")
|
|
3954
|
+
if not isinstance(session_trace_dict, dict):
|
|
3955
|
+
# If no session_trace key, assume "full" format where trace itself is the session_trace
|
|
3956
|
+
if "session_id" in trace_namespace:
|
|
3957
|
+
session_trace_dict = trace_namespace
|
|
3958
|
+
else:
|
|
3959
|
+
raise RuntimeError(
|
|
3960
|
+
"The 'synth-ai eval' command requires 'session_trace' in the trace payload or a valid full trace format. "
|
|
3961
|
+
"Ensure the task app is using tracing_v3 and returning structured trace data."
|
|
3962
|
+
)
|
|
2719
3963
|
metrics = data.get("metrics") if isinstance(data.get("metrics"), dict) else None
|
|
2720
3964
|
if metrics:
|
|
2721
3965
|
mean_return = metrics.get("mean_return") or metrics.get("total_reward")
|
|
@@ -2724,7 +3968,6 @@ def eval_command(
|
|
|
2724
3968
|
summary.append(f"mean_return={mean_return}")
|
|
2725
3969
|
if outcome is not None:
|
|
2726
3970
|
summary.append(f"outcome={outcome}")
|
|
2727
|
-
# Aggregate outcome stats
|
|
2728
3971
|
try:
|
|
2729
3972
|
val = float(outcome)
|
|
2730
3973
|
outcome_sum += val
|
|
@@ -2733,7 +3976,6 @@ def eval_command(
|
|
|
2733
3976
|
outcome_correct += 1
|
|
2734
3977
|
except Exception:
|
|
2735
3978
|
pass
|
|
2736
|
-
# Try to infer tool call count from first trajectory step
|
|
2737
3979
|
trajs = (
|
|
2738
3980
|
data.get("trajectories")
|
|
2739
3981
|
if isinstance(data.get("trajectories"), list)
|
|
@@ -2747,38 +3989,164 @@ def eval_command(
|
|
|
2747
3989
|
tool_calls = step0.get("tool_calls") or step0.get("tools") or []
|
|
2748
3990
|
if isinstance(tool_calls, list):
|
|
2749
3991
|
summary.append(f"tool_calls={len(tool_calls)}")
|
|
3992
|
+
obs = step0.get("obs") if isinstance(step0, dict) else None
|
|
3993
|
+
if isinstance(obs, dict):
|
|
3994
|
+
idx_val = obs.get("prompt_index")
|
|
3995
|
+
if isinstance(idx_val, int):
|
|
3996
|
+
prompt_index = idx_val
|
|
3997
|
+
prompt_raw = obs.get("prompt")
|
|
3998
|
+
if isinstance(prompt_raw, str):
|
|
3999
|
+
prompt_text = prompt_raw
|
|
4000
|
+
if task_id is None:
|
|
4001
|
+
candidate_id = obs.get("task_id")
|
|
4002
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
4003
|
+
task_id = candidate_id
|
|
4004
|
+
if task_split is None:
|
|
4005
|
+
candidate_split = obs.get("task_split")
|
|
4006
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
4007
|
+
task_split = candidate_split
|
|
4008
|
+
if task_rubric_id is None:
|
|
4009
|
+
candidate_rid = obs.get("task_rubric_id")
|
|
4010
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4011
|
+
task_rubric_id = candidate_rid
|
|
4012
|
+
final = first.get("final") if isinstance(first, dict) else None
|
|
4013
|
+
if isinstance(final, dict):
|
|
4014
|
+
final_obs = final.get("observation")
|
|
4015
|
+
if isinstance(final_obs, dict):
|
|
4016
|
+
comp_val = final_obs.get("completion")
|
|
4017
|
+
if isinstance(comp_val, str):
|
|
4018
|
+
completion = comp_val
|
|
4019
|
+
if task_id is None:
|
|
4020
|
+
candidate_id = final_obs.get("task_id")
|
|
4021
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
4022
|
+
task_id = candidate_id
|
|
4023
|
+
if task_split is None:
|
|
4024
|
+
candidate_split = final_obs.get("task_split")
|
|
4025
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
4026
|
+
task_split = candidate_split
|
|
4027
|
+
if task_rubric_id is None:
|
|
4028
|
+
candidate_rid = final_obs.get("task_rubric_id")
|
|
4029
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4030
|
+
task_rubric_id = candidate_rid
|
|
4031
|
+
final_info = final.get("info")
|
|
4032
|
+
if isinstance(final_info, dict):
|
|
4033
|
+
if task_id is None:
|
|
4034
|
+
candidate_id = final_info.get("task_id")
|
|
4035
|
+
if isinstance(candidate_id, str) and candidate_id:
|
|
4036
|
+
task_id = candidate_id
|
|
4037
|
+
if task_split is None:
|
|
4038
|
+
candidate_split = final_info.get("task_split")
|
|
4039
|
+
if isinstance(candidate_split, str) and candidate_split:
|
|
4040
|
+
task_split = candidate_split
|
|
4041
|
+
if task_rubric_id is None:
|
|
4042
|
+
candidate_rid = final_info.get("task_rubric_id")
|
|
4043
|
+
if isinstance(candidate_rid, str) and candidate_rid:
|
|
4044
|
+
task_rubric_id = candidate_rid
|
|
4045
|
+
if task_id:
|
|
4046
|
+
summary.append(f"task_id={task_id}")
|
|
2750
4047
|
click.echo(" ".join(summary))
|
|
2751
|
-
# Print the full response JSON (trace, trajectories, metrics)
|
|
2752
4048
|
with contextlib.suppress(Exception):
|
|
2753
4049
|
click.echo(json.dumps(data, indent=2))
|
|
2754
4050
|
else:
|
|
2755
4051
|
click.echo(" ".join(summary))
|
|
2756
|
-
except Exception as exc:
|
|
2757
|
-
failures += 1
|
|
2758
|
-
click.echo(f"seed={seed_val} error={exc}")
|
|
2759
4052
|
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
|
|
2766
|
-
|
|
2767
|
-
|
|
2768
|
-
except RuntimeError:
|
|
2769
|
-
# Fallback when already inside a running loop (rare for CLI).
|
|
2770
|
-
new_loop = asyncio.new_event_loop()
|
|
4053
|
+
official_score = None
|
|
4054
|
+
if isinstance(metrics, dict):
|
|
4055
|
+
for key in ("mean_return", "total_reward", "outcome_score"):
|
|
4056
|
+
val = metrics.get(key)
|
|
4057
|
+
if isinstance(val, int | float):
|
|
4058
|
+
official_score = float(val)
|
|
4059
|
+
break
|
|
4060
|
+
if official_score is None and isinstance(data, dict):
|
|
2771
4061
|
try:
|
|
2772
|
-
|
|
2773
|
-
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
|
|
4062
|
+
reward_val = data["trajectories"][0]["steps"][0].get("reward")
|
|
4063
|
+
if isinstance(reward_val, int | float):
|
|
4064
|
+
official_score = float(reward_val)
|
|
4065
|
+
except Exception:
|
|
4066
|
+
pass
|
|
4067
|
+
|
|
4068
|
+
if official_score is not None:
|
|
4069
|
+
if official_score < 0.0:
|
|
4070
|
+
official_score = 0.0
|
|
4071
|
+
elif official_score > 1.0:
|
|
4072
|
+
official_score = min(1.0, official_score)
|
|
4073
|
+
|
|
4074
|
+
judge_scores: dict[str, float | None] = {}
|
|
4075
|
+
judges_timings: dict[str, float | None] = {}
|
|
4076
|
+
timings: dict[str, Any] = {
|
|
4077
|
+
"rollout_s": rollout_elapsed,
|
|
4078
|
+
"judges": judges_timings,
|
|
4079
|
+
}
|
|
4080
|
+
if judge_specs:
|
|
4081
|
+
for spec in judge_specs:
|
|
4082
|
+
score_value: float | None = None
|
|
4083
|
+
judge_elapsed: float | None = None
|
|
4084
|
+
# Run judges for all tasks (text-based and trajectory-based)
|
|
4085
|
+
# Text-based tasks have completion, trajectory-based tasks use response
|
|
4086
|
+
judge_payload = {
|
|
4087
|
+
"seed": seed_val,
|
|
4088
|
+
"prompt_index": prompt_index,
|
|
4089
|
+
"prompt": prompt_text,
|
|
4090
|
+
"completion": completion,
|
|
4091
|
+
"metrics": metrics,
|
|
4092
|
+
"response": data,
|
|
4093
|
+
"trace": trace_namespace,
|
|
4094
|
+
}
|
|
4095
|
+
try:
|
|
4096
|
+
judge_start = time.perf_counter()
|
|
4097
|
+
result = spec.fn(judge_payload, **spec.kwargs)
|
|
4098
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
4099
|
+
if isinstance(result, int | float):
|
|
4100
|
+
score_value = float(result)
|
|
4101
|
+
except Exception as exc:
|
|
4102
|
+
if judge_elapsed is None:
|
|
4103
|
+
judge_elapsed = time.perf_counter() - judge_start
|
|
4104
|
+
click.echo(f"seed={seed_val} judge[{spec.name}]_error={exc}")
|
|
4105
|
+
judges_timings[spec.name] = judge_elapsed
|
|
4106
|
+
judge_scores[spec.name] = score_value
|
|
4107
|
+
|
|
4108
|
+
if trace_tracer is not None and trace_namespace:
|
|
4109
|
+
storage_metadata = {
|
|
4110
|
+
"eval_seed": seed_val,
|
|
4111
|
+
"prompt_index": prompt_index,
|
|
4112
|
+
"task_id": task_id,
|
|
4113
|
+
"task_split": task_split,
|
|
4114
|
+
"task_rubric_id": task_rubric_id,
|
|
4115
|
+
"official_score": official_score,
|
|
4116
|
+
"judge_scores": judge_scores,
|
|
4117
|
+
"model": selected_model,
|
|
4118
|
+
"prompt": prompt_text,
|
|
4119
|
+
"completion": completion,
|
|
4120
|
+
}
|
|
4121
|
+
await _store_trace(trace_tracer, trace_namespace, storage_metadata)
|
|
4122
|
+
|
|
4123
|
+
records.append(
|
|
4124
|
+
{
|
|
4125
|
+
"seed": seed_val,
|
|
4126
|
+
"prompt_index": prompt_index,
|
|
4127
|
+
"task_id": task_id,
|
|
4128
|
+
"task_split": task_split,
|
|
4129
|
+
"task_rubric_id": task_rubric_id,
|
|
4130
|
+
"official_score": official_score,
|
|
4131
|
+
"judge_scores": judge_scores,
|
|
4132
|
+
"timings": timings,
|
|
4133
|
+
}
|
|
4134
|
+
)
|
|
4135
|
+
|
|
4136
|
+
await asyncio.gather(*[_run_seed(seed_val) for seed_val in seed_values])
|
|
4137
|
+
finally:
|
|
4138
|
+
await async_client.aclose()
|
|
4139
|
+
|
|
4140
|
+
try:
|
|
4141
|
+
asyncio.run(_run_eval())
|
|
4142
|
+
finally:
|
|
4143
|
+
if trace_tracer is not None and trace_tracer.db is not None:
|
|
4144
|
+
asyncio.run(trace_tracer.db.close())
|
|
2777
4145
|
|
|
2778
4146
|
click.echo(
|
|
2779
4147
|
f"Eval complete: {successes} ok, {failures} failed; model={selected_model}, split={split}"
|
|
2780
4148
|
)
|
|
2781
|
-
|
|
4149
|
+
|
|
2782
4150
|
if outcome_count > 0:
|
|
2783
4151
|
mean_outcome = outcome_sum / float(outcome_count)
|
|
2784
4152
|
frac_right = outcome_correct / float(outcome_count)
|
|
@@ -2786,6 +4154,370 @@ def eval_command(
|
|
|
2786
4154
|
f"Outcome summary: correct={outcome_correct}/{outcome_count} ({frac_right:.2%}), mean_outcome={mean_outcome:.3f}"
|
|
2787
4155
|
)
|
|
2788
4156
|
|
|
4157
|
+
if records:
|
|
4158
|
+
judge_specs = judge_specs or [] # ensure iterable
|
|
4159
|
+
official_scores = [
|
|
4160
|
+
r["official_score"] for r in records if r["official_score"] is not None
|
|
4161
|
+
]
|
|
4162
|
+
if official_scores:
|
|
4163
|
+
click.echo(f" Official mean: {sum(official_scores) / len(official_scores):.3f}")
|
|
4164
|
+
else:
|
|
4165
|
+
click.echo(" Official mean: n/a")
|
|
4166
|
+
|
|
4167
|
+
for spec in judge_specs:
|
|
4168
|
+
spec_scores = [
|
|
4169
|
+
record["judge_scores"].get(spec.name)
|
|
4170
|
+
for record in records
|
|
4171
|
+
if record["judge_scores"].get(spec.name) is not None
|
|
4172
|
+
]
|
|
4173
|
+
if spec_scores:
|
|
4174
|
+
mean_spec = sum(spec_scores) / len(spec_scores)
|
|
4175
|
+
click.echo(f" [{spec.name}] mean: {mean_spec:.3f}")
|
|
4176
|
+
else:
|
|
4177
|
+
click.echo(f" [{spec.name}] mean: n/a")
|
|
4178
|
+
|
|
4179
|
+
paired = [
|
|
4180
|
+
(
|
|
4181
|
+
record["official_score"],
|
|
4182
|
+
record["judge_scores"].get(spec.name),
|
|
4183
|
+
)
|
|
4184
|
+
for record in records
|
|
4185
|
+
if record["official_score"] is not None
|
|
4186
|
+
and record["judge_scores"].get(spec.name) is not None
|
|
4187
|
+
]
|
|
4188
|
+
if len(paired) >= 2:
|
|
4189
|
+
corr = _pearson(
|
|
4190
|
+
[p[0] for p in paired if p[0] is not None],
|
|
4191
|
+
[p[1] for p in paired if p[1] is not None],
|
|
4192
|
+
)
|
|
4193
|
+
if corr is not None:
|
|
4194
|
+
click.echo(f" Pearson r: {corr:.3f}")
|
|
4195
|
+
else:
|
|
4196
|
+
click.echo(" Pearson r: undefined (zero variance)")
|
|
4197
|
+
else:
|
|
4198
|
+
click.echo(" Pearson r: n/a (need ≥2 paired scores)")
|
|
4199
|
+
|
|
4200
|
+
header = ["Seed", "Prompt", "Official"]
|
|
4201
|
+
header.extend(spec.name for spec in judge_specs)
|
|
4202
|
+
rows: list[list[str]] = []
|
|
4203
|
+
for record in sorted(records, key=lambda r: (r["seed"], r.get("prompt_index") or -1)):
|
|
4204
|
+
seed_val = str(record["seed"])
|
|
4205
|
+
prompt_idx = (
|
|
4206
|
+
str(record["prompt_index"])
|
|
4207
|
+
if record["prompt_index"] is not None
|
|
4208
|
+
else "-"
|
|
4209
|
+
)
|
|
4210
|
+
official_val = (
|
|
4211
|
+
f"{record['official_score']:.3f}"
|
|
4212
|
+
if record["official_score"] is not None
|
|
4213
|
+
else "-"
|
|
4214
|
+
)
|
|
4215
|
+
row = [seed_val, prompt_idx, official_val]
|
|
4216
|
+
for spec in judge_specs:
|
|
4217
|
+
score_val = record["judge_scores"].get(spec.name)
|
|
4218
|
+
row.append(f"{score_val:.3f}" if isinstance(score_val, int | float) else "-")
|
|
4219
|
+
rows.append(row)
|
|
4220
|
+
|
|
4221
|
+
widths = [len(col) for col in header]
|
|
4222
|
+
for row in rows:
|
|
4223
|
+
for idx, cell in enumerate(row):
|
|
4224
|
+
widths[idx] = max(widths[idx], len(cell))
|
|
4225
|
+
|
|
4226
|
+
click.echo("")
|
|
4227
|
+
click.echo(" ".join(h.ljust(widths[idx]) for idx, h in enumerate(header)))
|
|
4228
|
+
click.echo(" ".join("-" * widths[idx] for idx in range(len(header))))
|
|
4229
|
+
for row in rows:
|
|
4230
|
+
click.echo(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row)))
|
|
4231
|
+
|
|
4232
|
+
|
|
4233
|
+
|
|
4234
|
+
@click.command(
|
|
4235
|
+
"filter",
|
|
4236
|
+
help="Export filtered tracing sessions to SFT-ready JSONL based on a TOML config.",
|
|
4237
|
+
)
|
|
4238
|
+
@click.option(
|
|
4239
|
+
"--config",
|
|
4240
|
+
"config_path",
|
|
4241
|
+
type=click.Path(),
|
|
4242
|
+
required=True,
|
|
4243
|
+
help="Path to TOML config describing the input trace DB, score thresholds, and output JSONL.",
|
|
4244
|
+
)
|
|
4245
|
+
def filter_command(config_path: str) -> None:
|
|
4246
|
+
"""Render tracing sessions that match filter rules into SFT JSONL.
|
|
4247
|
+
|
|
4248
|
+
The TOML file should contain a `[filter]` table with at least:
|
|
4249
|
+
|
|
4250
|
+
db = \"path/to/traces.db\" # sqlite path or URL (sqlite+aiosqlite://...)
|
|
4251
|
+
output = \"ft_data/out.jsonl\" # destination JSONL
|
|
4252
|
+
|
|
4253
|
+
Optional keys such as `splits`, `task_ids`, `models`, `min_official_score`, or
|
|
4254
|
+
`min_judge_scores.my_judge = 0.7` allow you to narrow the dataset down to
|
|
4255
|
+
high-quality traces. See `customers/agora_single_file/configs/filter_local.toml`
|
|
4256
|
+
for a working example.
|
|
4257
|
+
"""
|
|
4258
|
+
# Parse and validate TOML config
|
|
4259
|
+
from synth_ai.task.config import FilterConfig
|
|
4260
|
+
|
|
4261
|
+
if _toml is None:
|
|
4262
|
+
raise click.ClickException("TOML parser not available; install tomli or use Python 3.11+")
|
|
4263
|
+
|
|
4264
|
+
cfg_path = Path(config_path)
|
|
4265
|
+
if not cfg_path.exists():
|
|
4266
|
+
raise click.ClickException(f"Filter config not found: {cfg_path}")
|
|
4267
|
+
|
|
4268
|
+
try:
|
|
4269
|
+
config_data = _toml.loads(cfg_path.read_text(encoding="utf-8"))
|
|
4270
|
+
except Exception as exc:
|
|
4271
|
+
raise click.ClickException(f"Failed to parse TOML '{cfg_path}': {exc}") from exc
|
|
4272
|
+
|
|
4273
|
+
filter_cfg_dict = config_data.get("filter") if isinstance(config_data, dict) else None
|
|
4274
|
+
if not isinstance(filter_cfg_dict, dict):
|
|
4275
|
+
raise click.ClickException("Config must contain a [filter] table")
|
|
4276
|
+
|
|
4277
|
+
# Validate config with dataclass
|
|
4278
|
+
try:
|
|
4279
|
+
filter_cfg = FilterConfig.from_dict(filter_cfg_dict)
|
|
4280
|
+
click.echo(f"✓ Config validated: db={filter_cfg.db}, output={filter_cfg.output}")
|
|
4281
|
+
if filter_cfg.min_official_score is not None:
|
|
4282
|
+
click.echo(f" → Filtering for official score >= {filter_cfg.min_official_score}")
|
|
4283
|
+
if filter_cfg.limit:
|
|
4284
|
+
click.echo(f" → Limiting to {filter_cfg.limit} examples")
|
|
4285
|
+
except (ValueError, TypeError) as validation_error:
|
|
4286
|
+
raise click.ClickException(f"Invalid filter config: {validation_error}") from validation_error
|
|
4287
|
+
|
|
4288
|
+
# Use validated config
|
|
4289
|
+
db_url = filter_cfg.get_db_url()
|
|
4290
|
+
output_path = filter_cfg.get_output_path()
|
|
4291
|
+
|
|
4292
|
+
# Extract validated fields from dataclass
|
|
4293
|
+
splits = set(filter_cfg.splits)
|
|
4294
|
+
task_ids = set(filter_cfg.task_ids)
|
|
4295
|
+
models = set(filter_cfg.models)
|
|
4296
|
+
min_official = filter_cfg.min_official_score
|
|
4297
|
+
max_official = filter_cfg.max_official_score
|
|
4298
|
+
min_judge_scores = filter_cfg.min_judge_scores
|
|
4299
|
+
max_judge_scores = filter_cfg.max_judge_scores
|
|
4300
|
+
# Note: min_created_at and max_created_at not yet in FilterConfig dataclass
|
|
4301
|
+
min_created = _parse_datetime_for_trace(filter_cfg_dict.get("min_created_at"))
|
|
4302
|
+
max_created = _parse_datetime_for_trace(filter_cfg_dict.get("max_created_at"))
|
|
4303
|
+
limit = filter_cfg.limit
|
|
4304
|
+
|
|
4305
|
+
def _score_ok(value: Any, min_val: Any, max_val: Any) -> bool:
|
|
4306
|
+
try:
|
|
4307
|
+
if value is None:
|
|
4308
|
+
return min_val is None
|
|
4309
|
+
value = float(value)
|
|
4310
|
+
except Exception:
|
|
4311
|
+
return False
|
|
4312
|
+
if min_val is not None and value < float(min_val):
|
|
4313
|
+
return False
|
|
4314
|
+
return not (max_val is not None and value > float(max_val))
|
|
4315
|
+
|
|
4316
|
+
async def _run_filter() -> None:
|
|
4317
|
+
tracer = SessionTracer(db_url=db_url, auto_save=False)
|
|
4318
|
+
await tracer.initialize()
|
|
4319
|
+
|
|
4320
|
+
df = await tracer.db.query_traces(
|
|
4321
|
+
"SELECT session_id, created_at, metadata FROM session_traces ORDER BY created_at"
|
|
4322
|
+
)
|
|
4323
|
+
if getattr(df, "empty", True):
|
|
4324
|
+
raise click.ClickException("No traces found in database")
|
|
4325
|
+
|
|
4326
|
+
sessions = df.to_dict("records")
|
|
4327
|
+
accepted: list[dict[str, Any]] = []
|
|
4328
|
+
|
|
4329
|
+
for row in sessions:
|
|
4330
|
+
metadata_raw = row.get("metadata")
|
|
4331
|
+
if isinstance(metadata_raw, str):
|
|
4332
|
+
try:
|
|
4333
|
+
metadata = json.loads(metadata_raw)
|
|
4334
|
+
except Exception:
|
|
4335
|
+
metadata = {}
|
|
4336
|
+
elif isinstance(metadata_raw, dict):
|
|
4337
|
+
metadata = dict(metadata_raw)
|
|
4338
|
+
else:
|
|
4339
|
+
metadata = {}
|
|
4340
|
+
|
|
4341
|
+
created_at_raw = row.get("created_at")
|
|
4342
|
+
created_at_dt = _parse_datetime_for_trace(created_at_raw)
|
|
4343
|
+
|
|
4344
|
+
session_id = row.get("session_id")
|
|
4345
|
+
|
|
4346
|
+
if splits and metadata.get("task_split") not in splits:
|
|
4347
|
+
continue
|
|
4348
|
+
if task_ids and metadata.get("task_id") not in task_ids:
|
|
4349
|
+
continue
|
|
4350
|
+
if models and metadata.get("model") not in models:
|
|
4351
|
+
continue
|
|
4352
|
+
|
|
4353
|
+
if min_created and (created_at_dt is None or created_at_dt < min_created):
|
|
4354
|
+
continue
|
|
4355
|
+
if max_created and (created_at_dt is None or created_at_dt > max_created):
|
|
4356
|
+
continue
|
|
4357
|
+
|
|
4358
|
+
# Check against outcome_rewards if score filter is set
|
|
4359
|
+
total_reward = None
|
|
4360
|
+
achievements_count = None
|
|
4361
|
+
if min_official is not None or max_official is not None:
|
|
4362
|
+
reward_query = "SELECT total_reward, achievements_count FROM outcome_rewards WHERE session_id = :session_id"
|
|
4363
|
+
reward_rows = await tracer.db.query_traces(reward_query, {"session_id": session_id})
|
|
4364
|
+
reward_records = reward_rows.to_dict("records") if hasattr(reward_rows, "to_dict") else []
|
|
4365
|
+
if reward_records:
|
|
4366
|
+
total_reward = reward_records[0].get("total_reward")
|
|
4367
|
+
achievements_count = reward_records[0].get("achievements_count")
|
|
4368
|
+
if not _score_ok(total_reward, min_official, max_official):
|
|
4369
|
+
continue
|
|
4370
|
+
elif min_official is not None:
|
|
4371
|
+
# No reward found, but score filter requires it
|
|
4372
|
+
continue
|
|
4373
|
+
|
|
4374
|
+
judge_scores = metadata.get("judge_scores") or {}
|
|
4375
|
+
include = True
|
|
4376
|
+
for judge_name, threshold in (min_judge_scores or {}).items():
|
|
4377
|
+
if not _score_ok(judge_scores.get(judge_name), threshold, None):
|
|
4378
|
+
include = False
|
|
4379
|
+
break
|
|
4380
|
+
if not include:
|
|
4381
|
+
continue
|
|
4382
|
+
for judge_name, threshold in (max_judge_scores or {}).items():
|
|
4383
|
+
if not _score_ok(judge_scores.get(judge_name), None, threshold):
|
|
4384
|
+
include = False
|
|
4385
|
+
break
|
|
4386
|
+
if not include:
|
|
4387
|
+
continue
|
|
4388
|
+
|
|
4389
|
+
# Query messages for this session
|
|
4390
|
+
messages_query = """
|
|
4391
|
+
SELECT message_type, content, timestamp
|
|
4392
|
+
FROM messages
|
|
4393
|
+
WHERE session_id = :session_id
|
|
4394
|
+
ORDER BY timestamp ASC, id ASC
|
|
4395
|
+
"""
|
|
4396
|
+
msg_df = await tracer.db.query_traces(messages_query, {"session_id": session_id})
|
|
4397
|
+
message_rows = msg_df.to_dict("records") if hasattr(msg_df, "to_dict") else []
|
|
4398
|
+
|
|
4399
|
+
if not message_rows:
|
|
4400
|
+
# Fallback: check if prompt/completion in metadata (old format)
|
|
4401
|
+
prompt = metadata.get("prompt") or ""
|
|
4402
|
+
completion = metadata.get("completion") or ""
|
|
4403
|
+
if prompt and completion:
|
|
4404
|
+
record = {
|
|
4405
|
+
"messages": [
|
|
4406
|
+
{"role": "user", "content": str(prompt)},
|
|
4407
|
+
{"role": "assistant", "content": str(completion)},
|
|
4408
|
+
],
|
|
4409
|
+
"metadata": {
|
|
4410
|
+
"session_id": session_id,
|
|
4411
|
+
"env_name": metadata.get("env_name"),
|
|
4412
|
+
"policy_name": metadata.get("policy_name"),
|
|
4413
|
+
"seed": metadata.get("seed"),
|
|
4414
|
+
"total_reward": total_reward,
|
|
4415
|
+
"achievements_count": achievements_count,
|
|
4416
|
+
"model": metadata.get("model"),
|
|
4417
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4418
|
+
},
|
|
4419
|
+
}
|
|
4420
|
+
accepted.append(record)
|
|
4421
|
+
continue
|
|
4422
|
+
|
|
4423
|
+
# Extract user/assistant pairs from messages
|
|
4424
|
+
for i, msg_row in enumerate(message_rows):
|
|
4425
|
+
msg_type = msg_row.get("message_type")
|
|
4426
|
+
content_raw = msg_row.get("content")
|
|
4427
|
+
|
|
4428
|
+
# Look for user message
|
|
4429
|
+
if msg_type in ("user", "policy_user_prompt"):
|
|
4430
|
+
# Find next policy_system_prompt or assistant
|
|
4431
|
+
assistant_msg = None
|
|
4432
|
+
for j in range(i + 1, len(message_rows)):
|
|
4433
|
+
next_type = message_rows[j].get("message_type")
|
|
4434
|
+
if next_type in ("assistant", "policy_system_prompt"):
|
|
4435
|
+
if next_type == "assistant":
|
|
4436
|
+
assistant_msg = message_rows[j]
|
|
4437
|
+
break
|
|
4438
|
+
|
|
4439
|
+
# Parse content
|
|
4440
|
+
try:
|
|
4441
|
+
user_content = json.loads(content_raw) if isinstance(content_raw, str) else content_raw
|
|
4442
|
+
except Exception:
|
|
4443
|
+
user_content = content_raw
|
|
4444
|
+
|
|
4445
|
+
# Extract text from structured content
|
|
4446
|
+
def extract_text(content: Any) -> str:
|
|
4447
|
+
if isinstance(content, str):
|
|
4448
|
+
return content
|
|
4449
|
+
if isinstance(content, dict):
|
|
4450
|
+
# Try payload.content for user prompts
|
|
4451
|
+
if "payload" in content and isinstance(content["payload"], dict):
|
|
4452
|
+
payload = content["payload"]
|
|
4453
|
+
if "content" in payload:
|
|
4454
|
+
return extract_text(payload["content"])
|
|
4455
|
+
# Try common keys
|
|
4456
|
+
for key in ["text", "content", "content_text"]:
|
|
4457
|
+
if key in content:
|
|
4458
|
+
val = content[key]
|
|
4459
|
+
if isinstance(val, str):
|
|
4460
|
+
return val
|
|
4461
|
+
return json.dumps(content)
|
|
4462
|
+
if isinstance(content, list):
|
|
4463
|
+
# Multimodal content - concatenate text parts
|
|
4464
|
+
parts = []
|
|
4465
|
+
for item in content:
|
|
4466
|
+
if isinstance(item, dict) and item.get("type") == "text":
|
|
4467
|
+
parts.append(item.get("text", ""))
|
|
4468
|
+
return " ".join(parts) if parts else str(content)
|
|
4469
|
+
return str(content)
|
|
4470
|
+
|
|
4471
|
+
user_text = extract_text(user_content)
|
|
4472
|
+
|
|
4473
|
+
# For assistant, we might not have it recorded, so use tool calls as completion
|
|
4474
|
+
assistant_text = ""
|
|
4475
|
+
if assistant_msg:
|
|
4476
|
+
assistant_content_raw = assistant_msg.get("content")
|
|
4477
|
+
try:
|
|
4478
|
+
assistant_content = json.loads(assistant_content_raw) if isinstance(assistant_content_raw, str) else assistant_content_raw
|
|
4479
|
+
except Exception:
|
|
4480
|
+
assistant_content = assistant_content_raw
|
|
4481
|
+
assistant_text = extract_text(assistant_content)
|
|
4482
|
+
|
|
4483
|
+
if not user_text:
|
|
4484
|
+
continue
|
|
4485
|
+
|
|
4486
|
+
record = {
|
|
4487
|
+
"messages": [
|
|
4488
|
+
{"role": "user", "content": user_text},
|
|
4489
|
+
{"role": "assistant", "content": assistant_text if assistant_text else "[no response recorded]"},
|
|
4490
|
+
],
|
|
4491
|
+
"metadata": {
|
|
4492
|
+
"session_id": session_id,
|
|
4493
|
+
"env_name": metadata.get("env_name"),
|
|
4494
|
+
"policy_name": metadata.get("policy_name"),
|
|
4495
|
+
"seed": metadata.get("seed"),
|
|
4496
|
+
"total_reward": total_reward,
|
|
4497
|
+
"achievements_count": achievements_count,
|
|
4498
|
+
"model": metadata.get("model"),
|
|
4499
|
+
"created_at": created_at_dt.isoformat() if created_at_dt else created_at_raw,
|
|
4500
|
+
},
|
|
4501
|
+
}
|
|
4502
|
+
accepted.append(record)
|
|
4503
|
+
|
|
4504
|
+
if not accepted:
|
|
4505
|
+
raise click.ClickException("No sessions matched the provided filters")
|
|
4506
|
+
|
|
4507
|
+
if limit is not None and limit > 0:
|
|
4508
|
+
accepted = accepted[:limit]
|
|
4509
|
+
|
|
4510
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
4511
|
+
with output_path.open("w", encoding="utf-8") as handle:
|
|
4512
|
+
for item in accepted:
|
|
4513
|
+
handle.write(json.dumps(item, ensure_ascii=False))
|
|
4514
|
+
handle.write("\n")
|
|
4515
|
+
|
|
4516
|
+
click.echo(f"Wrote {len(accepted)} examples -> {output_path}")
|
|
4517
|
+
await tracer.db.close()
|
|
4518
|
+
|
|
4519
|
+
asyncio.run(_run_filter())
|
|
4520
|
+
|
|
2789
4521
|
|
|
2790
4522
|
def register_eval(cli: click.Group) -> None:
|
|
2791
4523
|
cli.add_command(eval_command)
|