synth-ai 0.2.14__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +73 -115
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +5 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -2
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +152 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +274 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +489 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +415 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +110 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +59 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +26 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +26 -0
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/filter_qwen3vl_sft.toml +49 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +52 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +61 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +6 -6
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +62 -0
- examples/rl/configs/rl_from_base_qwen17.toml +79 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +21 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +6 -6
- examples/sft/configs/crafter_fft_qwen0p6b.toml +7 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +7 -3
- examples/sft/evaluate.py +2 -4
- examples/sft/export_dataset.py +7 -4
- examples/swe/task_app/README.md +33 -3
- examples/swe/task_app/grpo_swe_mini.py +4 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +50 -23
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +0 -8
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +70 -10
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +63 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +48 -50
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +75 -36
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +31 -15
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +5 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +5 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +827 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +454 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1084 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +144 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +155 -17
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +61 -69
- synth_ai/cli/_modal_wrapper.py +7 -5
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/recent.py +2 -1
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +21 -0
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +7 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +11 -0
- synth_ai/cli/task_app_serve.py +11 -0
- synth_ai/cli/task_apps.py +110 -1499
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +5 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/http.py +8 -22
- synth_ai/inference/client.py +1 -1
- synth_ai/judge_schemas.py +4 -5
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +4 -2
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/apps/__init__.py +4 -2
- synth_ai/task/config.py +6 -4
- synth_ai/task/rubrics/__init__.py +1 -2
- synth_ai/task/rubrics/loaders.py +14 -10
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +24 -11
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/tracing_v3/session_tracer.py +7 -7
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -9
- synth_ai/tracing_v3/turso/native_manager.py +80 -72
- synth_ai/tracing_v3/utils.py +2 -2
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +294 -0
- synth_ai/utils/http.py +172 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/METADATA +91 -32
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/RECORD +341 -154
- synth_ai/cli/man.py +0 -106
- synth_ai/cli/tui.py +0 -57
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -906
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.14.dist-info → synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
synth_ai/api/train/cli.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import importlib
|
|
4
5
|
import os
|
|
6
|
+
import time
|
|
5
7
|
from collections.abc import Callable, Mapping
|
|
6
8
|
from pathlib import Path
|
|
7
9
|
from typing import Any, cast
|
|
@@ -16,10 +18,18 @@ try:
|
|
|
16
18
|
except Exception as exc: # pragma: no cover - critical dependency
|
|
17
19
|
raise RuntimeError("Unable to load backend configuration helpers") from exc
|
|
18
20
|
|
|
21
|
+
from synth_ai.streaming import (
|
|
22
|
+
CLIHandler,
|
|
23
|
+
JobStreamer,
|
|
24
|
+
LossCurveHandler,
|
|
25
|
+
StreamConfig,
|
|
26
|
+
StreamEndpoints,
|
|
27
|
+
StreamType,
|
|
28
|
+
)
|
|
29
|
+
|
|
19
30
|
from .builders import build_rl_payload, build_sft_payload
|
|
20
31
|
from .config_finder import discover_configs, prompt_for_config
|
|
21
32
|
from .env_resolver import KeySpec, resolve_env
|
|
22
|
-
from .pollers import RLJobPoller, SFTJobPoller
|
|
23
33
|
from .task_app import check_task_app_health
|
|
24
34
|
from .utils import (
|
|
25
35
|
REPO_ROOT,
|
|
@@ -36,20 +46,41 @@ from .utils import (
|
|
|
36
46
|
)
|
|
37
47
|
|
|
38
48
|
|
|
39
|
-
def _discover_dataset_candidates(
|
|
49
|
+
def _discover_dataset_candidates(
|
|
50
|
+
config_path: Path, limit: int = 50, timeout: float = 10.0
|
|
51
|
+
) -> list[Path]:
|
|
52
|
+
root = config_path.parent
|
|
53
|
+
parent = root.parent
|
|
54
|
+
cwd = Path.cwd()
|
|
55
|
+
|
|
40
56
|
search_dirs: list[Path] = [
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
57
|
+
root,
|
|
58
|
+
root / "datasets",
|
|
59
|
+
parent,
|
|
60
|
+
parent / "datasets",
|
|
61
|
+
parent / "ft_data",
|
|
62
|
+
cwd,
|
|
63
|
+
cwd / "datasets",
|
|
64
|
+
cwd / "ft_data",
|
|
44
65
|
REPO_ROOT / "datasets",
|
|
66
|
+
REPO_ROOT / "ft_data",
|
|
67
|
+
REPO_ROOT / "traces",
|
|
45
68
|
]
|
|
46
69
|
|
|
47
70
|
candidates: list[Path] = []
|
|
48
71
|
seen: set[Path] = set()
|
|
72
|
+
start = time.monotonic()
|
|
73
|
+
timed_out = False
|
|
49
74
|
for directory in search_dirs:
|
|
75
|
+
if timed_out or time.monotonic() - start > timeout:
|
|
76
|
+
timed_out = True
|
|
77
|
+
break
|
|
50
78
|
if not directory.exists() or not directory.is_dir():
|
|
51
79
|
continue
|
|
52
80
|
for path in directory.rglob("*.jsonl"):
|
|
81
|
+
if time.monotonic() - start > timeout:
|
|
82
|
+
timed_out = True
|
|
83
|
+
break
|
|
53
84
|
try:
|
|
54
85
|
resolved = path.resolve()
|
|
55
86
|
except OSError:
|
|
@@ -113,6 +144,62 @@ def _default_backend() -> str:
|
|
|
113
144
|
return f"{base}/api" if not base.endswith("/api") else base
|
|
114
145
|
|
|
115
146
|
|
|
147
|
+
_DEFAULT_SFT_HIDDEN_EVENTS = {
|
|
148
|
+
"sft.created",
|
|
149
|
+
"sft.pricing.check.requested",
|
|
150
|
+
"sft.pricing.check.allowed",
|
|
151
|
+
"sft.stage",
|
|
152
|
+
"snapshot.fetch",
|
|
153
|
+
"hatchet.preflight",
|
|
154
|
+
"hatchet.submission.attempt",
|
|
155
|
+
"hatchet.submission.result",
|
|
156
|
+
"sft.running",
|
|
157
|
+
"sft.status",
|
|
158
|
+
"sft.worker.alive",
|
|
159
|
+
"sft.dispatch.selected",
|
|
160
|
+
"sft.config.prepared",
|
|
161
|
+
"sft.strategy.selected",
|
|
162
|
+
"sft.training.args",
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
_DEFAULT_RL_HIDDEN_SUBSTRINGS = {"modal", "hatchet"}
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _build_stream_components(
|
|
169
|
+
stream_format: str,
|
|
170
|
+
*,
|
|
171
|
+
hidden_event_types: set[str] | None = None,
|
|
172
|
+
hidden_event_substrings: set[str] | None = None,
|
|
173
|
+
) -> tuple[StreamConfig, list]:
|
|
174
|
+
"""Return stream configuration and handlers for the requested format."""
|
|
175
|
+
if stream_format == "chart":
|
|
176
|
+
config = StreamConfig(
|
|
177
|
+
enabled_streams={StreamType.STATUS, StreamType.EVENTS, StreamType.METRICS},
|
|
178
|
+
event_types={
|
|
179
|
+
"sft.progress",
|
|
180
|
+
"sft.training.started",
|
|
181
|
+
"sft.training.finish",
|
|
182
|
+
"sft.validation.summary",
|
|
183
|
+
"rl.train.step",
|
|
184
|
+
"rl.train.started",
|
|
185
|
+
"rl.train.completed",
|
|
186
|
+
"workflow.completed",
|
|
187
|
+
"workflow.failed",
|
|
188
|
+
},
|
|
189
|
+
metric_names={"train.loss"},
|
|
190
|
+
)
|
|
191
|
+
handlers = [LossCurveHandler()]
|
|
192
|
+
else:
|
|
193
|
+
config = StreamConfig.default()
|
|
194
|
+
handlers = [
|
|
195
|
+
CLIHandler(
|
|
196
|
+
hidden_event_types=hidden_event_types or set(),
|
|
197
|
+
hidden_event_substrings=hidden_event_substrings or set(),
|
|
198
|
+
)
|
|
199
|
+
]
|
|
200
|
+
return config, handlers
|
|
201
|
+
|
|
202
|
+
|
|
116
203
|
@click.command("train")
|
|
117
204
|
@click.option(
|
|
118
205
|
"--config",
|
|
@@ -161,6 +248,13 @@ def _default_backend() -> str:
|
|
|
161
248
|
"--poll-timeout", default=3600.0, type=float, help="Maximum seconds to poll before timing out"
|
|
162
249
|
)
|
|
163
250
|
@click.option("--poll-interval", default=5.0, type=float, help="Seconds between poll attempts")
|
|
251
|
+
@click.option(
|
|
252
|
+
"--stream-format",
|
|
253
|
+
type=click.Choice(["cli", "chart"]),
|
|
254
|
+
default="cli",
|
|
255
|
+
show_default=True,
|
|
256
|
+
help="Streaming output style (cli = line updates, chart = live loss panel)",
|
|
257
|
+
)
|
|
164
258
|
@click.option(
|
|
165
259
|
"--examples",
|
|
166
260
|
"examples_limit",
|
|
@@ -182,6 +276,7 @@ def train_command(
|
|
|
182
276
|
poll: bool,
|
|
183
277
|
poll_timeout: float,
|
|
184
278
|
poll_interval: float,
|
|
279
|
+
stream_format: str,
|
|
185
280
|
examples_limit: int | None,
|
|
186
281
|
) -> None:
|
|
187
282
|
"""Interactive launcher for RL / SFT jobs."""
|
|
@@ -280,6 +375,7 @@ def train_command(
|
|
|
280
375
|
poll=poll,
|
|
281
376
|
poll_timeout=poll_timeout,
|
|
282
377
|
poll_interval=poll_interval,
|
|
378
|
+
stream_format=stream_format,
|
|
283
379
|
)
|
|
284
380
|
else:
|
|
285
381
|
dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
|
|
@@ -293,14 +389,23 @@ def train_command(
|
|
|
293
389
|
poll=poll,
|
|
294
390
|
poll_timeout=poll_timeout,
|
|
295
391
|
poll_interval=poll_interval,
|
|
392
|
+
stream_format=stream_format,
|
|
296
393
|
examples_limit=examples_limit,
|
|
297
394
|
)
|
|
298
395
|
|
|
299
396
|
|
|
300
397
|
def _wait_for_training_file(
|
|
301
|
-
backend_base: str, api_key: str, file_id: str, *, timeout: float =
|
|
398
|
+
backend_base: str, api_key: str, file_id: str, *, timeout: float = 10.0
|
|
302
399
|
) -> None:
|
|
303
|
-
|
|
400
|
+
"""Wait for training file to be visible after upload.
|
|
401
|
+
|
|
402
|
+
Reduced from 120s to 10s because:
|
|
403
|
+
- POST response already confirms file is uploaded
|
|
404
|
+
- Backend now forces read-your-writes consistency
|
|
405
|
+
- By job creation time, replica lag has resolved
|
|
406
|
+
- Quick sanity check only, not critical path
|
|
407
|
+
"""
|
|
408
|
+
url = f"{backend_base.rstrip('/')}/files/{file_id}"
|
|
304
409
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
305
410
|
elapsed = 0.0
|
|
306
411
|
interval = 2.0
|
|
@@ -378,6 +483,7 @@ def handle_rl(
|
|
|
378
483
|
poll: bool,
|
|
379
484
|
poll_timeout: float,
|
|
380
485
|
poll_interval: float,
|
|
486
|
+
stream_format: str,
|
|
381
487
|
) -> None:
|
|
382
488
|
overrides: dict[str, Any] = {
|
|
383
489
|
"backend": backend_base,
|
|
@@ -475,10 +581,25 @@ def handle_rl(
|
|
|
475
581
|
click.echo(f"Created job {job_id} (polling disabled)")
|
|
476
582
|
return
|
|
477
583
|
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
584
|
+
click.echo("\n=== Streaming Job Progress ===")
|
|
585
|
+
config, handlers = _build_stream_components(
|
|
586
|
+
stream_format, hidden_event_substrings=_DEFAULT_RL_HIDDEN_SUBSTRINGS
|
|
587
|
+
)
|
|
588
|
+
if stream_format == "chart":
|
|
589
|
+
click.echo("Using live loss chart (metric=train.loss)")
|
|
590
|
+
streamer = JobStreamer(
|
|
591
|
+
base_url=backend_base,
|
|
592
|
+
api_key=synth_key,
|
|
593
|
+
job_id=job_id,
|
|
594
|
+
endpoints=StreamEndpoints.rl(job_id),
|
|
595
|
+
config=config,
|
|
596
|
+
handlers=handlers,
|
|
597
|
+
interval_seconds=poll_interval,
|
|
598
|
+
timeout_seconds=poll_timeout,
|
|
599
|
+
)
|
|
600
|
+
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
601
|
+
click.echo(f"Final status: {final_status.get('status', 'unknown')}")
|
|
602
|
+
click.echo(preview_json(final_status, limit=600))
|
|
482
603
|
|
|
483
604
|
|
|
484
605
|
def handle_sft(
|
|
@@ -492,6 +613,7 @@ def handle_sft(
|
|
|
492
613
|
poll: bool,
|
|
493
614
|
poll_timeout: float,
|
|
494
615
|
poll_interval: float,
|
|
616
|
+
stream_format: str,
|
|
495
617
|
examples_limit: int | None,
|
|
496
618
|
) -> None:
|
|
497
619
|
dataset_path = dataset_override
|
|
@@ -524,7 +646,7 @@ def handle_sft(
|
|
|
524
646
|
click.echo("Validating validation dataset…")
|
|
525
647
|
validate_sft_jsonl(build.validation_file)
|
|
526
648
|
|
|
527
|
-
upload_url = f"{backend_base}/
|
|
649
|
+
upload_url = f"{backend_base.rstrip('/')}/files"
|
|
528
650
|
click.echo("\n=== Uploading Training Data ===")
|
|
529
651
|
click.echo(f"Dataset: {build.train_file}")
|
|
530
652
|
click.echo(f"Destination: {upload_url}")
|
|
@@ -579,7 +701,8 @@ def handle_sft(
|
|
|
579
701
|
try:
|
|
580
702
|
_wait_for_training_file(backend_base, synth_key, train_file_id)
|
|
581
703
|
except click.ClickException as exc:
|
|
582
|
-
|
|
704
|
+
click.echo(f"[WARN] File readiness check failed: {exc}")
|
|
705
|
+
click.echo("Proceeding anyway - backend will validate file during job creation...")
|
|
583
706
|
|
|
584
707
|
click.echo("\n=== Creating Training Job ===")
|
|
585
708
|
click.echo("Job payload preview:")
|
|
@@ -618,10 +741,25 @@ def handle_sft(
|
|
|
618
741
|
click.echo(f"Started job {job_id} (polling disabled)")
|
|
619
742
|
return
|
|
620
743
|
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
744
|
+
click.echo("\n=== Streaming Job Progress ===")
|
|
745
|
+
config, handlers = _build_stream_components(
|
|
746
|
+
stream_format, hidden_event_types=_DEFAULT_SFT_HIDDEN_EVENTS
|
|
747
|
+
)
|
|
748
|
+
if stream_format == "chart":
|
|
749
|
+
click.echo("Using live loss chart (metric=train.loss)")
|
|
750
|
+
streamer = JobStreamer(
|
|
751
|
+
base_url=backend_base,
|
|
752
|
+
api_key=synth_key,
|
|
753
|
+
job_id=job_id,
|
|
754
|
+
endpoints=StreamEndpoints.learning(job_id),
|
|
755
|
+
config=config,
|
|
756
|
+
handlers=handlers,
|
|
757
|
+
interval_seconds=poll_interval,
|
|
758
|
+
timeout_seconds=poll_timeout,
|
|
759
|
+
)
|
|
760
|
+
final_status = asyncio.run(streamer.stream_until_terminal())
|
|
761
|
+
click.echo(f"Final status: {final_status.get('status', 'unknown')}")
|
|
762
|
+
click.echo(preview_json(final_status, limit=600))
|
|
625
763
|
finally:
|
|
626
764
|
if limited_path is not None:
|
|
627
765
|
try:
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
import os
|
|
5
4
|
from collections.abc import Iterable
|
|
6
5
|
from dataclasses import dataclass
|
|
7
6
|
from pathlib import Path
|
|
@@ -11,7 +10,9 @@ import click
|
|
|
11
10
|
from .utils import REPO_ROOT, load_toml, preview_json
|
|
12
11
|
|
|
13
12
|
_SKIP_DIRS = {".git", "__pycache__", ".venv", "node_modules", "dist", "build"}
|
|
14
|
-
|
|
13
|
+
|
|
14
|
+
_STATE_DIR = Path.home() / ".synth-ai"
|
|
15
|
+
_STATE_FILE = _STATE_DIR / "train_cli.json"
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
@dataclass(slots=True)
|
|
@@ -23,8 +24,8 @@ class ConfigCandidate:
|
|
|
23
24
|
def _load_last_config() -> Path | None:
|
|
24
25
|
"""Load the last used training config path from state file."""
|
|
25
26
|
try:
|
|
26
|
-
if
|
|
27
|
-
with open(
|
|
27
|
+
if _STATE_FILE.is_file():
|
|
28
|
+
with _STATE_FILE.open() as fh:
|
|
28
29
|
data = json.load(fh)
|
|
29
30
|
if isinstance(data, dict):
|
|
30
31
|
last_config = data.get("LAST_CONFIG")
|
|
@@ -41,14 +42,14 @@ def _save_last_config(config_path: Path) -> None:
|
|
|
41
42
|
"""Save the last used training config path to state file."""
|
|
42
43
|
try:
|
|
43
44
|
data = {}
|
|
44
|
-
if
|
|
45
|
-
with open(
|
|
45
|
+
if _STATE_FILE.is_file():
|
|
46
|
+
with _STATE_FILE.open() as fh:
|
|
46
47
|
data = json.load(fh) or {}
|
|
47
48
|
if not isinstance(data, dict):
|
|
48
49
|
data = {}
|
|
49
50
|
data["LAST_CONFIG"] = str(config_path.resolve())
|
|
50
|
-
|
|
51
|
-
with open(
|
|
51
|
+
_STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
with _STATE_FILE.open("w") as fh:
|
|
52
53
|
json.dump(data, fh)
|
|
53
54
|
except Exception:
|
|
54
55
|
pass
|
|
@@ -77,6 +78,7 @@ def _iter_candidate_paths() -> Iterable[Path]:
|
|
|
77
78
|
REPO_ROOT / "configs",
|
|
78
79
|
REPO_ROOT / "examples",
|
|
79
80
|
REPO_ROOT / "training",
|
|
81
|
+
REPO_ROOT / "synth_ai" / "demos",
|
|
80
82
|
]
|
|
81
83
|
for base in preferred:
|
|
82
84
|
if not base.exists():
|
|
@@ -148,6 +150,10 @@ def discover_configs(explicit: list[str], *, requested_type: str | None) -> list
|
|
|
148
150
|
raise click.ClickException(f"Config not found: {path}")
|
|
149
151
|
data = load_toml(path)
|
|
150
152
|
cfg_type = _infer_config_type(data)
|
|
153
|
+
if cfg_type == "unknown":
|
|
154
|
+
raise click.ClickException(
|
|
155
|
+
f"Config {path} is missing algorithm.type/method metadata. Add type = 'rl' or 'sft'."
|
|
156
|
+
)
|
|
151
157
|
candidates.append(ConfigCandidate(path=path, train_type=cfg_type))
|
|
152
158
|
seen.add(path)
|
|
153
159
|
|
|
@@ -162,10 +168,12 @@ def discover_configs(explicit: list[str], *, requested_type: str | None) -> list
|
|
|
162
168
|
except Exception:
|
|
163
169
|
continue
|
|
164
170
|
cfg_type = _infer_config_type(data)
|
|
171
|
+
if cfg_type == "unknown":
|
|
172
|
+
continue
|
|
165
173
|
candidates.append(ConfigCandidate(path=path, train_type=cfg_type))
|
|
166
174
|
|
|
167
175
|
if requested_type and requested_type != "auto":
|
|
168
|
-
candidates = [c for c in candidates if c.train_type
|
|
176
|
+
candidates = [c for c in candidates if c.train_type == requested_type]
|
|
169
177
|
|
|
170
178
|
# De-dupe by path and keep deterministic ordering by directory depth then name
|
|
171
179
|
candidates.sort(key=lambda c: (len(c.path.parts), str(c.path)))
|
|
@@ -196,9 +204,8 @@ def prompt_for_config(
|
|
|
196
204
|
|
|
197
205
|
click.echo("Select a training config:")
|
|
198
206
|
for idx, cand in enumerate(candidates, start=1):
|
|
199
|
-
label = cand.train_type if cand.train_type != "unknown" else "?"
|
|
200
207
|
last_marker = " (last used)" if last_config and cand.path.resolve() == last_config else ""
|
|
201
|
-
click.echo(f" {idx})
|
|
208
|
+
click.echo(f" {idx}) {cand.path}{last_marker}")
|
|
202
209
|
click.echo(" 0) Abort")
|
|
203
210
|
|
|
204
211
|
choice = click.prompt("Enter choice", type=int, default=default_idx)
|
|
@@ -5,10 +5,12 @@ from .rl import (
|
|
|
5
5
|
JudgeConfig,
|
|
6
6
|
JudgeOptionsConfig,
|
|
7
7
|
ModelConfig,
|
|
8
|
+
RewardsConfig,
|
|
8
9
|
RLConfig,
|
|
9
10
|
RLServicesConfig,
|
|
10
11
|
RLTrainingConfig,
|
|
11
12
|
RolloutConfig,
|
|
13
|
+
RubricConfig,
|
|
12
14
|
WeightSyncConfig,
|
|
13
15
|
)
|
|
14
16
|
from .sft import (
|
|
@@ -20,7 +22,7 @@ from .sft import (
|
|
|
20
22
|
TrainingConfig,
|
|
21
23
|
TrainingValidationConfig,
|
|
22
24
|
)
|
|
23
|
-
from .shared import AlgorithmConfig, ComputeConfig
|
|
25
|
+
from .shared import AlgorithmConfig, ComputeConfig, LoraConfig, PolicyConfig, TopologyConfig
|
|
24
26
|
|
|
25
27
|
__all__ = [
|
|
26
28
|
"AlgorithmConfig",
|
|
@@ -31,13 +33,18 @@ __all__ = [
|
|
|
31
33
|
"JobConfig",
|
|
32
34
|
"JudgeConfig",
|
|
33
35
|
"JudgeOptionsConfig",
|
|
36
|
+
"LoraConfig",
|
|
34
37
|
"ModelConfig",
|
|
38
|
+
"PolicyConfig",
|
|
39
|
+
"RewardsConfig",
|
|
35
40
|
"RLConfig",
|
|
36
41
|
"RLServicesConfig",
|
|
37
42
|
"RLTrainingConfig",
|
|
38
43
|
"RolloutConfig",
|
|
44
|
+
"RubricConfig",
|
|
39
45
|
"SFTConfig",
|
|
40
46
|
"SFTDataConfig",
|
|
47
|
+
"TopologyConfig",
|
|
41
48
|
"TrainingConfig",
|
|
42
49
|
"TrainingValidationConfig",
|
|
43
50
|
"WeightSyncConfig",
|
synth_ai/api/train/configs/rl.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
from pydantic import model_validator
|
|
8
8
|
|
|
9
9
|
from ..utils import load_toml
|
|
10
|
-
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
|
|
10
|
+
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel, LoraConfig, PolicyConfig
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class RLServicesConfig(ExtraModel):
|
|
@@ -48,6 +48,16 @@ class WeightSyncConfig(ExtraModel):
|
|
|
48
48
|
verify_every_k: int | None = None
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
class RewardsConfig(ExtraModel):
|
|
52
|
+
"""Rewards configuration for RL training."""
|
|
53
|
+
step_rewards_enabled: bool | None = None
|
|
54
|
+
step_rewards_mode: str | None = None
|
|
55
|
+
step_rewards_indicator_lambda: float | None = None
|
|
56
|
+
step_rewards_beta: float | None = None
|
|
57
|
+
step_rewards_strategy: str | None = None
|
|
58
|
+
event_rewards_kind: str | None = None
|
|
59
|
+
|
|
60
|
+
|
|
51
61
|
class RLTrainingConfig(ExtraModel):
|
|
52
62
|
num_epochs: int
|
|
53
63
|
iterations_per_epoch: int
|
|
@@ -59,13 +69,17 @@ class RLTrainingConfig(ExtraModel):
|
|
|
59
69
|
learning_rate: float
|
|
60
70
|
log_interval: int | None = None
|
|
61
71
|
weight_sync_interval: int | None = None
|
|
72
|
+
# DEPRECATED: flat reward fields (use rewards.* instead)
|
|
62
73
|
step_rewards_enabled: bool | None = None
|
|
63
74
|
step_rewards_mode: str | None = None
|
|
64
75
|
step_rewards_indicator_lambda: float | None = None
|
|
65
76
|
step_rewards_beta: float | None = None
|
|
66
77
|
step_rewards_strategy: str | None = None
|
|
67
78
|
event_rewards_kind: str | None = None
|
|
79
|
+
# NEW: nested configs
|
|
68
80
|
weight_sync: WeightSyncConfig | None = None
|
|
81
|
+
lora: LoraConfig | None = None
|
|
82
|
+
rewards: RewardsConfig | None = None
|
|
69
83
|
|
|
70
84
|
|
|
71
85
|
class EvaluationConfig(ExtraModel):
|
|
@@ -86,9 +100,18 @@ class JudgeOptionsConfig(ExtraModel):
|
|
|
86
100
|
max_concurrency: int | None = None
|
|
87
101
|
|
|
88
102
|
|
|
103
|
+
class RubricConfig(ExtraModel):
|
|
104
|
+
"""Rubric configuration for reward blending."""
|
|
105
|
+
enabled: bool = False
|
|
106
|
+
reward_blend: dict[str, float] | None = None # env, event, outcome weights
|
|
107
|
+
|
|
108
|
+
|
|
89
109
|
class JudgeConfig(ExtraModel):
|
|
90
110
|
type: str | None = None
|
|
91
111
|
timeout_s: int | None = None
|
|
112
|
+
enabled: bool | None = None # Master switch for judge/rubric
|
|
113
|
+
reward_blend: dict[str, float] | None = None # NEW: nested reward blending (replaces rubric.weights)
|
|
114
|
+
rubric: RubricConfig | None = None # DEPRECATED: use flat fields instead
|
|
92
115
|
options: JudgeOptionsConfig | None = None
|
|
93
116
|
|
|
94
117
|
|
|
@@ -96,15 +119,16 @@ class RLConfig(ExtraModel):
|
|
|
96
119
|
algorithm: AlgorithmConfig
|
|
97
120
|
services: RLServicesConfig
|
|
98
121
|
compute: ComputeConfig | None = None
|
|
99
|
-
topology: dict[str, Any] | None = None
|
|
122
|
+
topology: dict[str, Any] | None = None # DEPRECATED: use compute.topology instead
|
|
100
123
|
vllm: dict[str, Any] | None = None
|
|
101
|
-
reference: dict[str, Any] | None = None
|
|
102
|
-
model: ModelConfig
|
|
103
|
-
|
|
124
|
+
reference: dict[str, Any] | None = None # DEPRECATED: use compute.topology.reference_placement instead
|
|
125
|
+
model: ModelConfig | None = None # DEPRECATED: use policy instead
|
|
126
|
+
policy: PolicyConfig | None = None # NEW: unified policy (preferred)
|
|
127
|
+
lora: dict[str, Any] | None = None # DEPRECATED: use training.lora instead
|
|
104
128
|
rollout: RolloutConfig | None = None
|
|
105
129
|
evaluation: EvaluationConfig | None = None
|
|
106
130
|
training: RLTrainingConfig | None = None
|
|
107
|
-
rubric: dict[str, Any] | None = None
|
|
131
|
+
rubric: dict[str, Any] | None = None # DEPRECATED: use judge.reward_blend and judge.enabled instead
|
|
108
132
|
judge: JudgeConfig | None = None
|
|
109
133
|
tags: dict[str, Any] | None = None
|
|
110
134
|
|
|
@@ -113,7 +137,8 @@ class RLConfig(ExtraModel):
|
|
|
113
137
|
|
|
114
138
|
@classmethod
|
|
115
139
|
def from_mapping(cls, data: Mapping[str, Any]) -> RLConfig:
|
|
116
|
-
|
|
140
|
+
"""Load RL config from dict/TOML mapping."""
|
|
141
|
+
return cls.model_validate(data)
|
|
117
142
|
|
|
118
143
|
@classmethod
|
|
119
144
|
def from_path(cls, path: Path) -> RLConfig:
|
|
@@ -7,7 +7,7 @@ from typing import Any
|
|
|
7
7
|
from pydantic import Field
|
|
8
8
|
|
|
9
9
|
from ..utils import load_toml
|
|
10
|
-
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
|
|
10
|
+
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel, LoraConfig, PolicyConfig
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class JobConfig(ExtraModel):
|
|
@@ -35,6 +35,7 @@ class TrainingConfig(ExtraModel):
|
|
|
35
35
|
mode: str | None = None
|
|
36
36
|
use_qlora: bool | None = None
|
|
37
37
|
validation: TrainingValidationConfig | None = None
|
|
38
|
+
lora: LoraConfig | None = None # NEW: nested LoRA config
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
class HyperparametersParallelism(ExtraModel):
|
|
@@ -65,10 +66,12 @@ class HyperparametersConfig(ExtraModel):
|
|
|
65
66
|
class SFTConfig(ExtraModel):
|
|
66
67
|
algorithm: AlgorithmConfig | None = None
|
|
67
68
|
job: JobConfig
|
|
69
|
+
policy: PolicyConfig | None = None # NEW: unified policy section
|
|
68
70
|
compute: ComputeConfig | None = None
|
|
69
71
|
data: SFTDataConfig | None = None
|
|
70
72
|
training: TrainingConfig | None = None
|
|
71
73
|
hyperparameters: HyperparametersConfig = Field(default_factory=HyperparametersConfig)
|
|
74
|
+
lora: dict[str, Any] | None = None # DEPRECATED: use training.lora instead
|
|
72
75
|
tags: dict[str, Any] | None = None
|
|
73
76
|
|
|
74
77
|
def to_dict(self) -> dict[str, Any]:
|
|
@@ -76,7 +79,8 @@ class SFTConfig(ExtraModel):
|
|
|
76
79
|
|
|
77
80
|
@classmethod
|
|
78
81
|
def from_mapping(cls, data: Mapping[str, Any]) -> SFTConfig:
|
|
79
|
-
|
|
82
|
+
"""Load SFT config from dict/TOML mapping."""
|
|
83
|
+
return cls.model_validate(data)
|
|
80
84
|
|
|
81
85
|
@classmethod
|
|
82
86
|
def from_path(cls, path: Path) -> SFTConfig:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel, ConfigDict
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class ExtraModel(BaseModel):
|
|
@@ -15,10 +15,67 @@ class AlgorithmConfig(ExtraModel):
|
|
|
15
15
|
variety: str
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
class TopologyConfig(ExtraModel):
|
|
19
|
+
"""Compute topology configuration - how GPUs are distributed across processes."""
|
|
20
|
+
type: str | None = None # e.g., "single_node_split"
|
|
21
|
+
gpus_for_vllm: int | None = None
|
|
22
|
+
gpus_for_training: int | None = None
|
|
23
|
+
gpus_for_ref: int | None = None
|
|
24
|
+
tensor_parallel: int | None = None
|
|
25
|
+
reference_placement: str | None = None # NEW: e.g., "none", "shared", "dedicated"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LoraConfig(ExtraModel):
|
|
29
|
+
"""LoRA (Low-Rank Adaptation) training configuration."""
|
|
30
|
+
r: int | None = None # Rank
|
|
31
|
+
alpha: int | None = None
|
|
32
|
+
dropout: float | None = None
|
|
33
|
+
target_modules: list[str] | None = None
|
|
34
|
+
|
|
35
|
+
|
|
18
36
|
class ComputeConfig(ExtraModel):
|
|
19
37
|
gpu_type: str
|
|
20
38
|
gpu_count: int
|
|
21
39
|
nodes: int | None = None
|
|
40
|
+
topology: TopologyConfig | None = None # NEW: nested topology
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PolicyConfig(ExtraModel):
|
|
44
|
+
"""Unified policy configuration for both SFT and RL.
|
|
45
|
+
|
|
46
|
+
This is the SINGLE SOURCE OF TRUTH for:
|
|
47
|
+
- What model to use (model_name or source)
|
|
48
|
+
- How to sample from it (temperature, max_tokens, etc.)
|
|
49
|
+
- How to train it (trainer_mode, label)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# Model specification (exactly one required)
|
|
53
|
+
model_name: str | None = None # e.g., "Qwen/Qwen3-4B"
|
|
54
|
+
source: str | None = None # e.g., "ft:abc123" for checkpoints
|
|
55
|
+
|
|
56
|
+
# Sampling parameters (with sensible defaults)
|
|
57
|
+
max_tokens: int = 512
|
|
58
|
+
temperature: float = 0.7
|
|
59
|
+
top_p: float = 0.95
|
|
60
|
+
top_k: int | None = None
|
|
61
|
+
repetition_penalty: float = 1.0
|
|
62
|
+
stop_sequences: list[str] | None = None
|
|
63
|
+
|
|
64
|
+
# Training-specific
|
|
65
|
+
trainer_mode: str # "lora", "full", "qlora"
|
|
66
|
+
label: str # Model identifier/name
|
|
67
|
+
|
|
68
|
+
# Optional - for distributed inference
|
|
69
|
+
inference_url: str | None = None
|
|
70
|
+
|
|
71
|
+
@model_validator(mode="after")
|
|
72
|
+
def _ensure_exactly_one_source(self) -> PolicyConfig:
|
|
73
|
+
"""Ensure exactly one of model_name or source is set."""
|
|
74
|
+
if not (bool(self.model_name) ^ bool(self.source)):
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"Must set exactly one: [policy].model_name OR [policy].source"
|
|
77
|
+
)
|
|
78
|
+
return self
|
|
22
79
|
|
|
23
80
|
|
|
24
|
-
__all__ = ["ExtraModel", "AlgorithmConfig", "ComputeConfig"]
|
|
81
|
+
__all__ = ["ExtraModel", "AlgorithmConfig", "ComputeConfig", "PolicyConfig", "TopologyConfig", "LoraConfig"]
|
|
@@ -8,6 +8,7 @@ from pathlib import Path
|
|
|
8
8
|
from typing import Any, cast
|
|
9
9
|
|
|
10
10
|
import click
|
|
11
|
+
from synth_ai.utils.env import resolve_env_var
|
|
11
12
|
|
|
12
13
|
from . import task_app
|
|
13
14
|
from .utils import REPO_ROOT, mask_value, read_env_file, write_env_value
|
|
@@ -232,18 +233,16 @@ def _resolve_key(resolver: EnvResolver, spec: KeySpec) -> str:
|
|
|
232
233
|
_maybe_persist(resolver, spec, env_val)
|
|
233
234
|
os.environ[spec.name] = env_val
|
|
234
235
|
return env_val
|
|
235
|
-
options: list[tuple[str, Callable[[], str | None]]] = []
|
|
236
236
|
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
return value
|
|
237
|
+
resolve_env_var(spec.name)
|
|
238
|
+
resolved_value = os.environ.get(spec.name)
|
|
239
|
+
if resolved_value:
|
|
240
|
+
click.echo(f"Found {spec.name} via secrets helper: {mask_value(resolved_value)}")
|
|
241
|
+
_maybe_persist(resolver, spec, resolved_value)
|
|
242
|
+
os.environ[spec.name] = resolved_value
|
|
243
|
+
return resolved_value
|
|
245
244
|
|
|
246
|
-
options
|
|
245
|
+
options: list[tuple[str, Callable[[], str | None]]] = []
|
|
247
246
|
|
|
248
247
|
def _pick_env() -> str | None:
|
|
249
248
|
resolver.select_new_env()
|
|
@@ -276,6 +275,10 @@ def _resolve_key(resolver: EnvResolver, spec: KeySpec) -> str:
|
|
|
276
275
|
|
|
277
276
|
def _maybe_persist(resolver: EnvResolver, spec: KeySpec, value: str) -> None:
|
|
278
277
|
# Automatically save (no prompt)
|
|
278
|
+
# Skip auto-persisting TASK_APP_URL to prevent overwriting CLI overrides
|
|
279
|
+
if spec.name == "TASK_APP_URL":
|
|
280
|
+
click.echo(f"Skipping auto-persist for {spec.name} (use CLI flags to override)")
|
|
281
|
+
return
|
|
279
282
|
resolver.set_value(spec.name, value)
|
|
280
283
|
click.echo(f"Saved {spec.name} to {resolver.current_path}")
|
|
281
284
|
|