mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/train/rft.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
|
|
9
|
+
from ..accel import get_backend
|
|
10
|
+
from ..config import ProjectConfig
|
|
11
|
+
from ..models import resolve_model_spec
|
|
12
|
+
from ..runs import RunPaths, new_run, snapshot_config
|
|
13
|
+
from ..envs.token_env import TokenEnvStep, StringTaskTokenEnv, create_token_env, load_token_env_spec
|
|
14
|
+
from ..util import ensure_dir, write_jsonl, now_ts, sha1_text, latency_summary_ms
|
|
15
|
+
from ..llm.registry import get_llm_backend
|
|
16
|
+
from ..llm.backend import BackendNotAvailable
|
|
17
|
+
from .lora import LoRAConfig
|
|
18
|
+
|
|
19
|
+
console = Console()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_verifier(verifier_path: Path):
|
|
23
|
+
import importlib.util
|
|
24
|
+
|
|
25
|
+
spec = importlib.util.spec_from_file_location(verifier_path.stem, verifier_path)
|
|
26
|
+
if spec is None or spec.loader is None:
|
|
27
|
+
raise RuntimeError(f"Could not load verifier: {verifier_path}")
|
|
28
|
+
module = importlib.util.module_from_spec(spec)
|
|
29
|
+
spec.loader.exec_module(module) # type: ignore
|
|
30
|
+
verify_fn = getattr(module, "verify", None)
|
|
31
|
+
if not callable(verify_fn):
|
|
32
|
+
raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
|
|
33
|
+
return verify_fn
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _normalize_observation(obs: list[int] | TokenEnvStep) -> tuple[list[int], float, bool, dict]:
|
|
37
|
+
if isinstance(obs, TokenEnvStep):
|
|
38
|
+
return list(obs.observation), float(obs.reward), bool(obs.done), dict(obs.info or {})
|
|
39
|
+
return list(obs), 0.0, False, {}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _rollout_token_env(
|
|
43
|
+
llm,
|
|
44
|
+
env,
|
|
45
|
+
*,
|
|
46
|
+
max_steps: int,
|
|
47
|
+
temperature: float,
|
|
48
|
+
seed: int,
|
|
49
|
+
) -> tuple[list[int], int, str, float, dict, int]:
|
|
50
|
+
obs = env.initial_observation()
|
|
51
|
+
obs_tokens, reward, done, info = _normalize_observation(obs)
|
|
52
|
+
prompt_len = len(obs_tokens)
|
|
53
|
+
full_tokens = list(obs_tokens)
|
|
54
|
+
gen_tokens = 0
|
|
55
|
+
|
|
56
|
+
for idx in range(max_steps):
|
|
57
|
+
if done:
|
|
58
|
+
break
|
|
59
|
+
prompt_text = llm.decode(obs_tokens)
|
|
60
|
+
gen = llm.generate_with_logprobs(
|
|
61
|
+
prompt_text,
|
|
62
|
+
max_new_tokens=1,
|
|
63
|
+
temperature=temperature,
|
|
64
|
+
top_p=1.0,
|
|
65
|
+
top_k_sampling=None,
|
|
66
|
+
seed=(seed + idx) % (2**31 - 1),
|
|
67
|
+
logprobs=0,
|
|
68
|
+
)
|
|
69
|
+
new_token = int(gen.token_ids[-1])
|
|
70
|
+
full_tokens.append(new_token)
|
|
71
|
+
gen_tokens += 1
|
|
72
|
+
|
|
73
|
+
step = env.step(new_token)
|
|
74
|
+
reward += float(step.reward)
|
|
75
|
+
done = bool(step.done)
|
|
76
|
+
info = dict(step.info or {})
|
|
77
|
+
obs_tokens = list(step.observation) if step.observation else list(full_tokens)
|
|
78
|
+
|
|
79
|
+
completion = llm.decode(full_tokens[prompt_len:])
|
|
80
|
+
return full_tokens, prompt_len, completion, reward, info, gen_tokens
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_path: Path, base_model_path: Path, accel: str) -> RunPaths:
|
|
84
|
+
run = new_run(project_root, "rft")
|
|
85
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
86
|
+
|
|
87
|
+
backend = get_backend(accel)
|
|
88
|
+
backend.patch()
|
|
89
|
+
console.print(f"[bold]RFT[/bold] run: {run.run_dir.name} algo={cfg.rft.algo} accel={backend.name}")
|
|
90
|
+
|
|
91
|
+
verify = load_verifier(verifier_path)
|
|
92
|
+
|
|
93
|
+
import yaml
|
|
94
|
+
|
|
95
|
+
env = yaml.safe_load(env_path.read_text(encoding="utf-8")) or {}
|
|
96
|
+
token_env_spec = load_token_env_spec(project_root, env)
|
|
97
|
+
tasks = env.get("tasks") or []
|
|
98
|
+
if token_env_spec is None and not tasks:
|
|
99
|
+
raise RuntimeError("Env has no tasks. Add `tasks:` list in env YAML.")
|
|
100
|
+
if token_env_spec is not None and token_env_spec.kind == "tasks" and not tasks:
|
|
101
|
+
raise RuntimeError("token_env is set to tasks shim but env has no tasks.")
|
|
102
|
+
|
|
103
|
+
accepted_path = run.run_dir / "accepted.jsonl"
|
|
104
|
+
|
|
105
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
106
|
+
base_model, adapter_path, adapter_meta = resolve_model_spec(project_root, str(base_model_path), cfg)
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
llm.load(
|
|
110
|
+
base_model,
|
|
111
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
112
|
+
dtype=cfg.model.dtype,
|
|
113
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
114
|
+
)
|
|
115
|
+
if adapter_path:
|
|
116
|
+
llm.apply_adapter(str(adapter_path))
|
|
117
|
+
else:
|
|
118
|
+
lora_cfg = LoRAConfig(
|
|
119
|
+
r=cfg.lora.r,
|
|
120
|
+
alpha=cfg.lora.alpha,
|
|
121
|
+
dropout=cfg.lora.dropout,
|
|
122
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
123
|
+
num_layers=cfg.lora.num_layers,
|
|
124
|
+
scale=cfg.lora.scale,
|
|
125
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
126
|
+
)
|
|
127
|
+
llm.apply_lora_from_config(lora_cfg)
|
|
128
|
+
except BackendNotAvailable as e:
|
|
129
|
+
console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
|
|
130
|
+
(run.adapter_dir / "ADAPTER.txt").write_text(
|
|
131
|
+
f"Backend unavailable in this environment.\nbase={base_model}\naccel={backend.name}\n",
|
|
132
|
+
encoding="utf-8",
|
|
133
|
+
)
|
|
134
|
+
return run
|
|
135
|
+
|
|
136
|
+
ref_llm = None
|
|
137
|
+
if cfg.rft.reference_model:
|
|
138
|
+
ref_llm = get_llm_backend(cfg.model.backend)
|
|
139
|
+
try:
|
|
140
|
+
ref_llm.load(
|
|
141
|
+
cfg.rft.reference_model,
|
|
142
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
143
|
+
dtype=cfg.model.dtype,
|
|
144
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
145
|
+
)
|
|
146
|
+
except BackendNotAvailable:
|
|
147
|
+
ref_llm = None
|
|
148
|
+
|
|
149
|
+
opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
150
|
+
|
|
151
|
+
rng = random.Random(cfg.train.seed)
|
|
152
|
+
total_iters = int(cfg.train.iters)
|
|
153
|
+
rollouts = int(cfg.rft.rollouts)
|
|
154
|
+
temperature = float(cfg.rft.temperature)
|
|
155
|
+
max_new = int(getattr(cfg.rft, "max_new_tokens", 256))
|
|
156
|
+
kl_coeff = float(cfg.rft.kl_coeff)
|
|
157
|
+
normalize_adv = bool(cfg.rft.normalize_advantage)
|
|
158
|
+
|
|
159
|
+
if token_env_spec is not None:
|
|
160
|
+
base_name = env.get("name") or "token_env"
|
|
161
|
+
eos_token_id = getattr(getattr(llm, "tokenizer", None), "eos_token_id", None)
|
|
162
|
+
|
|
163
|
+
for step in range(1, total_iters + 1):
|
|
164
|
+
if token_env_spec.kind == "tasks":
|
|
165
|
+
task = tasks[(step - 1) % len(tasks)]
|
|
166
|
+
prompt = task.get("prompt", "")
|
|
167
|
+
task_id = task.get("id") or sha1_text(prompt)[:12]
|
|
168
|
+
tests = task.get("tests", "")
|
|
169
|
+
verifier_kwargs = task.get("verifier_kwargs") or {}
|
|
170
|
+
else:
|
|
171
|
+
prompt = ""
|
|
172
|
+
tests = ""
|
|
173
|
+
verifier_kwargs = {}
|
|
174
|
+
task_id = f"{base_name}_{step:06d}"
|
|
175
|
+
|
|
176
|
+
gens = []
|
|
177
|
+
gen_tokens = 0
|
|
178
|
+
gen_start = time.time()
|
|
179
|
+
verifier_latencies_ms: list[float] = []
|
|
180
|
+
per_verifier_latencies: dict[str, list[float]] = {}
|
|
181
|
+
|
|
182
|
+
for k in range(rollouts):
|
|
183
|
+
wdir = ensure_dir(run.artifacts_dir / task_id / f"step_{step:06d}" / f"rollout_{k:02d}")
|
|
184
|
+
if token_env_spec.kind == "tasks":
|
|
185
|
+
env_instance = StringTaskTokenEnv(
|
|
186
|
+
prompt=prompt,
|
|
187
|
+
tests=tests,
|
|
188
|
+
verifier_fn=verify,
|
|
189
|
+
workdir=wdir,
|
|
190
|
+
max_steps=max_new,
|
|
191
|
+
encode=llm.encode,
|
|
192
|
+
decode=llm.decode,
|
|
193
|
+
verifier_kwargs=verifier_kwargs,
|
|
194
|
+
eos_token_id=eos_token_id,
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
env_instance = create_token_env(
|
|
198
|
+
token_env_spec,
|
|
199
|
+
workdir=wdir,
|
|
200
|
+
encode=llm.encode,
|
|
201
|
+
decode=llm.decode,
|
|
202
|
+
tokenizer=getattr(llm, "tokenizer", None),
|
|
203
|
+
max_steps=max_new,
|
|
204
|
+
seed=rng.randint(0, 2**31 - 1),
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
token_ids, prompt_len, completion, reward, info, gen_count = _rollout_token_env(
|
|
208
|
+
llm,
|
|
209
|
+
env_instance,
|
|
210
|
+
max_steps=max_new,
|
|
211
|
+
temperature=temperature,
|
|
212
|
+
seed=rng.randint(0, 2**31 - 1),
|
|
213
|
+
)
|
|
214
|
+
gen_tokens += gen_count
|
|
215
|
+
|
|
216
|
+
verifier_latency = info.get("verifier_latency_ms")
|
|
217
|
+
if verifier_latency is not None:
|
|
218
|
+
verifier_latencies_ms.append(float(verifier_latency))
|
|
219
|
+
per_lat = info.get("verifier_latencies_ms")
|
|
220
|
+
if isinstance(per_lat, dict):
|
|
221
|
+
for path, val in per_lat.items():
|
|
222
|
+
try:
|
|
223
|
+
per_verifier_latencies.setdefault(str(path), []).append(float(val))
|
|
224
|
+
except (TypeError, ValueError):
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
passed = bool(info.get("passed", reward > 0.0))
|
|
228
|
+
gens.append((token_ids, prompt_len, completion, passed, reward, info))
|
|
229
|
+
|
|
230
|
+
gen_elapsed = max(time.time() - gen_start, 1e-6)
|
|
231
|
+
tps = gen_tokens / gen_elapsed
|
|
232
|
+
|
|
233
|
+
mean_r = sum(r for *_rest, r, _info in gens) / max(1, len(gens))
|
|
234
|
+
std_r = (
|
|
235
|
+
sum((r - mean_r) ** 2 for *_rest, r, _info in gens) / max(1, len(gens))
|
|
236
|
+
) ** 0.5
|
|
237
|
+
advs = [r - mean_r for *_rest, r, _info in gens]
|
|
238
|
+
if normalize_adv and std_r > 1e-6:
|
|
239
|
+
advs = [a / std_r for a in advs]
|
|
240
|
+
|
|
241
|
+
def loss_fn(_model):
|
|
242
|
+
loss = llm.mx.array(0.0) # type: ignore
|
|
243
|
+
for (token_ids, prompt_len, _comp, _passed, _reward, _info), adv in zip(gens, advs):
|
|
244
|
+
logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
|
|
245
|
+
pg = -llm.mx.array(float(adv)) * logp # type: ignore
|
|
246
|
+
if ref_llm is not None and kl_coeff > 0:
|
|
247
|
+
ref_logp = ref_llm.sequence_logprob(token_ids, prompt_len=prompt_len)
|
|
248
|
+
pg = pg + llm.mx.array(kl_coeff) * (logp - ref_logp) # type: ignore
|
|
249
|
+
loss = loss + pg
|
|
250
|
+
return loss / llm.mx.array(float(len(gens))) # type: ignore
|
|
251
|
+
|
|
252
|
+
lval, grads = llm.value_and_grad(loss_fn)
|
|
253
|
+
if grads is not None:
|
|
254
|
+
llm.apply_grads(opt, grads)
|
|
255
|
+
|
|
256
|
+
best_idx = max(range(len(gens)), key=lambda i: gens[i][4])
|
|
257
|
+
best = gens[best_idx]
|
|
258
|
+
pass_at_1 = 1.0 if gens[0][3] else 0.0
|
|
259
|
+
pass_at_k = 1.0 if any(passed for *_g, passed, _r, _i in gens) else 0.0
|
|
260
|
+
acceptance = sum(1 for *_g, passed, _r, _i in gens if passed) / max(1, len(gens))
|
|
261
|
+
|
|
262
|
+
latency_summary = latency_summary_ms(verifier_latencies_ms)
|
|
263
|
+
per_verifier_summary = {
|
|
264
|
+
path: latency_summary_ms(vals) for path, vals in per_verifier_latencies.items()
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total_iters:
|
|
268
|
+
metrics = {
|
|
269
|
+
"ts": now_ts(),
|
|
270
|
+
"step": step,
|
|
271
|
+
"kind": "rft",
|
|
272
|
+
"algo": cfg.rft.algo,
|
|
273
|
+
"task_id": task_id,
|
|
274
|
+
"mean_reward": mean_r,
|
|
275
|
+
"std_reward": std_r,
|
|
276
|
+
"best_reward": best[4],
|
|
277
|
+
"best_passed": best[3],
|
|
278
|
+
"pass@1": pass_at_1,
|
|
279
|
+
"pass@k": pass_at_k,
|
|
280
|
+
"acceptance": acceptance,
|
|
281
|
+
"tokens_per_sec": tps,
|
|
282
|
+
"loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
|
|
283
|
+
"accel": backend.name,
|
|
284
|
+
}
|
|
285
|
+
if latency_summary:
|
|
286
|
+
metrics["verifier_latency_ms"] = latency_summary["mean"]
|
|
287
|
+
for key, val in latency_summary.items():
|
|
288
|
+
metrics[f"verifier_latency_ms_{key}"] = val
|
|
289
|
+
if per_verifier_summary:
|
|
290
|
+
metrics["verifier_latency_ms_by_path"] = per_verifier_summary
|
|
291
|
+
write_jsonl(run.metrics_path, [metrics])
|
|
292
|
+
|
|
293
|
+
for (token_ids, prompt_len, completion, passed, reward, _info) in gens:
|
|
294
|
+
if passed:
|
|
295
|
+
prompt_text = llm.decode(token_ids[:prompt_len]) if prompt_len > 0 else ""
|
|
296
|
+
write_jsonl(
|
|
297
|
+
accepted_path,
|
|
298
|
+
[
|
|
299
|
+
{
|
|
300
|
+
"prompt": prompt_text,
|
|
301
|
+
"response": completion,
|
|
302
|
+
"reward": reward,
|
|
303
|
+
"task_id": task_id,
|
|
304
|
+
}
|
|
305
|
+
],
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if step % cfg.train.save_every == 0 or step == total_iters:
|
|
309
|
+
llm.save_adapter(
|
|
310
|
+
str(run.adapter_dir),
|
|
311
|
+
metadata={
|
|
312
|
+
"base_model": base_model,
|
|
313
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
314
|
+
"run": run.run_dir.name,
|
|
315
|
+
"kind": "rft",
|
|
316
|
+
},
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
|
|
320
|
+
return run
|
|
321
|
+
|
|
322
|
+
for step in range(1, total_iters + 1):
|
|
323
|
+
task = tasks[(step - 1) % len(tasks)]
|
|
324
|
+
prompt = task.get("prompt", "")
|
|
325
|
+
task_id = task.get("id") or sha1_text(prompt)[:12]
|
|
326
|
+
|
|
327
|
+
gens = []
|
|
328
|
+
gen_tokens = 0
|
|
329
|
+
gen_start = time.time()
|
|
330
|
+
verifier_times = []
|
|
331
|
+
per_verifier_latencies: dict[str, list[float]] = {}
|
|
332
|
+
|
|
333
|
+
for k in range(rollouts):
|
|
334
|
+
gen = llm.generate(
|
|
335
|
+
prompt,
|
|
336
|
+
max_new_tokens=max_new,
|
|
337
|
+
temperature=temperature,
|
|
338
|
+
seed=rng.randint(0, 2**31 - 1),
|
|
339
|
+
)
|
|
340
|
+
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
341
|
+
gen_tokens += max(0, len(gen.token_ids) - gen.prompt_len)
|
|
342
|
+
|
|
343
|
+
wdir = ensure_dir(run.artifacts_dir / task_id / f"step_{step:06d}" / f"rollout_{k:02d}")
|
|
344
|
+
if "tests" in task:
|
|
345
|
+
tdir = ensure_dir(wdir / "tests")
|
|
346
|
+
(tdir / "test_task.py").write_text(task["tests"], encoding="utf-8")
|
|
347
|
+
|
|
348
|
+
t0 = time.time()
|
|
349
|
+
res = verify(prompt, completion, str(wdir), **(task.get("verifier_kwargs") or {}))
|
|
350
|
+
verifier_times.append(time.time() - t0)
|
|
351
|
+
per_lat = getattr(res, "info", {}) or {}
|
|
352
|
+
per_lat = per_lat.get("verifier_latencies_ms") if isinstance(per_lat, dict) else None
|
|
353
|
+
if isinstance(per_lat, dict):
|
|
354
|
+
for path, val in per_lat.items():
|
|
355
|
+
try:
|
|
356
|
+
per_verifier_latencies.setdefault(str(path), []).append(float(val))
|
|
357
|
+
except (TypeError, ValueError):
|
|
358
|
+
continue
|
|
359
|
+
|
|
360
|
+
passed = bool(getattr(res, "passed", False))
|
|
361
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
362
|
+
gens.append((gen, completion, passed, reward))
|
|
363
|
+
|
|
364
|
+
gen_elapsed = max(time.time() - gen_start, 1e-6)
|
|
365
|
+
tps = gen_tokens / gen_elapsed
|
|
366
|
+
|
|
367
|
+
mean_r = sum(r for *_rest, r in gens) / max(1, len(gens))
|
|
368
|
+
std_r = (
|
|
369
|
+
sum((r - mean_r) ** 2 for *_rest, r in gens) / max(1, len(gens))
|
|
370
|
+
) ** 0.5
|
|
371
|
+
advs = [r - mean_r for *_rest, r in gens]
|
|
372
|
+
if normalize_adv and std_r > 1e-6:
|
|
373
|
+
advs = [a / std_r for a in advs]
|
|
374
|
+
|
|
375
|
+
def loss_fn(_model):
|
|
376
|
+
loss = llm.mx.array(0.0) # type: ignore
|
|
377
|
+
for (gen, _comp, _passed, _reward), adv in zip(gens, advs):
|
|
378
|
+
logp = llm.sequence_logprob(gen.token_ids, prompt_len=gen.prompt_len)
|
|
379
|
+
pg = -llm.mx.array(float(adv)) * logp # type: ignore
|
|
380
|
+
if ref_llm is not None and kl_coeff > 0:
|
|
381
|
+
ref_logp = ref_llm.sequence_logprob(gen.token_ids, prompt_len=gen.prompt_len)
|
|
382
|
+
pg = pg + llm.mx.array(kl_coeff) * (logp - ref_logp) # type: ignore
|
|
383
|
+
loss = loss + pg
|
|
384
|
+
return loss / llm.mx.array(float(len(gens))) # type: ignore
|
|
385
|
+
|
|
386
|
+
lval, grads = llm.value_and_grad(loss_fn)
|
|
387
|
+
if grads is not None:
|
|
388
|
+
llm.apply_grads(opt, grads)
|
|
389
|
+
|
|
390
|
+
best_idx = max(range(len(gens)), key=lambda i: gens[i][3])
|
|
391
|
+
best = gens[best_idx]
|
|
392
|
+
pass_at_1 = 1.0 if gens[0][2] else 0.0
|
|
393
|
+
pass_at_k = 1.0 if any(passed for _g, _c, passed, _r in gens) else 0.0
|
|
394
|
+
acceptance = sum(1 for *_rest, passed, _reward in gens if passed) / max(1, len(gens))
|
|
395
|
+
|
|
396
|
+
latency_summary = latency_summary_ms([t * 1000.0 for t in verifier_times])
|
|
397
|
+
per_verifier_summary = {
|
|
398
|
+
path: latency_summary_ms(vals) for path, vals in per_verifier_latencies.items()
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total_iters:
|
|
402
|
+
write_jsonl(
|
|
403
|
+
run.metrics_path,
|
|
404
|
+
[
|
|
405
|
+
{
|
|
406
|
+
"ts": now_ts(),
|
|
407
|
+
"step": step,
|
|
408
|
+
"kind": "rft",
|
|
409
|
+
"algo": cfg.rft.algo,
|
|
410
|
+
"task_id": task_id,
|
|
411
|
+
"mean_reward": mean_r,
|
|
412
|
+
"std_reward": std_r,
|
|
413
|
+
"best_reward": best[3],
|
|
414
|
+
"best_passed": best[2],
|
|
415
|
+
"pass@1": pass_at_1,
|
|
416
|
+
"pass@k": pass_at_k,
|
|
417
|
+
"acceptance": acceptance,
|
|
418
|
+
"verifier_latency_ms": latency_summary.get("mean", 0.0),
|
|
419
|
+
"verifier_latency_ms_mean": latency_summary.get("mean", 0.0),
|
|
420
|
+
"verifier_latency_ms_p50": latency_summary.get("p50", 0.0),
|
|
421
|
+
"verifier_latency_ms_p90": latency_summary.get("p90", 0.0),
|
|
422
|
+
"verifier_latency_ms_p99": latency_summary.get("p99", 0.0),
|
|
423
|
+
"verifier_latency_ms_max": latency_summary.get("max", 0.0),
|
|
424
|
+
"verifier_latency_ms_by_path": per_verifier_summary,
|
|
425
|
+
"tokens_per_sec": tps,
|
|
426
|
+
"loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
|
|
427
|
+
"accel": backend.name,
|
|
428
|
+
}
|
|
429
|
+
],
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
for (gen, completion, passed, reward) in gens:
|
|
433
|
+
if passed:
|
|
434
|
+
write_jsonl(
|
|
435
|
+
accepted_path,
|
|
436
|
+
[
|
|
437
|
+
{
|
|
438
|
+
"prompt": prompt,
|
|
439
|
+
"response": completion,
|
|
440
|
+
"reward": reward,
|
|
441
|
+
"task_id": task_id,
|
|
442
|
+
}
|
|
443
|
+
],
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
if step % cfg.train.save_every == 0 or step == total_iters:
|
|
447
|
+
llm.save_adapter(
|
|
448
|
+
str(run.adapter_dir),
|
|
449
|
+
metadata={
|
|
450
|
+
"base_model": base_model,
|
|
451
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
452
|
+
"run": run.run_dir.name,
|
|
453
|
+
"kind": "rft",
|
|
454
|
+
},
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
|
|
458
|
+
return run
|
mlxsmith/train/sft.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
|
|
9
|
+
from ..accel import get_backend
|
|
10
|
+
from ..config import ProjectConfig
|
|
11
|
+
from ..models import resolve_model_spec
|
|
12
|
+
from ..runs import RunPaths, new_run, snapshot_config
|
|
13
|
+
from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
|
|
14
|
+
from ..llm.registry import get_llm_backend
|
|
15
|
+
from ..llm.backend import BackendNotAvailable
|
|
16
|
+
from .lora import LoRAConfig
|
|
17
|
+
|
|
18
|
+
console = Console()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _load_sft_rows(train_path: Path) -> list[dict]:
|
|
22
|
+
return [json.loads(line) for line in train_path.read_text(encoding="utf-8").splitlines() if line.strip()]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _row_to_prompt_response(row: dict) -> tuple[str, str]:
|
|
26
|
+
prompt = row.get("prompt") or row.get("instruction") or row.get("input") or ""
|
|
27
|
+
response = row.get("response") or row.get("output") or row.get("completion") or row.get("answer") or ""
|
|
28
|
+
if not response and "messages" in row:
|
|
29
|
+
msgs = row.get("messages") or []
|
|
30
|
+
if msgs:
|
|
31
|
+
prompt = "\n".join([m.get("content", "") for m in msgs[:-1]])
|
|
32
|
+
response = msgs[-1].get("content", "") or ""
|
|
33
|
+
return prompt, response
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_path: str, accel: str) -> RunPaths:
|
|
37
|
+
run = new_run(project_root, "sft")
|
|
38
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
39
|
+
|
|
40
|
+
backend = get_backend(accel)
|
|
41
|
+
backend.patch()
|
|
42
|
+
console.print(f"[bold]SFT[/bold] run: {run.run_dir.name} accel={backend.name}")
|
|
43
|
+
|
|
44
|
+
train_path = data_dir / "train.jsonl"
|
|
45
|
+
if not train_path.exists():
|
|
46
|
+
raise RuntimeError(
|
|
47
|
+
"Missing train.jsonl. Run `mlxsmith data split` or point --data to a dir containing train.jsonl"
|
|
48
|
+
)
|
|
49
|
+
rows = _load_sft_rows(train_path)
|
|
50
|
+
|
|
51
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
52
|
+
base_model, adapter_path, adapter_meta = resolve_model_spec(project_root, model_id_or_path, cfg)
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
llm.load(
|
|
56
|
+
base_model,
|
|
57
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
58
|
+
dtype=cfg.model.dtype,
|
|
59
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
60
|
+
)
|
|
61
|
+
if adapter_path:
|
|
62
|
+
llm.apply_adapter(str(adapter_path))
|
|
63
|
+
else:
|
|
64
|
+
lora_cfg = LoRAConfig(
|
|
65
|
+
r=cfg.lora.r,
|
|
66
|
+
alpha=cfg.lora.alpha,
|
|
67
|
+
dropout=cfg.lora.dropout,
|
|
68
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
69
|
+
num_layers=cfg.lora.num_layers,
|
|
70
|
+
scale=cfg.lora.scale,
|
|
71
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
72
|
+
)
|
|
73
|
+
llm.apply_lora_from_config(lora_cfg)
|
|
74
|
+
except BackendNotAvailable as e:
|
|
75
|
+
console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
|
|
76
|
+
(run.adapter_dir / "ADAPTER.txt").write_text(
|
|
77
|
+
f"Backend unavailable in this environment.\nmodel={model_id_or_path}\naccel={backend.name}\n",
|
|
78
|
+
encoding="utf-8",
|
|
79
|
+
)
|
|
80
|
+
return run
|
|
81
|
+
|
|
82
|
+
opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
83
|
+
|
|
84
|
+
total = int(cfg.train.iters)
|
|
85
|
+
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
86
|
+
train_on_prompt = bool(getattr(cfg.train, "train_on_prompt", False))
|
|
87
|
+
max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 1.0))
|
|
88
|
+
|
|
89
|
+
rng = random.Random(cfg.train.seed)
|
|
90
|
+
accum_grads = None
|
|
91
|
+
accum_loss = 0.0
|
|
92
|
+
|
|
93
|
+
for step in range(1, total + 1):
|
|
94
|
+
row = rng.choice(rows)
|
|
95
|
+
prompt, response = _row_to_prompt_response(row)
|
|
96
|
+
if not response:
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
text = f"{prompt}{response}"
|
|
100
|
+
prompt_ids = llm.encode(prompt)
|
|
101
|
+
ids = llm.encode(text)
|
|
102
|
+
max_len = int(cfg.model.max_seq_len)
|
|
103
|
+
if max_len and len(ids) > max_len:
|
|
104
|
+
overflow = len(ids) - max_len
|
|
105
|
+
ids = ids[overflow:]
|
|
106
|
+
prompt_ids = prompt_ids[overflow:] if overflow < len(prompt_ids) else []
|
|
107
|
+
|
|
108
|
+
def loss_fn(_model):
|
|
109
|
+
return llm.sft_loss(ids, train_on_prompt=train_on_prompt, prompt_len=len(prompt_ids))
|
|
110
|
+
|
|
111
|
+
lval, grads = llm.value_and_grad(loss_fn)
|
|
112
|
+
accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
|
|
113
|
+
if grads is not None:
|
|
114
|
+
accum_grads = tree_add(accum_grads, grads)
|
|
115
|
+
|
|
116
|
+
if step % grad_accum == 0:
|
|
117
|
+
if accum_grads is not None:
|
|
118
|
+
scaled = tree_scale(accum_grads, 1.0 / grad_accum)
|
|
119
|
+
if max_grad_norm > 0:
|
|
120
|
+
scaled = clip_grad_norm(scaled, max_grad_norm)
|
|
121
|
+
llm.apply_grads(opt, scaled)
|
|
122
|
+
accum_grads = None
|
|
123
|
+
accum_loss = 0.0
|
|
124
|
+
|
|
125
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
126
|
+
write_jsonl(
|
|
127
|
+
run.metrics_path,
|
|
128
|
+
[
|
|
129
|
+
{
|
|
130
|
+
"ts": now_ts(),
|
|
131
|
+
"step": step,
|
|
132
|
+
"kind": "sft",
|
|
133
|
+
"loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
|
|
134
|
+
"accel": backend.name,
|
|
135
|
+
}
|
|
136
|
+
],
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if step % cfg.train.save_every == 0 or step == total:
|
|
140
|
+
llm.save_adapter(
|
|
141
|
+
str(run.adapter_dir),
|
|
142
|
+
metadata={
|
|
143
|
+
"base_model": base_model,
|
|
144
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
145
|
+
"run": run.run_dir.name,
|
|
146
|
+
"kind": "sft",
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
|
|
151
|
+
return run
|