synth-ai 0.2.8.dev4__py3-none-any.whl ā 0.2.23.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/README.md +1 -0
- examples/__init__.py +16 -0
- examples/analyze_semantic_words.sh +17 -0
- examples/baseline/banking77_baseline.py +243 -0
- examples/baseline/banking77_pipeline_baseline.py +294 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +80 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +50 -0
- examples/blog_posts/gepa/configs/banking77_pipeline_gepa_local.toml +101 -0
- examples/blog_posts/gepa/configs/banking77_pipeline_gepa_test.toml +96 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +57 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +35 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +51 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +57 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +35 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +51 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +57 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +35 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +51 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +58 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +52 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +54 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +112 -0
- examples/blog_posts/gepa/run_gepa_banking77_pipeline.sh +163 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/mipro/README.md +415 -0
- examples/blog_posts/mipro/configs/banking77_mipro_local.toml +91 -0
- examples/blog_posts/mipro/configs/banking77_mipro_test.toml +87 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_gemini_flash_lite_local.toml +98 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_gpt41mini_local.toml +96 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_local.toml +94 -0
- examples/blog_posts/mipro/configs/banking77_pipeline_mipro_test.toml +170 -0
- examples/blog_posts/mipro/deploy_banking77_pipeline_task_app.sh +59 -0
- examples/blog_posts/mipro/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/mipro/multi_step.md +79 -0
- examples/blog_posts/mipro/run_mipro_banking77.sh +191 -0
- examples/blog_posts/mipro/run_mipro_banking77_pipeline.sh +171 -0
- examples/blog_posts/mipro/run_mipro_banking77_pipeline_gemini_flash_lite.sh +177 -0
- examples/blog_posts/mipro/run_mipro_banking77_pipeline_gpt41mini.sh +173 -0
- examples/blog_posts/mipro/verify_banking77_setup.sh +117 -0
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +27 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +43 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +91 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/crafter_debug_render.py +186 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +45 -0
- examples/gepa/banking77_pipeline_gepa.toml +96 -0
- examples/gepa/multi_stage_gepa_example.toml +84 -0
- examples/gepa/run_gepa_banking77_pipeline.sh +157 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +103 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +196 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +75 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +145 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +84 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +79 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +147 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/crafter_rl_lora.md +70 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +494 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +60 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_small.toml +57 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +152 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +274 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +415 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +61 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +62 -0
- examples/rl/configs/rl_from_base_qwen17.toml +80 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/download_dataset.py +80 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +21 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/run_crafter_demo.sh +10 -0
- examples/sdk_prompt_learning_example.py +55 -0
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +49 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +49 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +120 -0
- examples/sft/generate_traces.py +164 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +135 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +604 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +124 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1191 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +584 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1094 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1905 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +136 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/TESTING.md +275 -0
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +912 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/banking77_pipeline/__init__.py +6 -0
- examples/task_apps/banking77_pipeline/banking77_pipeline_task_app.py +489 -0
- examples/task_apps/banking77_pipeline/deploy_wrapper.py +50 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +286 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +187 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +281 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/README.md +42 -0
- examples/task_apps/crafter/task_app/__init__.py +5 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +1055 -0
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +146 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/README.md +173 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/branching.py +143 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +532 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +583 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +122 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +253 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +999 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/main.py +100 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +1252 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/registry.py +195 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +2233 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/test_service.py +136 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +411 -0
- examples/task_apps/dev/pokemon_emerald/__init__.py +2 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/README.md +811 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/__init__.py +120 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/action.py +160 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/memory.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/perception.py +69 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/planning.py +96 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/simple.py +1502 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/agent/system_prompt.py +4 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/grab_map.py +68 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/manual.py +216 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/__init__.py +35 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emerald_utils.py +631 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/emulator.py +1544 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/enums.py +1428 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/memory_reader.py +4848 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/types.py +41 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pokemon_env/utils.py +298 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/pyproject.toml +95 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/run.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/app.py +2152 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/client.py +429 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server/frame_server.py +155 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/README.md +78 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/run_tests.py +122 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_direct.py +76 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_agent_prompts.py +413 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_battle_state_formatting.py +204 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection.py +133 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_dialogue_detection_comprehensive.py +229 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_direct_agent_emulator.py +300 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_fps_adjustment_pytest.py +205 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_direct.py +200 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_house_to_outside_transition.py +284 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_map_ground_truth_comparison.py +468 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_memory_map.py +575 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_server_map_validation.py +311 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests/test_torchic_state.py +259 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/anticheat.py +372 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/checkpoint.py +296 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/error_handler.py +275 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/get_local_ip.py +22 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/helpers.py +44 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/llm_logger.py +514 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_formatter.py +415 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher.py +1763 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_stitcher_singleton.py +33 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_trimmer.py +106 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/map_visualizer.py +334 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/ocr_dialogue.py +1020 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/recording.py +188 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/state_formatter.py +1481 -0
- examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils/vlm.py +862 -0
- examples/task_apps/dev/pokemon_emerald/modal_app.py +114 -0
- examples/task_apps/dev/pokemon_emerald/task_app/README.md +81 -0
- examples/task_apps/dev/pokemon_emerald/task_app/__init__.py +6 -0
- examples/task_apps/dev/pokemon_emerald/task_app/pokemon_emerald.py +685 -0
- examples/task_apps/enron/__init__.py +2 -0
- examples/task_apps/enron/eval_groq_qwen32.toml +16 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/task_app/README.md +14 -0
- examples/task_apps/enron/task_app/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron.py +906 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +146 -0
- examples/task_apps/enron/tests/__init__.py +4 -0
- examples/task_apps/enron/tests/conftest.py +115 -0
- examples/task_apps/enron/tests/integration/__init__.py +4 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +179 -0
- examples/task_apps/enron/tests/integration/test_enron_rollout.py +135 -0
- examples/task_apps/enron/tests/unit/__init__.py +4 -0
- examples/task_apps/enron/tests/unit/test_enron_environment.py +126 -0
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/math/README.md +21 -0
- examples/task_apps/math/math_single_step.py +1000 -0
- examples/task_apps/math/math_task_app.py +115 -0
- examples/task_apps/pokemon_battle/__init__.py +2 -0
- examples/task_apps/pokemon_battle/modal_app.py +104 -0
- examples/task_apps/pokemon_battle/task_app/README.md +68 -0
- examples/task_apps/pokemon_battle/task_app/__init__.py +6 -0
- examples/task_apps/pokemon_battle/task_app/pokemon_showdown.py +932 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README.md +356 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +428 -0
- examples/task_apps/pokemon_red/__init__.py +3 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +30 -0
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +224 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +75 -0
- examples/task_apps/pokemon_red/task_app.py +1048 -0
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +193 -0
- examples/task_apps/sokoban/README.md +306 -0
- examples/task_apps/sokoban/__init__.py +3 -0
- examples/task_apps/sokoban/eval_groq_qwen32.toml +16 -0
- examples/task_apps/sokoban/eval_openai_gpt5.toml +16 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/task_app.py +1058 -0
- examples/task_apps/sokoban/tests/__init__.py +4 -0
- examples/task_apps/sokoban/tests/conftest.py +113 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +4 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_eval.py +57 -0
- examples/task_apps/sokoban/tests/integration/test_sokoban_rollout.py +198 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +4 -0
- examples/task_apps/sokoban/tests/unit/test_sokoban_environment.py +114 -0
- examples/task_apps/verilog/__init__.py +1 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +22 -0
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/README.md +12 -0
- examples/task_apps/verilog/task_app/__init__.py +1 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +1166 -0
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +145 -0
- examples/task_apps/verilog/tests/__init__.py +4 -0
- examples/task_apps/verilog/tests/conftest.py +115 -0
- examples/task_apps/verilog/tests/integration/__init__.py +4 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +181 -0
- examples/task_apps/verilog/tests/integration/test_verilog_rollout.py +55 -0
- examples/task_apps/verilog/tests/unit/__init__.py +4 -0
- examples/task_apps/verilog/tests/unit/test_verilog_scoring.py +118 -0
- examples/tunnel_gepa_banking77/README.md +106 -0
- examples/tunnel_gepa_banking77/banking77_gepa_tunnel.toml +95 -0
- examples/tunnel_gepa_banking77/keep_tunnel_running.py +60 -0
- examples/tunnel_gepa_banking77/run_gepa_with_tunnel.sh +226 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +49 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +275 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +422 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +53 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +22 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +15 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +24 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +85 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +58 -0
- examples/warming_up_to_rl/export_trace_sft.py +837 -0
- examples/warming_up_to_rl/groq_test.py +97 -0
- examples/warming_up_to_rl/manage_secrets.py +131 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +110 -0
- examples/warming_up_to_rl/run_eval.py +736 -0
- examples/warming_up_to_rl/run_fft_and_save.py +380 -0
- examples/warming_up_to_rl/run_local_rollout.py +239 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +248 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +405 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +477 -0
- examples/warming_up_to_rl/run_rl_and_save.py +124 -0
- examples/warming_up_to_rl/run_rollout_remote.py +156 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +876 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +253 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +729 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1114 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1891 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +129 -0
- examples/workflows/math_rl/configs/eval_base_qwen.toml +15 -0
- examples/workflows/math_rl/configs/eval_rl_qwen.toml +11 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +62 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +80 -0
- examples/workflows/math_rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/workflows/math_rl/download_dataset.py +80 -0
- examples/workflows/math_rl/run_eval.py +436 -0
- examples/workflows/math_rl/run_rl_and_save.py +111 -0
- synth_ai/__init__.py +47 -23
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +514 -0
- synth_ai/api/train/__init__.py +63 -0
- synth_ai/api/train/builders.py +473 -0
- synth_ai/api/train/cli.py +1185 -0
- synth_ai/api/train/config_finder.py +246 -0
- synth_ai/api/train/configs/__init__.py +65 -0
- synth_ai/api/train/configs/prompt_learning.py +496 -0
- synth_ai/api/train/configs/rl.py +188 -0
- synth_ai/api/train/configs/sft.py +99 -0
- synth_ai/api/train/configs/shared.py +81 -0
- synth_ai/api/train/env_resolver.py +352 -0
- synth_ai/api/train/pollers.py +91 -0
- synth_ai/api/train/prompt_learning.py +425 -0
- synth_ai/api/train/sft.py +390 -0
- synth_ai/api/train/supported_algos.py +147 -0
- synth_ai/api/train/task_app.py +195 -0
- synth_ai/api/train/utils.py +244 -0
- synth_ai/api/train/validators.py +1117 -0
- synth_ai/api/tunnel.py +49 -0
- synth_ai/auth/credentials.py +94 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cfgs.py +227 -0
- synth_ai/cli/__init__.py +90 -45
- synth_ai/cli/_modal_wrapper.py +31 -0
- synth_ai/cli/_storage.py +20 -0
- synth_ai/cli/_typer_patch.py +47 -0
- synth_ai/cli/_validate_task_app.py +29 -0
- synth_ai/cli/balance.py +16 -4
- synth_ai/cli/calc.py +36 -21
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +267 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1112 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +185 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1437 -0
- synth_ai/cli/commands/status/__init__.py +66 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/session.py +183 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +200 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/validation.py +386 -0
- synth_ai/cli/demo.py +32 -140
- synth_ai/cli/deploy.py +233 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +28 -22
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/mcp.py +34 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/opencode.py +256 -0
- synth_ai/cli/recent.py +13 -7
- synth_ai/cli/rl_demo.py +166 -114
- synth_ai/cli/root.py +143 -112
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +49 -0
- synth_ai/cli/status.py +7 -125
- synth_ai/cli/task_app_deploy.py +7 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +11 -0
- synth_ai/cli/task_app_serve.py +11 -0
- synth_ai/cli/task_apps.py +3134 -0
- synth_ai/cli/traces.py +9 -5
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +5 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +13 -18
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/core/cli.py +745 -416
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/core.py +75 -37
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +53 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +184 -0
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +491 -166
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +37 -0
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +703 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +12 -5
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/environment.py +93 -2
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/environments/examples/enron/engine.py +7 -2
- synth_ai/environments/examples/enron/environment.py +68 -0
- synth_ai/environments/examples/red/engine.py +60 -12
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +7 -0
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_progression.py +477 -0
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +32 -0
- synth_ai/environments/examples/red/environment.py +86 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/environments/examples/sokoban/taskset.py +116 -0
- synth_ai/environments/examples/verilog/engine.py +104 -12
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/base.py +14 -5
- synth_ai/evals/client.py +82 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/http.py +8 -22
- synth_ai/http_client.py +45 -12
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +21 -7
- synth_ai/jobs/client.py +129 -80
- synth_ai/judge_schemas.py +127 -0
- synth_ai/learning/__init__.py +51 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +122 -30
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +14 -8
- synth_ai/learning/jobs.py +43 -47
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +185 -0
- synth_ai/{rl ā learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +269 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl ā learning/rl}/contracts.py +5 -10
- synth_ai/{rl ā learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +698 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +29 -25
- synth_ai/mcp/__init__.py +5 -0
- synth_ai/mcp/__main__.py +8 -0
- synth_ai/mcp/main.py +254 -0
- synth_ai/mcp/setup.py +100 -0
- synth_ai/modal.py +257 -0
- synth_ai/pricing/__init__.py +3 -0
- synth_ai/pricing/model_pricing.py +64 -0
- synth_ai/session/__init__.py +75 -0
- synth_ai/session/client.py +383 -0
- synth_ai/session/constants.py +63 -0
- synth_ai/session/exceptions.py +105 -0
- synth_ai/session/manager.py +139 -0
- synth_ai/session/models.py +89 -0
- synth_ai/session/query.py +110 -0
- synth_ai/spec/__init__.py +46 -0
- synth_ai/spec/dataclasses.py +149 -0
- synth_ai/spec/loader.py +144 -0
- synth_ai/spec/serializer.py +199 -0
- synth_ai/spec/validation.py +250 -0
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +589 -0
- synth_ai/streaming/streamer.py +320 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/__init__.py +116 -3
- synth_ai/task/apps/__init__.py +132 -0
- synth_ai/task/auth.py +165 -0
- synth_ai/task/client.py +167 -0
- synth_ai/task/config.py +261 -0
- synth_ai/task/contracts.py +173 -57
- synth_ai/task/datasets.py +108 -0
- synth_ai/task/errors.py +50 -0
- synth_ai/task/health.py +17 -11
- synth_ai/task/inference_api.py +101 -0
- synth_ai/task/json.py +111 -0
- synth_ai/task/proxy.py +251 -0
- synth_ai/task/rubrics/__init__.py +55 -0
- synth_ai/task/rubrics/loaders.py +156 -0
- synth_ai/task/rubrics/models.py +57 -0
- synth_ai/task/rubrics/scoring.py +116 -0
- synth_ai/task/rubrics/strict.py +149 -0
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/server.py +432 -0
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +95 -0
- synth_ai/task/validators.py +449 -6
- synth_ai/task/vendors.py +59 -0
- synth_ai/tracing_v3/__init__.py +4 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/config.py +167 -22
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +42 -29
- synth_ai/tracing_v3/decorators.py +80 -45
- synth_ai/tracing_v3/examples/basic_usage.py +15 -9
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +161 -61
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/replica_sync.py +12 -7
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/session_tracer.py +86 -21
- synth_ai/tracing_v3/storage/base.py +98 -12
- synth_ai/tracing_v3/storage/config.py +63 -16
- synth_ai/tracing_v3/storage/factory.py +11 -9
- synth_ai/tracing_v3/storage/utils.py +15 -11
- synth_ai/tracing_v3/trace_utils.py +317 -0
- synth_ai/tracing_v3/turso/__init__.py +8 -21
- synth_ai/tracing_v3/turso/daemon.py +123 -15
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1293 -0
- synth_ai/tracing_v3/utils.py +5 -4
- synth_ai/tunnel.py +143 -0
- synth_ai/tunnel_deploy.py +278 -0
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +166 -0
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/apps.py +152 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/claude.py +36 -0
- synth_ai/utils/cli.py +284 -0
- synth_ai/utils/config.py +81 -0
- synth_ai/utils/env.py +346 -0
- synth_ai/utils/errors.py +85 -0
- synth_ai/utils/http.py +172 -0
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/log_filter.py +99 -0
- synth_ai/utils/logging.py +198 -0
- synth_ai/utils/modal.py +299 -0
- synth_ai/utils/paths.py +95 -0
- synth_ai/utils/process.py +233 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/ssl.py +25 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/tunnel/__init__.py +12 -0
- synth_ai/utils/tunnel/config.py +55 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/uvicorn.py +77 -0
- synth_ai-0.2.23.dev3.dist-info/METADATA +357 -0
- synth_ai-0.2.23.dev3.dist-info/RECORD +983 -0
- {synth_ai-0.2.8.dev4.dist-info ā synth_ai-0.2.23.dev3.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.8.dev4.dist-info ā synth_ai-0.2.23.dev3.dist-info}/top_level.txt +1 -0
- synth_ai/cli/man.py +0 -106
- synth_ai/core/experiment.py +0 -15
- synth_ai/core/system.py +0 -15
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/handshake.py +0 -63
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/lm/__init__.py +0 -51
- synth_ai/lm/caching/constants.py +0 -6
- synth_ai/lm/caching/dbs.py +0 -0
- synth_ai/lm/caching/ephemeral.py +0 -102
- synth_ai/lm/caching/handler.py +0 -137
- synth_ai/lm/caching/initialize.py +0 -11
- synth_ai/lm/caching/persistent.py +0 -114
- synth_ai/lm/config.py +0 -110
- synth_ai/lm/constants.py +0 -32
- synth_ai/lm/core/__init__.py +0 -8
- synth_ai/lm/core/all.py +0 -73
- synth_ai/lm/core/exceptions.py +0 -7
- synth_ai/lm/core/main.py +0 -319
- synth_ai/lm/core/main_v3.py +0 -594
- synth_ai/lm/core/synth_models.py +0 -48
- synth_ai/lm/core/vendor_clients.py +0 -188
- synth_ai/lm/cost/monitor.py +0 -1
- synth_ai/lm/cost/statefulness.py +0 -1
- synth_ai/lm/injection.py +0 -80
- synth_ai/lm/overrides.py +0 -206
- synth_ai/lm/provider_support/__init__.py +0 -8
- synth_ai/lm/provider_support/anthropic.py +0 -972
- synth_ai/lm/provider_support/openai.py +0 -1139
- synth_ai/lm/provider_support/suppress_logging.py +0 -31
- synth_ai/lm/structured_outputs/handler.py +0 -440
- synth_ai/lm/structured_outputs/inject.py +0 -297
- synth_ai/lm/structured_outputs/rehabilitate.py +0 -185
- synth_ai/lm/tools/__init__.py +0 -3
- synth_ai/lm/tools/base.py +0 -172
- synth_ai/lm/unified_interface.py +0 -202
- synth_ai/lm/vendors/base.py +0 -81
- synth_ai/lm/vendors/core/anthropic_api.py +0 -387
- synth_ai/lm/vendors/core/gemini_api.py +0 -292
- synth_ai/lm/vendors/core/mistral_api.py +0 -322
- synth_ai/lm/vendors/core/openai_api.py +0 -225
- synth_ai/lm/vendors/core/synth_dev_api.py +0 -0
- synth_ai/lm/vendors/local/ollama.py +0 -0
- synth_ai/lm/vendors/openai_standard.py +0 -780
- synth_ai/lm/vendors/openai_standard_responses.py +0 -256
- synth_ai/lm/vendors/retries.py +0 -22
- synth_ai/lm/vendors/supported/custom_endpoint.py +0 -417
- synth_ai/lm/vendors/supported/deepseek.py +0 -69
- synth_ai/lm/vendors/supported/grok.py +0 -75
- synth_ai/lm/vendors/supported/groq.py +0 -16
- synth_ai/lm/vendors/supported/ollama.py +0 -15
- synth_ai/lm/vendors/supported/openrouter.py +0 -74
- synth_ai/lm/vendors/supported/together.py +0 -11
- synth_ai/lm/vendors/synth_client.py +0 -808
- synth_ai/lm/warmup.py +0 -186
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/manager.py +0 -760
- synth_ai/v0/tracing/abstractions.py +0 -224
- synth_ai/v0/tracing/base_client.py +0 -91
- synth_ai/v0/tracing/client_manager.py +0 -131
- synth_ai/v0/tracing/config.py +0 -142
- synth_ai/v0/tracing/context.py +0 -146
- synth_ai/v0/tracing/decorators.py +0 -682
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/v0/tracing/events/manage.py +0 -147
- synth_ai/v0/tracing/events/scope.py +0 -86
- synth_ai/v0/tracing/events/store.py +0 -228
- synth_ai/v0/tracing/immediate_client.py +0 -151
- synth_ai/v0/tracing/local.py +0 -18
- synth_ai/v0/tracing/log_client_base.py +0 -73
- synth_ai/v0/tracing/retry_queue.py +0 -186
- synth_ai/v0/tracing/trackers.py +0 -515
- synth_ai/v0/tracing/upload.py +0 -512
- synth_ai/v0/tracing/utils.py +0 -9
- synth_ai/v0/tracing_v1/__init__.py +0 -16
- synth_ai/v0/tracing_v1/abstractions.py +0 -224
- synth_ai/v0/tracing_v1/base_client.py +0 -91
- synth_ai/v0/tracing_v1/client_manager.py +0 -131
- synth_ai/v0/tracing_v1/config.py +0 -142
- synth_ai/v0/tracing_v1/context.py +0 -146
- synth_ai/v0/tracing_v1/decorators.py +0 -703
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/v0/tracing_v1/events/manage.py +0 -147
- synth_ai/v0/tracing_v1/events/scope.py +0 -86
- synth_ai/v0/tracing_v1/events/store.py +0 -228
- synth_ai/v0/tracing_v1/immediate_client.py +0 -151
- synth_ai/v0/tracing_v1/local.py +0 -18
- synth_ai/v0/tracing_v1/log_client_base.py +0 -73
- synth_ai/v0/tracing_v1/retry_queue.py +0 -186
- synth_ai/v0/tracing_v1/trackers.py +0 -515
- synth_ai/v0/tracing_v1/upload.py +0 -527
- synth_ai/v0/tracing_v1/utils.py +0 -9
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.8.dev4.dist-info/METADATA +0 -129
- synth_ai-0.2.8.dev4.dist-info/RECORD +0 -420
- {synth_ai/lm/caching ā examples/task_apps}/__init__.py +0 -0
- {synth_ai/lm/cost ā examples/task_apps/crafter}/__init__.py +0 -0
- {synth_ai/lm/structured_outputs ā examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/server}/__init__.py +0 -0
- {synth_ai/lm/vendors ā examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/tests}/__init__.py +0 -0
- {synth_ai/lm/vendors/core ā examples/task_apps/dev/pokemon_emerald/external/pokeagent-speedrun/utils}/__init__.py +0 -0
- {synth_ai/lm/vendors/local ā examples/task_apps/math}/__init__.py +0 -0
- {synth_ai/lm/vendors/supported ā examples/workflows}/__init__.py +0 -0
- {synth_ai/v0/tracing ā examples/workflows/math_rl}/__init__.py +0 -0
- /synth_ai/{compound/cais.py ā cli/__main__.py} +0 -0
- /synth_ai/{learning/filtering.py ā py.typed} +0 -0
- {synth_ai-0.2.8.dev4.dist-info ā synth_ai-0.2.23.dev3.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev4.dist-info ā synth_ai-0.2.23.dev3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1117 @@
|
|
|
1
|
+
"""SDK-side validation for training configs - catch errors BEFORE sending to backend."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConfigValidationError(Exception):
|
|
11
|
+
"""Raised when a training config is invalid."""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Supported models for prompt learning (GEPA & MIPRO)
|
|
16
|
+
# NOTE: gpt-5-pro is explicitly EXCLUDED - too expensive for prompt learning
|
|
17
|
+
OPENAI_SUPPORTED_MODELS = {
|
|
18
|
+
"gpt-4o",
|
|
19
|
+
"gpt-4o-mini",
|
|
20
|
+
"gpt-4.1",
|
|
21
|
+
"gpt-4.1-mini",
|
|
22
|
+
"gpt-4.1-nano",
|
|
23
|
+
"gpt-5",
|
|
24
|
+
"gpt-5-mini",
|
|
25
|
+
"gpt-5-nano",
|
|
26
|
+
# Explicitly EXCLUDED: "gpt-5-pro" - too expensive
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
# Groq supported models - patterns and exact matches
|
|
30
|
+
# Models can be in format "model-name" or "provider/model-name" (e.g., "openai/gpt-oss-20b")
|
|
31
|
+
GROQ_SUPPORTED_PATTERNS = [
|
|
32
|
+
re.compile(r"^(openai/)?gpt-oss-\d+b"), # e.g., gpt-oss-20b, openai/gpt-oss-120b
|
|
33
|
+
re.compile(r"^(llama-3\.3-70b|groq/llama-3\.3-70b)"), # e.g., llama-3.3-70b-versatile
|
|
34
|
+
re.compile(r"^(qwen.*32b|groq/qwen.*32b)"), # e.g., qwen-32b, qwen3-32b, groq/qwen3-32b
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
GROQ_EXACT_MATCHES = {
|
|
38
|
+
"llama-3.3-70b",
|
|
39
|
+
"qwen-32b",
|
|
40
|
+
"qwen3-32b",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Google/Gemini supported models
|
|
44
|
+
GOOGLE_SUPPORTED_MODELS = {
|
|
45
|
+
"gemini-2.5-pro",
|
|
46
|
+
"gemini-2.5-pro-gt200k",
|
|
47
|
+
"gemini-2.5-flash",
|
|
48
|
+
"gemini-2.5-flash-lite",
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _is_supported_openai_model(model: str) -> bool:
|
|
53
|
+
"""Check if model is a supported OpenAI model."""
|
|
54
|
+
model_lower = model.lower().strip()
|
|
55
|
+
# Strip provider prefix if present (e.g., "openai/gpt-4o" -> "gpt-4o")
|
|
56
|
+
if "/" in model_lower:
|
|
57
|
+
model_lower = model_lower.split("/", 1)[1]
|
|
58
|
+
return model_lower in {m.lower() for m in OPENAI_SUPPORTED_MODELS}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_supported_groq_model(model: str) -> bool:
|
|
62
|
+
"""Check if model is a supported Groq model."""
|
|
63
|
+
model_lower = model.lower().strip()
|
|
64
|
+
|
|
65
|
+
# Remove provider prefix if present (e.g., "openai/gpt-oss-20b" -> "gpt-oss-20b")
|
|
66
|
+
if "/" in model_lower:
|
|
67
|
+
model_lower = model_lower.split("/", 1)[1]
|
|
68
|
+
|
|
69
|
+
# Check exact matches first
|
|
70
|
+
if model_lower in {m.lower() for m in GROQ_EXACT_MATCHES}:
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
# Check patterns (patterns already handle provider prefix)
|
|
74
|
+
return any(pattern.match(model.lower().strip()) for pattern in GROQ_SUPPORTED_PATTERNS)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _is_supported_google_model(model: str) -> bool:
|
|
78
|
+
"""Check if model is a supported Google/Gemini model."""
|
|
79
|
+
model_lower = model.lower().strip()
|
|
80
|
+
# Strip provider prefix if present (e.g., "google/gemini-2.5-flash-lite" -> "gemini-2.5-flash-lite")
|
|
81
|
+
if "/" in model_lower:
|
|
82
|
+
model_lower = model_lower.split("/", 1)[1]
|
|
83
|
+
return model_lower in {m.lower() for m in GOOGLE_SUPPORTED_MODELS}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _validate_model_for_provider(model: str, provider: str, field_name: str, *, allow_nano: bool = False) -> list[str]:
|
|
87
|
+
"""
|
|
88
|
+
Validate that a model is supported for the given provider.
|
|
89
|
+
|
|
90
|
+
Models can be specified with or without provider prefix (e.g., "gpt-4o" or "openai/gpt-4o").
|
|
91
|
+
The provider prefix is stripped before validation.
|
|
92
|
+
|
|
93
|
+
REJECTS gpt-5-pro explicitly (too expensive).
|
|
94
|
+
REJECTS nano models for proposal/mutation models (unless allow_nano=True).
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
model: Model name to validate
|
|
98
|
+
provider: Provider name (openai, groq, google)
|
|
99
|
+
field_name: Field name for error messages (e.g., "prompt_learning.policy.model")
|
|
100
|
+
allow_nano: If True, allow nano models (for policy models). If False, reject nano models.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
List of error messages (empty if valid)
|
|
104
|
+
"""
|
|
105
|
+
errors: list[str] = []
|
|
106
|
+
|
|
107
|
+
if not model or not isinstance(model, str) or not model.strip():
|
|
108
|
+
errors.append(f"Missing or empty {field_name}")
|
|
109
|
+
return errors
|
|
110
|
+
|
|
111
|
+
provider_lower = provider.lower().strip()
|
|
112
|
+
model_lower = model.lower().strip()
|
|
113
|
+
|
|
114
|
+
# Strip provider prefix if present (e.g., "openai/gpt-4o" -> "gpt-4o")
|
|
115
|
+
model_without_prefix = model_lower.split("/", 1)[1] if "/" in model_lower else model_lower
|
|
116
|
+
|
|
117
|
+
# Explicitly reject gpt-5-pro (too expensive)
|
|
118
|
+
if model_without_prefix == "gpt-5-pro":
|
|
119
|
+
errors.append(
|
|
120
|
+
f"Model '{model}' is not supported for prompt learning (too expensive).\n"
|
|
121
|
+
f" gpt-5-pro is excluded due to high cost ($15/$120 per 1M tokens).\n"
|
|
122
|
+
f" Please use a supported model instead."
|
|
123
|
+
)
|
|
124
|
+
return errors
|
|
125
|
+
|
|
126
|
+
# Reject nano models for proposal/mutation models (unless explicitly allowed)
|
|
127
|
+
if not allow_nano and model_without_prefix.endswith("-nano"):
|
|
128
|
+
errors.append(
|
|
129
|
+
f"Model '{model}' is not supported for {field_name}.\n"
|
|
130
|
+
f" ā Nano models (e.g., gpt-4.1-nano, gpt-5-nano) are NOT allowed for proposal/mutation models.\n"
|
|
131
|
+
f" \n"
|
|
132
|
+
f" Why?\n"
|
|
133
|
+
f" Proposal and mutation models need to be SMART and capable of generating high-quality,\n"
|
|
134
|
+
f" creative prompt variations. Nano models are too small and lack the reasoning capability\n"
|
|
135
|
+
f" needed for effective prompt optimization.\n"
|
|
136
|
+
f" \n"
|
|
137
|
+
f" ā
Use a larger model instead:\n"
|
|
138
|
+
f" - For OpenAI: gpt-4.1-mini, gpt-4o-mini, gpt-4o, or gpt-4.1\n"
|
|
139
|
+
f" - For Groq: openai/gpt-oss-120b, llama-3.3-70b-versatile\n"
|
|
140
|
+
f" - For Google: gemini-2.5-flash, gemini-2.5-pro\n"
|
|
141
|
+
f" \n"
|
|
142
|
+
f" Note: Nano models ARE allowed for policy models (task execution), but NOT for\n"
|
|
143
|
+
f" proposal/mutation models (prompt generation)."
|
|
144
|
+
)
|
|
145
|
+
return errors
|
|
146
|
+
|
|
147
|
+
if provider_lower == "openai":
|
|
148
|
+
if not _is_supported_openai_model(model_without_prefix):
|
|
149
|
+
errors.append(
|
|
150
|
+
f"Unsupported OpenAI model: '{model}'\n"
|
|
151
|
+
f" Supported OpenAI models for prompt learning:\n"
|
|
152
|
+
f" - gpt-4o\n"
|
|
153
|
+
f" - gpt-4o-mini\n"
|
|
154
|
+
f" - gpt-4.1, gpt-4.1-mini, gpt-4.1-nano\n"
|
|
155
|
+
f" - gpt-5, gpt-5-mini, gpt-5-nano\n"
|
|
156
|
+
f" Note: gpt-5-pro is excluded (too expensive)\n"
|
|
157
|
+
f" Got: '{model}'"
|
|
158
|
+
)
|
|
159
|
+
elif provider_lower == "groq":
|
|
160
|
+
# For Groq, check both with and without prefix since models can be "openai/gpt-oss-20b"
|
|
161
|
+
if not _is_supported_groq_model(model_lower):
|
|
162
|
+
errors.append(
|
|
163
|
+
f"Unsupported Groq model: '{model}'\n"
|
|
164
|
+
f" Supported Groq models for prompt learning:\n"
|
|
165
|
+
f" - gpt-oss-Xb (e.g., gpt-oss-20b, openai/gpt-oss-120b)\n"
|
|
166
|
+
f" - llama-3.3-70b (and variants like llama-3.3-70b-versatile)\n"
|
|
167
|
+
f" - qwen/qwen3-32b (and variants)\n"
|
|
168
|
+
f" Got: '{model}'"
|
|
169
|
+
)
|
|
170
|
+
elif provider_lower == "google":
|
|
171
|
+
if not _is_supported_google_model(model_without_prefix):
|
|
172
|
+
errors.append(
|
|
173
|
+
f"Unsupported Google/Gemini model: '{model}'\n"
|
|
174
|
+
f" Supported Google models for prompt learning:\n"
|
|
175
|
+
f" - gemini-2.5-pro, gemini-2.5-pro-gt200k\n"
|
|
176
|
+
f" - gemini-2.5-flash\n"
|
|
177
|
+
f" - gemini-2.5-flash-lite\n"
|
|
178
|
+
f" Got: '{model}'"
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
errors.append(
|
|
182
|
+
f"Unsupported provider: '{provider}'\n"
|
|
183
|
+
f" Supported providers for prompt learning: 'openai', 'groq', 'google'\n"
|
|
184
|
+
f" Got: '{provider}'"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return errors
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def validate_prompt_learning_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Validate prompt learning config BEFORE sending to backend.
|
|
193
|
+
|
|
194
|
+
This catches common errors early with clear messages instead of cryptic backend errors.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
config_data: Parsed TOML/JSON config
|
|
198
|
+
config_path: Path to config file (for error messages)
|
|
199
|
+
|
|
200
|
+
Raises:
|
|
201
|
+
ConfigValidationError: If config is invalid
|
|
202
|
+
click.ClickException: If validation fails (for CLI)
|
|
203
|
+
"""
|
|
204
|
+
errors: list[str] = []
|
|
205
|
+
|
|
206
|
+
# Check for prompt_learning section
|
|
207
|
+
pl_section = config_data.get("prompt_learning")
|
|
208
|
+
if not pl_section:
|
|
209
|
+
errors.append(
|
|
210
|
+
"Missing [prompt_learning] section in config. "
|
|
211
|
+
"Expected: [prompt_learning] with algorithm, task_app_url, etc."
|
|
212
|
+
)
|
|
213
|
+
_raise_validation_errors(errors, config_path)
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
if not isinstance(pl_section, dict):
|
|
217
|
+
errors.append(
|
|
218
|
+
f"[prompt_learning] must be a table/dict, got {type(pl_section).__name__}"
|
|
219
|
+
)
|
|
220
|
+
_raise_validation_errors(errors, config_path)
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
# CRITICAL: Validate algorithm field
|
|
224
|
+
algorithm = pl_section.get("algorithm")
|
|
225
|
+
if not algorithm:
|
|
226
|
+
errors.append(
|
|
227
|
+
"Missing required field: prompt_learning.algorithm\n"
|
|
228
|
+
" Must be one of: 'gepa', 'mipro'\n"
|
|
229
|
+
" Example:\n"
|
|
230
|
+
" [prompt_learning]\n"
|
|
231
|
+
" algorithm = \"gepa\""
|
|
232
|
+
)
|
|
233
|
+
elif algorithm not in ("gepa", "mipro"):
|
|
234
|
+
errors.append(
|
|
235
|
+
f"Invalid algorithm: '{algorithm}'\n"
|
|
236
|
+
f" Must be one of: 'gepa', 'mipro'\n"
|
|
237
|
+
f" Got: '{algorithm}'"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Validate task_app_url
|
|
241
|
+
task_app_url = pl_section.get("task_app_url")
|
|
242
|
+
if not task_app_url:
|
|
243
|
+
errors.append(
|
|
244
|
+
"Missing required field: prompt_learning.task_app_url\n"
|
|
245
|
+
" Example:\n"
|
|
246
|
+
" task_app_url = \"http://127.0.0.1:8102\""
|
|
247
|
+
)
|
|
248
|
+
elif not isinstance(task_app_url, str):
|
|
249
|
+
errors.append(
|
|
250
|
+
f"task_app_url must be a string, got {type(task_app_url).__name__}"
|
|
251
|
+
)
|
|
252
|
+
elif not task_app_url.startswith(("http://", "https://")):
|
|
253
|
+
errors.append(
|
|
254
|
+
f"task_app_url must start with http:// or https://, got: '{task_app_url}'"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Validate initial_prompt if present
|
|
258
|
+
initial_prompt = pl_section.get("initial_prompt")
|
|
259
|
+
if initial_prompt:
|
|
260
|
+
if not isinstance(initial_prompt, dict):
|
|
261
|
+
errors.append(
|
|
262
|
+
f"prompt_learning.initial_prompt must be a table/dict, got {type(initial_prompt).__name__}"
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
# Validate messages array
|
|
266
|
+
messages = initial_prompt.get("messages")
|
|
267
|
+
if messages is not None:
|
|
268
|
+
if not isinstance(messages, list):
|
|
269
|
+
errors.append(
|
|
270
|
+
f"prompt_learning.initial_prompt.messages must be an array, got {type(messages).__name__}"
|
|
271
|
+
)
|
|
272
|
+
elif len(messages) == 0:
|
|
273
|
+
errors.append(
|
|
274
|
+
"prompt_learning.initial_prompt.messages is empty (must have at least one message)"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Validate policy config
|
|
278
|
+
policy = pl_section.get("policy")
|
|
279
|
+
if not policy or not isinstance(policy, dict):
|
|
280
|
+
errors.append("Missing [prompt_learning.policy] section or not a table")
|
|
281
|
+
else:
|
|
282
|
+
# Enforce inference_mode
|
|
283
|
+
mode = str(policy.get("inference_mode", "")).strip().lower()
|
|
284
|
+
if not mode:
|
|
285
|
+
errors.append("Missing required field: prompt_learning.policy.inference_mode (must be 'synth_hosted')")
|
|
286
|
+
elif mode != "synth_hosted":
|
|
287
|
+
errors.append("prompt_learning.policy.inference_mode must be 'synth_hosted' (bring_your_own unsupported)")
|
|
288
|
+
# Required fields for synth_hosted
|
|
289
|
+
provider = (policy.get("provider") or "").strip()
|
|
290
|
+
model = (policy.get("model") or "").strip()
|
|
291
|
+
if not provider:
|
|
292
|
+
errors.append("Missing required field: prompt_learning.policy.provider")
|
|
293
|
+
if not model:
|
|
294
|
+
errors.append("Missing required field: prompt_learning.policy.model")
|
|
295
|
+
else:
|
|
296
|
+
# Validate model is supported for the provider
|
|
297
|
+
if provider:
|
|
298
|
+
errors.extend(_validate_model_for_provider(
|
|
299
|
+
model, provider, "prompt_learning.policy.model", allow_nano=True
|
|
300
|
+
))
|
|
301
|
+
# Validate inference_url format if provided (even though trainer provides it in rollout requests)
|
|
302
|
+
inference_url = policy.get("inference_url")
|
|
303
|
+
if inference_url is not None:
|
|
304
|
+
if not isinstance(inference_url, str):
|
|
305
|
+
errors.append("prompt_learning.policy.inference_url must be a string")
|
|
306
|
+
else:
|
|
307
|
+
inference_url_stripped = inference_url.strip()
|
|
308
|
+
if inference_url_stripped and not inference_url_stripped.startswith(("http://", "https://")):
|
|
309
|
+
errors.append("prompt_learning.policy.inference_url must start with http:// or https://")
|
|
310
|
+
if not inference_url_stripped:
|
|
311
|
+
errors.append("prompt_learning.policy.inference_url must start with http:// or https://")
|
|
312
|
+
# inference_url is NOT required - trainer provides it in rollout requests
|
|
313
|
+
|
|
314
|
+
# Check for multi-stage/multi-module pipeline config
|
|
315
|
+
initial_prompt = pl_section.get("initial_prompt", {})
|
|
316
|
+
pipeline_modules: list[str | dict[str, Any]] = []
|
|
317
|
+
if isinstance(initial_prompt, dict):
|
|
318
|
+
metadata = initial_prompt.get("metadata", {})
|
|
319
|
+
pipeline_modules = metadata.get("pipeline_modules", [])
|
|
320
|
+
if not isinstance(pipeline_modules, list):
|
|
321
|
+
pipeline_modules = []
|
|
322
|
+
has_multi_stage = isinstance(pipeline_modules, list) and len(pipeline_modules) > 0
|
|
323
|
+
|
|
324
|
+
# Validate algorithm-specific config
|
|
325
|
+
if algorithm == "gepa":
|
|
326
|
+
gepa_config = pl_section.get("gepa")
|
|
327
|
+
if not gepa_config or not isinstance(gepa_config, dict):
|
|
328
|
+
errors.append("Missing [prompt_learning.gepa] section for GEPA algorithm")
|
|
329
|
+
else:
|
|
330
|
+
# Multi-stage validation
|
|
331
|
+
modules_config = gepa_config.get("modules")
|
|
332
|
+
if has_multi_stage:
|
|
333
|
+
if not modules_config or not isinstance(modules_config, list) or len(modules_config) == 0:
|
|
334
|
+
errors.append(
|
|
335
|
+
f"GEPA multi-stage pipeline detected (found {len(pipeline_modules)} modules in "
|
|
336
|
+
f"prompt_learning.initial_prompt.metadata.pipeline_modules), "
|
|
337
|
+
f"but [prompt_learning.gepa.modules] is missing or empty. "
|
|
338
|
+
f"Define module configs for each pipeline stage."
|
|
339
|
+
)
|
|
340
|
+
else:
|
|
341
|
+
# Validate module IDs match pipeline_modules
|
|
342
|
+
module_ids = []
|
|
343
|
+
for m in modules_config:
|
|
344
|
+
if isinstance(m, dict):
|
|
345
|
+
module_id = m.get("module_id") or m.get("stage_id")
|
|
346
|
+
if module_id:
|
|
347
|
+
module_ids.append(str(module_id).strip())
|
|
348
|
+
elif hasattr(m, "module_id"):
|
|
349
|
+
module_ids.append(str(m.module_id).strip())
|
|
350
|
+
elif hasattr(m, "stage_id"):
|
|
351
|
+
module_ids.append(str(m.stage_id).strip())
|
|
352
|
+
|
|
353
|
+
# Extract pipeline module names (can be strings or dicts with 'name' field)
|
|
354
|
+
pipeline_module_names = []
|
|
355
|
+
for m in pipeline_modules:
|
|
356
|
+
if isinstance(m, str):
|
|
357
|
+
pipeline_module_names.append(m.strip())
|
|
358
|
+
elif isinstance(m, dict):
|
|
359
|
+
name = m.get("name") or m.get("module_id") or m.get("stage_id")
|
|
360
|
+
if name:
|
|
361
|
+
pipeline_module_names.append(str(name).strip())
|
|
362
|
+
|
|
363
|
+
# Check for missing modules
|
|
364
|
+
missing_modules = set(pipeline_module_names) - set(module_ids)
|
|
365
|
+
if missing_modules:
|
|
366
|
+
errors.append(
|
|
367
|
+
f"Pipeline modules {sorted(missing_modules)} are missing from "
|
|
368
|
+
f"[prompt_learning.gepa.modules]. Each pipeline module must have a corresponding "
|
|
369
|
+
f"module config with matching module_id."
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Check for extra modules (warn but don't error)
|
|
373
|
+
extra_modules = set(module_ids) - set(pipeline_module_names)
|
|
374
|
+
if extra_modules:
|
|
375
|
+
# This is a warning, not an error - extra modules are allowed
|
|
376
|
+
pass
|
|
377
|
+
|
|
378
|
+
# Numeric sanity checks
|
|
379
|
+
def _pos_int(name: str) -> None:
|
|
380
|
+
val = gepa_config.get(name)
|
|
381
|
+
if val is not None:
|
|
382
|
+
try:
|
|
383
|
+
ival = int(val)
|
|
384
|
+
if ival <= 0:
|
|
385
|
+
errors.append(f"prompt_learning.gepa.{name} must be > 0")
|
|
386
|
+
except Exception:
|
|
387
|
+
errors.append(f"prompt_learning.gepa.{name} must be an integer")
|
|
388
|
+
|
|
389
|
+
def _pos_int_nested(section: str, name: str) -> None:
|
|
390
|
+
"""Check positive int in nested section."""
|
|
391
|
+
section_config = gepa_config.get(section)
|
|
392
|
+
if section_config and isinstance(section_config, dict):
|
|
393
|
+
val = section_config.get(name)
|
|
394
|
+
if val is not None:
|
|
395
|
+
try:
|
|
396
|
+
ival = int(val)
|
|
397
|
+
if ival <= 0:
|
|
398
|
+
errors.append(f"prompt_learning.gepa.{section}.{name} must be > 0")
|
|
399
|
+
except Exception:
|
|
400
|
+
errors.append(f"prompt_learning.gepa.{section}.{name} must be an integer")
|
|
401
|
+
|
|
402
|
+
def _non_neg_int(name: str) -> None:
|
|
403
|
+
"""Check non-negative int."""
|
|
404
|
+
val = gepa_config.get(name)
|
|
405
|
+
if val is not None:
|
|
406
|
+
try:
|
|
407
|
+
ival = int(val)
|
|
408
|
+
if ival < 0:
|
|
409
|
+
errors.append(f"prompt_learning.gepa.{name} must be >= 0")
|
|
410
|
+
except Exception:
|
|
411
|
+
errors.append(f"prompt_learning.gepa.{name} must be an integer")
|
|
412
|
+
|
|
413
|
+
def _rate_float(name: str) -> None:
|
|
414
|
+
"""Check float in [0.0, 1.0] range."""
|
|
415
|
+
val = gepa_config.get(name)
|
|
416
|
+
if val is not None:
|
|
417
|
+
try:
|
|
418
|
+
fval = float(val)
|
|
419
|
+
if not (0.0 <= fval <= 1.0):
|
|
420
|
+
errors.append(f"prompt_learning.gepa.{name} must be between 0.0 and 1.0")
|
|
421
|
+
except Exception:
|
|
422
|
+
errors.append(f"prompt_learning.gepa.{name} must be numeric")
|
|
423
|
+
|
|
424
|
+
def _pos_float(name: str) -> None:
|
|
425
|
+
"""Check positive float."""
|
|
426
|
+
val = gepa_config.get(name)
|
|
427
|
+
if val is not None:
|
|
428
|
+
try:
|
|
429
|
+
fval = float(val)
|
|
430
|
+
if fval <= 0:
|
|
431
|
+
errors.append(f"prompt_learning.gepa.{name} must be > 0")
|
|
432
|
+
except Exception:
|
|
433
|
+
errors.append(f"prompt_learning.gepa.{name} must be numeric")
|
|
434
|
+
|
|
435
|
+
# Required positive integers
|
|
436
|
+
for fld in ("initial_population_size", "num_generations", "children_per_generation", "max_concurrent_rollouts"):
|
|
437
|
+
_pos_int(fld)
|
|
438
|
+
|
|
439
|
+
# Nested rollout config validation
|
|
440
|
+
_pos_int_nested("rollout", "budget")
|
|
441
|
+
_pos_int_nested("rollout", "max_concurrent")
|
|
442
|
+
_pos_int_nested("rollout", "minibatch_size")
|
|
443
|
+
|
|
444
|
+
# Nested population config validation
|
|
445
|
+
_pos_int_nested("population", "initial_size")
|
|
446
|
+
_pos_int_nested("population", "num_generations")
|
|
447
|
+
_pos_int_nested("population", "children_per_generation")
|
|
448
|
+
_rate_float("mutation_rate") # Can be at top level or in mutation section
|
|
449
|
+
_rate_float("crossover_rate") # Can be at top level or in population section
|
|
450
|
+
_pos_float("selection_pressure") # Must be >= 1.0
|
|
451
|
+
selection_pressure = gepa_config.get("selection_pressure")
|
|
452
|
+
if selection_pressure is not None:
|
|
453
|
+
try:
|
|
454
|
+
sp = float(selection_pressure)
|
|
455
|
+
if sp < 1.0:
|
|
456
|
+
errors.append("prompt_learning.gepa.selection_pressure must be >= 1.0")
|
|
457
|
+
except Exception:
|
|
458
|
+
pass # Already caught by type check
|
|
459
|
+
_non_neg_int("patience_generations")
|
|
460
|
+
|
|
461
|
+
# Nested archive config validation
|
|
462
|
+
_pos_int_nested("archive", "size")
|
|
463
|
+
_pos_int_nested("archive", "pareto_set_size")
|
|
464
|
+
_pos_float("pareto_eps") # Must be > 0, typically very small
|
|
465
|
+
_rate_float("feedback_fraction")
|
|
466
|
+
|
|
467
|
+
# Nested mutation config validation
|
|
468
|
+
mutation_config = gepa_config.get("mutation")
|
|
469
|
+
if mutation_config and isinstance(mutation_config, dict):
|
|
470
|
+
_rate_float("mutation_rate") # Check in mutation section too
|
|
471
|
+
mutation_model = mutation_config.get("llm_model")
|
|
472
|
+
mutation_provider = mutation_config.get("llm_provider", "").strip()
|
|
473
|
+
if mutation_model:
|
|
474
|
+
if not mutation_provider:
|
|
475
|
+
errors.append(
|
|
476
|
+
"Missing required field: prompt_learning.gepa.mutation.llm_provider\n"
|
|
477
|
+
" Required when prompt_learning.gepa.mutation.llm_model is set"
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
errors.extend(_validate_model_for_provider(
|
|
481
|
+
mutation_model, mutation_provider, "prompt_learning.gepa.mutation.llm_model", allow_nano=False
|
|
482
|
+
))
|
|
483
|
+
|
|
484
|
+
# Top-level mutation_rate and crossover_rate (if not in nested sections)
|
|
485
|
+
if not (mutation_config and isinstance(mutation_config, dict) and "rate" in mutation_config):
|
|
486
|
+
_rate_float("mutation_rate")
|
|
487
|
+
population_config = gepa_config.get("population")
|
|
488
|
+
if not (population_config and isinstance(population_config, dict) and "crossover_rate" in population_config):
|
|
489
|
+
_rate_float("crossover_rate")
|
|
490
|
+
|
|
491
|
+
# Budget cap
|
|
492
|
+
max_spend = gepa_config.get("max_spend_usd")
|
|
493
|
+
if max_spend is not None:
|
|
494
|
+
try:
|
|
495
|
+
f = float(max_spend)
|
|
496
|
+
if f <= 0:
|
|
497
|
+
errors.append("prompt_learning.gepa.max_spend_usd must be > 0 when provided")
|
|
498
|
+
except (ValueError, TypeError):
|
|
499
|
+
errors.append("prompt_learning.gepa.max_spend_usd must be numeric")
|
|
500
|
+
|
|
501
|
+
# Rollout budget validation
|
|
502
|
+
rollout_config = gepa_config.get("rollout")
|
|
503
|
+
rollout_budget = None
|
|
504
|
+
if rollout_config and isinstance(rollout_config, dict):
|
|
505
|
+
rollout_budget = rollout_config.get("budget")
|
|
506
|
+
if rollout_budget is None:
|
|
507
|
+
rollout_budget = gepa_config.get("rollout_budget")
|
|
508
|
+
if rollout_budget is not None:
|
|
509
|
+
try:
|
|
510
|
+
rb = int(rollout_budget)
|
|
511
|
+
if rb <= 0:
|
|
512
|
+
errors.append("prompt_learning.gepa.rollout.budget (or rollout_budget) must be > 0 when provided")
|
|
513
|
+
except Exception:
|
|
514
|
+
errors.append("prompt_learning.gepa.rollout.budget (or rollout_budget) must be an integer")
|
|
515
|
+
|
|
516
|
+
# Minibatch size validation
|
|
517
|
+
minibatch_size = None
|
|
518
|
+
if rollout_config and isinstance(rollout_config, dict):
|
|
519
|
+
minibatch_size = rollout_config.get("minibatch_size")
|
|
520
|
+
if minibatch_size is None:
|
|
521
|
+
minibatch_size = gepa_config.get("minibatch_size")
|
|
522
|
+
if minibatch_size is not None:
|
|
523
|
+
try:
|
|
524
|
+
mbs = int(minibatch_size)
|
|
525
|
+
if mbs <= 0:
|
|
526
|
+
errors.append("prompt_learning.gepa.rollout.minibatch_size (or minibatch_size) must be > 0")
|
|
527
|
+
except Exception:
|
|
528
|
+
errors.append("prompt_learning.gepa.rollout.minibatch_size (or minibatch_size) must be an integer")
|
|
529
|
+
|
|
530
|
+
# Proposer type validation
|
|
531
|
+
proposer_type = gepa_config.get("proposer_type", "dspy")
|
|
532
|
+
if proposer_type not in ("dspy", "spec"):
|
|
533
|
+
errors.append(
|
|
534
|
+
f"Invalid proposer_type: '{proposer_type}'\n"
|
|
535
|
+
f" Must be one of: 'dspy', 'spec'\n"
|
|
536
|
+
f" Got: '{proposer_type}'"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Spec validation when proposer_type is "spec"
|
|
540
|
+
if proposer_type == "spec":
|
|
541
|
+
spec_path = gepa_config.get("spec_path")
|
|
542
|
+
if not spec_path:
|
|
543
|
+
errors.append(
|
|
544
|
+
"Missing required field: prompt_learning.gepa.spec_path\n"
|
|
545
|
+
" Required when proposer_type='spec'\n"
|
|
546
|
+
" Example:\n"
|
|
547
|
+
" [prompt_learning.gepa]\n"
|
|
548
|
+
" proposer_type = \"spec\"\n"
|
|
549
|
+
" spec_path = \"examples/task_apps/banking77/banking77_spec.json\""
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
# Validate spec_max_tokens if provided
|
|
553
|
+
spec_max_tokens = gepa_config.get("spec_max_tokens")
|
|
554
|
+
if spec_max_tokens is not None:
|
|
555
|
+
try:
|
|
556
|
+
smt = int(spec_max_tokens)
|
|
557
|
+
if smt <= 0:
|
|
558
|
+
errors.append("prompt_learning.gepa.spec_max_tokens must be > 0")
|
|
559
|
+
except Exception:
|
|
560
|
+
errors.append("prompt_learning.gepa.spec_max_tokens must be an integer")
|
|
561
|
+
|
|
562
|
+
# Validate spec_priority_threshold if provided
|
|
563
|
+
spec_priority_threshold = gepa_config.get("spec_priority_threshold")
|
|
564
|
+
if spec_priority_threshold is not None:
|
|
565
|
+
try:
|
|
566
|
+
spt = int(spec_priority_threshold)
|
|
567
|
+
if spt < 0:
|
|
568
|
+
errors.append("prompt_learning.gepa.spec_priority_threshold must be >= 0")
|
|
569
|
+
except Exception:
|
|
570
|
+
errors.append("prompt_learning.gepa.spec_priority_threshold must be an integer")
|
|
571
|
+
|
|
572
|
+
# Archive size validation
|
|
573
|
+
archive_config = gepa_config.get("archive")
|
|
574
|
+
archive_size = None
|
|
575
|
+
if archive_config and isinstance(archive_config, dict):
|
|
576
|
+
archive_size = archive_config.get("size")
|
|
577
|
+
if archive_size is None:
|
|
578
|
+
archive_size = gepa_config.get("archive_size")
|
|
579
|
+
if archive_size is not None:
|
|
580
|
+
try:
|
|
581
|
+
asize = int(archive_size)
|
|
582
|
+
if asize <= 0:
|
|
583
|
+
errors.append("prompt_learning.gepa.archive.size (or archive_size) must be > 0")
|
|
584
|
+
except Exception:
|
|
585
|
+
errors.append("prompt_learning.gepa.archive.size (or archive_size) must be an integer")
|
|
586
|
+
|
|
587
|
+
# Pareto eps validation
|
|
588
|
+
pareto_eps = None
|
|
589
|
+
if archive_config and isinstance(archive_config, dict):
|
|
590
|
+
pareto_eps = archive_config.get("pareto_eps")
|
|
591
|
+
if pareto_eps is None:
|
|
592
|
+
pareto_eps = gepa_config.get("pareto_eps")
|
|
593
|
+
if pareto_eps is not None:
|
|
594
|
+
try:
|
|
595
|
+
pe = float(pareto_eps)
|
|
596
|
+
if pe <= 0:
|
|
597
|
+
errors.append("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) must be > 0")
|
|
598
|
+
elif pe >= 1.0:
|
|
599
|
+
errors.append("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) should be < 1.0 (typically 1e-6)")
|
|
600
|
+
except Exception:
|
|
601
|
+
errors.append("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) must be numeric")
|
|
602
|
+
|
|
603
|
+
# Feedback fraction validation
|
|
604
|
+
feedback_fraction = None
|
|
605
|
+
if archive_config and isinstance(archive_config, dict):
|
|
606
|
+
feedback_fraction = archive_config.get("feedback_fraction")
|
|
607
|
+
if feedback_fraction is None:
|
|
608
|
+
feedback_fraction = gepa_config.get("feedback_fraction")
|
|
609
|
+
if feedback_fraction is not None:
|
|
610
|
+
try:
|
|
611
|
+
ff = float(feedback_fraction)
|
|
612
|
+
if not (0.0 <= ff <= 1.0):
|
|
613
|
+
errors.append("prompt_learning.gepa.archive.feedback_fraction (or feedback_fraction) must be between 0.0 and 1.0")
|
|
614
|
+
except Exception:
|
|
615
|
+
errors.append("prompt_learning.gepa.archive.feedback_fraction (or feedback_fraction) must be numeric")
|
|
616
|
+
|
|
617
|
+
# Token counting model validation (should be a valid model name)
|
|
618
|
+
token_config = gepa_config.get("token")
|
|
619
|
+
token_counting_model = None
|
|
620
|
+
if token_config and isinstance(token_config, dict):
|
|
621
|
+
token_counting_model = token_config.get("counting_model")
|
|
622
|
+
if token_counting_model is None:
|
|
623
|
+
token_counting_model = gepa_config.get("token_counting_model")
|
|
624
|
+
if token_counting_model and (not isinstance(token_counting_model, str) or not token_counting_model.strip()):
|
|
625
|
+
# Basic validation - should be a non-empty string
|
|
626
|
+
errors.append("prompt_learning.gepa.token.counting_model (or token_counting_model) must be a non-empty string")
|
|
627
|
+
|
|
628
|
+
# Module/stage validation for multi-stage
|
|
629
|
+
if has_multi_stage:
|
|
630
|
+
modules_config = gepa_config.get("modules")
|
|
631
|
+
if modules_config and isinstance(modules_config, list):
|
|
632
|
+
for idx, module_entry in enumerate(modules_config):
|
|
633
|
+
if isinstance(module_entry, dict):
|
|
634
|
+
module_id = module_entry.get("module_id") or module_entry.get("stage_id") or f"module_{idx}"
|
|
635
|
+
max_instruction_slots = module_entry.get("max_instruction_slots")
|
|
636
|
+
max_tokens = module_entry.get("max_tokens")
|
|
637
|
+
allowed_tools = module_entry.get("allowed_tools")
|
|
638
|
+
|
|
639
|
+
# Validate max_instruction_slots
|
|
640
|
+
if max_instruction_slots is not None:
|
|
641
|
+
try:
|
|
642
|
+
mis = int(max_instruction_slots)
|
|
643
|
+
if mis < 1:
|
|
644
|
+
errors.append(
|
|
645
|
+
f"prompt_learning.gepa.modules[{idx}].max_instruction_slots must be >= 1"
|
|
646
|
+
)
|
|
647
|
+
except Exception:
|
|
648
|
+
errors.append(
|
|
649
|
+
f"prompt_learning.gepa.modules[{idx}].max_instruction_slots must be an integer"
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Validate max_tokens
|
|
653
|
+
if max_tokens is not None:
|
|
654
|
+
try:
|
|
655
|
+
mt = int(max_tokens)
|
|
656
|
+
if mt <= 0:
|
|
657
|
+
errors.append(
|
|
658
|
+
f"prompt_learning.gepa.modules[{idx}].max_tokens must be > 0"
|
|
659
|
+
)
|
|
660
|
+
except Exception:
|
|
661
|
+
errors.append(
|
|
662
|
+
f"prompt_learning.gepa.modules[{idx}].max_tokens must be an integer"
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
# Validate allowed_tools
|
|
666
|
+
if allowed_tools is not None:
|
|
667
|
+
if not isinstance(allowed_tools, list):
|
|
668
|
+
errors.append(
|
|
669
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools must be a list"
|
|
670
|
+
)
|
|
671
|
+
else:
|
|
672
|
+
if len(allowed_tools) == 0:
|
|
673
|
+
errors.append(
|
|
674
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools cannot be empty (use null/omit to allow all tools)"
|
|
675
|
+
)
|
|
676
|
+
else:
|
|
677
|
+
# Check for duplicates
|
|
678
|
+
seen_tools = set()
|
|
679
|
+
for tool_idx, tool in enumerate(allowed_tools):
|
|
680
|
+
if not isinstance(tool, str):
|
|
681
|
+
errors.append(
|
|
682
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools[{tool_idx}] must be a string"
|
|
683
|
+
)
|
|
684
|
+
elif not tool.strip():
|
|
685
|
+
errors.append(
|
|
686
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools[{tool_idx}] cannot be empty"
|
|
687
|
+
)
|
|
688
|
+
elif tool.strip() in seen_tools:
|
|
689
|
+
errors.append(
|
|
690
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools contains duplicate '{tool.strip()}'"
|
|
691
|
+
)
|
|
692
|
+
else:
|
|
693
|
+
seen_tools.add(tool.strip())
|
|
694
|
+
|
|
695
|
+
elif algorithm == "mipro":
|
|
696
|
+
mipro_config = pl_section.get("mipro")
|
|
697
|
+
if not mipro_config or not isinstance(mipro_config, dict):
|
|
698
|
+
errors.append("Missing [prompt_learning.mipro] section for MIPRO algorithm")
|
|
699
|
+
else:
|
|
700
|
+
# Validate required MIPRO fields
|
|
701
|
+
def _pos_int(name: str) -> None:
|
|
702
|
+
val = mipro_config.get(name)
|
|
703
|
+
if val is not None:
|
|
704
|
+
try:
|
|
705
|
+
ival = int(val)
|
|
706
|
+
if ival <= 0:
|
|
707
|
+
errors.append(f"prompt_learning.mipro.{name} must be > 0")
|
|
708
|
+
except Exception:
|
|
709
|
+
errors.append(f"prompt_learning.mipro.{name} must be an integer")
|
|
710
|
+
|
|
711
|
+
def _non_neg_int(name: str) -> None:
|
|
712
|
+
"""Check non-negative int."""
|
|
713
|
+
val = mipro_config.get(name)
|
|
714
|
+
if val is not None:
|
|
715
|
+
try:
|
|
716
|
+
ival = int(val)
|
|
717
|
+
if ival < 0:
|
|
718
|
+
errors.append(f"prompt_learning.mipro.{name} must be >= 0")
|
|
719
|
+
except Exception:
|
|
720
|
+
errors.append(f"prompt_learning.mipro.{name} must be an integer")
|
|
721
|
+
|
|
722
|
+
def _rate_float(name: str) -> None:
|
|
723
|
+
"""Check float in [0.0, 1.0] range."""
|
|
724
|
+
val = mipro_config.get(name)
|
|
725
|
+
if val is not None:
|
|
726
|
+
try:
|
|
727
|
+
fval = float(val)
|
|
728
|
+
if not (0.0 <= fval <= 1.0):
|
|
729
|
+
errors.append(f"prompt_learning.mipro.{name} must be between 0.0 and 1.0")
|
|
730
|
+
except Exception:
|
|
731
|
+
errors.append(f"prompt_learning.mipro.{name} must be numeric")
|
|
732
|
+
|
|
733
|
+
def _pos_float(name: str) -> None:
|
|
734
|
+
"""Check positive float."""
|
|
735
|
+
val = mipro_config.get(name)
|
|
736
|
+
if val is not None:
|
|
737
|
+
try:
|
|
738
|
+
fval = float(val)
|
|
739
|
+
if fval <= 0:
|
|
740
|
+
errors.append(f"prompt_learning.mipro.{name} must be > 0")
|
|
741
|
+
except Exception:
|
|
742
|
+
errors.append(f"prompt_learning.mipro.{name} must be numeric")
|
|
743
|
+
|
|
744
|
+
# Required numeric fields
|
|
745
|
+
for fld in ("num_iterations", "num_evaluations_per_iteration", "batch_size", "max_concurrent"):
|
|
746
|
+
_pos_int(fld)
|
|
747
|
+
|
|
748
|
+
# Additional MIPRO numeric validations
|
|
749
|
+
_pos_int("max_demo_set_size")
|
|
750
|
+
_pos_int("max_demo_sets")
|
|
751
|
+
_pos_int("max_instruction_sets")
|
|
752
|
+
_pos_int("full_eval_every_k")
|
|
753
|
+
_pos_int("instructions_per_batch")
|
|
754
|
+
_pos_int("max_instructions")
|
|
755
|
+
_pos_int("duplicate_retry_limit")
|
|
756
|
+
|
|
757
|
+
# Validate meta_model is set and supported
|
|
758
|
+
meta_model = mipro_config.get("meta_model")
|
|
759
|
+
meta_model_provider = mipro_config.get("meta_model_provider", "").strip()
|
|
760
|
+
if not meta_model:
|
|
761
|
+
errors.append("Missing required field: prompt_learning.mipro.meta_model")
|
|
762
|
+
else:
|
|
763
|
+
if not meta_model_provider:
|
|
764
|
+
errors.append(
|
|
765
|
+
"Missing required field: prompt_learning.mipro.meta_model_provider\n"
|
|
766
|
+
" Required when prompt_learning.mipro.meta_model is set"
|
|
767
|
+
)
|
|
768
|
+
else:
|
|
769
|
+
errors.extend(_validate_model_for_provider(
|
|
770
|
+
meta_model, meta_model_provider, "prompt_learning.mipro.meta_model", allow_nano=False
|
|
771
|
+
))
|
|
772
|
+
|
|
773
|
+
# Validate meta model temperature
|
|
774
|
+
meta_temperature = mipro_config.get("meta_model_temperature")
|
|
775
|
+
if meta_temperature is not None:
|
|
776
|
+
try:
|
|
777
|
+
temp = float(meta_temperature)
|
|
778
|
+
if temp < 0.0:
|
|
779
|
+
errors.append("prompt_learning.mipro.meta_model_temperature must be >= 0.0")
|
|
780
|
+
except Exception:
|
|
781
|
+
errors.append("prompt_learning.mipro.meta_model_temperature must be numeric")
|
|
782
|
+
|
|
783
|
+
# Validate meta model max_tokens
|
|
784
|
+
meta_max_tokens = mipro_config.get("meta_model_max_tokens")
|
|
785
|
+
if meta_max_tokens is not None:
|
|
786
|
+
try:
|
|
787
|
+
mmt = int(meta_max_tokens)
|
|
788
|
+
if mmt <= 0:
|
|
789
|
+
errors.append("prompt_learning.mipro.meta_model_max_tokens must be > 0")
|
|
790
|
+
except Exception:
|
|
791
|
+
errors.append("prompt_learning.mipro.meta_model_max_tokens must be an integer")
|
|
792
|
+
|
|
793
|
+
# Validate generate_at_iterations
|
|
794
|
+
generate_at_iterations = mipro_config.get("generate_at_iterations")
|
|
795
|
+
if generate_at_iterations is not None:
|
|
796
|
+
if not isinstance(generate_at_iterations, list):
|
|
797
|
+
errors.append("prompt_learning.mipro.generate_at_iterations must be a list")
|
|
798
|
+
else:
|
|
799
|
+
for idx, iter_val in enumerate(generate_at_iterations):
|
|
800
|
+
try:
|
|
801
|
+
iter_int = int(iter_val)
|
|
802
|
+
if iter_int < 0:
|
|
803
|
+
errors.append(
|
|
804
|
+
f"prompt_learning.mipro.generate_at_iterations[{idx}] must be >= 0"
|
|
805
|
+
)
|
|
806
|
+
except Exception:
|
|
807
|
+
errors.append(
|
|
808
|
+
f"prompt_learning.mipro.generate_at_iterations[{idx}] must be an integer"
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
# Validate spec configuration
|
|
812
|
+
spec_path = mipro_config.get("spec_path")
|
|
813
|
+
if spec_path:
|
|
814
|
+
# Validate spec_max_tokens if provided
|
|
815
|
+
spec_max_tokens = mipro_config.get("spec_max_tokens")
|
|
816
|
+
if spec_max_tokens is not None:
|
|
817
|
+
try:
|
|
818
|
+
smt = int(spec_max_tokens)
|
|
819
|
+
if smt <= 0:
|
|
820
|
+
errors.append("prompt_learning.mipro.spec_max_tokens must be > 0")
|
|
821
|
+
except Exception:
|
|
822
|
+
errors.append("prompt_learning.mipro.spec_max_tokens must be an integer")
|
|
823
|
+
|
|
824
|
+
# Validate spec_priority_threshold if provided
|
|
825
|
+
spec_priority_threshold = mipro_config.get("spec_priority_threshold")
|
|
826
|
+
if spec_priority_threshold is not None:
|
|
827
|
+
try:
|
|
828
|
+
spt = int(spec_priority_threshold)
|
|
829
|
+
if spt < 0:
|
|
830
|
+
errors.append("prompt_learning.mipro.spec_priority_threshold must be >= 0")
|
|
831
|
+
except Exception:
|
|
832
|
+
errors.append("prompt_learning.mipro.spec_priority_threshold must be an integer")
|
|
833
|
+
|
|
834
|
+
# Validate modules/stages configuration
|
|
835
|
+
modules_config = mipro_config.get("modules")
|
|
836
|
+
if modules_config and isinstance(modules_config, list):
|
|
837
|
+
max_instruction_sets = mipro_config.get("max_instruction_sets", 128)
|
|
838
|
+
max_demo_sets = mipro_config.get("max_demo_sets", 128)
|
|
839
|
+
seen_module_ids = set()
|
|
840
|
+
seen_stage_ids = set()
|
|
841
|
+
|
|
842
|
+
for module_idx, module_entry in enumerate(modules_config):
|
|
843
|
+
if not isinstance(module_entry, dict):
|
|
844
|
+
errors.append(
|
|
845
|
+
f"prompt_learning.mipro.modules[{module_idx}] must be a table/dict"
|
|
846
|
+
)
|
|
847
|
+
continue
|
|
848
|
+
|
|
849
|
+
module_id = module_entry.get("module_id") or module_entry.get("id") or f"module_{module_idx}"
|
|
850
|
+
if module_id in seen_module_ids:
|
|
851
|
+
errors.append(
|
|
852
|
+
f"Duplicate module_id '{module_id}' in prompt_learning.mipro.modules"
|
|
853
|
+
)
|
|
854
|
+
seen_module_ids.add(module_id)
|
|
855
|
+
|
|
856
|
+
# Validate stages
|
|
857
|
+
stages = module_entry.get("stages")
|
|
858
|
+
if stages is not None:
|
|
859
|
+
if not isinstance(stages, list):
|
|
860
|
+
errors.append(
|
|
861
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages must be a list"
|
|
862
|
+
)
|
|
863
|
+
else:
|
|
864
|
+
for stage_idx, stage_entry in enumerate(stages):
|
|
865
|
+
if isinstance(stage_entry, dict):
|
|
866
|
+
stage_id = stage_entry.get("stage_id") or stage_entry.get("module_stage_id") or f"stage_{stage_idx}"
|
|
867
|
+
if stage_id in seen_stage_ids:
|
|
868
|
+
errors.append(
|
|
869
|
+
f"Duplicate stage_id '{stage_id}' across modules"
|
|
870
|
+
)
|
|
871
|
+
seen_stage_ids.add(stage_id)
|
|
872
|
+
|
|
873
|
+
# Validate max_instruction_slots <= max_instruction_sets
|
|
874
|
+
max_instr_slots = stage_entry.get("max_instruction_slots")
|
|
875
|
+
if max_instr_slots is not None:
|
|
876
|
+
try:
|
|
877
|
+
mis = int(max_instr_slots)
|
|
878
|
+
if mis < 1:
|
|
879
|
+
errors.append(
|
|
880
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots must be >= 1"
|
|
881
|
+
)
|
|
882
|
+
elif mis > max_instruction_sets:
|
|
883
|
+
errors.append(
|
|
884
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots ({mis}) "
|
|
885
|
+
f"exceeds max_instruction_sets ({max_instruction_sets})"
|
|
886
|
+
)
|
|
887
|
+
except Exception:
|
|
888
|
+
errors.append(
|
|
889
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots must be an integer"
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# Validate max_demo_slots <= max_demo_sets
|
|
893
|
+
max_demo_slots = stage_entry.get("max_demo_slots")
|
|
894
|
+
if max_demo_slots is not None:
|
|
895
|
+
try:
|
|
896
|
+
mds = int(max_demo_slots)
|
|
897
|
+
if mds < 0:
|
|
898
|
+
errors.append(
|
|
899
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots must be >= 0"
|
|
900
|
+
)
|
|
901
|
+
elif mds > max_demo_sets:
|
|
902
|
+
errors.append(
|
|
903
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots ({mds}) "
|
|
904
|
+
f"exceeds max_demo_sets ({max_demo_sets})"
|
|
905
|
+
)
|
|
906
|
+
except Exception:
|
|
907
|
+
errors.append(
|
|
908
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots must be an integer"
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
# Validate edges reference valid stages
|
|
912
|
+
edges = module_entry.get("edges")
|
|
913
|
+
if edges is not None:
|
|
914
|
+
if not isinstance(edges, list):
|
|
915
|
+
errors.append(
|
|
916
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges must be a list"
|
|
917
|
+
)
|
|
918
|
+
else:
|
|
919
|
+
stage_ids_in_module = set()
|
|
920
|
+
if stages and isinstance(stages, list):
|
|
921
|
+
for stage_entry in stages:
|
|
922
|
+
if isinstance(stage_entry, dict):
|
|
923
|
+
sid = stage_entry.get("stage_id") or stage_entry.get("module_stage_id")
|
|
924
|
+
if sid:
|
|
925
|
+
stage_ids_in_module.add(str(sid))
|
|
926
|
+
|
|
927
|
+
for edge_idx, edge in enumerate(edges):
|
|
928
|
+
if isinstance(edge, list | tuple) and len(edge) == 2:
|
|
929
|
+
source, target = edge
|
|
930
|
+
elif isinstance(edge, dict):
|
|
931
|
+
source = edge.get("from") or edge.get("source")
|
|
932
|
+
target = edge.get("to") or edge.get("target")
|
|
933
|
+
else:
|
|
934
|
+
errors.append(
|
|
935
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges[{edge_idx}] must be a pair or mapping"
|
|
936
|
+
)
|
|
937
|
+
continue
|
|
938
|
+
|
|
939
|
+
source_str = str(source or "").strip()
|
|
940
|
+
target_str = str(target or "").strip()
|
|
941
|
+
if source_str and source_str not in stage_ids_in_module:
|
|
942
|
+
errors.append(
|
|
943
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges[{edge_idx}] references unknown source stage '{source_str}'"
|
|
944
|
+
)
|
|
945
|
+
if target_str and target_str not in stage_ids_in_module:
|
|
946
|
+
errors.append(
|
|
947
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges[{edge_idx}] references unknown target stage '{target_str}'"
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
# CRITICAL: Validate bootstrap_train_seeds and online_pool (can be at top level or under mipro)
|
|
951
|
+
bootstrap_seeds = pl_section.get("bootstrap_train_seeds") or (mipro_config.get("bootstrap_train_seeds") if isinstance(mipro_config, dict) else None)
|
|
952
|
+
online_pool = pl_section.get("online_pool") or (mipro_config.get("online_pool") if isinstance(mipro_config, dict) else None)
|
|
953
|
+
|
|
954
|
+
if not bootstrap_seeds:
|
|
955
|
+
errors.append(
|
|
956
|
+
"Missing required field: prompt_learning.bootstrap_train_seeds\n"
|
|
957
|
+
" MIPRO requires bootstrap seeds for the few-shot bootstrapping phase.\n"
|
|
958
|
+
" Example:\n"
|
|
959
|
+
" [prompt_learning]\n"
|
|
960
|
+
" bootstrap_train_seeds = [0, 1, 2, 3, 4]"
|
|
961
|
+
)
|
|
962
|
+
elif not isinstance(bootstrap_seeds, list):
|
|
963
|
+
errors.append("prompt_learning.bootstrap_train_seeds must be an array")
|
|
964
|
+
elif len(bootstrap_seeds) == 0:
|
|
965
|
+
errors.append("prompt_learning.bootstrap_train_seeds cannot be empty")
|
|
966
|
+
|
|
967
|
+
if not online_pool:
|
|
968
|
+
errors.append(
|
|
969
|
+
"Missing required field: prompt_learning.online_pool\n"
|
|
970
|
+
" MIPRO requires online_pool seeds for mini-batch evaluation during optimization.\n"
|
|
971
|
+
" Example:\n"
|
|
972
|
+
" [prompt_learning]\n"
|
|
973
|
+
" online_pool = [5, 6, 7, 8, 9]"
|
|
974
|
+
)
|
|
975
|
+
elif not isinstance(online_pool, list):
|
|
976
|
+
errors.append("prompt_learning.online_pool must be an array")
|
|
977
|
+
elif len(online_pool) == 0:
|
|
978
|
+
errors.append("prompt_learning.online_pool cannot be empty")
|
|
979
|
+
|
|
980
|
+
# Validate few_shot_score_threshold (if mipro_config exists)
|
|
981
|
+
if isinstance(mipro_config, dict):
|
|
982
|
+
threshold = mipro_config.get("few_shot_score_threshold")
|
|
983
|
+
if threshold is not None:
|
|
984
|
+
try:
|
|
985
|
+
f = float(threshold)
|
|
986
|
+
if not (0.0 <= f <= 1.0):
|
|
987
|
+
errors.append("prompt_learning.mipro.few_shot_score_threshold must be between 0.0 and 1.0")
|
|
988
|
+
except Exception:
|
|
989
|
+
errors.append("prompt_learning.mipro.few_shot_score_threshold must be a number")
|
|
990
|
+
|
|
991
|
+
# Validate reference pool doesn't overlap with bootstrap/online/test pools
|
|
992
|
+
reference_pool = mipro_config.get("reference_pool") or pl_section.get("reference_pool")
|
|
993
|
+
if reference_pool:
|
|
994
|
+
if not isinstance(reference_pool, list):
|
|
995
|
+
errors.append("prompt_learning.mipro.reference_pool (or prompt_learning.reference_pool) must be an array")
|
|
996
|
+
else:
|
|
997
|
+
all_train_test = set(bootstrap_seeds or []) | set(online_pool or []) | set(mipro_config.get("test_pool") or pl_section.get("test_pool") or [])
|
|
998
|
+
overlapping = set(reference_pool) & all_train_test
|
|
999
|
+
if overlapping:
|
|
1000
|
+
errors.append(
|
|
1001
|
+
f"reference_pool seeds must not overlap with bootstrap/online/test pools. "
|
|
1002
|
+
f"Found overlapping seeds: {sorted(overlapping)}"
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
# Raise all errors at once for better UX
|
|
1006
|
+
if errors:
|
|
1007
|
+
_raise_validation_errors(errors, config_path)
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
def _raise_validation_errors(errors: list[str], config_path: Path) -> None:
|
|
1011
|
+
"""Format and raise validation errors."""
|
|
1012
|
+
error_msg = (
|
|
1013
|
+
f"\nā Invalid prompt learning config: {config_path}\n\n"
|
|
1014
|
+
f"Found {len(errors)} error(s):\n\n"
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
for i, error in enumerate(errors, 1):
|
|
1018
|
+
# Indent multi-line errors
|
|
1019
|
+
indented_error = "\n ".join(error.split("\n"))
|
|
1020
|
+
error_msg += f"{i}. {indented_error}\n\n"
|
|
1021
|
+
|
|
1022
|
+
error_msg += (
|
|
1023
|
+
"š See example configs:\n"
|
|
1024
|
+
" - examples/blog_posts/gepa/configs/banking77_gepa_local.toml\n"
|
|
1025
|
+
" - examples/blog_posts/mipro/configs/banking77_mipro_local.toml\n"
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
raise click.ClickException(error_msg)
|
|
1029
|
+
|
|
1030
|
+
|
|
1031
|
+
def validate_rl_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
1032
|
+
"""
|
|
1033
|
+
Validate RL config BEFORE sending to backend.
|
|
1034
|
+
|
|
1035
|
+
Args:
|
|
1036
|
+
config_data: Parsed TOML/JSON config
|
|
1037
|
+
config_path: Path to config file (for error messages)
|
|
1038
|
+
|
|
1039
|
+
Raises:
|
|
1040
|
+
ConfigValidationError: If config is invalid
|
|
1041
|
+
click.ClickException: If validation fails (for CLI)
|
|
1042
|
+
"""
|
|
1043
|
+
errors: list[str] = []
|
|
1044
|
+
|
|
1045
|
+
# Check for rl section
|
|
1046
|
+
rl_section = config_data.get("rl") or config_data.get("online_rl")
|
|
1047
|
+
if not rl_section:
|
|
1048
|
+
errors.append(
|
|
1049
|
+
"Missing [rl] or [online_rl] section in config"
|
|
1050
|
+
)
|
|
1051
|
+
_raise_validation_errors(errors, config_path)
|
|
1052
|
+
return
|
|
1053
|
+
|
|
1054
|
+
# Validate algorithm
|
|
1055
|
+
algorithm = rl_section.get("algorithm")
|
|
1056
|
+
if not algorithm:
|
|
1057
|
+
errors.append(
|
|
1058
|
+
"Missing required field: rl.algorithm\n"
|
|
1059
|
+
" Must be one of: 'grpo', 'ppo', etc."
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# Validate task_url
|
|
1063
|
+
task_url = rl_section.get("task_url")
|
|
1064
|
+
if not task_url:
|
|
1065
|
+
errors.append(
|
|
1066
|
+
"Missing required field: rl.task_url"
|
|
1067
|
+
)
|
|
1068
|
+
elif not isinstance(task_url, str):
|
|
1069
|
+
errors.append(
|
|
1070
|
+
f"task_url must be a string, got {type(task_url).__name__}"
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
if errors:
|
|
1074
|
+
_raise_validation_errors(errors, config_path)
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
def validate_sft_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
1078
|
+
"""
|
|
1079
|
+
Validate SFT config BEFORE sending to backend.
|
|
1080
|
+
|
|
1081
|
+
Args:
|
|
1082
|
+
config_data: Parsed TOML/JSON config
|
|
1083
|
+
config_path: Path to config file (for error messages)
|
|
1084
|
+
|
|
1085
|
+
Raises:
|
|
1086
|
+
ConfigValidationError: If config is invalid
|
|
1087
|
+
click.ClickException: If validation fails (for CLI)
|
|
1088
|
+
"""
|
|
1089
|
+
errors: list[str] = []
|
|
1090
|
+
|
|
1091
|
+
# Check for sft section
|
|
1092
|
+
sft_section = config_data.get("sft")
|
|
1093
|
+
if not sft_section:
|
|
1094
|
+
errors.append(
|
|
1095
|
+
"Missing [sft] section in config"
|
|
1096
|
+
)
|
|
1097
|
+
_raise_validation_errors(errors, config_path)
|
|
1098
|
+
return
|
|
1099
|
+
|
|
1100
|
+
# Validate model
|
|
1101
|
+
model = sft_section.get("model")
|
|
1102
|
+
if not model:
|
|
1103
|
+
errors.append(
|
|
1104
|
+
"Missing required field: sft.model"
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
if errors:
|
|
1108
|
+
_raise_validation_errors(errors, config_path)
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
__all__ = [
|
|
1112
|
+
"ConfigValidationError",
|
|
1113
|
+
"validate_prompt_learning_config",
|
|
1114
|
+
"validate_rl_config",
|
|
1115
|
+
"validate_sft_config",
|
|
1116
|
+
]
|
|
1117
|
+
|