mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/cli.py ADDED
@@ -0,0 +1,950 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import typer
6
+ from rich.console import Console
7
+ from rich.table import Table
8
+
9
+ from .auth import get_status as get_auth_status, login as hf_login, logout as hf_logout
10
+ from .config import (
11
+ ProjectConfig,
12
+ dump_config,
13
+ get_config,
14
+ load_config,
15
+ resolve_config_path,
16
+ show_merged_config,
17
+ write_default_config,
18
+ )
19
+ from .data import import_sharegpt, split_jsonl, pull_hf_dataset, list_presets, resolve_preset, analyze_jsonl
20
+ from .models import hf_pull, quantize_stub
21
+ from .util import detect_system, ensure_dir
22
+ from .accel import get_backend
23
+ from .train.sft import run_sft
24
+ from .train.pref import run_pref
25
+ from .train.rft import run_rft
26
+ from .train.distill import run_distill
27
+ from .eval import run_eval
28
+ from .bench import run_bench
29
+ from .rlm import run_rlm, run_rlm_orchestrated
30
+ from .rlm.gating import load_state as _load_rlm_state
31
+ from .adapters import merge_adapters
32
+ from .envs import (
33
+ init_env as init_env_plugin,
34
+ install_env as install_env_plugin,
35
+ list_registry_packages,
36
+ package_env as package_env_plugin,
37
+ pull_env as pull_env_plugin,
38
+ publish_env as publish_env_plugin,
39
+ registry_info as registry_info_plugin,
40
+ resolve_env_path as resolve_env_path_plugin,
41
+ load_manifest as load_env_manifest,
42
+ )
43
+
44
+ app = typer.Typer(
45
+ add_completion=False,
46
+ help="mlxsmith — MLX fine-tuning + OpenAI-compatible serving (SFT stable; preference/RL experimental)",
47
+ )
48
+ console = Console()
49
+
50
+
51
+ def project_root_from_cwd() -> Path:
52
+ return Path.cwd()
53
+
54
+
55
+ @app.command()
56
+ def init(path: str = typer.Argument(..., help="Project directory to create")):
57
+ p = Path(path)
58
+ p.mkdir(parents=True, exist_ok=True)
59
+ for d in ["data/sft", "data/prefs", "models", "envs", "verifiers", "eval/suites", "runs", "cache", "bench"]:
60
+ (p / d).mkdir(parents=True, exist_ok=True)
61
+ cfg_path = p / "mlxsmith.yaml"
62
+ if not cfg_path.exists():
63
+ write_default_config(cfg_path)
64
+ (p / "envs" / "coding.yaml").write_text(_sample_env_yaml(), encoding="utf-8")
65
+ (p / "verifiers" / "regex.py").write_text(_sample_verifier_regex(), encoding="utf-8")
66
+ (p / "verifiers" / "pytest.py").write_text(_sample_verifier_pytest(), encoding="utf-8")
67
+ (p / "verifiers" / "jsonschema.py").write_text(_sample_verifier_jsonschema(), encoding="utf-8")
68
+ (p / "eval" / "suites" / "coding.yaml").write_text(_sample_eval_suite(), encoding="utf-8")
69
+ console.print(f"[green]Initialized[/green] {p.resolve()}")
70
+
71
+
72
+ @app.command()
73
+ def doctor():
74
+ info = detect_system()
75
+ table = Table(title="mlxsmith doctor")
76
+ table.add_column("item")
77
+ table.add_column("value")
78
+ table.add_row("python", info.python)
79
+ table.add_row("python_arch", info.python_arch)
80
+ table.add_row("platform", info.platform)
81
+ table.add_row("macos_version", info.macos_version or "n/a")
82
+ table.add_row("machine", info.machine)
83
+ table.add_row("cpu_count", str(info.cpu_count))
84
+ table.add_row("metal", str(info.has_metal))
85
+ table.add_row("mlx", f"{info.has_mlx} {info.mlx_version or ''}".strip())
86
+ table.add_row("zmlx", str(info.has_zmlx))
87
+ console.print(table)
88
+
89
+
90
+ @app.command()
91
+ def pull(
92
+ model: str = typer.Argument(..., help="Hugging Face model id"),
93
+ out: Optional[str] = typer.Option(None, "--out", help="MLX output path (defaults to cache/mlx/<model>)"),
94
+ no_convert: bool = typer.Option(False, "--no-convert", help="Only download HF snapshot, skip MLX conversion"),
95
+ quantize: bool = typer.Option(False, "--quantize", help="Quantize during conversion"),
96
+ q_bits: Optional[int] = typer.Option(None, "--q-bits"),
97
+ q_group_size: Optional[int] = typer.Option(None, "--q-group-size"),
98
+ q_mode: Optional[str] = typer.Option(None, "--q-mode"),
99
+ quant_predicate: Optional[str] = typer.Option(None, "--quant-predicate"),
100
+ trust_remote_code: bool = typer.Option(False, "--trust-remote-code"),
101
+ ):
102
+ root = project_root_from_cwd()
103
+ cache_dir = ensure_dir(root / "cache")
104
+ out_path = Path(out) if out else None
105
+ dst = hf_pull(
106
+ model,
107
+ cache_dir,
108
+ convert=not no_convert,
109
+ mlx_path=out_path,
110
+ quantize=quantize,
111
+ q_bits=q_bits,
112
+ q_group_size=q_group_size,
113
+ q_mode=q_mode,
114
+ quant_predicate=quant_predicate,
115
+ trust_remote_code=trust_remote_code,
116
+ )
117
+ console.print(f"[green]Pulled[/green] {model} -> {dst}")
118
+
119
+
120
+ @app.command()
121
+ def quantize(
122
+ model_path: str = typer.Argument(...),
123
+ to: str = typer.Option("q4"),
124
+ out: str = typer.Option("models/quantized"),
125
+ ):
126
+ root = project_root_from_cwd()
127
+ out_path = Path(out)
128
+ if not out_path.is_absolute():
129
+ out_path = root / out_path
130
+ result = quantize_stub(Path(model_path), out_path, to)
131
+ console.print(f"[green]Quant stub wrote[/green] {result}")
132
+
133
+
134
+ data_app = typer.Typer(help="Dataset utilities")
135
+ app.add_typer(data_app, name="data")
136
+
137
+
138
+ @data_app.command("import")
139
+ def data_import(in_path: str = typer.Option(..., "--in"), fmt: str = typer.Option("sharegpt", "--format"), out_path: str = typer.Option(..., "--out")):
140
+ root = project_root_from_cwd()
141
+ inp = Path(in_path)
142
+ outp = Path(out_path)
143
+ if not outp.is_absolute():
144
+ outp = root / outp
145
+ if fmt.lower() == "sharegpt":
146
+ n = import_sharegpt(inp, outp)
147
+ console.print(f"[green]Wrote[/green] {n} rows -> {outp}")
148
+ else:
149
+ raise typer.BadParameter(f"Unsupported format: {fmt}")
150
+
151
+
152
+ @data_app.command("split")
153
+ def data_split(
154
+ in_path: str = typer.Option(..., "--in"),
155
+ out_dir: str = typer.Option("data/sft", "--out-dir"),
156
+ valid: float = typer.Option(0.02),
157
+ test: float = typer.Option(0.02),
158
+ seed: int = typer.Option(1337),
159
+ ):
160
+ root = project_root_from_cwd()
161
+ inp = Path(in_path)
162
+ outd = Path(out_dir)
163
+ if not outd.is_absolute():
164
+ outd = root / outd
165
+ stats = split_jsonl(inp, outd, valid_frac=valid, test_frac=test, seed=seed)
166
+ console.print(f"[green]Split[/green] -> {outd} {stats}")
167
+
168
+
169
+ @data_app.command("stats")
170
+ def data_stats(
171
+ in_path: str = typer.Option(..., "--in"),
172
+ kind: Optional[str] = typer.Option(None, "--kind", help="sft | prefs (auto if omitted)"),
173
+ limit: Optional[int] = typer.Option(None, "--limit"),
174
+ ):
175
+ root = project_root_from_cwd()
176
+ inp = Path(in_path)
177
+ if not inp.is_absolute():
178
+ inp = root / inp
179
+ stats = analyze_jsonl(inp, kind=kind, limit=limit)
180
+ table = Table(title="mlxsmith data stats")
181
+ table.add_column("metric")
182
+ table.add_column("value")
183
+ table.add_row("kind", str(stats.get("kind")))
184
+ table.add_row("rows", str(stats.get("rows")))
185
+ table.add_row("empty_lines", str(stats.get("empty_lines")))
186
+ table.add_row("bad_json", str(stats.get("bad_json")))
187
+ table.add_row("missing_prompt", str(stats.get("missing_prompt")))
188
+ if stats.get("kind") == "prefs":
189
+ table.add_row("missing_chosen", str(stats.get("missing_chosen")))
190
+ table.add_row("missing_rejected", str(stats.get("missing_rejected")))
191
+ chosen_count = max(1, stats.get("chosen_count", 0))
192
+ rejected_count = max(1, stats.get("rejected_count", 0))
193
+ table.add_row("avg_prompt_chars", f"{stats.get('prompt_chars', 0) / max(1, stats.get('prompt_count', 0)):.1f}")
194
+ table.add_row("avg_chosen_chars", f"{stats.get('chosen_chars', 0) / chosen_count:.1f}")
195
+ table.add_row("avg_rejected_chars", f"{stats.get('rejected_chars', 0) / rejected_count:.1f}")
196
+ else:
197
+ table.add_row("missing_response", str(stats.get("missing_response")))
198
+ response_count = max(1, stats.get("response_count", 0))
199
+ table.add_row("avg_prompt_chars", f"{stats.get('prompt_chars', 0) / max(1, stats.get('prompt_count', 0)):.1f}")
200
+ table.add_row("avg_response_chars", f"{stats.get('response_chars', 0) / response_count:.1f}")
201
+ console.print(table)
202
+
203
+
204
+ @data_app.command("validate")
205
+ def data_validate(
206
+ in_path: str = typer.Option(..., "--in"),
207
+ kind: Optional[str] = typer.Option(None, "--kind", help="sft | prefs (auto if omitted)"),
208
+ limit: Optional[int] = typer.Option(None, "--limit"),
209
+ strict: bool = typer.Option(True, "--strict/--no-strict"),
210
+ ):
211
+ root = project_root_from_cwd()
212
+ inp = Path(in_path)
213
+ if not inp.is_absolute():
214
+ inp = root / inp
215
+ stats = analyze_jsonl(inp, kind=kind, limit=limit)
216
+ issues = []
217
+ if stats.get("bad_json", 0):
218
+ issues.append(f"bad_json={stats.get('bad_json')}")
219
+ if stats.get("missing_prompt", 0):
220
+ issues.append(f"missing_prompt={stats.get('missing_prompt')}")
221
+ if stats.get("kind") == "prefs":
222
+ if stats.get("missing_chosen", 0):
223
+ issues.append(f"missing_chosen={stats.get('missing_chosen')}")
224
+ if stats.get("missing_rejected", 0):
225
+ issues.append(f"missing_rejected={stats.get('missing_rejected')}")
226
+ else:
227
+ if stats.get("missing_response", 0):
228
+ issues.append(f"missing_response={stats.get('missing_response')}")
229
+ if issues:
230
+ console.print(f"[yellow]Issues:[/yellow] {', '.join(issues)}")
231
+ if strict:
232
+ raise typer.Exit(code=1)
233
+ console.print(f"[green]OK[/green] kind={stats.get('kind')} rows={stats.get('rows')}")
234
+
235
+
236
+ @data_app.command("presets")
237
+ def data_presets():
238
+ presets = list_presets()
239
+ if not presets:
240
+ console.print("[yellow]No presets defined[/yellow]")
241
+ return
242
+ table = Table(title="mlxsmith data presets")
243
+ table.add_column("name")
244
+ table.add_column("dataset")
245
+ table.add_column("kind")
246
+ table.add_column("split")
247
+ table.add_column("license")
248
+ for name, cfg in presets.items():
249
+ table.add_row(
250
+ name,
251
+ str(cfg.get("dataset", "")),
252
+ str(cfg.get("kind", "")),
253
+ str(cfg.get("split", "")),
254
+ str(cfg.get("license", "")),
255
+ )
256
+ console.print(table)
257
+
258
+
259
+ @data_app.command("pull")
260
+ def data_pull(
261
+ dataset: Optional[str] = typer.Option(None, "--dataset", help="HF dataset name"),
262
+ preset: Optional[str] = typer.Option(None, "--preset", help="Preset name"),
263
+ split: str = typer.Option("train", "--split"),
264
+ out_dir: str = typer.Option("data/sft", "--out-dir"),
265
+ kind: str = typer.Option("sft", "--kind", help="Dataset kind: sft or prefs"),
266
+ limit: Optional[int] = typer.Option(None, "--limit"),
267
+ prompt_field: Optional[str] = typer.Option(None, "--prompt-field"),
268
+ response_field: Optional[str] = typer.Option(None, "--response-field"),
269
+ chosen_field: Optional[str] = typer.Option(None, "--chosen-field"),
270
+ rejected_field: Optional[str] = typer.Option(None, "--rejected-field"),
271
+ config: Optional[str] = typer.Option(None, "--config"),
272
+ revision: Optional[str] = typer.Option(None, "--revision"),
273
+ ):
274
+ root = project_root_from_cwd()
275
+ outd = Path(out_dir)
276
+ if not outd.is_absolute():
277
+ outd = root / outd
278
+ license_name = None
279
+ notes = None
280
+ if preset:
281
+ preset_cfg = resolve_preset(preset)
282
+ dataset = dataset or preset_cfg.get("dataset")
283
+ kind = preset_cfg.get("kind", kind)
284
+ split = preset_cfg.get("split", split)
285
+ config = preset_cfg.get("config", config)
286
+ revision = preset_cfg.get("revision", revision)
287
+ prompt_field = prompt_field or preset_cfg.get("prompt_field")
288
+ response_field = response_field or preset_cfg.get("response_field")
289
+ chosen_field = chosen_field or preset_cfg.get("chosen_field")
290
+ rejected_field = rejected_field or preset_cfg.get("rejected_field")
291
+ license_name = preset_cfg.get("license")
292
+ notes = preset_cfg.get("notes")
293
+ if not dataset:
294
+ raise typer.BadParameter("Missing --dataset (or use --preset)")
295
+ stats = pull_hf_dataset(
296
+ dataset,
297
+ outd,
298
+ split=split,
299
+ limit=limit,
300
+ prompt_field=prompt_field,
301
+ response_field=response_field,
302
+ chosen_field=chosen_field,
303
+ rejected_field=rejected_field,
304
+ config=config,
305
+ revision=revision,
306
+ kind=kind,
307
+ license=license_name,
308
+ notes=notes,
309
+ preset=preset,
310
+ )
311
+ console.print(f"[green]Pulled[/green] {stats}")
312
+
313
+
314
+ @app.command()
315
+ def sft(
316
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
317
+ data: str = typer.Option("data/sft", "--data"),
318
+ model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
319
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
320
+ lr: Optional[float] = typer.Option(None, "--lr", help="Override train.lr (learning rate)"),
321
+ iters: Optional[int] = typer.Option(None, "--iters", help="Override train.iters"),
322
+ batch_size: Optional[int] = typer.Option(None, "--batch-size", help="Override train.batch_size"),
323
+ ):
324
+ root = project_root_from_cwd()
325
+ cfg = get_config(
326
+ config_path=config,
327
+ root=root,
328
+ model_id=model,
329
+ accel_backend=accel,
330
+ lr=lr,
331
+ iters=iters,
332
+ batch_size=batch_size,
333
+ )
334
+ data_dir = root / data
335
+ run = run_sft(root, cfg, data_dir, cfg.model.id, cfg.accel.backend)
336
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
337
+
338
+
339
+ @app.command()
340
+ def pref(
341
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
342
+ data: str = typer.Option("data/prefs", "--data"),
343
+ model: str = typer.Option(..., "--model", help="Base adapter or model path (e.g., runs/sft_0001/adapter)"),
344
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
345
+ algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (dpo|orpo|grpo)"),
346
+ ):
347
+ root = project_root_from_cwd()
348
+ cfg = get_config(
349
+ config_path=config,
350
+ root=root,
351
+ accel_backend=accel,
352
+ algo=algo,
353
+ )
354
+ data_dir = root / data
355
+ run = run_pref(root, cfg, data_dir, Path(model), cfg.accel.backend)
356
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
357
+
358
+
359
+ @app.command()
360
+ def rft(
361
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
362
+ env: str = typer.Option("envs/coding.yaml", "--env"),
363
+ verifier: str = typer.Option("verifiers/regex.py", "--verifier"),
364
+ model: str = typer.Option(..., "--model"),
365
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
366
+ rollouts: Optional[int] = typer.Option(None, "--rollouts", help="Override rft.rollouts"),
367
+ ):
368
+ root = project_root_from_cwd()
369
+ cfg = get_config(
370
+ config_path=config,
371
+ root=root,
372
+ accel_backend=accel,
373
+ rollouts=rollouts,
374
+ )
375
+ run = run_rft(root, cfg, root / env, root / verifier, Path(model), cfg.accel.backend)
376
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
377
+
378
+
379
+ @app.command()
380
+ def pipeline(
381
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
382
+ model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
383
+ data_sft: str = typer.Option("data/sft", "--data-sft"),
384
+ data_pref: str = typer.Option("data/prefs", "--data-pref"),
385
+ env: str = typer.Option("envs/coding.yaml", "--env"),
386
+ verifier: str = typer.Option("verifiers/regex.py", "--verifier"),
387
+ iterations: Optional[int] = typer.Option(None, "--iterations", help="Override rlm.iterations"),
388
+ resume: bool = typer.Option(False, "--resume"),
389
+ orchestrated: bool = typer.Option(False, "--orchestrated", help="Use multi-process orchestrator mode"),
390
+ ):
391
+ """Run SFT -> Pref -> RFT -> RLM in one command."""
392
+ root = project_root_from_cwd()
393
+ cfg = get_config(config_path=config, root=root, model_id=model, iterations=iterations)
394
+
395
+ # SFT
396
+ run_sft_out = run_sft(root, cfg, root / data_sft, cfg.model.id, cfg.accel.backend)
397
+ console.print(f"[bold]SFT[/bold] {run_sft_out.run_dir}")
398
+
399
+ # Pref (DPO/ORPO)
400
+ run_pref_out = run_pref(root, cfg, root / data_pref, run_sft_out.adapter_dir, cfg.accel.backend)
401
+ console.print(f"[bold]PREF[/bold] {run_pref_out.run_dir}")
402
+
403
+ # RFT (GRPO)
404
+ run_rft_out = run_rft(root, cfg, root / env, root / verifier, run_pref_out.adapter_dir, cfg.accel.backend)
405
+ console.print(f"[bold]RFT[/bold] {run_rft_out.run_dir}")
406
+
407
+ # RLM
408
+ if orchestrated:
409
+ run_rlm_orchestrated(root, cfg, model_spec=str(run_rft_out.adapter_dir), iterations=iterations, resume=resume)
410
+ else:
411
+ run_rlm(root, cfg, model_spec=str(run_rft_out.adapter_dir), iterations=iterations, resume=resume)
412
+
413
+
414
+ @app.command()
415
+ def distill(
416
+ data: str = typer.Option(..., "--data", help="JSONL with prompts"),
417
+ teacher: str = typer.Option(..., "--teacher"),
418
+ student: str = typer.Option(..., "--student"),
419
+ mode: str = typer.Option("offline", "--mode", help="offline | opd"),
420
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
421
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
422
+ max_new_tokens: int = typer.Option(256, "--max-new-tokens"),
423
+ temperature: float = typer.Option(0.7, "--temperature"),
424
+ ):
425
+ root = project_root_from_cwd()
426
+ cfg = get_config(config_path=config, root=root, accel_backend=accel)
427
+ run = run_distill(
428
+ root,
429
+ cfg,
430
+ Path(data),
431
+ teacher_model=teacher,
432
+ student_model=student,
433
+ accel=cfg.accel.backend,
434
+ mode=mode,
435
+ max_new_tokens=max_new_tokens,
436
+ temperature=temperature,
437
+ )
438
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
439
+
440
+
441
+ @app.command()
442
+ def eval(
443
+ suite: str = typer.Option("eval/suites/coding.yaml", "--suite"),
444
+ model: str = typer.Option(..., "--model"),
445
+ ):
446
+ root = project_root_from_cwd()
447
+ out = run_eval(root, root / suite, Path(model))
448
+ console.print(f"[bold]Eval:[/bold] {out}")
449
+
450
+
451
+ @app.command()
452
+ def serve(
453
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
454
+ model: str = typer.Option(..., "--model"),
455
+ host: Optional[str] = typer.Option(None, "--host", help="Override serve.host"),
456
+ port: Optional[int] = typer.Option(None, "--port", help="Override serve.port"),
457
+ ui: Optional[bool] = typer.Option(None, "--ui", help="Override serve.ui (true/false)"),
458
+ ):
459
+ root = project_root_from_cwd()
460
+ cfg = get_config(
461
+ config_path=config,
462
+ root=root,
463
+ host=host,
464
+ port=port,
465
+ ui=ui,
466
+ )
467
+ h = cfg.serve.host
468
+ p = cfg.serve.port
469
+
470
+ import uvicorn
471
+ from .server import create_app
472
+
473
+ server_app = create_app(model, cfg)
474
+ console.print(f"[green]Serving[/green] model={model} on http://{h}:{p}")
475
+ uvicorn.run(server_app, host=h, port=p)
476
+
477
+
478
+ @app.command()
479
+ def bench(
480
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
481
+ model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
482
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
483
+ prompt: str = typer.Option("Hello", "--prompt"),
484
+ max_tokens: int = typer.Option(128, "--max-tokens"),
485
+ reps: int = typer.Option(3, "--reps"),
486
+ mode: str = typer.Option("inference", "--mode", help="inference | trainer | end_to_end"),
487
+ steps: int = typer.Option(5, "--steps", help="trainer mode steps per rep"),
488
+ ):
489
+ root = project_root_from_cwd()
490
+ cfg = get_config(
491
+ config_path=config,
492
+ root=root,
493
+ model_id=model,
494
+ accel_backend=accel,
495
+ )
496
+ out = run_bench(
497
+ root,
498
+ cfg,
499
+ cfg.model.id,
500
+ cfg.accel.backend,
501
+ prompt=prompt,
502
+ max_tokens=max_tokens,
503
+ reps=reps,
504
+ mode=mode,
505
+ steps=steps,
506
+ )
507
+ console.print(f"[bold]Bench:[/bold] {out}")
508
+
509
+
510
+ # Config subcommand
511
+ config_app = typer.Typer(help="Configuration management")
512
+ app.add_typer(config_app, name="config")
513
+
514
+
515
+ @config_app.command("show")
516
+ def config_show(
517
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
518
+ format: str = typer.Option("yaml", "-f", "--format", help="Output format: yaml, json, toml"),
519
+ sources: bool = typer.Option(False, "--sources", help="Show value sources (default, env, file, cli)"),
520
+ root: Optional[str] = typer.Option(None, "--root", help="Project root directory"),
521
+ ):
522
+ """Show merged configuration with all overrides applied."""
523
+ project_root = Path(root) if root else project_root_from_cwd()
524
+ cfg_path = resolve_config_path(config, root=project_root)
525
+
526
+ if sources:
527
+ from .config import get_config_sources
528
+ cfg, srcs = get_config_sources(cfg_path if cfg_path.exists() else None)
529
+ output = show_merged_config(cfg, show_sources=True, sources=srcs)
530
+ console.print(output)
531
+ else:
532
+ cfg = load_config(cfg_path if cfg_path.exists() else None)
533
+ try:
534
+ output = dump_config(cfg, format=format)
535
+ console.print(output)
536
+ except ValueError as e:
537
+ raise typer.BadParameter(str(e))
538
+
539
+
540
+ @config_app.command("init")
541
+ def config_init(
542
+ path: str = typer.Argument("mlxsmith.yaml", help="Output config file path"),
543
+ format: str = typer.Option("yaml", "-f", "--format", help="Output format: yaml, json, toml"),
544
+ ):
545
+ """Initialize a new configuration file with defaults."""
546
+ out_path = Path(path)
547
+ if out_path.exists():
548
+ overwrite = typer.confirm(f"File {path} already exists. Overwrite?")
549
+ if not overwrite:
550
+ raise typer.Exit()
551
+
552
+ write_default_config(out_path, format=format)
553
+ console.print(f"[green]Created config file:[/green] {out_path.resolve()}")
554
+
555
+
556
+ @config_app.command("validate")
557
+ def config_validate(
558
+ config: str = typer.Argument(..., help="Config file path to validate"),
559
+ root: Optional[str] = typer.Option(None, "--root", help="Project root directory"),
560
+ ):
561
+ """Validate a configuration file."""
562
+ project_root = Path(root) if root else project_root_from_cwd()
563
+ cfg_path = resolve_config_path(config, root=project_root)
564
+
565
+ try:
566
+ cfg = load_config(cfg_path, require=True)
567
+ console.print(f"[green]✓ Configuration is valid[/green]")
568
+
569
+ # Show summary
570
+ table = Table(title="Configuration Summary")
571
+ table.add_column("Section")
572
+ table.add_column("Key Settings")
573
+
574
+ data = cfg.model_dump()
575
+ for section, values in data.items():
576
+ if isinstance(values, dict):
577
+ summary = ", ".join(f"{k}={v}" for k, v in list(values.items())[:3])
578
+ if len(values) > 3:
579
+ summary += f" ... ({len(values) - 3} more)"
580
+ table.add_row(section, summary)
581
+
582
+ console.print(table)
583
+
584
+ except Exception as e:
585
+ console.print(f"[red]✗ Configuration validation failed:[/red] {e}")
586
+ raise typer.Exit(code=1)
587
+
588
+
589
+ @config_app.command("env")
590
+ def config_env(
591
+ prefix: str = typer.Option("MLXSMITH__", "--prefix", help="Environment variable prefix"),
592
+ ):
593
+ """Show available environment variables."""
594
+ cfg = ProjectConfig()
595
+
596
+ console.print(f"\n[bold]Environment Variable Configuration[/bold]")
597
+ console.print(f"Prefix: [cyan]{prefix}[/cyan]")
598
+ console.print(f"Nested delimiter: [cyan]__[/cyan] (double underscore)\n")
599
+
600
+ table = Table(title=f"Available {prefix}* Environment Variables")
601
+ table.add_column("Environment Variable")
602
+ table.add_column("Default Value")
603
+ table.add_column("Description")
604
+
605
+ data = cfg.model_dump()
606
+ for section_name, section_data in data.items():
607
+ if not isinstance(section_data, dict):
608
+ continue
609
+ for key, value in section_data.items():
610
+ env_var = f"{prefix}{section_name.upper()}__{key.upper()}"
611
+ value_str = str(value) if value is not None else "None"
612
+ if len(value_str) > 40:
613
+ value_str = value_str[:37] + "..."
614
+ table.add_row(env_var, value_str, f"{section_name}.{key}")
615
+
616
+ console.print(table)
617
+ console.print("\n[dim]Example: MLXSMITH__MODEL__ID=custom/model mlxsmith sft[/dim]")
618
+
619
+
620
+ auth_app = typer.Typer(help="Hugging Face authentication")
621
+ app.add_typer(auth_app, name="auth")
622
+
623
+
624
+ @auth_app.command("login")
625
+ def auth_login(
626
+ token: Optional[str] = typer.Option(None, "--token", envvar="HF_TOKEN"),
627
+ validate: bool = typer.Option(True, "--validate", help="Validate token with HF API"),
628
+ ):
629
+ if not token:
630
+ token = typer.prompt("Hugging Face token", hide_input=True)
631
+ status = hf_login(token, validate=validate)
632
+ if status.user:
633
+ console.print(f"[green]Logged in[/green] as {status.user}")
634
+ else:
635
+ hint = f" ({status.token_hint})" if status.token_hint else ""
636
+ console.print(f"[green]Token saved[/green]{hint}")
637
+ for warning in status.warnings:
638
+ console.print(f"[yellow]{warning}[/yellow]")
639
+
640
+
641
+ @auth_app.command("status")
642
+ def auth_status(validate: bool = typer.Option(False, "--validate", help="Validate token with HF API")):
643
+ status = get_auth_status(validate=validate)
644
+ if not status.token_present:
645
+ console.print("[yellow]No token found[/yellow]")
646
+ return
647
+ if status.user:
648
+ console.print(f"[green]Logged in[/green] as {status.user}")
649
+ else:
650
+ hint = f" ({status.token_hint})" if status.token_hint else ""
651
+ console.print(f"[green]Token present[/green]{hint}")
652
+ for warning in status.warnings:
653
+ console.print(f"[yellow]{warning}[/yellow]")
654
+
655
+
656
+ @auth_app.command("logout")
657
+ def auth_logout():
658
+ if hf_logout():
659
+ console.print("[green]Logged out[/green]")
660
+ else:
661
+ console.print("[yellow]No token found[/yellow]")
662
+
663
+
664
+ accel_app = typer.Typer(help="Acceleration utilities")
665
+ app.add_typer(accel_app, name="accel")
666
+
667
+ rlm_app = typer.Typer(help="Recursive Language Model (RLM) loop")
668
+ app.add_typer(rlm_app, name="rlm")
669
+
670
+ env_app = typer.Typer(help="Environment plugins")
671
+ app.add_typer(env_app, name="env")
672
+
673
+ adapter_app = typer.Typer(help="Adapter utilities")
674
+ app.add_typer(adapter_app, name="adapters")
675
+
676
+
677
+ @rlm_app.callback(invoke_without_command=True)
678
+ def rlm_callback(
679
+ ctx: typer.Context,
680
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
681
+ model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
682
+ iterations: Optional[int] = typer.Option(None, "--iterations", help="Override rlm.iterations"),
683
+ resume: bool = typer.Option(False, "--resume"),
684
+ orchestrated: bool = typer.Option(False, "--orchestrated", help="Use multi-process orchestrator mode"),
685
+ ):
686
+ if ctx.invoked_subcommand is not None:
687
+ return
688
+ root = project_root_from_cwd()
689
+ cfg = get_config(
690
+ config_path=config,
691
+ root=root,
692
+ model_id=model,
693
+ iterations=iterations,
694
+ )
695
+ if orchestrated:
696
+ run_rlm_orchestrated(root, cfg, model_spec=model, iterations=iterations, resume=resume)
697
+ else:
698
+ run_rlm(root, cfg, model_spec=model, iterations=iterations, resume=resume)
699
+
700
+
701
+ @rlm_app.command("status")
702
+ def rlm_status():
703
+ root = project_root_from_cwd()
704
+ state_path = root / "runs" / "rlm_state.json"
705
+ state = _load_rlm_state(state_path)
706
+ table = Table(title="mlxsmith rlm status")
707
+ table.add_column("item")
708
+ table.add_column("value")
709
+ table.add_row("last_iteration", str(state.last_iteration))
710
+ table.add_row("current_adapter", state.current_adapter or "n/a")
711
+ table.add_row("best_adapter", state.best_adapter or "n/a")
712
+ table.add_row("best_score", str(state.best_score) if state.best_score is not None else "n/a")
713
+ table.add_row("ema_score", str(state.ema_score) if state.ema_score is not None else "n/a")
714
+ console.print(table)
715
+
716
+
717
+ @rlm_app.command("history")
718
+ def rlm_history(limit: int = typer.Option(10, "--limit")):
719
+ root = project_root_from_cwd()
720
+ history_path = root / "runs" / "rlm_history.jsonl"
721
+ if not history_path.exists():
722
+ console.print("[yellow]No history found[/yellow]")
723
+ return
724
+ lines = history_path.read_text(encoding="utf-8").splitlines()
725
+ tail = lines[-limit:] if limit > 0 else lines
726
+ for line in tail:
727
+ console.print(line)
728
+
729
+
730
+ @accel_app.command("status")
731
+ def accel_status():
732
+ backends = ["none", "zmlx"]
733
+ table = Table(title="mlxsmith accel status")
734
+ table.add_column("backend")
735
+ table.add_column("available")
736
+ table.add_column("notes")
737
+ for name in backends:
738
+ b = get_backend(name)
739
+ stats = b.stats()
740
+ table.add_row(stats.backend, "yes" if not stats.notes or "error" not in (stats.notes or {}) else "no", json.dumps(stats.notes or {}))
741
+ console.print(table)
742
+
743
+
744
+ @env_app.command("init")
745
+ def env_init(name: str = typer.Argument(..., help="Environment name")):
746
+ root = project_root_from_cwd()
747
+ env_dir = init_env_plugin(root, name)
748
+ console.print(f"[green]Initialized env[/green] {env_dir}")
749
+
750
+
751
+ @env_app.command("list")
752
+ def env_list(
753
+ name: Optional[str] = typer.Argument(None, help="Filter by env name"),
754
+ all_versions: bool = typer.Option(False, "--all", help="Show all versions"),
755
+ ):
756
+ root = project_root_from_cwd()
757
+ packages = list_registry_packages(root, name=name, all_versions=all_versions)
758
+ if not packages:
759
+ console.print("[yellow]No registry entries found[/yellow]")
760
+ return
761
+ table = Table(title="mlxsmith env registry")
762
+ table.add_column("name")
763
+ table.add_column("version")
764
+ table.add_column("description")
765
+ for pkg in packages:
766
+ table.add_row(
767
+ str(pkg.get("name") or ""),
768
+ str(pkg.get("version") or ""),
769
+ str(pkg.get("description") or ""),
770
+ )
771
+ console.print(table)
772
+
773
+
774
+ @env_app.command("info")
775
+ def env_info(
776
+ env: str = typer.Argument(..., help="Env name (optionally name@version or name==version)"),
777
+ version: Optional[str] = typer.Option(None, "--version", help="Pin to a specific version"),
778
+ ):
779
+ root = project_root_from_cwd()
780
+ pkg, manifest = registry_info_plugin(root, env, version=version)
781
+ table = Table(title=f"mlxsmith env info: {manifest.name}")
782
+ table.add_column("field")
783
+ table.add_column("value")
784
+ table.add_row("name", manifest.name)
785
+ table.add_row("version", manifest.version)
786
+ table.add_row("description", manifest.description or "n/a")
787
+ table.add_row("verifier", manifest.verifier or "n/a")
788
+ table.add_row("tasks", str(len(manifest.tasks or [])))
789
+ table.add_row("token_env", "yes" if manifest.token_env else "no")
790
+ table.add_row("registry_path", str(pkg.get("path") or ""))
791
+ console.print(table)
792
+
793
+
794
+ @env_app.command("install")
795
+ def env_install(
796
+ source: str = typer.Argument(..., help="Env dir, package path, or registry name"),
797
+ version: Optional[str] = typer.Option(None, "--version", help="Pin to a specific version when using registry"),
798
+ ):
799
+ root = project_root_from_cwd()
800
+ env_dir = install_env_plugin(root, source, version=version)
801
+ console.print(f"[green]Installed env[/green] {env_dir}")
802
+
803
+
804
+ @env_app.command("package")
805
+ def env_package(
806
+ name: str = typer.Argument(..., help="Env name (directory under envs/)"),
807
+ out: str = typer.Option(None, "--out", help="Output directory for package"),
808
+ ):
809
+ root = project_root_from_cwd()
810
+ out_path = package_env_plugin(root, name, out_path=out)
811
+ console.print(f"[green]Packaged env[/green] {out_path}")
812
+
813
+
814
+ @env_app.command("publish")
815
+ def env_publish(
816
+ package: str = typer.Argument(..., help="Path to .tar.gz package"),
817
+ ):
818
+ root = project_root_from_cwd()
819
+ dest = publish_env_plugin(root, package)
820
+ console.print(f"[green]Published env[/green] {dest}")
821
+
822
+
823
+ @env_app.command("pull")
824
+ def env_pull(
825
+ env: str = typer.Argument(..., help="Env name (optionally name@version or name==version)"),
826
+ out: Optional[str] = typer.Option(None, "--out", help="Output directory"),
827
+ version: Optional[str] = typer.Option(None, "--version", help="Pin to a specific version"),
828
+ force: bool = typer.Option(False, "--force", help="Overwrite destination if it exists"),
829
+ ):
830
+ root = project_root_from_cwd()
831
+ dest = pull_env_plugin(root, env, out_dir=out, version=version, force=force)
832
+ console.print(f"[green]Pulled env[/green] {dest}")
833
+
834
+
835
+ @env_app.command("run")
836
+ def env_run(
837
+ env: str = typer.Argument(..., help="Env name or path to env.yaml"),
838
+ model: str = typer.Option(..., "--model"),
839
+ accel: Optional[str] = typer.Option(None, "--accel"),
840
+ verifier: Optional[str] = typer.Option(None, "--verifier"),
841
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
842
+ ):
843
+ root = project_root_from_cwd()
844
+ cfg = get_config(config_path=config, root=root, accel_backend=accel)
845
+ env_path = resolve_env_path_plugin(root, env)
846
+ if env_path.is_dir():
847
+ env_path = env_path / "env.yaml"
848
+ if not env_path.exists():
849
+ raise typer.BadParameter(f"Env not found: {env}")
850
+ manifest = load_env_manifest(env_path)
851
+ verifier_path = verifier or manifest.verifier or "verifiers/regex.py"
852
+ vpath = Path(verifier_path)
853
+ if not vpath.is_absolute():
854
+ vpath = root / vpath
855
+ run = run_rft(root, cfg, env_path, vpath, Path(model), cfg.accel.backend)
856
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
857
+
858
+
859
+ @env_app.command("registry")
860
+ def env_registry():
861
+ root = project_root_from_cwd()
862
+ registry_path = root / "envs" / "registry.json"
863
+ if not registry_path.exists():
864
+ console.print("[yellow]No registry found[/yellow]")
865
+ return
866
+ console.print(registry_path.read_text(encoding="utf-8"))
867
+
868
+
869
+ @adapter_app.command("merge")
870
+ def adapters_merge(
871
+ base: str = typer.Option(..., "--base", help="Base model id or path"),
872
+ adapters: str = typer.Option(..., "--adapters", help="Comma-separated adapter paths"),
873
+ out: str = typer.Option("models/merged_adapter", "--out"),
874
+ weights: Optional[str] = typer.Option(None, "--weights", help="Comma-separated weights"),
875
+ ):
876
+ root = project_root_from_cwd()
877
+ adapter_paths = [Path(p.strip()) for p in adapters.split(",") if p.strip()]
878
+ out_path = Path(out)
879
+ if not out_path.is_absolute():
880
+ out_path = root / out_path
881
+ w = None
882
+ if weights:
883
+ w = [float(x) for x in weights.split(",") if x.strip()]
884
+ result = merge_adapters(base, adapter_paths, out_path, weights=w)
885
+ console.print(f"[green]Merged adapters[/green] {result}")
886
+
887
+
888
+ def _sample_env_yaml() -> str:
889
+ return """name: coding-sample
890
+ tasks:
891
+ - id: hello
892
+ prompt: |
893
+ Write a Python function `add(a, b)` that returns the sum.
894
+ gold: |
895
+ def add(a, b):
896
+ return a + b
897
+ verifier_kwargs:
898
+ pattern: "def\\s+add\\("
899
+ - id: pytest_task
900
+ prompt: |
901
+ Implement `mul(a,b)` in main.py. Tests provided.
902
+ tests: |
903
+ from main import mul
904
+ def test_mul():
905
+ assert mul(2,3) == 6
906
+ assert mul(-1,5) == -5
907
+ gold: |
908
+ def mul(a,b):
909
+ return a*b
910
+ """
911
+
912
+
913
+ def _sample_verifier_regex() -> str:
914
+ return """from mlxsmith.verifiers.regex import verify as _verify
915
+
916
+ def verify(prompt: str, completion: str, workdir: str, **kwargs):
917
+ return _verify(prompt, completion, workdir, **kwargs)
918
+ """
919
+
920
+
921
+ def _sample_verifier_pytest() -> str:
922
+ return """from mlxsmith.verifiers.pytest_verifier import verify as _verify
923
+
924
+ def verify(prompt: str, completion: str, workdir: str, **kwargs):
925
+ return _verify(prompt, completion, workdir, **kwargs)
926
+ """
927
+
928
+
929
+ def _sample_verifier_jsonschema() -> str:
930
+ return """from mlxsmith.verifiers.jsonschema import verify as _verify
931
+
932
+ def verify(prompt: str, completion: str, workdir: str, **kwargs):
933
+ return _verify(prompt, completion, workdir, **kwargs)
934
+ """
935
+
936
+
937
+ def _sample_eval_suite() -> str:
938
+ return """name: coding-eval-sample
939
+ notes: |
940
+ Minimal eval suite for smoke testing.
941
+ tasks:
942
+ - id: add
943
+ prompt: |
944
+ Write a Python function `add(a, b)` that returns the sum.
945
+ k: 2
946
+ max_new_tokens: 128
947
+ verifier: verifiers/regex.py
948
+ verifier_kwargs:
949
+ pattern: "def\\s+add\\("
950
+ """