synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. synth_ai/api/train/__init__.py +5 -0
  2. synth_ai/api/train/builders.py +165 -0
  3. synth_ai/api/train/cli.py +450 -0
  4. synth_ai/api/train/config_finder.py +168 -0
  5. synth_ai/api/train/env_resolver.py +302 -0
  6. synth_ai/api/train/pollers.py +66 -0
  7. synth_ai/api/train/task_app.py +193 -0
  8. synth_ai/api/train/utils.py +232 -0
  9. synth_ai/cli/__init__.py +23 -0
  10. synth_ai/cli/rl_demo.py +18 -6
  11. synth_ai/cli/root.py +38 -6
  12. synth_ai/cli/task_apps.py +1107 -0
  13. synth_ai/demo_registry.py +258 -0
  14. synth_ai/demos/core/cli.py +147 -111
  15. synth_ai/demos/demo_task_apps/__init__.py +7 -1
  16. synth_ai/demos/demo_task_apps/math/config.toml +55 -110
  17. synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
  18. synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
  19. synth_ai/task/__init__.py +94 -1
  20. synth_ai/task/apps/__init__.py +88 -0
  21. synth_ai/task/apps/grpo_crafter.py +438 -0
  22. synth_ai/task/apps/math_single_step.py +852 -0
  23. synth_ai/task/auth.py +153 -0
  24. synth_ai/task/client.py +165 -0
  25. synth_ai/task/contracts.py +29 -14
  26. synth_ai/task/datasets.py +105 -0
  27. synth_ai/task/errors.py +49 -0
  28. synth_ai/task/json.py +77 -0
  29. synth_ai/task/proxy.py +258 -0
  30. synth_ai/task/rubrics.py +212 -0
  31. synth_ai/task/server.py +398 -0
  32. synth_ai/task/tracing_utils.py +79 -0
  33. synth_ai/task/vendors.py +61 -0
  34. synth_ai/tracing_v3/session_tracer.py +13 -5
  35. synth_ai/tracing_v3/storage/base.py +10 -12
  36. synth_ai/tracing_v3/turso/manager.py +20 -6
  37. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
  38. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
  39. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
  40. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
  41. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
  42. {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .cli import register, train_command
4
+
5
+ __all__ = ["register", "train_command"]
@@ -0,0 +1,165 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import click
8
+
9
+ from .utils import ensure_api_base, load_toml, TrainError
10
+
11
+
12
+ @dataclass(slots=True)
13
+ class RLBuildResult:
14
+ payload: dict[str, Any]
15
+ task_url: str
16
+ idempotency: str | None
17
+
18
+
19
+ @dataclass(slots=True)
20
+ class SFTBuildResult:
21
+ payload: dict[str, Any]
22
+ train_file: Path
23
+ validation_file: Path | None
24
+
25
+
26
+ def build_rl_payload(
27
+ *,
28
+ config_path: Path,
29
+ task_url: str,
30
+ overrides: dict[str, Any],
31
+ idempotency: str | None,
32
+ ) -> RLBuildResult:
33
+ data = load_toml(config_path)
34
+ services = data.get("services") if isinstance(data.get("services"), dict) else {}
35
+ model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
36
+
37
+ final_task_url = (overrides.get("task_url") or task_url or services.get("task_url") or "").strip()
38
+ if not final_task_url:
39
+ raise click.ClickException("Task app URL required (provide --task-url or set services.task_url in TOML)")
40
+
41
+ model_source = (model_cfg.get("source") or "").strip()
42
+ model_base = (model_cfg.get("base") or "").strip()
43
+ override_model = (overrides.get("model") or "").strip()
44
+ if override_model:
45
+ model_source = override_model
46
+ model_base = ""
47
+ if bool(model_source) == bool(model_base):
48
+ raise click.ClickException("Model section must specify exactly one of [model].source or [model].base")
49
+
50
+ # Force TOML services.task_url to the effective endpoint to avoid split URLs
51
+ try:
52
+ if isinstance(data.get("services"), dict):
53
+ data["services"]["task_url"] = final_task_url
54
+ else:
55
+ data["services"] = {"task_url": final_task_url}
56
+ except Exception:
57
+ pass
58
+
59
+ payload: dict[str, Any] = {
60
+ "job_type": "rl",
61
+ "compute": data.get("compute", {}),
62
+ "data": {
63
+ "endpoint_base_url": final_task_url.rstrip("/"),
64
+ "config": data,
65
+ },
66
+ "tags": {"source": "train-cli"},
67
+ }
68
+ if model_source:
69
+ payload["data"]["model"] = model_source
70
+ if model_base:
71
+ payload["data"]["base_model"] = model_base
72
+
73
+ backend = overrides.get("backend")
74
+ if backend:
75
+ payload.setdefault("metadata", {})["backend_base_url"] = ensure_api_base(str(backend))
76
+
77
+ return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
78
+
79
+
80
+ def build_sft_payload(
81
+ *,
82
+ config_path: Path,
83
+ dataset_override: Path | None,
84
+ ) -> SFTBuildResult:
85
+ data = load_toml(config_path)
86
+ job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
87
+ data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
88
+ hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
89
+ train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
90
+ compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
91
+
92
+ raw_dataset = dataset_override or job_cfg.get("data") or job_cfg.get("data_path")
93
+ if not raw_dataset:
94
+ raise TrainError("Dataset not specified; pass --dataset or set [job].data")
95
+ dataset_path = Path(raw_dataset)
96
+ dataset_path = (dataset_path if dataset_path.is_absolute() else (config_path.parent / dataset_path)).resolve()
97
+ if not dataset_path.exists():
98
+ raise TrainError(f"Dataset not found: {dataset_path}")
99
+
100
+ validation_path = data_cfg.get("validation_path") if isinstance(data_cfg.get("validation_path"), str) else None
101
+ validation_file = None
102
+ if validation_path:
103
+ vpath = Path(validation_path)
104
+ vpath = (vpath if vpath.is_absolute() else (config_path.parent / vpath)).resolve()
105
+ if not vpath.exists():
106
+ click.echo(f"[WARN] Validation dataset {vpath} missing; continuing without validation")
107
+ else:
108
+ validation_file = vpath
109
+
110
+ hp_block: dict[str, Any] = {
111
+ "n_epochs": int(hp_cfg.get("n_epochs", 1)),
112
+ }
113
+ for key in (
114
+ "batch_size",
115
+ "global_batch",
116
+ "per_device_batch",
117
+ "gradient_accumulation_steps",
118
+ "sequence_length",
119
+ "learning_rate",
120
+ "warmup_ratio",
121
+ "train_kind",
122
+ ):
123
+ if key in hp_cfg:
124
+ hp_block[key] = hp_cfg[key]
125
+ if isinstance(hp_cfg.get("parallelism"), dict):
126
+ hp_block["parallelism"] = hp_cfg["parallelism"]
127
+
128
+ compute_block = {k: compute_cfg[k] for k in ("gpu_type", "gpu_count", "nodes") if k in compute_cfg}
129
+
130
+ effective = {
131
+ "compute": compute_block,
132
+ "data": {"topology": data_cfg.get("topology", {}) if isinstance(data_cfg.get("topology"), dict) else {}},
133
+ "training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
134
+ }
135
+
136
+ validation_cfg = train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
137
+ if isinstance(validation_cfg, dict):
138
+ hp_block.update(
139
+ {
140
+ "evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
141
+ "eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
142
+ "save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
143
+ "metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
144
+ "greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
145
+ }
146
+ )
147
+ effective.setdefault("training", {})["validation"] = {"enabled": bool(validation_cfg.get("enabled", True))}
148
+
149
+ payload = {
150
+ "model": job_cfg.get("model") or data.get("model"),
151
+ "training_file_id": None, # populated after upload
152
+ "training_type": "sft_offline",
153
+ "hyperparameters": hp_block,
154
+ "metadata": {"effective_config": effective},
155
+ }
156
+
157
+ return SFTBuildResult(payload=payload, train_file=dataset_path, validation_file=validation_file)
158
+
159
+
160
+ __all__ = [
161
+ "RLBuildResult",
162
+ "SFTBuildResult",
163
+ "build_rl_payload",
164
+ "build_sft_payload",
165
+ ]
@@ -0,0 +1,450 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+
7
+ import click
8
+
9
+ from .builders import RLBuildResult, SFTBuildResult, build_rl_payload, build_sft_payload
10
+ from .config_finder import discover_configs, prompt_for_config
11
+ from .env_resolver import KeySpec, resolve_env
12
+ from .pollers import RLJobPoller, SFTJobPoller
13
+ from .task_app import check_task_app_health
14
+ from .utils import (
15
+ TrainError,
16
+ REPO_ROOT,
17
+ ensure_api_base,
18
+ http_post,
19
+ http_get,
20
+ limit_jsonl_examples,
21
+ mask_value,
22
+ post_multipart,
23
+ preview_json,
24
+ sleep,
25
+ validate_sft_jsonl,
26
+ )
27
+
28
+
29
+ def _discover_dataset_candidates(config_path: Path, limit: int = 50) -> list[Path]:
30
+ search_dirs: list[Path] = [
31
+ config_path.parent,
32
+ config_path.parent / "datasets",
33
+ REPO_ROOT / "traces",
34
+ REPO_ROOT / "datasets",
35
+ ]
36
+
37
+ candidates: list[Path] = []
38
+ seen: set[Path] = set()
39
+ for directory in search_dirs:
40
+ if not directory.exists() or not directory.is_dir():
41
+ continue
42
+ for path in directory.rglob("*.jsonl"):
43
+ try:
44
+ resolved = path.resolve()
45
+ except OSError:
46
+ continue
47
+ if resolved in seen:
48
+ continue
49
+ seen.add(resolved)
50
+ if resolved.stat().st_size == 0:
51
+ continue
52
+ candidates.append(resolved)
53
+ if len(candidates) >= limit:
54
+ return candidates
55
+ return candidates
56
+
57
+
58
+ def prompt_for_dataset(config_path: Path) -> Path:
59
+ candidates = _discover_dataset_candidates(config_path)
60
+ while True:
61
+ if candidates:
62
+ click.echo("Select dataset JSONL file:")
63
+ for idx, candidate in enumerate(candidates, start=1):
64
+ click.echo(f" {idx}) {candidate}")
65
+ click.echo(" m) Enter path manually")
66
+ click.echo(" 0) Abort")
67
+ choice = click.prompt("Choice", default="m").strip().lower()
68
+ if choice == "0":
69
+ raise click.ClickException("Aborted by user")
70
+ if choice in {"m", "manual"}:
71
+ selected = _prompt_manual_dataset()
72
+ else:
73
+ try:
74
+ idx = int(choice)
75
+ except ValueError:
76
+ click.echo("Invalid selection; try again")
77
+ continue
78
+ if idx < 1 or idx > len(candidates):
79
+ click.echo("Invalid selection; try again")
80
+ continue
81
+ selected = candidates[idx - 1]
82
+ else:
83
+ selected = _prompt_manual_dataset()
84
+
85
+ if selected.exists() and selected.suffix == ".jsonl":
86
+ return selected.resolve()
87
+ click.echo("File not found or not a .jsonl; please try again.")
88
+
89
+
90
+ def _prompt_manual_dataset() -> Path:
91
+ manual = click.prompt("Enter dataset JSONL path", type=str).strip()
92
+ return Path(manual).expanduser()
93
+
94
+
95
+ @click.command("train")
96
+ @click.option("--config", "config_paths", multiple=True, type=click.Path(), help="Path to training TOML (repeatable)")
97
+ @click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft"]), default="auto")
98
+ @click.option("--env-file", "env_files", multiple=True, type=click.Path(), help=".env file(s) to preload (skips selection prompt)")
99
+ @click.option("--task-url", default=None, help="Override task app base URL (RL only)")
100
+ @click.option("--dataset", "dataset_path", type=click.Path(), default=None, help="Override dataset JSONL path (SFT)")
101
+ @click.option("--backend", default=lambda: os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"), help="Backend base URL")
102
+ @click.option("--model", default=None, help="Override model identifier")
103
+ @click.option("--idempotency", default=None, help="Idempotency-Key header for job creation")
104
+ @click.option("--dry-run", is_flag=True, help="Preview payload without submitting")
105
+ @click.option("--poll/--no-poll", default=True, help="Poll job status until terminal state")
106
+ @click.option("--poll-timeout", default=3600.0, type=float, help="Maximum seconds to poll before timing out")
107
+ @click.option("--poll-interval", default=5.0, type=float, help="Seconds between poll attempts")
108
+ @click.option("--examples", "examples_limit", type=int, default=None, help="Limit SFT training to the first N examples")
109
+ def train_command(
110
+ config_paths: tuple[str, ...],
111
+ train_type: str,
112
+ env_files: tuple[str, ...],
113
+ task_url: str | None,
114
+ dataset_path: str | None,
115
+ backend: str,
116
+ model: str | None,
117
+ idempotency: str | None,
118
+ dry_run: bool,
119
+ poll: bool,
120
+ poll_timeout: float,
121
+ poll_interval: float,
122
+ examples_limit: int | None,
123
+ ) -> None:
124
+ """Interactive launcher for RL / SFT jobs."""
125
+
126
+ candidates = discover_configs(list(config_paths), requested_type=train_type if train_type != "auto" else None)
127
+ selection = prompt_for_config(candidates, requested_type=train_type if train_type != "auto" else None)
128
+
129
+ effective_type = train_type if train_type != "auto" else selection.train_type
130
+ if effective_type not in {"rl", "sft"}:
131
+ effective_type = click.prompt("Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft"]))
132
+
133
+ cfg_path = selection.path
134
+ click.echo(f"Using config: {cfg_path} ({effective_type})")
135
+
136
+ required_keys: list[KeySpec] = []
137
+ if effective_type == "rl":
138
+ required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
139
+ required_keys.append(
140
+ KeySpec(
141
+ "ENVIRONMENT_API_KEY",
142
+ "Environment API key for task app",
143
+ allow_modal_secret=True,
144
+ modal_secret_pattern="env",
145
+ )
146
+ )
147
+ required_keys.append(
148
+ KeySpec(
149
+ "TASK_APP_URL",
150
+ "Task app base URL",
151
+ secret=False,
152
+ allow_modal_app=True,
153
+ optional=bool(task_url),
154
+ )
155
+ )
156
+ else: # sft
157
+ required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
158
+
159
+ env_path, env_values = resolve_env(
160
+ config_path=cfg_path,
161
+ explicit_env_paths=env_files,
162
+ required_keys=required_keys,
163
+ )
164
+
165
+ missing_keys = [
166
+ spec.name
167
+ for spec in required_keys
168
+ if not spec.optional and not (env_values.get(spec.name) or os.environ.get(spec.name))
169
+ ]
170
+ if missing_keys:
171
+ try:
172
+ from synth_ai.cli.task_apps import _interactive_fill_env
173
+ except Exception as exc: # pragma: no cover - protective fallback
174
+ raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
175
+
176
+ target_dir = cfg_path.parent
177
+ generated = _interactive_fill_env(target_dir / ".env")
178
+ if generated is None:
179
+ raise click.ClickException("Required environment values missing; aborting.")
180
+ env_path, env_values = resolve_env(
181
+ config_path=cfg_path,
182
+ explicit_env_paths=(str(generated),),
183
+ required_keys=required_keys,
184
+ )
185
+ click.echo(f"Using env file: {env_path}")
186
+
187
+ synth_key = env_values.get("SYNTH_API_KEY") or os.environ.get("SYNTH_API_KEY")
188
+ if not synth_key:
189
+ raise click.ClickException("SYNTH_API_KEY required")
190
+
191
+ backend_base = ensure_api_base(backend)
192
+ click.echo(f"Backend base: {backend_base} (key {mask_value(synth_key)})")
193
+
194
+ if effective_type == "rl":
195
+ handle_rl(
196
+ cfg_path=cfg_path,
197
+ backend_base=backend_base,
198
+ synth_key=synth_key,
199
+ task_url_override=task_url,
200
+ model_override=model,
201
+ idempotency=idempotency,
202
+ dry_run=dry_run,
203
+ poll=poll,
204
+ poll_timeout=poll_timeout,
205
+ poll_interval=poll_interval,
206
+ )
207
+ else:
208
+ dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
209
+ handle_sft(
210
+ cfg_path=cfg_path,
211
+ backend_base=backend_base,
212
+ synth_key=synth_key,
213
+ dataset_override=dataset_override_path,
214
+ dry_run=dry_run,
215
+ poll=poll,
216
+ poll_timeout=poll_timeout,
217
+ poll_interval=poll_interval,
218
+ examples_limit=examples_limit,
219
+ )
220
+
221
+
222
+ def _wait_for_training_file(backend_base: str, api_key: str, file_id: str, *, timeout: float = 120.0) -> None:
223
+ url = f"{backend_base}/learning/files/{file_id}"
224
+ headers = {"Authorization": f"Bearer {api_key}"}
225
+ elapsed = 0.0
226
+ interval = 2.0
227
+ while True:
228
+ resp = http_get(url, headers=headers, timeout=30.0)
229
+ if resp.status_code == 200:
230
+ try:
231
+ data = resp.json()
232
+ except Exception:
233
+ data = {}
234
+ status = str(data.get("status") or data.get("state") or data.get("storage_state") or "ready").lower()
235
+ if status in {"ready", "uploaded", "stored", "complete"}:
236
+ return
237
+ elif resp.status_code == 404:
238
+ # Keep polling; object may not be visible yet
239
+ pass
240
+ else:
241
+ click.echo(f"[WARN] Unexpected response while checking training file {file_id}: {resp.status_code}")
242
+
243
+ if elapsed >= timeout:
244
+ raise click.ClickException(f"Training file {file_id} not ready after {timeout:.0f}s")
245
+ sleep(interval)
246
+ elapsed += interval
247
+
248
+
249
+ def handle_rl(
250
+ *,
251
+ cfg_path: Path,
252
+ backend_base: str,
253
+ synth_key: str,
254
+ task_url_override: str | None,
255
+ model_override: str | None,
256
+ idempotency: str | None,
257
+ dry_run: bool,
258
+ poll: bool,
259
+ poll_timeout: float,
260
+ poll_interval: float,
261
+ ) -> None:
262
+ overrides: Dict[str, Any] = {"backend": backend_base, "task_url": task_url_override, "model": model_override}
263
+ build = build_rl_payload(
264
+ config_path=cfg_path,
265
+ task_url=task_url_override or os.environ.get("TASK_APP_URL", ""),
266
+ overrides=overrides,
267
+ idempotency=idempotency,
268
+ )
269
+
270
+ # Backend-side verification: try ALL org environment keys against /health and /task_info
271
+ verify_url = f"{backend_base}/rl/verify_task_app"
272
+ verify_headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
273
+ try:
274
+ vresp = http_post(verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url})
275
+ try:
276
+ vjs = vresp.json()
277
+ except Exception:
278
+ vjs = {"status": vresp.status_code, "text": (vresp.text or "")[:400]}
279
+ except Exception as _ve:
280
+ raise click.ClickException(f"Task app verification call failed: {type(_ve).__name__}: {_ve}") from _ve
281
+ if vresp.status_code >= 400:
282
+ click.echo("Task app verification error:\n" + preview_json(vjs, limit=800))
283
+ raise click.ClickException(f"Verification failed with status {vresp.status_code}")
284
+ if not bool(vjs.get("any_ok")):
285
+ click.echo("Task app verification failed; no auth combination succeeded. Full report:")
286
+ click.echo(preview_json(vjs, limit=1200))
287
+ raise click.ClickException("Task app verification failed (auth)")
288
+ else:
289
+ # Print concise summary
290
+ try:
291
+ cands = vjs.get("candidates_first15") or []
292
+ attempts = vjs.get("attempts") or []
293
+ statuses = [a.get("status") for a in attempts]
294
+ click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
295
+ except Exception:
296
+ pass
297
+
298
+ env_key = os.environ.get("ENVIRONMENT_API_KEY")
299
+ if not env_key:
300
+ raise click.ClickException("ENVIRONMENT_API_KEY required for RL flow")
301
+
302
+ click.echo("Performing task app health check…")
303
+ health = check_task_app_health(build.task_url, env_key)
304
+ if not health.ok:
305
+ click.echo(f"Task app health check failed: {health.detail}")
306
+ raise click.ClickException("Aborting due to failing health check")
307
+ else:
308
+ click.echo("Task app healthy")
309
+
310
+ create_url = f"{backend_base}/rl/jobs"
311
+ headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
312
+ if build.idempotency:
313
+ headers["Idempotency-Key"] = build.idempotency
314
+
315
+ click.echo(f"POST {create_url}")
316
+ click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
317
+ if dry_run:
318
+ click.echo("Dry run enabled; skipping submission")
319
+ return
320
+
321
+ resp = http_post(create_url, headers=headers, json_body=build.payload)
322
+ try:
323
+ js = resp.json()
324
+ except Exception:
325
+ js = {"status": resp.status_code, "text": resp.text[:400]}
326
+ click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
327
+ if resp.status_code not in (200, 201):
328
+ raise click.ClickException("Job creation failed")
329
+ job_id = js.get("job_id") or js.get("id")
330
+ if not job_id:
331
+ raise click.ClickException("Response missing job id")
332
+
333
+ if not poll:
334
+ click.echo(f"Created job {job_id} (polling disabled)")
335
+ return
336
+
337
+ poller = RLJobPoller(backend_base, synth_key, interval=poll_interval, timeout=poll_timeout)
338
+ outcome = poller.poll_job(job_id)
339
+ click.echo(f"Final status: {outcome.status}")
340
+ click.echo(preview_json(outcome.payload, limit=600))
341
+
342
+
343
+ def handle_sft(
344
+ *,
345
+ cfg_path: Path,
346
+ backend_base: str,
347
+ synth_key: str,
348
+ dataset_override: Path | None,
349
+ dry_run: bool,
350
+ poll: bool,
351
+ poll_timeout: float,
352
+ poll_interval: float,
353
+ examples_limit: int | None,
354
+ ) -> None:
355
+ dataset_path = dataset_override
356
+
357
+ while True:
358
+ try:
359
+ build = build_sft_payload(config_path=cfg_path, dataset_override=dataset_path)
360
+ break
361
+ except TrainError as exc:
362
+ click.echo(str(exc))
363
+ dataset_path = prompt_for_dataset(cfg_path)
364
+
365
+ limited_path: Path | None = None
366
+
367
+ try:
368
+ if examples_limit is not None:
369
+ limited_path = limit_jsonl_examples(build.train_file, examples_limit)
370
+ click.echo(
371
+ f"Using first {examples_limit} examples from {build.train_file} -> {limited_path}"
372
+ )
373
+ build.train_file = limited_path
374
+
375
+ click.echo("Validating training dataset…")
376
+ validate_sft_jsonl(build.train_file)
377
+ if build.validation_file and build.validation_file.suffix == ".jsonl":
378
+ click.echo("Validating validation dataset…")
379
+ validate_sft_jsonl(build.validation_file)
380
+
381
+ upload_url = f"{backend_base}/learning/files"
382
+ click.echo(f"Uploading dataset {build.train_file}")
383
+ if dry_run:
384
+ click.echo("Dry run: skipping upload")
385
+ train_file_id = "dry-run-train"
386
+ val_file_id = None
387
+ else:
388
+ resp = post_multipart(upload_url, api_key=synth_key, file_field="file", file_path=build.train_file)
389
+ js = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
390
+ if resp.status_code >= 400 or "id" not in js:
391
+ raise click.ClickException(f"Training file upload failed ({resp.status_code}): {js or resp.text[:200]}")
392
+ train_file_id = js["id"]
393
+ val_file_id = None
394
+ if build.validation_file:
395
+ click.echo(f"Uploading validation dataset {build.validation_file}")
396
+ vresp = post_multipart(upload_url, api_key=synth_key, file_field="file", file_path=build.validation_file)
397
+ vjs = vresp.json() if vresp.headers.get("content-type", "").startswith("application/json") else {}
398
+ if vresp.status_code < 400 and "id" in vjs:
399
+ val_file_id = vjs["id"]
400
+ else:
401
+ click.echo(f"[WARN] Validation upload failed: {vresp.status_code} {vjs or vresp.text[:200]}")
402
+ payload = dict(build.payload)
403
+ payload["training_file_id"] = train_file_id
404
+ if val_file_id:
405
+ payload.setdefault("metadata", {}).setdefault("effective_config", {}).setdefault("data", {})["validation_files"] = [val_file_id]
406
+
407
+ try:
408
+ _wait_for_training_file(backend_base, synth_key, train_file_id)
409
+ except click.ClickException as exc:
410
+ raise click.ClickException(f"Training file {train_file_id} not ready: {exc}") from exc
411
+
412
+ click.echo("FFT job payload:\n" + preview_json(payload, limit=800))
413
+ if dry_run:
414
+ click.echo("Dry run: skipping job submission")
415
+ return
416
+
417
+ create_url = f"{backend_base}/learning/jobs"
418
+ headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
419
+ resp = http_post(create_url, headers=headers, json_body=payload)
420
+ js = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
421
+ click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
422
+ if resp.status_code not in (200, 201):
423
+ raise click.ClickException("Failed to create learning job")
424
+ job_id = js.get("job_id") or js.get("id")
425
+ if not job_id:
426
+ raise click.ClickException("Response missing job id")
427
+
428
+ start_url = f"{backend_base}/learning/jobs/{job_id}/start"
429
+ click.echo(f"POST {start_url} (start)")
430
+ _ = http_post(start_url, headers=headers, json_body={})
431
+
432
+ if not poll:
433
+ click.echo(f"Started job {job_id} (polling disabled)")
434
+ return
435
+
436
+ poller = SFTJobPoller(backend_base, synth_key, interval=poll_interval, timeout=poll_timeout)
437
+ outcome = poller.poll_job(job_id)
438
+ click.echo(f"Final status: {outcome.status}")
439
+ click.echo(preview_json(outcome.payload, limit=600))
440
+ finally:
441
+ if limited_path is not None:
442
+ try:
443
+ limited_path.unlink(missing_ok=True)
444
+ limited_path.parent.rmdir()
445
+ except Exception:
446
+ pass
447
+
448
+
449
+ def register(cli: click.Group) -> None:
450
+ cli.add_command(train_command)