synth-ai 0.2.8.dev11__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (37) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +429 -0
  4. synth_ai/api/train/config_finder.py +120 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +128 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +2 -2
  11. synth_ai/cli/root.py +2 -1
  12. synth_ai/cli/task_apps.py +520 -0
  13. synth_ai/demos/demo_task_apps/math/modal_task_app.py +31 -25
  14. synth_ai/task/__init__.py +94 -1
  15. synth_ai/task/apps/__init__.py +88 -0
  16. synth_ai/task/apps/grpo_crafter.py +438 -0
  17. synth_ai/task/apps/math_single_step.py +852 -0
  18. synth_ai/task/auth.py +132 -0
  19. synth_ai/task/client.py +148 -0
  20. synth_ai/task/contracts.py +29 -14
  21. synth_ai/task/datasets.py +105 -0
  22. synth_ai/task/errors.py +49 -0
  23. synth_ai/task/json.py +77 -0
  24. synth_ai/task/proxy.py +258 -0
  25. synth_ai/task/rubrics.py +212 -0
  26. synth_ai/task/server.py +398 -0
  27. synth_ai/task/tracing_utils.py +79 -0
  28. synth_ai/task/vendors.py +61 -0
  29. synth_ai/tracing_v3/session_tracer.py +13 -5
  30. synth_ai/tracing_v3/storage/base.py +10 -12
  31. synth_ai/tracing_v3/turso/manager.py +20 -6
  32. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
  33. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +37 -15
  34. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
  35. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
  36. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
  37. {synth_ai-0.2.8.dev11.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .cli import register, train_command
4
+
5
+ __all__ = ["register", "train_command"]
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import click
8
+
9
+ from .utils import ensure_api_base, load_toml, TrainError
10
+
11
+
12
+ @dataclass(slots=True)
13
+ class RLBuildResult:
14
+ payload: dict[str, Any]
15
+ task_url: str
16
+ idempotency: str | None
17
+
18
+
19
+ @dataclass(slots=True)
20
+ class SFTBuildResult:
21
+ payload: dict[str, Any]
22
+ train_file: Path
23
+ validation_file: Path | None
24
+
25
+
26
+ def build_rl_payload(
27
+ *,
28
+ config_path: Path,
29
+ task_url: str,
30
+ overrides: dict[str, Any],
31
+ idempotency: str | None,
32
+ ) -> RLBuildResult:
33
+ data = load_toml(config_path)
34
+ services = data.get("services") if isinstance(data.get("services"), dict) else {}
35
+ model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
36
+
37
+ final_task_url = (overrides.get("task_url") or task_url or services.get("task_url") or "").strip()
38
+ if not final_task_url:
39
+ raise click.ClickException("Task app URL required (provide --task-url or set services.task_url in TOML)")
40
+
41
+ model_source = (model_cfg.get("source") or "").strip()
42
+ model_base = (model_cfg.get("base") or "").strip()
43
+ override_model = (overrides.get("model") or "").strip()
44
+ if override_model:
45
+ model_source = override_model
46
+ model_base = ""
47
+ if bool(model_source) == bool(model_base):
48
+ raise click.ClickException("Model section must specify exactly one of [model].source or [model].base")
49
+
50
+ # Force TOML services.task_url to the effective endpoint to avoid split URLs
51
+ try:
52
+ if isinstance(data.get("services"), dict):
53
+ data["services"]["task_url"] = final_task_url
54
+ else:
55
+ data["services"] = {"task_url": final_task_url}
56
+ except Exception:
57
+ pass
58
+
59
+ payload: dict[str, Any] = {
60
+ "job_type": "rl",
61
+ "compute": data.get("compute", {}),
62
+ "data": {
63
+ "endpoint_base_url": final_task_url.rstrip("/"),
64
+ "config": data,
65
+ },
66
+ "tags": {"source": "train-cli"},
67
+ }
68
+ if model_source:
69
+ payload["data"]["model"] = model_source
70
+ if model_base:
71
+ payload["data"]["base_model"] = model_base
72
+
73
+ backend = overrides.get("backend")
74
+ if backend:
75
+ payload.setdefault("metadata", {})["backend_base_url"] = ensure_api_base(str(backend))
76
+
77
+ return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
78
+
79
+
80
+ def build_sft_payload(
81
+ *,
82
+ config_path: Path,
83
+ dataset_override: Path | None,
84
+ ) -> SFTBuildResult:
85
+ data = load_toml(config_path)
86
+ job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
87
+ data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
88
+ hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
89
+ train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
90
+ compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
91
+
92
+ raw_dataset = dataset_override or job_cfg.get("data") or job_cfg.get("data_path")
93
+ if not raw_dataset:
94
+ raise TrainError("Dataset not specified; pass --dataset or set [job].data")
95
+ dataset_path = Path(raw_dataset)
96
+ dataset_path = (dataset_path if dataset_path.is_absolute() else (config_path.parent / dataset_path)).resolve()
97
+ if not dataset_path.exists():
98
+ raise TrainError(f"Dataset not found: {dataset_path}")
99
+
100
+ validation_path = data_cfg.get("validation_path") if isinstance(data_cfg.get("validation_path"), str) else None
101
+ validation_file = None
102
+ if validation_path:
103
+ vpath = Path(validation_path)
104
+ vpath = (vpath if vpath.is_absolute() else (config_path.parent / vpath)).resolve()
105
+ if not vpath.exists():
106
+ click.echo(f"[WARN] Validation dataset {vpath} missing; continuing without validation")
107
+ else:
108
+ validation_file = vpath
109
+
110
+ hp_block: dict[str, Any] = {
111
+ "n_epochs": int(hp_cfg.get("n_epochs", 1)),
112
+ }
113
+ for key in (
114
+ "batch_size",
115
+ "global_batch",
116
+ "per_device_batch",
117
+ "gradient_accumulation_steps",
118
+ "sequence_length",
119
+ "learning_rate",
120
+ "warmup_ratio",
121
+ "train_kind",
122
+ ):
123
+ if key in hp_cfg:
124
+ hp_block[key] = hp_cfg[key]
125
+ if isinstance(hp_cfg.get("parallelism"), dict):
126
+ hp_block["parallelism"] = hp_cfg["parallelism"]
127
+
128
+ compute_block = {k: compute_cfg[k] for k in ("gpu_type", "gpu_count", "nodes") if k in compute_cfg}
129
+
130
+ effective = {
131
+ "compute": compute_block,
132
+ "data": {"topology": data_cfg.get("topology", {}) if isinstance(data_cfg.get("topology"), dict) else {}},
133
+ "training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
134
+ }
135
+
136
+ validation_cfg = train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
137
+ if isinstance(validation_cfg, dict):
138
+ hp_block.update(
139
+ {
140
+ "evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
141
+ "eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
142
+ "save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
143
+ "metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
144
+ "greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
145
+ }
146
+ )
147
+ effective.setdefault("training", {})["validation"] = {"enabled": bool(validation_cfg.get("enabled", True))}
148
+
149
+ payload = {
150
+ "model": job_cfg.get("model") or data.get("model"),
151
+ "training_file_id": None, # populated after upload
152
+ "training_type": "sft_offline",
153
+ "hyperparameters": hp_block,
154
+ "metadata": {"effective_config": effective},
155
+ }
156
+
157
+ return SFTBuildResult(payload=payload, train_file=dataset_path, validation_file=validation_file)
158
+
159
+
160
+ __all__ = [
161
+ "RLBuildResult",
162
+ "SFTBuildResult",
163
+ "build_rl_payload",
164
+ "build_sft_payload",
165
+ ]
@@ -0,0 +1,429 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+
7
+ import click
8
+
9
+ from .builders import RLBuildResult, SFTBuildResult, build_rl_payload, build_sft_payload
10
+ from .config_finder import discover_configs, prompt_for_config
11
+ from .env_resolver import KeySpec, resolve_env
12
+ from .pollers import RLJobPoller, SFTJobPoller
13
+ from .task_app import check_task_app_health
14
+ from .utils import (
15
+ TrainError,
16
+ REPO_ROOT,
17
+ ensure_api_base,
18
+ http_post,
19
+ http_get,
20
+ limit_jsonl_examples,
21
+ mask_value,
22
+ post_multipart,
23
+ preview_json,
24
+ sleep,
25
+ validate_sft_jsonl,
26
+ )
27
+
28
+
29
+ def _discover_dataset_candidates(config_path: Path, limit: int = 50) -> list[Path]:
30
+ search_dirs: list[Path] = [
31
+ config_path.parent,
32
+ config_path.parent / "datasets",
33
+ REPO_ROOT / "traces",
34
+ REPO_ROOT / "datasets",
35
+ ]
36
+
37
+ candidates: list[Path] = []
38
+ seen: set[Path] = set()
39
+ for directory in search_dirs:
40
+ if not directory.exists() or not directory.is_dir():
41
+ continue
42
+ for path in directory.rglob("*.jsonl"):
43
+ try:
44
+ resolved = path.resolve()
45
+ except OSError:
46
+ continue
47
+ if resolved in seen:
48
+ continue
49
+ seen.add(resolved)
50
+ if resolved.stat().st_size == 0:
51
+ continue
52
+ candidates.append(resolved)
53
+ if len(candidates) >= limit:
54
+ return candidates
55
+ return candidates
56
+
57
+
58
+ def prompt_for_dataset(config_path: Path) -> Path:
59
+ candidates = _discover_dataset_candidates(config_path)
60
+ while True:
61
+ if candidates:
62
+ click.echo("Select dataset JSONL file:")
63
+ for idx, candidate in enumerate(candidates, start=1):
64
+ click.echo(f" {idx}) {candidate}")
65
+ click.echo(" m) Enter path manually")
66
+ click.echo(" 0) Abort")
67
+ choice = click.prompt("Choice", default="m").strip().lower()
68
+ if choice == "0":
69
+ raise click.ClickException("Aborted by user")
70
+ if choice in {"m", "manual"}:
71
+ selected = _prompt_manual_dataset()
72
+ else:
73
+ try:
74
+ idx = int(choice)
75
+ except ValueError:
76
+ click.echo("Invalid selection; try again")
77
+ continue
78
+ if idx < 1 or idx > len(candidates):
79
+ click.echo("Invalid selection; try again")
80
+ continue
81
+ selected = candidates[idx - 1]
82
+ else:
83
+ selected = _prompt_manual_dataset()
84
+
85
+ if selected.exists() and selected.suffix == ".jsonl":
86
+ return selected.resolve()
87
+ click.echo("File not found or not a .jsonl; please try again.")
88
+
89
+
90
+ def _prompt_manual_dataset() -> Path:
91
+ manual = click.prompt("Enter dataset JSONL path", type=str).strip()
92
+ return Path(manual).expanduser()
93
+
94
+
95
+ @click.command("train")
96
+ @click.option("--config", "config_paths", multiple=True, type=click.Path(), help="Path to training TOML (repeatable)")
97
+ @click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft"]), default="auto")
98
+ @click.option("--env-file", "env_files", multiple=True, type=click.Path(), help=".env file(s) to preload (skips selection prompt)")
99
+ @click.option("--task-url", default=None, help="Override task app base URL (RL only)")
100
+ @click.option("--dataset", "dataset_path", type=click.Path(), default=None, help="Override dataset JSONL path (SFT)")
101
+ @click.option("--backend", default=lambda: os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"), help="Backend base URL")
102
+ @click.option("--model", default=None, help="Override model identifier")
103
+ @click.option("--idempotency", default=None, help="Idempotency-Key header for job creation")
104
+ @click.option("--dry-run", is_flag=True, help="Preview payload without submitting")
105
+ @click.option("--poll/--no-poll", default=True, help="Poll job status until terminal state")
106
+ @click.option("--poll-timeout", default=3600.0, type=float, help="Maximum seconds to poll before timing out")
107
+ @click.option("--poll-interval", default=5.0, type=float, help="Seconds between poll attempts")
108
+ @click.option("--examples", "examples_limit", type=int, default=None, help="Limit SFT training to the first N examples")
109
+ def train_command(
110
+ config_paths: tuple[str, ...],
111
+ train_type: str,
112
+ env_files: tuple[str, ...],
113
+ task_url: str | None,
114
+ dataset_path: str | None,
115
+ backend: str,
116
+ model: str | None,
117
+ idempotency: str | None,
118
+ dry_run: bool,
119
+ poll: bool,
120
+ poll_timeout: float,
121
+ poll_interval: float,
122
+ examples_limit: int | None,
123
+ ) -> None:
124
+ """Interactive launcher for RL / SFT jobs."""
125
+
126
+ candidates = discover_configs(list(config_paths), requested_type=train_type if train_type != "auto" else None)
127
+ selection = prompt_for_config(candidates, requested_type=train_type if train_type != "auto" else None)
128
+
129
+ effective_type = train_type if train_type != "auto" else selection.train_type
130
+ if effective_type not in {"rl", "sft"}:
131
+ effective_type = click.prompt("Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft"]))
132
+
133
+ cfg_path = selection.path
134
+ click.echo(f"Using config: {cfg_path} ({effective_type})")
135
+
136
+ required_keys: list[KeySpec] = []
137
+ if effective_type == "rl":
138
+ required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
139
+ required_keys.append(
140
+ KeySpec(
141
+ "ENVIRONMENT_API_KEY",
142
+ "Environment API key for task app",
143
+ allow_modal_secret=True,
144
+ modal_secret_pattern="env",
145
+ )
146
+ )
147
+ required_keys.append(
148
+ KeySpec(
149
+ "TASK_APP_URL",
150
+ "Task app base URL",
151
+ secret=False,
152
+ allow_modal_app=True,
153
+ optional=bool(task_url),
154
+ )
155
+ )
156
+ else: # sft
157
+ required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
158
+
159
+ env_path, env_values = resolve_env(
160
+ config_path=cfg_path,
161
+ explicit_env_paths=env_files,
162
+ required_keys=required_keys,
163
+ )
164
+ click.echo(f"Using env file: {env_path}")
165
+
166
+ synth_key = env_values.get("SYNTH_API_KEY") or os.environ.get("SYNTH_API_KEY")
167
+ if not synth_key:
168
+ raise click.ClickException("SYNTH_API_KEY required")
169
+
170
+ backend_base = ensure_api_base(backend)
171
+ click.echo(f"Backend base: {backend_base} (key {mask_value(synth_key)})")
172
+
173
+ if effective_type == "rl":
174
+ handle_rl(
175
+ cfg_path=cfg_path,
176
+ backend_base=backend_base,
177
+ synth_key=synth_key,
178
+ task_url_override=task_url,
179
+ model_override=model,
180
+ idempotency=idempotency,
181
+ dry_run=dry_run,
182
+ poll=poll,
183
+ poll_timeout=poll_timeout,
184
+ poll_interval=poll_interval,
185
+ )
186
+ else:
187
+ dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
188
+ handle_sft(
189
+ cfg_path=cfg_path,
190
+ backend_base=backend_base,
191
+ synth_key=synth_key,
192
+ dataset_override=dataset_override_path,
193
+ dry_run=dry_run,
194
+ poll=poll,
195
+ poll_timeout=poll_timeout,
196
+ poll_interval=poll_interval,
197
+ examples_limit=examples_limit,
198
+ )
199
+
200
+
201
+ def _wait_for_training_file(backend_base: str, api_key: str, file_id: str, *, timeout: float = 120.0) -> None:
202
+ url = f"{backend_base}/learning/files/{file_id}"
203
+ headers = {"Authorization": f"Bearer {api_key}"}
204
+ elapsed = 0.0
205
+ interval = 2.0
206
+ while True:
207
+ resp = http_get(url, headers=headers, timeout=30.0)
208
+ if resp.status_code == 200:
209
+ try:
210
+ data = resp.json()
211
+ except Exception:
212
+ data = {}
213
+ status = str(data.get("status") or data.get("state") or data.get("storage_state") or "ready").lower()
214
+ if status in {"ready", "uploaded", "stored", "complete"}:
215
+ return
216
+ elif resp.status_code == 404:
217
+ # Keep polling; object may not be visible yet
218
+ pass
219
+ else:
220
+ click.echo(f"[WARN] Unexpected response while checking training file {file_id}: {resp.status_code}")
221
+
222
+ if elapsed >= timeout:
223
+ raise click.ClickException(f"Training file {file_id} not ready after {timeout:.0f}s")
224
+ sleep(interval)
225
+ elapsed += interval
226
+
227
+
228
+ def handle_rl(
229
+ *,
230
+ cfg_path: Path,
231
+ backend_base: str,
232
+ synth_key: str,
233
+ task_url_override: str | None,
234
+ model_override: str | None,
235
+ idempotency: str | None,
236
+ dry_run: bool,
237
+ poll: bool,
238
+ poll_timeout: float,
239
+ poll_interval: float,
240
+ ) -> None:
241
+ overrides: Dict[str, Any] = {"backend": backend_base, "task_url": task_url_override, "model": model_override}
242
+ build = build_rl_payload(
243
+ config_path=cfg_path,
244
+ task_url=task_url_override or os.environ.get("TASK_APP_URL", ""),
245
+ overrides=overrides,
246
+ idempotency=idempotency,
247
+ )
248
+
249
+ # Backend-side verification: try ALL org environment keys against /health and /task_info
250
+ verify_url = f"{backend_base}/rl/verify_task_app"
251
+ verify_headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
252
+ try:
253
+ vresp = http_post(verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url})
254
+ try:
255
+ vjs = vresp.json()
256
+ except Exception:
257
+ vjs = {"status": vresp.status_code, "text": (vresp.text or "")[:400]}
258
+ except Exception as _ve:
259
+ raise click.ClickException(f"Task app verification call failed: {type(_ve).__name__}: {_ve}") from _ve
260
+ if vresp.status_code >= 400:
261
+ click.echo("Task app verification error:\n" + preview_json(vjs, limit=800))
262
+ raise click.ClickException(f"Verification failed with status {vresp.status_code}")
263
+ if not bool(vjs.get("any_ok")):
264
+ click.echo("Task app verification failed; no auth combination succeeded. Full report:")
265
+ click.echo(preview_json(vjs, limit=1200))
266
+ raise click.ClickException("Task app verification failed (auth)")
267
+ else:
268
+ # Print concise summary
269
+ try:
270
+ cands = vjs.get("candidates_first15") or []
271
+ attempts = vjs.get("attempts") or []
272
+ statuses = [a.get("status") for a in attempts]
273
+ click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
274
+ except Exception:
275
+ pass
276
+
277
+ env_key = os.environ.get("ENVIRONMENT_API_KEY")
278
+ if not env_key:
279
+ raise click.ClickException("ENVIRONMENT_API_KEY required for RL flow")
280
+
281
+ click.echo("Performing task app health check…")
282
+ health = check_task_app_health(build.task_url, env_key)
283
+ if not health.ok:
284
+ click.echo(f"Task app health check failed: {health.detail}")
285
+ raise click.ClickException("Aborting due to failing health check")
286
+ else:
287
+ click.echo("Task app healthy")
288
+
289
+ create_url = f"{backend_base}/rl/jobs"
290
+ headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
291
+ if build.idempotency:
292
+ headers["Idempotency-Key"] = build.idempotency
293
+
294
+ click.echo(f"POST {create_url}")
295
+ click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
296
+ if dry_run:
297
+ click.echo("Dry run enabled; skipping submission")
298
+ return
299
+
300
+ resp = http_post(create_url, headers=headers, json_body=build.payload)
301
+ try:
302
+ js = resp.json()
303
+ except Exception:
304
+ js = {"status": resp.status_code, "text": resp.text[:400]}
305
+ click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
306
+ if resp.status_code not in (200, 201):
307
+ raise click.ClickException("Job creation failed")
308
+ job_id = js.get("job_id") or js.get("id")
309
+ if not job_id:
310
+ raise click.ClickException("Response missing job id")
311
+
312
+ if not poll:
313
+ click.echo(f"Created job {job_id} (polling disabled)")
314
+ return
315
+
316
+ poller = RLJobPoller(backend_base, synth_key, interval=poll_interval, timeout=poll_timeout)
317
+ outcome = poller.poll_job(job_id)
318
+ click.echo(f"Final status: {outcome.status}")
319
+ click.echo(preview_json(outcome.payload, limit=600))
320
+
321
+
322
+ def handle_sft(
323
+ *,
324
+ cfg_path: Path,
325
+ backend_base: str,
326
+ synth_key: str,
327
+ dataset_override: Path | None,
328
+ dry_run: bool,
329
+ poll: bool,
330
+ poll_timeout: float,
331
+ poll_interval: float,
332
+ examples_limit: int | None,
333
+ ) -> None:
334
+ dataset_path = dataset_override
335
+
336
+ while True:
337
+ try:
338
+ build = build_sft_payload(config_path=cfg_path, dataset_override=dataset_path)
339
+ break
340
+ except TrainError as exc:
341
+ click.echo(str(exc))
342
+ dataset_path = prompt_for_dataset(cfg_path)
343
+
344
+ limited_path: Path | None = None
345
+
346
+ try:
347
+ if examples_limit is not None:
348
+ limited_path = limit_jsonl_examples(build.train_file, examples_limit)
349
+ click.echo(
350
+ f"Using first {examples_limit} examples from {build.train_file} -> {limited_path}"
351
+ )
352
+ build.train_file = limited_path
353
+
354
+ click.echo("Validating training dataset…")
355
+ validate_sft_jsonl(build.train_file)
356
+ if build.validation_file and build.validation_file.suffix == ".jsonl":
357
+ click.echo("Validating validation dataset…")
358
+ validate_sft_jsonl(build.validation_file)
359
+
360
+ upload_url = f"{backend_base}/learning/files"
361
+ click.echo(f"Uploading dataset {build.train_file}")
362
+ if dry_run:
363
+ click.echo("Dry run: skipping upload")
364
+ train_file_id = "dry-run-train"
365
+ val_file_id = None
366
+ else:
367
+ resp = post_multipart(upload_url, api_key=synth_key, file_field="file", file_path=build.train_file)
368
+ js = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
369
+ if resp.status_code >= 400 or "id" not in js:
370
+ raise click.ClickException(f"Training file upload failed ({resp.status_code}): {js or resp.text[:200]}")
371
+ train_file_id = js["id"]
372
+ val_file_id = None
373
+ if build.validation_file:
374
+ click.echo(f"Uploading validation dataset {build.validation_file}")
375
+ vresp = post_multipart(upload_url, api_key=synth_key, file_field="file", file_path=build.validation_file)
376
+ vjs = vresp.json() if vresp.headers.get("content-type", "").startswith("application/json") else {}
377
+ if vresp.status_code < 400 and "id" in vjs:
378
+ val_file_id = vjs["id"]
379
+ else:
380
+ click.echo(f"[WARN] Validation upload failed: {vresp.status_code} {vjs or vresp.text[:200]}")
381
+ payload = dict(build.payload)
382
+ payload["training_file_id"] = train_file_id
383
+ if val_file_id:
384
+ payload.setdefault("metadata", {}).setdefault("effective_config", {}).setdefault("data", {})["validation_files"] = [val_file_id]
385
+
386
+ try:
387
+ _wait_for_training_file(backend_base, synth_key, train_file_id)
388
+ except click.ClickException as exc:
389
+ raise click.ClickException(f"Training file {train_file_id} not ready: {exc}") from exc
390
+
391
+ click.echo("FFT job payload:\n" + preview_json(payload, limit=800))
392
+ if dry_run:
393
+ click.echo("Dry run: skipping job submission")
394
+ return
395
+
396
+ create_url = f"{backend_base}/learning/jobs"
397
+ headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
398
+ resp = http_post(create_url, headers=headers, json_body=payload)
399
+ js = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
400
+ click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
401
+ if resp.status_code not in (200, 201):
402
+ raise click.ClickException("Failed to create learning job")
403
+ job_id = js.get("job_id") or js.get("id")
404
+ if not job_id:
405
+ raise click.ClickException("Response missing job id")
406
+
407
+ start_url = f"{backend_base}/learning/jobs/{job_id}/start"
408
+ click.echo(f"POST {start_url} (start)")
409
+ _ = http_post(start_url, headers=headers, json_body={})
410
+
411
+ if not poll:
412
+ click.echo(f"Started job {job_id} (polling disabled)")
413
+ return
414
+
415
+ poller = SFTJobPoller(backend_base, synth_key, interval=poll_interval, timeout=poll_timeout)
416
+ outcome = poller.poll_job(job_id)
417
+ click.echo(f"Final status: {outcome.status}")
418
+ click.echo(preview_json(outcome.payload, limit=600))
419
+ finally:
420
+ if limited_path is not None:
421
+ try:
422
+ limited_path.unlink(missing_ok=True)
423
+ limited_path.parent.rmdir()
424
+ except Exception:
425
+ pass
426
+
427
+
428
+ def register(cli: click.Group) -> None:
429
+ cli.add_command(train_command)