synth-ai 0.2.8.dev11__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +429 -0
- synth_ai/api/train/config_finder.py +120 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +128 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +2 -2
- synth_ai/cli/root.py +2 -1
- synth_ai/cli/task_apps.py +520 -0
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +31 -25
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +132 -0
- synth_ai/task/client.py +148 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +37 -15
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
8
|
+
import tempfile
|
|
9
|
+
import time
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Iterable, Mapping
|
|
13
|
+
|
|
14
|
+
import requests
|
|
15
|
+
import tomllib
|
|
16
|
+
|
|
17
|
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TrainError(RuntimeError):
|
|
21
|
+
"""Raised for interactive CLI failures."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_toml(path: Path) -> dict[str, Any]:
|
|
25
|
+
try:
|
|
26
|
+
with path.open("rb") as fh:
|
|
27
|
+
return tomllib.load(fh)
|
|
28
|
+
except FileNotFoundError as exc: # pragma: no cover - guarded by CLI
|
|
29
|
+
raise TrainError(f"Config not found: {path}") from exc
|
|
30
|
+
except tomllib.TOMLDecodeError as exc: # pragma: no cover - malformed input
|
|
31
|
+
raise TrainError(f"Failed to parse TOML: {path}\n{exc}") from exc
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def mask_value(value: str | None) -> str:
|
|
35
|
+
if not value:
|
|
36
|
+
return "<unset>"
|
|
37
|
+
value = str(value)
|
|
38
|
+
if len(value) <= 6:
|
|
39
|
+
return "****"
|
|
40
|
+
return f"{value[:4]}…{value[-2:]}"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_ENV_LINE = re.compile(r"^\s*(?:export\s+)?(?P<key>[A-Za-z0-9_]+)\s*=\s*(?P<value>.*)$")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def read_env_file(path: Path) -> dict[str, str]:
|
|
47
|
+
if not path.exists():
|
|
48
|
+
return {}
|
|
49
|
+
data: dict[str, str] = {}
|
|
50
|
+
for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
|
|
51
|
+
m = _ENV_LINE.match(line)
|
|
52
|
+
if not m:
|
|
53
|
+
continue
|
|
54
|
+
raw = m.group("value").strip()
|
|
55
|
+
if raw and raw[0] == raw[-1] and raw[0] in {'"', "'"} and len(raw) >= 2:
|
|
56
|
+
raw = raw[1:-1]
|
|
57
|
+
data[m.group("key")] = raw
|
|
58
|
+
return data
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def write_env_value(path: Path, key: str, value: str) -> None:
|
|
62
|
+
existing = []
|
|
63
|
+
if path.exists():
|
|
64
|
+
existing = path.read_text(encoding="utf-8", errors="ignore").splitlines()
|
|
65
|
+
updated = False
|
|
66
|
+
new_lines: list[str] = []
|
|
67
|
+
for line in existing:
|
|
68
|
+
m = _ENV_LINE.match(line)
|
|
69
|
+
if m and m.group("key") == key:
|
|
70
|
+
new_lines.append(f"{key}={value}")
|
|
71
|
+
updated = True
|
|
72
|
+
else:
|
|
73
|
+
new_lines.append(line)
|
|
74
|
+
if not updated:
|
|
75
|
+
new_lines.append(f"{key}={value}")
|
|
76
|
+
path.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass(slots=True)
|
|
80
|
+
class CLIResult:
|
|
81
|
+
code: int
|
|
82
|
+
stdout: str
|
|
83
|
+
stderr: str
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def run_cli(args: Iterable[str], *, cwd: Path | None = None, env: Mapping[str, str] | None = None, timeout: float | None = None) -> CLIResult:
|
|
87
|
+
proc = subprocess.run(
|
|
88
|
+
list(args),
|
|
89
|
+
cwd=cwd,
|
|
90
|
+
env=dict(os.environ, **(env or {})),
|
|
91
|
+
capture_output=True,
|
|
92
|
+
text=True,
|
|
93
|
+
timeout=timeout,
|
|
94
|
+
)
|
|
95
|
+
return CLIResult(code=proc.returncode, stdout=proc.stdout.strip(), stderr=proc.stderr.strip())
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def http_post(url: str, *, headers: Mapping[str, str] | None = None, json_body: Any | None = None, timeout: float = 60.0) -> requests.Response:
|
|
99
|
+
resp = requests.post(url, headers=dict(headers or {}), json=json_body, timeout=timeout)
|
|
100
|
+
return resp
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def http_get(url: str, *, headers: Mapping[str, str] | None = None, timeout: float = 30.0) -> requests.Response:
|
|
104
|
+
resp = requests.get(url, headers=dict(headers or {}), timeout=timeout)
|
|
105
|
+
return resp
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def post_multipart(url: str, *, api_key: str, file_field: str, file_path: Path, purpose: str = "fine-tune") -> requests.Response:
|
|
109
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
110
|
+
files = {file_field: (file_path.name, file_path.read_bytes(), "application/jsonl")}
|
|
111
|
+
data = {"purpose": purpose}
|
|
112
|
+
return requests.post(url, headers=headers, files=files, data=data, timeout=300)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def fmt_duration(seconds: float) -> str:
|
|
116
|
+
if seconds < 60:
|
|
117
|
+
return f"{seconds:.1f}s"
|
|
118
|
+
minutes, secs = divmod(seconds, 60)
|
|
119
|
+
if minutes < 60:
|
|
120
|
+
return f"{int(minutes)}m{int(secs):02d}s"
|
|
121
|
+
hours, mins = divmod(minutes, 60)
|
|
122
|
+
return f"{int(hours)}h{int(mins):02d}m"
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def validate_sft_jsonl(path: Path, *, max_errors: int = 20) -> None:
|
|
126
|
+
errors: list[str] = []
|
|
127
|
+
try:
|
|
128
|
+
fh = path.open("r", encoding="utf-8")
|
|
129
|
+
except FileNotFoundError as exc: # pragma: no cover - upstream ensures existence
|
|
130
|
+
raise TrainError(f"Dataset not found: {path}") from exc
|
|
131
|
+
|
|
132
|
+
with fh:
|
|
133
|
+
for idx, line in enumerate(fh, start=1):
|
|
134
|
+
stripped = line.strip()
|
|
135
|
+
if not stripped:
|
|
136
|
+
continue
|
|
137
|
+
try:
|
|
138
|
+
record = json.loads(stripped)
|
|
139
|
+
except json.JSONDecodeError as exc:
|
|
140
|
+
errors.append(f"Line {idx}: invalid JSON ({exc.msg})")
|
|
141
|
+
if len(errors) >= max_errors:
|
|
142
|
+
break
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
messages = record.get("messages")
|
|
146
|
+
if not isinstance(messages, list) or not messages:
|
|
147
|
+
errors.append(f"Line {idx}: missing or empty 'messages' list")
|
|
148
|
+
if len(errors) >= max_errors:
|
|
149
|
+
break
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
for msg_idx, msg in enumerate(messages):
|
|
153
|
+
if not isinstance(msg, dict):
|
|
154
|
+
errors.append(f"Line {idx}: message {msg_idx} is not an object")
|
|
155
|
+
break
|
|
156
|
+
if "role" not in msg or "content" not in msg:
|
|
157
|
+
errors.append(f"Line {idx}: message {msg_idx} missing 'role' or 'content'")
|
|
158
|
+
break
|
|
159
|
+
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
|
|
160
|
+
errors.append(f"Line {idx}: message {msg_idx} has non-string role/content")
|
|
161
|
+
break
|
|
162
|
+
if len(errors) >= max_errors:
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
if errors:
|
|
166
|
+
suffix = "" if len(errors) < max_errors else f" (showing first {max_errors} issues)"
|
|
167
|
+
details = "\n - ".join(errors)
|
|
168
|
+
raise TrainError(f"Dataset validation failed{suffix}:\n - {details}")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def limit_jsonl_examples(src: Path, limit: int) -> Path:
|
|
172
|
+
if limit <= 0:
|
|
173
|
+
raise TrainError("Example limit must be positive")
|
|
174
|
+
if not src.exists():
|
|
175
|
+
raise TrainError(f"Dataset not found: {src}")
|
|
176
|
+
|
|
177
|
+
tmp_dir = Path(tempfile.mkdtemp(prefix="sft_subset_"))
|
|
178
|
+
dest = tmp_dir / f"{src.stem}.head{limit}{src.suffix}"
|
|
179
|
+
|
|
180
|
+
written = 0
|
|
181
|
+
with src.open("r", encoding="utf-8") as fin, dest.open("w", encoding="utf-8") as fout:
|
|
182
|
+
for line in fin:
|
|
183
|
+
if not line.strip():
|
|
184
|
+
continue
|
|
185
|
+
fout.write(line)
|
|
186
|
+
written += 1
|
|
187
|
+
if written >= limit:
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
if written == 0:
|
|
191
|
+
raise TrainError("Subset dataset is empty; check limit value")
|
|
192
|
+
|
|
193
|
+
return dest
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def ensure_api_base(base: str) -> str:
|
|
197
|
+
base = base.rstrip("/")
|
|
198
|
+
if not base.endswith("/api"):
|
|
199
|
+
base = f"{base}/api"
|
|
200
|
+
return base
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def preview_json(data: Any, limit: int = 600) -> str:
|
|
204
|
+
try:
|
|
205
|
+
return json.dumps(data, indent=2)[:limit]
|
|
206
|
+
except Exception:
|
|
207
|
+
return str(data)[:limit]
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def sleep(seconds: float) -> None:
|
|
211
|
+
time.sleep(seconds)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
__all__ = [
|
|
215
|
+
"CLIResult",
|
|
216
|
+
"REPO_ROOT",
|
|
217
|
+
"TrainError",
|
|
218
|
+
"ensure_api_base",
|
|
219
|
+
"fmt_duration",
|
|
220
|
+
"http_get",
|
|
221
|
+
"http_post",
|
|
222
|
+
"load_toml",
|
|
223
|
+
"mask_value",
|
|
224
|
+
"post_multipart",
|
|
225
|
+
"preview_json",
|
|
226
|
+
"read_env_file",
|
|
227
|
+
"run_cli",
|
|
228
|
+
"sleep",
|
|
229
|
+
"limit_jsonl_examples",
|
|
230
|
+
"validate_sft_jsonl",
|
|
231
|
+
"write_env_value",
|
|
232
|
+
]
|
synth_ai/cli/__init__.py
CHANGED
|
@@ -75,3 +75,26 @@ try:
|
|
|
75
75
|
_rl_demo.register(cli)
|
|
76
76
|
except Exception:
|
|
77
77
|
pass
|
|
78
|
+
try:
|
|
79
|
+
from synth_ai.api.train import register as _train_register
|
|
80
|
+
|
|
81
|
+
_train_register(cli)
|
|
82
|
+
except Exception:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
from .task_apps import task_app_group
|
|
88
|
+
cli.add_command(task_app_group, name="task-app")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
from . import task_apps as _task_apps
|
|
93
|
+
_task_apps.register(cli)
|
|
94
|
+
except Exception:
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
cli.add_command(task_app_group.commands['serve'], name='serve')
|
|
98
|
+
cli.add_command(task_app_group.commands['deploy'], name='deploy')
|
|
99
|
+
|
|
100
|
+
cli.add_command(task_app_group.commands['modal-serve'], name='modal-serve')
|
synth_ai/cli/rl_demo.py
CHANGED
|
@@ -161,12 +161,12 @@ def register(cli):
|
|
|
161
161
|
_forward(args)
|
|
162
162
|
|
|
163
163
|
# Top-level convenience alias: `synth-ai deploy`
|
|
164
|
-
@cli.command("deploy")
|
|
164
|
+
@cli.command("demo-deploy")
|
|
165
165
|
@click.option("--local", is_flag=True, help="Run local FastAPI instead of Modal deploy")
|
|
166
166
|
@click.option("--app", type=click.Path(), default=None, help="Path to Modal app.py for uv run modal deploy")
|
|
167
167
|
@click.option("--name", type=str, default="synth-math-demo", help="Modal app name")
|
|
168
168
|
@click.option("--script", type=click.Path(), default=None, help="Path to deploy_task_app.sh (optional legacy)")
|
|
169
|
-
def
|
|
169
|
+
def deploy_demo(local: bool, app: str | None, name: str, script: str | None):
|
|
170
170
|
args: list[str] = ["rl_demo.deploy"]
|
|
171
171
|
if local:
|
|
172
172
|
args.append("--local")
|
synth_ai/cli/root.py
CHANGED
|
@@ -164,7 +164,7 @@ def setup():
|
|
|
164
164
|
default=True,
|
|
165
165
|
help="Kill any process already bound to --env-port without prompting",
|
|
166
166
|
)
|
|
167
|
-
def
|
|
167
|
+
def serve_deprecated(
|
|
168
168
|
db_file: str,
|
|
169
169
|
sqld_port: int,
|
|
170
170
|
env_port: int,
|
|
@@ -174,6 +174,7 @@ def serve(
|
|
|
174
174
|
force: bool,
|
|
175
175
|
):
|
|
176
176
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
177
|
+
click.echo("⚠️ 'synth-ai serve' now targets task apps; use 'synth-ai serve' for task apps or 'synth-ai serve-deprecated' for this legacy service.", err=True)
|
|
177
178
|
processes = []
|
|
178
179
|
|
|
179
180
|
def signal_handler(sig, frame):
|