synth-ai 0.2.8.dev2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +44 -24
- synth_ai/__main__.py +30 -3
- synth_ai/cli/__init__.py +103 -48
- synth_ai/cli/__main__.py +42 -0
- synth_ai/cli/_internal/__init__.py +5 -0
- synth_ai/cli/_internal/modal_wrapper.py +31 -0
- synth_ai/cli/_internal/storage.py +20 -0
- synth_ai/cli/_internal/typer_patch.py +47 -0
- synth_ai/cli/_internal/validate_task_app.py +29 -0
- synth_ai/cli/agents/__init__.py +17 -0
- synth_ai/cli/agents/claude.py +77 -0
- synth_ai/cli/agents/codex.py +265 -0
- synth_ai/cli/agents/opencode.py +253 -0
- synth_ai/cli/commands/__init__.py +18 -0
- synth_ai/cli/commands/artifacts/__init__.py +13 -0
- synth_ai/cli/commands/artifacts/client.py +119 -0
- synth_ai/cli/commands/artifacts/config.py +57 -0
- synth_ai/cli/commands/artifacts/core.py +24 -0
- synth_ai/cli/commands/artifacts/download.py +188 -0
- synth_ai/cli/commands/artifacts/export.py +186 -0
- synth_ai/cli/commands/artifacts/list.py +156 -0
- synth_ai/cli/commands/artifacts/parsing.py +250 -0
- synth_ai/cli/commands/artifacts/show.py +336 -0
- synth_ai/cli/commands/demo/__init__.py +3 -0
- synth_ai/cli/commands/demo/core.py +153 -0
- synth_ai/cli/commands/eval/__init__.py +10 -0
- synth_ai/cli/commands/eval/config.py +338 -0
- synth_ai/cli/commands/eval/core.py +256 -0
- synth_ai/cli/commands/eval/runner.py +704 -0
- synth_ai/cli/commands/eval/validation.py +60 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +424 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +185 -0
- synth_ai/cli/commands/help/core.py +72 -0
- synth_ai/cli/commands/scan/__init__.py +19 -0
- synth_ai/cli/commands/scan/cloudflare_scanner.py +403 -0
- synth_ai/cli/commands/scan/core.py +344 -0
- synth_ai/cli/commands/scan/health_checker.py +242 -0
- synth_ai/cli/commands/scan/local_scanner.py +278 -0
- synth_ai/cli/commands/scan/models.py +83 -0
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1428 -0
- synth_ai/cli/commands/status/__init__.py +3 -0
- synth_ai/cli/commands/status/client.py +91 -0
- synth_ai/cli/commands/status/config.py +12 -0
- synth_ai/cli/commands/status/errors.py +11 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +3 -0
- synth_ai/cli/commands/status/subcommands/config.py +13 -0
- synth_ai/cli/commands/status/subcommands/files.py +34 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +51 -0
- synth_ai/cli/commands/status/subcommands/models.py +35 -0
- synth_ai/cli/commands/status/subcommands/runs.py +34 -0
- synth_ai/cli/commands/status/subcommands/session.py +77 -0
- synth_ai/cli/commands/status/subcommands/summary.py +39 -0
- synth_ai/cli/commands/status/subcommands/utils.py +41 -0
- synth_ai/cli/commands/status/utils.py +23 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +22 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +201 -0
- synth_ai/cli/commands/train/judge_validation.py +305 -0
- synth_ai/cli/commands/train/prompt_learning_validation.py +633 -0
- synth_ai/cli/commands/train/validation.py +392 -0
- synth_ai/cli/demo_apps/__init__.py +10 -0
- synth_ai/cli/demo_apps/core/__init__.py +28 -0
- synth_ai/{demos → cli/demo_apps}/core/cli.py +783 -441
- synth_ai/cli/demo_apps/crafter/__init__.py +1 -0
- synth_ai/cli/demo_apps/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/cli/demo_apps/crafter/grpo_crafter_task_app.py +186 -0
- synth_ai/cli/demo_apps/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/cli/demo_apps/demo_registry.py +176 -0
- synth_ai/cli/demo_apps/demo_task_apps/__init__.py +7 -0
- synth_ai/{demos → cli/demo_apps}/demo_task_apps/core.py +75 -37
- synth_ai/cli/demo_apps/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/cli/demo_apps/demo_task_apps/crafter/configs/crafter_fft_4b.toml +53 -0
- synth_ai/cli/demo_apps/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/cli/demo_apps/demo_task_apps/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/{demos → cli/demo_apps}/demo_task_apps/math/_common.py +1 -2
- synth_ai/{demos → cli/demo_apps}/demo_task_apps/math/app.py +2 -1
- synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +73 -0
- synth_ai/{demos → cli/demo_apps}/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +738 -0
- synth_ai/cli/demo_apps/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/cli/demo_apps/math/__init__.py +1 -0
- synth_ai/cli/demo_apps/math/_common.py +16 -0
- synth_ai/cli/demo_apps/math/app.py +38 -0
- synth_ai/cli/demo_apps/math/config.toml +75 -0
- synth_ai/cli/demo_apps/math/deploy_modal.py +54 -0
- synth_ai/cli/demo_apps/math/modal_task_app.py +698 -0
- synth_ai/cli/demo_apps/math/task_app_entry.py +53 -0
- synth_ai/cli/demo_apps/mipro/main.py +271 -0
- synth_ai/cli/demo_apps/mipro/task_app.py +922 -0
- synth_ai/cli/demo_apps/mipro/train_cfg.toml +92 -0
- synth_ai/cli/demos/__init__.py +12 -0
- synth_ai/cli/demos/demo.py +32 -0
- synth_ai/cli/demos/rl_demo.py +254 -0
- synth_ai/cli/deploy.py +216 -0
- synth_ai/cli/infra/__init__.py +14 -0
- synth_ai/cli/{balance.py → infra/balance.py} +16 -4
- synth_ai/cli/infra/mcp.py +35 -0
- synth_ai/cli/infra/modal_app.py +36 -0
- synth_ai/cli/infra/setup.py +69 -0
- synth_ai/cli/infra/status.py +16 -0
- synth_ai/cli/infra/turso.py +77 -0
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/agents.py +76 -0
- synth_ai/cli/lib/apps/modal_app.py +101 -0
- synth_ai/cli/lib/apps/task_app.py +642 -0
- synth_ai/cli/lib/bin.py +39 -0
- synth_ai/cli/lib/env.py +375 -0
- synth_ai/cli/lib/errors.py +85 -0
- synth_ai/cli/lib/modal.py +315 -0
- synth_ai/cli/lib/plotting.py +126 -0
- synth_ai/cli/lib/prompt_args.py +39 -0
- synth_ai/cli/lib/prompts.py +284 -0
- synth_ai/cli/lib/sqld.py +122 -0
- synth_ai/cli/lib/task_app_discovery.py +884 -0
- synth_ai/cli/lib/task_app_env.py +295 -0
- synth_ai/cli/lib/train_cfgs.py +300 -0
- synth_ai/cli/lib/tunnel_records.py +207 -0
- synth_ai/cli/local/__init__.py +14 -0
- synth_ai/cli/local/experiment_queue/__init__.py +72 -0
- synth_ai/cli/local/experiment_queue/api_schemas.py +221 -0
- synth_ai/cli/local/experiment_queue/celery_app.py +208 -0
- synth_ai/cli/local/experiment_queue/config.py +128 -0
- synth_ai/cli/local/experiment_queue/config_utils.py +272 -0
- synth_ai/cli/local/experiment_queue/database.py +175 -0
- synth_ai/cli/local/experiment_queue/dispatcher.py +119 -0
- synth_ai/cli/local/experiment_queue/models.py +231 -0
- synth_ai/cli/local/experiment_queue/progress_info.py +160 -0
- synth_ai/cli/local/experiment_queue/results.py +373 -0
- synth_ai/cli/local/experiment_queue/schemas.py +131 -0
- synth_ai/cli/local/experiment_queue/service.py +344 -0
- synth_ai/cli/local/experiment_queue/status.py +372 -0
- synth_ai/cli/local/experiment_queue/status_tracker.py +360 -0
- synth_ai/cli/local/experiment_queue/tasks.py +1984 -0
- synth_ai/cli/local/experiment_queue/trace_storage.py +65 -0
- synth_ai/cli/local/experiment_queue/validation.py +157 -0
- synth_ai/cli/local/session/__init__.py +92 -0
- synth_ai/cli/local/session/client.py +383 -0
- synth_ai/cli/local/session/constants.py +63 -0
- synth_ai/cli/local/session/exceptions.py +105 -0
- synth_ai/cli/local/session/manager.py +139 -0
- synth_ai/cli/local/session/models.py +89 -0
- synth_ai/cli/local/session/query.py +110 -0
- synth_ai/cli/root.py +150 -108
- synth_ai/cli/task_apps/__init__.py +37 -0
- synth_ai/cli/task_apps/commands.py +3145 -0
- synth_ai/cli/task_apps/deploy.py +7 -0
- synth_ai/cli/task_apps/list.py +26 -0
- synth_ai/cli/task_apps/main.py +36 -0
- synth_ai/cli/task_apps/modal_serve.py +11 -0
- synth_ai/cli/task_apps/serve.py +11 -0
- synth_ai/cli/training/__init__.py +8 -0
- synth_ai/cli/training/train.py +5 -0
- synth_ai/cli/training/train_cfg.py +34 -0
- synth_ai/cli/{watch.py → training/watch.py} +13 -18
- synth_ai/cli/turso.py +52 -0
- synth_ai/cli/utils/__init__.py +8 -0
- synth_ai/cli/utils/experiments.py +235 -0
- synth_ai/cli/utils/queue.py +504 -0
- synth_ai/cli/{recent.py → utils/recent.py} +13 -7
- synth_ai/cli/{traces.py → utils/traces.py} +9 -5
- synth_ai/contracts/__init__.py +67 -0
- synth_ai/core/__init__.py +100 -0
- synth_ai/core/_utils/__init__.py +54 -0
- synth_ai/core/_utils/base_url.py +10 -0
- synth_ai/core/_utils/http.py +10 -0
- synth_ai/core/_utils/prompts.py +14 -0
- synth_ai/core/_utils/task_app_state.py +12 -0
- synth_ai/core/_utils/user_config.py +10 -0
- synth_ai/core/apps/common.py +116 -0
- synth_ai/core/auth.py +95 -0
- synth_ai/core/cfgs.py +240 -0
- synth_ai/core/config/__init__.py +16 -0
- synth_ai/core/config/base.py +168 -0
- synth_ai/core/config/resolver.py +89 -0
- synth_ai/core/env.py +231 -0
- synth_ai/core/errors.py +126 -0
- synth_ai/core/http.py +230 -0
- synth_ai/core/integrations/__init__.py +11 -0
- synth_ai/core/integrations/cloudflare.py +1710 -0
- synth_ai/core/integrations/mcp/__init__.py +6 -0
- synth_ai/core/integrations/mcp/__main__.py +8 -0
- synth_ai/core/integrations/mcp/claude.py +36 -0
- synth_ai/core/integrations/mcp/main.py +254 -0
- synth_ai/core/integrations/mcp/setup.py +100 -0
- synth_ai/core/integrations/modal.py +277 -0
- synth_ai/core/json.py +72 -0
- synth_ai/core/log_filter.py +99 -0
- synth_ai/core/logging.py +82 -0
- synth_ai/core/paths.py +107 -0
- synth_ai/core/pricing.py +109 -0
- synth_ai/core/process.py +233 -0
- synth_ai/core/ssl.py +25 -0
- synth_ai/core/storage/__init__.py +71 -0
- synth_ai/core/task_app_state.py +318 -0
- synth_ai/core/telemetry.py +282 -0
- synth_ai/{tracing_v3 → core/tracing_v3}/__init__.py +5 -1
- synth_ai/{tracing_v3 → core/tracing_v3}/abstractions.py +21 -4
- synth_ai/core/tracing_v3/config.py +229 -0
- synth_ai/core/tracing_v3/constants.py +21 -0
- synth_ai/{tracing_v3 → core/tracing_v3}/db_config.py +42 -29
- synth_ai/{tracing_v3 → core/tracing_v3}/decorators.py +80 -45
- synth_ai/{tracing_v3 → core/tracing_v3}/examples/basic_usage.py +15 -9
- synth_ai/{tracing_v3 → core/tracing_v3}/hooks.py +6 -4
- synth_ai/{tracing_v3 → core/tracing_v3}/llm_call_record_helpers.py +161 -61
- synth_ai/{tracing_v3 → core/tracing_v3}/migration_helper.py +1 -2
- synth_ai/{tracing_v3 → core/tracing_v3}/replica_sync.py +12 -7
- synth_ai/core/tracing_v3/serialization.py +130 -0
- synth_ai/{tracing_v3 → core/tracing_v3}/session_tracer.py +88 -21
- synth_ai/{tracing_v3 → core/tracing_v3}/storage/base.py +99 -12
- synth_ai/core/tracing_v3/storage/config.py +109 -0
- synth_ai/{tracing_v3 → core/tracing_v3}/storage/factory.py +11 -9
- synth_ai/{tracing_v3 → core/tracing_v3}/storage/utils.py +15 -11
- synth_ai/core/tracing_v3/trace_utils.py +326 -0
- synth_ai/core/tracing_v3/turso/__init__.py +12 -0
- synth_ai/core/tracing_v3/turso/daemon.py +278 -0
- synth_ai/{tracing_v3 → core/tracing_v3}/turso/models.py +7 -3
- synth_ai/core/tracing_v3/turso/native_manager.py +1385 -0
- synth_ai/{tracing_v3 → core/tracing_v3}/utils.py +5 -4
- synth_ai/core/urls.py +18 -0
- synth_ai/core/user_config.py +137 -0
- synth_ai/core/uvicorn.py +222 -0
- synth_ai/data/__init__.py +83 -0
- synth_ai/data/enums.py +123 -0
- synth_ai/data/rewards.py +152 -0
- synth_ai/data/traces.py +35 -0
- synth_ai/products/__init__.py +6 -0
- synth_ai/products/graph_evolve/__init__.py +46 -0
- synth_ai/products/graph_evolve/client.py +226 -0
- synth_ai/products/graph_evolve/config.py +591 -0
- synth_ai/products/graph_evolve/converters/__init__.py +42 -0
- synth_ai/products/graph_evolve/converters/openai_sft.py +484 -0
- synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +109 -0
- synth_ai/products/graph_evolve/run.py +222 -0
- synth_ai/products/graph_gepa/__init__.py +23 -0
- synth_ai/products/graph_gepa/converters/__init__.py +19 -0
- synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
- synth_ai/sdk/__init__.py +123 -0
- synth_ai/sdk/api/__init__.py +1 -0
- synth_ai/sdk/api/models/supported.py +514 -0
- synth_ai/sdk/api/research_agent/__init__.py +296 -0
- synth_ai/sdk/api/train/__init__.py +85 -0
- synth_ai/sdk/api/train/builders.py +895 -0
- synth_ai/sdk/api/train/cli.py +2199 -0
- synth_ai/sdk/api/train/config_finder.py +267 -0
- synth_ai/sdk/api/train/configs/__init__.py +65 -0
- synth_ai/sdk/api/train/configs/prompt_learning.py +1706 -0
- synth_ai/sdk/api/train/configs/rl.py +187 -0
- synth_ai/sdk/api/train/configs/sft.py +99 -0
- synth_ai/sdk/api/train/configs/shared.py +81 -0
- synth_ai/sdk/api/train/context_learning.py +312 -0
- synth_ai/sdk/api/train/env_resolver.py +418 -0
- synth_ai/sdk/api/train/graph_validators.py +216 -0
- synth_ai/sdk/api/train/graphgen.py +984 -0
- synth_ai/sdk/api/train/graphgen_models.py +823 -0
- synth_ai/sdk/api/train/graphgen_validators.py +109 -0
- synth_ai/sdk/api/train/local_api.py +10 -0
- synth_ai/sdk/api/train/pollers.py +124 -0
- synth_ai/sdk/api/train/progress/__init__.py +97 -0
- synth_ai/sdk/api/train/progress/dataclasses.py +569 -0
- synth_ai/sdk/api/train/progress/events.py +326 -0
- synth_ai/sdk/api/train/progress/results.py +428 -0
- synth_ai/sdk/api/train/progress/tracker.py +641 -0
- synth_ai/sdk/api/train/prompt_learning.py +469 -0
- synth_ai/sdk/api/train/rl.py +441 -0
- synth_ai/sdk/api/train/sft.py +396 -0
- synth_ai/sdk/api/train/summary.py +522 -0
- synth_ai/sdk/api/train/supported_algos.py +147 -0
- synth_ai/sdk/api/train/task_app.py +351 -0
- synth_ai/sdk/api/train/utils.py +279 -0
- synth_ai/sdk/api/train/validators.py +2424 -0
- synth_ai/sdk/graphs/__init__.py +15 -0
- synth_ai/sdk/graphs/completions.py +570 -0
- synth_ai/{inference → sdk/inference}/__init__.py +0 -1
- synth_ai/sdk/inference/client.py +128 -0
- synth_ai/sdk/jobs/__init__.py +16 -0
- synth_ai/sdk/jobs/client.py +371 -0
- synth_ai/sdk/judging/__init__.py +14 -0
- synth_ai/sdk/judging/base.py +24 -0
- synth_ai/sdk/judging/client.py +40 -0
- synth_ai/sdk/judging/schemas.py +222 -0
- synth_ai/sdk/judging/types.py +42 -0
- synth_ai/sdk/learning/__init__.py +99 -0
- synth_ai/sdk/learning/algorithms.py +14 -0
- synth_ai/{learning → sdk/learning}/client.py +121 -30
- synth_ai/sdk/learning/config.py +5 -0
- synth_ai/{learning → sdk/learning}/constants.py +0 -2
- synth_ai/sdk/learning/context_learning_client.py +531 -0
- synth_ai/sdk/learning/context_learning_types.py +292 -0
- synth_ai/sdk/learning/ft_client.py +7 -0
- synth_ai/{learning → sdk/learning}/health.py +15 -9
- synth_ai/{learning → sdk/learning}/jobs.py +44 -47
- synth_ai/sdk/learning/prompt_extraction.py +334 -0
- synth_ai/sdk/learning/prompt_learning_client.py +455 -0
- synth_ai/sdk/learning/prompt_learning_types.py +186 -0
- synth_ai/{rl → sdk/learning/rl}/__init__.py +13 -8
- synth_ai/{learning/rl_client.py → sdk/learning/rl/client.py} +89 -77
- synth_ai/sdk/learning/rl/config.py +31 -0
- synth_ai/{rl → sdk/learning/rl}/contracts.py +5 -14
- synth_ai/{rl → sdk/learning/rl}/env_keys.py +45 -16
- synth_ai/sdk/learning/rl/secrets.py +13 -0
- synth_ai/sdk/learning/rl_client.py +5 -0
- synth_ai/sdk/learning/sft/__init__.py +29 -0
- synth_ai/sdk/learning/sft/client.py +95 -0
- synth_ai/sdk/learning/sft/config.py +270 -0
- synth_ai/sdk/learning/sft/data.py +698 -0
- synth_ai/sdk/learning/sse.py +57 -0
- synth_ai/sdk/learning/validators.py +52 -0
- synth_ai/sdk/localapi/__init__.py +40 -0
- synth_ai/sdk/localapi/apps/__init__.py +28 -0
- synth_ai/sdk/localapi/client.py +10 -0
- synth_ai/sdk/localapi/contracts.py +10 -0
- synth_ai/sdk/localapi/helpers.py +519 -0
- synth_ai/sdk/localapi/rollouts.py +87 -0
- synth_ai/sdk/localapi/server.py +29 -0
- synth_ai/sdk/localapi/template.py +70 -0
- synth_ai/sdk/streaming/__init__.py +35 -0
- synth_ai/sdk/streaming/config.py +94 -0
- synth_ai/sdk/streaming/handlers.py +1997 -0
- synth_ai/sdk/streaming/streamer.py +713 -0
- synth_ai/sdk/streaming/types.py +112 -0
- synth_ai/sdk/task/__init__.py +164 -0
- synth_ai/sdk/task/apps/__init__.py +169 -0
- synth_ai/sdk/task/auth.py +165 -0
- synth_ai/sdk/task/client.py +175 -0
- synth_ai/sdk/task/config.py +257 -0
- synth_ai/sdk/task/contracts.py +219 -0
- synth_ai/sdk/task/datasets.py +108 -0
- synth_ai/sdk/task/errors.py +50 -0
- synth_ai/sdk/task/health.py +34 -0
- synth_ai/sdk/task/in_process.py +1190 -0
- synth_ai/sdk/task/in_process_runner.py +314 -0
- synth_ai/sdk/task/inference_api.py +299 -0
- synth_ai/sdk/task/json.py +111 -0
- synth_ai/sdk/task/proxy.py +287 -0
- synth_ai/sdk/task/rubrics/__init__.py +55 -0
- synth_ai/sdk/task/rubrics/loaders.py +156 -0
- synth_ai/sdk/task/rubrics/models.py +57 -0
- synth_ai/sdk/task/rubrics/scoring.py +116 -0
- synth_ai/sdk/task/rubrics/strict.py +149 -0
- synth_ai/sdk/task/rubrics.py +219 -0
- synth_ai/sdk/task/server.py +631 -0
- synth_ai/sdk/task/trace_correlation_helpers.py +539 -0
- synth_ai/sdk/task/tracing_utils.py +95 -0
- synth_ai/sdk/task/validators.py +441 -0
- synth_ai/sdk/task/vendors.py +59 -0
- synth_ai/sdk/training/__init__.py +102 -0
- synth_ai/sdk/tunnels/__init__.py +83 -0
- synth_ai/sdk/tunnels/cleanup.py +83 -0
- synth_ai/sdk/tunnels/ports.py +120 -0
- synth_ai/utils/__init__.py +213 -0
- synth_ai-0.4.3.dist-info/METADATA +262 -0
- synth_ai-0.4.3.dist-info/RECORD +370 -0
- {synth_ai-0.2.8.dev2.dist-info → synth_ai-0.4.3.dist-info}/entry_points.txt +0 -1
- synth_ai/cli/calc.py +0 -69
- synth_ai/cli/demo.py +0 -144
- synth_ai/cli/legacy_root_backup.py +0 -470
- synth_ai/cli/man.py +0 -106
- synth_ai/cli/rl_demo.py +0 -202
- synth_ai/cli/status.py +0 -133
- synth_ai/config/base_url.py +0 -107
- synth_ai/core/experiment.py +0 -15
- synth_ai/core/system.py +0 -15
- synth_ai/demos/core/__init__.py +0 -1
- synth_ai/demos/demo_task_apps/__init__.py +0 -1
- synth_ai/demos/demo_task_apps/math/config.toml +0 -129
- synth_ai/demos/demo_task_apps/math/deploy_task_app.sh +0 -22
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +0 -415
- synth_ai/environments/__init__.py +0 -31
- synth_ai/environments/environment/__init__.py +0 -1
- synth_ai/environments/environment/artifacts/__init__.py +0 -1
- synth_ai/environments/environment/artifacts/base.py +0 -52
- synth_ai/environments/environment/core.py +0 -67
- synth_ai/environments/environment/db/__init__.py +0 -1
- synth_ai/environments/environment/db/sqlite.py +0 -45
- synth_ai/environments/environment/registry.py +0 -233
- synth_ai/environments/environment/resources/sqlite.py +0 -45
- synth_ai/environments/environment/results.py +0 -1
- synth_ai/environments/environment/rewards/__init__.py +0 -1
- synth_ai/environments/environment/rewards/core.py +0 -29
- synth_ai/environments/environment/shared_engine.py +0 -26
- synth_ai/environments/environment/tools/__init__.py +0 -200
- synth_ai/environments/examples/__init__.py +0 -1
- synth_ai/environments/examples/bandit/__init__.py +0 -33
- synth_ai/environments/examples/bandit/engine.py +0 -294
- synth_ai/environments/examples/bandit/environment.py +0 -194
- synth_ai/environments/examples/bandit/taskset.py +0 -200
- synth_ai/environments/examples/crafter_classic/__init__.py +0 -8
- synth_ai/environments/examples/crafter_classic/agent_demos/analyze_semantic_words_markdown.py +0 -250
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_comprehensive_evaluation.py +0 -59
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_browser.py +0 -152
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_config.toml +0 -24
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_evaluation_framework.py +0 -1194
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/crafter_synth_config.toml +0 -56
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_config_modal.toml +0 -32
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/kick_off_ft_modal.py +0 -384
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_action_results.py +0 -53
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_agent_actions.py +0 -178
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_latest_run.py +0 -222
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_lm_traces.py +0 -183
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_no_rewards.py +0 -210
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/analyze_trace_issue.py +0 -206
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_db_schema.py +0 -49
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/check_latest_results.py +0 -64
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/debug_agent_responses.py +0 -88
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/old/quick_trace_check.py +0 -77
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/compare_experiments.py +0 -324
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/kick_off_ft_oai.py +0 -362
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/multi_model_config.toml +0 -49
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_enhanced_hooks.py +0 -332
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_events.py +0 -97
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/analyze_hook_results.py +0 -217
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_hook_storage.py +0 -87
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/check_seeds.py +0 -88
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/compare_seed_performance.py +0 -195
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/custom_eval_pipelines.py +0 -400
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/plot_hook_frequency.py +0 -195
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/old/seed_analysis_summary.py +0 -56
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +0 -858
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_quick_evaluation.py +0 -52
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_react_agent.py +0 -874
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_trace_evaluation.py +0 -1412
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +0 -216
- synth_ai/environments/examples/crafter_classic/agent_demos/old/compare_traces.py +0 -296
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_comprehensive_evaluation.py +0 -58
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_env_serialization.py +0 -464
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_evaluation_browser.py +0 -152
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_quick_evaluation.py +0 -51
- synth_ai/environments/examples/crafter_classic/agent_demos/old/crafter_trace_evaluation.py +0 -1412
- synth_ai/environments/examples/crafter_classic/agent_demos/old/debug_player_loss.py +0 -112
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_service.py +0 -203
- synth_ai/environments/examples/crafter_classic/agent_demos/old/diagnose_slowness.py +0 -305
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_by_difficulty.py +0 -126
- synth_ai/environments/examples/crafter_classic/agent_demos/old/eval_example.py +0 -94
- synth_ai/environments/examples/crafter_classic/agent_demos/old/explore_saved_states.py +0 -142
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft.py +0 -26
- synth_ai/environments/examples/crafter_classic/agent_demos/old/filter_traces_sft_OLD.py +0 -984
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_gemini.py +0 -724
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_data_modal.py +0 -386
- synth_ai/environments/examples/crafter_classic/agent_demos/old/generate_ft_metadata.py +0 -205
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_gemini.py +0 -150
- synth_ai/environments/examples/crafter_classic/agent_demos/old/kick_off_ft_modal.py +0 -283
- synth_ai/environments/examples/crafter_classic/agent_demos/old/prepare_vertex_ft.py +0 -280
- synth_ai/environments/examples/crafter_classic/agent_demos/old/profile_env_slowness.py +0 -456
- synth_ai/environments/examples/crafter_classic/agent_demos/old/replicate_issue.py +0 -166
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_and_eval.py +0 -102
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_comparison.py +0 -128
- synth_ai/environments/examples/crafter_classic/agent_demos/old/run_qwen_rollouts.py +0 -655
- synth_ai/environments/examples/crafter_classic/agent_demos/old/trace_eval_OLD.py +0 -202
- synth_ai/environments/examples/crafter_classic/agent_demos/old/validate_openai_format.py +0 -166
- synth_ai/environments/examples/crafter_classic/config_logging.py +0 -111
- synth_ai/environments/examples/crafter_classic/debug_translation.py +0 -0
- synth_ai/environments/examples/crafter_classic/engine.py +0 -579
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +0 -64
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +0 -6
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +0 -75
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +0 -267
- synth_ai/environments/examples/crafter_classic/environment.py +0 -404
- synth_ai/environments/examples/crafter_classic/taskset.py +0 -233
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +0 -228
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +0 -299
- synth_ai/environments/examples/crafter_custom/__init__.py +0 -4
- synth_ai/environments/examples/crafter_custom/agent_demos/__init__.py +0 -1
- synth_ai/environments/examples/crafter_custom/agent_demos/trace_eval.py +0 -202
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +0 -7
- synth_ai/environments/examples/crafter_custom/crafter/config.py +0 -182
- synth_ai/environments/examples/crafter_custom/crafter/constants.py +0 -8
- synth_ai/environments/examples/crafter_custom/crafter/engine.py +0 -269
- synth_ai/environments/examples/crafter_custom/crafter/env.py +0 -262
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +0 -417
- synth_ai/environments/examples/crafter_custom/crafter/recorder.py +0 -187
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +0 -118
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +0 -373
- synth_ai/environments/examples/crafter_custom/environment.py +0 -312
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_issue.py +0 -159
- synth_ai/environments/examples/crafter_custom/old/analyze_diamond_spawning.py +0 -158
- synth_ai/environments/examples/crafter_custom/old/compare_worlds.py +0 -71
- synth_ai/environments/examples/crafter_custom/old/dataset_stats.py +0 -105
- synth_ai/environments/examples/crafter_custom/old/diamond_spawning_summary.py +0 -119
- synth_ai/environments/examples/crafter_custom/old/example_dataset_usage.py +0 -52
- synth_ai/environments/examples/crafter_custom/run_dataset.py +0 -305
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +0 -156
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +0 -281
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +0 -25
- synth_ai/environments/examples/enron/engine.py +0 -295
- synth_ai/environments/examples/enron/environment.py +0 -166
- synth_ai/environments/examples/enron/taskset.py +0 -112
- synth_ai/environments/examples/enron/units/keyword_stats.py +0 -112
- synth_ai/environments/examples/minigrid/__init__.py +0 -48
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_evaluation_framework.py +0 -1188
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_quick_evaluation.py +0 -48
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_react_agent.py +0 -562
- synth_ai/environments/examples/minigrid/agent_demos/minigrid_trace_evaluation.py +0 -221
- synth_ai/environments/examples/minigrid/engine.py +0 -589
- synth_ai/environments/examples/minigrid/environment.py +0 -274
- synth_ai/environments/examples/minigrid/environment_mapping.py +0 -242
- synth_ai/environments/examples/minigrid/puzzle_loader.py +0 -417
- synth_ai/environments/examples/minigrid/taskset.py +0 -583
- synth_ai/environments/examples/nethack/__init__.py +0 -7
- synth_ai/environments/examples/nethack/achievements.py +0 -337
- synth_ai/environments/examples/nethack/agent_demos/nethack_evaluation_framework.py +0 -981
- synth_ai/environments/examples/nethack/agent_demos/nethack_quick_evaluation.py +0 -74
- synth_ai/environments/examples/nethack/agent_demos/nethack_react_agent.py +0 -831
- synth_ai/environments/examples/nethack/engine.py +0 -739
- synth_ai/environments/examples/nethack/environment.py +0 -256
- synth_ai/environments/examples/nethack/helpers/__init__.py +0 -41
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +0 -301
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +0 -402
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +0 -433
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +0 -200
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +0 -269
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +0 -308
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +0 -431
- synth_ai/environments/examples/nethack/taskset.py +0 -323
- synth_ai/environments/examples/red/__init__.py +0 -7
- synth_ai/environments/examples/red/agent_demos/__init__.py +0 -1
- synth_ai/environments/examples/red/config_logging.py +0 -110
- synth_ai/environments/examples/red/engine.py +0 -694
- synth_ai/environments/examples/red/engine_helpers/__init__.py +0 -1
- synth_ai/environments/examples/red/engine_helpers/memory_map.py +0 -28
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +0 -276
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +0 -142
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +0 -57
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +0 -284
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +0 -150
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +0 -138
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +0 -57
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +0 -331
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +0 -121
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +0 -559
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +0 -313
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +0 -148
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +0 -247
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +0 -368
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +0 -140
- synth_ai/environments/examples/red/environment.py +0 -238
- synth_ai/environments/examples/red/taskset.py +0 -79
- synth_ai/environments/examples/red/units/__init__.py +0 -1
- synth_ai/environments/examples/sokoban/__init__.py +0 -1
- synth_ai/environments/examples/sokoban/agent_demos/sokoban_full_eval.py +0 -899
- synth_ai/environments/examples/sokoban/engine.py +0 -678
- synth_ai/environments/examples/sokoban/engine_helpers/__init__.py +0 -1
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +0 -657
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +0 -18
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +0 -3
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +0 -131
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +0 -370
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +0 -332
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +0 -306
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +0 -67
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +0 -115
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +0 -123
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +0 -394
- synth_ai/environments/examples/sokoban/environment.py +0 -229
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +0 -440
- synth_ai/environments/examples/sokoban/puzzle_loader.py +0 -312
- synth_ai/environments/examples/sokoban/taskset.py +0 -428
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/environments/examples/tictactoe/__init__.py +0 -1
- synth_ai/environments/examples/tictactoe/engine.py +0 -368
- synth_ai/environments/examples/tictactoe/environment.py +0 -240
- synth_ai/environments/examples/tictactoe/taskset.py +0 -215
- synth_ai/environments/examples/verilog/__init__.py +0 -10
- synth_ai/environments/examples/verilog/engine.py +0 -329
- synth_ai/environments/examples/verilog/environment.py +0 -350
- synth_ai/environments/examples/verilog/taskset.py +0 -420
- synth_ai/environments/examples/wordle/__init__.py +0 -29
- synth_ai/environments/examples/wordle/engine.py +0 -398
- synth_ai/environments/examples/wordle/environment.py +0 -159
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +0 -75
- synth_ai/environments/examples/wordle/taskset.py +0 -230
- synth_ai/environments/reproducibility/core.py +0 -42
- synth_ai/environments/reproducibility/helpers.py +0 -0
- synth_ai/environments/reproducibility/tree.py +0 -364
- synth_ai/environments/service/app.py +0 -98
- synth_ai/environments/service/core_routes.py +0 -1020
- synth_ai/environments/service/external_registry.py +0 -56
- synth_ai/environments/service/registry.py +0 -9
- synth_ai/environments/stateful/__init__.py +0 -1
- synth_ai/environments/stateful/core.py +0 -163
- synth_ai/environments/stateful/engine.py +0 -21
- synth_ai/environments/stateful/state.py +0 -7
- synth_ai/environments/tasks/api.py +0 -19
- synth_ai/environments/tasks/core.py +0 -80
- synth_ai/environments/tasks/filters.py +0 -41
- synth_ai/environments/tasks/utils.py +0 -91
- synth_ai/environments/v0_observability/history.py +0 -3
- synth_ai/environments/v0_observability/log.py +0 -2
- synth_ai/evals/base.py +0 -15
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/handshake.py +0 -63
- synth_ai/http.py +0 -26
- synth_ai/http_client.py +0 -104
- synth_ai/inference/client.py +0 -20
- synth_ai/install_sqld.sh +0 -40
- synth_ai/jobs/client.py +0 -246
- synth_ai/learning/__init__.py +0 -24
- synth_ai/learning/config.py +0 -43
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/ft_client.py +0 -59
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/learning/sse.py +0 -58
- synth_ai/learning/validators.py +0 -48
- synth_ai/lm/__init__.py +0 -51
- synth_ai/lm/caching/constants.py +0 -6
- synth_ai/lm/caching/dbs.py +0 -0
- synth_ai/lm/caching/ephemeral.py +0 -102
- synth_ai/lm/caching/handler.py +0 -137
- synth_ai/lm/caching/initialize.py +0 -11
- synth_ai/lm/caching/persistent.py +0 -114
- synth_ai/lm/config.py +0 -110
- synth_ai/lm/constants.py +0 -32
- synth_ai/lm/core/__init__.py +0 -8
- synth_ai/lm/core/all.py +0 -73
- synth_ai/lm/core/exceptions.py +0 -7
- synth_ai/lm/core/main.py +0 -319
- synth_ai/lm/core/main_v3.py +0 -594
- synth_ai/lm/core/synth_models.py +0 -48
- synth_ai/lm/core/vendor_clients.py +0 -188
- synth_ai/lm/cost/__init__.py +0 -0
- synth_ai/lm/cost/monitor.py +0 -1
- synth_ai/lm/cost/statefulness.py +0 -1
- synth_ai/lm/injection.py +0 -80
- synth_ai/lm/overrides.py +0 -206
- synth_ai/lm/provider_support/__init__.py +0 -8
- synth_ai/lm/provider_support/anthropic.py +0 -972
- synth_ai/lm/provider_support/openai.py +0 -1139
- synth_ai/lm/provider_support/suppress_logging.py +0 -31
- synth_ai/lm/structured_outputs/__init__.py +0 -0
- synth_ai/lm/structured_outputs/handler.py +0 -440
- synth_ai/lm/structured_outputs/inject.py +0 -297
- synth_ai/lm/structured_outputs/rehabilitate.py +0 -185
- synth_ai/lm/tools/__init__.py +0 -3
- synth_ai/lm/tools/base.py +0 -172
- synth_ai/lm/unified_interface.py +0 -202
- synth_ai/lm/vendors/__init__.py +0 -0
- synth_ai/lm/vendors/base.py +0 -81
- synth_ai/lm/vendors/core/__init__.py +0 -0
- synth_ai/lm/vendors/core/anthropic_api.py +0 -387
- synth_ai/lm/vendors/core/gemini_api.py +0 -292
- synth_ai/lm/vendors/core/mistral_api.py +0 -322
- synth_ai/lm/vendors/core/openai_api.py +0 -225
- synth_ai/lm/vendors/core/synth_dev_api.py +0 -0
- synth_ai/lm/vendors/local/__init__.py +0 -0
- synth_ai/lm/vendors/local/ollama.py +0 -0
- synth_ai/lm/vendors/openai_standard.py +0 -780
- synth_ai/lm/vendors/openai_standard_responses.py +0 -256
- synth_ai/lm/vendors/retries.py +0 -22
- synth_ai/lm/vendors/supported/__init__.py +0 -0
- synth_ai/lm/vendors/supported/custom_endpoint.py +0 -417
- synth_ai/lm/vendors/supported/deepseek.py +0 -69
- synth_ai/lm/vendors/supported/grok.py +0 -75
- synth_ai/lm/vendors/supported/groq.py +0 -16
- synth_ai/lm/vendors/supported/ollama.py +0 -15
- synth_ai/lm/vendors/supported/openrouter.py +0 -74
- synth_ai/lm/vendors/supported/together.py +0 -11
- synth_ai/lm/vendors/synth_client.py +0 -808
- synth_ai/lm/warmup.py +0 -186
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/task/__init__.py +0 -10
- synth_ai/task/contracts.py +0 -120
- synth_ai/task/health.py +0 -28
- synth_ai/task/validators.py +0 -12
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/config.py +0 -84
- synth_ai/tracing_v3/storage/config.py +0 -62
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/daemon.py +0 -144
- synth_ai/tracing_v3/turso/manager.py +0 -760
- synth_ai/v0/tracing/__init__.py +0 -0
- synth_ai/v0/tracing/abstractions.py +0 -224
- synth_ai/v0/tracing/base_client.py +0 -91
- synth_ai/v0/tracing/client_manager.py +0 -131
- synth_ai/v0/tracing/config.py +0 -142
- synth_ai/v0/tracing/context.py +0 -146
- synth_ai/v0/tracing/decorators.py +0 -682
- synth_ai/v0/tracing/events/__init__.py +0 -0
- synth_ai/v0/tracing/events/manage.py +0 -147
- synth_ai/v0/tracing/events/scope.py +0 -86
- synth_ai/v0/tracing/events/store.py +0 -228
- synth_ai/v0/tracing/immediate_client.py +0 -151
- synth_ai/v0/tracing/local.py +0 -18
- synth_ai/v0/tracing/log_client_base.py +0 -73
- synth_ai/v0/tracing/retry_queue.py +0 -186
- synth_ai/v0/tracing/trackers.py +0 -515
- synth_ai/v0/tracing/upload.py +0 -512
- synth_ai/v0/tracing/utils.py +0 -9
- synth_ai/v0/tracing_v1/__init__.py +0 -16
- synth_ai/v0/tracing_v1/abstractions.py +0 -224
- synth_ai/v0/tracing_v1/base_client.py +0 -91
- synth_ai/v0/tracing_v1/client_manager.py +0 -131
- synth_ai/v0/tracing_v1/config.py +0 -142
- synth_ai/v0/tracing_v1/context.py +0 -146
- synth_ai/v0/tracing_v1/decorators.py +0 -703
- synth_ai/v0/tracing_v1/events/__init__.py +0 -0
- synth_ai/v0/tracing_v1/events/manage.py +0 -147
- synth_ai/v0/tracing_v1/events/scope.py +0 -86
- synth_ai/v0/tracing_v1/events/store.py +0 -228
- synth_ai/v0/tracing_v1/immediate_client.py +0 -151
- synth_ai/v0/tracing_v1/local.py +0 -18
- synth_ai/v0/tracing_v1/log_client_base.py +0 -73
- synth_ai/v0/tracing_v1/retry_queue.py +0 -186
- synth_ai/v0/tracing_v1/trackers.py +0 -515
- synth_ai/v0/tracing_v1/upload.py +0 -527
- synth_ai/v0/tracing_v1/utils.py +0 -9
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.8.dev2.dist-info/METADATA +0 -129
- synth_ai-0.2.8.dev2.dist-info/RECORD +0 -420
- /synth_ai/{demos → cli/demo_apps}/demo_task_apps/math/__init__.py +0 -0
- /synth_ai/{lm/caching → core/apps}/__init__.py +0 -0
- /synth_ai/{tracing_v3 → core/tracing_v3}/lm_call_record_abstractions.py +0 -0
- /synth_ai/{tracing_v3 → core/tracing_v3}/storage/__init__.py +0 -0
- /synth_ai/{tracing_v3 → core/tracing_v3}/storage/exceptions.py +0 -0
- /synth_ai/{tracing_v3 → core/tracing_v3}/storage/types.py +0 -0
- /synth_ai/{compound/cais.py → py.typed} +0 -0
- /synth_ai/{learning → sdk/learning}/core.py +0 -0
- /synth_ai/{learning → sdk/learning}/gateway.py +0 -0
- {synth_ai-0.2.8.dev2.dist-info → synth_ai-0.4.3.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev2.dist-info → synth_ai-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev2.dist-info → synth_ai-0.4.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,2424 @@
|
|
|
1
|
+
"""SDK-side validation for training configs - catch errors BEFORE sending to backend."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import warnings
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, List, Tuple
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
import toml
|
|
10
|
+
|
|
11
|
+
# Import unknown field validation from CLI module
|
|
12
|
+
from synth_ai.cli.commands.train.prompt_learning_validation import (
|
|
13
|
+
validate_prompt_learning_config as _validate_unknown_fields,
|
|
14
|
+
)
|
|
15
|
+
from synth_ai.core.telemetry import log_info
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConfigValidationError(Exception):
|
|
19
|
+
"""Raised when a training config is invalid."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Supported models for prompt learning (GEPA & MIPRO)
|
|
24
|
+
# NOTE: gpt-5-pro is explicitly EXCLUDED - too expensive for prompt learning
|
|
25
|
+
OPENAI_SUPPORTED_MODELS = {
|
|
26
|
+
"gpt-4o",
|
|
27
|
+
"gpt-4o-mini",
|
|
28
|
+
"gpt-4.1",
|
|
29
|
+
"gpt-4.1-mini",
|
|
30
|
+
"gpt-4.1-nano",
|
|
31
|
+
"gpt-5",
|
|
32
|
+
"gpt-5-mini",
|
|
33
|
+
"gpt-5-nano",
|
|
34
|
+
# Explicitly EXCLUDED: "gpt-5-pro" - too expensive
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
# Groq supported models - patterns and exact matches
|
|
38
|
+
# Models can be in format "model-name" or "provider/model-name" (e.g., "openai/gpt-oss-20b")
|
|
39
|
+
GROQ_SUPPORTED_PATTERNS = [
|
|
40
|
+
re.compile(r"^(openai/)?gpt-oss-\d+b"), # e.g., gpt-oss-20b, openai/gpt-oss-120b
|
|
41
|
+
re.compile(r"^(llama-3\.3-70b|groq/llama-3\.3-70b)"), # e.g., llama-3.3-70b-versatile
|
|
42
|
+
re.compile(r"^(qwen.*32b|groq/qwen.*32b)"), # e.g., qwen-32b, qwen3-32b, groq/qwen3-32b
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
GROQ_EXACT_MATCHES = {
|
|
46
|
+
"llama-3.3-70b",
|
|
47
|
+
"llama-3.1-8b-instant",
|
|
48
|
+
"qwen-32b",
|
|
49
|
+
"qwen3-32b",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# Google/Gemini supported models
|
|
53
|
+
GOOGLE_SUPPORTED_MODELS = {
|
|
54
|
+
"gemini-2.5-pro",
|
|
55
|
+
"gemini-2.5-pro-gt200k",
|
|
56
|
+
"gemini-2.5-flash",
|
|
57
|
+
"gemini-2.5-flash-lite",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_supported_openai_model(model: str) -> bool:
|
|
62
|
+
"""Check if model is a supported OpenAI model."""
|
|
63
|
+
model_lower = model.lower().strip()
|
|
64
|
+
# Strip provider prefix if present (e.g., "openai/gpt-4o" -> "gpt-4o")
|
|
65
|
+
if "/" in model_lower:
|
|
66
|
+
model_lower = model_lower.split("/", 1)[1]
|
|
67
|
+
return model_lower in {m.lower() for m in OPENAI_SUPPORTED_MODELS}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _is_supported_groq_model(model: str) -> bool:
|
|
71
|
+
"""Check if model is a supported Groq model."""
|
|
72
|
+
model_lower = model.lower().strip()
|
|
73
|
+
|
|
74
|
+
# Remove provider prefix if present (e.g., "openai/gpt-oss-20b" -> "gpt-oss-20b")
|
|
75
|
+
if "/" in model_lower:
|
|
76
|
+
model_lower = model_lower.split("/", 1)[1]
|
|
77
|
+
|
|
78
|
+
# Check exact matches first
|
|
79
|
+
if model_lower in {m.lower() for m in GROQ_EXACT_MATCHES}:
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
# Check patterns (patterns already handle provider prefix)
|
|
83
|
+
return any(pattern.match(model.lower().strip()) for pattern in GROQ_SUPPORTED_PATTERNS)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _is_supported_google_model(model: str) -> bool:
|
|
87
|
+
"""Check if model is a supported Google/Gemini model."""
|
|
88
|
+
model_lower = model.lower().strip()
|
|
89
|
+
# Strip provider prefix if present (e.g., "google/gemini-2.5-flash-lite" -> "gemini-2.5-flash-lite")
|
|
90
|
+
if "/" in model_lower:
|
|
91
|
+
model_lower = model_lower.split("/", 1)[1]
|
|
92
|
+
return model_lower in {m.lower() for m in GOOGLE_SUPPORTED_MODELS}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _validate_adaptive_pool_config(
|
|
96
|
+
adaptive_pool_section: dict[str, Any],
|
|
97
|
+
prefix: str, # e.g., "gepa.adaptive_pool" or "mipro.adaptive_pool"
|
|
98
|
+
errors: list[str],
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Validate adaptive_pool configuration section.
|
|
101
|
+
|
|
102
|
+
Validates all fields in adaptive_pool config including:
|
|
103
|
+
- Level presets (NONE, LOW, MODERATE, HIGH)
|
|
104
|
+
- Numeric fields with min/max constraints
|
|
105
|
+
- Relationship constraints (pool_init_size >= pool_min_size >= anchor_size)
|
|
106
|
+
- String enum fields (anchor_selection_method, exploration_strategy, etc.)
|
|
107
|
+
- Heat-up phase configuration
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
adaptive_pool_section: Dict containing adaptive_pool config with fields:
|
|
111
|
+
- level: Preset level (NONE, LOW, MODERATE, HIGH)
|
|
112
|
+
- anchor_size: Number of anchor examples (always evaluated)
|
|
113
|
+
- pool_init_size: Initial pool size
|
|
114
|
+
- pool_min_size: Target minimum pool size after annealing
|
|
115
|
+
- warmup_iters: Iterations before starting annealing
|
|
116
|
+
- anneal_stop_iter: Iteration when pool reaches min_size
|
|
117
|
+
- pool_update_period: Update informativeness every N generations
|
|
118
|
+
- min_evals_per_example: Min evals before computing informativeness
|
|
119
|
+
- k_info_prompts: Number of prompts for informativeness
|
|
120
|
+
- info_buffer_factor: Buffer factor (0.0-1.0) for preserving info
|
|
121
|
+
- info_epsilon: Epsilon for informativeness calculations
|
|
122
|
+
- anchor_selection_method: "random" or "clustering"
|
|
123
|
+
- exploration_strategy: "random" or "diversity"
|
|
124
|
+
- heatup_trigger: "after_min_size", "immediate", or "every_N_trials_after_min"
|
|
125
|
+
- heatup_schedule: "repeat" or "once"
|
|
126
|
+
- heatup_size: Number of seeds to add during heat-up
|
|
127
|
+
- heatup_cooldown_trials: Trials to wait before cooling down
|
|
128
|
+
- heatup_reserve_pool: Optional list of seed IDs for heat-up
|
|
129
|
+
prefix: Prefix for error messages (e.g., "gepa.adaptive_pool" or "mipro.adaptive_pool")
|
|
130
|
+
errors: List to append validation errors to
|
|
131
|
+
"""
|
|
132
|
+
if not isinstance(adaptive_pool_section, dict):
|
|
133
|
+
errors.append(f"❌ {prefix} must be a table/dict when provided")
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
# Validate level
|
|
137
|
+
level = adaptive_pool_section.get("level")
|
|
138
|
+
if level is not None:
|
|
139
|
+
valid_levels = {"NONE", "LOW", "MODERATE", "HIGH"}
|
|
140
|
+
if str(level).upper() not in valid_levels:
|
|
141
|
+
errors.append(
|
|
142
|
+
f"❌ {prefix}.level must be one of {valid_levels}, got '{level}'"
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Validate numeric fields
|
|
146
|
+
for field, min_val in [
|
|
147
|
+
("anchor_size", 0),
|
|
148
|
+
("pool_init_size", 0),
|
|
149
|
+
("pool_min_size", 0),
|
|
150
|
+
("warmup_iters", 0),
|
|
151
|
+
("anneal_stop_iter", 0),
|
|
152
|
+
("pool_update_period", 1),
|
|
153
|
+
("min_evals_per_example", 1),
|
|
154
|
+
("k_info_prompts", 0),
|
|
155
|
+
]:
|
|
156
|
+
val = adaptive_pool_section.get(field)
|
|
157
|
+
if val is not None:
|
|
158
|
+
try:
|
|
159
|
+
ival = int(val)
|
|
160
|
+
if ival < min_val:
|
|
161
|
+
errors.append(f"❌ {prefix}.{field} must be >= {min_val}, got {ival}")
|
|
162
|
+
except (TypeError, ValueError):
|
|
163
|
+
errors.append(f"❌ {prefix}.{field} must be an integer, got {type(val).__name__}")
|
|
164
|
+
|
|
165
|
+
# Validate pool_init_size >= pool_min_size if both provided
|
|
166
|
+
pool_init = adaptive_pool_section.get("pool_init_size")
|
|
167
|
+
pool_min = adaptive_pool_section.get("pool_min_size")
|
|
168
|
+
if pool_init is not None and pool_min is not None:
|
|
169
|
+
try:
|
|
170
|
+
pool_init_int = int(pool_init)
|
|
171
|
+
pool_min_int = int(pool_min)
|
|
172
|
+
if pool_init_int < pool_min_int:
|
|
173
|
+
errors.append(
|
|
174
|
+
f"❌ {prefix}.pool_init_size ({pool_init}) must be >= pool_min_size ({pool_min})"
|
|
175
|
+
)
|
|
176
|
+
except (TypeError, ValueError):
|
|
177
|
+
pass # Already validated above
|
|
178
|
+
|
|
179
|
+
# Validate pool_min_size >= anchor_size if both provided
|
|
180
|
+
anchor_size = adaptive_pool_section.get("anchor_size")
|
|
181
|
+
if pool_min is not None and anchor_size is not None:
|
|
182
|
+
try:
|
|
183
|
+
pool_min_int = int(pool_min)
|
|
184
|
+
anchor_size_int = int(anchor_size)
|
|
185
|
+
if pool_min_int < anchor_size_int:
|
|
186
|
+
errors.append(
|
|
187
|
+
f"❌ {prefix}.pool_min_size ({pool_min}) must be >= anchor_size ({anchor_size})"
|
|
188
|
+
)
|
|
189
|
+
except (TypeError, ValueError):
|
|
190
|
+
pass # Already validated above
|
|
191
|
+
|
|
192
|
+
# Validate info_buffer_factor and info_epsilon
|
|
193
|
+
for field, min_val, max_val in [("info_buffer_factor", 0.0, 1.0), ("info_epsilon", 0.0, None)]:
|
|
194
|
+
val = adaptive_pool_section.get(field)
|
|
195
|
+
if val is not None:
|
|
196
|
+
try:
|
|
197
|
+
fval = float(val)
|
|
198
|
+
if fval < min_val:
|
|
199
|
+
errors.append(f"❌ {prefix}.{field} must be >= {min_val}, got {fval}")
|
|
200
|
+
if max_val is not None and fval > max_val:
|
|
201
|
+
errors.append(f"❌ {prefix}.{field} must be <= {max_val}, got {fval}")
|
|
202
|
+
except (TypeError, ValueError):
|
|
203
|
+
errors.append(f"❌ {prefix}.{field} must be numeric, got {type(val).__name__}")
|
|
204
|
+
|
|
205
|
+
# Validate string fields
|
|
206
|
+
anchor_method = adaptive_pool_section.get("anchor_selection_method")
|
|
207
|
+
if anchor_method is not None and anchor_method not in ("random", "clustering"):
|
|
208
|
+
errors.append(
|
|
209
|
+
f"❌ {prefix}.anchor_selection_method must be 'random' or 'clustering', got '{anchor_method}'"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
exploration_strategy = adaptive_pool_section.get("exploration_strategy")
|
|
213
|
+
if exploration_strategy is not None and exploration_strategy not in ("random", "diversity"):
|
|
214
|
+
errors.append(
|
|
215
|
+
f"❌ {prefix}.exploration_strategy must be 'random' or 'diversity', got '{exploration_strategy}'"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Validate heatup fields
|
|
219
|
+
heatup_trigger = adaptive_pool_section.get("heatup_trigger")
|
|
220
|
+
if heatup_trigger is not None and heatup_trigger not in ("after_min_size", "immediate", "every_N_trials_after_min"):
|
|
221
|
+
errors.append(
|
|
222
|
+
f"❌ {prefix}.heatup_trigger must be 'after_min_size', 'immediate', or 'every_N_trials_after_min', got '{heatup_trigger}'"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
heatup_schedule = adaptive_pool_section.get("heatup_schedule")
|
|
226
|
+
if heatup_schedule is not None and heatup_schedule not in ("repeat", "once"):
|
|
227
|
+
errors.append(
|
|
228
|
+
f"❌ {prefix}.heatup_schedule must be 'repeat' or 'once', got '{heatup_schedule}'"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
heatup_size = adaptive_pool_section.get("heatup_size")
|
|
232
|
+
if heatup_size is not None:
|
|
233
|
+
try:
|
|
234
|
+
if int(heatup_size) <= 0:
|
|
235
|
+
errors.append(f"❌ {prefix}.heatup_size must be > 0, got {heatup_size}")
|
|
236
|
+
except (TypeError, ValueError):
|
|
237
|
+
errors.append(f"❌ {prefix}.heatup_size must be an integer, got {type(heatup_size).__name__}")
|
|
238
|
+
|
|
239
|
+
heatup_cooldown_trials = adaptive_pool_section.get("heatup_cooldown_trials")
|
|
240
|
+
if heatup_cooldown_trials is not None:
|
|
241
|
+
try:
|
|
242
|
+
if int(heatup_cooldown_trials) < 0:
|
|
243
|
+
errors.append(f"❌ {prefix}.heatup_cooldown_trials must be >= 0, got {heatup_cooldown_trials}")
|
|
244
|
+
except (TypeError, ValueError):
|
|
245
|
+
errors.append(f"❌ {prefix}.heatup_cooldown_trials must be an integer, got {type(heatup_cooldown_trials).__name__}")
|
|
246
|
+
|
|
247
|
+
heatup_reserve_pool = adaptive_pool_section.get("heatup_reserve_pool")
|
|
248
|
+
if heatup_reserve_pool is not None:
|
|
249
|
+
if not isinstance(heatup_reserve_pool, list):
|
|
250
|
+
errors.append(f"❌ {prefix}.heatup_reserve_pool must be a list, got {type(heatup_reserve_pool).__name__}")
|
|
251
|
+
elif not all(isinstance(s, int) for s in heatup_reserve_pool):
|
|
252
|
+
errors.append(f"❌ {prefix}.heatup_reserve_pool must contain only integers")
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _validate_model_for_provider(model: str, provider: str, field_name: str, *, allow_nano: bool = False) -> list[str]:
|
|
256
|
+
"""
|
|
257
|
+
Validate that a model is supported for the given provider.
|
|
258
|
+
|
|
259
|
+
Models can be specified with or without provider prefix (e.g., "gpt-4o" or "openai/gpt-4o").
|
|
260
|
+
The provider prefix is stripped before validation.
|
|
261
|
+
|
|
262
|
+
REJECTS gpt-5-pro explicitly (too expensive).
|
|
263
|
+
REJECTS nano models for proposal/mutation models (unless allow_nano=True).
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
model: Model name to validate
|
|
267
|
+
provider: Provider name (openai, groq, google)
|
|
268
|
+
field_name: Field name for error messages (e.g., "prompt_learning.policy.model")
|
|
269
|
+
allow_nano: If True, allow nano models (for policy models). If False, reject nano models.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
List of error messages (empty if valid)
|
|
273
|
+
"""
|
|
274
|
+
errors: list[str] = []
|
|
275
|
+
|
|
276
|
+
if not model or not isinstance(model, str) or not model.strip():
|
|
277
|
+
errors.append(f"Missing or empty {field_name}")
|
|
278
|
+
return errors
|
|
279
|
+
|
|
280
|
+
provider_lower = provider.lower().strip()
|
|
281
|
+
model_lower = model.lower().strip()
|
|
282
|
+
|
|
283
|
+
# Strip provider prefix if present (e.g., "openai/gpt-4o" -> "gpt-4o")
|
|
284
|
+
model_without_prefix = model_lower.split("/", 1)[1] if "/" in model_lower else model_lower
|
|
285
|
+
|
|
286
|
+
# Explicitly reject gpt-5-pro (too expensive)
|
|
287
|
+
if model_without_prefix == "gpt-5-pro":
|
|
288
|
+
errors.append(
|
|
289
|
+
f"Model '{model}' is not supported for prompt learning (too expensive).\n"
|
|
290
|
+
f" gpt-5-pro is excluded due to high cost ($15/$120 per 1M tokens).\n"
|
|
291
|
+
f" Please use a supported model instead."
|
|
292
|
+
)
|
|
293
|
+
return errors
|
|
294
|
+
|
|
295
|
+
# Reject nano models for proposal/mutation models (unless explicitly allowed)
|
|
296
|
+
if not allow_nano and model_without_prefix.endswith("-nano"):
|
|
297
|
+
errors.append(
|
|
298
|
+
f"Model '{model}' is not supported for {field_name}.\n"
|
|
299
|
+
f" ❌ Nano models (e.g., gpt-4.1-nano, gpt-5-nano) are NOT allowed for proposal/mutation models.\n"
|
|
300
|
+
f" \n"
|
|
301
|
+
f" Why?\n"
|
|
302
|
+
f" Proposal and mutation models need to be SMART and capable of generating high-quality,\n"
|
|
303
|
+
f" creative prompt variations. Nano models are too small and lack the reasoning capability\n"
|
|
304
|
+
f" needed for effective prompt optimization.\n"
|
|
305
|
+
f" \n"
|
|
306
|
+
f" ✅ Use a larger model instead:\n"
|
|
307
|
+
f" - For OpenAI: gpt-4.1-mini, gpt-4o-mini, gpt-4o, or gpt-4.1\n"
|
|
308
|
+
f" - For Groq: openai/gpt-oss-120b, llama-3.3-70b-versatile\n"
|
|
309
|
+
f" - For Google: gemini-2.5-flash, gemini-2.5-pro\n"
|
|
310
|
+
f" \n"
|
|
311
|
+
f" Note: Nano models ARE allowed for policy models (task execution), but NOT for\n"
|
|
312
|
+
f" proposal/mutation models (prompt generation)."
|
|
313
|
+
)
|
|
314
|
+
return errors
|
|
315
|
+
|
|
316
|
+
if provider_lower == "openai":
|
|
317
|
+
if not _is_supported_openai_model(model_without_prefix):
|
|
318
|
+
errors.append(
|
|
319
|
+
f"Unsupported OpenAI model: '{model}'\n"
|
|
320
|
+
f" Supported OpenAI models for prompt learning:\n"
|
|
321
|
+
f" - gpt-4o\n"
|
|
322
|
+
f" - gpt-4o-mini\n"
|
|
323
|
+
f" - gpt-4.1, gpt-4.1-mini, gpt-4.1-nano\n"
|
|
324
|
+
f" - gpt-5, gpt-5-mini, gpt-5-nano\n"
|
|
325
|
+
f" Note: gpt-5-pro is excluded (too expensive)\n"
|
|
326
|
+
f" Got: '{model}'"
|
|
327
|
+
)
|
|
328
|
+
elif provider_lower == "groq":
|
|
329
|
+
# For Groq, check both with and without prefix since models can be "openai/gpt-oss-20b"
|
|
330
|
+
if not _is_supported_groq_model(model_lower):
|
|
331
|
+
errors.append(
|
|
332
|
+
f"Unsupported Groq model: '{model}'\n"
|
|
333
|
+
f" Supported Groq models for prompt learning:\n"
|
|
334
|
+
f" - gpt-oss-Xb (e.g., gpt-oss-20b, openai/gpt-oss-120b)\n"
|
|
335
|
+
f" - llama-3.3-70b (and variants like llama-3.3-70b-versatile)\n"
|
|
336
|
+
f" - llama-3.1-8b-instant\n"
|
|
337
|
+
f" - qwen/qwen3-32b (and variants)\n"
|
|
338
|
+
f" Got: '{model}'"
|
|
339
|
+
)
|
|
340
|
+
elif provider_lower == "google":
|
|
341
|
+
if not _is_supported_google_model(model_without_prefix):
|
|
342
|
+
errors.append(
|
|
343
|
+
f"Unsupported Google/Gemini model: '{model}'\n"
|
|
344
|
+
f" Supported Google models for prompt learning:\n"
|
|
345
|
+
f" - gemini-2.5-pro, gemini-2.5-pro-gt200k\n"
|
|
346
|
+
f" - gemini-2.5-flash\n"
|
|
347
|
+
f" - gemini-2.5-flash-lite\n"
|
|
348
|
+
f" Got: '{model}'"
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
errors.append(
|
|
352
|
+
f"Unsupported provider: '{provider}'\n"
|
|
353
|
+
f" Supported providers for prompt learning: 'openai', 'groq', 'google'\n"
|
|
354
|
+
f" Got: '{provider}'"
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
return errors
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def validate_prompt_learning_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
361
|
+
"""
|
|
362
|
+
Validate prompt learning config BEFORE sending to backend.
|
|
363
|
+
|
|
364
|
+
This catches common errors early with clear messages instead of cryptic backend errors.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
config_data: Parsed TOML/JSON config
|
|
368
|
+
config_path: Path to config file (for error messages)
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
ConfigValidationError: If config is invalid
|
|
372
|
+
click.ClickException: If validation fails (for CLI)
|
|
373
|
+
"""
|
|
374
|
+
ctx: dict[str, Any] = {"config_path": str(config_path)}
|
|
375
|
+
log_info("validate_prompt_learning_config invoked", ctx=ctx)
|
|
376
|
+
errors: list[str] = []
|
|
377
|
+
|
|
378
|
+
# Run unknown field validation (warnings only, doesn't raise)
|
|
379
|
+
try:
|
|
380
|
+
validation_result = _validate_unknown_fields(config_data, config_path=config_path)
|
|
381
|
+
# Print warnings about unknown fields and deprecated sections
|
|
382
|
+
for warning_msg in validation_result.warnings:
|
|
383
|
+
warnings.warn(warning_msg, UserWarning, stacklevel=3)
|
|
384
|
+
except Exception:
|
|
385
|
+
# Don't fail validation if unknown field check fails
|
|
386
|
+
pass
|
|
387
|
+
|
|
388
|
+
# Check for prompt_learning section
|
|
389
|
+
pl_section = config_data.get("prompt_learning")
|
|
390
|
+
if not pl_section:
|
|
391
|
+
errors.append(
|
|
392
|
+
"Missing [prompt_learning] section in config. "
|
|
393
|
+
"Expected: [prompt_learning] with algorithm, task_app_url, etc."
|
|
394
|
+
)
|
|
395
|
+
_raise_validation_errors(errors, config_path)
|
|
396
|
+
return
|
|
397
|
+
|
|
398
|
+
if not isinstance(pl_section, dict):
|
|
399
|
+
errors.append(
|
|
400
|
+
f"[prompt_learning] must be a table/dict, got {type(pl_section).__name__}"
|
|
401
|
+
)
|
|
402
|
+
_raise_validation_errors(errors, config_path)
|
|
403
|
+
return
|
|
404
|
+
|
|
405
|
+
# CRITICAL: Validate algorithm field
|
|
406
|
+
algorithm = pl_section.get("algorithm")
|
|
407
|
+
if not algorithm:
|
|
408
|
+
errors.append(
|
|
409
|
+
"Missing required field: prompt_learning.algorithm\n"
|
|
410
|
+
" Must be one of: 'gepa', 'mipro'\n"
|
|
411
|
+
" Example:\n"
|
|
412
|
+
" [prompt_learning]\n"
|
|
413
|
+
" algorithm = \"gepa\""
|
|
414
|
+
)
|
|
415
|
+
elif algorithm not in ("gepa", "mipro"):
|
|
416
|
+
errors.append(
|
|
417
|
+
f"Invalid algorithm: '{algorithm}'\n"
|
|
418
|
+
f" Must be one of: 'gepa', 'mipro'\n"
|
|
419
|
+
f" Got: '{algorithm}'"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Validate task_app_url
|
|
423
|
+
task_app_url = pl_section.get("task_app_url")
|
|
424
|
+
if not task_app_url:
|
|
425
|
+
errors.append(
|
|
426
|
+
"Missing required field: prompt_learning.task_app_url\n"
|
|
427
|
+
" Example:\n"
|
|
428
|
+
" task_app_url = \"http://127.0.0.1:8102\""
|
|
429
|
+
)
|
|
430
|
+
elif not isinstance(task_app_url, str):
|
|
431
|
+
errors.append(
|
|
432
|
+
f"task_app_url must be a string, got {type(task_app_url).__name__}"
|
|
433
|
+
)
|
|
434
|
+
elif not task_app_url.startswith(("http://", "https://")):
|
|
435
|
+
errors.append(
|
|
436
|
+
f"task_app_url must start with http:// or https://, got: '{task_app_url}'"
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Validate initial_prompt if present
|
|
440
|
+
initial_prompt = pl_section.get("initial_prompt")
|
|
441
|
+
if initial_prompt:
|
|
442
|
+
if not isinstance(initial_prompt, dict):
|
|
443
|
+
errors.append(
|
|
444
|
+
f"prompt_learning.initial_prompt must be a table/dict, got {type(initial_prompt).__name__}"
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
# Validate messages array
|
|
448
|
+
messages = initial_prompt.get("messages")
|
|
449
|
+
if messages is not None:
|
|
450
|
+
if not isinstance(messages, list):
|
|
451
|
+
errors.append(
|
|
452
|
+
f"prompt_learning.initial_prompt.messages must be an array, got {type(messages).__name__}"
|
|
453
|
+
)
|
|
454
|
+
elif len(messages) == 0:
|
|
455
|
+
errors.append(
|
|
456
|
+
"prompt_learning.initial_prompt.messages is empty (must have at least one message)"
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Validate policy config
|
|
460
|
+
policy = pl_section.get("policy")
|
|
461
|
+
if not policy or not isinstance(policy, dict):
|
|
462
|
+
errors.append("Missing [prompt_learning.policy] section or not a table")
|
|
463
|
+
else:
|
|
464
|
+
# Enforce inference_mode
|
|
465
|
+
mode = str(policy.get("inference_mode", "")).strip().lower()
|
|
466
|
+
if not mode:
|
|
467
|
+
errors.append("Missing required field: prompt_learning.policy.inference_mode (must be 'synth_hosted')")
|
|
468
|
+
elif mode != "synth_hosted":
|
|
469
|
+
errors.append("prompt_learning.policy.inference_mode must be 'synth_hosted' (bring_your_own unsupported)")
|
|
470
|
+
# Required fields for synth_hosted
|
|
471
|
+
provider = (policy.get("provider") or "").strip()
|
|
472
|
+
model = (policy.get("model") or "").strip()
|
|
473
|
+
if not provider:
|
|
474
|
+
errors.append("Missing required field: prompt_learning.policy.provider")
|
|
475
|
+
if not model:
|
|
476
|
+
errors.append("Missing required field: prompt_learning.policy.model")
|
|
477
|
+
else:
|
|
478
|
+
# Validate model is supported for the provider
|
|
479
|
+
if provider:
|
|
480
|
+
errors.extend(_validate_model_for_provider(
|
|
481
|
+
model, provider, "prompt_learning.policy.model", allow_nano=True
|
|
482
|
+
))
|
|
483
|
+
# VALIDATION: Reject inference_url in config - trainer must provide it in rollout requests
|
|
484
|
+
if "inference_url" in policy:
|
|
485
|
+
errors.append(
|
|
486
|
+
"inference_url must not be specified in [prompt_learning.policy]. "
|
|
487
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
488
|
+
"Remove inference_url from your config file."
|
|
489
|
+
)
|
|
490
|
+
if "api_base" in policy:
|
|
491
|
+
errors.append(
|
|
492
|
+
"api_base must not be specified in [prompt_learning.policy]. "
|
|
493
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
494
|
+
"Remove api_base from your config file."
|
|
495
|
+
)
|
|
496
|
+
if "base_url" in policy:
|
|
497
|
+
errors.append(
|
|
498
|
+
"base_url must not be specified in [prompt_learning.policy]. "
|
|
499
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
500
|
+
"Remove base_url from your config file."
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# Validate proxy_models config (can be at top-level or algorithm-specific)
|
|
504
|
+
proxy_models_section = pl_section.get("proxy_models")
|
|
505
|
+
if proxy_models_section:
|
|
506
|
+
if not isinstance(proxy_models_section, dict):
|
|
507
|
+
errors.append(f"prompt_learning.proxy_models must be a table/dict, got {type(proxy_models_section).__name__}")
|
|
508
|
+
else:
|
|
509
|
+
required_fields = ["hi_provider", "hi_model", "lo_provider", "lo_model"]
|
|
510
|
+
for field in required_fields:
|
|
511
|
+
if not proxy_models_section.get(field):
|
|
512
|
+
errors.append(f"prompt_learning.proxy_models.{field} is required")
|
|
513
|
+
# Validate numeric fields
|
|
514
|
+
for field, min_val in [("n_min_hi", 0), ("r2_thresh", 0.0), ("r2_stop", 0.0), ("sigma_max", 0.0), ("sigma_stop", 0.0), ("verify_every", 0)]:
|
|
515
|
+
val = proxy_models_section.get(field)
|
|
516
|
+
if val is not None:
|
|
517
|
+
try:
|
|
518
|
+
if field in ("r2_thresh", "r2_stop"):
|
|
519
|
+
fval = float(val)
|
|
520
|
+
if not (0.0 <= fval <= 1.0):
|
|
521
|
+
errors.append(f"prompt_learning.proxy_models.{field} must be between 0.0 and 1.0, got {fval}")
|
|
522
|
+
elif field.startswith("sigma"):
|
|
523
|
+
fval = float(val)
|
|
524
|
+
if fval < min_val:
|
|
525
|
+
errors.append(f"prompt_learning.proxy_models.{field} must be >= {min_val}, got {fval}")
|
|
526
|
+
else:
|
|
527
|
+
ival = int(val)
|
|
528
|
+
if ival < min_val:
|
|
529
|
+
errors.append(f"prompt_learning.proxy_models.{field} must be >= {min_val}, got {ival}")
|
|
530
|
+
except (TypeError, ValueError):
|
|
531
|
+
errors.append(f"prompt_learning.proxy_models.{field} must be numeric, got {type(val).__name__}")
|
|
532
|
+
# Validate provider/model combinations
|
|
533
|
+
if proxy_models_section.get("hi_provider") and proxy_models_section.get("hi_model"):
|
|
534
|
+
hi_errors = _validate_model_for_provider(
|
|
535
|
+
proxy_models_section["hi_model"],
|
|
536
|
+
proxy_models_section["hi_provider"],
|
|
537
|
+
"prompt_learning.proxy_models.hi_model",
|
|
538
|
+
allow_nano=True,
|
|
539
|
+
)
|
|
540
|
+
errors.extend(hi_errors)
|
|
541
|
+
if proxy_models_section.get("lo_provider") and proxy_models_section.get("lo_model"):
|
|
542
|
+
lo_errors = _validate_model_for_provider(
|
|
543
|
+
proxy_models_section["lo_model"],
|
|
544
|
+
proxy_models_section["lo_provider"],
|
|
545
|
+
"prompt_learning.proxy_models.lo_model",
|
|
546
|
+
allow_nano=True,
|
|
547
|
+
)
|
|
548
|
+
errors.extend(lo_errors)
|
|
549
|
+
|
|
550
|
+
# Validate judge config (shared by GEPA and MIPRO)
|
|
551
|
+
judge_section = pl_section.get("judge") or {}
|
|
552
|
+
if judge_section:
|
|
553
|
+
if not isinstance(judge_section, dict):
|
|
554
|
+
errors.append(f"prompt_learning.judge must be a table/dict, got {type(judge_section).__name__}")
|
|
555
|
+
else:
|
|
556
|
+
reward_source = str(judge_section.get("reward_source", "task_app")).strip().lower()
|
|
557
|
+
enabled = bool(judge_section.get("enabled"))
|
|
558
|
+
if reward_source and reward_source not in {"task_app", "judge", "fused"}:
|
|
559
|
+
errors.append("prompt_learning.judge.reward_source must be 'task_app', 'judge', or 'fused'")
|
|
560
|
+
backend_base = str(judge_section.get("backend_base", "") or "").strip()
|
|
561
|
+
backend_provider = str(judge_section.get("backend_provider", "") or "").strip()
|
|
562
|
+
backend_model = str(judge_section.get("backend_model", "") or "").strip()
|
|
563
|
+
if enabled:
|
|
564
|
+
pass
|
|
565
|
+
if reward_source == "fused":
|
|
566
|
+
weight_event = judge_section.get("weight_event", 0.0)
|
|
567
|
+
weight_outcome = judge_section.get("weight_outcome", 0.0)
|
|
568
|
+
try:
|
|
569
|
+
weight_event_f = float(weight_event)
|
|
570
|
+
except (TypeError, ValueError):
|
|
571
|
+
errors.append("prompt_learning.judge.weight_event must be numeric")
|
|
572
|
+
weight_event_f = 0.0
|
|
573
|
+
try:
|
|
574
|
+
weight_outcome_f = float(weight_outcome)
|
|
575
|
+
except (TypeError, ValueError):
|
|
576
|
+
errors.append("prompt_learning.judge.weight_outcome must be numeric")
|
|
577
|
+
weight_outcome_f = 0.0
|
|
578
|
+
if weight_event_f <= 0 and weight_outcome_f <= 0:
|
|
579
|
+
errors.append(
|
|
580
|
+
"prompt_learning.judge.reward_source='fused' requires weight_event > 0 or weight_outcome > 0"
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
# Check for multi-stage/multi-module pipeline config
|
|
584
|
+
initial_prompt = pl_section.get("initial_prompt", {})
|
|
585
|
+
pipeline_modules: list[str | dict[str, Any]] = []
|
|
586
|
+
if isinstance(initial_prompt, dict):
|
|
587
|
+
metadata = initial_prompt.get("metadata", {})
|
|
588
|
+
pipeline_modules = metadata.get("pipeline_modules", [])
|
|
589
|
+
if not isinstance(pipeline_modules, list):
|
|
590
|
+
pipeline_modules = []
|
|
591
|
+
has_multi_stage = isinstance(pipeline_modules, list) and len(pipeline_modules) > 0
|
|
592
|
+
|
|
593
|
+
# Validate algorithm-specific config
|
|
594
|
+
if algorithm == "gepa":
|
|
595
|
+
gepa_config = pl_section.get("gepa")
|
|
596
|
+
if not gepa_config or not isinstance(gepa_config, dict):
|
|
597
|
+
errors.append("Missing [prompt_learning.gepa] section for GEPA algorithm")
|
|
598
|
+
else:
|
|
599
|
+
# Multi-stage validation
|
|
600
|
+
modules_config = gepa_config.get("modules")
|
|
601
|
+
if has_multi_stage:
|
|
602
|
+
if not modules_config or not isinstance(modules_config, list) or len(modules_config) == 0:
|
|
603
|
+
errors.append(
|
|
604
|
+
f"GEPA multi-stage pipeline detected (found {len(pipeline_modules)} modules in "
|
|
605
|
+
f"prompt_learning.initial_prompt.metadata.pipeline_modules), "
|
|
606
|
+
f"but [prompt_learning.gepa.modules] is missing or empty. "
|
|
607
|
+
f"Define module configs for each pipeline stage."
|
|
608
|
+
)
|
|
609
|
+
else:
|
|
610
|
+
# Validate module IDs match pipeline_modules
|
|
611
|
+
module_ids = []
|
|
612
|
+
for m in modules_config:
|
|
613
|
+
if isinstance(m, dict):
|
|
614
|
+
module_id = m.get("module_id") or m.get("stage_id")
|
|
615
|
+
if module_id:
|
|
616
|
+
module_ids.append(str(module_id).strip())
|
|
617
|
+
elif hasattr(m, "module_id"):
|
|
618
|
+
module_ids.append(str(m.module_id).strip())
|
|
619
|
+
elif hasattr(m, "stage_id"):
|
|
620
|
+
module_ids.append(str(m.stage_id).strip())
|
|
621
|
+
|
|
622
|
+
# Extract pipeline module names (can be strings or dicts with 'name' field)
|
|
623
|
+
pipeline_module_names = []
|
|
624
|
+
for m in pipeline_modules:
|
|
625
|
+
if isinstance(m, str):
|
|
626
|
+
pipeline_module_names.append(m.strip())
|
|
627
|
+
elif isinstance(m, dict):
|
|
628
|
+
name = m.get("name") or m.get("module_id") or m.get("stage_id")
|
|
629
|
+
if name:
|
|
630
|
+
pipeline_module_names.append(str(name).strip())
|
|
631
|
+
|
|
632
|
+
# Check for missing modules
|
|
633
|
+
missing_modules = set(pipeline_module_names) - set(module_ids)
|
|
634
|
+
if missing_modules:
|
|
635
|
+
errors.append(
|
|
636
|
+
f"Pipeline modules {sorted(missing_modules)} are missing from "
|
|
637
|
+
f"[prompt_learning.gepa.modules]. Each pipeline module must have a corresponding "
|
|
638
|
+
f"module config with matching module_id."
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
# Check for extra modules (warn but don't error)
|
|
642
|
+
extra_modules = set(module_ids) - set(pipeline_module_names)
|
|
643
|
+
if extra_modules:
|
|
644
|
+
# This is a warning, not an error - extra modules are allowed
|
|
645
|
+
pass
|
|
646
|
+
|
|
647
|
+
# Numeric sanity checks
|
|
648
|
+
def _pos_int(name: str) -> None:
|
|
649
|
+
val = gepa_config.get(name)
|
|
650
|
+
if val is not None:
|
|
651
|
+
try:
|
|
652
|
+
ival = int(val)
|
|
653
|
+
if ival <= 0:
|
|
654
|
+
errors.append(f"prompt_learning.gepa.{name} must be > 0")
|
|
655
|
+
except Exception:
|
|
656
|
+
errors.append(f"prompt_learning.gepa.{name} must be an integer")
|
|
657
|
+
|
|
658
|
+
def _pos_int_nested(section: str, name: str) -> None:
|
|
659
|
+
"""Check positive int in nested section."""
|
|
660
|
+
section_config = gepa_config.get(section)
|
|
661
|
+
if section_config and isinstance(section_config, dict):
|
|
662
|
+
val = section_config.get(name)
|
|
663
|
+
if val is not None:
|
|
664
|
+
try:
|
|
665
|
+
ival = int(val)
|
|
666
|
+
if ival <= 0:
|
|
667
|
+
errors.append(f"prompt_learning.gepa.{section}.{name} must be > 0")
|
|
668
|
+
except Exception:
|
|
669
|
+
errors.append(f"prompt_learning.gepa.{section}.{name} must be an integer")
|
|
670
|
+
|
|
671
|
+
def _non_neg_int(name: str) -> None:
|
|
672
|
+
"""Check non-negative int."""
|
|
673
|
+
val = gepa_config.get(name)
|
|
674
|
+
if val is not None:
|
|
675
|
+
try:
|
|
676
|
+
ival = int(val)
|
|
677
|
+
if ival < 0:
|
|
678
|
+
errors.append(f"prompt_learning.gepa.{name} must be >= 0")
|
|
679
|
+
except Exception:
|
|
680
|
+
errors.append(f"prompt_learning.gepa.{name} must be an integer")
|
|
681
|
+
|
|
682
|
+
def _rate_float(name: str) -> None:
|
|
683
|
+
"""Check float in [0.0, 1.0] range."""
|
|
684
|
+
val = gepa_config.get(name)
|
|
685
|
+
if val is not None:
|
|
686
|
+
try:
|
|
687
|
+
fval = float(val)
|
|
688
|
+
if not (0.0 <= fval <= 1.0):
|
|
689
|
+
errors.append(f"prompt_learning.gepa.{name} must be between 0.0 and 1.0")
|
|
690
|
+
except Exception:
|
|
691
|
+
errors.append(f"prompt_learning.gepa.{name} must be numeric")
|
|
692
|
+
|
|
693
|
+
def _pos_float(name: str) -> None:
|
|
694
|
+
"""Check positive float."""
|
|
695
|
+
val = gepa_config.get(name)
|
|
696
|
+
if val is not None:
|
|
697
|
+
try:
|
|
698
|
+
fval = float(val)
|
|
699
|
+
if fval <= 0:
|
|
700
|
+
errors.append(f"prompt_learning.gepa.{name} must be > 0")
|
|
701
|
+
except Exception:
|
|
702
|
+
errors.append(f"prompt_learning.gepa.{name} must be numeric")
|
|
703
|
+
|
|
704
|
+
# Required positive integers
|
|
705
|
+
for fld in ("initial_population_size", "num_generations", "children_per_generation", "max_concurrent_rollouts"):
|
|
706
|
+
_pos_int(fld)
|
|
707
|
+
|
|
708
|
+
# Nested rollout config validation
|
|
709
|
+
_pos_int_nested("rollout", "budget")
|
|
710
|
+
_pos_int_nested("rollout", "max_concurrent")
|
|
711
|
+
_pos_int_nested("rollout", "minibatch_size")
|
|
712
|
+
|
|
713
|
+
# Nested population config validation
|
|
714
|
+
_pos_int_nested("population", "initial_size")
|
|
715
|
+
_pos_int_nested("population", "num_generations")
|
|
716
|
+
_pos_int_nested("population", "children_per_generation")
|
|
717
|
+
_rate_float("mutation_rate") # Can be at top level or in mutation section
|
|
718
|
+
_rate_float("crossover_rate") # Can be at top level or in population section
|
|
719
|
+
_pos_float("selection_pressure") # Must be >= 1.0
|
|
720
|
+
selection_pressure = gepa_config.get("selection_pressure")
|
|
721
|
+
if selection_pressure is not None:
|
|
722
|
+
try:
|
|
723
|
+
sp = float(selection_pressure)
|
|
724
|
+
if sp < 1.0:
|
|
725
|
+
errors.append("prompt_learning.gepa.selection_pressure must be >= 1.0")
|
|
726
|
+
except Exception:
|
|
727
|
+
pass # Already caught by type check
|
|
728
|
+
_non_neg_int("patience_generations")
|
|
729
|
+
|
|
730
|
+
# Nested archive config validation
|
|
731
|
+
_pos_int_nested("archive", "size")
|
|
732
|
+
_pos_int_nested("archive", "pareto_set_size")
|
|
733
|
+
_pos_float("pareto_eps") # Must be > 0, typically very small
|
|
734
|
+
_rate_float("feedback_fraction")
|
|
735
|
+
|
|
736
|
+
# Nested mutation config validation
|
|
737
|
+
mutation_config = gepa_config.get("mutation")
|
|
738
|
+
if mutation_config and isinstance(mutation_config, dict):
|
|
739
|
+
_rate_float("mutation_rate") # Check in mutation section too
|
|
740
|
+
mutation_model = mutation_config.get("llm_model")
|
|
741
|
+
mutation_provider = mutation_config.get("llm_provider", "").strip()
|
|
742
|
+
if mutation_model:
|
|
743
|
+
if not mutation_provider:
|
|
744
|
+
errors.append(
|
|
745
|
+
"Missing required field: prompt_learning.gepa.mutation.llm_provider\n"
|
|
746
|
+
" Required when prompt_learning.gepa.mutation.llm_model is set"
|
|
747
|
+
)
|
|
748
|
+
else:
|
|
749
|
+
errors.extend(_validate_model_for_provider(
|
|
750
|
+
mutation_model, mutation_provider, "prompt_learning.gepa.mutation.llm_model", allow_nano=False
|
|
751
|
+
))
|
|
752
|
+
|
|
753
|
+
# Top-level mutation_rate and crossover_rate (if not in nested sections)
|
|
754
|
+
if not (mutation_config and isinstance(mutation_config, dict) and "rate" in mutation_config):
|
|
755
|
+
_rate_float("mutation_rate")
|
|
756
|
+
population_config = gepa_config.get("population")
|
|
757
|
+
if not (population_config and isinstance(population_config, dict) and "crossover_rate" in population_config):
|
|
758
|
+
_rate_float("crossover_rate")
|
|
759
|
+
|
|
760
|
+
# Budget cap
|
|
761
|
+
max_spend = gepa_config.get("max_spend_usd")
|
|
762
|
+
if max_spend is not None:
|
|
763
|
+
try:
|
|
764
|
+
f = float(max_spend)
|
|
765
|
+
if f <= 0:
|
|
766
|
+
errors.append("prompt_learning.gepa.max_spend_usd must be > 0 when provided")
|
|
767
|
+
except (ValueError, TypeError):
|
|
768
|
+
errors.append("prompt_learning.gepa.max_spend_usd must be numeric")
|
|
769
|
+
|
|
770
|
+
# Rollout budget validation
|
|
771
|
+
rollout_config = gepa_config.get("rollout")
|
|
772
|
+
rollout_budget = None
|
|
773
|
+
if rollout_config and isinstance(rollout_config, dict):
|
|
774
|
+
rollout_budget = rollout_config.get("budget")
|
|
775
|
+
if rollout_budget is None:
|
|
776
|
+
rollout_budget = gepa_config.get("rollout_budget")
|
|
777
|
+
if rollout_budget is not None:
|
|
778
|
+
try:
|
|
779
|
+
rb = int(rollout_budget)
|
|
780
|
+
if rb <= 0:
|
|
781
|
+
errors.append("prompt_learning.gepa.rollout.budget (or rollout_budget) must be > 0 when provided")
|
|
782
|
+
except Exception:
|
|
783
|
+
errors.append("prompt_learning.gepa.rollout.budget (or rollout_budget) must be an integer")
|
|
784
|
+
|
|
785
|
+
# Minibatch size validation
|
|
786
|
+
minibatch_size = None
|
|
787
|
+
if rollout_config and isinstance(rollout_config, dict):
|
|
788
|
+
minibatch_size = rollout_config.get("minibatch_size")
|
|
789
|
+
if minibatch_size is None:
|
|
790
|
+
minibatch_size = gepa_config.get("minibatch_size")
|
|
791
|
+
if minibatch_size is not None:
|
|
792
|
+
try:
|
|
793
|
+
mbs = int(minibatch_size)
|
|
794
|
+
if mbs <= 0:
|
|
795
|
+
errors.append("prompt_learning.gepa.rollout.minibatch_size (or minibatch_size) must be > 0")
|
|
796
|
+
except Exception:
|
|
797
|
+
errors.append("prompt_learning.gepa.rollout.minibatch_size (or minibatch_size) must be an integer")
|
|
798
|
+
|
|
799
|
+
# Proposer type validation
|
|
800
|
+
proposer_type = gepa_config.get("proposer_type", "dspy")
|
|
801
|
+
if proposer_type not in ("dspy", "spec", "synth", "gepa-ai"):
|
|
802
|
+
errors.append(
|
|
803
|
+
f"Invalid proposer_type: '{proposer_type}'\n"
|
|
804
|
+
f" Must be one of: 'dspy', 'spec', 'synth', 'gepa-ai'\n"
|
|
805
|
+
f" Got: '{proposer_type}'"
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Proposer effort validation
|
|
809
|
+
proposer_effort = str(gepa_config.get("proposer_effort", "LOW")).upper()
|
|
810
|
+
valid_effort_levels = {"LOW_CONTEXT", "LOW", "MEDIUM", "HIGH"}
|
|
811
|
+
if proposer_effort not in valid_effort_levels:
|
|
812
|
+
errors.append(
|
|
813
|
+
f"Invalid proposer_effort: '{proposer_effort}'\n"
|
|
814
|
+
f" Must be one of: {', '.join(sorted(valid_effort_levels))}\n"
|
|
815
|
+
f" Got: '{proposer_effort}'"
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Proposer output tokens validation
|
|
819
|
+
proposer_output_tokens = str(gepa_config.get("proposer_output_tokens", "FAST")).upper()
|
|
820
|
+
valid_output_tokens = {"RAPID", "FAST", "SLOW"}
|
|
821
|
+
if proposer_output_tokens not in valid_output_tokens:
|
|
822
|
+
errors.append(
|
|
823
|
+
f"Invalid proposer_output_tokens: '{proposer_output_tokens}'\n"
|
|
824
|
+
f" Must be one of: {', '.join(sorted(valid_output_tokens))}\n"
|
|
825
|
+
f" Got: '{proposer_output_tokens}'"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
# Note: RAPID can now be used with any proposer_effort level (5000 tokens)
|
|
829
|
+
|
|
830
|
+
# Spec validation when proposer_type is "spec"
|
|
831
|
+
if proposer_type == "spec":
|
|
832
|
+
spec_path = gepa_config.get("spec_path")
|
|
833
|
+
if not spec_path:
|
|
834
|
+
errors.append(
|
|
835
|
+
"Missing required field: prompt_learning.gepa.spec_path\n"
|
|
836
|
+
" Required when proposer_type='spec'\n"
|
|
837
|
+
" Example:\n"
|
|
838
|
+
" [prompt_learning.gepa]\n"
|
|
839
|
+
" proposer_type = \"spec\"\n"
|
|
840
|
+
" spec_path = \"examples/task_apps/banking77/banking77_spec.json\""
|
|
841
|
+
)
|
|
842
|
+
else:
|
|
843
|
+
# Validate spec_max_tokens if provided
|
|
844
|
+
spec_max_tokens = gepa_config.get("spec_max_tokens")
|
|
845
|
+
if spec_max_tokens is not None:
|
|
846
|
+
try:
|
|
847
|
+
smt = int(spec_max_tokens)
|
|
848
|
+
if smt <= 0:
|
|
849
|
+
errors.append("prompt_learning.gepa.spec_max_tokens must be > 0")
|
|
850
|
+
except Exception:
|
|
851
|
+
errors.append("prompt_learning.gepa.spec_max_tokens must be an integer")
|
|
852
|
+
|
|
853
|
+
# Validate spec_priority_threshold if provided
|
|
854
|
+
spec_priority_threshold = gepa_config.get("spec_priority_threshold")
|
|
855
|
+
if spec_priority_threshold is not None:
|
|
856
|
+
try:
|
|
857
|
+
spt = int(spec_priority_threshold)
|
|
858
|
+
if spt < 0:
|
|
859
|
+
errors.append("prompt_learning.gepa.spec_priority_threshold must be >= 0")
|
|
860
|
+
except Exception:
|
|
861
|
+
errors.append("prompt_learning.gepa.spec_priority_threshold must be an integer")
|
|
862
|
+
|
|
863
|
+
# Archive size validation
|
|
864
|
+
archive_config = gepa_config.get("archive")
|
|
865
|
+
archive_size = None
|
|
866
|
+
if archive_config and isinstance(archive_config, dict):
|
|
867
|
+
archive_size = archive_config.get("size")
|
|
868
|
+
if archive_size is None:
|
|
869
|
+
archive_size = gepa_config.get("archive_size")
|
|
870
|
+
if archive_size is not None:
|
|
871
|
+
try:
|
|
872
|
+
asize = int(archive_size)
|
|
873
|
+
if asize <= 0:
|
|
874
|
+
errors.append("prompt_learning.gepa.archive.size (or archive_size) must be > 0")
|
|
875
|
+
except Exception:
|
|
876
|
+
errors.append("prompt_learning.gepa.archive.size (or archive_size) must be an integer")
|
|
877
|
+
|
|
878
|
+
# CRITICAL: Validate pareto_set_size vs seeds BEFORE submitting to backend
|
|
879
|
+
# This catches config errors immediately instead of after job submission
|
|
880
|
+
eval_config = gepa_config.get("evaluation")
|
|
881
|
+
if eval_config and isinstance(eval_config, dict):
|
|
882
|
+
train_seeds = eval_config.get("seeds") or eval_config.get("train_seeds")
|
|
883
|
+
if train_seeds and isinstance(train_seeds, list) and len(train_seeds) > 0:
|
|
884
|
+
total_seeds = len(train_seeds)
|
|
885
|
+
|
|
886
|
+
# Get pareto_set_size (can be in archive section or top-level)
|
|
887
|
+
pareto_set_size = None
|
|
888
|
+
if archive_config and isinstance(archive_config, dict):
|
|
889
|
+
pareto_set_size = archive_config.get("pareto_set_size")
|
|
890
|
+
if pareto_set_size is None:
|
|
891
|
+
pareto_set_size = gepa_config.get("pareto_set_size", 64) # Default from backend
|
|
892
|
+
|
|
893
|
+
try:
|
|
894
|
+
pareto_count = int(pareto_set_size)
|
|
895
|
+
feedback_fraction = 0.5 # Default
|
|
896
|
+
if archive_config and isinstance(archive_config, dict):
|
|
897
|
+
feedback_fraction = archive_config.get("feedback_fraction", 0.5)
|
|
898
|
+
if feedback_fraction is None:
|
|
899
|
+
feedback_fraction = gepa_config.get("feedback_fraction", 0.5)
|
|
900
|
+
feedback_fraction = float(feedback_fraction)
|
|
901
|
+
|
|
902
|
+
# Calculate split
|
|
903
|
+
feedback_count = total_seeds - pareto_count
|
|
904
|
+
|
|
905
|
+
# Constants matching backend
|
|
906
|
+
min_pareto_set_size = 10
|
|
907
|
+
min_feedback_seeds = 3
|
|
908
|
+
|
|
909
|
+
# Validate pareto_set_size <= total_seeds
|
|
910
|
+
if pareto_count > total_seeds:
|
|
911
|
+
errors.append(
|
|
912
|
+
f"CONFIG ERROR: pareto_set_size={pareto_count} > total_seeds={total_seeds}. "
|
|
913
|
+
f"Increase [prompt_learning.gepa.evaluation].seeds or decrease "
|
|
914
|
+
f"[prompt_learning.gepa.archive].pareto_set_size. "
|
|
915
|
+
f"Seeds: {train_seeds[:10]}{'...' if len(train_seeds) > 10 else ''}"
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# Validate pareto_set_size >= min_pareto_set_size
|
|
919
|
+
if pareto_count < min_pareto_set_size:
|
|
920
|
+
errors.append(
|
|
921
|
+
f"CONFIG ERROR: pareto_set_size={pareto_count} < MIN_PARETO_SET_SIZE={min_pareto_set_size}. "
|
|
922
|
+
f"Increase [prompt_learning.gepa.archive].pareto_set_size to at least {min_pareto_set_size}. "
|
|
923
|
+
f"Below this threshold, accuracy estimates are too noisy for reliable optimization."
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
# Validate feedback_count >= min_feedback_seeds
|
|
927
|
+
if feedback_count < min_feedback_seeds:
|
|
928
|
+
errors.append(
|
|
929
|
+
f"CONFIG ERROR: feedback_count={feedback_count} < MIN_FEEDBACK_SEEDS={min_feedback_seeds}. "
|
|
930
|
+
f"Increase total seeds or decrease pareto_set_size to ensure at least {min_feedback_seeds} feedback seeds. "
|
|
931
|
+
f"Below this threshold, reflection prompts lack sufficient diversity."
|
|
932
|
+
)
|
|
933
|
+
except (ValueError, TypeError):
|
|
934
|
+
pass # Type errors already caught by _pos_int_nested above
|
|
935
|
+
|
|
936
|
+
# Pareto eps validation
|
|
937
|
+
pareto_eps = None
|
|
938
|
+
if archive_config and isinstance(archive_config, dict):
|
|
939
|
+
pareto_eps = archive_config.get("pareto_eps")
|
|
940
|
+
if pareto_eps is None:
|
|
941
|
+
pareto_eps = gepa_config.get("pareto_eps")
|
|
942
|
+
if pareto_eps is not None:
|
|
943
|
+
try:
|
|
944
|
+
pe = float(pareto_eps)
|
|
945
|
+
if pe <= 0:
|
|
946
|
+
errors.append("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) must be > 0")
|
|
947
|
+
elif pe >= 1.0:
|
|
948
|
+
errors.append("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) should be < 1.0 (typically 1e-6)")
|
|
949
|
+
except Exception:
|
|
950
|
+
errors.append("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) must be numeric")
|
|
951
|
+
|
|
952
|
+
# Feedback fraction validation
|
|
953
|
+
feedback_fraction = None
|
|
954
|
+
if archive_config and isinstance(archive_config, dict):
|
|
955
|
+
feedback_fraction = archive_config.get("feedback_fraction")
|
|
956
|
+
if feedback_fraction is None:
|
|
957
|
+
feedback_fraction = gepa_config.get("feedback_fraction")
|
|
958
|
+
if feedback_fraction is not None:
|
|
959
|
+
try:
|
|
960
|
+
ff = float(feedback_fraction)
|
|
961
|
+
if not (0.0 <= ff <= 1.0):
|
|
962
|
+
errors.append("prompt_learning.gepa.archive.feedback_fraction (or feedback_fraction) must be between 0.0 and 1.0")
|
|
963
|
+
except Exception:
|
|
964
|
+
errors.append("prompt_learning.gepa.archive.feedback_fraction (or feedback_fraction) must be numeric")
|
|
965
|
+
|
|
966
|
+
# Token counting model validation (should be a valid model name)
|
|
967
|
+
token_config = gepa_config.get("token")
|
|
968
|
+
token_counting_model = None
|
|
969
|
+
if token_config and isinstance(token_config, dict):
|
|
970
|
+
token_counting_model = token_config.get("counting_model")
|
|
971
|
+
if token_counting_model is None:
|
|
972
|
+
token_counting_model = gepa_config.get("token_counting_model")
|
|
973
|
+
if token_counting_model and (not isinstance(token_counting_model, str) or not token_counting_model.strip()):
|
|
974
|
+
# Basic validation - should be a non-empty string
|
|
975
|
+
errors.append("prompt_learning.gepa.token.counting_model (or token_counting_model) must be a non-empty string")
|
|
976
|
+
|
|
977
|
+
# Module/stage validation for multi-stage
|
|
978
|
+
if has_multi_stage:
|
|
979
|
+
modules_config = gepa_config.get("modules")
|
|
980
|
+
if modules_config and isinstance(modules_config, list):
|
|
981
|
+
for idx, module_entry in enumerate(modules_config):
|
|
982
|
+
if isinstance(module_entry, dict):
|
|
983
|
+
module_id = module_entry.get("module_id") or module_entry.get("stage_id") or f"module_{idx}"
|
|
984
|
+
max_instruction_slots = module_entry.get("max_instruction_slots")
|
|
985
|
+
max_tokens = module_entry.get("max_tokens")
|
|
986
|
+
allowed_tools = module_entry.get("allowed_tools")
|
|
987
|
+
|
|
988
|
+
# Validate max_instruction_slots
|
|
989
|
+
if max_instruction_slots is not None:
|
|
990
|
+
try:
|
|
991
|
+
mis = int(max_instruction_slots)
|
|
992
|
+
if mis < 1:
|
|
993
|
+
errors.append(
|
|
994
|
+
f"prompt_learning.gepa.modules[{idx}].max_instruction_slots must be >= 1"
|
|
995
|
+
)
|
|
996
|
+
except Exception:
|
|
997
|
+
errors.append(
|
|
998
|
+
f"prompt_learning.gepa.modules[{idx}].max_instruction_slots must be an integer"
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
# Validate max_tokens
|
|
1002
|
+
if max_tokens is not None:
|
|
1003
|
+
try:
|
|
1004
|
+
mt = int(max_tokens)
|
|
1005
|
+
if mt <= 0:
|
|
1006
|
+
errors.append(
|
|
1007
|
+
f"prompt_learning.gepa.modules[{idx}].max_tokens must be > 0"
|
|
1008
|
+
)
|
|
1009
|
+
except Exception:
|
|
1010
|
+
errors.append(
|
|
1011
|
+
f"prompt_learning.gepa.modules[{idx}].max_tokens must be an integer"
|
|
1012
|
+
)
|
|
1013
|
+
|
|
1014
|
+
# Validate allowed_tools
|
|
1015
|
+
if allowed_tools is not None:
|
|
1016
|
+
if not isinstance(allowed_tools, list):
|
|
1017
|
+
errors.append(
|
|
1018
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools must be a list"
|
|
1019
|
+
)
|
|
1020
|
+
else:
|
|
1021
|
+
if len(allowed_tools) == 0:
|
|
1022
|
+
errors.append(
|
|
1023
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools cannot be empty (use null/omit to allow all tools)"
|
|
1024
|
+
)
|
|
1025
|
+
else:
|
|
1026
|
+
# Check for duplicates
|
|
1027
|
+
seen_tools = set()
|
|
1028
|
+
for tool_idx, tool in enumerate(allowed_tools):
|
|
1029
|
+
if not isinstance(tool, str):
|
|
1030
|
+
errors.append(
|
|
1031
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools[{tool_idx}] must be a string"
|
|
1032
|
+
)
|
|
1033
|
+
elif not tool.strip():
|
|
1034
|
+
errors.append(
|
|
1035
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools[{tool_idx}] cannot be empty"
|
|
1036
|
+
)
|
|
1037
|
+
elif tool.strip() in seen_tools:
|
|
1038
|
+
errors.append(
|
|
1039
|
+
f"prompt_learning.gepa.modules[{idx}].allowed_tools contains duplicate '{tool.strip()}'"
|
|
1040
|
+
)
|
|
1041
|
+
else:
|
|
1042
|
+
seen_tools.add(tool.strip())
|
|
1043
|
+
|
|
1044
|
+
# Validate per-module policy config (REQUIRED)
|
|
1045
|
+
module_policy = module_entry.get("policy")
|
|
1046
|
+
if module_policy is None:
|
|
1047
|
+
errors.append(
|
|
1048
|
+
f"❌ gepa.modules[{idx}]: [policy] table is REQUIRED. "
|
|
1049
|
+
f"Each module must have its own policy configuration with 'model' and 'provider' fields."
|
|
1050
|
+
)
|
|
1051
|
+
elif not isinstance(module_policy, dict):
|
|
1052
|
+
errors.append(
|
|
1053
|
+
f"❌ gepa.modules[{idx}]: [policy] must be a table/dict, got {type(module_policy).__name__}"
|
|
1054
|
+
)
|
|
1055
|
+
else:
|
|
1056
|
+
# Validate required fields in module policy
|
|
1057
|
+
if not module_policy.get("model"):
|
|
1058
|
+
errors.append(
|
|
1059
|
+
f"❌ gepa.modules[{idx}]: [policy].model is required"
|
|
1060
|
+
)
|
|
1061
|
+
if not module_policy.get("provider"):
|
|
1062
|
+
errors.append(
|
|
1063
|
+
f"❌ gepa.modules[{idx}]: [policy].provider is required"
|
|
1064
|
+
)
|
|
1065
|
+
# Validate model/provider combination
|
|
1066
|
+
module_model = module_policy.get("model")
|
|
1067
|
+
module_provider = module_policy.get("provider")
|
|
1068
|
+
if module_model and module_provider:
|
|
1069
|
+
errors.extend(_validate_model_for_provider(
|
|
1070
|
+
module_model, module_provider,
|
|
1071
|
+
f"prompt_learning.gepa.modules[{idx}].policy.model",
|
|
1072
|
+
allow_nano=True, # Policy models can be nano
|
|
1073
|
+
))
|
|
1074
|
+
# Reject inference_url in module policy (trainer provides it)
|
|
1075
|
+
if "inference_url" in module_policy:
|
|
1076
|
+
errors.append(
|
|
1077
|
+
f"❌ gepa.modules[{idx}]: [policy].inference_url must not be specified. "
|
|
1078
|
+
f"The trainer provides the inference URL in rollout requests. Remove inference_url from module policy."
|
|
1079
|
+
)
|
|
1080
|
+
if "api_base" in module_policy:
|
|
1081
|
+
errors.append(
|
|
1082
|
+
f"❌ gepa.modules[{idx}]: [policy].api_base must not be specified. "
|
|
1083
|
+
f"Remove api_base from module policy."
|
|
1084
|
+
)
|
|
1085
|
+
if "base_url" in module_policy:
|
|
1086
|
+
errors.append(
|
|
1087
|
+
f"❌ gepa.modules[{idx}]: [policy].base_url must not be specified. "
|
|
1088
|
+
f"Remove base_url from module policy."
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
elif algorithm == "mipro":
|
|
1092
|
+
mipro_config = pl_section.get("mipro")
|
|
1093
|
+
if not mipro_config or not isinstance(mipro_config, dict):
|
|
1094
|
+
errors.append("Missing [prompt_learning.mipro] section for MIPRO algorithm")
|
|
1095
|
+
else:
|
|
1096
|
+
# Validate required MIPRO fields
|
|
1097
|
+
def _pos_int(name: str) -> None:
|
|
1098
|
+
val = mipro_config.get(name)
|
|
1099
|
+
if val is not None:
|
|
1100
|
+
try:
|
|
1101
|
+
ival = int(val)
|
|
1102
|
+
if ival <= 0:
|
|
1103
|
+
errors.append(f"prompt_learning.mipro.{name} must be > 0")
|
|
1104
|
+
except Exception:
|
|
1105
|
+
errors.append(f"prompt_learning.mipro.{name} must be an integer")
|
|
1106
|
+
|
|
1107
|
+
def _non_neg_int(name: str) -> None:
|
|
1108
|
+
"""Check non-negative int."""
|
|
1109
|
+
val = mipro_config.get(name)
|
|
1110
|
+
if val is not None:
|
|
1111
|
+
try:
|
|
1112
|
+
ival = int(val)
|
|
1113
|
+
if ival < 0:
|
|
1114
|
+
errors.append(f"prompt_learning.mipro.{name} must be >= 0")
|
|
1115
|
+
except Exception:
|
|
1116
|
+
errors.append(f"prompt_learning.mipro.{name} must be an integer")
|
|
1117
|
+
|
|
1118
|
+
def _rate_float(name: str) -> None:
|
|
1119
|
+
"""Check float in [0.0, 1.0] range."""
|
|
1120
|
+
val = mipro_config.get(name)
|
|
1121
|
+
if val is not None:
|
|
1122
|
+
try:
|
|
1123
|
+
fval = float(val)
|
|
1124
|
+
if not (0.0 <= fval <= 1.0):
|
|
1125
|
+
errors.append(f"prompt_learning.mipro.{name} must be between 0.0 and 1.0")
|
|
1126
|
+
except Exception:
|
|
1127
|
+
errors.append(f"prompt_learning.mipro.{name} must be numeric")
|
|
1128
|
+
|
|
1129
|
+
def _pos_float(name: str) -> None:
|
|
1130
|
+
"""Check positive float."""
|
|
1131
|
+
val = mipro_config.get(name)
|
|
1132
|
+
if val is not None:
|
|
1133
|
+
try:
|
|
1134
|
+
fval = float(val)
|
|
1135
|
+
if fval <= 0:
|
|
1136
|
+
errors.append(f"prompt_learning.mipro.{name} must be > 0")
|
|
1137
|
+
except Exception:
|
|
1138
|
+
errors.append(f"prompt_learning.mipro.{name} must be numeric")
|
|
1139
|
+
|
|
1140
|
+
# Required numeric fields
|
|
1141
|
+
for fld in ("num_iterations", "num_evaluations_per_iteration", "batch_size", "max_concurrent"):
|
|
1142
|
+
_pos_int(fld)
|
|
1143
|
+
|
|
1144
|
+
# Additional MIPRO numeric validations
|
|
1145
|
+
_pos_int("max_demo_set_size")
|
|
1146
|
+
_pos_int("max_demo_sets")
|
|
1147
|
+
_pos_int("max_instruction_sets")
|
|
1148
|
+
_pos_int("full_eval_every_k")
|
|
1149
|
+
_pos_int("instructions_per_batch")
|
|
1150
|
+
_pos_int("max_instructions")
|
|
1151
|
+
_pos_int("duplicate_retry_limit")
|
|
1152
|
+
|
|
1153
|
+
# Validate meta_model if set (optional - backend applies defaults)
|
|
1154
|
+
meta_model = mipro_config.get("meta_model")
|
|
1155
|
+
meta_model_provider = mipro_config.get("meta_model_provider", "").strip()
|
|
1156
|
+
if meta_model:
|
|
1157
|
+
# If meta_model is explicitly set, validate it
|
|
1158
|
+
if not meta_model_provider:
|
|
1159
|
+
errors.append(
|
|
1160
|
+
"Missing required field: prompt_learning.mipro.meta_model_provider\n"
|
|
1161
|
+
" Required when prompt_learning.mipro.meta_model is set"
|
|
1162
|
+
)
|
|
1163
|
+
else:
|
|
1164
|
+
errors.extend(_validate_model_for_provider(
|
|
1165
|
+
meta_model, meta_model_provider, "prompt_learning.mipro.meta_model", allow_nano=False
|
|
1166
|
+
))
|
|
1167
|
+
# If meta_model is not set, backend will use defaults (llama-3.3-70b-versatile/groq)
|
|
1168
|
+
|
|
1169
|
+
# Validate meta model temperature
|
|
1170
|
+
meta_temperature = mipro_config.get("meta_model_temperature")
|
|
1171
|
+
if meta_temperature is not None:
|
|
1172
|
+
try:
|
|
1173
|
+
temp = float(meta_temperature)
|
|
1174
|
+
if temp < 0.0:
|
|
1175
|
+
errors.append("prompt_learning.mipro.meta_model_temperature must be >= 0.0")
|
|
1176
|
+
except Exception:
|
|
1177
|
+
errors.append("prompt_learning.mipro.meta_model_temperature must be numeric")
|
|
1178
|
+
|
|
1179
|
+
# Validate meta model max_tokens
|
|
1180
|
+
meta_max_tokens = mipro_config.get("meta_model_max_tokens")
|
|
1181
|
+
if meta_max_tokens is not None:
|
|
1182
|
+
try:
|
|
1183
|
+
mmt = int(meta_max_tokens)
|
|
1184
|
+
if mmt <= 0:
|
|
1185
|
+
errors.append("prompt_learning.mipro.meta_model_max_tokens must be > 0")
|
|
1186
|
+
except Exception:
|
|
1187
|
+
errors.append("prompt_learning.mipro.meta_model_max_tokens must be an integer")
|
|
1188
|
+
|
|
1189
|
+
# Validate generate_at_iterations
|
|
1190
|
+
generate_at_iterations = mipro_config.get("generate_at_iterations")
|
|
1191
|
+
if generate_at_iterations is not None:
|
|
1192
|
+
if not isinstance(generate_at_iterations, list):
|
|
1193
|
+
errors.append("prompt_learning.mipro.generate_at_iterations must be a list")
|
|
1194
|
+
else:
|
|
1195
|
+
for idx, iter_val in enumerate(generate_at_iterations):
|
|
1196
|
+
try:
|
|
1197
|
+
iter_int = int(iter_val)
|
|
1198
|
+
if iter_int < 0:
|
|
1199
|
+
errors.append(
|
|
1200
|
+
f"prompt_learning.mipro.generate_at_iterations[{idx}] must be >= 0"
|
|
1201
|
+
)
|
|
1202
|
+
except Exception:
|
|
1203
|
+
errors.append(
|
|
1204
|
+
f"prompt_learning.mipro.generate_at_iterations[{idx}] must be an integer"
|
|
1205
|
+
)
|
|
1206
|
+
|
|
1207
|
+
# Validate spec configuration
|
|
1208
|
+
spec_path = mipro_config.get("spec_path")
|
|
1209
|
+
if spec_path:
|
|
1210
|
+
# Validate spec_max_tokens if provided
|
|
1211
|
+
spec_max_tokens = mipro_config.get("spec_max_tokens")
|
|
1212
|
+
if spec_max_tokens is not None:
|
|
1213
|
+
try:
|
|
1214
|
+
smt = int(spec_max_tokens)
|
|
1215
|
+
if smt <= 0:
|
|
1216
|
+
errors.append("prompt_learning.mipro.spec_max_tokens must be > 0")
|
|
1217
|
+
except Exception:
|
|
1218
|
+
errors.append("prompt_learning.mipro.spec_max_tokens must be an integer")
|
|
1219
|
+
|
|
1220
|
+
# Validate spec_priority_threshold if provided
|
|
1221
|
+
spec_priority_threshold = mipro_config.get("spec_priority_threshold")
|
|
1222
|
+
if spec_priority_threshold is not None:
|
|
1223
|
+
try:
|
|
1224
|
+
spt = int(spec_priority_threshold)
|
|
1225
|
+
if spt < 0:
|
|
1226
|
+
errors.append("prompt_learning.mipro.spec_priority_threshold must be >= 0")
|
|
1227
|
+
except Exception:
|
|
1228
|
+
errors.append("prompt_learning.mipro.spec_priority_threshold must be an integer")
|
|
1229
|
+
|
|
1230
|
+
# Validate modules/stages configuration
|
|
1231
|
+
modules_config = mipro_config.get("modules")
|
|
1232
|
+
if modules_config and isinstance(modules_config, list):
|
|
1233
|
+
max_instruction_sets = mipro_config.get("max_instruction_sets", 128)
|
|
1234
|
+
max_demo_sets = mipro_config.get("max_demo_sets", 128)
|
|
1235
|
+
seen_module_ids = set()
|
|
1236
|
+
seen_stage_ids = set()
|
|
1237
|
+
|
|
1238
|
+
for module_idx, module_entry in enumerate(modules_config):
|
|
1239
|
+
if not isinstance(module_entry, dict):
|
|
1240
|
+
errors.append(
|
|
1241
|
+
f"prompt_learning.mipro.modules[{module_idx}] must be a table/dict"
|
|
1242
|
+
)
|
|
1243
|
+
continue
|
|
1244
|
+
|
|
1245
|
+
module_id = module_entry.get("module_id") or module_entry.get("id") or f"module_{module_idx}"
|
|
1246
|
+
if module_id in seen_module_ids:
|
|
1247
|
+
errors.append(
|
|
1248
|
+
f"Duplicate module_id '{module_id}' in prompt_learning.mipro.modules"
|
|
1249
|
+
)
|
|
1250
|
+
seen_module_ids.add(module_id)
|
|
1251
|
+
|
|
1252
|
+
# Validate stages
|
|
1253
|
+
stages = module_entry.get("stages")
|
|
1254
|
+
if stages is not None:
|
|
1255
|
+
if not isinstance(stages, list):
|
|
1256
|
+
errors.append(
|
|
1257
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages must be a list"
|
|
1258
|
+
)
|
|
1259
|
+
else:
|
|
1260
|
+
for stage_idx, stage_entry in enumerate(stages):
|
|
1261
|
+
if isinstance(stage_entry, dict):
|
|
1262
|
+
stage_id = stage_entry.get("stage_id") or stage_entry.get("module_stage_id") or f"stage_{stage_idx}"
|
|
1263
|
+
if stage_id in seen_stage_ids:
|
|
1264
|
+
errors.append(
|
|
1265
|
+
f"Duplicate stage_id '{stage_id}' across modules"
|
|
1266
|
+
)
|
|
1267
|
+
seen_stage_ids.add(stage_id)
|
|
1268
|
+
|
|
1269
|
+
# Validate max_instruction_slots <= max_instruction_sets
|
|
1270
|
+
max_instr_slots = stage_entry.get("max_instruction_slots")
|
|
1271
|
+
if max_instr_slots is not None:
|
|
1272
|
+
try:
|
|
1273
|
+
mis = int(max_instr_slots)
|
|
1274
|
+
if mis < 1:
|
|
1275
|
+
errors.append(
|
|
1276
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots must be >= 1"
|
|
1277
|
+
)
|
|
1278
|
+
elif mis > max_instruction_sets:
|
|
1279
|
+
errors.append(
|
|
1280
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots ({mis}) "
|
|
1281
|
+
f"exceeds max_instruction_sets ({max_instruction_sets})"
|
|
1282
|
+
)
|
|
1283
|
+
except Exception:
|
|
1284
|
+
errors.append(
|
|
1285
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots must be an integer"
|
|
1286
|
+
)
|
|
1287
|
+
|
|
1288
|
+
# Validate max_demo_slots <= max_demo_sets
|
|
1289
|
+
max_demo_slots = stage_entry.get("max_demo_slots")
|
|
1290
|
+
if max_demo_slots is not None:
|
|
1291
|
+
try:
|
|
1292
|
+
mds = int(max_demo_slots)
|
|
1293
|
+
if mds < 0:
|
|
1294
|
+
errors.append(
|
|
1295
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots must be >= 0"
|
|
1296
|
+
)
|
|
1297
|
+
elif mds > max_demo_sets:
|
|
1298
|
+
errors.append(
|
|
1299
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots ({mds}) "
|
|
1300
|
+
f"exceeds max_demo_sets ({max_demo_sets})"
|
|
1301
|
+
)
|
|
1302
|
+
except Exception:
|
|
1303
|
+
errors.append(
|
|
1304
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots must be an integer"
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
# Validate edges reference valid stages
|
|
1308
|
+
edges = module_entry.get("edges")
|
|
1309
|
+
if edges is not None:
|
|
1310
|
+
if not isinstance(edges, list):
|
|
1311
|
+
errors.append(
|
|
1312
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges must be a list"
|
|
1313
|
+
)
|
|
1314
|
+
else:
|
|
1315
|
+
stage_ids_in_module = set()
|
|
1316
|
+
if stages and isinstance(stages, list):
|
|
1317
|
+
for stage_entry in stages:
|
|
1318
|
+
if isinstance(stage_entry, dict):
|
|
1319
|
+
sid = stage_entry.get("stage_id") or stage_entry.get("module_stage_id")
|
|
1320
|
+
if sid:
|
|
1321
|
+
stage_ids_in_module.add(str(sid))
|
|
1322
|
+
|
|
1323
|
+
for edge_idx, edge in enumerate(edges):
|
|
1324
|
+
if isinstance(edge, list | tuple) and len(edge) == 2:
|
|
1325
|
+
source, target = edge
|
|
1326
|
+
elif isinstance(edge, dict):
|
|
1327
|
+
source = edge.get("from") or edge.get("source")
|
|
1328
|
+
target = edge.get("to") or edge.get("target")
|
|
1329
|
+
else:
|
|
1330
|
+
errors.append(
|
|
1331
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges[{edge_idx}] must be a pair or mapping"
|
|
1332
|
+
)
|
|
1333
|
+
continue
|
|
1334
|
+
|
|
1335
|
+
source_str = str(source or "").strip()
|
|
1336
|
+
target_str = str(target or "").strip()
|
|
1337
|
+
if source_str and source_str not in stage_ids_in_module:
|
|
1338
|
+
errors.append(
|
|
1339
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges[{edge_idx}] references unknown source stage '{source_str}'"
|
|
1340
|
+
)
|
|
1341
|
+
if target_str and target_str not in stage_ids_in_module:
|
|
1342
|
+
errors.append(
|
|
1343
|
+
f"prompt_learning.mipro.modules[{module_idx}].edges[{edge_idx}] references unknown target stage '{target_str}'"
|
|
1344
|
+
)
|
|
1345
|
+
|
|
1346
|
+
# CRITICAL: Validate bootstrap_train_seeds and online_pool (can be at top level or under mipro)
|
|
1347
|
+
bootstrap_seeds = pl_section.get("bootstrap_train_seeds") or (mipro_config.get("bootstrap_train_seeds") if isinstance(mipro_config, dict) else None)
|
|
1348
|
+
online_pool = pl_section.get("online_pool") or (mipro_config.get("online_pool") if isinstance(mipro_config, dict) else None)
|
|
1349
|
+
|
|
1350
|
+
if not bootstrap_seeds:
|
|
1351
|
+
errors.append(
|
|
1352
|
+
"Missing required field: prompt_learning.bootstrap_train_seeds\n"
|
|
1353
|
+
" MIPRO requires bootstrap seeds for the few-shot bootstrapping phase.\n"
|
|
1354
|
+
" Example:\n"
|
|
1355
|
+
" [prompt_learning]\n"
|
|
1356
|
+
" bootstrap_train_seeds = [0, 1, 2, 3, 4]"
|
|
1357
|
+
)
|
|
1358
|
+
elif not isinstance(bootstrap_seeds, list):
|
|
1359
|
+
errors.append("prompt_learning.bootstrap_train_seeds must be an array")
|
|
1360
|
+
elif len(bootstrap_seeds) == 0:
|
|
1361
|
+
errors.append("prompt_learning.bootstrap_train_seeds cannot be empty")
|
|
1362
|
+
|
|
1363
|
+
if not online_pool:
|
|
1364
|
+
errors.append(
|
|
1365
|
+
"Missing required field: prompt_learning.online_pool\n"
|
|
1366
|
+
" MIPRO requires online_pool seeds for mini-batch evaluation during optimization.\n"
|
|
1367
|
+
" Example:\n"
|
|
1368
|
+
" [prompt_learning]\n"
|
|
1369
|
+
" online_pool = [5, 6, 7, 8, 9]"
|
|
1370
|
+
)
|
|
1371
|
+
elif not isinstance(online_pool, list):
|
|
1372
|
+
errors.append("prompt_learning.online_pool must be an array")
|
|
1373
|
+
elif len(online_pool) == 0:
|
|
1374
|
+
errors.append("prompt_learning.online_pool cannot be empty")
|
|
1375
|
+
|
|
1376
|
+
# Validate few_shot_score_threshold (if mipro_config exists)
|
|
1377
|
+
if isinstance(mipro_config, dict):
|
|
1378
|
+
threshold = mipro_config.get("few_shot_score_threshold")
|
|
1379
|
+
if threshold is not None:
|
|
1380
|
+
try:
|
|
1381
|
+
f = float(threshold)
|
|
1382
|
+
if not (0.0 <= f <= 1.0):
|
|
1383
|
+
errors.append("prompt_learning.mipro.few_shot_score_threshold must be between 0.0 and 1.0")
|
|
1384
|
+
except Exception:
|
|
1385
|
+
errors.append("prompt_learning.mipro.few_shot_score_threshold must be a number")
|
|
1386
|
+
|
|
1387
|
+
# Validate min_bootstrap_demos (strict bootstrap mode)
|
|
1388
|
+
min_bootstrap_demos = mipro_config.get("min_bootstrap_demos")
|
|
1389
|
+
if min_bootstrap_demos is not None:
|
|
1390
|
+
try:
|
|
1391
|
+
min_demos_int = int(min_bootstrap_demos)
|
|
1392
|
+
if min_demos_int < 0:
|
|
1393
|
+
errors.append("prompt_learning.mipro.min_bootstrap_demos must be >= 0")
|
|
1394
|
+
elif bootstrap_seeds and min_demos_int > len(bootstrap_seeds):
|
|
1395
|
+
errors.append(
|
|
1396
|
+
f"prompt_learning.mipro.min_bootstrap_demos ({min_demos_int}) exceeds "
|
|
1397
|
+
f"bootstrap_train_seeds count ({len(bootstrap_seeds)}). "
|
|
1398
|
+
f"You can never have more demos than bootstrap seeds."
|
|
1399
|
+
)
|
|
1400
|
+
except (TypeError, ValueError):
|
|
1401
|
+
errors.append("prompt_learning.mipro.min_bootstrap_demos must be an integer")
|
|
1402
|
+
|
|
1403
|
+
# Validate reference pool doesn't overlap with bootstrap/online/test pools
|
|
1404
|
+
reference_pool = mipro_config.get("reference_pool") or pl_section.get("reference_pool")
|
|
1405
|
+
if reference_pool:
|
|
1406
|
+
if not isinstance(reference_pool, list):
|
|
1407
|
+
errors.append("prompt_learning.mipro.reference_pool (or prompt_learning.reference_pool) must be an array")
|
|
1408
|
+
else:
|
|
1409
|
+
all_train_test = set(bootstrap_seeds or []) | set(online_pool or []) | set(mipro_config.get("test_pool") or pl_section.get("test_pool") or [])
|
|
1410
|
+
overlapping = set(reference_pool) & all_train_test
|
|
1411
|
+
if overlapping:
|
|
1412
|
+
errors.append(
|
|
1413
|
+
f"reference_pool seeds must not overlap with bootstrap/online/test pools. "
|
|
1414
|
+
f"Found overlapping seeds: {sorted(overlapping)}"
|
|
1415
|
+
)
|
|
1416
|
+
|
|
1417
|
+
# Raise all errors at once for better UX
|
|
1418
|
+
if errors:
|
|
1419
|
+
_raise_validation_errors(errors, config_path)
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
def _raise_validation_errors(errors: list[str], config_path: Path) -> None:
|
|
1423
|
+
"""Format and raise validation errors."""
|
|
1424
|
+
error_msg = (
|
|
1425
|
+
f"\n❌ Invalid prompt learning config: {config_path}\n\n"
|
|
1426
|
+
f"Found {len(errors)} error(s):\n\n"
|
|
1427
|
+
)
|
|
1428
|
+
|
|
1429
|
+
for i, error in enumerate(errors, 1):
|
|
1430
|
+
# Indent multi-line errors
|
|
1431
|
+
indented_error = "\n ".join(error.split("\n"))
|
|
1432
|
+
error_msg += f"{i}. {indented_error}\n\n"
|
|
1433
|
+
|
|
1434
|
+
error_msg += (
|
|
1435
|
+
"📖 See example configs:\n"
|
|
1436
|
+
" - cookbooks/dev/blog_posts/gepa/configs/banking77_gepa_local.toml\n"
|
|
1437
|
+
" - cookbooks/dev/blog_posts/mipro/configs/banking77_mipro_local.toml\n"
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
raise click.ClickException(error_msg)
|
|
1441
|
+
|
|
1442
|
+
|
|
1443
|
+
def validate_rl_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
1444
|
+
"""
|
|
1445
|
+
Validate RL config BEFORE sending to backend.
|
|
1446
|
+
|
|
1447
|
+
Args:
|
|
1448
|
+
config_data: Parsed TOML/JSON config
|
|
1449
|
+
config_path: Path to config file (for error messages)
|
|
1450
|
+
|
|
1451
|
+
Raises:
|
|
1452
|
+
ConfigValidationError: If config is invalid
|
|
1453
|
+
click.ClickException: If validation fails (for CLI)
|
|
1454
|
+
"""
|
|
1455
|
+
errors: list[str] = []
|
|
1456
|
+
|
|
1457
|
+
# Check for rl section
|
|
1458
|
+
rl_section = config_data.get("rl") or config_data.get("online_rl")
|
|
1459
|
+
if not rl_section:
|
|
1460
|
+
errors.append(
|
|
1461
|
+
"Missing [rl] or [online_rl] section in config"
|
|
1462
|
+
)
|
|
1463
|
+
_raise_validation_errors(errors, config_path)
|
|
1464
|
+
return
|
|
1465
|
+
|
|
1466
|
+
# Validate algorithm
|
|
1467
|
+
algorithm = rl_section.get("algorithm")
|
|
1468
|
+
if not algorithm:
|
|
1469
|
+
errors.append(
|
|
1470
|
+
"Missing required field: rl.algorithm\n"
|
|
1471
|
+
" Must be one of: 'grpo', 'ppo', etc."
|
|
1472
|
+
)
|
|
1473
|
+
|
|
1474
|
+
# Validate task_url
|
|
1475
|
+
task_url = rl_section.get("task_url")
|
|
1476
|
+
if not task_url:
|
|
1477
|
+
errors.append(
|
|
1478
|
+
"Missing required field: rl.task_url"
|
|
1479
|
+
)
|
|
1480
|
+
elif not isinstance(task_url, str):
|
|
1481
|
+
errors.append(
|
|
1482
|
+
f"task_url must be a string, got {type(task_url).__name__}"
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
if errors:
|
|
1486
|
+
_raise_validation_errors(errors, config_path)
|
|
1487
|
+
|
|
1488
|
+
|
|
1489
|
+
def validate_sft_config(config_data: dict[str, Any], config_path: Path) -> None:
|
|
1490
|
+
"""
|
|
1491
|
+
Validate SFT config BEFORE sending to backend.
|
|
1492
|
+
|
|
1493
|
+
Args:
|
|
1494
|
+
config_data: Parsed TOML/JSON config
|
|
1495
|
+
config_path: Path to config file (for error messages)
|
|
1496
|
+
|
|
1497
|
+
Raises:
|
|
1498
|
+
ConfigValidationError: If config is invalid
|
|
1499
|
+
click.ClickException: If validation fails (for CLI)
|
|
1500
|
+
"""
|
|
1501
|
+
errors: list[str] = []
|
|
1502
|
+
|
|
1503
|
+
# Check for sft section
|
|
1504
|
+
sft_section = config_data.get("sft")
|
|
1505
|
+
if not sft_section:
|
|
1506
|
+
errors.append(
|
|
1507
|
+
"Missing [sft] section in config"
|
|
1508
|
+
)
|
|
1509
|
+
_raise_validation_errors(errors, config_path)
|
|
1510
|
+
return
|
|
1511
|
+
|
|
1512
|
+
# Validate model
|
|
1513
|
+
model = sft_section.get("model")
|
|
1514
|
+
if not model:
|
|
1515
|
+
errors.append(
|
|
1516
|
+
"Missing required field: sft.model"
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
if errors:
|
|
1520
|
+
_raise_validation_errors(errors, config_path)
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
def validate_gepa_config_from_file(config_path: Path) -> Tuple[bool, List[str]]:
|
|
1524
|
+
"""Validate GEPA config from TOML file with comprehensive checks.
|
|
1525
|
+
|
|
1526
|
+
Returns:
|
|
1527
|
+
(is_valid, errors) tuple where errors is a list of error messages
|
|
1528
|
+
"""
|
|
1529
|
+
errors = []
|
|
1530
|
+
|
|
1531
|
+
try:
|
|
1532
|
+
with open(config_path) as f:
|
|
1533
|
+
config_dict = toml.load(f)
|
|
1534
|
+
except Exception as e:
|
|
1535
|
+
return False, [f"Failed to parse TOML: {e}"]
|
|
1536
|
+
|
|
1537
|
+
pl_section = config_dict.get("prompt_learning", {})
|
|
1538
|
+
if not isinstance(pl_section, dict):
|
|
1539
|
+
errors.append("❌ [prompt_learning] section is missing or invalid")
|
|
1540
|
+
return False, errors
|
|
1541
|
+
|
|
1542
|
+
# Check algorithm
|
|
1543
|
+
algorithm = pl_section.get("algorithm")
|
|
1544
|
+
if algorithm != "gepa":
|
|
1545
|
+
errors.append(f"❌ Expected algorithm='gepa', got '{algorithm}'")
|
|
1546
|
+
|
|
1547
|
+
# Check required top-level fields (env_name is now in gepa section)
|
|
1548
|
+
required_top_level = ["task_app_url", "task_app_api_key"]
|
|
1549
|
+
for field in required_top_level:
|
|
1550
|
+
if not pl_section.get(field):
|
|
1551
|
+
errors.append(f"❌ [prompt_learning].{field} is required")
|
|
1552
|
+
|
|
1553
|
+
# Check GEPA section
|
|
1554
|
+
gepa_section = pl_section.get("gepa", {})
|
|
1555
|
+
if not isinstance(gepa_section, dict):
|
|
1556
|
+
errors.append("❌ [prompt_learning.gepa] section is missing or invalid")
|
|
1557
|
+
return False, errors
|
|
1558
|
+
|
|
1559
|
+
# Check env_name in gepa section (required)
|
|
1560
|
+
if not gepa_section.get("env_name"):
|
|
1561
|
+
errors.append("❌ [prompt_learning.gepa].env_name is required")
|
|
1562
|
+
|
|
1563
|
+
# Check required GEPA subsections
|
|
1564
|
+
required_sections = ["evaluation", "rollout", "mutation", "population", "archive", "token"]
|
|
1565
|
+
missing_sections = [s for s in required_sections if not gepa_section.get(s)]
|
|
1566
|
+
if missing_sections:
|
|
1567
|
+
errors.append(
|
|
1568
|
+
f"❌ Missing required GEPA sections: {', '.join(f'[prompt_learning.gepa.{s}]' for s in missing_sections)}"
|
|
1569
|
+
)
|
|
1570
|
+
|
|
1571
|
+
# Validate evaluation section
|
|
1572
|
+
eval_section = gepa_section.get("evaluation", {})
|
|
1573
|
+
if isinstance(eval_section, dict):
|
|
1574
|
+
# Check train_seeds (required, can be in eval section or top-level)
|
|
1575
|
+
train_seeds = (
|
|
1576
|
+
eval_section.get("train_seeds") or
|
|
1577
|
+
eval_section.get("seeds") or
|
|
1578
|
+
pl_section.get("train_seeds")
|
|
1579
|
+
)
|
|
1580
|
+
if not train_seeds:
|
|
1581
|
+
errors.append(
|
|
1582
|
+
"❌ train_seeds is required. "
|
|
1583
|
+
"Must be in [prompt_learning.gepa.evaluation].train_seeds or [prompt_learning].train_seeds"
|
|
1584
|
+
)
|
|
1585
|
+
elif not isinstance(train_seeds, list):
|
|
1586
|
+
errors.append(f"❌ train_seeds must be a list, got {type(train_seeds).__name__}")
|
|
1587
|
+
elif len(train_seeds) == 0:
|
|
1588
|
+
errors.append("❌ train_seeds cannot be empty")
|
|
1589
|
+
elif not all(isinstance(s, int) for s in train_seeds):
|
|
1590
|
+
errors.append("❌ train_seeds must contain only integers")
|
|
1591
|
+
|
|
1592
|
+
# Check val_seeds (required)
|
|
1593
|
+
val_seeds = eval_section.get("val_seeds") or eval_section.get("validation_seeds")
|
|
1594
|
+
if not val_seeds:
|
|
1595
|
+
errors.append(
|
|
1596
|
+
"❌ val_seeds is required in [prompt_learning.gepa.evaluation].val_seeds"
|
|
1597
|
+
)
|
|
1598
|
+
elif not isinstance(val_seeds, list):
|
|
1599
|
+
errors.append(f"❌ val_seeds must be a list, got {type(val_seeds).__name__}")
|
|
1600
|
+
elif len(val_seeds) == 0:
|
|
1601
|
+
errors.append("❌ val_seeds cannot be empty")
|
|
1602
|
+
elif not all(isinstance(s, int) for s in val_seeds):
|
|
1603
|
+
errors.append("❌ val_seeds must contain only integers")
|
|
1604
|
+
|
|
1605
|
+
# Check validation_pool (optional but should be valid if present)
|
|
1606
|
+
validation_pool = eval_section.get("validation_pool")
|
|
1607
|
+
if validation_pool is not None:
|
|
1608
|
+
if not isinstance(validation_pool, str):
|
|
1609
|
+
errors.append(f"❌ validation_pool must be a string, got {type(validation_pool).__name__}")
|
|
1610
|
+
elif validation_pool not in ("train", "test", "val", "validation"):
|
|
1611
|
+
errors.append(
|
|
1612
|
+
f"❌ validation_pool must be one of: train, test, val, validation. Got '{validation_pool}'"
|
|
1613
|
+
)
|
|
1614
|
+
|
|
1615
|
+
# Check validation_top_k (optional but should be valid if present)
|
|
1616
|
+
validation_top_k = eval_section.get("validation_top_k")
|
|
1617
|
+
if validation_top_k is not None:
|
|
1618
|
+
if not isinstance(validation_top_k, int):
|
|
1619
|
+
errors.append(f"❌ validation_top_k must be an integer, got {type(validation_top_k).__name__}")
|
|
1620
|
+
elif validation_top_k <= 0:
|
|
1621
|
+
errors.append(f"❌ validation_top_k must be > 0, got {validation_top_k}")
|
|
1622
|
+
|
|
1623
|
+
# Validate rollout section
|
|
1624
|
+
rollout_section = gepa_section.get("rollout", {})
|
|
1625
|
+
if isinstance(rollout_section, dict):
|
|
1626
|
+
budget = rollout_section.get("budget")
|
|
1627
|
+
if budget is None:
|
|
1628
|
+
errors.append("❌ [prompt_learning.gepa.rollout].budget is required")
|
|
1629
|
+
elif not isinstance(budget, int):
|
|
1630
|
+
errors.append(f"❌ rollout.budget must be an integer, got {type(budget).__name__}")
|
|
1631
|
+
elif budget <= 0:
|
|
1632
|
+
errors.append(f"❌ rollout.budget must be > 0, got {budget}")
|
|
1633
|
+
|
|
1634
|
+
max_concurrent = rollout_section.get("max_concurrent")
|
|
1635
|
+
if max_concurrent is not None:
|
|
1636
|
+
if not isinstance(max_concurrent, int):
|
|
1637
|
+
errors.append(f"❌ rollout.max_concurrent must be an integer, got {type(max_concurrent).__name__}")
|
|
1638
|
+
elif max_concurrent <= 0:
|
|
1639
|
+
errors.append(f"❌ rollout.max_concurrent must be > 0, got {max_concurrent}")
|
|
1640
|
+
|
|
1641
|
+
# Validate mutation section
|
|
1642
|
+
mutation_section = gepa_section.get("mutation", {})
|
|
1643
|
+
if isinstance(mutation_section, dict):
|
|
1644
|
+
required_mutation_fields = ["llm_model", "llm_provider"]
|
|
1645
|
+
for field in required_mutation_fields:
|
|
1646
|
+
if not mutation_section.get(field):
|
|
1647
|
+
errors.append(f"❌ [prompt_learning.gepa.mutation].{field} is required")
|
|
1648
|
+
|
|
1649
|
+
rate = mutation_section.get("rate")
|
|
1650
|
+
if rate is not None:
|
|
1651
|
+
if not isinstance(rate, int | float):
|
|
1652
|
+
errors.append(f"❌ mutation.rate must be a number, got {type(rate).__name__}")
|
|
1653
|
+
elif not (0.0 <= rate <= 1.0):
|
|
1654
|
+
errors.append(f"❌ mutation.rate must be between 0.0 and 1.0, got {rate}")
|
|
1655
|
+
|
|
1656
|
+
# Validate population section
|
|
1657
|
+
population_section = gepa_section.get("population", {})
|
|
1658
|
+
if isinstance(population_section, dict):
|
|
1659
|
+
initial_size = population_section.get("initial_size")
|
|
1660
|
+
if initial_size is not None:
|
|
1661
|
+
if not isinstance(initial_size, int):
|
|
1662
|
+
errors.append(f"❌ population.initial_size must be an integer, got {type(initial_size).__name__}")
|
|
1663
|
+
elif initial_size <= 0:
|
|
1664
|
+
errors.append(f"❌ population.initial_size must be > 0, got {initial_size}")
|
|
1665
|
+
|
|
1666
|
+
num_generations = population_section.get("num_generations")
|
|
1667
|
+
if num_generations is not None:
|
|
1668
|
+
if not isinstance(num_generations, int):
|
|
1669
|
+
errors.append(f"❌ population.num_generations must be an integer, got {type(num_generations).__name__}")
|
|
1670
|
+
elif num_generations <= 0:
|
|
1671
|
+
errors.append(f"❌ population.num_generations must be > 0, got {num_generations}")
|
|
1672
|
+
|
|
1673
|
+
# Validate archive section
|
|
1674
|
+
archive_section = gepa_section.get("archive", {})
|
|
1675
|
+
if isinstance(archive_section, dict):
|
|
1676
|
+
max_size = archive_section.get("max_size")
|
|
1677
|
+
if max_size is not None:
|
|
1678
|
+
if not isinstance(max_size, int):
|
|
1679
|
+
errors.append(f"❌ archive.max_size must be an integer, got {type(max_size).__name__}")
|
|
1680
|
+
elif max_size < 0:
|
|
1681
|
+
errors.append(f"❌ archive.max_size must be >= 0, got {max_size}")
|
|
1682
|
+
|
|
1683
|
+
# Validate token section
|
|
1684
|
+
token_section = gepa_section.get("token", {})
|
|
1685
|
+
if isinstance(token_section, dict):
|
|
1686
|
+
max_limit = token_section.get("max_limit")
|
|
1687
|
+
if max_limit is not None:
|
|
1688
|
+
if not isinstance(max_limit, int):
|
|
1689
|
+
errors.append(f"❌ token.max_limit must be an integer, got {type(max_limit).__name__}")
|
|
1690
|
+
elif max_limit <= 0:
|
|
1691
|
+
errors.append(f"❌ token.max_limit must be > 0, got {max_limit}")
|
|
1692
|
+
|
|
1693
|
+
# Check initial_prompt section
|
|
1694
|
+
initial_prompt = pl_section.get("initial_prompt", {})
|
|
1695
|
+
if not isinstance(initial_prompt, dict):
|
|
1696
|
+
errors.append("❌ [prompt_learning.initial_prompt] section is missing or invalid")
|
|
1697
|
+
else:
|
|
1698
|
+
if not initial_prompt.get("id"):
|
|
1699
|
+
errors.append("❌ [prompt_learning.initial_prompt].id is required")
|
|
1700
|
+
if not initial_prompt.get("messages"):
|
|
1701
|
+
errors.append("❌ [prompt_learning.initial_prompt].messages is required (must be a list)")
|
|
1702
|
+
elif not isinstance(initial_prompt.get("messages"), list):
|
|
1703
|
+
errors.append("❌ [prompt_learning.initial_prompt].messages must be a list")
|
|
1704
|
+
elif len(initial_prompt.get("messages", [])) == 0:
|
|
1705
|
+
errors.append("❌ [prompt_learning.initial_prompt].messages cannot be empty")
|
|
1706
|
+
|
|
1707
|
+
# Check policy section
|
|
1708
|
+
policy_section = pl_section.get("policy", {})
|
|
1709
|
+
if not isinstance(policy_section, dict):
|
|
1710
|
+
errors.append("❌ [prompt_learning.policy] section is missing or invalid")
|
|
1711
|
+
else:
|
|
1712
|
+
# Validate policy section - reject inference_url (backend requirement)
|
|
1713
|
+
if "inference_url" in policy_section:
|
|
1714
|
+
errors.append(
|
|
1715
|
+
"❌ inference_url must not be specified in [prompt_learning.policy]. "
|
|
1716
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
1717
|
+
"Remove inference_url from your config file."
|
|
1718
|
+
)
|
|
1719
|
+
if "api_base" in policy_section:
|
|
1720
|
+
errors.append(
|
|
1721
|
+
"❌ api_base must not be specified in [prompt_learning.policy]. "
|
|
1722
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
1723
|
+
"Remove api_base from your config file."
|
|
1724
|
+
)
|
|
1725
|
+
if "base_url" in policy_section:
|
|
1726
|
+
errors.append(
|
|
1727
|
+
"❌ base_url must not be specified in [prompt_learning.policy]. "
|
|
1728
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
1729
|
+
"Remove base_url from your config file."
|
|
1730
|
+
)
|
|
1731
|
+
|
|
1732
|
+
if not policy_section.get("model"):
|
|
1733
|
+
errors.append("❌ [prompt_learning.policy].model is required")
|
|
1734
|
+
if not policy_section.get("provider"):
|
|
1735
|
+
errors.append("❌ [prompt_learning.policy].provider is required")
|
|
1736
|
+
|
|
1737
|
+
# Validate proxy_models section (can be at top-level or gepa-specific)
|
|
1738
|
+
proxy_models_section = pl_section.get("proxy_models") or gepa_section.get("proxy_models")
|
|
1739
|
+
if proxy_models_section:
|
|
1740
|
+
if not isinstance(proxy_models_section, dict):
|
|
1741
|
+
errors.append("❌ proxy_models must be a table/dict when provided")
|
|
1742
|
+
else:
|
|
1743
|
+
required_fields = ["hi_provider", "hi_model", "lo_provider", "lo_model"]
|
|
1744
|
+
for field in required_fields:
|
|
1745
|
+
if not proxy_models_section.get(field):
|
|
1746
|
+
errors.append(f"❌ proxy_models.{field} is required")
|
|
1747
|
+
# Validate numeric fields
|
|
1748
|
+
for field, min_val in [("n_min_hi", 0), ("r2_thresh", 0.0), ("r2_stop", 0.0), ("sigma_max", 0.0), ("sigma_stop", 0.0), ("verify_every", 0)]:
|
|
1749
|
+
val = proxy_models_section.get(field)
|
|
1750
|
+
if val is not None:
|
|
1751
|
+
try:
|
|
1752
|
+
if field in ("r2_thresh", "r2_stop"):
|
|
1753
|
+
fval = float(val)
|
|
1754
|
+
if not (0.0 <= fval <= 1.0):
|
|
1755
|
+
errors.append(f"❌ proxy_models.{field} must be between 0.0 and 1.0, got {fval}")
|
|
1756
|
+
elif field.startswith("sigma"):
|
|
1757
|
+
fval = float(val)
|
|
1758
|
+
if fval < min_val:
|
|
1759
|
+
errors.append(f"❌ proxy_models.{field} must be >= {min_val}, got {fval}")
|
|
1760
|
+
else:
|
|
1761
|
+
ival = int(val)
|
|
1762
|
+
if ival < min_val:
|
|
1763
|
+
errors.append(f"❌ proxy_models.{field} must be >= {min_val}, got {ival}")
|
|
1764
|
+
except (TypeError, ValueError):
|
|
1765
|
+
errors.append(f"❌ proxy_models.{field} must be numeric, got {type(val).__name__}")
|
|
1766
|
+
# Validate provider/model combinations
|
|
1767
|
+
if proxy_models_section.get("hi_provider") and proxy_models_section.get("hi_model"):
|
|
1768
|
+
hi_errors = _validate_model_for_provider(
|
|
1769
|
+
proxy_models_section["hi_model"],
|
|
1770
|
+
proxy_models_section["hi_provider"],
|
|
1771
|
+
"proxy_models.hi_model",
|
|
1772
|
+
allow_nano=True,
|
|
1773
|
+
)
|
|
1774
|
+
errors.extend(hi_errors)
|
|
1775
|
+
if proxy_models_section.get("lo_provider") and proxy_models_section.get("lo_model"):
|
|
1776
|
+
lo_errors = _validate_model_for_provider(
|
|
1777
|
+
proxy_models_section["lo_model"],
|
|
1778
|
+
proxy_models_section["lo_provider"],
|
|
1779
|
+
"proxy_models.lo_model",
|
|
1780
|
+
allow_nano=True,
|
|
1781
|
+
)
|
|
1782
|
+
errors.extend(lo_errors)
|
|
1783
|
+
|
|
1784
|
+
# Validate adaptive_pool section (GEPA-specific)
|
|
1785
|
+
adaptive_pool_section = gepa_section.get("adaptive_pool")
|
|
1786
|
+
if adaptive_pool_section:
|
|
1787
|
+
_validate_adaptive_pool_config(adaptive_pool_section, "gepa.adaptive_pool", errors)
|
|
1788
|
+
|
|
1789
|
+
# Validate adaptive_batch section (GEPA-specific)
|
|
1790
|
+
adaptive_batch_section = gepa_section.get("adaptive_batch")
|
|
1791
|
+
if adaptive_batch_section:
|
|
1792
|
+
if not isinstance(adaptive_batch_section, dict):
|
|
1793
|
+
errors.append("❌ gepa.adaptive_batch must be a table/dict when provided")
|
|
1794
|
+
else:
|
|
1795
|
+
level = adaptive_batch_section.get("level")
|
|
1796
|
+
if level is not None:
|
|
1797
|
+
valid_levels = {"NONE", "LOW", "MODERATE", "HIGH"}
|
|
1798
|
+
if str(level).upper() not in valid_levels:
|
|
1799
|
+
errors.append(
|
|
1800
|
+
f"❌ gepa.adaptive_batch.level must be one of {valid_levels}, got '{level}'"
|
|
1801
|
+
)
|
|
1802
|
+
# Validate numeric fields
|
|
1803
|
+
for field, min_val in [
|
|
1804
|
+
("reflection_minibatch_size", 1),
|
|
1805
|
+
("val_subsample_size", 1),
|
|
1806
|
+
]:
|
|
1807
|
+
val = adaptive_batch_section.get(field)
|
|
1808
|
+
if val is not None:
|
|
1809
|
+
try:
|
|
1810
|
+
ival = int(val)
|
|
1811
|
+
if ival < min_val:
|
|
1812
|
+
errors.append(f"❌ gepa.adaptive_batch.{field} must be >= {min_val}, got {ival}")
|
|
1813
|
+
except (TypeError, ValueError):
|
|
1814
|
+
errors.append(f"❌ gepa.adaptive_batch.{field} must be an integer, got {type(val).__name__}")
|
|
1815
|
+
# Validate min_local_improvement
|
|
1816
|
+
min_improvement = adaptive_batch_section.get("min_local_improvement")
|
|
1817
|
+
if min_improvement is not None:
|
|
1818
|
+
try:
|
|
1819
|
+
float(min_improvement) # Just validate it's numeric
|
|
1820
|
+
except (TypeError, ValueError):
|
|
1821
|
+
errors.append(
|
|
1822
|
+
f"❌ gepa.adaptive_batch.min_local_improvement must be numeric, got {type(min_improvement).__name__}"
|
|
1823
|
+
)
|
|
1824
|
+
# Validate val_evaluation_mode
|
|
1825
|
+
val_mode = adaptive_batch_section.get("val_evaluation_mode")
|
|
1826
|
+
if val_mode is not None and val_mode not in ("full", "subsample"):
|
|
1827
|
+
errors.append(
|
|
1828
|
+
f"❌ gepa.adaptive_batch.val_evaluation_mode must be 'full' or 'subsample', got '{val_mode}'"
|
|
1829
|
+
)
|
|
1830
|
+
# Validate candidate_selection_strategy
|
|
1831
|
+
selection_strategy = adaptive_batch_section.get("candidate_selection_strategy")
|
|
1832
|
+
if selection_strategy is not None and selection_strategy not in ("coverage", "random"):
|
|
1833
|
+
errors.append(
|
|
1834
|
+
f"❌ gepa.adaptive_batch.candidate_selection_strategy must be 'coverage' or 'random', got '{selection_strategy}'"
|
|
1835
|
+
)
|
|
1836
|
+
# Validate val_evaluation_mode="subsample" requires val_subsample_size > 0
|
|
1837
|
+
val_mode = adaptive_batch_section.get("val_evaluation_mode")
|
|
1838
|
+
if val_mode == "subsample":
|
|
1839
|
+
subsample_size = adaptive_batch_section.get("val_subsample_size")
|
|
1840
|
+
if subsample_size is None:
|
|
1841
|
+
errors.append(
|
|
1842
|
+
"❌ gepa.adaptive_batch.val_evaluation_mode='subsample' requires val_subsample_size to be set"
|
|
1843
|
+
)
|
|
1844
|
+
elif isinstance(subsample_size, int | float) and subsample_size <= 0:
|
|
1845
|
+
errors.append(
|
|
1846
|
+
f"❌ gepa.adaptive_batch.val_subsample_size must be > 0 when val_evaluation_mode='subsample', got {subsample_size}"
|
|
1847
|
+
)
|
|
1848
|
+
|
|
1849
|
+
return len(errors) == 0, errors
|
|
1850
|
+
|
|
1851
|
+
|
|
1852
|
+
def validate_mipro_config_from_file(config_path: Path) -> Tuple[bool, List[str]]:
|
|
1853
|
+
"""Validate MIPRO config from TOML file with comprehensive checks.
|
|
1854
|
+
|
|
1855
|
+
Returns:
|
|
1856
|
+
(is_valid, errors) tuple where errors is a list of error messages
|
|
1857
|
+
"""
|
|
1858
|
+
errors = []
|
|
1859
|
+
|
|
1860
|
+
try:
|
|
1861
|
+
with open(config_path) as f:
|
|
1862
|
+
config_dict = toml.load(f)
|
|
1863
|
+
except Exception as e:
|
|
1864
|
+
return False, [f"Failed to parse TOML: {e}"]
|
|
1865
|
+
|
|
1866
|
+
pl_section = config_dict.get("prompt_learning", {})
|
|
1867
|
+
if not isinstance(pl_section, dict):
|
|
1868
|
+
errors.append("❌ [prompt_learning] section is missing or invalid")
|
|
1869
|
+
return False, errors
|
|
1870
|
+
|
|
1871
|
+
# Check algorithm
|
|
1872
|
+
algorithm = pl_section.get("algorithm")
|
|
1873
|
+
if algorithm != "mipro":
|
|
1874
|
+
errors.append(f"❌ Expected algorithm='mipro', got '{algorithm}'")
|
|
1875
|
+
|
|
1876
|
+
# Check required top-level fields
|
|
1877
|
+
required_top_level = ["task_app_url", "task_app_api_key"]
|
|
1878
|
+
for field in required_top_level:
|
|
1879
|
+
if not pl_section.get(field):
|
|
1880
|
+
errors.append(f"❌ [prompt_learning].{field} is required")
|
|
1881
|
+
|
|
1882
|
+
# Check env_name (required - can be at top level or in mipro section)
|
|
1883
|
+
env_name = pl_section.get("env_name") or pl_section.get("task_app_id")
|
|
1884
|
+
mipro_section = pl_section.get("mipro", {})
|
|
1885
|
+
if isinstance(mipro_section, dict):
|
|
1886
|
+
env_name = env_name or mipro_section.get("env_name")
|
|
1887
|
+
if not env_name:
|
|
1888
|
+
errors.append(
|
|
1889
|
+
"❌ env_name is required. "
|
|
1890
|
+
"Must be in [prompt_learning].env_name, [prompt_learning].task_app_id, or [prompt_learning.mipro].env_name"
|
|
1891
|
+
)
|
|
1892
|
+
|
|
1893
|
+
# Check MIPRO section
|
|
1894
|
+
if not isinstance(mipro_section, dict):
|
|
1895
|
+
errors.append("❌ [prompt_learning.mipro] section is missing or invalid")
|
|
1896
|
+
return False, errors
|
|
1897
|
+
|
|
1898
|
+
# Validate policy section - reject inference_url
|
|
1899
|
+
policy_section = pl_section.get("policy", {})
|
|
1900
|
+
if isinstance(policy_section, dict):
|
|
1901
|
+
if "inference_url" in policy_section:
|
|
1902
|
+
errors.append(
|
|
1903
|
+
"❌ inference_url must not be specified in [prompt_learning.policy]. "
|
|
1904
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
1905
|
+
"Remove inference_url from your config file."
|
|
1906
|
+
)
|
|
1907
|
+
if "api_base" in policy_section:
|
|
1908
|
+
errors.append(
|
|
1909
|
+
"❌ api_base must not be specified in [prompt_learning.policy]. "
|
|
1910
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
1911
|
+
"Remove api_base from your config file."
|
|
1912
|
+
)
|
|
1913
|
+
if "base_url" in policy_section:
|
|
1914
|
+
errors.append(
|
|
1915
|
+
"❌ base_url must not be specified in [prompt_learning.policy]. "
|
|
1916
|
+
"The trainer provides the inference URL in rollout requests. "
|
|
1917
|
+
"Remove base_url from your config file."
|
|
1918
|
+
)
|
|
1919
|
+
|
|
1920
|
+
# CRITICAL: Validate bootstrap_train_seeds and online_pool (can be at top level or under mipro)
|
|
1921
|
+
bootstrap_seeds = (
|
|
1922
|
+
mipro_section.get("bootstrap_train_seeds") or
|
|
1923
|
+
pl_section.get("bootstrap_train_seeds")
|
|
1924
|
+
)
|
|
1925
|
+
if not bootstrap_seeds:
|
|
1926
|
+
errors.append(
|
|
1927
|
+
"❌ bootstrap_train_seeds is required. "
|
|
1928
|
+
"Must be in [prompt_learning].bootstrap_train_seeds or [prompt_learning.mipro].bootstrap_train_seeds"
|
|
1929
|
+
)
|
|
1930
|
+
elif not isinstance(bootstrap_seeds, list):
|
|
1931
|
+
errors.append(f"❌ bootstrap_train_seeds must be a list, got {type(bootstrap_seeds).__name__}")
|
|
1932
|
+
elif len(bootstrap_seeds) == 0:
|
|
1933
|
+
errors.append("❌ bootstrap_train_seeds cannot be empty")
|
|
1934
|
+
elif not all(isinstance(s, int) for s in bootstrap_seeds):
|
|
1935
|
+
errors.append("❌ bootstrap_train_seeds must contain only integers")
|
|
1936
|
+
|
|
1937
|
+
online_pool = (
|
|
1938
|
+
mipro_section.get("online_pool") or
|
|
1939
|
+
pl_section.get("online_pool")
|
|
1940
|
+
)
|
|
1941
|
+
if not online_pool:
|
|
1942
|
+
errors.append(
|
|
1943
|
+
"❌ online_pool is required. "
|
|
1944
|
+
"Must be in [prompt_learning].online_pool or [prompt_learning.mipro].online_pool"
|
|
1945
|
+
)
|
|
1946
|
+
elif not isinstance(online_pool, list):
|
|
1947
|
+
errors.append(f"❌ online_pool must be a list, got {type(online_pool).__name__}")
|
|
1948
|
+
elif len(online_pool) == 0:
|
|
1949
|
+
errors.append("❌ online_pool cannot be empty")
|
|
1950
|
+
elif not all(isinstance(s, int) for s in online_pool):
|
|
1951
|
+
errors.append("❌ online_pool must contain only integers")
|
|
1952
|
+
|
|
1953
|
+
# CRITICAL: Validate reference_pool is required (backend requires it)
|
|
1954
|
+
reference_pool = (
|
|
1955
|
+
mipro_section.get("reference_pool") or
|
|
1956
|
+
pl_section.get("reference_pool")
|
|
1957
|
+
)
|
|
1958
|
+
if not reference_pool:
|
|
1959
|
+
errors.append(
|
|
1960
|
+
"❌ reference_pool is required for MIPRO. "
|
|
1961
|
+
"reference_pool seeds are used to build the reference corpus for meta-prompt context. "
|
|
1962
|
+
"Add reference_pool at [prompt_learning] or [prompt_learning.mipro] level. "
|
|
1963
|
+
"Example: reference_pool = [30, 31, 32, 33, 34]"
|
|
1964
|
+
)
|
|
1965
|
+
elif not isinstance(reference_pool, list):
|
|
1966
|
+
errors.append(f"❌ reference_pool must be a list, got {type(reference_pool).__name__}")
|
|
1967
|
+
elif len(reference_pool) == 0:
|
|
1968
|
+
errors.append("❌ reference_pool cannot be empty")
|
|
1969
|
+
elif not all(isinstance(s, int) for s in reference_pool):
|
|
1970
|
+
errors.append("❌ reference_pool must contain only integers")
|
|
1971
|
+
else:
|
|
1972
|
+
# Validate reference pool doesn't overlap with bootstrap/online/test pools
|
|
1973
|
+
test_pool = (
|
|
1974
|
+
mipro_section.get("test_pool") or
|
|
1975
|
+
pl_section.get("test_pool") or
|
|
1976
|
+
[]
|
|
1977
|
+
)
|
|
1978
|
+
all_train_test = set(bootstrap_seeds or []) | set(online_pool or []) | set(test_pool)
|
|
1979
|
+
overlapping = set(reference_pool) & all_train_test
|
|
1980
|
+
if overlapping:
|
|
1981
|
+
errors.append(
|
|
1982
|
+
f"❌ reference_pool seeds must not overlap with bootstrap/online/test pools. "
|
|
1983
|
+
f"Found overlapping seeds: {sorted(overlapping)}"
|
|
1984
|
+
)
|
|
1985
|
+
|
|
1986
|
+
# Validate required numeric fields
|
|
1987
|
+
required_numeric_fields = [
|
|
1988
|
+
"num_iterations",
|
|
1989
|
+
"num_evaluations_per_iteration",
|
|
1990
|
+
"batch_size",
|
|
1991
|
+
"max_concurrent",
|
|
1992
|
+
]
|
|
1993
|
+
for field in required_numeric_fields:
|
|
1994
|
+
val = mipro_section.get(field)
|
|
1995
|
+
if val is None:
|
|
1996
|
+
errors.append(f"❌ [prompt_learning.mipro].{field} is required")
|
|
1997
|
+
elif not isinstance(val, int):
|
|
1998
|
+
errors.append(f"❌ mipro.{field} must be an integer, got {type(val).__name__}")
|
|
1999
|
+
elif val <= 0:
|
|
2000
|
+
errors.append(f"❌ mipro.{field} must be > 0, got {val}")
|
|
2001
|
+
|
|
2002
|
+
# Validate optional numeric fields
|
|
2003
|
+
optional_numeric_fields = [
|
|
2004
|
+
("max_demo_set_size", True),
|
|
2005
|
+
("max_demo_sets", True),
|
|
2006
|
+
("max_instruction_sets", True),
|
|
2007
|
+
("full_eval_every_k", True),
|
|
2008
|
+
("instructions_per_batch", True),
|
|
2009
|
+
("max_instructions", True),
|
|
2010
|
+
("duplicate_retry_limit", True),
|
|
2011
|
+
]
|
|
2012
|
+
for field, must_be_positive in optional_numeric_fields:
|
|
2013
|
+
val = mipro_section.get(field)
|
|
2014
|
+
if val is not None:
|
|
2015
|
+
if not isinstance(val, int):
|
|
2016
|
+
errors.append(f"❌ mipro.{field} must be an integer, got {type(val).__name__}")
|
|
2017
|
+
elif must_be_positive and val <= 0:
|
|
2018
|
+
errors.append(f"❌ mipro.{field} must be > 0, got {val}")
|
|
2019
|
+
elif not must_be_positive and val < 0:
|
|
2020
|
+
errors.append(f"❌ mipro.{field} must be >= 0, got {val}")
|
|
2021
|
+
|
|
2022
|
+
# Validate meta_model if set (optional - backend applies defaults)
|
|
2023
|
+
meta_model = mipro_section.get("meta_model")
|
|
2024
|
+
meta_model_provider = mipro_section.get("meta_model_provider", "").strip()
|
|
2025
|
+
if meta_model:
|
|
2026
|
+
# If meta_model is explicitly set, validate it
|
|
2027
|
+
if not meta_model_provider:
|
|
2028
|
+
errors.append(
|
|
2029
|
+
"❌ [prompt_learning.mipro].meta_model_provider is required when meta_model is set"
|
|
2030
|
+
)
|
|
2031
|
+
else:
|
|
2032
|
+
errors.extend(_validate_model_for_provider(
|
|
2033
|
+
meta_model, meta_model_provider, "prompt_learning.mipro.meta_model", allow_nano=False
|
|
2034
|
+
))
|
|
2035
|
+
# If meta_model is not set, backend will use defaults (llama-3.3-70b-versatile/groq)
|
|
2036
|
+
|
|
2037
|
+
# Validate meta model temperature
|
|
2038
|
+
meta_temperature = mipro_section.get("meta_model_temperature")
|
|
2039
|
+
if meta_temperature is not None:
|
|
2040
|
+
if not isinstance(meta_temperature, int | float):
|
|
2041
|
+
errors.append(f"❌ mipro.meta_model_temperature must be numeric, got {type(meta_temperature).__name__}")
|
|
2042
|
+
else:
|
|
2043
|
+
temp = float(meta_temperature)
|
|
2044
|
+
if temp < 0.0:
|
|
2045
|
+
errors.append(f"❌ mipro.meta_model_temperature must be >= 0.0, got {temp}")
|
|
2046
|
+
|
|
2047
|
+
# Validate meta model max_tokens
|
|
2048
|
+
meta_max_tokens = mipro_section.get("meta_model_max_tokens")
|
|
2049
|
+
if meta_max_tokens is not None and not isinstance(meta_max_tokens, int):
|
|
2050
|
+
errors.append(f"❌ mipro.meta_model_max_tokens must be an integer, got {type(meta_max_tokens).__name__}")
|
|
2051
|
+
|
|
2052
|
+
# Validate proposer_effort (can be in instructions section or top-level mipro section)
|
|
2053
|
+
instructions_section = mipro_section.get("instructions", {})
|
|
2054
|
+
if not isinstance(instructions_section, dict):
|
|
2055
|
+
instructions_section = {}
|
|
2056
|
+
proposer_effort = str(
|
|
2057
|
+
instructions_section.get("proposer_effort") or
|
|
2058
|
+
mipro_section.get("proposer_effort") or
|
|
2059
|
+
"LOW"
|
|
2060
|
+
).upper()
|
|
2061
|
+
valid_effort_levels = {"LOW_CONTEXT", "LOW", "MEDIUM", "HIGH"}
|
|
2062
|
+
if proposer_effort not in valid_effort_levels:
|
|
2063
|
+
errors.append(
|
|
2064
|
+
f"❌ Invalid proposer_effort: '{proposer_effort}'\n"
|
|
2065
|
+
f" Must be one of: {', '.join(sorted(valid_effort_levels))}\n"
|
|
2066
|
+
f" Got: '{proposer_effort}'"
|
|
2067
|
+
)
|
|
2068
|
+
|
|
2069
|
+
# Validate proposer_output_tokens (can be in instructions section or top-level mipro section)
|
|
2070
|
+
proposer_output_tokens = str(
|
|
2071
|
+
instructions_section.get("proposer_output_tokens") or
|
|
2072
|
+
mipro_section.get("proposer_output_tokens") or
|
|
2073
|
+
"FAST"
|
|
2074
|
+
).upper()
|
|
2075
|
+
valid_output_tokens = {"RAPID", "FAST", "SLOW"}
|
|
2076
|
+
if proposer_output_tokens not in valid_output_tokens:
|
|
2077
|
+
errors.append(
|
|
2078
|
+
f"❌ Invalid proposer_output_tokens: '{proposer_output_tokens}'\n"
|
|
2079
|
+
f" Must be one of: {', '.join(sorted(valid_output_tokens))}\n"
|
|
2080
|
+
f" Got: '{proposer_output_tokens}'"
|
|
2081
|
+
)
|
|
2082
|
+
|
|
2083
|
+
# Note: RAPID can now be used with any proposer_effort level (5000 tokens)
|
|
2084
|
+
|
|
2085
|
+
# Validate meta_max_tokens if present
|
|
2086
|
+
meta_max_tokens = mipro_section.get("meta_model_max_tokens")
|
|
2087
|
+
if meta_max_tokens is not None:
|
|
2088
|
+
try:
|
|
2089
|
+
meta_max_tokens_val = int(meta_max_tokens)
|
|
2090
|
+
if meta_max_tokens_val <= 0:
|
|
2091
|
+
errors.append(f"❌ mipro.meta_model_max_tokens must be > 0, got {meta_max_tokens_val}")
|
|
2092
|
+
except (TypeError, ValueError):
|
|
2093
|
+
errors.append(f"❌ mipro.meta_model_max_tokens must be an integer, got {type(meta_max_tokens).__name__}")
|
|
2094
|
+
|
|
2095
|
+
# Validate generate_at_iterations
|
|
2096
|
+
generate_at_iterations = mipro_section.get("generate_at_iterations")
|
|
2097
|
+
if generate_at_iterations is not None:
|
|
2098
|
+
if not isinstance(generate_at_iterations, list):
|
|
2099
|
+
errors.append(f"❌ mipro.generate_at_iterations must be a list, got {type(generate_at_iterations).__name__}")
|
|
2100
|
+
else:
|
|
2101
|
+
for idx, iter_val in enumerate(generate_at_iterations):
|
|
2102
|
+
try:
|
|
2103
|
+
iter_int = int(iter_val)
|
|
2104
|
+
if iter_int < 0:
|
|
2105
|
+
errors.append(
|
|
2106
|
+
f"❌ mipro.generate_at_iterations[{idx}] must be >= 0, got {iter_int}"
|
|
2107
|
+
)
|
|
2108
|
+
except Exception:
|
|
2109
|
+
errors.append(
|
|
2110
|
+
f"❌ mipro.generate_at_iterations[{idx}] must be an integer, got {iter_val!r}"
|
|
2111
|
+
)
|
|
2112
|
+
|
|
2113
|
+
# Validate spec configuration
|
|
2114
|
+
spec_path = mipro_section.get("spec_path")
|
|
2115
|
+
if spec_path:
|
|
2116
|
+
# Validate spec_max_tokens if provided
|
|
2117
|
+
spec_max_tokens = mipro_section.get("spec_max_tokens")
|
|
2118
|
+
if spec_max_tokens is not None:
|
|
2119
|
+
if not isinstance(spec_max_tokens, int):
|
|
2120
|
+
errors.append(f"❌ mipro.spec_max_tokens must be an integer, got {type(spec_max_tokens).__name__}")
|
|
2121
|
+
elif spec_max_tokens <= 0:
|
|
2122
|
+
errors.append(f"❌ mipro.spec_max_tokens must be > 0, got {spec_max_tokens}")
|
|
2123
|
+
|
|
2124
|
+
# Validate spec_priority_threshold if provided
|
|
2125
|
+
spec_priority_threshold = mipro_section.get("spec_priority_threshold")
|
|
2126
|
+
if spec_priority_threshold is not None:
|
|
2127
|
+
if not isinstance(spec_priority_threshold, int):
|
|
2128
|
+
errors.append(f"❌ mipro.spec_priority_threshold must be an integer, got {type(spec_priority_threshold).__name__}")
|
|
2129
|
+
elif spec_priority_threshold < 0:
|
|
2130
|
+
errors.append(f"❌ mipro.spec_priority_threshold must be >= 0, got {spec_priority_threshold}")
|
|
2131
|
+
|
|
2132
|
+
# Validate few_shot_score_threshold
|
|
2133
|
+
few_shot_score_threshold = mipro_section.get("few_shot_score_threshold")
|
|
2134
|
+
if few_shot_score_threshold is not None:
|
|
2135
|
+
if not isinstance(few_shot_score_threshold, int | float):
|
|
2136
|
+
errors.append(f"❌ mipro.few_shot_score_threshold must be numeric, got {type(few_shot_score_threshold).__name__}")
|
|
2137
|
+
else:
|
|
2138
|
+
threshold = float(few_shot_score_threshold)
|
|
2139
|
+
if not (0.0 <= threshold <= 1.0):
|
|
2140
|
+
errors.append(f"❌ mipro.few_shot_score_threshold must be between 0.0 and 1.0, got {threshold}")
|
|
2141
|
+
|
|
2142
|
+
# Validate modules/stages configuration
|
|
2143
|
+
modules_config = mipro_section.get("modules")
|
|
2144
|
+
if modules_config is not None:
|
|
2145
|
+
if not isinstance(modules_config, list):
|
|
2146
|
+
errors.append(f"❌ mipro.modules must be a list, got {type(modules_config).__name__}")
|
|
2147
|
+
else:
|
|
2148
|
+
max_instruction_sets = mipro_section.get("max_instruction_sets", 128)
|
|
2149
|
+
max_demo_sets = mipro_section.get("max_demo_sets", 128)
|
|
2150
|
+
seen_module_ids = set()
|
|
2151
|
+
seen_stage_ids = set()
|
|
2152
|
+
|
|
2153
|
+
for module_idx, module_entry in enumerate(modules_config):
|
|
2154
|
+
if not isinstance(module_entry, dict):
|
|
2155
|
+
errors.append(
|
|
2156
|
+
f"❌ mipro.modules[{module_idx}] must be a table/dict, got {type(module_entry).__name__}"
|
|
2157
|
+
)
|
|
2158
|
+
continue
|
|
2159
|
+
|
|
2160
|
+
module_id = module_entry.get("module_id") or module_entry.get("id") or f"module_{module_idx}"
|
|
2161
|
+
if module_id in seen_module_ids:
|
|
2162
|
+
errors.append(
|
|
2163
|
+
f"❌ Duplicate module_id '{module_id}' in mipro.modules"
|
|
2164
|
+
)
|
|
2165
|
+
seen_module_ids.add(module_id)
|
|
2166
|
+
|
|
2167
|
+
# Validate stages
|
|
2168
|
+
stages = module_entry.get("stages")
|
|
2169
|
+
if stages is not None:
|
|
2170
|
+
if not isinstance(stages, list):
|
|
2171
|
+
errors.append(
|
|
2172
|
+
f"❌ mipro.modules[{module_idx}].stages must be a list, got {type(stages).__name__}"
|
|
2173
|
+
)
|
|
2174
|
+
else:
|
|
2175
|
+
for stage_idx, stage_entry in enumerate(stages):
|
|
2176
|
+
if isinstance(stage_entry, dict):
|
|
2177
|
+
stage_id = stage_entry.get("stage_id") or stage_entry.get("module_stage_id") or f"stage_{stage_idx}"
|
|
2178
|
+
if stage_id in seen_stage_ids:
|
|
2179
|
+
errors.append(
|
|
2180
|
+
f"❌ Duplicate stage_id '{stage_id}' across modules"
|
|
2181
|
+
)
|
|
2182
|
+
seen_stage_ids.add(stage_id)
|
|
2183
|
+
|
|
2184
|
+
# Validate max_instruction_slots <= max_instruction_sets
|
|
2185
|
+
max_instr_slots = stage_entry.get("max_instruction_slots")
|
|
2186
|
+
if max_instr_slots is not None:
|
|
2187
|
+
try:
|
|
2188
|
+
mis = int(max_instr_slots)
|
|
2189
|
+
if mis < 1:
|
|
2190
|
+
errors.append(
|
|
2191
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots must be >= 1, got {mis}"
|
|
2192
|
+
)
|
|
2193
|
+
elif mis > max_instruction_sets:
|
|
2194
|
+
errors.append(
|
|
2195
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots ({mis}) "
|
|
2196
|
+
f"exceeds max_instruction_sets ({max_instruction_sets})"
|
|
2197
|
+
)
|
|
2198
|
+
except Exception:
|
|
2199
|
+
errors.append(
|
|
2200
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}].max_instruction_slots must be an integer"
|
|
2201
|
+
)
|
|
2202
|
+
|
|
2203
|
+
# Validate max_demo_slots <= max_demo_sets
|
|
2204
|
+
max_demo_slots = stage_entry.get("max_demo_slots")
|
|
2205
|
+
if max_demo_slots is not None:
|
|
2206
|
+
try:
|
|
2207
|
+
mds = int(max_demo_slots)
|
|
2208
|
+
if mds < 0:
|
|
2209
|
+
errors.append(
|
|
2210
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots must be >= 0, got {mds}"
|
|
2211
|
+
)
|
|
2212
|
+
elif mds > max_demo_sets:
|
|
2213
|
+
errors.append(
|
|
2214
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots ({mds}) "
|
|
2215
|
+
f"exceeds max_demo_sets ({max_demo_sets})"
|
|
2216
|
+
)
|
|
2217
|
+
except Exception:
|
|
2218
|
+
errors.append(
|
|
2219
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}].max_demo_slots must be an integer"
|
|
2220
|
+
)
|
|
2221
|
+
|
|
2222
|
+
# Validate per-stage policy config (REQUIRED)
|
|
2223
|
+
stage_policy = stage_entry.get("policy")
|
|
2224
|
+
if stage_policy is None:
|
|
2225
|
+
errors.append(
|
|
2226
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy] table is REQUIRED. "
|
|
2227
|
+
f"Each stage must have its own policy configuration with 'model' and 'provider' fields."
|
|
2228
|
+
)
|
|
2229
|
+
elif not isinstance(stage_policy, dict):
|
|
2230
|
+
errors.append(
|
|
2231
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy] must be a table/dict, got {type(stage_policy).__name__}"
|
|
2232
|
+
)
|
|
2233
|
+
else:
|
|
2234
|
+
# Validate required fields in stage policy
|
|
2235
|
+
if not stage_policy.get("model"):
|
|
2236
|
+
errors.append(
|
|
2237
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy].model is required"
|
|
2238
|
+
)
|
|
2239
|
+
if not stage_policy.get("provider"):
|
|
2240
|
+
errors.append(
|
|
2241
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy].provider is required"
|
|
2242
|
+
)
|
|
2243
|
+
# Validate model/provider combination
|
|
2244
|
+
stage_model = stage_policy.get("model")
|
|
2245
|
+
stage_provider = stage_policy.get("provider")
|
|
2246
|
+
if stage_model and stage_provider:
|
|
2247
|
+
errors.extend(_validate_model_for_provider(
|
|
2248
|
+
stage_model, stage_provider,
|
|
2249
|
+
f"prompt_learning.mipro.modules[{module_idx}].stages[{stage_idx}].policy.model",
|
|
2250
|
+
allow_nano=True, # Policy models can be nano
|
|
2251
|
+
))
|
|
2252
|
+
# Reject inference_url in stage policy (trainer provides it)
|
|
2253
|
+
if "inference_url" in stage_policy:
|
|
2254
|
+
errors.append(
|
|
2255
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy].inference_url must not be specified. "
|
|
2256
|
+
f"The trainer provides the inference URL in rollout requests. Remove inference_url from stage policy."
|
|
2257
|
+
)
|
|
2258
|
+
if "api_base" in stage_policy:
|
|
2259
|
+
errors.append(
|
|
2260
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy].api_base must not be specified. "
|
|
2261
|
+
f"Remove api_base from stage policy."
|
|
2262
|
+
)
|
|
2263
|
+
if "base_url" in stage_policy:
|
|
2264
|
+
errors.append(
|
|
2265
|
+
f"❌ mipro.modules[{module_idx}].stages[{stage_idx}]: [policy].base_url must not be specified. "
|
|
2266
|
+
f"Remove base_url from stage policy."
|
|
2267
|
+
)
|
|
2268
|
+
|
|
2269
|
+
# Validate edges reference valid stages
|
|
2270
|
+
edges = module_entry.get("edges")
|
|
2271
|
+
if edges is not None:
|
|
2272
|
+
if not isinstance(edges, list):
|
|
2273
|
+
errors.append(
|
|
2274
|
+
f"❌ mipro.modules[{module_idx}].edges must be a list, got {type(edges).__name__}"
|
|
2275
|
+
)
|
|
2276
|
+
else:
|
|
2277
|
+
stage_ids_in_module = set()
|
|
2278
|
+
if stages and isinstance(stages, list):
|
|
2279
|
+
for stage_entry in stages:
|
|
2280
|
+
if isinstance(stage_entry, dict):
|
|
2281
|
+
sid = stage_entry.get("stage_id") or stage_entry.get("module_stage_id")
|
|
2282
|
+
if sid:
|
|
2283
|
+
stage_ids_in_module.add(str(sid))
|
|
2284
|
+
|
|
2285
|
+
for edge_idx, edge in enumerate(edges):
|
|
2286
|
+
if isinstance(edge, list | tuple) and len(edge) == 2:
|
|
2287
|
+
source, target = edge
|
|
2288
|
+
elif isinstance(edge, dict):
|
|
2289
|
+
source = edge.get("from") or edge.get("source")
|
|
2290
|
+
target = edge.get("to") or edge.get("target")
|
|
2291
|
+
else:
|
|
2292
|
+
errors.append(
|
|
2293
|
+
f"❌ mipro.modules[{module_idx}].edges[{edge_idx}] must be a pair or mapping"
|
|
2294
|
+
)
|
|
2295
|
+
continue
|
|
2296
|
+
|
|
2297
|
+
source_str = str(source or "").strip()
|
|
2298
|
+
target_str = str(target or "").strip()
|
|
2299
|
+
if source_str and source_str not in stage_ids_in_module:
|
|
2300
|
+
errors.append(
|
|
2301
|
+
f"❌ mipro.modules[{module_idx}].edges[{edge_idx}] references unknown source stage '{source_str}'"
|
|
2302
|
+
)
|
|
2303
|
+
if target_str and target_str not in stage_ids_in_module:
|
|
2304
|
+
errors.append(
|
|
2305
|
+
f"❌ mipro.modules[{module_idx}].edges[{edge_idx}] references unknown target stage '{target_str}'"
|
|
2306
|
+
)
|
|
2307
|
+
|
|
2308
|
+
# Check initial_prompt section
|
|
2309
|
+
initial_prompt = pl_section.get("initial_prompt", {})
|
|
2310
|
+
if not isinstance(initial_prompt, dict):
|
|
2311
|
+
errors.append("❌ [prompt_learning.initial_prompt] section is missing or invalid")
|
|
2312
|
+
else:
|
|
2313
|
+
if not initial_prompt.get("id"):
|
|
2314
|
+
errors.append("❌ [prompt_learning.initial_prompt].id is required")
|
|
2315
|
+
if not initial_prompt.get("messages"):
|
|
2316
|
+
errors.append("❌ [prompt_learning.initial_prompt].messages is required (must be a list)")
|
|
2317
|
+
elif not isinstance(initial_prompt.get("messages"), list):
|
|
2318
|
+
errors.append("❌ [prompt_learning.initial_prompt].messages must be a list")
|
|
2319
|
+
elif len(initial_prompt.get("messages", [])) == 0:
|
|
2320
|
+
errors.append("❌ [prompt_learning.initial_prompt].messages cannot be empty")
|
|
2321
|
+
|
|
2322
|
+
# Check policy section
|
|
2323
|
+
if not isinstance(policy_section, dict):
|
|
2324
|
+
errors.append("❌ [prompt_learning.policy] section is missing or invalid")
|
|
2325
|
+
else:
|
|
2326
|
+
if not policy_section.get("model"):
|
|
2327
|
+
errors.append("❌ [prompt_learning.policy].model is required")
|
|
2328
|
+
if not policy_section.get("provider"):
|
|
2329
|
+
errors.append("❌ [prompt_learning.policy].provider is required")
|
|
2330
|
+
|
|
2331
|
+
# Validate proxy_models section (can be at top-level or mipro-specific)
|
|
2332
|
+
proxy_models_section = pl_section.get("proxy_models") or mipro_section.get("proxy_models")
|
|
2333
|
+
if proxy_models_section:
|
|
2334
|
+
if not isinstance(proxy_models_section, dict):
|
|
2335
|
+
errors.append("❌ proxy_models must be a table/dict when provided")
|
|
2336
|
+
else:
|
|
2337
|
+
required_fields = ["hi_provider", "hi_model", "lo_provider", "lo_model"]
|
|
2338
|
+
for field in required_fields:
|
|
2339
|
+
if not proxy_models_section.get(field):
|
|
2340
|
+
errors.append(f"❌ proxy_models.{field} is required")
|
|
2341
|
+
# Validate numeric fields (same as GEPA)
|
|
2342
|
+
for field, min_val in [("n_min_hi", 0), ("r2_thresh", 0.0), ("r2_stop", 0.0), ("sigma_max", 0.0), ("sigma_stop", 0.0), ("verify_every", 0)]:
|
|
2343
|
+
val = proxy_models_section.get(field)
|
|
2344
|
+
if val is not None:
|
|
2345
|
+
try:
|
|
2346
|
+
if field in ("r2_thresh", "r2_stop"):
|
|
2347
|
+
fval = float(val)
|
|
2348
|
+
if not (0.0 <= fval <= 1.0):
|
|
2349
|
+
errors.append(f"❌ proxy_models.{field} must be between 0.0 and 1.0, got {fval}")
|
|
2350
|
+
elif field.startswith("sigma"):
|
|
2351
|
+
fval = float(val)
|
|
2352
|
+
if fval < min_val:
|
|
2353
|
+
errors.append(f"❌ proxy_models.{field} must be >= {min_val}, got {fval}")
|
|
2354
|
+
else:
|
|
2355
|
+
ival = int(val)
|
|
2356
|
+
if ival < min_val:
|
|
2357
|
+
errors.append(f"❌ proxy_models.{field} must be >= {min_val}, got {ival}")
|
|
2358
|
+
except (TypeError, ValueError):
|
|
2359
|
+
errors.append(f"❌ proxy_models.{field} must be numeric, got {type(val).__name__}")
|
|
2360
|
+
# Validate provider/model combinations
|
|
2361
|
+
if proxy_models_section.get("hi_provider") and proxy_models_section.get("hi_model"):
|
|
2362
|
+
hi_errors = _validate_model_for_provider(
|
|
2363
|
+
proxy_models_section["hi_model"],
|
|
2364
|
+
proxy_models_section["hi_provider"],
|
|
2365
|
+
"proxy_models.hi_model",
|
|
2366
|
+
allow_nano=True,
|
|
2367
|
+
)
|
|
2368
|
+
errors.extend(hi_errors)
|
|
2369
|
+
if proxy_models_section.get("lo_provider") and proxy_models_section.get("lo_model"):
|
|
2370
|
+
lo_errors = _validate_model_for_provider(
|
|
2371
|
+
proxy_models_section["lo_model"],
|
|
2372
|
+
proxy_models_section["lo_provider"],
|
|
2373
|
+
"proxy_models.lo_model",
|
|
2374
|
+
allow_nano=True,
|
|
2375
|
+
)
|
|
2376
|
+
errors.extend(lo_errors)
|
|
2377
|
+
|
|
2378
|
+
# Validate adaptive_pool section (MIPRO-specific, can be nested or flat)
|
|
2379
|
+
adaptive_pool_section = mipro_section.get("adaptive_pool")
|
|
2380
|
+
if adaptive_pool_section:
|
|
2381
|
+
_validate_adaptive_pool_config(adaptive_pool_section, "mipro.adaptive_pool", errors)
|
|
2382
|
+
|
|
2383
|
+
return len(errors) == 0, errors
|
|
2384
|
+
|
|
2385
|
+
|
|
2386
|
+
def validate_prompt_learning_config_from_file(config_path: Path, algorithm: str) -> None:
|
|
2387
|
+
"""Validate prompt learning config from TOML file and raise ConfigValidationError if invalid.
|
|
2388
|
+
|
|
2389
|
+
Args:
|
|
2390
|
+
config_path: Path to TOML config file
|
|
2391
|
+
algorithm: Either 'gepa' or 'mipro'
|
|
2392
|
+
|
|
2393
|
+
Raises:
|
|
2394
|
+
ConfigValidationError: If validation fails, with detailed error messages
|
|
2395
|
+
"""
|
|
2396
|
+
ctx: dict[str, Any] = {"config_path": str(config_path), "algorithm": algorithm}
|
|
2397
|
+
log_info("validate_prompt_learning_config_from_file invoked", ctx=ctx)
|
|
2398
|
+
if algorithm == "gepa":
|
|
2399
|
+
is_valid, errors = validate_gepa_config_from_file(config_path)
|
|
2400
|
+
elif algorithm == "mipro":
|
|
2401
|
+
is_valid, errors = validate_mipro_config_from_file(config_path)
|
|
2402
|
+
else:
|
|
2403
|
+
raise ValueError(f"Unknown algorithm: {algorithm}. Must be 'gepa' or 'mipro'")
|
|
2404
|
+
|
|
2405
|
+
if not is_valid:
|
|
2406
|
+
error_msg = "\n".join(errors)
|
|
2407
|
+
raise ConfigValidationError(
|
|
2408
|
+
f"\n{'=' * 80}\n"
|
|
2409
|
+
f"❌ Config Validation Failed ({algorithm.upper()})\n"
|
|
2410
|
+
f"{'=' * 80}\n"
|
|
2411
|
+
f"{error_msg}\n"
|
|
2412
|
+
f"{'=' * 80}\n"
|
|
2413
|
+
)
|
|
2414
|
+
|
|
2415
|
+
|
|
2416
|
+
__all__ = [
|
|
2417
|
+
"ConfigValidationError",
|
|
2418
|
+
"validate_prompt_learning_config",
|
|
2419
|
+
"validate_prompt_learning_config_from_file",
|
|
2420
|
+
"validate_gepa_config_from_file",
|
|
2421
|
+
"validate_mipro_config_from_file",
|
|
2422
|
+
"validate_rl_config",
|
|
2423
|
+
"validate_sft_config",
|
|
2424
|
+
]
|