synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/agora_ex/README_MoE.md +224 -0
- examples/agora_ex/__init__.py +7 -0
- examples/agora_ex/agora_ex.py +65 -0
- examples/agora_ex/agora_ex_task_app.py +590 -0
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
- examples/agora_ex/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/system_prompt_CURRENT.md +63 -0
- examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
- examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +494 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
- examples/warming_up_to_rl/run_eval.py +267 -41
- examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +74 -33
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +133 -0
- synth_ai/api/train/configs/sft.py +94 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/demo.py +38 -39
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/rl_demo.py +81 -102
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +146 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +85 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +15 -3
- synth_ai/judge_schemas.py +127 -0
- synth_ai/rubrics/__init__.py +22 -0
- synth_ai/rubrics/validators.py +126 -0
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
synth_ai/__init__.py
CHANGED
|
@@ -2,6 +2,28 @@
|
|
|
2
2
|
Synth AI - Software for aiding the best and multiplying the will.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from importlib import metadata as _metadata
|
|
8
|
+
from importlib.metadata import PackageNotFoundError
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
try: # Prefer the installed package metadata when available
|
|
12
|
+
__version__ = _metadata.version("synth-ai")
|
|
13
|
+
except PackageNotFoundError: # Fallback to pyproject version for editable installs
|
|
14
|
+
try:
|
|
15
|
+
import tomllib as _toml # Python 3.11+
|
|
16
|
+
except ModuleNotFoundError: # pragma: no cover - legacy interpreter guard
|
|
17
|
+
import tomli as _toml # type: ignore[no-redef]
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
|
21
|
+
with pyproject_path.open("rb") as fh:
|
|
22
|
+
_pyproject = _toml.load(fh)
|
|
23
|
+
__version__ = str(_pyproject["project"]["version"])
|
|
24
|
+
except Exception:
|
|
25
|
+
__version__ = "0.0.0.dev0"
|
|
26
|
+
|
|
5
27
|
# Environment exports - moved from synth-env
|
|
6
28
|
from synth_ai.environments import * # noqa
|
|
7
29
|
import synth_ai.environments as environments # expose module name for __all__
|
|
@@ -21,12 +43,22 @@ try:
|
|
|
21
43
|
except Exception:
|
|
22
44
|
AsyncOpenAI = OpenAI = None # type: ignore
|
|
23
45
|
|
|
46
|
+
# Judge API contract schemas
|
|
47
|
+
from synth_ai.judge_schemas import (
|
|
48
|
+
JudgeScoreRequest,
|
|
49
|
+
JudgeScoreResponse,
|
|
50
|
+
JudgeOptions,
|
|
51
|
+
JudgeTaskApp,
|
|
52
|
+
JudgeTracePayload,
|
|
53
|
+
ReviewPayload,
|
|
54
|
+
CriterionScorePayload,
|
|
55
|
+
)
|
|
56
|
+
|
|
24
57
|
# Legacy tracing v1 is not required for v3 usage and can be unavailable in minimal envs.
|
|
25
58
|
tracing = None # type: ignore
|
|
26
59
|
EventPartitionElement = RewardSignal = SystemTrace = TrainingQuestion = None # type: ignore
|
|
27
60
|
trace_event_async = trace_event_sync = upload = None # type: ignore
|
|
28
61
|
|
|
29
|
-
__version__ = "0.2.6.dev4"
|
|
30
62
|
__all__ = [
|
|
31
63
|
"LM",
|
|
32
64
|
"OpenAI",
|
|
@@ -34,4 +66,12 @@ __all__ = [
|
|
|
34
66
|
"Anthropic",
|
|
35
67
|
"AsyncAnthropic",
|
|
36
68
|
"environments",
|
|
69
|
+
# Judge API contracts
|
|
70
|
+
"JudgeScoreRequest",
|
|
71
|
+
"JudgeScoreResponse",
|
|
72
|
+
"JudgeOptions",
|
|
73
|
+
"JudgeTaskApp",
|
|
74
|
+
"JudgeTracePayload",
|
|
75
|
+
"ReviewPayload",
|
|
76
|
+
"CriterionScorePayload",
|
|
37
77
|
] # Explicitly define public API (v1 tracing omitted in minimal env)
|
synth_ai/api/train/builders.py
CHANGED
|
@@ -1,23 +1,33 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any, cast
|
|
6
7
|
|
|
7
8
|
import click
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
9
|
+
from pydantic import ValidationError
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
_models_module = importlib.import_module("synth_ai.api.models.supported")
|
|
13
|
+
UnsupportedModelError = _models_module.UnsupportedModelError
|
|
14
|
+
ensure_allowed_model = _models_module.ensure_allowed_model
|
|
15
|
+
normalize_model_identifier = _models_module.normalize_model_identifier
|
|
16
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
17
|
+
raise RuntimeError("Unable to load supported model helpers") from exc
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
prepare_sft_job_payload = importlib.import_module("synth_ai.learning.sft.config").prepare_sft_job_payload
|
|
21
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
22
|
+
raise RuntimeError("Unable to load SFT payload helpers") from exc
|
|
14
23
|
|
|
15
24
|
from .supported_algos import (
|
|
16
25
|
AlgorithmValidationError,
|
|
17
26
|
ensure_model_supported_for_algorithm,
|
|
18
27
|
validate_algorithm_config,
|
|
19
28
|
)
|
|
20
|
-
from .utils import TrainError, ensure_api_base
|
|
29
|
+
from .utils import TrainError, ensure_api_base
|
|
30
|
+
from .configs import RLConfig, SFTConfig
|
|
21
31
|
|
|
22
32
|
|
|
23
33
|
@dataclass(slots=True)
|
|
@@ -34,6 +44,16 @@ class SFTBuildResult:
|
|
|
34
44
|
validation_file: Path | None
|
|
35
45
|
|
|
36
46
|
|
|
47
|
+
def _format_validation_error(path: Path, exc: ValidationError) -> str:
|
|
48
|
+
lines: list[str] = []
|
|
49
|
+
for error in exc.errors():
|
|
50
|
+
loc = ".".join(str(part) for part in error.get("loc", ()))
|
|
51
|
+
msg = error.get("msg", "invalid value")
|
|
52
|
+
lines.append(f"{loc or '<root>'}: {msg}")
|
|
53
|
+
details = "\n".join(f" - {line}" for line in lines) or " - Invalid configuration"
|
|
54
|
+
return f"Config validation failed ({path}):\n{details}"
|
|
55
|
+
|
|
56
|
+
|
|
37
57
|
def build_rl_payload(
|
|
38
58
|
*,
|
|
39
59
|
config_path: Path,
|
|
@@ -42,13 +62,30 @@ def build_rl_payload(
|
|
|
42
62
|
idempotency: str | None,
|
|
43
63
|
allow_experimental: bool | None = None,
|
|
44
64
|
) -> RLBuildResult:
|
|
45
|
-
data = load_toml(config_path)
|
|
46
65
|
try:
|
|
47
|
-
|
|
66
|
+
rl_cfg = RLConfig.from_path(config_path)
|
|
67
|
+
except ValidationError as exc:
|
|
68
|
+
raise click.ClickException(_format_validation_error(config_path, exc)) from exc
|
|
69
|
+
|
|
70
|
+
data = rl_cfg.to_dict()
|
|
71
|
+
# Ensure required [reference] section for backend validators
|
|
72
|
+
try:
|
|
73
|
+
ref_cfg = data.get("reference") if isinstance(data, dict) else None
|
|
74
|
+
if not isinstance(ref_cfg, dict):
|
|
75
|
+
data["reference"] = {"placement": "none"}
|
|
76
|
+
else:
|
|
77
|
+
ref_cfg.setdefault("placement", "none")
|
|
78
|
+
except Exception:
|
|
79
|
+
# Defensive: never fail builder due to optional defaults
|
|
80
|
+
data["reference"] = {"placement": "none"}
|
|
81
|
+
try:
|
|
82
|
+
spec = validate_algorithm_config(
|
|
83
|
+
rl_cfg.algorithm.model_dump(), expected_family="rl"
|
|
84
|
+
)
|
|
48
85
|
except AlgorithmValidationError as exc:
|
|
49
86
|
raise click.ClickException(str(exc)) from exc
|
|
50
87
|
services = data.get("services") if isinstance(data.get("services"), dict) else {}
|
|
51
|
-
model_cfg =
|
|
88
|
+
model_cfg = rl_cfg.model
|
|
52
89
|
|
|
53
90
|
final_task_url = (
|
|
54
91
|
overrides.get("task_url")
|
|
@@ -61,10 +98,8 @@ def build_rl_payload(
|
|
|
61
98
|
"Task app URL required (provide --task-url or set services.task_url in TOML)"
|
|
62
99
|
)
|
|
63
100
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
raw_base = model_cfg.get("base") if isinstance(model_cfg, dict) else ""
|
|
67
|
-
model_base = str(raw_base or "").strip()
|
|
101
|
+
model_source = (model_cfg.source or "").strip()
|
|
102
|
+
model_base = (model_cfg.base or "").strip()
|
|
68
103
|
override_model = (overrides.get("model") or "").strip()
|
|
69
104
|
if override_model:
|
|
70
105
|
model_source = override_model
|
|
@@ -122,23 +157,26 @@ def build_rl_payload(
|
|
|
122
157
|
except Exception:
|
|
123
158
|
pass
|
|
124
159
|
|
|
160
|
+
payload_data: dict[str, Any] = {
|
|
161
|
+
"endpoint_base_url": final_task_url.rstrip("/"),
|
|
162
|
+
"config": data,
|
|
163
|
+
}
|
|
125
164
|
payload: dict[str, Any] = {
|
|
126
165
|
"job_type": "rl",
|
|
127
166
|
"compute": data.get("compute", {}),
|
|
128
|
-
"data":
|
|
129
|
-
"endpoint_base_url": final_task_url.rstrip("/"),
|
|
130
|
-
"config": data,
|
|
131
|
-
},
|
|
167
|
+
"data": payload_data,
|
|
132
168
|
"tags": {"source": "train-cli"},
|
|
133
169
|
}
|
|
134
170
|
if model_source:
|
|
135
|
-
|
|
171
|
+
payload_data["model"] = model_source
|
|
136
172
|
if model_base:
|
|
137
|
-
|
|
173
|
+
payload_data["base_model"] = model_base
|
|
138
174
|
|
|
139
175
|
backend = overrides.get("backend")
|
|
140
176
|
if backend:
|
|
141
|
-
|
|
177
|
+
metadata_default: dict[str, Any] = {}
|
|
178
|
+
metadata = cast(dict[str, Any], payload.setdefault("metadata", metadata_default))
|
|
179
|
+
metadata["backend_base_url"] = ensure_api_base(str(backend))
|
|
142
180
|
|
|
143
181
|
return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
|
|
144
182
|
|
|
@@ -149,22 +187,23 @@ def build_sft_payload(
|
|
|
149
187
|
dataset_override: Path | None,
|
|
150
188
|
allow_experimental: bool | None,
|
|
151
189
|
) -> SFTBuildResult:
|
|
152
|
-
data = load_toml(config_path)
|
|
153
190
|
try:
|
|
154
|
-
|
|
191
|
+
sft_cfg = SFTConfig.from_path(config_path)
|
|
192
|
+
except ValidationError as exc:
|
|
193
|
+
raise TrainError(_format_validation_error(config_path, exc)) from exc
|
|
194
|
+
|
|
195
|
+
data = sft_cfg.to_dict()
|
|
196
|
+
try:
|
|
197
|
+
algo_mapping = sft_cfg.algorithm.model_dump() if sft_cfg.algorithm else None
|
|
198
|
+
spec = validate_algorithm_config(algo_mapping, expected_family="sft")
|
|
155
199
|
except AlgorithmValidationError as exc:
|
|
156
200
|
raise TrainError(str(exc)) from exc
|
|
157
|
-
job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
|
|
158
201
|
data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
|
|
159
202
|
hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
|
|
160
203
|
train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
|
|
161
204
|
compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
|
|
162
205
|
|
|
163
|
-
raw_dataset =
|
|
164
|
-
dataset_override
|
|
165
|
-
or (job_cfg.get("data") if isinstance(job_cfg, dict) else None)
|
|
166
|
-
or (job_cfg.get("data_path") if isinstance(job_cfg, dict) else None)
|
|
167
|
-
)
|
|
206
|
+
raw_dataset = dataset_override or sft_cfg.job.data or sft_cfg.job.data_path
|
|
168
207
|
if not raw_dataset:
|
|
169
208
|
raise TrainError("Dataset not specified; pass --dataset or set [job].data")
|
|
170
209
|
dataset_path = Path(raw_dataset)
|
|
@@ -249,9 +288,11 @@ def build_sft_payload(
|
|
|
249
288
|
"enabled": bool(validation_cfg.get("enabled", True))
|
|
250
289
|
}
|
|
251
290
|
|
|
252
|
-
raw_model =
|
|
253
|
-
|
|
254
|
-
|
|
291
|
+
raw_model = (sft_cfg.job.model or "").strip()
|
|
292
|
+
if not raw_model:
|
|
293
|
+
model_block = data.get("model")
|
|
294
|
+
if isinstance(model_block, str):
|
|
295
|
+
raw_model = model_block.strip()
|
|
255
296
|
if not raw_model:
|
|
256
297
|
raise TrainError("Model not specified; set [job].model or [model].base in the config")
|
|
257
298
|
|
synth_ai/api/train/cli.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
import os
|
|
5
|
+
from collections.abc import Mapping
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
from typing import Any
|
|
6
8
|
|
|
7
9
|
import click
|
|
8
|
-
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
_config_module = importlib.import_module("synth_ai.config.base_url")
|
|
13
|
+
get_backend_from_env = _config_module.get_backend_from_env
|
|
14
|
+
except Exception as exc: # pragma: no cover - critical dependency
|
|
15
|
+
raise RuntimeError("Unable to load backend configuration helpers") from exc
|
|
9
16
|
|
|
10
17
|
from .builders import build_rl_payload, build_sft_payload
|
|
11
18
|
from .config_finder import discover_configs, prompt_for_config
|
|
@@ -231,7 +238,8 @@ def train_command(
|
|
|
231
238
|
]
|
|
232
239
|
if missing_keys:
|
|
233
240
|
try:
|
|
234
|
-
|
|
241
|
+
_task_apps_module = importlib.import_module("synth_ai.cli.task_apps")
|
|
242
|
+
_interactive_fill_env = _task_apps_module._interactive_fill_env
|
|
235
243
|
except Exception as exc: # pragma: no cover - protective fallback
|
|
236
244
|
raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
|
|
237
245
|
|
|
@@ -386,9 +394,19 @@ def handle_rl(
|
|
|
386
394
|
verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url}
|
|
387
395
|
)
|
|
388
396
|
try:
|
|
389
|
-
|
|
397
|
+
parsed_json = vresp.json()
|
|
390
398
|
except Exception:
|
|
391
|
-
|
|
399
|
+
parsed_json = None
|
|
400
|
+
|
|
401
|
+
if isinstance(parsed_json, Mapping):
|
|
402
|
+
vjs: dict[str, Any] = dict(parsed_json)
|
|
403
|
+
else:
|
|
404
|
+
vjs = {
|
|
405
|
+
"status": vresp.status_code,
|
|
406
|
+
"text": (vresp.text or "")[:400],
|
|
407
|
+
}
|
|
408
|
+
if parsed_json is not None:
|
|
409
|
+
vjs["body"] = parsed_json
|
|
392
410
|
except Exception as _ve:
|
|
393
411
|
raise click.ClickException(
|
|
394
412
|
f"Task app verification call failed: {type(_ve).__name__}: {_ve}"
|
|
@@ -404,8 +422,13 @@ def handle_rl(
|
|
|
404
422
|
# Print concise summary
|
|
405
423
|
try:
|
|
406
424
|
cands = vjs.get("candidates_first15") or []
|
|
407
|
-
|
|
408
|
-
|
|
425
|
+
attempts_raw = vjs.get("attempts")
|
|
426
|
+
attempts: list[Mapping[str, Any]] = (
|
|
427
|
+
[a for a in attempts_raw if isinstance(a, Mapping)]
|
|
428
|
+
if isinstance(attempts_raw, list)
|
|
429
|
+
else []
|
|
430
|
+
)
|
|
431
|
+
statuses = [attempt.get("status") for attempt in attempts]
|
|
409
432
|
click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
|
|
410
433
|
except Exception:
|
|
411
434
|
pass
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Typed training config loaders for RL and SFT jobs."""
|
|
2
|
+
|
|
3
|
+
from .shared import AlgorithmConfig, ComputeConfig
|
|
4
|
+
from .sft import (
|
|
5
|
+
HyperparametersConfig,
|
|
6
|
+
HyperparametersParallelism,
|
|
7
|
+
JobConfig,
|
|
8
|
+
SFTConfig,
|
|
9
|
+
SFTDataConfig,
|
|
10
|
+
TrainingConfig,
|
|
11
|
+
TrainingValidationConfig,
|
|
12
|
+
)
|
|
13
|
+
from .rl import (
|
|
14
|
+
EvaluationConfig,
|
|
15
|
+
JudgeConfig,
|
|
16
|
+
JudgeOptionsConfig,
|
|
17
|
+
ModelConfig,
|
|
18
|
+
RLConfig,
|
|
19
|
+
RLServicesConfig,
|
|
20
|
+
RLTrainingConfig,
|
|
21
|
+
RolloutConfig,
|
|
22
|
+
WeightSyncConfig,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"AlgorithmConfig",
|
|
27
|
+
"ComputeConfig",
|
|
28
|
+
"EvaluationConfig",
|
|
29
|
+
"HyperparametersConfig",
|
|
30
|
+
"HyperparametersParallelism",
|
|
31
|
+
"JobConfig",
|
|
32
|
+
"JudgeConfig",
|
|
33
|
+
"JudgeOptionsConfig",
|
|
34
|
+
"ModelConfig",
|
|
35
|
+
"RLConfig",
|
|
36
|
+
"RLServicesConfig",
|
|
37
|
+
"RLTrainingConfig",
|
|
38
|
+
"RolloutConfig",
|
|
39
|
+
"SFTConfig",
|
|
40
|
+
"SFTDataConfig",
|
|
41
|
+
"TrainingConfig",
|
|
42
|
+
"TrainingValidationConfig",
|
|
43
|
+
"WeightSyncConfig",
|
|
44
|
+
]
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Mapping
|
|
5
|
+
|
|
6
|
+
from pydantic import model_validator
|
|
7
|
+
|
|
8
|
+
from ..utils import load_toml
|
|
9
|
+
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RLServicesConfig(ExtraModel):
|
|
13
|
+
task_url: str
|
|
14
|
+
judge_url: str | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelConfig(ExtraModel):
|
|
18
|
+
source: str | None = None
|
|
19
|
+
base: str | None = None
|
|
20
|
+
trainer_mode: str
|
|
21
|
+
label: str
|
|
22
|
+
|
|
23
|
+
@model_validator(mode="after")
|
|
24
|
+
def _ensure_exactly_one_source_or_base(self) -> "ModelConfig":
|
|
25
|
+
if bool(self.source) == bool(self.base):
|
|
26
|
+
raise ValueError("Config must set exactly one of [model].source or [model].base")
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RolloutConfig(ExtraModel):
|
|
31
|
+
env_name: str
|
|
32
|
+
policy_name: str
|
|
33
|
+
env_config: dict[str, Any] | None = None
|
|
34
|
+
policy_config: dict[str, Any] | None = None
|
|
35
|
+
max_turns: int
|
|
36
|
+
episodes_per_batch: int
|
|
37
|
+
max_concurrent_rollouts: int
|
|
38
|
+
batches_per_step: int | None = None
|
|
39
|
+
ops: list[str] | None = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class WeightSyncConfig(ExtraModel):
|
|
43
|
+
enable: bool | None = None
|
|
44
|
+
targets: list[str] | None = None
|
|
45
|
+
mode: str | None = None
|
|
46
|
+
direct: bool | None = None
|
|
47
|
+
verify_every_k: int | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RLTrainingConfig(ExtraModel):
|
|
51
|
+
num_epochs: int
|
|
52
|
+
iterations_per_epoch: int
|
|
53
|
+
gradient_accumulation_steps: int | None = None
|
|
54
|
+
max_accumulated_minibatch: int | None = None
|
|
55
|
+
max_turns: int
|
|
56
|
+
batch_size: int
|
|
57
|
+
group_size: int
|
|
58
|
+
learning_rate: float
|
|
59
|
+
log_interval: int | None = None
|
|
60
|
+
weight_sync_interval: int | None = None
|
|
61
|
+
step_rewards_enabled: bool | None = None
|
|
62
|
+
step_rewards_mode: str | None = None
|
|
63
|
+
step_rewards_indicator_lambda: float | None = None
|
|
64
|
+
step_rewards_beta: float | None = None
|
|
65
|
+
step_rewards_strategy: str | None = None
|
|
66
|
+
event_rewards_kind: str | None = None
|
|
67
|
+
weight_sync: WeightSyncConfig | None = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class EvaluationConfig(ExtraModel):
|
|
71
|
+
instances: int
|
|
72
|
+
every_n_iters: int
|
|
73
|
+
seeds: list[int]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class JudgeOptionsConfig(ExtraModel):
|
|
77
|
+
event: bool | None = None
|
|
78
|
+
outcome: bool | None = None
|
|
79
|
+
provider: str | None = None
|
|
80
|
+
model: str | None = None
|
|
81
|
+
rubric_id: str | None = None
|
|
82
|
+
rubric_overrides: dict[str, Any] | None = None
|
|
83
|
+
tracks: list[str] | None = None
|
|
84
|
+
weights: dict[str, float] | None = None
|
|
85
|
+
max_concurrency: int | None = None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class JudgeConfig(ExtraModel):
|
|
89
|
+
type: str | None = None
|
|
90
|
+
timeout_s: int | None = None
|
|
91
|
+
options: JudgeOptionsConfig | None = None
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RLConfig(ExtraModel):
|
|
95
|
+
algorithm: AlgorithmConfig
|
|
96
|
+
services: RLServicesConfig
|
|
97
|
+
compute: ComputeConfig | None = None
|
|
98
|
+
topology: dict[str, Any] | None = None
|
|
99
|
+
vllm: dict[str, Any] | None = None
|
|
100
|
+
reference: dict[str, Any] | None = None
|
|
101
|
+
model: ModelConfig
|
|
102
|
+
lora: dict[str, Any] | None = None
|
|
103
|
+
rollout: RolloutConfig | None = None
|
|
104
|
+
evaluation: EvaluationConfig | None = None
|
|
105
|
+
training: RLTrainingConfig | None = None
|
|
106
|
+
rubric: dict[str, Any] | None = None
|
|
107
|
+
judge: JudgeConfig | None = None
|
|
108
|
+
tags: dict[str, Any] | None = None
|
|
109
|
+
|
|
110
|
+
def to_dict(self) -> dict[str, Any]:
|
|
111
|
+
return self.model_dump(mode="python", exclude_none=True)
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def from_mapping(cls, data: Mapping[str, Any]) -> "RLConfig":
|
|
115
|
+
return cls.model_validate(dict(data))
|
|
116
|
+
|
|
117
|
+
@classmethod
|
|
118
|
+
def from_path(cls, path: Path) -> "RLConfig":
|
|
119
|
+
content = load_toml(path)
|
|
120
|
+
return cls.from_mapping(content)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
__all__ = [
|
|
124
|
+
"EvaluationConfig",
|
|
125
|
+
"JudgeConfig",
|
|
126
|
+
"JudgeOptionsConfig",
|
|
127
|
+
"ModelConfig",
|
|
128
|
+
"RLConfig",
|
|
129
|
+
"RLServicesConfig",
|
|
130
|
+
"RLTrainingConfig",
|
|
131
|
+
"RolloutConfig",
|
|
132
|
+
"WeightSyncConfig",
|
|
133
|
+
]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Mapping
|
|
5
|
+
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
from ..utils import load_toml
|
|
9
|
+
from .shared import AlgorithmConfig, ComputeConfig, ExtraModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class JobConfig(ExtraModel):
|
|
13
|
+
model: str
|
|
14
|
+
data: str | None = None
|
|
15
|
+
data_path: str | None = None
|
|
16
|
+
poll_seconds: int | None = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SFTDataConfig(ExtraModel):
|
|
20
|
+
topology: dict[str, Any] | None = None
|
|
21
|
+
validation_path: str | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TrainingValidationConfig(ExtraModel):
|
|
25
|
+
enabled: bool | None = None
|
|
26
|
+
evaluation_strategy: str | None = None
|
|
27
|
+
eval_steps: int | None = None
|
|
28
|
+
save_best_model_at_end: bool | None = None
|
|
29
|
+
metric_for_best_model: str | None = None
|
|
30
|
+
greater_is_better: bool | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TrainingConfig(ExtraModel):
|
|
34
|
+
mode: str | None = None
|
|
35
|
+
use_qlora: bool | None = None
|
|
36
|
+
validation: TrainingValidationConfig | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class HyperparametersParallelism(ExtraModel):
|
|
40
|
+
use_deepspeed: bool | None = None
|
|
41
|
+
deepspeed_stage: int | None = None
|
|
42
|
+
fsdp: bool | None = None
|
|
43
|
+
bf16: bool | None = None
|
|
44
|
+
fp16: bool | None = None
|
|
45
|
+
activation_checkpointing: bool | None = None
|
|
46
|
+
tensor_parallel_size: int | None = None
|
|
47
|
+
pipeline_parallel_size: int | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class HyperparametersConfig(ExtraModel):
|
|
51
|
+
n_epochs: int = 1
|
|
52
|
+
batch_size: int | None = None
|
|
53
|
+
global_batch: int | None = None
|
|
54
|
+
per_device_batch: int | None = None
|
|
55
|
+
gradient_accumulation_steps: int | None = None
|
|
56
|
+
sequence_length: int | None = None
|
|
57
|
+
learning_rate: float | None = None
|
|
58
|
+
warmup_ratio: float | None = None
|
|
59
|
+
train_kind: str | None = None
|
|
60
|
+
weight_decay: float | None = None
|
|
61
|
+
parallelism: HyperparametersParallelism | None = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SFTConfig(ExtraModel):
|
|
65
|
+
algorithm: AlgorithmConfig | None = None
|
|
66
|
+
job: JobConfig
|
|
67
|
+
compute: ComputeConfig | None = None
|
|
68
|
+
data: SFTDataConfig | None = None
|
|
69
|
+
training: TrainingConfig | None = None
|
|
70
|
+
hyperparameters: HyperparametersConfig = Field(default_factory=HyperparametersConfig)
|
|
71
|
+
tags: dict[str, Any] | None = None
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> dict[str, Any]:
|
|
74
|
+
return self.model_dump(mode="python", exclude_none=True)
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_mapping(cls, data: Mapping[str, Any]) -> "SFTConfig":
|
|
78
|
+
return cls.model_validate(dict(data))
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_path(cls, path: Path) -> "SFTConfig":
|
|
82
|
+
content = load_toml(path)
|
|
83
|
+
return cls.from_mapping(content)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
__all__ = [
|
|
87
|
+
"HyperparametersConfig",
|
|
88
|
+
"HyperparametersParallelism",
|
|
89
|
+
"JobConfig",
|
|
90
|
+
"SFTConfig",
|
|
91
|
+
"SFTDataConfig",
|
|
92
|
+
"TrainingConfig",
|
|
93
|
+
"TrainingValidationConfig",
|
|
94
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExtraModel(BaseModel):
|
|
7
|
+
"""Base model that tolerates unknown keys so configs keep forward compatibility."""
|
|
8
|
+
|
|
9
|
+
model_config = ConfigDict(extra="allow")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AlgorithmConfig(ExtraModel):
|
|
13
|
+
type: str
|
|
14
|
+
method: str
|
|
15
|
+
variety: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ComputeConfig(ExtraModel):
|
|
19
|
+
gpu_type: str
|
|
20
|
+
gpu_count: int
|
|
21
|
+
nodes: int | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = ["ExtraModel", "AlgorithmConfig", "ComputeConfig"]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import importlib
|
|
3
4
|
import os
|
|
4
5
|
from collections.abc import Callable, Iterable, MutableMapping
|
|
5
6
|
from dataclasses import dataclass
|
|
@@ -11,6 +12,18 @@ from . import task_app
|
|
|
11
12
|
from .utils import REPO_ROOT, mask_value, read_env_file, write_env_value
|
|
12
13
|
|
|
13
14
|
|
|
15
|
+
def _load_saved_env_path() -> Path | None:
|
|
16
|
+
try:
|
|
17
|
+
module = importlib.import_module("synth_ai.demos.demo_task_apps.core")
|
|
18
|
+
loader = module.load_env_file_path
|
|
19
|
+
saved_path = loader()
|
|
20
|
+
if saved_path:
|
|
21
|
+
return Path(saved_path)
|
|
22
|
+
except Exception:
|
|
23
|
+
return None
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
|
|
14
27
|
@dataclass(slots=True)
|
|
15
28
|
class KeySpec:
|
|
16
29
|
name: str
|
|
@@ -156,25 +169,11 @@ def resolve_env(
|
|
|
156
169
|
raise click.ClickException(f"Env file not found: {path}")
|
|
157
170
|
resolver = EnvResolver(provided)
|
|
158
171
|
else:
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
if saved_env_path:
|
|
165
|
-
saved_path = Path(saved_env_path)
|
|
166
|
-
if saved_path.exists():
|
|
167
|
-
click.echo(f"Using .env file: {saved_path}")
|
|
168
|
-
resolver = EnvResolver([saved_path])
|
|
169
|
-
else:
|
|
170
|
-
# Saved path no longer exists, fall back to prompt
|
|
171
|
-
resolver = EnvResolver(_collect_default_candidates(config_path))
|
|
172
|
-
resolver.select_new_env()
|
|
173
|
-
else:
|
|
174
|
-
resolver = EnvResolver(_collect_default_candidates(config_path))
|
|
175
|
-
resolver.select_new_env()
|
|
176
|
-
except Exception:
|
|
177
|
-
# If import fails or any error, fall back to original behavior
|
|
172
|
+
saved_path = _load_saved_env_path()
|
|
173
|
+
if saved_path and saved_path.exists():
|
|
174
|
+
click.echo(f"Using .env file: {saved_path}")
|
|
175
|
+
resolver = EnvResolver([saved_path])
|
|
176
|
+
else:
|
|
178
177
|
resolver = EnvResolver(_collect_default_candidates(config_path))
|
|
179
178
|
resolver.select_new_env()
|
|
180
179
|
|