synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +64 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +18 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +38 -0
- examples/qwen_coder/validate_jsonl.py +59 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth/__init__.py +14 -0
- synth_ai/__init__.py +20 -4
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1709 -243
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- synth_ai-0.2.9.dev6.dist-info/METADATA +191 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/RECORD +291 -264
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/top_level.txt +1 -0
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- examples/warming_up_to_rl/old/event_rewards.md +0 -234
- examples/warming_up_to_rl/old/notes.md +0 -73
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- synth_ai-0.2.9.dev4.dist-info/METADATA +0 -131
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
"""Catalog of Synth-hosted base models and helpers (core vs experimental)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
# ------------------------------------------------------------------------------
|
|
11
|
+
# Model families
|
|
12
|
+
# ------------------------------------------------------------------------------
|
|
13
|
+
|
|
14
|
+
QWEN3_MODELS: list[str] = [
|
|
15
|
+
# Core Qwen3 base models
|
|
16
|
+
"Qwen/Qwen3-0.6B",
|
|
17
|
+
"Qwen/Qwen3-1.7B",
|
|
18
|
+
"Qwen/Qwen3-4B",
|
|
19
|
+
"Qwen/Qwen3-8B",
|
|
20
|
+
"Qwen/Qwen3-14B",
|
|
21
|
+
"Qwen/Qwen3-30B-A3B",
|
|
22
|
+
"Qwen/Qwen3-32B",
|
|
23
|
+
# Include 4B-2507 and Thinking variants used in RL
|
|
24
|
+
"Qwen/Qwen3-4B-Thinking-2507",
|
|
25
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
26
|
+
"Qwen/Qwen3-235B-A22B-Thinking-2507",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
# Qwen3 Coder family (backend-supported); text-only, SFT/inference
|
|
30
|
+
QWEN3_CODER_MODELS: list[str] = [
|
|
31
|
+
# Instruct variants used for coding tasks
|
|
32
|
+
"Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
|
33
|
+
"Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
# Training support sets
|
|
37
|
+
RL_SUPPORTED_MODELS: frozenset[str] = frozenset(
|
|
38
|
+
{
|
|
39
|
+
"Qwen/Qwen3-1.7B",
|
|
40
|
+
"Qwen/Qwen3-4B",
|
|
41
|
+
"Qwen/Qwen3-4B-Thinking-2507",
|
|
42
|
+
"Qwen/Qwen3-8B",
|
|
43
|
+
"Qwen/Qwen3-14B",
|
|
44
|
+
"Qwen/Qwen3-30B-A3B",
|
|
45
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# SFT allowlist includes core Qwen3 plus Coder family
|
|
50
|
+
SFT_SUPPORTED_MODELS: frozenset[str] = frozenset([*QWEN3_MODELS, *QWEN3_CODER_MODELS])
|
|
51
|
+
|
|
52
|
+
# ------------------------------------------------------------------------------
|
|
53
|
+
# Lifecycle classification (core vs experimental)
|
|
54
|
+
# ------------------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
# Which base models are considered "experimental" by default.
|
|
57
|
+
_EXPERIMENTAL_DEFAULTS: frozenset[str] = frozenset(
|
|
58
|
+
{
|
|
59
|
+
# Larger (>= 64B) or bleeding-edge variants are experimental by default.
|
|
60
|
+
"Qwen/Qwen3-235B-A22B-Thinking-2507",
|
|
61
|
+
"Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
|
62
|
+
# Thinking variants can fluctuate more rapidly.
|
|
63
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
64
|
+
"Qwen/Qwen3-4B-Thinking-2507",
|
|
65
|
+
}
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _parse_experimental_env() -> frozenset[str]:
|
|
70
|
+
raw = os.getenv("SDK_EXPERIMENTAL_MODELS", "").strip()
|
|
71
|
+
if not raw:
|
|
72
|
+
return frozenset()
|
|
73
|
+
return frozenset(s.strip() for s in raw.split(",") if s.strip())
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Final experimental set (defaults ∪ optional env override)
|
|
77
|
+
EXPERIMENTAL_MODELS: frozenset[str] = frozenset(_EXPERIMENTAL_DEFAULTS | _parse_experimental_env())
|
|
78
|
+
|
|
79
|
+
# Build catalog entries for both core and coder families under unified "Qwen3"
|
|
80
|
+
_ALL_QWEN3_IDS: list[str] = [*QWEN3_MODELS, *QWEN3_CODER_MODELS]
|
|
81
|
+
|
|
82
|
+
CORE_MODELS: frozenset[str] = frozenset(m for m in _ALL_QWEN3_IDS if m not in EXPERIMENTAL_MODELS)
|
|
83
|
+
|
|
84
|
+
# ------------------------------------------------------------------------------
|
|
85
|
+
# Experimental gating / warnings
|
|
86
|
+
# ------------------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ExperimentalWarning(UserWarning):
|
|
90
|
+
"""Warning for usage of experimental SDK models/APIs."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _experimental_enabled() -> bool:
|
|
94
|
+
# Global toggle to permit experimental usage
|
|
95
|
+
return os.getenv("SDK_EXPERIMENTAL", "0") == "1"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _warn_if_experimental(model_id: str) -> None:
|
|
99
|
+
if model_id in EXPERIMENTAL_MODELS:
|
|
100
|
+
warnings.warn(
|
|
101
|
+
f"Model '{model_id}' is experimental and may change or be removed.",
|
|
102
|
+
category=ExperimentalWarning,
|
|
103
|
+
stacklevel=2,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# ------------------------------------------------------------------------------
|
|
108
|
+
# Model metadata + catalog
|
|
109
|
+
# ------------------------------------------------------------------------------
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass(frozen=True, slots=True)
|
|
113
|
+
class SupportedModel:
|
|
114
|
+
"""Metadata describing a supported base model."""
|
|
115
|
+
|
|
116
|
+
model_id: str
|
|
117
|
+
family: str
|
|
118
|
+
provider: str
|
|
119
|
+
modalities: tuple[str, ...] = ()
|
|
120
|
+
training_modes: tuple[str, ...] = ()
|
|
121
|
+
lifecycle: str = "core" # "core" | "experimental"
|
|
122
|
+
|
|
123
|
+
def as_dict(self) -> dict[str, object]:
|
|
124
|
+
data: dict[str, object] = {
|
|
125
|
+
"model_id": self.model_id,
|
|
126
|
+
"family": self.family,
|
|
127
|
+
"provider": self.provider,
|
|
128
|
+
"lifecycle": self.lifecycle,
|
|
129
|
+
}
|
|
130
|
+
if self.modalities:
|
|
131
|
+
data["modalities"] = list(self.modalities)
|
|
132
|
+
if self.training_modes:
|
|
133
|
+
data["training_modes"] = list(self.training_modes)
|
|
134
|
+
return data
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
SUPPORTED_MODELS: tuple[SupportedModel, ...] = tuple(
|
|
138
|
+
SupportedModel(
|
|
139
|
+
model_id=model,
|
|
140
|
+
family="Qwen3",
|
|
141
|
+
provider="Qwen",
|
|
142
|
+
modalities=("text",),
|
|
143
|
+
training_modes=tuple(
|
|
144
|
+
sorted(
|
|
145
|
+
{
|
|
146
|
+
*(("sft",) if model in SFT_SUPPORTED_MODELS else ()),
|
|
147
|
+
*(("rl",) if model in RL_SUPPORTED_MODELS else ()),
|
|
148
|
+
}
|
|
149
|
+
)
|
|
150
|
+
),
|
|
151
|
+
lifecycle=("experimental" if model in EXPERIMENTAL_MODELS else "core"),
|
|
152
|
+
)
|
|
153
|
+
for model in _ALL_QWEN3_IDS
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
_BASE_LOOKUP = {model.model_id.lower(): model.model_id for model in SUPPORTED_MODELS}
|
|
157
|
+
SUPPORTED_BASE_MODEL_IDS: frozenset[str] = frozenset(_BASE_LOOKUP.values())
|
|
158
|
+
FINE_TUNED_PREFIXES: tuple[str, ...] = ("ft:", "fft:", "qft:", "rl:")
|
|
159
|
+
_MODEL_BY_ID = {model.model_id: model for model in SUPPORTED_MODELS}
|
|
160
|
+
|
|
161
|
+
# ------------------------------------------------------------------------------
|
|
162
|
+
# Public API
|
|
163
|
+
# ------------------------------------------------------------------------------
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class UnsupportedModelError(ValueError):
|
|
167
|
+
"""Raised when a model identifier is not supported by Synth."""
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _extract_base_model(candidate: str, *, allow_finetuned_prefixes: bool) -> str | None:
|
|
171
|
+
cleaned = candidate.strip()
|
|
172
|
+
lowered = cleaned.lower()
|
|
173
|
+
base = _BASE_LOOKUP.get(lowered)
|
|
174
|
+
if base:
|
|
175
|
+
return base
|
|
176
|
+
if not allow_finetuned_prefixes or ":" not in cleaned:
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
segments = cleaned.split(":")
|
|
180
|
+
for segment in segments[1:]:
|
|
181
|
+
candidate_base = segment.strip()
|
|
182
|
+
if not candidate_base:
|
|
183
|
+
continue
|
|
184
|
+
base = _BASE_LOOKUP.get(candidate_base.lower())
|
|
185
|
+
if base:
|
|
186
|
+
return base
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def ensure_supported_model(
|
|
191
|
+
model_id: str,
|
|
192
|
+
*,
|
|
193
|
+
allow_finetuned_prefixes: bool = True,
|
|
194
|
+
) -> str:
|
|
195
|
+
"""Validate that *model_id* resolves to a supported base model (no lifecycle gate)."""
|
|
196
|
+
candidate = (model_id or "").strip()
|
|
197
|
+
if not candidate:
|
|
198
|
+
raise UnsupportedModelError("Model identifier is empty")
|
|
199
|
+
|
|
200
|
+
base = _extract_base_model(candidate, allow_finetuned_prefixes=allow_finetuned_prefixes)
|
|
201
|
+
if base:
|
|
202
|
+
return base
|
|
203
|
+
|
|
204
|
+
raise UnsupportedModelError(
|
|
205
|
+
f"Model '{candidate}' is not supported. Call supported_model_ids() for available base models."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def ensure_allowed_model(
|
|
210
|
+
model_id: str,
|
|
211
|
+
*,
|
|
212
|
+
allow_finetuned_prefixes: bool = True,
|
|
213
|
+
allow_experimental: bool | None = None,
|
|
214
|
+
) -> str:
|
|
215
|
+
"""Validate support + lifecycle; gate experimental unless enabled."""
|
|
216
|
+
base = ensure_supported_model(model_id, allow_finetuned_prefixes=allow_finetuned_prefixes)
|
|
217
|
+
is_exp = base in EXPERIMENTAL_MODELS
|
|
218
|
+
allow_exp = allow_experimental if allow_experimental is not None else _experimental_enabled()
|
|
219
|
+
if is_exp and not allow_exp:
|
|
220
|
+
raise UnsupportedModelError(
|
|
221
|
+
f"Model '{base}' is experimental and disabled. "
|
|
222
|
+
"Set SDK_EXPERIMENTAL=1 or pass allow_experimental=True."
|
|
223
|
+
)
|
|
224
|
+
if is_exp:
|
|
225
|
+
_warn_if_experimental(base)
|
|
226
|
+
return base
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def normalize_model_identifier(
|
|
230
|
+
model_id: str,
|
|
231
|
+
*,
|
|
232
|
+
allow_finetuned_prefixes: bool = True,
|
|
233
|
+
) -> str:
|
|
234
|
+
"""Return a cleaned model identifier suitable for job payloads (no lifecycle gate)."""
|
|
235
|
+
canonical = ensure_supported_model(model_id, allow_finetuned_prefixes=allow_finetuned_prefixes)
|
|
236
|
+
cleaned = (model_id or "").strip()
|
|
237
|
+
if not cleaned:
|
|
238
|
+
return canonical
|
|
239
|
+
if cleaned.lower() in _BASE_LOOKUP:
|
|
240
|
+
return canonical
|
|
241
|
+
return cleaned
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def is_supported_model(model_id: str, *, allow_finetuned_prefixes: bool = True) -> bool:
|
|
245
|
+
"""Return True if *model_id* resolves to a supported base model (ignores lifecycle)."""
|
|
246
|
+
try:
|
|
247
|
+
ensure_supported_model(model_id, allow_finetuned_prefixes=allow_finetuned_prefixes)
|
|
248
|
+
except UnsupportedModelError:
|
|
249
|
+
return False
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def is_experimental_model(model_id: str) -> bool:
|
|
254
|
+
"""Return True if *model_id* is marked experimental."""
|
|
255
|
+
try:
|
|
256
|
+
base = ensure_supported_model(model_id, allow_finetuned_prefixes=True)
|
|
257
|
+
except UnsupportedModelError:
|
|
258
|
+
return False
|
|
259
|
+
return base in EXPERIMENTAL_MODELS
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def is_core_model(model_id: str) -> bool:
|
|
263
|
+
"""Return True if *model_id* is marked core."""
|
|
264
|
+
try:
|
|
265
|
+
base = ensure_supported_model(model_id, allow_finetuned_prefixes=True)
|
|
266
|
+
except UnsupportedModelError:
|
|
267
|
+
return False
|
|
268
|
+
return base in CORE_MODELS
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def iter_supported_models(
|
|
272
|
+
*,
|
|
273
|
+
families: Sequence[str] | None = None,
|
|
274
|
+
include: Sequence[str] | None = None,
|
|
275
|
+
exclude: Sequence[str] | None = None,
|
|
276
|
+
) -> Iterator[SupportedModel]:
|
|
277
|
+
"""Yield supported models, optionally filtered by family and lifecycle."""
|
|
278
|
+
include_set = {s.lower() for s in include} if include else None
|
|
279
|
+
exclude_set = {s.lower() for s in exclude} if exclude else None
|
|
280
|
+
fam_set = {f.lower() for f in families} if families else None
|
|
281
|
+
|
|
282
|
+
for m in SUPPORTED_MODELS:
|
|
283
|
+
if fam_set is not None and m.family.lower() not in fam_set:
|
|
284
|
+
continue
|
|
285
|
+
if include_set is not None and m.lifecycle.lower() not in include_set:
|
|
286
|
+
continue
|
|
287
|
+
if exclude_set is not None and m.lifecycle.lower() in exclude_set:
|
|
288
|
+
continue
|
|
289
|
+
yield m
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def list_supported_models(
|
|
293
|
+
*,
|
|
294
|
+
families: Sequence[str] | None = None,
|
|
295
|
+
include: Sequence[str] | None = None,
|
|
296
|
+
exclude: Sequence[str] | None = None,
|
|
297
|
+
) -> list[SupportedModel]:
|
|
298
|
+
"""Return supported models as a list for easier consumption."""
|
|
299
|
+
return list(iter_supported_models(families=families, include=include, exclude=exclude))
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def supported_model_ids(
|
|
303
|
+
*,
|
|
304
|
+
families: Sequence[str] | None = None,
|
|
305
|
+
include: Sequence[str] | None = None,
|
|
306
|
+
exclude: Sequence[str] | None = None,
|
|
307
|
+
) -> list[str]:
|
|
308
|
+
"""Return just the model identifiers for supported models."""
|
|
309
|
+
return [m.model_id for m in iter_supported_models(families=families, include=include, exclude=exclude)]
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def experimental_model_ids(*, families: Sequence[str] | None = None) -> list[str]:
|
|
313
|
+
"""Return identifiers for experimental supported models."""
|
|
314
|
+
return supported_model_ids(families=families, include=("experimental",))
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def core_model_ids(*, families: Sequence[str] | None = None) -> list[str]:
|
|
318
|
+
"""Return identifiers for core supported models."""
|
|
319
|
+
return supported_model_ids(families=families, include=("core",))
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def format_supported_models(
|
|
323
|
+
*,
|
|
324
|
+
families: Sequence[str] | None = None,
|
|
325
|
+
include: Sequence[str] | None = None,
|
|
326
|
+
exclude: Sequence[str] | None = None,
|
|
327
|
+
) -> str:
|
|
328
|
+
"""Produce a human readable table of supported models."""
|
|
329
|
+
rows: Iterable[SupportedModel] = iter_supported_models(families=families, include=include, exclude=exclude)
|
|
330
|
+
lines = ["model_id | family | provider | lifecycle | modalities | training_modes", "-" * 96]
|
|
331
|
+
for model in rows:
|
|
332
|
+
modalities = ",".join(model.modalities) or "-"
|
|
333
|
+
training = ",".join(model.training_modes) or "-"
|
|
334
|
+
lines.append(
|
|
335
|
+
f"{model.model_id} | {model.family} | {model.provider} | {model.lifecycle} | {modalities} | {training}"
|
|
336
|
+
)
|
|
337
|
+
return "\n".join(lines)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def training_modes_for_model(model_id: str) -> tuple[str, ...]:
|
|
341
|
+
"""Return the supported training modes (e.g., ('sft','rl')) for the given base model."""
|
|
342
|
+
canonical = ensure_supported_model(model_id, allow_finetuned_prefixes=True)
|
|
343
|
+
model = _MODEL_BY_ID.get(canonical)
|
|
344
|
+
if not model:
|
|
345
|
+
raise UnsupportedModelError(f"Model '{model_id}' is not registered as supported.")
|
|
346
|
+
return model.training_modes
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
__all__ = [
|
|
350
|
+
"QWEN3_MODELS",
|
|
351
|
+
"QWEN3_CODER_MODELS",
|
|
352
|
+
"RL_SUPPORTED_MODELS",
|
|
353
|
+
"SFT_SUPPORTED_MODELS",
|
|
354
|
+
"EXPERIMENTAL_MODELS",
|
|
355
|
+
"CORE_MODELS",
|
|
356
|
+
"ExperimentalWarning",
|
|
357
|
+
"SupportedModel",
|
|
358
|
+
"SUPPORTED_MODELS",
|
|
359
|
+
"SUPPORTED_BASE_MODEL_IDS",
|
|
360
|
+
"FINE_TUNED_PREFIXES",
|
|
361
|
+
"UnsupportedModelError",
|
|
362
|
+
"ensure_supported_model",
|
|
363
|
+
"ensure_allowed_model",
|
|
364
|
+
"normalize_model_identifier",
|
|
365
|
+
"is_supported_model",
|
|
366
|
+
"is_experimental_model",
|
|
367
|
+
"is_core_model",
|
|
368
|
+
"iter_supported_models",
|
|
369
|
+
"list_supported_models",
|
|
370
|
+
"supported_model_ids",
|
|
371
|
+
"experimental_model_ids",
|
|
372
|
+
"core_model_ids",
|
|
373
|
+
"format_supported_models",
|
|
374
|
+
"training_modes_for_model",
|
|
375
|
+
]
|
|
376
|
+
|
synth_ai/api/train/builders.py
CHANGED
|
@@ -5,8 +5,19 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
import click
|
|
8
|
+
from synth_ai.api.models.supported import (
|
|
9
|
+
UnsupportedModelError,
|
|
10
|
+
ensure_allowed_model,
|
|
11
|
+
normalize_model_identifier,
|
|
12
|
+
)
|
|
13
|
+
from synth_ai.learning.sft.config import prepare_sft_job_payload
|
|
8
14
|
|
|
9
|
-
from .
|
|
15
|
+
from .supported_algos import (
|
|
16
|
+
AlgorithmValidationError,
|
|
17
|
+
ensure_model_supported_for_algorithm,
|
|
18
|
+
validate_algorithm_config,
|
|
19
|
+
)
|
|
20
|
+
from .utils import TrainError, ensure_api_base, load_toml
|
|
10
21
|
|
|
11
22
|
|
|
12
23
|
@dataclass(slots=True)
|
|
@@ -29,23 +40,78 @@ def build_rl_payload(
|
|
|
29
40
|
task_url: str,
|
|
30
41
|
overrides: dict[str, Any],
|
|
31
42
|
idempotency: str | None,
|
|
43
|
+
allow_experimental: bool | None = None,
|
|
32
44
|
) -> RLBuildResult:
|
|
33
45
|
data = load_toml(config_path)
|
|
46
|
+
try:
|
|
47
|
+
spec = validate_algorithm_config(data.get("algorithm"), expected_family="rl")
|
|
48
|
+
except AlgorithmValidationError as exc:
|
|
49
|
+
raise click.ClickException(str(exc)) from exc
|
|
34
50
|
services = data.get("services") if isinstance(data.get("services"), dict) else {}
|
|
35
51
|
model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
|
|
36
52
|
|
|
37
|
-
final_task_url = (
|
|
53
|
+
final_task_url = (
|
|
54
|
+
overrides.get("task_url")
|
|
55
|
+
or task_url
|
|
56
|
+
or (services.get("task_url") if isinstance(services, dict) else None)
|
|
57
|
+
or ""
|
|
58
|
+
).strip()
|
|
38
59
|
if not final_task_url:
|
|
39
|
-
raise click.ClickException(
|
|
60
|
+
raise click.ClickException(
|
|
61
|
+
"Task app URL required (provide --task-url or set services.task_url in TOML)"
|
|
62
|
+
)
|
|
40
63
|
|
|
41
|
-
|
|
42
|
-
|
|
64
|
+
raw_source = model_cfg.get("source") if isinstance(model_cfg, dict) else ""
|
|
65
|
+
model_source = str(raw_source or "").strip()
|
|
66
|
+
raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
|
|
67
|
+
model_base = str(raw_base or "").strip()
|
|
43
68
|
override_model = (overrides.get("model") or "").strip()
|
|
44
69
|
if override_model:
|
|
45
70
|
model_source = override_model
|
|
46
71
|
model_base = ""
|
|
47
72
|
if bool(model_source) == bool(model_base):
|
|
48
|
-
|
|
73
|
+
details = (
|
|
74
|
+
f"Config: {config_path}\n"
|
|
75
|
+
f"[model].source={model_source!r} | [model].base={model_base!r}"
|
|
76
|
+
)
|
|
77
|
+
hint = (
|
|
78
|
+
"Set exactly one: [model].base for a base model (e.g. 'Qwen/Qwen3-1.7B') "
|
|
79
|
+
"or [model].source for a fine-tuned model id. Also remove any conflicting "
|
|
80
|
+
"'[policy].model' entries."
|
|
81
|
+
)
|
|
82
|
+
raise click.ClickException(
|
|
83
|
+
"Invalid model config: exactly one of [model].source or [model].base is required.\n"
|
|
84
|
+
+ details
|
|
85
|
+
+ "\nHint: "
|
|
86
|
+
+ hint
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
if model_source:
|
|
91
|
+
model_source = normalize_model_identifier(model_source)
|
|
92
|
+
if model_base:
|
|
93
|
+
model_base = normalize_model_identifier(model_base, allow_finetuned_prefixes=False)
|
|
94
|
+
except UnsupportedModelError as exc:
|
|
95
|
+
raise click.ClickException(str(exc)) from exc
|
|
96
|
+
|
|
97
|
+
base_model_for_training: str | None = None
|
|
98
|
+
if model_source:
|
|
99
|
+
base_model_for_training = ensure_allowed_model(
|
|
100
|
+
model_source,
|
|
101
|
+
allow_finetuned_prefixes=True,
|
|
102
|
+
allow_experimental=allow_experimental,
|
|
103
|
+
)
|
|
104
|
+
elif model_base:
|
|
105
|
+
base_model_for_training = ensure_allowed_model(
|
|
106
|
+
model_base,
|
|
107
|
+
allow_finetuned_prefixes=False,
|
|
108
|
+
allow_experimental=allow_experimental,
|
|
109
|
+
)
|
|
110
|
+
if base_model_for_training:
|
|
111
|
+
try:
|
|
112
|
+
ensure_model_supported_for_algorithm(base_model_for_training, spec)
|
|
113
|
+
except AlgorithmValidationError as exc:
|
|
114
|
+
raise click.ClickException(str(exc)) from exc
|
|
49
115
|
|
|
50
116
|
# Force TOML services.task_url to the effective endpoint to avoid split URLs
|
|
51
117
|
try:
|
|
@@ -81,34 +147,53 @@ def build_sft_payload(
|
|
|
81
147
|
*,
|
|
82
148
|
config_path: Path,
|
|
83
149
|
dataset_override: Path | None,
|
|
150
|
+
allow_experimental: bool | None,
|
|
84
151
|
) -> SFTBuildResult:
|
|
85
152
|
data = load_toml(config_path)
|
|
153
|
+
try:
|
|
154
|
+
spec = validate_algorithm_config(data.get("algorithm"), expected_family="sft")
|
|
155
|
+
except AlgorithmValidationError as exc:
|
|
156
|
+
raise TrainError(str(exc)) from exc
|
|
86
157
|
job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
|
|
87
158
|
data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
|
|
88
159
|
hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
|
|
89
160
|
train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
|
|
90
161
|
compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
|
|
91
162
|
|
|
92
|
-
raw_dataset =
|
|
163
|
+
raw_dataset = (
|
|
164
|
+
dataset_override
|
|
165
|
+
or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
|
|
166
|
+
or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
|
|
167
|
+
)
|
|
93
168
|
if not raw_dataset:
|
|
94
169
|
raise TrainError("Dataset not specified; pass --dataset or set [job].data")
|
|
95
170
|
dataset_path = Path(raw_dataset)
|
|
96
|
-
|
|
171
|
+
# Resolve relative paths from current working directory, not config directory
|
|
172
|
+
dataset_path = (
|
|
173
|
+
dataset_path if dataset_path.is_absolute() else (Path.cwd() / dataset_path)
|
|
174
|
+
).resolve()
|
|
97
175
|
if not dataset_path.exists():
|
|
98
176
|
raise TrainError(f"Dataset not found: {dataset_path}")
|
|
99
177
|
|
|
100
|
-
validation_path =
|
|
178
|
+
validation_path = (
|
|
179
|
+
data_cfg.get("validation_path")
|
|
180
|
+
if isinstance(data_cfg, dict)
|
|
181
|
+
else None
|
|
182
|
+
if isinstance(data_cfg, dict) and isinstance(data_cfg.get("validation_path"), str)
|
|
183
|
+
else None
|
|
184
|
+
)
|
|
101
185
|
validation_file = None
|
|
102
186
|
if validation_path:
|
|
103
187
|
vpath = Path(validation_path)
|
|
104
|
-
|
|
188
|
+
# Resolve relative paths from current working directory, not config directory
|
|
189
|
+
vpath = (vpath if vpath.is_absolute() else (Path.cwd() / vpath)).resolve()
|
|
105
190
|
if not vpath.exists():
|
|
106
191
|
click.echo(f"[WARN] Validation dataset {vpath} missing; continuing without validation")
|
|
107
192
|
else:
|
|
108
193
|
validation_file = vpath
|
|
109
194
|
|
|
110
195
|
hp_block: dict[str, Any] = {
|
|
111
|
-
"n_epochs": int(hp_cfg.get("n_epochs", 1)),
|
|
196
|
+
"n_epochs": int(hp_cfg.get("n_epochs", 1) if isinstance(hp_cfg, dict) else 1),
|
|
112
197
|
}
|
|
113
198
|
for key in (
|
|
114
199
|
"batch_size",
|
|
@@ -120,20 +205,36 @@ def build_sft_payload(
|
|
|
120
205
|
"warmup_ratio",
|
|
121
206
|
"train_kind",
|
|
122
207
|
):
|
|
123
|
-
if key in hp_cfg:
|
|
208
|
+
if isinstance(hp_cfg, dict) and key in hp_cfg:
|
|
124
209
|
hp_block[key] = hp_cfg[key]
|
|
125
|
-
if isinstance(hp_cfg.get("parallelism"), dict):
|
|
210
|
+
if isinstance(hp_cfg, dict) and isinstance(hp_cfg.get("parallelism"), dict):
|
|
126
211
|
hp_block["parallelism"] = hp_cfg["parallelism"]
|
|
127
212
|
|
|
128
|
-
compute_block = {
|
|
213
|
+
compute_block = {
|
|
214
|
+
k: compute_cfg[k]
|
|
215
|
+
for k in ("gpu_type", "gpu_count", "nodes")
|
|
216
|
+
if isinstance(compute_cfg, dict) and k in compute_cfg
|
|
217
|
+
}
|
|
129
218
|
|
|
130
219
|
effective = {
|
|
131
220
|
"compute": compute_block,
|
|
132
|
-
"data": {
|
|
133
|
-
|
|
221
|
+
"data": {
|
|
222
|
+
"topology": data_cfg.get("topology", {})
|
|
223
|
+
if isinstance(data_cfg, dict) and isinstance(data_cfg.get("topology"), dict)
|
|
224
|
+
else {}
|
|
225
|
+
},
|
|
226
|
+
"training": {
|
|
227
|
+
k: v
|
|
228
|
+
for k, v in (train_cfg.items() if isinstance(train_cfg, dict) else [])
|
|
229
|
+
if k in ("mode", "use_qlora")
|
|
230
|
+
},
|
|
134
231
|
}
|
|
135
232
|
|
|
136
|
-
validation_cfg =
|
|
233
|
+
validation_cfg = (
|
|
234
|
+
train_cfg.get("validation")
|
|
235
|
+
if isinstance(train_cfg, dict) and isinstance(train_cfg.get("validation"), dict)
|
|
236
|
+
else None
|
|
237
|
+
)
|
|
137
238
|
if isinstance(validation_cfg, dict):
|
|
138
239
|
hp_block.update(
|
|
139
240
|
{
|
|
@@ -144,15 +245,45 @@ def build_sft_payload(
|
|
|
144
245
|
"greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
|
|
145
246
|
}
|
|
146
247
|
)
|
|
147
|
-
effective.setdefault("training", {})["validation"] = {
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
"
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
248
|
+
effective.setdefault("training", {})["validation"] = {
|
|
249
|
+
"enabled": bool(validation_cfg.get("enabled", True))
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
raw_model = str(
|
|
253
|
+
job_cfg.get("model") if isinstance(job_cfg, dict) else None or data.get("model") or ""
|
|
254
|
+
).strip()
|
|
255
|
+
if not raw_model:
|
|
256
|
+
raise TrainError("Model not specified; set [job].model or [model].base in the config")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
base_model = ensure_allowed_model(
|
|
260
|
+
raw_model,
|
|
261
|
+
allow_finetuned_prefixes=False,
|
|
262
|
+
allow_experimental=allow_experimental,
|
|
263
|
+
)
|
|
264
|
+
except UnsupportedModelError as exc:
|
|
265
|
+
raise TrainError(str(exc)) from exc
|
|
266
|
+
try:
|
|
267
|
+
ensure_model_supported_for_algorithm(base_model, spec)
|
|
268
|
+
except AlgorithmValidationError as exc:
|
|
269
|
+
raise TrainError(str(exc)) from exc
|
|
270
|
+
|
|
271
|
+
try:
|
|
272
|
+
payload = prepare_sft_job_payload(
|
|
273
|
+
model=raw_model,
|
|
274
|
+
training_file=None,
|
|
275
|
+
hyperparameters=hp_block,
|
|
276
|
+
metadata={"effective_config": effective},
|
|
277
|
+
training_type="sft_offline",
|
|
278
|
+
training_file_field="training_file_id",
|
|
279
|
+
require_training_file=False,
|
|
280
|
+
include_training_file_when_none=True,
|
|
281
|
+
allow_finetuned_prefixes=False,
|
|
282
|
+
)
|
|
283
|
+
except UnsupportedModelError as exc:
|
|
284
|
+
raise TrainError(str(exc)) from exc
|
|
285
|
+
except ValueError as exc:
|
|
286
|
+
raise TrainError(str(exc)) from exc
|
|
156
287
|
|
|
157
288
|
return SFTBuildResult(payload=payload, train_file=dataset_path, validation_file=validation_file)
|
|
158
289
|
|