synth-ai 0.2.12__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/agora_ex/README_MoE.md +224 -0
- examples/agora_ex/__init__.py +7 -0
- examples/agora_ex/agora_ex.py +65 -0
- examples/agora_ex/agora_ex_task_app.py +590 -0
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
- examples/agora_ex/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/system_prompt_CURRENT.md +63 -0
- examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
- examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +7 -1
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +4 -2
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +4 -2
- examples/warming_up_to_rl/run_eval.py +127 -18
- examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +232 -193
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +49 -19
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +133 -0
- synth_ai/api/train/configs/sft.py +94 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/cli/demo.py +38 -39
- synth_ai/cli/rl_demo.py +81 -102
- synth_ai/cli/task_apps.py +3 -0
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +85 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/judge_schemas.py +127 -0
- synth_ai/rubrics/__init__.py +22 -0
- synth_ai/rubrics/validators.py +126 -0
- synth_ai/tracing_v3/serialization.py +130 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +1 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +48 -22
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.12.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
synth_ai/api/train/builders.py
CHANGED
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import Any, cast
|
|
7
7
|
|
|
8
8
|
import click
|
|
9
|
+
from pydantic import ValidationError
|
|
9
10
|
|
|
10
11
|
try:
|
|
11
12
|
_models_module = importlib.import_module("synth_ai.api.models.supported")
|
|
@@ -25,7 +26,8 @@ from .supported_algos import (
|
|
|
25
26
|
ensure_model_supported_for_algorithm,
|
|
26
27
|
validate_algorithm_config,
|
|
27
28
|
)
|
|
28
|
-
from .utils import TrainError, ensure_api_base
|
|
29
|
+
from .utils import TrainError, ensure_api_base
|
|
30
|
+
from .configs import RLConfig, SFTConfig
|
|
29
31
|
|
|
30
32
|
|
|
31
33
|
@dataclass(slots=True)
|
|
@@ -42,6 +44,16 @@ class SFTBuildResult:
|
|
|
42
44
|
validation_file: Path | None
|
|
43
45
|
|
|
44
46
|
|
|
47
|
+
def _format_validation_error(path: Path, exc: ValidationError) -> str:
|
|
48
|
+
lines: list[str] = []
|
|
49
|
+
for error in exc.errors():
|
|
50
|
+
loc = ".".join(str(part) for part in error.get("loc", ()))
|
|
51
|
+
msg = error.get("msg", "invalid value")
|
|
52
|
+
lines.append(f"{loc or '<root>'}: {msg}")
|
|
53
|
+
details = "\n".join(f" - {line}" for line in lines) or " - Invalid configuration"
|
|
54
|
+
return f"Config validation failed ({path}):\n{details}"
|
|
55
|
+
|
|
56
|
+
|
|
45
57
|
def build_rl_payload(
|
|
46
58
|
*,
|
|
47
59
|
config_path: Path,
|
|
@@ -50,13 +62,30 @@ def build_rl_payload(
|
|
|
50
62
|
idempotency: str | None,
|
|
51
63
|
allow_experimental: bool | None = None,
|
|
52
64
|
) -> RLBuildResult:
|
|
53
|
-
data = load_toml(config_path)
|
|
54
65
|
try:
|
|
55
|
-
|
|
66
|
+
rl_cfg = RLConfig.from_path(config_path)
|
|
67
|
+
except ValidationError as exc:
|
|
68
|
+
raise click.ClickException(_format_validation_error(config_path, exc)) from exc
|
|
69
|
+
|
|
70
|
+
data = rl_cfg.to_dict()
|
|
71
|
+
# Ensure required [reference] section for backend validators
|
|
72
|
+
try:
|
|
73
|
+
ref_cfg = data.get("reference") if isinstance(data, dict) else None
|
|
74
|
+
if not isinstance(ref_cfg, dict):
|
|
75
|
+
data["reference"] = {"placement": "none"}
|
|
76
|
+
else:
|
|
77
|
+
ref_cfg.setdefault("placement", "none")
|
|
78
|
+
except Exception:
|
|
79
|
+
# Defensive: never fail builder due to optional defaults
|
|
80
|
+
data["reference"] = {"placement": "none"}
|
|
81
|
+
try:
|
|
82
|
+
spec = validate_algorithm_config(
|
|
83
|
+
rl_cfg.algorithm.model_dump(), expected_family="rl"
|
|
84
|
+
)
|
|
56
85
|
except AlgorithmValidationError as exc:
|
|
57
86
|
raise click.ClickException(str(exc)) from exc
|
|
58
87
|
services = data.get("services") if isinstance(data.get("services"), dict) else {}
|
|
59
|
-
model_cfg =
|
|
88
|
+
model_cfg = rl_cfg.model
|
|
60
89
|
|
|
61
90
|
final_task_url = (
|
|
62
91
|
overrides.get("task_url")
|
|
@@ -69,10 +98,8 @@ def build_rl_payload(
|
|
|
69
98
|
"Task app URL required (provide --task-url or set services.task_url in TOML)"
|
|
70
99
|
)
|
|
71
100
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
|
|
75
|
-
model_base = str(raw_base or "").strip()
|
|
101
|
+
model_source = (model_cfg.source or "").strip()
|
|
102
|
+
model_base = (model_cfg.base or "").strip()
|
|
76
103
|
override_model = (overrides.get("model") or "").strip()
|
|
77
104
|
if override_model:
|
|
78
105
|
model_source = override_model
|
|
@@ -160,22 +187,23 @@ def build_sft_payload(
|
|
|
160
187
|
dataset_override: Path | None,
|
|
161
188
|
allow_experimental: bool | None,
|
|
162
189
|
) -> SFTBuildResult:
|
|
163
|
-
data = load_toml(config_path)
|
|
164
190
|
try:
|
|
165
|
-
|
|
191
|
+
sft_cfg = SFTConfig.from_path(config_path)
|
|
192
|
+
except ValidationError as exc:
|
|
193
|
+
raise TrainError(_format_validation_error(config_path, exc)) from exc
|
|
194
|
+
|
|
195
|
+
data = sft_cfg.to_dict()
|
|
196
|
+
try:
|
|
197
|
+
algo_mapping = sft_cfg.algorithm.model_dump() if sft_cfg.algorithm else None
|
|
198
|
+
spec = validate_algorithm_config(algo_mapping, expected_family="sft")
|
|
166
199
|
except AlgorithmValidationError as exc:
|
|
167
200
|
raise TrainError(str(exc)) from exc
|
|
168
|
-
job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
|
|
169
201
|
data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
|
|
170
202
|
hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
|
|
171
203
|
train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
|
|
172
204
|
compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
|
|
173
205
|
|
|
174
|
-
raw_dataset =
|
|
175
|
-
dataset_override
|
|
176
|
-
or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
|
|
177
|
-
or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
|
|
178
|
-
)
|
|
206
|
+
raw_dataset = dataset_override or sft_cfg.job.data or sft_cfg.job.data_path
|
|
179
207
|
if not raw_dataset:
|
|
180
208
|
raise TrainError("Dataset not specified; pass --dataset or set [job].data")
|
|
181
209
|
dataset_path = Path(raw_dataset)
|
|
@@ -260,9 +288,11 @@ def build_sft_payload(
|
|
|
260
288
|
"enabled": bool(validation_cfg.get("enabled", True))
|
|
261
289
|
}
|
|
262
290
|
|
|
263
|
-
raw_model =
|
|
264
|
-
|
|
265
|
-
|
|
291
|
+
raw_model = (sft_cfg.job.model or "").strip()
|
|
292
|
+
if not raw_model:
|
|
293
|
+
model_block = data.get("model")
|
|
294
|
+
if isinstance(model_block, str):
|
|
295
|
+
raw_model = model_block.strip()
|
|
266
296
|
if not raw_model:
|
|
267
297
|
raise TrainError("Model not specified; set [job].model or [model].base in the config")
|
|
268
298
|
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Typed training config loaders for RL and SFT jobs."""
|
|
2
|
+
|
|
3
|
+
from .shared import AlgorithmConfig, ComputeConfig
|
|
4
|
+
from .sft import (
|
|
5
|
+
HyperparametersConfig,
|
|
6
|
+
HyperparametersParallelism,
|
|
7
|
+
JobConfig,
|
|
8
|
+
SFTConfig,
|
|
9
|
+
SFTDataConfig,
|
|
10
|
+
TrainingConfig,
|
|
11
|
+
TrainingValidationConfig,
|
|
12
|
+
)
|
|
13
|
+
from .rl import (
|
|
14
|
+
EvaluationConfig,
|
|
15
|
+
JudgeConfig,
|
|
16
|
+
JudgeOptionsConfig,
|
|
17
|
+
ModelConfig,
|
|
18
|
+
RLConfig,
|
|
19
|
+
RLServicesConfig,
|
|
20
|
+
RLTrainingConfig,
|
|
21
|
+
RolloutConfig,
|
|
22
|
+
WeightSyncConfig,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"AlgorithmConfig",
|
|
27
|
+
"ComputeConfig",
|
|
28
|
+
"EvaluationConfig",
|
|
29
|
+
"HyperparametersConfig",
|
|
30
|
+
"HyperparametersParallelism",
|
|
31
|
+
"JobConfig",
|
|
32
|
+
"JudgeConfig",
|
|
33
|
+
"JudgeOptionsConfig",
|
|
34
|
+
"ModelConfig",
|
|
35
|
+
"RLConfig",
|
|
36
|
+
"RLServicesConfig",
|
|
37
|
+
"RLTrainingConfig",
|
|
38
|
+
"RolloutConfig",
|
|
39
|
+
"SFTConfig",
|
|
40
|
+
"SFTDataConfig",
|
|
41
|
+
"TrainingConfig",
|
|
42
|
+
"TrainingValidationConfig",
|
|
43
|
+
"WeightSyncConfig",
|
|
44
|
+
]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Mapping
|
|
5
|
+
|
|
6
|
+
from pydantic import model_validator
|
|
7
|
+
|
|
8
|
+
from ..utils import load_toml
|
|
9
|
+
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RLServicesConfig(ExtraModel):
|
|
13
|
+
task_url: str
|
|
14
|
+
judge_url: str | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelConfig(ExtraModel):
|
|
18
|
+
source: str | None = None
|
|
19
|
+
base: str | None = None
|
|
20
|
+
trainer_mode: str
|
|
21
|
+
label: str
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="after")
|
|
24
|
+
def _ensure_exactly_one_source_or_base(self) -> "ModelConfig":
|
|
25
|
+
if bool(self.source) == bool(self.base):
|
|
26
|
+
raise ValueError("Config must set exactly one of [model].source or [model].base")
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RolloutConfig(ExtraModel):
|
|
31
|
+
env_name: str
|
|
32
|
+
policy_name: str
|
|
33
|
+
env_config: dict[str, Any] | None = None
|
|
34
|
+
policy_config: dict[str, Any] | None = None
|
|
35
|
+
max_turns: int
|
|
36
|
+
episodes_per_batch: int
|
|
37
|
+
max_concurrent_rollouts: int
|
|
38
|
+
batches_per_step: int | None = None
|
|
39
|
+
ops: list[str] | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class WeightSyncConfig(ExtraModel):
|
|
43
|
+
enable: bool | None = None
|
|
44
|
+
targets: list[str] | None = None
|
|
45
|
+
mode: str | None = None
|
|
46
|
+
direct: bool | None = None
|
|
47
|
+
verify_every_k: int | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RLTrainingConfig(ExtraModel):
|
|
51
|
+
num_epochs: int
|
|
52
|
+
iterations_per_epoch: int
|
|
53
|
+
gradient_accumulation_steps: int | None = None
|
|
54
|
+
max_accumulated_minibatch: int | None = None
|
|
55
|
+
max_turns: int
|
|
56
|
+
batch_size: int
|
|
57
|
+
group_size: int
|
|
58
|
+
learning_rate: float
|
|
59
|
+
log_interval: int | None = None
|
|
60
|
+
weight_sync_interval: int | None = None
|
|
61
|
+
step_rewards_enabled: bool | None = None
|
|
62
|
+
step_rewards_mode: str | None = None
|
|
63
|
+
step_rewards_indicator_lambda: float | None = None
|
|
64
|
+
step_rewards_beta: float | None = None
|
|
65
|
+
step_rewards_strategy: str | None = None
|
|
66
|
+
event_rewards_kind: str | None = None
|
|
67
|
+
weight_sync: WeightSyncConfig | None = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class EvaluationConfig(ExtraModel):
|
|
71
|
+
instances: int
|
|
72
|
+
every_n_iters: int
|
|
73
|
+
seeds: list[int]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class JudgeOptionsConfig(ExtraModel):
|
|
77
|
+
event: bool | None = None
|
|
78
|
+
outcome: bool | None = None
|
|
79
|
+
provider: str | None = None
|
|
80
|
+
model: str | None = None
|
|
81
|
+
rubric_id: str | None = None
|
|
82
|
+
rubric_overrides: dict[str, Any] | None = None
|
|
83
|
+
tracks: list[str] | None = None
|
|
84
|
+
weights: dict[str, float] | None = None
|
|
85
|
+
max_concurrency: int | None = None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class JudgeConfig(ExtraModel):
|
|
89
|
+
type: str | None = None
|
|
90
|
+
timeout_s: int | None = None
|
|
91
|
+
options: JudgeOptionsConfig | None = None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RLConfig(ExtraModel):
|
|
95
|
+
algorithm: AlgorithmConfig
|
|
96
|
+
services: RLServicesConfig
|
|
97
|
+
compute: ComputeConfig | None = None
|
|
98
|
+
topology: dict[str, Any] | None = None
|
|
99
|
+
vllm: dict[str, Any] | None = None
|
|
100
|
+
reference: dict[str, Any] | None = None
|
|
101
|
+
model: ModelConfig
|
|
102
|
+
lora: dict[str, Any] | None = None
|
|
103
|
+
rollout: RolloutConfig | None = None
|
|
104
|
+
evaluation: EvaluationConfig | None = None
|
|
105
|
+
training: RLTrainingConfig | None = None
|
|
106
|
+
rubric: dict[str, Any] | None = None
|
|
107
|
+
judge: JudgeConfig | None = None
|
|
108
|
+
tags: dict[str, Any] | None = None
|
|
109
|
+
|
|
110
|
+
def to_dict(self) -> dict[str, Any]:
|
|
111
|
+
return self.model_dump(mode="python", exclude_none=True)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_mapping(cls, data: Mapping[str, Any]) -> "RLConfig":
|
|
115
|
+
return cls.model_validate(dict(data))
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def from_path(cls, path: Path) -> "RLConfig":
|
|
119
|
+
content = load_toml(path)
|
|
120
|
+
return cls.from_mapping(content)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
__all__ = [
|
|
124
|
+
"EvaluationConfig",
|
|
125
|
+
"JudgeConfig",
|
|
126
|
+
"JudgeOptionsConfig",
|
|
127
|
+
"ModelConfig",
|
|
128
|
+
"RLConfig",
|
|
129
|
+
"RLServicesConfig",
|
|
130
|
+
"RLTrainingConfig",
|
|
131
|
+
"RolloutConfig",
|
|
132
|
+
"WeightSyncConfig",
|
|
133
|
+
]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Mapping
|
|
5
|
+
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
from ..utils import load_toml
|
|
9
|
+
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class JobConfig(ExtraModel):
|
|
13
|
+
model: str
|
|
14
|
+
data: str | None = None
|
|
15
|
+
data_path: str | None = None
|
|
16
|
+
poll_seconds: int | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SFTDataConfig(ExtraModel):
|
|
20
|
+
topology: dict[str, Any] | None = None
|
|
21
|
+
validation_path: str | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TrainingValidationConfig(ExtraModel):
|
|
25
|
+
enabled: bool | None = None
|
|
26
|
+
evaluation_strategy: str | None = None
|
|
27
|
+
eval_steps: int | None = None
|
|
28
|
+
save_best_model_at_end: bool | None = None
|
|
29
|
+
metric_for_best_model: str | None = None
|
|
30
|
+
greater_is_better: bool | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TrainingConfig(ExtraModel):
|
|
34
|
+
mode: str | None = None
|
|
35
|
+
use_qlora: bool | None = None
|
|
36
|
+
validation: TrainingValidationConfig | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class HyperparametersParallelism(ExtraModel):
|
|
40
|
+
use_deepspeed: bool | None = None
|
|
41
|
+
deepspeed_stage: int | None = None
|
|
42
|
+
fsdp: bool | None = None
|
|
43
|
+
bf16: bool | None = None
|
|
44
|
+
fp16: bool | None = None
|
|
45
|
+
activation_checkpointing: bool | None = None
|
|
46
|
+
tensor_parallel_size: int | None = None
|
|
47
|
+
pipeline_parallel_size: int | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class HyperparametersConfig(ExtraModel):
|
|
51
|
+
n_epochs: int = 1
|
|
52
|
+
batch_size: int | None = None
|
|
53
|
+
global_batch: int | None = None
|
|
54
|
+
per_device_batch: int | None = None
|
|
55
|
+
gradient_accumulation_steps: int | None = None
|
|
56
|
+
sequence_length: int | None = None
|
|
57
|
+
learning_rate: float | None = None
|
|
58
|
+
warmup_ratio: float | None = None
|
|
59
|
+
train_kind: str | None = None
|
|
60
|
+
weight_decay: float | None = None
|
|
61
|
+
parallelism: HyperparametersParallelism | None = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SFTConfig(ExtraModel):
|
|
65
|
+
algorithm: AlgorithmConfig | None = None
|
|
66
|
+
job: JobConfig
|
|
67
|
+
compute: ComputeConfig | None = None
|
|
68
|
+
data: SFTDataConfig | None = None
|
|
69
|
+
training: TrainingConfig | None = None
|
|
70
|
+
hyperparameters: HyperparametersConfig = Field(default_factory=HyperparametersConfig)
|
|
71
|
+
tags: dict[str, Any] | None = None
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> dict[str, Any]:
|
|
74
|
+
return self.model_dump(mode="python", exclude_none=True)
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_mapping(cls, data: Mapping[str, Any]) -> "SFTConfig":
|
|
78
|
+
return cls.model_validate(dict(data))
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_path(cls, path: Path) -> "SFTConfig":
|
|
82
|
+
content = load_toml(path)
|
|
83
|
+
return cls.from_mapping(content)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
__all__ = [
|
|
87
|
+
"HyperparametersConfig",
|
|
88
|
+
"HyperparametersParallelism",
|
|
89
|
+
"JobConfig",
|
|
90
|
+
"SFTConfig",
|
|
91
|
+
"SFTDataConfig",
|
|
92
|
+
"TrainingConfig",
|
|
93
|
+
"TrainingValidationConfig",
|
|
94
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExtraModel(BaseModel):
|
|
7
|
+
"""Base model that tolerates unknown keys so configs keep forward compatibility."""
|
|
8
|
+
|
|
9
|
+
model_config = ConfigDict(extra="allow")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AlgorithmConfig(ExtraModel):
|
|
13
|
+
type: str
|
|
14
|
+
method: str
|
|
15
|
+
variety: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ComputeConfig(ExtraModel):
|
|
19
|
+
gpu_type: str
|
|
20
|
+
gpu_count: int
|
|
21
|
+
nodes: int | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = ["ExtraModel", "AlgorithmConfig", "ComputeConfig"]
|
synth_ai/cli/demo.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
"""
|
|
3
|
-
CLI: interactive launcher for example demos and
|
|
3
|
+
CLI: interactive launcher for example demos and RL demo helpers.
|
|
4
4
|
|
|
5
|
-
- `synth-ai demo` (no subcommand) ->
|
|
6
|
-
- `synth-ai demo deploy|configure|run` ->
|
|
5
|
+
- `synth-ai demo` (no subcommand) -> initialize RL demo files into ./synth_demo/
|
|
6
|
+
- `synth-ai demo deploy|configure|run` -> invoke RL demo helpers directly.
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
@@ -14,6 +14,8 @@ from pathlib import Path
|
|
|
14
14
|
|
|
15
15
|
import click
|
|
16
16
|
|
|
17
|
+
from synth_ai.demos.core import cli as demo_commands
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
def _find_demo_scripts(root: Path) -> list[Path]:
|
|
19
21
|
if not root.exists():
|
|
@@ -21,17 +23,23 @@ def _find_demo_scripts(root: Path) -> list[Path]:
|
|
|
21
23
|
return sorted([p for p in root.rglob("run_demo.sh") if p.is_file()])
|
|
22
24
|
|
|
23
25
|
|
|
24
|
-
def
|
|
25
|
-
|
|
26
|
+
def _run_demo_command(func, *args, **kwargs) -> None:
|
|
27
|
+
"""Invoke a demo command and exit via Click on non-zero status codes."""
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
result = func(*args, **kwargs)
|
|
31
|
+
except SystemExit as exc: # pragma: no cover - defensive
|
|
32
|
+
raise click.exceptions.Exit(exc.code or 1) from exc
|
|
33
|
+
|
|
34
|
+
if result is None:
|
|
35
|
+
return
|
|
26
36
|
|
|
27
37
|
try:
|
|
28
|
-
|
|
29
|
-
except
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
if rc != 0:
|
|
34
|
-
sys.exit(rc)
|
|
38
|
+
code = int(result)
|
|
39
|
+
except (TypeError, ValueError):
|
|
40
|
+
return
|
|
41
|
+
if code != 0:
|
|
42
|
+
raise click.exceptions.Exit(code)
|
|
35
43
|
|
|
36
44
|
|
|
37
45
|
def register(cli):
|
|
@@ -92,11 +100,8 @@ def register(cli):
|
|
|
92
100
|
click.echo("\n🛑 Demo interrupted by user")
|
|
93
101
|
return
|
|
94
102
|
|
|
95
|
-
# Default:
|
|
96
|
-
|
|
97
|
-
if force:
|
|
98
|
-
args.append("--force")
|
|
99
|
-
_forward_to_new(args)
|
|
103
|
+
# Default: initialize RL demo files via new command
|
|
104
|
+
_run_demo_command(demo_commands.init, force=force)
|
|
100
105
|
|
|
101
106
|
# (prepare command removed; configure now prepares baseline TOML)
|
|
102
107
|
|
|
@@ -122,24 +127,21 @@ def register(cli):
|
|
|
122
127
|
help="Path to deploy_task_app.sh (optional legacy)",
|
|
123
128
|
)
|
|
124
129
|
def demo_deploy(local: bool, app: str | None, name: str, script: str | None):
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
if script:
|
|
133
|
-
args.extend(["--script", script])
|
|
134
|
-
_forward_to_new(args)
|
|
130
|
+
_run_demo_command(
|
|
131
|
+
demo_commands.deploy,
|
|
132
|
+
local=local,
|
|
133
|
+
app=app,
|
|
134
|
+
name=name,
|
|
135
|
+
script=script,
|
|
136
|
+
)
|
|
135
137
|
|
|
136
138
|
@_dg.command("configure")
|
|
137
139
|
def demo_configure():
|
|
138
|
-
|
|
140
|
+
_run_demo_command(demo_commands.run)
|
|
139
141
|
|
|
140
142
|
@_dg.command("setup")
|
|
141
143
|
def demo_setup():
|
|
142
|
-
|
|
144
|
+
_run_demo_command(demo_commands.setup)
|
|
143
145
|
|
|
144
146
|
@_dg.command("run")
|
|
145
147
|
@click.option("--batch-size", type=int, default=None)
|
|
@@ -147,13 +149,10 @@ def register(cli):
|
|
|
147
149
|
@click.option("--model", type=str, default=None)
|
|
148
150
|
@click.option("--timeout", type=int, default=600)
|
|
149
151
|
def demo_run(batch_size: int | None, group_size: int | None, model: str | None, timeout: int):
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
if timeout:
|
|
158
|
-
args.extend(["--timeout", str(timeout)])
|
|
159
|
-
_forward_to_new(args)
|
|
152
|
+
_run_demo_command(
|
|
153
|
+
demo_commands.run,
|
|
154
|
+
batch_size=batch_size,
|
|
155
|
+
group_size=group_size,
|
|
156
|
+
model=model,
|
|
157
|
+
timeout=timeout,
|
|
158
|
+
)
|