mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/train/rft.py ADDED
@@ -0,0 +1,458 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ import time
5
+ from pathlib import Path
6
+
7
+ from rich.console import Console
8
+
9
+ from ..accel import get_backend
10
+ from ..config import ProjectConfig
11
+ from ..models import resolve_model_spec
12
+ from ..runs import RunPaths, new_run, snapshot_config
13
+ from ..envs.token_env import TokenEnvStep, StringTaskTokenEnv, create_token_env, load_token_env_spec
14
+ from ..util import ensure_dir, write_jsonl, now_ts, sha1_text, latency_summary_ms
15
+ from ..llm.registry import get_llm_backend
16
+ from ..llm.backend import BackendNotAvailable
17
+ from .lora import LoRAConfig
18
+
19
+ console = Console()
20
+
21
+
22
+ def load_verifier(verifier_path: Path):
23
+ import importlib.util
24
+
25
+ spec = importlib.util.spec_from_file_location(verifier_path.stem, verifier_path)
26
+ if spec is None or spec.loader is None:
27
+ raise RuntimeError(f"Could not load verifier: {verifier_path}")
28
+ module = importlib.util.module_from_spec(spec)
29
+ spec.loader.exec_module(module) # type: ignore
30
+ verify_fn = getattr(module, "verify", None)
31
+ if not callable(verify_fn):
32
+ raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
33
+ return verify_fn
34
+
35
+
36
+ def _normalize_observation(obs: list[int] | TokenEnvStep) -> tuple[list[int], float, bool, dict]:
37
+ if isinstance(obs, TokenEnvStep):
38
+ return list(obs.observation), float(obs.reward), bool(obs.done), dict(obs.info or {})
39
+ return list(obs), 0.0, False, {}
40
+
41
+
42
+ def _rollout_token_env(
43
+ llm,
44
+ env,
45
+ *,
46
+ max_steps: int,
47
+ temperature: float,
48
+ seed: int,
49
+ ) -> tuple[list[int], int, str, float, dict, int]:
50
+ obs = env.initial_observation()
51
+ obs_tokens, reward, done, info = _normalize_observation(obs)
52
+ prompt_len = len(obs_tokens)
53
+ full_tokens = list(obs_tokens)
54
+ gen_tokens = 0
55
+
56
+ for idx in range(max_steps):
57
+ if done:
58
+ break
59
+ prompt_text = llm.decode(obs_tokens)
60
+ gen = llm.generate_with_logprobs(
61
+ prompt_text,
62
+ max_new_tokens=1,
63
+ temperature=temperature,
64
+ top_p=1.0,
65
+ top_k_sampling=None,
66
+ seed=(seed + idx) % (2**31 - 1),
67
+ logprobs=0,
68
+ )
69
+ new_token = int(gen.token_ids[-1])
70
+ full_tokens.append(new_token)
71
+ gen_tokens += 1
72
+
73
+ step = env.step(new_token)
74
+ reward += float(step.reward)
75
+ done = bool(step.done)
76
+ info = dict(step.info or {})
77
+ obs_tokens = list(step.observation) if step.observation else list(full_tokens)
78
+
79
+ completion = llm.decode(full_tokens[prompt_len:])
80
+ return full_tokens, prompt_len, completion, reward, info, gen_tokens
81
+
82
+
83
+ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_path: Path, base_model_path: Path, accel: str) -> RunPaths:
84
+ run = new_run(project_root, "rft")
85
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
86
+
87
+ backend = get_backend(accel)
88
+ backend.patch()
89
+ console.print(f"[bold]RFT[/bold] run: {run.run_dir.name} algo={cfg.rft.algo} accel={backend.name}")
90
+
91
+ verify = load_verifier(verifier_path)
92
+
93
+ import yaml
94
+
95
+ env = yaml.safe_load(env_path.read_text(encoding="utf-8")) or {}
96
+ token_env_spec = load_token_env_spec(project_root, env)
97
+ tasks = env.get("tasks") or []
98
+ if token_env_spec is None and not tasks:
99
+ raise RuntimeError("Env has no tasks. Add `tasks:` list in env YAML.")
100
+ if token_env_spec is not None and token_env_spec.kind == "tasks" and not tasks:
101
+ raise RuntimeError("token_env is set to tasks shim but env has no tasks.")
102
+
103
+ accepted_path = run.run_dir / "accepted.jsonl"
104
+
105
+ llm = get_llm_backend(cfg.model.backend)
106
+ base_model, adapter_path, adapter_meta = resolve_model_spec(project_root, str(base_model_path), cfg)
107
+
108
+ try:
109
+ llm.load(
110
+ base_model,
111
+ max_seq_len=cfg.model.max_seq_len,
112
+ dtype=cfg.model.dtype,
113
+ trust_remote_code=cfg.model.trust_remote_code,
114
+ )
115
+ if adapter_path:
116
+ llm.apply_adapter(str(adapter_path))
117
+ else:
118
+ lora_cfg = LoRAConfig(
119
+ r=cfg.lora.r,
120
+ alpha=cfg.lora.alpha,
121
+ dropout=cfg.lora.dropout,
122
+ target_modules=list(cfg.lora.target_modules or []),
123
+ num_layers=cfg.lora.num_layers,
124
+ scale=cfg.lora.scale,
125
+ fine_tune_type=cfg.lora.fine_tune_type,
126
+ )
127
+ llm.apply_lora_from_config(lora_cfg)
128
+ except BackendNotAvailable as e:
129
+ console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
130
+ (run.adapter_dir / "ADAPTER.txt").write_text(
131
+ f"Backend unavailable in this environment.\nbase={base_model}\naccel={backend.name}\n",
132
+ encoding="utf-8",
133
+ )
134
+ return run
135
+
136
+ ref_llm = None
137
+ if cfg.rft.reference_model:
138
+ ref_llm = get_llm_backend(cfg.model.backend)
139
+ try:
140
+ ref_llm.load(
141
+ cfg.rft.reference_model,
142
+ max_seq_len=cfg.model.max_seq_len,
143
+ dtype=cfg.model.dtype,
144
+ trust_remote_code=cfg.model.trust_remote_code,
145
+ )
146
+ except BackendNotAvailable:
147
+ ref_llm = None
148
+
149
+ opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
150
+
151
+ rng = random.Random(cfg.train.seed)
152
+ total_iters = int(cfg.train.iters)
153
+ rollouts = int(cfg.rft.rollouts)
154
+ temperature = float(cfg.rft.temperature)
155
+ max_new = int(getattr(cfg.rft, "max_new_tokens", 256))
156
+ kl_coeff = float(cfg.rft.kl_coeff)
157
+ normalize_adv = bool(cfg.rft.normalize_advantage)
158
+
159
+ if token_env_spec is not None:
160
+ base_name = env.get("name") or "token_env"
161
+ eos_token_id = getattr(getattr(llm, "tokenizer", None), "eos_token_id", None)
162
+
163
+ for step in range(1, total_iters + 1):
164
+ if token_env_spec.kind == "tasks":
165
+ task = tasks[(step - 1) % len(tasks)]
166
+ prompt = task.get("prompt", "")
167
+ task_id = task.get("id") or sha1_text(prompt)[:12]
168
+ tests = task.get("tests", "")
169
+ verifier_kwargs = task.get("verifier_kwargs") or {}
170
+ else:
171
+ prompt = ""
172
+ tests = ""
173
+ verifier_kwargs = {}
174
+ task_id = f"{base_name}_{step:06d}"
175
+
176
+ gens = []
177
+ gen_tokens = 0
178
+ gen_start = time.time()
179
+ verifier_latencies_ms: list[float] = []
180
+ per_verifier_latencies: dict[str, list[float]] = {}
181
+
182
+ for k in range(rollouts):
183
+ wdir = ensure_dir(run.artifacts_dir / task_id / f"step_{step:06d}" / f"rollout_{k:02d}")
184
+ if token_env_spec.kind == "tasks":
185
+ env_instance = StringTaskTokenEnv(
186
+ prompt=prompt,
187
+ tests=tests,
188
+ verifier_fn=verify,
189
+ workdir=wdir,
190
+ max_steps=max_new,
191
+ encode=llm.encode,
192
+ decode=llm.decode,
193
+ verifier_kwargs=verifier_kwargs,
194
+ eos_token_id=eos_token_id,
195
+ )
196
+ else:
197
+ env_instance = create_token_env(
198
+ token_env_spec,
199
+ workdir=wdir,
200
+ encode=llm.encode,
201
+ decode=llm.decode,
202
+ tokenizer=getattr(llm, "tokenizer", None),
203
+ max_steps=max_new,
204
+ seed=rng.randint(0, 2**31 - 1),
205
+ )
206
+
207
+ token_ids, prompt_len, completion, reward, info, gen_count = _rollout_token_env(
208
+ llm,
209
+ env_instance,
210
+ max_steps=max_new,
211
+ temperature=temperature,
212
+ seed=rng.randint(0, 2**31 - 1),
213
+ )
214
+ gen_tokens += gen_count
215
+
216
+ verifier_latency = info.get("verifier_latency_ms")
217
+ if verifier_latency is not None:
218
+ verifier_latencies_ms.append(float(verifier_latency))
219
+ per_lat = info.get("verifier_latencies_ms")
220
+ if isinstance(per_lat, dict):
221
+ for path, val in per_lat.items():
222
+ try:
223
+ per_verifier_latencies.setdefault(str(path), []).append(float(val))
224
+ except (TypeError, ValueError):
225
+ continue
226
+
227
+ passed = bool(info.get("passed", reward > 0.0))
228
+ gens.append((token_ids, prompt_len, completion, passed, reward, info))
229
+
230
+ gen_elapsed = max(time.time() - gen_start, 1e-6)
231
+ tps = gen_tokens / gen_elapsed
232
+
233
+ mean_r = sum(r for *_rest, r, _info in gens) / max(1, len(gens))
234
+ std_r = (
235
+ sum((r - mean_r) ** 2 for *_rest, r, _info in gens) / max(1, len(gens))
236
+ ) ** 0.5
237
+ advs = [r - mean_r for *_rest, r, _info in gens]
238
+ if normalize_adv and std_r > 1e-6:
239
+ advs = [a / std_r for a in advs]
240
+
241
+ def loss_fn(_model):
242
+ loss = llm.mx.array(0.0) # type: ignore
243
+ for (token_ids, prompt_len, _comp, _passed, _reward, _info), adv in zip(gens, advs):
244
+ logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
245
+ pg = -llm.mx.array(float(adv)) * logp # type: ignore
246
+ if ref_llm is not None and kl_coeff > 0:
247
+ ref_logp = ref_llm.sequence_logprob(token_ids, prompt_len=prompt_len)
248
+ pg = pg + llm.mx.array(kl_coeff) * (logp - ref_logp) # type: ignore
249
+ loss = loss + pg
250
+ return loss / llm.mx.array(float(len(gens))) # type: ignore
251
+
252
+ lval, grads = llm.value_and_grad(loss_fn)
253
+ if grads is not None:
254
+ llm.apply_grads(opt, grads)
255
+
256
+ best_idx = max(range(len(gens)), key=lambda i: gens[i][4])
257
+ best = gens[best_idx]
258
+ pass_at_1 = 1.0 if gens[0][3] else 0.0
259
+ pass_at_k = 1.0 if any(passed for *_g, passed, _r, _i in gens) else 0.0
260
+ acceptance = sum(1 for *_g, passed, _r, _i in gens if passed) / max(1, len(gens))
261
+
262
+ latency_summary = latency_summary_ms(verifier_latencies_ms)
263
+ per_verifier_summary = {
264
+ path: latency_summary_ms(vals) for path, vals in per_verifier_latencies.items()
265
+ }
266
+
267
+ if step % cfg.train.log_every == 0 or step == 1 or step == total_iters:
268
+ metrics = {
269
+ "ts": now_ts(),
270
+ "step": step,
271
+ "kind": "rft",
272
+ "algo": cfg.rft.algo,
273
+ "task_id": task_id,
274
+ "mean_reward": mean_r,
275
+ "std_reward": std_r,
276
+ "best_reward": best[4],
277
+ "best_passed": best[3],
278
+ "pass@1": pass_at_1,
279
+ "pass@k": pass_at_k,
280
+ "acceptance": acceptance,
281
+ "tokens_per_sec": tps,
282
+ "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
283
+ "accel": backend.name,
284
+ }
285
+ if latency_summary:
286
+ metrics["verifier_latency_ms"] = latency_summary["mean"]
287
+ for key, val in latency_summary.items():
288
+ metrics[f"verifier_latency_ms_{key}"] = val
289
+ if per_verifier_summary:
290
+ metrics["verifier_latency_ms_by_path"] = per_verifier_summary
291
+ write_jsonl(run.metrics_path, [metrics])
292
+
293
+ for (token_ids, prompt_len, completion, passed, reward, _info) in gens:
294
+ if passed:
295
+ prompt_text = llm.decode(token_ids[:prompt_len]) if prompt_len > 0 else ""
296
+ write_jsonl(
297
+ accepted_path,
298
+ [
299
+ {
300
+ "prompt": prompt_text,
301
+ "response": completion,
302
+ "reward": reward,
303
+ "task_id": task_id,
304
+ }
305
+ ],
306
+ )
307
+
308
+ if step % cfg.train.save_every == 0 or step == total_iters:
309
+ llm.save_adapter(
310
+ str(run.adapter_dir),
311
+ metadata={
312
+ "base_model": base_model,
313
+ "source_adapter": str(adapter_path) if adapter_path else None,
314
+ "run": run.run_dir.name,
315
+ "kind": "rft",
316
+ },
317
+ )
318
+
319
+ console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
320
+ return run
321
+
322
+ for step in range(1, total_iters + 1):
323
+ task = tasks[(step - 1) % len(tasks)]
324
+ prompt = task.get("prompt", "")
325
+ task_id = task.get("id") or sha1_text(prompt)[:12]
326
+
327
+ gens = []
328
+ gen_tokens = 0
329
+ gen_start = time.time()
330
+ verifier_times = []
331
+ per_verifier_latencies: dict[str, list[float]] = {}
332
+
333
+ for k in range(rollouts):
334
+ gen = llm.generate(
335
+ prompt,
336
+ max_new_tokens=max_new,
337
+ temperature=temperature,
338
+ seed=rng.randint(0, 2**31 - 1),
339
+ )
340
+ completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
341
+ gen_tokens += max(0, len(gen.token_ids) - gen.prompt_len)
342
+
343
+ wdir = ensure_dir(run.artifacts_dir / task_id / f"step_{step:06d}" / f"rollout_{k:02d}")
344
+ if "tests" in task:
345
+ tdir = ensure_dir(wdir / "tests")
346
+ (tdir / "test_task.py").write_text(task["tests"], encoding="utf-8")
347
+
348
+ t0 = time.time()
349
+ res = verify(prompt, completion, str(wdir), **(task.get("verifier_kwargs") or {}))
350
+ verifier_times.append(time.time() - t0)
351
+ per_lat = getattr(res, "info", {}) or {}
352
+ per_lat = per_lat.get("verifier_latencies_ms") if isinstance(per_lat, dict) else None
353
+ if isinstance(per_lat, dict):
354
+ for path, val in per_lat.items():
355
+ try:
356
+ per_verifier_latencies.setdefault(str(path), []).append(float(val))
357
+ except (TypeError, ValueError):
358
+ continue
359
+
360
+ passed = bool(getattr(res, "passed", False))
361
+ reward = float(getattr(res, "reward", 0.0))
362
+ gens.append((gen, completion, passed, reward))
363
+
364
+ gen_elapsed = max(time.time() - gen_start, 1e-6)
365
+ tps = gen_tokens / gen_elapsed
366
+
367
+ mean_r = sum(r for *_rest, r in gens) / max(1, len(gens))
368
+ std_r = (
369
+ sum((r - mean_r) ** 2 for *_rest, r in gens) / max(1, len(gens))
370
+ ) ** 0.5
371
+ advs = [r - mean_r for *_rest, r in gens]
372
+ if normalize_adv and std_r > 1e-6:
373
+ advs = [a / std_r for a in advs]
374
+
375
+ def loss_fn(_model):
376
+ loss = llm.mx.array(0.0) # type: ignore
377
+ for (gen, _comp, _passed, _reward), adv in zip(gens, advs):
378
+ logp = llm.sequence_logprob(gen.token_ids, prompt_len=gen.prompt_len)
379
+ pg = -llm.mx.array(float(adv)) * logp # type: ignore
380
+ if ref_llm is not None and kl_coeff > 0:
381
+ ref_logp = ref_llm.sequence_logprob(gen.token_ids, prompt_len=gen.prompt_len)
382
+ pg = pg + llm.mx.array(kl_coeff) * (logp - ref_logp) # type: ignore
383
+ loss = loss + pg
384
+ return loss / llm.mx.array(float(len(gens))) # type: ignore
385
+
386
+ lval, grads = llm.value_and_grad(loss_fn)
387
+ if grads is not None:
388
+ llm.apply_grads(opt, grads)
389
+
390
+ best_idx = max(range(len(gens)), key=lambda i: gens[i][3])
391
+ best = gens[best_idx]
392
+ pass_at_1 = 1.0 if gens[0][2] else 0.0
393
+ pass_at_k = 1.0 if any(passed for _g, _c, passed, _r in gens) else 0.0
394
+ acceptance = sum(1 for *_rest, passed, _reward in gens if passed) / max(1, len(gens))
395
+
396
+ latency_summary = latency_summary_ms([t * 1000.0 for t in verifier_times])
397
+ per_verifier_summary = {
398
+ path: latency_summary_ms(vals) for path, vals in per_verifier_latencies.items()
399
+ }
400
+
401
+ if step % cfg.train.log_every == 0 or step == 1 or step == total_iters:
402
+ write_jsonl(
403
+ run.metrics_path,
404
+ [
405
+ {
406
+ "ts": now_ts(),
407
+ "step": step,
408
+ "kind": "rft",
409
+ "algo": cfg.rft.algo,
410
+ "task_id": task_id,
411
+ "mean_reward": mean_r,
412
+ "std_reward": std_r,
413
+ "best_reward": best[3],
414
+ "best_passed": best[2],
415
+ "pass@1": pass_at_1,
416
+ "pass@k": pass_at_k,
417
+ "acceptance": acceptance,
418
+ "verifier_latency_ms": latency_summary.get("mean", 0.0),
419
+ "verifier_latency_ms_mean": latency_summary.get("mean", 0.0),
420
+ "verifier_latency_ms_p50": latency_summary.get("p50", 0.0),
421
+ "verifier_latency_ms_p90": latency_summary.get("p90", 0.0),
422
+ "verifier_latency_ms_p99": latency_summary.get("p99", 0.0),
423
+ "verifier_latency_ms_max": latency_summary.get("max", 0.0),
424
+ "verifier_latency_ms_by_path": per_verifier_summary,
425
+ "tokens_per_sec": tps,
426
+ "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
427
+ "accel": backend.name,
428
+ }
429
+ ],
430
+ )
431
+
432
+ for (gen, completion, passed, reward) in gens:
433
+ if passed:
434
+ write_jsonl(
435
+ accepted_path,
436
+ [
437
+ {
438
+ "prompt": prompt,
439
+ "response": completion,
440
+ "reward": reward,
441
+ "task_id": task_id,
442
+ }
443
+ ],
444
+ )
445
+
446
+ if step % cfg.train.save_every == 0 or step == total_iters:
447
+ llm.save_adapter(
448
+ str(run.adapter_dir),
449
+ metadata={
450
+ "base_model": base_model,
451
+ "source_adapter": str(adapter_path) if adapter_path else None,
452
+ "run": run.run_dir.name,
453
+ "kind": "rft",
454
+ },
455
+ )
456
+
457
+ console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
458
+ return run
mlxsmith/train/sft.py ADDED
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+
7
+ from rich.console import Console
8
+
9
+ from ..accel import get_backend
10
+ from ..config import ProjectConfig
11
+ from ..models import resolve_model_spec
12
+ from ..runs import RunPaths, new_run, snapshot_config
13
+ from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
14
+ from ..llm.registry import get_llm_backend
15
+ from ..llm.backend import BackendNotAvailable
16
+ from .lora import LoRAConfig
17
+
18
+ console = Console()
19
+
20
+
21
+ def _load_sft_rows(train_path: Path) -> list[dict]:
22
+ return [json.loads(line) for line in train_path.read_text(encoding="utf-8").splitlines() if line.strip()]
23
+
24
+
25
+ def _row_to_prompt_response(row: dict) -> tuple[str, str]:
26
+ prompt = row.get("prompt") or row.get("instruction") or row.get("input") or ""
27
+ response = row.get("response") or row.get("output") or row.get("completion") or row.get("answer") or ""
28
+ if not response and "messages" in row:
29
+ msgs = row.get("messages") or []
30
+ if msgs:
31
+ prompt = "\n".join([m.get("content", "") for m in msgs[:-1]])
32
+ response = msgs[-1].get("content", "") or ""
33
+ return prompt, response
34
+
35
+
36
+ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_path: str, accel: str) -> RunPaths:
37
+ run = new_run(project_root, "sft")
38
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
39
+
40
+ backend = get_backend(accel)
41
+ backend.patch()
42
+ console.print(f"[bold]SFT[/bold] run: {run.run_dir.name} accel={backend.name}")
43
+
44
+ train_path = data_dir / "train.jsonl"
45
+ if not train_path.exists():
46
+ raise RuntimeError(
47
+ "Missing train.jsonl. Run `mlxsmith data split` or point --data to a dir containing train.jsonl"
48
+ )
49
+ rows = _load_sft_rows(train_path)
50
+
51
+ llm = get_llm_backend(cfg.model.backend)
52
+ base_model, adapter_path, adapter_meta = resolve_model_spec(project_root, model_id_or_path, cfg)
53
+
54
+ try:
55
+ llm.load(
56
+ base_model,
57
+ max_seq_len=cfg.model.max_seq_len,
58
+ dtype=cfg.model.dtype,
59
+ trust_remote_code=cfg.model.trust_remote_code,
60
+ )
61
+ if adapter_path:
62
+ llm.apply_adapter(str(adapter_path))
63
+ else:
64
+ lora_cfg = LoRAConfig(
65
+ r=cfg.lora.r,
66
+ alpha=cfg.lora.alpha,
67
+ dropout=cfg.lora.dropout,
68
+ target_modules=list(cfg.lora.target_modules or []),
69
+ num_layers=cfg.lora.num_layers,
70
+ scale=cfg.lora.scale,
71
+ fine_tune_type=cfg.lora.fine_tune_type,
72
+ )
73
+ llm.apply_lora_from_config(lora_cfg)
74
+ except BackendNotAvailable as e:
75
+ console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
76
+ (run.adapter_dir / "ADAPTER.txt").write_text(
77
+ f"Backend unavailable in this environment.\nmodel={model_id_or_path}\naccel={backend.name}\n",
78
+ encoding="utf-8",
79
+ )
80
+ return run
81
+
82
+ opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
83
+
84
+ total = int(cfg.train.iters)
85
+ grad_accum = max(1, int(cfg.train.grad_accum))
86
+ train_on_prompt = bool(getattr(cfg.train, "train_on_prompt", False))
87
+ max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 1.0))
88
+
89
+ rng = random.Random(cfg.train.seed)
90
+ accum_grads = None
91
+ accum_loss = 0.0
92
+
93
+ for step in range(1, total + 1):
94
+ row = rng.choice(rows)
95
+ prompt, response = _row_to_prompt_response(row)
96
+ if not response:
97
+ continue
98
+
99
+ text = f"{prompt}{response}"
100
+ prompt_ids = llm.encode(prompt)
101
+ ids = llm.encode(text)
102
+ max_len = int(cfg.model.max_seq_len)
103
+ if max_len and len(ids) > max_len:
104
+ overflow = len(ids) - max_len
105
+ ids = ids[overflow:]
106
+ prompt_ids = prompt_ids[overflow:] if overflow < len(prompt_ids) else []
107
+
108
+ def loss_fn(_model):
109
+ return llm.sft_loss(ids, train_on_prompt=train_on_prompt, prompt_len=len(prompt_ids))
110
+
111
+ lval, grads = llm.value_and_grad(loss_fn)
112
+ accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
113
+ if grads is not None:
114
+ accum_grads = tree_add(accum_grads, grads)
115
+
116
+ if step % grad_accum == 0:
117
+ if accum_grads is not None:
118
+ scaled = tree_scale(accum_grads, 1.0 / grad_accum)
119
+ if max_grad_norm > 0:
120
+ scaled = clip_grad_norm(scaled, max_grad_norm)
121
+ llm.apply_grads(opt, scaled)
122
+ accum_grads = None
123
+ accum_loss = 0.0
124
+
125
+ if step % cfg.train.log_every == 0 or step == 1 or step == total:
126
+ write_jsonl(
127
+ run.metrics_path,
128
+ [
129
+ {
130
+ "ts": now_ts(),
131
+ "step": step,
132
+ "kind": "sft",
133
+ "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
134
+ "accel": backend.name,
135
+ }
136
+ ],
137
+ )
138
+
139
+ if step % cfg.train.save_every == 0 or step == total:
140
+ llm.save_adapter(
141
+ str(run.adapter_dir),
142
+ metadata={
143
+ "base_model": base_model,
144
+ "source_adapter": str(adapter_path) if adapter_path else None,
145
+ "run": run.run_dir.name,
146
+ "kind": "sft",
147
+ },
148
+ )
149
+
150
+ console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
151
+ return run