synth-ai 0.2.16__py3-none-any.whl ā 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +2 -2
- examples/blog_posts/pokemon_vl/README.md +98 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +25 -0
- examples/blog_posts/pokemon_vl/configs/eval_rl_final.toml +24 -0
- examples/blog_posts/pokemon_vl/configs/filter_high_reward.toml +10 -0
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +42 -0
- examples/blog_posts/pokemon_vl/configs/train_sft_qwen4b_vl.toml +40 -0
- examples/blog_posts/warming_up_to_rl/README.md +158 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_groq_qwen32b.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_openai_gpt_oss_120b.toml +29 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +10 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +41 -0
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +40 -0
- examples/dev/qwen3_32b_qlora_4xh100.toml +5 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +65 -107
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -1
- examples/multi_step/configs/crafter_rl_stepwise_simple_NEW_FORMAT.toml +105 -0
- examples/multi_step/configs/verilog_rl_lora.toml +80 -123
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -3
- examples/qwen_coder/configs/coder_lora_4b.toml +4 -1
- examples/qwen_coder/configs/coder_lora_small.toml +1 -3
- examples/qwen_vl/README.md +10 -12
- examples/qwen_vl/SETUP_COMPLETE.md +7 -8
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +2 -3
- examples/qwen_vl/collect_data_via_cli.md +76 -84
- examples/qwen_vl/collect_vision_traces.py +4 -4
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +40 -57
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +1 -2
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +20 -37
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +21 -40
- examples/qwen_vl/configs/eval_qwen3vl_vision.toml +26 -0
- examples/qwen_vl/configs/{filter_qwen2vl_sft.toml ā filter_qwen3vl_sft.toml} +4 -5
- examples/qwen_vl/configs/filter_vision_sft.toml +2 -3
- examples/qwen_vl/crafter_qwen_vl_agent.py +5 -5
- examples/qwen_vl/run_vision_comparison.sh +6 -7
- examples/rl/README.md +5 -5
- examples/rl/configs/rl_from_base_qwen.toml +26 -1
- examples/rl/configs/rl_from_base_qwen17.toml +5 -2
- examples/rl/task_app/README.md +1 -2
- examples/rl/task_app/math_single_step.py +2 -2
- examples/run_crafter_demo.sh +2 -2
- examples/sft/README.md +1 -1
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -1
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -1
- examples/swe/task_app/README.md +32 -2
- examples/swe/task_app/grpo_swe_mini.py +4 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +1 -1
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +37 -10
- examples/swe/task_app/hosted/inference/openai_client.py +4 -4
- examples/swe/task_app/morph_backend.py +178 -0
- examples/task_apps/crafter/task_app/README.md +1 -1
- examples/task_apps/crafter/task_app/grpo_crafter.py +66 -3
- examples/task_apps/crafter/task_app/grpo_crafter_task_app.py +1 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +4 -26
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +17 -49
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +13 -5
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +15 -1
- examples/task_apps/enron/task_app/grpo_enron_task_app.py +1 -1
- examples/task_apps/math/README.md +1 -2
- examples/task_apps/pokemon_red/README.md +3 -4
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +6 -5
- examples/task_apps/pokemon_red/eval_pokemon_red_policy.py +1 -2
- examples/task_apps/pokemon_red/task_app.py +36 -5
- examples/task_apps/sokoban/README.md +2 -3
- examples/task_apps/verilog/eval_groq_qwen32b.toml +12 -14
- examples/task_apps/verilog/task_app/grpo_verilog_task_app.py +1 -1
- examples/vlm/configs/crafter_vlm_gpt4o.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +4 -1
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +0 -2
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -2
- examples/warming_up_to_rl/run_local_rollout_traced.py +1 -1
- examples/warming_up_to_rl/task_app/README.md +1 -1
- examples/warming_up_to_rl/task_app/grpo_crafter.py +134 -3
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +3 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +4 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +6 -3
- examples/workflows/math_rl/configs/rl_from_base_qwen.toml +27 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +5 -0
- synth_ai/api/train/builders.py +9 -3
- synth_ai/api/train/cli.py +125 -10
- synth_ai/api/train/configs/__init__.py +8 -1
- synth_ai/api/train/configs/rl.py +32 -7
- synth_ai/api/train/configs/sft.py +6 -2
- synth_ai/api/train/configs/shared.py +59 -2
- synth_ai/auth/credentials.py +119 -0
- synth_ai/cli/__init__.py +12 -4
- synth_ai/cli/commands/__init__.py +17 -0
- synth_ai/cli/commands/demo/__init__.py +6 -0
- synth_ai/cli/commands/demo/core.py +163 -0
- synth_ai/cli/commands/deploy/__init__.py +23 -0
- synth_ai/cli/commands/deploy/core.py +614 -0
- synth_ai/cli/commands/deploy/errors.py +72 -0
- synth_ai/cli/commands/deploy/validation.py +11 -0
- synth_ai/cli/commands/eval/__init__.py +19 -0
- synth_ai/cli/commands/eval/core.py +1109 -0
- synth_ai/cli/commands/eval/errors.py +81 -0
- synth_ai/cli/commands/eval/validation.py +133 -0
- synth_ai/cli/commands/filter/__init__.py +12 -0
- synth_ai/cli/commands/filter/core.py +388 -0
- synth_ai/cli/commands/filter/errors.py +55 -0
- synth_ai/cli/commands/filter/validation.py +77 -0
- synth_ai/cli/commands/help/__init__.py +177 -0
- synth_ai/cli/commands/help/core.py +73 -0
- synth_ai/cli/commands/status/__init__.py +64 -0
- synth_ai/cli/commands/status/client.py +192 -0
- synth_ai/cli/commands/status/config.py +92 -0
- synth_ai/cli/commands/status/errors.py +20 -0
- synth_ai/cli/commands/status/formatters.py +164 -0
- synth_ai/cli/commands/status/subcommands/__init__.py +9 -0
- synth_ai/cli/commands/status/subcommands/files.py +79 -0
- synth_ai/cli/commands/status/subcommands/jobs.py +334 -0
- synth_ai/cli/commands/status/subcommands/models.py +79 -0
- synth_ai/cli/commands/status/subcommands/runs.py +81 -0
- synth_ai/cli/commands/status/subcommands/summary.py +47 -0
- synth_ai/cli/commands/status/utils.py +114 -0
- synth_ai/cli/commands/train/__init__.py +53 -0
- synth_ai/cli/commands/train/core.py +21 -0
- synth_ai/cli/commands/train/errors.py +117 -0
- synth_ai/cli/commands/train/judge_schemas.py +199 -0
- synth_ai/cli/commands/train/judge_validation.py +304 -0
- synth_ai/cli/commands/train/validation.py +443 -0
- synth_ai/cli/demo.py +2 -162
- synth_ai/cli/deploy/__init__.py +28 -0
- synth_ai/cli/deploy/core.py +5 -0
- synth_ai/cli/deploy/errors.py +23 -0
- synth_ai/cli/deploy/validation.py +5 -0
- synth_ai/cli/eval/__init__.py +36 -0
- synth_ai/cli/eval/core.py +5 -0
- synth_ai/cli/eval/errors.py +31 -0
- synth_ai/cli/eval/validation.py +5 -0
- synth_ai/cli/filter/__init__.py +28 -0
- synth_ai/cli/filter/core.py +5 -0
- synth_ai/cli/filter/errors.py +23 -0
- synth_ai/cli/filter/validation.py +5 -0
- synth_ai/cli/modal_serve/__init__.py +12 -0
- synth_ai/cli/modal_serve/core.py +14 -0
- synth_ai/cli/modal_serve/errors.py +8 -0
- synth_ai/cli/modal_serve/validation.py +11 -0
- synth_ai/cli/serve/__init__.py +12 -0
- synth_ai/cli/serve/core.py +14 -0
- synth_ai/cli/serve/errors.py +8 -0
- synth_ai/cli/serve/validation.py +11 -0
- synth_ai/cli/setup.py +20 -265
- synth_ai/cli/status.py +7 -126
- synth_ai/cli/task_app_deploy.py +1 -10
- synth_ai/cli/task_app_modal_serve.py +4 -9
- synth_ai/cli/task_app_serve.py +4 -11
- synth_ai/cli/task_apps.py +58 -1487
- synth_ai/cli/train/__init__.py +12 -0
- synth_ai/cli/train/core.py +21 -0
- synth_ai/cli/train/errors.py +8 -0
- synth_ai/cli/train/validation.py +24 -0
- synth_ai/cli/train.py +1 -14
- synth_ai/demos/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/red/engine.py +33 -12
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +151 -179
- synth_ai/environments/examples/red/environment.py +26 -0
- synth_ai/environments/examples/red/trace_hooks_v3.py +168 -0
- synth_ai/http.py +12 -0
- synth_ai/judge_schemas.py +10 -11
- synth_ai/learning/rl/client.py +3 -1
- synth_ai/streaming/__init__.py +29 -0
- synth_ai/streaming/config.py +94 -0
- synth_ai/streaming/handlers.py +469 -0
- synth_ai/streaming/streamer.py +301 -0
- synth_ai/streaming/types.py +95 -0
- synth_ai/task/validators.py +2 -2
- synth_ai/tracing_v3/migration_helper.py +1 -2
- synth_ai/utils/env.py +25 -18
- synth_ai/utils/http.py +4 -1
- synth_ai/utils/modal.py +2 -2
- {synth_ai-0.2.16.dist-info ā synth_ai-0.2.17.dist-info}/METADATA +8 -3
- {synth_ai-0.2.16.dist-info ā synth_ai-0.2.17.dist-info}/RECORD +184 -109
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +0 -44
- synth_ai/cli/tui.py +0 -62
- synth_ai/tui/__init__.py +0 -5
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -911
- {synth_ai-0.2.16.dist-info ā synth_ai-0.2.17.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.16.dist-info ā synth_ai-0.2.17.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.16.dist-info ā synth_ai-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.16.dist-info ā synth_ai-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""TOML validation logic for train commands (SFT and RL)."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import MutableMapping
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import ValidationError
|
|
8
|
+
from synth_ai.api.train.configs.rl import RLConfig
|
|
9
|
+
from synth_ai.api.train.configs.sft import SFTConfig
|
|
10
|
+
from synth_ai.api.train.utils import load_toml
|
|
11
|
+
|
|
12
|
+
from .errors import (
|
|
13
|
+
InvalidJudgeConfigError,
|
|
14
|
+
InvalidRLConfigError,
|
|
15
|
+
InvalidRubricConfigError,
|
|
16
|
+
InvalidSFTConfigError,
|
|
17
|
+
MissingAlgorithmError,
|
|
18
|
+
MissingComputeError,
|
|
19
|
+
MissingDatasetError,
|
|
20
|
+
MissingModelError,
|
|
21
|
+
TomlParseError,
|
|
22
|
+
UnsupportedAlgorithmError,
|
|
23
|
+
)
|
|
24
|
+
from .judge_validation import extract_and_validate_judge_rubric
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"validate_sft_config",
|
|
28
|
+
"validate_rl_config",
|
|
29
|
+
"load_and_validate_sft",
|
|
30
|
+
"load_and_validate_rl",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def validate_sft_config(config: MutableMapping[str, Any]) -> dict[str, Any]:
|
|
35
|
+
"""Validate SFT configuration from TOML.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config: Raw configuration dictionary from TOML
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Validated configuration dictionary
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
InvalidSFTConfigError: If validation fails
|
|
45
|
+
MissingAlgorithmError: If algorithm section is missing or invalid
|
|
46
|
+
MissingModelError: If model is not specified
|
|
47
|
+
MissingDatasetError: If dataset path is not specified
|
|
48
|
+
MissingComputeError: If compute section is missing required fields
|
|
49
|
+
"""
|
|
50
|
+
# Check for required top-level sections
|
|
51
|
+
if "algorithm" not in config or not config["algorithm"]:
|
|
52
|
+
raise MissingAlgorithmError(
|
|
53
|
+
detail="[algorithm] section is required for SFT configs"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if "job" not in config or not config["job"]:
|
|
57
|
+
raise InvalidSFTConfigError(
|
|
58
|
+
detail="[job] section is required for SFT configs"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
job = config.get("job", {})
|
|
62
|
+
if not job.get("model"):
|
|
63
|
+
raise MissingModelError(
|
|
64
|
+
detail="[job].model is required (e.g., 'Qwen/Qwen3-4B')"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Check that at least one dataset source is specified
|
|
68
|
+
if not (job.get("data") or job.get("data_path")):
|
|
69
|
+
raise MissingDatasetError(
|
|
70
|
+
detail="[job].data or [job].data_path must be specified",
|
|
71
|
+
hint="Provide path to training JSONL file"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Validate algorithm type, method, and variety
|
|
75
|
+
algorithm = config.get("algorithm", {})
|
|
76
|
+
if algorithm.get("type") not in {"offline", None}:
|
|
77
|
+
raise UnsupportedAlgorithmError(
|
|
78
|
+
algorithm_type=algorithm.get("type", "unknown"),
|
|
79
|
+
expected="offline",
|
|
80
|
+
hint="SFT requires algorithm.type = 'offline'"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
method = algorithm.get("method", "")
|
|
84
|
+
if method and method not in {"sft", "supervised_finetune"}:
|
|
85
|
+
raise UnsupportedAlgorithmError(
|
|
86
|
+
algorithm_type=method,
|
|
87
|
+
expected="sft or supervised_finetune",
|
|
88
|
+
hint="SFT requires algorithm.method = 'sft' or 'supervised_finetune'"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Validate variety is present
|
|
92
|
+
if not algorithm.get("variety"):
|
|
93
|
+
raise MissingAlgorithmError(
|
|
94
|
+
detail="[algorithm].variety is required (e.g., 'fft', 'lora', 'qlora')"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Validate compute section
|
|
98
|
+
compute = config.get("compute", {})
|
|
99
|
+
if not compute:
|
|
100
|
+
raise MissingComputeError(
|
|
101
|
+
detail="[compute] section is required",
|
|
102
|
+
hint="Specify gpu_type, gpu_count, and nodes"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if not compute.get("gpu_type"):
|
|
106
|
+
raise MissingComputeError(
|
|
107
|
+
detail="[compute].gpu_type is required (e.g., 'H100', 'A100')"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if not compute.get("gpu_count"):
|
|
111
|
+
raise MissingComputeError(
|
|
112
|
+
detail="[compute].gpu_count is required"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Validate using Pydantic model
|
|
116
|
+
try:
|
|
117
|
+
validated = SFTConfig.from_mapping(config)
|
|
118
|
+
return validated.to_dict()
|
|
119
|
+
except ValidationError as exc:
|
|
120
|
+
errors = []
|
|
121
|
+
for error in exc.errors():
|
|
122
|
+
loc = ".".join(str(x) for x in error["loc"])
|
|
123
|
+
msg = error["msg"]
|
|
124
|
+
errors.append(f" ⢠{loc}: {msg}")
|
|
125
|
+
raise InvalidSFTConfigError(
|
|
126
|
+
detail="Pydantic validation failed:\n" + "\n".join(errors)
|
|
127
|
+
) from exc
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def validate_rl_config(config: MutableMapping[str, Any]) -> dict[str, Any]:
|
|
131
|
+
"""Validate RL configuration from TOML.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
config: Raw configuration dictionary from TOML
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Validated configuration dictionary
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
InvalidRLConfigError: If validation fails
|
|
141
|
+
MissingAlgorithmError: If algorithm section is missing or invalid
|
|
142
|
+
MissingModelError: If model is not specified
|
|
143
|
+
MissingComputeError: If compute section is missing required fields
|
|
144
|
+
"""
|
|
145
|
+
# Check for required top-level sections
|
|
146
|
+
if "algorithm" not in config or not config["algorithm"]:
|
|
147
|
+
raise MissingAlgorithmError(
|
|
148
|
+
detail="[algorithm] section is required for RL configs"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Check for model OR policy (policy is the new format)
|
|
152
|
+
if "policy" not in config and "model" not in config:
|
|
153
|
+
raise MissingModelError(
|
|
154
|
+
detail="[policy] or [model] section is required for RL configs"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Validate algorithm type, method, and variety
|
|
158
|
+
algorithm = config.get("algorithm", {})
|
|
159
|
+
if algorithm.get("type") not in {"online", None}:
|
|
160
|
+
raise UnsupportedAlgorithmError(
|
|
161
|
+
algorithm_type=algorithm.get("type", "unknown"),
|
|
162
|
+
expected="online",
|
|
163
|
+
hint="RL requires algorithm.type = 'online'"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
method = algorithm.get("method", "")
|
|
167
|
+
if method and method not in {"policy_gradient", "ppo", "gspo"}:
|
|
168
|
+
raise UnsupportedAlgorithmError(
|
|
169
|
+
algorithm_type=method,
|
|
170
|
+
expected="policy_gradient",
|
|
171
|
+
hint="RL requires algorithm.method = 'policy_gradient'"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Validate variety is present
|
|
175
|
+
if not algorithm.get("variety"):
|
|
176
|
+
raise MissingAlgorithmError(
|
|
177
|
+
detail="[algorithm].variety is required (e.g., 'gspo', 'ppo')"
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Validate model/policy section
|
|
181
|
+
model = config.get("model", {})
|
|
182
|
+
policy = config.get("policy", {})
|
|
183
|
+
|
|
184
|
+
# Use policy if available, otherwise fall back to model
|
|
185
|
+
if policy:
|
|
186
|
+
if not policy.get("model_name") and not policy.get("source"):
|
|
187
|
+
raise MissingModelError(
|
|
188
|
+
detail="[policy].model_name or [policy].source must be specified",
|
|
189
|
+
hint="Provide base model (e.g., 'Qwen/Qwen3-4B') or source checkpoint"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
if not policy.get("trainer_mode"):
|
|
193
|
+
raise InvalidRLConfigError(
|
|
194
|
+
detail="[policy].trainer_mode is required (e.g., 'full', 'lora')"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if not policy.get("label"):
|
|
198
|
+
raise InvalidRLConfigError(
|
|
199
|
+
detail="[policy].label is required (e.g., 'my-rl-model')",
|
|
200
|
+
hint="Provide a descriptive label for this model"
|
|
201
|
+
)
|
|
202
|
+
elif model:
|
|
203
|
+
if not model.get("base") and not model.get("source"):
|
|
204
|
+
raise MissingModelError(
|
|
205
|
+
detail="[model].base or [model].source must be specified",
|
|
206
|
+
hint="Provide base model (e.g., 'Qwen/Qwen3-4B') or source checkpoint"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if not model.get("trainer_mode"):
|
|
210
|
+
raise InvalidRLConfigError(
|
|
211
|
+
detail="[model].trainer_mode is required (e.g., 'full', 'lora')"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if not model.get("label"):
|
|
215
|
+
raise InvalidRLConfigError(
|
|
216
|
+
detail="[model].label is required (e.g., 'my-rl-model')",
|
|
217
|
+
hint="Provide a descriptive label for this model"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Validate compute section
|
|
221
|
+
compute = config.get("compute", {})
|
|
222
|
+
if not compute:
|
|
223
|
+
raise MissingComputeError(
|
|
224
|
+
detail="[compute] section is required",
|
|
225
|
+
hint="Specify gpu_type and gpu_count"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
if not compute.get("gpu_type"):
|
|
229
|
+
raise MissingComputeError(
|
|
230
|
+
detail="[compute].gpu_type is required (e.g., 'H100', 'A100')"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if not compute.get("gpu_count"):
|
|
234
|
+
raise MissingComputeError(
|
|
235
|
+
detail="[compute].gpu_count is required"
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Check for rollout configuration
|
|
239
|
+
rollout = config.get("rollout", {})
|
|
240
|
+
if not rollout:
|
|
241
|
+
raise InvalidRLConfigError(
|
|
242
|
+
detail="[rollout] section is required for RL configs",
|
|
243
|
+
hint="Specify env_name, policy_name, max_turns, etc."
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if not rollout.get("env_name"):
|
|
247
|
+
raise InvalidRLConfigError(
|
|
248
|
+
detail="[rollout].env_name is required (e.g., 'math', 'crafter')"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if not rollout.get("policy_name"):
|
|
252
|
+
raise InvalidRLConfigError(
|
|
253
|
+
detail="[rollout].policy_name is required"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Validate topology section (can be top-level or under compute)
|
|
257
|
+
topology = config.get("topology") or compute.get("topology", {})
|
|
258
|
+
if not topology:
|
|
259
|
+
raise InvalidRLConfigError(
|
|
260
|
+
detail="[topology] or [compute.topology] section is required",
|
|
261
|
+
hint="Specify gpus_for_vllm, gpus_for_training, etc."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Validate vllm section and tensor_parallel consistency
|
|
265
|
+
vllm = config.get("vllm", {})
|
|
266
|
+
topology_tensor_parallel = topology.get("tensor_parallel")
|
|
267
|
+
vllm_tensor_parallel = vllm.get("tensor_parallel_size")
|
|
268
|
+
|
|
269
|
+
if topology_tensor_parallel and not vllm_tensor_parallel:
|
|
270
|
+
raise InvalidRLConfigError(
|
|
271
|
+
detail="Both [topology].tensor_parallel and [vllm].tensor_parallel_size must be provided",
|
|
272
|
+
hint=f"Add [vllm] section with tensor_parallel_size={topology_tensor_parallel}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if vllm_tensor_parallel and not topology_tensor_parallel:
|
|
276
|
+
raise InvalidRLConfigError(
|
|
277
|
+
detail="Both [topology].tensor_parallel and [vllm].tensor_parallel_size must be provided",
|
|
278
|
+
hint=f"Add tensor_parallel={vllm_tensor_parallel} to [topology] section"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Check for training section and its required fields
|
|
282
|
+
training = config.get("training", {})
|
|
283
|
+
if training:
|
|
284
|
+
required_training_fields = {
|
|
285
|
+
"num_epochs": "number of training epochs",
|
|
286
|
+
"iterations_per_epoch": "iterations per epoch",
|
|
287
|
+
"max_turns": "maximum turns",
|
|
288
|
+
"batch_size": "batch size",
|
|
289
|
+
"group_size": "group size",
|
|
290
|
+
"learning_rate": "learning rate",
|
|
291
|
+
"weight_sync_interval": "weight sync interval",
|
|
292
|
+
"log_interval": "logging interval",
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
for field, description in required_training_fields.items():
|
|
296
|
+
if field not in training:
|
|
297
|
+
raise InvalidRLConfigError(
|
|
298
|
+
detail=f"[training].{field} is required ({description})",
|
|
299
|
+
hint=f"Add {field} to the [training] section"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Validate weight_sync_interval is positive
|
|
303
|
+
weight_sync_interval = training.get("weight_sync_interval")
|
|
304
|
+
if weight_sync_interval is not None and weight_sync_interval <= 0:
|
|
305
|
+
raise InvalidRLConfigError(
|
|
306
|
+
detail="[training].weight_sync_interval must be a positive integer",
|
|
307
|
+
hint="Set weight_sync_interval to a value >= 1"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Ensure weight_sync block exists with proper defaults
|
|
311
|
+
# Backend requires mode="direct" - always inject it
|
|
312
|
+
if "weight_sync" not in training:
|
|
313
|
+
training["weight_sync"] = {
|
|
314
|
+
"enable": True,
|
|
315
|
+
"mode": "direct", # Backend requirement
|
|
316
|
+
"targets": ["policy"],
|
|
317
|
+
"interval": training.get("weight_sync_interval", 1),
|
|
318
|
+
}
|
|
319
|
+
else:
|
|
320
|
+
weight_sync = training["weight_sync"]
|
|
321
|
+
# Always force mode to "direct" (backend requirement)
|
|
322
|
+
weight_sync["mode"] = "direct"
|
|
323
|
+
|
|
324
|
+
# Validate existing weight_sync block
|
|
325
|
+
if not weight_sync.get("enable"):
|
|
326
|
+
raise InvalidRLConfigError(
|
|
327
|
+
detail="[training.weight_sync].enable must be true",
|
|
328
|
+
hint="Set enable=true in the weight_sync section"
|
|
329
|
+
)
|
|
330
|
+
targets = weight_sync.get("targets", [])
|
|
331
|
+
if not targets or "policy" not in targets:
|
|
332
|
+
raise InvalidRLConfigError(
|
|
333
|
+
detail="[training.weight_sync].targets must include 'policy'",
|
|
334
|
+
hint="Add targets=['policy'] to the weight_sync section"
|
|
335
|
+
)
|
|
336
|
+
# Inject interval if not present
|
|
337
|
+
if "interval" not in weight_sync:
|
|
338
|
+
weight_sync["interval"] = training.get("weight_sync_interval", 1)
|
|
339
|
+
|
|
340
|
+
# Check for evaluation section
|
|
341
|
+
evaluation = config.get("evaluation", {})
|
|
342
|
+
if evaluation:
|
|
343
|
+
required_eval_fields = {
|
|
344
|
+
"instances": "number of evaluation instances",
|
|
345
|
+
"every_n_iters": "evaluation frequency",
|
|
346
|
+
"seeds": "evaluation seeds",
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
for field, description in required_eval_fields.items():
|
|
350
|
+
if field not in evaluation:
|
|
351
|
+
raise InvalidRLConfigError(
|
|
352
|
+
detail=f"[evaluation].{field} is required ({description})",
|
|
353
|
+
hint=f"Add {field} to the [evaluation] section"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Inject services section if not present (will be populated at runtime)
|
|
357
|
+
if "services" not in config:
|
|
358
|
+
config["services"] = {
|
|
359
|
+
"task_url": "placeholder", # Will be resolved at runtime
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
# Inject reference placement if not present (like builders.py does)
|
|
363
|
+
# Reference is now under compute.topology.reference_placement
|
|
364
|
+
if "compute" not in config:
|
|
365
|
+
config["compute"] = {}
|
|
366
|
+
if "topology" not in config["compute"]:
|
|
367
|
+
config["compute"]["topology"] = {}
|
|
368
|
+
if "reference_placement" not in config["compute"]["topology"]:
|
|
369
|
+
config["compute"]["topology"]["reference_placement"] = "none"
|
|
370
|
+
|
|
371
|
+
# Validate judge/rubric configuration with formalized Pydantic models
|
|
372
|
+
# This will emit deprecation warnings for dead fields and validate structure
|
|
373
|
+
try:
|
|
374
|
+
rubric_config, judge_config = extract_and_validate_judge_rubric(config)
|
|
375
|
+
# Validation passed - configs are clean and ready for use
|
|
376
|
+
# The validated Pydantic models can be used by training code if needed
|
|
377
|
+
except (InvalidJudgeConfigError, InvalidRubricConfigError) as exc:
|
|
378
|
+
raise InvalidRLConfigError(
|
|
379
|
+
detail=f"Judge/Rubric validation failed: {exc.detail}",
|
|
380
|
+
hint="Check JUDGE_RUBRIC_CLEANUP_GUIDE.md for migration help."
|
|
381
|
+
) from exc
|
|
382
|
+
|
|
383
|
+
# Validate using Pydantic model
|
|
384
|
+
try:
|
|
385
|
+
validated = RLConfig.from_mapping(config)
|
|
386
|
+
return validated.to_dict()
|
|
387
|
+
except ValidationError as exc:
|
|
388
|
+
errors = []
|
|
389
|
+
for error in exc.errors():
|
|
390
|
+
loc = ".".join(str(x) for x in error["loc"])
|
|
391
|
+
msg = error["msg"]
|
|
392
|
+
errors.append(f" ⢠{loc}: {msg}")
|
|
393
|
+
raise InvalidRLConfigError(
|
|
394
|
+
detail="Pydantic validation failed:\n" + "\n".join(errors)
|
|
395
|
+
) from exc
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def load_and_validate_sft(config_path: Path) -> dict[str, Any]:
|
|
399
|
+
"""Load and validate an SFT TOML configuration file.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
config_path: Path to TOML configuration file
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
Validated configuration dictionary
|
|
406
|
+
|
|
407
|
+
Raises:
|
|
408
|
+
TomlParseError: If TOML parsing fails
|
|
409
|
+
InvalidSFTConfigError: If validation fails
|
|
410
|
+
"""
|
|
411
|
+
try:
|
|
412
|
+
raw_config = load_toml(config_path)
|
|
413
|
+
except Exception as exc:
|
|
414
|
+
raise TomlParseError(
|
|
415
|
+
path=str(config_path),
|
|
416
|
+
detail=str(exc)
|
|
417
|
+
) from exc
|
|
418
|
+
|
|
419
|
+
return validate_sft_config(raw_config)
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def load_and_validate_rl(config_path: Path) -> dict[str, Any]:
|
|
423
|
+
"""Load and validate an RL TOML configuration file.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
config_path: Path to TOML configuration file
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Validated configuration dictionary
|
|
430
|
+
|
|
431
|
+
Raises:
|
|
432
|
+
TomlParseError: If TOML parsing fails
|
|
433
|
+
InvalidRLConfigError: If validation fails
|
|
434
|
+
"""
|
|
435
|
+
try:
|
|
436
|
+
raw_config = load_toml(config_path)
|
|
437
|
+
except Exception as exc:
|
|
438
|
+
raise TomlParseError(
|
|
439
|
+
path=str(config_path),
|
|
440
|
+
detail=str(exc)
|
|
441
|
+
) from exc
|
|
442
|
+
|
|
443
|
+
return validate_rl_config(raw_config)
|
synth_ai/cli/demo.py
CHANGED
|
@@ -1,165 +1,5 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
CLI: interactive launcher for example demos and RL demo helpers.
|
|
4
|
-
|
|
5
|
-
- `synth-ai demo` (no subcommand) -> initialize RL demo files into ./synth_demo/
|
|
6
|
-
- `synth-ai demo deploy|configure|run` -> invoke RL demo helpers directly.
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
1
|
from __future__ import annotations
|
|
10
2
|
|
|
11
|
-
import
|
|
12
|
-
import os
|
|
13
|
-
import subprocess
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
from typing import Any, cast
|
|
16
|
-
|
|
17
|
-
import click
|
|
18
|
-
from click.exceptions import Exit
|
|
19
|
-
|
|
20
|
-
demo_commands = cast(
|
|
21
|
-
Any, importlib.import_module("synth_ai.demos.core.cli")
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def _find_demo_scripts(root: Path) -> list[Path]:
|
|
26
|
-
if not root.exists():
|
|
27
|
-
return []
|
|
28
|
-
return sorted([p for p in root.rglob("run_demo.sh") if p.is_file()])
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def _run_demo_command(func, *args, **kwargs) -> None:
|
|
32
|
-
"""Invoke a demo command and exit via Click on non-zero status codes."""
|
|
33
|
-
|
|
34
|
-
try:
|
|
35
|
-
result = func(*args, **kwargs)
|
|
36
|
-
except SystemExit as exc: # pragma: no cover - defensive
|
|
37
|
-
raise Exit(exc.code or 1) from exc
|
|
38
|
-
|
|
39
|
-
if result is None:
|
|
40
|
-
return
|
|
41
|
-
|
|
42
|
-
try:
|
|
43
|
-
code = int(result)
|
|
44
|
-
except (TypeError, ValueError):
|
|
45
|
-
return
|
|
46
|
-
if code != 0:
|
|
47
|
-
raise Exit(code)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def register(cli):
|
|
51
|
-
@cli.group("demo", invoke_without_command=True)
|
|
52
|
-
@click.option(
|
|
53
|
-
"--force", is_flag=True, help="Overwrite existing files in CWD when initializing demo"
|
|
54
|
-
)
|
|
55
|
-
@click.option("--list", "list_only", is_flag=True, help="List available legacy demos and exit")
|
|
56
|
-
@click.option("-f", "filter_term", default="", help="Filter legacy demos by substring")
|
|
57
|
-
@click.pass_context
|
|
58
|
-
def demo(ctx: click.Context, force: bool, list_only: bool, filter_term: str):
|
|
59
|
-
"""Demo helpers.
|
|
60
|
-
|
|
61
|
-
- Default (no subcommand): initialize RL demo files into ./synth_demo/ (alias of rl_demo init)
|
|
62
|
-
- Legacy mode: with --list, find and run examples/*/run_demo.sh
|
|
63
|
-
- New RL demo subcommands: deploy, configure, run
|
|
64
|
-
"""
|
|
65
|
-
if ctx.invoked_subcommand is not None:
|
|
66
|
-
return
|
|
67
|
-
|
|
68
|
-
# If explicitly asked to list legacy demos, show interactive picker
|
|
69
|
-
if list_only:
|
|
70
|
-
repo_root = Path(os.getcwd())
|
|
71
|
-
examples_dir = repo_root / "examples"
|
|
72
|
-
demos = _find_demo_scripts(examples_dir)
|
|
73
|
-
if filter_term:
|
|
74
|
-
demos = [p for p in demos if filter_term.lower() in str(p).lower()]
|
|
75
|
-
|
|
76
|
-
if not demos:
|
|
77
|
-
click.echo("No run_demo.sh scripts found under examples/.")
|
|
78
|
-
return
|
|
79
|
-
|
|
80
|
-
click.echo("Available demos:")
|
|
81
|
-
for idx, p in enumerate(demos, start=1):
|
|
82
|
-
click.echo(f" {idx}. {p.relative_to(repo_root)}")
|
|
83
|
-
click.echo("")
|
|
84
|
-
|
|
85
|
-
def _validate_choice(val: str) -> int:
|
|
86
|
-
try:
|
|
87
|
-
i = int(val)
|
|
88
|
-
except Exception as err:
|
|
89
|
-
raise click.BadParameter("Enter a number from the list") from err
|
|
90
|
-
if i < 1 or i > len(demos):
|
|
91
|
-
raise click.BadParameter(f"Choose a number between 1 and {len(demos)}")
|
|
92
|
-
return i
|
|
93
|
-
|
|
94
|
-
choice = click.prompt("Select a demo to run", value_proc=_validate_choice)
|
|
95
|
-
script = demos[choice - 1]
|
|
96
|
-
|
|
97
|
-
click.echo("")
|
|
98
|
-
click.echo(f"š Running {script.relative_to(repo_root)}\n")
|
|
99
|
-
|
|
100
|
-
try:
|
|
101
|
-
subprocess.run(["bash", str(script)], check=True)
|
|
102
|
-
except subprocess.CalledProcessError as e:
|
|
103
|
-
click.echo(f"ā Demo exited with non-zero status: {e.returncode}")
|
|
104
|
-
except KeyboardInterrupt:
|
|
105
|
-
click.echo("\nš Demo interrupted by user")
|
|
106
|
-
return
|
|
107
|
-
|
|
108
|
-
# Default: initialize RL demo files via new command
|
|
109
|
-
_run_demo_command(demo_commands.init, force=force)
|
|
110
|
-
|
|
111
|
-
# (prepare command removed; configure now prepares baseline TOML)
|
|
112
|
-
|
|
113
|
-
# Help pyright understand dynamic Click group attributes
|
|
114
|
-
_dg = cast(Any, demo)
|
|
115
|
-
|
|
116
|
-
@_dg.command("deploy")
|
|
117
|
-
@click.option("--local", is_flag=True, help="Run local FastAPI instead of Modal deploy")
|
|
118
|
-
@click.option(
|
|
119
|
-
"--app",
|
|
120
|
-
type=click.Path(),
|
|
121
|
-
default=None,
|
|
122
|
-
help="Path to Modal app.py for uv run modal deploy",
|
|
123
|
-
)
|
|
124
|
-
@click.option("--name", type=str, default="synth-math-demo", help="Modal app name")
|
|
125
|
-
@click.option(
|
|
126
|
-
"--script",
|
|
127
|
-
type=click.Path(),
|
|
128
|
-
default=None,
|
|
129
|
-
help="Path to deploy_task_app.sh (optional legacy)",
|
|
130
|
-
)
|
|
131
|
-
def demo_deploy(local: bool, app: str | None, name: str, script: str | None):
|
|
132
|
-
_run_demo_command(
|
|
133
|
-
demo_commands.deploy,
|
|
134
|
-
local=local,
|
|
135
|
-
app=app,
|
|
136
|
-
name=name,
|
|
137
|
-
script=script,
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
@_dg.command("configure")
|
|
141
|
-
def demo_configure():
|
|
142
|
-
_run_demo_command(demo_commands.run)
|
|
143
|
-
|
|
144
|
-
@_dg.command("setup")
|
|
145
|
-
def demo_setup():
|
|
146
|
-
_run_demo_command(demo_commands.setup)
|
|
147
|
-
|
|
148
|
-
@_dg.command("run")
|
|
149
|
-
@click.option("--batch-size", type=int, default=None)
|
|
150
|
-
@click.option("--group-size", type=int, default=None)
|
|
151
|
-
@click.option("--model", type=str, default=None)
|
|
152
|
-
@click.option("--timeout", type=int, default=600)
|
|
153
|
-
def demo_run(batch_size: int | None, group_size: int | None, model: str | None, timeout: int):
|
|
154
|
-
_run_demo_command(
|
|
155
|
-
demo_commands.run,
|
|
156
|
-
batch_size=batch_size,
|
|
157
|
-
group_size=group_size,
|
|
158
|
-
model=model,
|
|
159
|
-
timeout=timeout,
|
|
160
|
-
)
|
|
3
|
+
from synth_ai.cli.commands.demo.core import register
|
|
161
4
|
|
|
162
|
-
|
|
163
|
-
def setup_alias():
|
|
164
|
-
"""Perform SDK handshake and write keys to .env."""
|
|
165
|
-
_run_demo_command(demo_commands.setup)
|
|
5
|
+
__all__ = ["register"]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .core import command, get_command
|
|
4
|
+
from .errors import (
|
|
5
|
+
DeployCliError,
|
|
6
|
+
EnvFileDiscoveryError,
|
|
7
|
+
EnvironmentKeyLoadError,
|
|
8
|
+
EnvKeyPreflightError,
|
|
9
|
+
MissingEnvironmentApiKeyError,
|
|
10
|
+
ModalCliResolutionError,
|
|
11
|
+
ModalExecutionError,
|
|
12
|
+
TaskAppNotFoundError,
|
|
13
|
+
)
|
|
14
|
+
from .validation import validate_deploy_options
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"command",
|
|
18
|
+
"get_command",
|
|
19
|
+
"DeployCliError",
|
|
20
|
+
"MissingEnvironmentApiKeyError",
|
|
21
|
+
"EnvironmentKeyLoadError",
|
|
22
|
+
"EnvFileDiscoveryError",
|
|
23
|
+
"TaskAppNotFoundError",
|
|
24
|
+
"ModalCliResolutionError",
|
|
25
|
+
"ModalExecutionError",
|
|
26
|
+
"EnvKeyPreflightError",
|
|
27
|
+
"validate_deploy_options",
|
|
28
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from synth_ai.cli.commands.deploy.errors import (
|
|
4
|
+
DeployCliError,
|
|
5
|
+
EnvFileDiscoveryError,
|
|
6
|
+
EnvironmentKeyLoadError,
|
|
7
|
+
EnvKeyPreflightError,
|
|
8
|
+
MissingEnvironmentApiKeyError,
|
|
9
|
+
ModalCliResolutionError,
|
|
10
|
+
ModalExecutionError,
|
|
11
|
+
TaskAppNotFoundError,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DeployCliError",
|
|
16
|
+
"MissingEnvironmentApiKeyError",
|
|
17
|
+
"EnvironmentKeyLoadError",
|
|
18
|
+
"EnvFileDiscoveryError",
|
|
19
|
+
"TaskAppNotFoundError",
|
|
20
|
+
"ModalCliResolutionError",
|
|
21
|
+
"ModalExecutionError",
|
|
22
|
+
"EnvKeyPreflightError",
|
|
23
|
+
]
|