synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +450 -0
- synth_ai/api/train/config_finder.py +168 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +193 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +18 -6
- synth_ai/cli/root.py +38 -6
- synth_ai/cli/task_apps.py +1107 -0
- synth_ai/demo_registry.py +258 -0
- synth_ai/demos/core/cli.py +147 -111
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +153 -0
- synth_ai/task/client.py +165 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
|
|
9
|
+
from .utils import ensure_api_base, load_toml, TrainError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(slots=True)
|
|
13
|
+
class RLBuildResult:
|
|
14
|
+
payload: dict[str, Any]
|
|
15
|
+
task_url: str
|
|
16
|
+
idempotency: str | None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(slots=True)
|
|
20
|
+
class SFTBuildResult:
|
|
21
|
+
payload: dict[str, Any]
|
|
22
|
+
train_file: Path
|
|
23
|
+
validation_file: Path | None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def build_rl_payload(
|
|
27
|
+
*,
|
|
28
|
+
config_path: Path,
|
|
29
|
+
task_url: str,
|
|
30
|
+
overrides: dict[str, Any],
|
|
31
|
+
idempotency: str | None,
|
|
32
|
+
) -> RLBuildResult:
|
|
33
|
+
data = load_toml(config_path)
|
|
34
|
+
services = data.get("services") if isinstance(data.get("services"), dict) else {}
|
|
35
|
+
model_cfg = data.get("model") if isinstance(data.get("model"), dict) else {}
|
|
36
|
+
|
|
37
|
+
final_task_url = (overrides.get("task_url") or task_url or services.get("task_url") or "").strip()
|
|
38
|
+
if not final_task_url:
|
|
39
|
+
raise click.ClickException("Task app URL required (provide --task-url or set services.task_url in TOML)")
|
|
40
|
+
|
|
41
|
+
model_source = (model_cfg.get("source") or "").strip()
|
|
42
|
+
model_base = (model_cfg.get("base") or "").strip()
|
|
43
|
+
override_model = (overrides.get("model") or "").strip()
|
|
44
|
+
if override_model:
|
|
45
|
+
model_source = override_model
|
|
46
|
+
model_base = ""
|
|
47
|
+
if bool(model_source) == bool(model_base):
|
|
48
|
+
raise click.ClickException("Model section must specify exactly one of [model].source or [model].base")
|
|
49
|
+
|
|
50
|
+
# Force TOML services.task_url to the effective endpoint to avoid split URLs
|
|
51
|
+
try:
|
|
52
|
+
if isinstance(data.get("services"), dict):
|
|
53
|
+
data["services"]["task_url"] = final_task_url
|
|
54
|
+
else:
|
|
55
|
+
data["services"] = {"task_url": final_task_url}
|
|
56
|
+
except Exception:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
payload: dict[str, Any] = {
|
|
60
|
+
"job_type": "rl",
|
|
61
|
+
"compute": data.get("compute", {}),
|
|
62
|
+
"data": {
|
|
63
|
+
"endpoint_base_url": final_task_url.rstrip("/"),
|
|
64
|
+
"config": data,
|
|
65
|
+
},
|
|
66
|
+
"tags": {"source": "train-cli"},
|
|
67
|
+
}
|
|
68
|
+
if model_source:
|
|
69
|
+
payload["data"]["model"] = model_source
|
|
70
|
+
if model_base:
|
|
71
|
+
payload["data"]["base_model"] = model_base
|
|
72
|
+
|
|
73
|
+
backend = overrides.get("backend")
|
|
74
|
+
if backend:
|
|
75
|
+
payload.setdefault("metadata", {})["backend_base_url"] = ensure_api_base(str(backend))
|
|
76
|
+
|
|
77
|
+
return RLBuildResult(payload=payload, task_url=final_task_url, idempotency=idempotency)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def build_sft_payload(
|
|
81
|
+
*,
|
|
82
|
+
config_path: Path,
|
|
83
|
+
dataset_override: Path | None,
|
|
84
|
+
) -> SFTBuildResult:
|
|
85
|
+
data = load_toml(config_path)
|
|
86
|
+
job_cfg = data.get("job") if isinstance(data.get("job"), dict) else {}
|
|
87
|
+
data_cfg = data.get("data") if isinstance(data.get("data"), dict) else {}
|
|
88
|
+
hp_cfg = data.get("hyperparameters") if isinstance(data.get("hyperparameters"), dict) else {}
|
|
89
|
+
train_cfg = data.get("training") if isinstance(data.get("training"), dict) else {}
|
|
90
|
+
compute_cfg = data.get("compute") if isinstance(data.get("compute"), dict) else {}
|
|
91
|
+
|
|
92
|
+
raw_dataset = dataset_override or job_cfg.get("data") or job_cfg.get("data_path")
|
|
93
|
+
if not raw_dataset:
|
|
94
|
+
raise TrainError("Dataset not specified; pass --dataset or set [job].data")
|
|
95
|
+
dataset_path = Path(raw_dataset)
|
|
96
|
+
dataset_path = (dataset_path if dataset_path.is_absolute() else (config_path.parent / dataset_path)).resolve()
|
|
97
|
+
if not dataset_path.exists():
|
|
98
|
+
raise TrainError(f"Dataset not found: {dataset_path}")
|
|
99
|
+
|
|
100
|
+
validation_path = data_cfg.get("validation_path") if isinstance(data_cfg.get("validation_path"), str) else None
|
|
101
|
+
validation_file = None
|
|
102
|
+
if validation_path:
|
|
103
|
+
vpath = Path(validation_path)
|
|
104
|
+
vpath = (vpath if vpath.is_absolute() else (config_path.parent / vpath)).resolve()
|
|
105
|
+
if not vpath.exists():
|
|
106
|
+
click.echo(f"[WARN] Validation dataset {vpath} missing; continuing without validation")
|
|
107
|
+
else:
|
|
108
|
+
validation_file = vpath
|
|
109
|
+
|
|
110
|
+
hp_block: dict[str, Any] = {
|
|
111
|
+
"n_epochs": int(hp_cfg.get("n_epochs", 1)),
|
|
112
|
+
}
|
|
113
|
+
for key in (
|
|
114
|
+
"batch_size",
|
|
115
|
+
"global_batch",
|
|
116
|
+
"per_device_batch",
|
|
117
|
+
"gradient_accumulation_steps",
|
|
118
|
+
"sequence_length",
|
|
119
|
+
"learning_rate",
|
|
120
|
+
"warmup_ratio",
|
|
121
|
+
"train_kind",
|
|
122
|
+
):
|
|
123
|
+
if key in hp_cfg:
|
|
124
|
+
hp_block[key] = hp_cfg[key]
|
|
125
|
+
if isinstance(hp_cfg.get("parallelism"), dict):
|
|
126
|
+
hp_block["parallelism"] = hp_cfg["parallelism"]
|
|
127
|
+
|
|
128
|
+
compute_block = {k: compute_cfg[k] for k in ("gpu_type", "gpu_count", "nodes") if k in compute_cfg}
|
|
129
|
+
|
|
130
|
+
effective = {
|
|
131
|
+
"compute": compute_block,
|
|
132
|
+
"data": {"topology": data_cfg.get("topology", {}) if isinstance(data_cfg.get("topology"), dict) else {}},
|
|
133
|
+
"training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
validation_cfg = train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
|
|
137
|
+
if isinstance(validation_cfg, dict):
|
|
138
|
+
hp_block.update(
|
|
139
|
+
{
|
|
140
|
+
"evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
|
|
141
|
+
"eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
|
|
142
|
+
"save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
|
|
143
|
+
"metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
|
|
144
|
+
"greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
|
|
145
|
+
}
|
|
146
|
+
)
|
|
147
|
+
effective.setdefault("training", {})["validation"] = {"enabled": bool(validation_cfg.get("enabled", True))}
|
|
148
|
+
|
|
149
|
+
payload = {
|
|
150
|
+
"model": job_cfg.get("model") or data.get("model"),
|
|
151
|
+
"training_file_id": None, # populated after upload
|
|
152
|
+
"training_type": "sft_offline",
|
|
153
|
+
"hyperparameters": hp_block,
|
|
154
|
+
"metadata": {"effective_config": effective},
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
return SFTBuildResult(payload=payload, train_file=dataset_path, validation_file=validation_file)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
__all__ = [
|
|
161
|
+
"RLBuildResult",
|
|
162
|
+
"SFTBuildResult",
|
|
163
|
+
"build_rl_payload",
|
|
164
|
+
"build_sft_payload",
|
|
165
|
+
]
|
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
|
|
9
|
+
from .builders import RLBuildResult, SFTBuildResult, build_rl_payload, build_sft_payload
|
|
10
|
+
from .config_finder import discover_configs, prompt_for_config
|
|
11
|
+
from .env_resolver import KeySpec, resolve_env
|
|
12
|
+
from .pollers import RLJobPoller, SFTJobPoller
|
|
13
|
+
from .task_app import check_task_app_health
|
|
14
|
+
from .utils import (
|
|
15
|
+
TrainError,
|
|
16
|
+
REPO_ROOT,
|
|
17
|
+
ensure_api_base,
|
|
18
|
+
http_post,
|
|
19
|
+
http_get,
|
|
20
|
+
limit_jsonl_examples,
|
|
21
|
+
mask_value,
|
|
22
|
+
post_multipart,
|
|
23
|
+
preview_json,
|
|
24
|
+
sleep,
|
|
25
|
+
validate_sft_jsonl,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _discover_dataset_candidates(config_path: Path, limit: int = 50) -> list[Path]:
|
|
30
|
+
search_dirs: list[Path] = [
|
|
31
|
+
config_path.parent,
|
|
32
|
+
config_path.parent / "datasets",
|
|
33
|
+
REPO_ROOT / "traces",
|
|
34
|
+
REPO_ROOT / "datasets",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
candidates: list[Path] = []
|
|
38
|
+
seen: set[Path] = set()
|
|
39
|
+
for directory in search_dirs:
|
|
40
|
+
if not directory.exists() or not directory.is_dir():
|
|
41
|
+
continue
|
|
42
|
+
for path in directory.rglob("*.jsonl"):
|
|
43
|
+
try:
|
|
44
|
+
resolved = path.resolve()
|
|
45
|
+
except OSError:
|
|
46
|
+
continue
|
|
47
|
+
if resolved in seen:
|
|
48
|
+
continue
|
|
49
|
+
seen.add(resolved)
|
|
50
|
+
if resolved.stat().st_size == 0:
|
|
51
|
+
continue
|
|
52
|
+
candidates.append(resolved)
|
|
53
|
+
if len(candidates) >= limit:
|
|
54
|
+
return candidates
|
|
55
|
+
return candidates
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def prompt_for_dataset(config_path: Path) -> Path:
|
|
59
|
+
candidates = _discover_dataset_candidates(config_path)
|
|
60
|
+
while True:
|
|
61
|
+
if candidates:
|
|
62
|
+
click.echo("Select dataset JSONL file:")
|
|
63
|
+
for idx, candidate in enumerate(candidates, start=1):
|
|
64
|
+
click.echo(f" {idx}) {candidate}")
|
|
65
|
+
click.echo(" m) Enter path manually")
|
|
66
|
+
click.echo(" 0) Abort")
|
|
67
|
+
choice = click.prompt("Choice", default="m").strip().lower()
|
|
68
|
+
if choice == "0":
|
|
69
|
+
raise click.ClickException("Aborted by user")
|
|
70
|
+
if choice in {"m", "manual"}:
|
|
71
|
+
selected = _prompt_manual_dataset()
|
|
72
|
+
else:
|
|
73
|
+
try:
|
|
74
|
+
idx = int(choice)
|
|
75
|
+
except ValueError:
|
|
76
|
+
click.echo("Invalid selection; try again")
|
|
77
|
+
continue
|
|
78
|
+
if idx < 1 or idx > len(candidates):
|
|
79
|
+
click.echo("Invalid selection; try again")
|
|
80
|
+
continue
|
|
81
|
+
selected = candidates[idx - 1]
|
|
82
|
+
else:
|
|
83
|
+
selected = _prompt_manual_dataset()
|
|
84
|
+
|
|
85
|
+
if selected.exists() and selected.suffix == ".jsonl":
|
|
86
|
+
return selected.resolve()
|
|
87
|
+
click.echo("File not found or not a .jsonl; please try again.")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _prompt_manual_dataset() -> Path:
|
|
91
|
+
manual = click.prompt("Enter dataset JSONL path", type=str).strip()
|
|
92
|
+
return Path(manual).expanduser()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@click.command("train")
|
|
96
|
+
@click.option("--config", "config_paths", multiple=True, type=click.Path(), help="Path to training TOML (repeatable)")
|
|
97
|
+
@click.option("--type", "train_type", type=click.Choice(["auto", "rl", "sft"]), default="auto")
|
|
98
|
+
@click.option("--env-file", "env_files", multiple=True, type=click.Path(), help=".env file(s) to preload (skips selection prompt)")
|
|
99
|
+
@click.option("--task-url", default=None, help="Override task app base URL (RL only)")
|
|
100
|
+
@click.option("--dataset", "dataset_path", type=click.Path(), default=None, help="Override dataset JSONL path (SFT)")
|
|
101
|
+
@click.option("--backend", default=lambda: os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"), help="Backend base URL")
|
|
102
|
+
@click.option("--model", default=None, help="Override model identifier")
|
|
103
|
+
@click.option("--idempotency", default=None, help="Idempotency-Key header for job creation")
|
|
104
|
+
@click.option("--dry-run", is_flag=True, help="Preview payload without submitting")
|
|
105
|
+
@click.option("--poll/--no-poll", default=True, help="Poll job status until terminal state")
|
|
106
|
+
@click.option("--poll-timeout", default=3600.0, type=float, help="Maximum seconds to poll before timing out")
|
|
107
|
+
@click.option("--poll-interval", default=5.0, type=float, help="Seconds between poll attempts")
|
|
108
|
+
@click.option("--examples", "examples_limit", type=int, default=None, help="Limit SFT training to the first N examples")
|
|
109
|
+
def train_command(
|
|
110
|
+
config_paths: tuple[str, ...],
|
|
111
|
+
train_type: str,
|
|
112
|
+
env_files: tuple[str, ...],
|
|
113
|
+
task_url: str | None,
|
|
114
|
+
dataset_path: str | None,
|
|
115
|
+
backend: str,
|
|
116
|
+
model: str | None,
|
|
117
|
+
idempotency: str | None,
|
|
118
|
+
dry_run: bool,
|
|
119
|
+
poll: bool,
|
|
120
|
+
poll_timeout: float,
|
|
121
|
+
poll_interval: float,
|
|
122
|
+
examples_limit: int | None,
|
|
123
|
+
) -> None:
|
|
124
|
+
"""Interactive launcher for RL / SFT jobs."""
|
|
125
|
+
|
|
126
|
+
candidates = discover_configs(list(config_paths), requested_type=train_type if train_type != "auto" else None)
|
|
127
|
+
selection = prompt_for_config(candidates, requested_type=train_type if train_type != "auto" else None)
|
|
128
|
+
|
|
129
|
+
effective_type = train_type if train_type != "auto" else selection.train_type
|
|
130
|
+
if effective_type not in {"rl", "sft"}:
|
|
131
|
+
effective_type = click.prompt("Detected config type is ambiguous. Enter type", type=click.Choice(["rl", "sft"]))
|
|
132
|
+
|
|
133
|
+
cfg_path = selection.path
|
|
134
|
+
click.echo(f"Using config: {cfg_path} ({effective_type})")
|
|
135
|
+
|
|
136
|
+
required_keys: list[KeySpec] = []
|
|
137
|
+
if effective_type == "rl":
|
|
138
|
+
required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
|
|
139
|
+
required_keys.append(
|
|
140
|
+
KeySpec(
|
|
141
|
+
"ENVIRONMENT_API_KEY",
|
|
142
|
+
"Environment API key for task app",
|
|
143
|
+
allow_modal_secret=True,
|
|
144
|
+
modal_secret_pattern="env",
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
required_keys.append(
|
|
148
|
+
KeySpec(
|
|
149
|
+
"TASK_APP_URL",
|
|
150
|
+
"Task app base URL",
|
|
151
|
+
secret=False,
|
|
152
|
+
allow_modal_app=True,
|
|
153
|
+
optional=bool(task_url),
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
else: # sft
|
|
157
|
+
required_keys.append(KeySpec("SYNTH_API_KEY", "Synth API key for backend"))
|
|
158
|
+
|
|
159
|
+
env_path, env_values = resolve_env(
|
|
160
|
+
config_path=cfg_path,
|
|
161
|
+
explicit_env_paths=env_files,
|
|
162
|
+
required_keys=required_keys,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
missing_keys = [
|
|
166
|
+
spec.name
|
|
167
|
+
for spec in required_keys
|
|
168
|
+
if not spec.optional and not (env_values.get(spec.name) or os.environ.get(spec.name))
|
|
169
|
+
]
|
|
170
|
+
if missing_keys:
|
|
171
|
+
try:
|
|
172
|
+
from synth_ai.cli.task_apps import _interactive_fill_env
|
|
173
|
+
except Exception as exc: # pragma: no cover - protective fallback
|
|
174
|
+
raise click.ClickException(f"Unable to prompt for env values: {exc}") from exc
|
|
175
|
+
|
|
176
|
+
target_dir = cfg_path.parent
|
|
177
|
+
generated = _interactive_fill_env(target_dir / ".env")
|
|
178
|
+
if generated is None:
|
|
179
|
+
raise click.ClickException("Required environment values missing; aborting.")
|
|
180
|
+
env_path, env_values = resolve_env(
|
|
181
|
+
config_path=cfg_path,
|
|
182
|
+
explicit_env_paths=(str(generated),),
|
|
183
|
+
required_keys=required_keys,
|
|
184
|
+
)
|
|
185
|
+
click.echo(f"Using env file: {env_path}")
|
|
186
|
+
|
|
187
|
+
synth_key = env_values.get("SYNTH_API_KEY") or os.environ.get("SYNTH_API_KEY")
|
|
188
|
+
if not synth_key:
|
|
189
|
+
raise click.ClickException("SYNTH_API_KEY required")
|
|
190
|
+
|
|
191
|
+
backend_base = ensure_api_base(backend)
|
|
192
|
+
click.echo(f"Backend base: {backend_base} (key {mask_value(synth_key)})")
|
|
193
|
+
|
|
194
|
+
if effective_type == "rl":
|
|
195
|
+
handle_rl(
|
|
196
|
+
cfg_path=cfg_path,
|
|
197
|
+
backend_base=backend_base,
|
|
198
|
+
synth_key=synth_key,
|
|
199
|
+
task_url_override=task_url,
|
|
200
|
+
model_override=model,
|
|
201
|
+
idempotency=idempotency,
|
|
202
|
+
dry_run=dry_run,
|
|
203
|
+
poll=poll,
|
|
204
|
+
poll_timeout=poll_timeout,
|
|
205
|
+
poll_interval=poll_interval,
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
dataset_override_path = Path(dataset_path).expanduser().resolve() if dataset_path else None
|
|
209
|
+
handle_sft(
|
|
210
|
+
cfg_path=cfg_path,
|
|
211
|
+
backend_base=backend_base,
|
|
212
|
+
synth_key=synth_key,
|
|
213
|
+
dataset_override=dataset_override_path,
|
|
214
|
+
dry_run=dry_run,
|
|
215
|
+
poll=poll,
|
|
216
|
+
poll_timeout=poll_timeout,
|
|
217
|
+
poll_interval=poll_interval,
|
|
218
|
+
examples_limit=examples_limit,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _wait_for_training_file(backend_base: str, api_key: str, file_id: str, *, timeout: float = 120.0) -> None:
|
|
223
|
+
url = f"{backend_base}/learning/files/{file_id}"
|
|
224
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
225
|
+
elapsed = 0.0
|
|
226
|
+
interval = 2.0
|
|
227
|
+
while True:
|
|
228
|
+
resp = http_get(url, headers=headers, timeout=30.0)
|
|
229
|
+
if resp.status_code == 200:
|
|
230
|
+
try:
|
|
231
|
+
data = resp.json()
|
|
232
|
+
except Exception:
|
|
233
|
+
data = {}
|
|
234
|
+
status = str(data.get("status") or data.get("state") or data.get("storage_state") or "ready").lower()
|
|
235
|
+
if status in {"ready", "uploaded", "stored", "complete"}:
|
|
236
|
+
return
|
|
237
|
+
elif resp.status_code == 404:
|
|
238
|
+
# Keep polling; object may not be visible yet
|
|
239
|
+
pass
|
|
240
|
+
else:
|
|
241
|
+
click.echo(f"[WARN] Unexpected response while checking training file {file_id}: {resp.status_code}")
|
|
242
|
+
|
|
243
|
+
if elapsed >= timeout:
|
|
244
|
+
raise click.ClickException(f"Training file {file_id} not ready after {timeout:.0f}s")
|
|
245
|
+
sleep(interval)
|
|
246
|
+
elapsed += interval
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def handle_rl(
|
|
250
|
+
*,
|
|
251
|
+
cfg_path: Path,
|
|
252
|
+
backend_base: str,
|
|
253
|
+
synth_key: str,
|
|
254
|
+
task_url_override: str | None,
|
|
255
|
+
model_override: str | None,
|
|
256
|
+
idempotency: str | None,
|
|
257
|
+
dry_run: bool,
|
|
258
|
+
poll: bool,
|
|
259
|
+
poll_timeout: float,
|
|
260
|
+
poll_interval: float,
|
|
261
|
+
) -> None:
|
|
262
|
+
overrides: Dict[str, Any] = {"backend": backend_base, "task_url": task_url_override, "model": model_override}
|
|
263
|
+
build = build_rl_payload(
|
|
264
|
+
config_path=cfg_path,
|
|
265
|
+
task_url=task_url_override or os.environ.get("TASK_APP_URL", ""),
|
|
266
|
+
overrides=overrides,
|
|
267
|
+
idempotency=idempotency,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Backend-side verification: try ALL org environment keys against /health and /task_info
|
|
271
|
+
verify_url = f"{backend_base}/rl/verify_task_app"
|
|
272
|
+
verify_headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
|
|
273
|
+
try:
|
|
274
|
+
vresp = http_post(verify_url, headers=verify_headers, json_body={"endpoint_base_url": build.task_url})
|
|
275
|
+
try:
|
|
276
|
+
vjs = vresp.json()
|
|
277
|
+
except Exception:
|
|
278
|
+
vjs = {"status": vresp.status_code, "text": (vresp.text or "")[:400]}
|
|
279
|
+
except Exception as _ve:
|
|
280
|
+
raise click.ClickException(f"Task app verification call failed: {type(_ve).__name__}: {_ve}") from _ve
|
|
281
|
+
if vresp.status_code >= 400:
|
|
282
|
+
click.echo("Task app verification error:\n" + preview_json(vjs, limit=800))
|
|
283
|
+
raise click.ClickException(f"Verification failed with status {vresp.status_code}")
|
|
284
|
+
if not bool(vjs.get("any_ok")):
|
|
285
|
+
click.echo("Task app verification failed; no auth combination succeeded. Full report:")
|
|
286
|
+
click.echo(preview_json(vjs, limit=1200))
|
|
287
|
+
raise click.ClickException("Task app verification failed (auth)")
|
|
288
|
+
else:
|
|
289
|
+
# Print concise summary
|
|
290
|
+
try:
|
|
291
|
+
cands = vjs.get("candidates_first15") or []
|
|
292
|
+
attempts = vjs.get("attempts") or []
|
|
293
|
+
statuses = [a.get("status") for a in attempts]
|
|
294
|
+
click.echo(f"Verification OK (candidates={cands}, statuses={statuses})")
|
|
295
|
+
except Exception:
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
env_key = os.environ.get("ENVIRONMENT_API_KEY")
|
|
299
|
+
if not env_key:
|
|
300
|
+
raise click.ClickException("ENVIRONMENT_API_KEY required for RL flow")
|
|
301
|
+
|
|
302
|
+
click.echo("Performing task app health check…")
|
|
303
|
+
health = check_task_app_health(build.task_url, env_key)
|
|
304
|
+
if not health.ok:
|
|
305
|
+
click.echo(f"Task app health check failed: {health.detail}")
|
|
306
|
+
raise click.ClickException("Aborting due to failing health check")
|
|
307
|
+
else:
|
|
308
|
+
click.echo("Task app healthy")
|
|
309
|
+
|
|
310
|
+
create_url = f"{backend_base}/rl/jobs"
|
|
311
|
+
headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
|
|
312
|
+
if build.idempotency:
|
|
313
|
+
headers["Idempotency-Key"] = build.idempotency
|
|
314
|
+
|
|
315
|
+
click.echo(f"POST {create_url}")
|
|
316
|
+
click.echo("Payload preview:\n" + preview_json(build.payload, limit=800))
|
|
317
|
+
if dry_run:
|
|
318
|
+
click.echo("Dry run enabled; skipping submission")
|
|
319
|
+
return
|
|
320
|
+
|
|
321
|
+
resp = http_post(create_url, headers=headers, json_body=build.payload)
|
|
322
|
+
try:
|
|
323
|
+
js = resp.json()
|
|
324
|
+
except Exception:
|
|
325
|
+
js = {"status": resp.status_code, "text": resp.text[:400]}
|
|
326
|
+
click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
|
|
327
|
+
if resp.status_code not in (200, 201):
|
|
328
|
+
raise click.ClickException("Job creation failed")
|
|
329
|
+
job_id = js.get("job_id") or js.get("id")
|
|
330
|
+
if not job_id:
|
|
331
|
+
raise click.ClickException("Response missing job id")
|
|
332
|
+
|
|
333
|
+
if not poll:
|
|
334
|
+
click.echo(f"Created job {job_id} (polling disabled)")
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
poller = RLJobPoller(backend_base, synth_key, interval=poll_interval, timeout=poll_timeout)
|
|
338
|
+
outcome = poller.poll_job(job_id)
|
|
339
|
+
click.echo(f"Final status: {outcome.status}")
|
|
340
|
+
click.echo(preview_json(outcome.payload, limit=600))
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def handle_sft(
|
|
344
|
+
*,
|
|
345
|
+
cfg_path: Path,
|
|
346
|
+
backend_base: str,
|
|
347
|
+
synth_key: str,
|
|
348
|
+
dataset_override: Path | None,
|
|
349
|
+
dry_run: bool,
|
|
350
|
+
poll: bool,
|
|
351
|
+
poll_timeout: float,
|
|
352
|
+
poll_interval: float,
|
|
353
|
+
examples_limit: int | None,
|
|
354
|
+
) -> None:
|
|
355
|
+
dataset_path = dataset_override
|
|
356
|
+
|
|
357
|
+
while True:
|
|
358
|
+
try:
|
|
359
|
+
build = build_sft_payload(config_path=cfg_path, dataset_override=dataset_path)
|
|
360
|
+
break
|
|
361
|
+
except TrainError as exc:
|
|
362
|
+
click.echo(str(exc))
|
|
363
|
+
dataset_path = prompt_for_dataset(cfg_path)
|
|
364
|
+
|
|
365
|
+
limited_path: Path | None = None
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
if examples_limit is not None:
|
|
369
|
+
limited_path = limit_jsonl_examples(build.train_file, examples_limit)
|
|
370
|
+
click.echo(
|
|
371
|
+
f"Using first {examples_limit} examples from {build.train_file} -> {limited_path}"
|
|
372
|
+
)
|
|
373
|
+
build.train_file = limited_path
|
|
374
|
+
|
|
375
|
+
click.echo("Validating training dataset…")
|
|
376
|
+
validate_sft_jsonl(build.train_file)
|
|
377
|
+
if build.validation_file and build.validation_file.suffix == ".jsonl":
|
|
378
|
+
click.echo("Validating validation dataset…")
|
|
379
|
+
validate_sft_jsonl(build.validation_file)
|
|
380
|
+
|
|
381
|
+
upload_url = f"{backend_base}/learning/files"
|
|
382
|
+
click.echo(f"Uploading dataset {build.train_file}")
|
|
383
|
+
if dry_run:
|
|
384
|
+
click.echo("Dry run: skipping upload")
|
|
385
|
+
train_file_id = "dry-run-train"
|
|
386
|
+
val_file_id = None
|
|
387
|
+
else:
|
|
388
|
+
resp = post_multipart(upload_url, api_key=synth_key, file_field="file", file_path=build.train_file)
|
|
389
|
+
js = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
|
390
|
+
if resp.status_code >= 400 or "id" not in js:
|
|
391
|
+
raise click.ClickException(f"Training file upload failed ({resp.status_code}): {js or resp.text[:200]}")
|
|
392
|
+
train_file_id = js["id"]
|
|
393
|
+
val_file_id = None
|
|
394
|
+
if build.validation_file:
|
|
395
|
+
click.echo(f"Uploading validation dataset {build.validation_file}")
|
|
396
|
+
vresp = post_multipart(upload_url, api_key=synth_key, file_field="file", file_path=build.validation_file)
|
|
397
|
+
vjs = vresp.json() if vresp.headers.get("content-type", "").startswith("application/json") else {}
|
|
398
|
+
if vresp.status_code < 400 and "id" in vjs:
|
|
399
|
+
val_file_id = vjs["id"]
|
|
400
|
+
else:
|
|
401
|
+
click.echo(f"[WARN] Validation upload failed: {vresp.status_code} {vjs or vresp.text[:200]}")
|
|
402
|
+
payload = dict(build.payload)
|
|
403
|
+
payload["training_file_id"] = train_file_id
|
|
404
|
+
if val_file_id:
|
|
405
|
+
payload.setdefault("metadata", {}).setdefault("effective_config", {}).setdefault("data", {})["validation_files"] = [val_file_id]
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
_wait_for_training_file(backend_base, synth_key, train_file_id)
|
|
409
|
+
except click.ClickException as exc:
|
|
410
|
+
raise click.ClickException(f"Training file {train_file_id} not ready: {exc}") from exc
|
|
411
|
+
|
|
412
|
+
click.echo("FFT job payload:\n" + preview_json(payload, limit=800))
|
|
413
|
+
if dry_run:
|
|
414
|
+
click.echo("Dry run: skipping job submission")
|
|
415
|
+
return
|
|
416
|
+
|
|
417
|
+
create_url = f"{backend_base}/learning/jobs"
|
|
418
|
+
headers = {"Authorization": f"Bearer {synth_key}", "Content-Type": "application/json"}
|
|
419
|
+
resp = http_post(create_url, headers=headers, json_body=payload)
|
|
420
|
+
js = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
|
421
|
+
click.echo(f"Response {resp.status_code}: {preview_json(js, limit=400)}")
|
|
422
|
+
if resp.status_code not in (200, 201):
|
|
423
|
+
raise click.ClickException("Failed to create learning job")
|
|
424
|
+
job_id = js.get("job_id") or js.get("id")
|
|
425
|
+
if not job_id:
|
|
426
|
+
raise click.ClickException("Response missing job id")
|
|
427
|
+
|
|
428
|
+
start_url = f"{backend_base}/learning/jobs/{job_id}/start"
|
|
429
|
+
click.echo(f"POST {start_url} (start)")
|
|
430
|
+
_ = http_post(start_url, headers=headers, json_body={})
|
|
431
|
+
|
|
432
|
+
if not poll:
|
|
433
|
+
click.echo(f"Started job {job_id} (polling disabled)")
|
|
434
|
+
return
|
|
435
|
+
|
|
436
|
+
poller = SFTJobPoller(backend_base, synth_key, interval=poll_interval, timeout=poll_timeout)
|
|
437
|
+
outcome = poller.poll_job(job_id)
|
|
438
|
+
click.echo(f"Final status: {outcome.status}")
|
|
439
|
+
click.echo(preview_json(outcome.payload, limit=600))
|
|
440
|
+
finally:
|
|
441
|
+
if limited_path is not None:
|
|
442
|
+
try:
|
|
443
|
+
limited_path.unlink(missing_ok=True)
|
|
444
|
+
limited_path.parent.rmdir()
|
|
445
|
+
except Exception:
|
|
446
|
+
pass
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def register(cli: click.Group) -> None:
|
|
450
|
+
cli.add_command(train_command)
|