mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/cli.py
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
|
|
9
|
+
from .auth import get_status as get_auth_status, login as hf_login, logout as hf_logout
|
|
10
|
+
from .config import (
|
|
11
|
+
ProjectConfig,
|
|
12
|
+
dump_config,
|
|
13
|
+
get_config,
|
|
14
|
+
load_config,
|
|
15
|
+
resolve_config_path,
|
|
16
|
+
show_merged_config,
|
|
17
|
+
write_default_config,
|
|
18
|
+
)
|
|
19
|
+
from .data import import_sharegpt, split_jsonl, pull_hf_dataset, list_presets, resolve_preset, analyze_jsonl
|
|
20
|
+
from .models import hf_pull, quantize_stub
|
|
21
|
+
from .util import detect_system, ensure_dir
|
|
22
|
+
from .accel import get_backend
|
|
23
|
+
from .train.sft import run_sft
|
|
24
|
+
from .train.pref import run_pref
|
|
25
|
+
from .train.rft import run_rft
|
|
26
|
+
from .train.distill import run_distill
|
|
27
|
+
from .eval import run_eval
|
|
28
|
+
from .bench import run_bench
|
|
29
|
+
from .rlm import run_rlm, run_rlm_orchestrated
|
|
30
|
+
from .rlm.gating import load_state as _load_rlm_state
|
|
31
|
+
from .adapters import merge_adapters
|
|
32
|
+
from .envs import (
|
|
33
|
+
init_env as init_env_plugin,
|
|
34
|
+
install_env as install_env_plugin,
|
|
35
|
+
list_registry_packages,
|
|
36
|
+
package_env as package_env_plugin,
|
|
37
|
+
pull_env as pull_env_plugin,
|
|
38
|
+
publish_env as publish_env_plugin,
|
|
39
|
+
registry_info as registry_info_plugin,
|
|
40
|
+
resolve_env_path as resolve_env_path_plugin,
|
|
41
|
+
load_manifest as load_env_manifest,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
app = typer.Typer(
|
|
45
|
+
add_completion=False,
|
|
46
|
+
help="mlxsmith — MLX fine-tuning + OpenAI-compatible serving (SFT stable; preference/RL experimental)",
|
|
47
|
+
)
|
|
48
|
+
console = Console()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def project_root_from_cwd() -> Path:
|
|
52
|
+
return Path.cwd()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@app.command()
|
|
56
|
+
def init(path: str = typer.Argument(..., help="Project directory to create")):
|
|
57
|
+
p = Path(path)
|
|
58
|
+
p.mkdir(parents=True, exist_ok=True)
|
|
59
|
+
for d in ["data/sft", "data/prefs", "models", "envs", "verifiers", "eval/suites", "runs", "cache", "bench"]:
|
|
60
|
+
(p / d).mkdir(parents=True, exist_ok=True)
|
|
61
|
+
cfg_path = p / "mlxsmith.yaml"
|
|
62
|
+
if not cfg_path.exists():
|
|
63
|
+
write_default_config(cfg_path)
|
|
64
|
+
(p / "envs" / "coding.yaml").write_text(_sample_env_yaml(), encoding="utf-8")
|
|
65
|
+
(p / "verifiers" / "regex.py").write_text(_sample_verifier_regex(), encoding="utf-8")
|
|
66
|
+
(p / "verifiers" / "pytest.py").write_text(_sample_verifier_pytest(), encoding="utf-8")
|
|
67
|
+
(p / "verifiers" / "jsonschema.py").write_text(_sample_verifier_jsonschema(), encoding="utf-8")
|
|
68
|
+
(p / "eval" / "suites" / "coding.yaml").write_text(_sample_eval_suite(), encoding="utf-8")
|
|
69
|
+
console.print(f"[green]Initialized[/green] {p.resolve()}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@app.command()
|
|
73
|
+
def doctor():
|
|
74
|
+
info = detect_system()
|
|
75
|
+
table = Table(title="mlxsmith doctor")
|
|
76
|
+
table.add_column("item")
|
|
77
|
+
table.add_column("value")
|
|
78
|
+
table.add_row("python", info.python)
|
|
79
|
+
table.add_row("python_arch", info.python_arch)
|
|
80
|
+
table.add_row("platform", info.platform)
|
|
81
|
+
table.add_row("macos_version", info.macos_version or "n/a")
|
|
82
|
+
table.add_row("machine", info.machine)
|
|
83
|
+
table.add_row("cpu_count", str(info.cpu_count))
|
|
84
|
+
table.add_row("metal", str(info.has_metal))
|
|
85
|
+
table.add_row("mlx", f"{info.has_mlx} {info.mlx_version or ''}".strip())
|
|
86
|
+
table.add_row("zmlx", str(info.has_zmlx))
|
|
87
|
+
console.print(table)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@app.command()
|
|
91
|
+
def pull(
|
|
92
|
+
model: str = typer.Argument(..., help="Hugging Face model id"),
|
|
93
|
+
out: Optional[str] = typer.Option(None, "--out", help="MLX output path (defaults to cache/mlx/<model>)"),
|
|
94
|
+
no_convert: bool = typer.Option(False, "--no-convert", help="Only download HF snapshot, skip MLX conversion"),
|
|
95
|
+
quantize: bool = typer.Option(False, "--quantize", help="Quantize during conversion"),
|
|
96
|
+
q_bits: Optional[int] = typer.Option(None, "--q-bits"),
|
|
97
|
+
q_group_size: Optional[int] = typer.Option(None, "--q-group-size"),
|
|
98
|
+
q_mode: Optional[str] = typer.Option(None, "--q-mode"),
|
|
99
|
+
quant_predicate: Optional[str] = typer.Option(None, "--quant-predicate"),
|
|
100
|
+
trust_remote_code: bool = typer.Option(False, "--trust-remote-code"),
|
|
101
|
+
):
|
|
102
|
+
root = project_root_from_cwd()
|
|
103
|
+
cache_dir = ensure_dir(root / "cache")
|
|
104
|
+
out_path = Path(out) if out else None
|
|
105
|
+
dst = hf_pull(
|
|
106
|
+
model,
|
|
107
|
+
cache_dir,
|
|
108
|
+
convert=not no_convert,
|
|
109
|
+
mlx_path=out_path,
|
|
110
|
+
quantize=quantize,
|
|
111
|
+
q_bits=q_bits,
|
|
112
|
+
q_group_size=q_group_size,
|
|
113
|
+
q_mode=q_mode,
|
|
114
|
+
quant_predicate=quant_predicate,
|
|
115
|
+
trust_remote_code=trust_remote_code,
|
|
116
|
+
)
|
|
117
|
+
console.print(f"[green]Pulled[/green] {model} -> {dst}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@app.command()
|
|
121
|
+
def quantize(
|
|
122
|
+
model_path: str = typer.Argument(...),
|
|
123
|
+
to: str = typer.Option("q4"),
|
|
124
|
+
out: str = typer.Option("models/quantized"),
|
|
125
|
+
):
|
|
126
|
+
root = project_root_from_cwd()
|
|
127
|
+
out_path = Path(out)
|
|
128
|
+
if not out_path.is_absolute():
|
|
129
|
+
out_path = root / out_path
|
|
130
|
+
result = quantize_stub(Path(model_path), out_path, to)
|
|
131
|
+
console.print(f"[green]Quant stub wrote[/green] {result}")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
data_app = typer.Typer(help="Dataset utilities")
|
|
135
|
+
app.add_typer(data_app, name="data")
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@data_app.command("import")
|
|
139
|
+
def data_import(in_path: str = typer.Option(..., "--in"), fmt: str = typer.Option("sharegpt", "--format"), out_path: str = typer.Option(..., "--out")):
|
|
140
|
+
root = project_root_from_cwd()
|
|
141
|
+
inp = Path(in_path)
|
|
142
|
+
outp = Path(out_path)
|
|
143
|
+
if not outp.is_absolute():
|
|
144
|
+
outp = root / outp
|
|
145
|
+
if fmt.lower() == "sharegpt":
|
|
146
|
+
n = import_sharegpt(inp, outp)
|
|
147
|
+
console.print(f"[green]Wrote[/green] {n} rows -> {outp}")
|
|
148
|
+
else:
|
|
149
|
+
raise typer.BadParameter(f"Unsupported format: {fmt}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@data_app.command("split")
|
|
153
|
+
def data_split(
|
|
154
|
+
in_path: str = typer.Option(..., "--in"),
|
|
155
|
+
out_dir: str = typer.Option("data/sft", "--out-dir"),
|
|
156
|
+
valid: float = typer.Option(0.02),
|
|
157
|
+
test: float = typer.Option(0.02),
|
|
158
|
+
seed: int = typer.Option(1337),
|
|
159
|
+
):
|
|
160
|
+
root = project_root_from_cwd()
|
|
161
|
+
inp = Path(in_path)
|
|
162
|
+
outd = Path(out_dir)
|
|
163
|
+
if not outd.is_absolute():
|
|
164
|
+
outd = root / outd
|
|
165
|
+
stats = split_jsonl(inp, outd, valid_frac=valid, test_frac=test, seed=seed)
|
|
166
|
+
console.print(f"[green]Split[/green] -> {outd} {stats}")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@data_app.command("stats")
|
|
170
|
+
def data_stats(
|
|
171
|
+
in_path: str = typer.Option(..., "--in"),
|
|
172
|
+
kind: Optional[str] = typer.Option(None, "--kind", help="sft | prefs (auto if omitted)"),
|
|
173
|
+
limit: Optional[int] = typer.Option(None, "--limit"),
|
|
174
|
+
):
|
|
175
|
+
root = project_root_from_cwd()
|
|
176
|
+
inp = Path(in_path)
|
|
177
|
+
if not inp.is_absolute():
|
|
178
|
+
inp = root / inp
|
|
179
|
+
stats = analyze_jsonl(inp, kind=kind, limit=limit)
|
|
180
|
+
table = Table(title="mlxsmith data stats")
|
|
181
|
+
table.add_column("metric")
|
|
182
|
+
table.add_column("value")
|
|
183
|
+
table.add_row("kind", str(stats.get("kind")))
|
|
184
|
+
table.add_row("rows", str(stats.get("rows")))
|
|
185
|
+
table.add_row("empty_lines", str(stats.get("empty_lines")))
|
|
186
|
+
table.add_row("bad_json", str(stats.get("bad_json")))
|
|
187
|
+
table.add_row("missing_prompt", str(stats.get("missing_prompt")))
|
|
188
|
+
if stats.get("kind") == "prefs":
|
|
189
|
+
table.add_row("missing_chosen", str(stats.get("missing_chosen")))
|
|
190
|
+
table.add_row("missing_rejected", str(stats.get("missing_rejected")))
|
|
191
|
+
chosen_count = max(1, stats.get("chosen_count", 0))
|
|
192
|
+
rejected_count = max(1, stats.get("rejected_count", 0))
|
|
193
|
+
table.add_row("avg_prompt_chars", f"{stats.get('prompt_chars', 0) / max(1, stats.get('prompt_count', 0)):.1f}")
|
|
194
|
+
table.add_row("avg_chosen_chars", f"{stats.get('chosen_chars', 0) / chosen_count:.1f}")
|
|
195
|
+
table.add_row("avg_rejected_chars", f"{stats.get('rejected_chars', 0) / rejected_count:.1f}")
|
|
196
|
+
else:
|
|
197
|
+
table.add_row("missing_response", str(stats.get("missing_response")))
|
|
198
|
+
response_count = max(1, stats.get("response_count", 0))
|
|
199
|
+
table.add_row("avg_prompt_chars", f"{stats.get('prompt_chars', 0) / max(1, stats.get('prompt_count', 0)):.1f}")
|
|
200
|
+
table.add_row("avg_response_chars", f"{stats.get('response_chars', 0) / response_count:.1f}")
|
|
201
|
+
console.print(table)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@data_app.command("validate")
|
|
205
|
+
def data_validate(
|
|
206
|
+
in_path: str = typer.Option(..., "--in"),
|
|
207
|
+
kind: Optional[str] = typer.Option(None, "--kind", help="sft | prefs (auto if omitted)"),
|
|
208
|
+
limit: Optional[int] = typer.Option(None, "--limit"),
|
|
209
|
+
strict: bool = typer.Option(True, "--strict/--no-strict"),
|
|
210
|
+
):
|
|
211
|
+
root = project_root_from_cwd()
|
|
212
|
+
inp = Path(in_path)
|
|
213
|
+
if not inp.is_absolute():
|
|
214
|
+
inp = root / inp
|
|
215
|
+
stats = analyze_jsonl(inp, kind=kind, limit=limit)
|
|
216
|
+
issues = []
|
|
217
|
+
if stats.get("bad_json", 0):
|
|
218
|
+
issues.append(f"bad_json={stats.get('bad_json')}")
|
|
219
|
+
if stats.get("missing_prompt", 0):
|
|
220
|
+
issues.append(f"missing_prompt={stats.get('missing_prompt')}")
|
|
221
|
+
if stats.get("kind") == "prefs":
|
|
222
|
+
if stats.get("missing_chosen", 0):
|
|
223
|
+
issues.append(f"missing_chosen={stats.get('missing_chosen')}")
|
|
224
|
+
if stats.get("missing_rejected", 0):
|
|
225
|
+
issues.append(f"missing_rejected={stats.get('missing_rejected')}")
|
|
226
|
+
else:
|
|
227
|
+
if stats.get("missing_response", 0):
|
|
228
|
+
issues.append(f"missing_response={stats.get('missing_response')}")
|
|
229
|
+
if issues:
|
|
230
|
+
console.print(f"[yellow]Issues:[/yellow] {', '.join(issues)}")
|
|
231
|
+
if strict:
|
|
232
|
+
raise typer.Exit(code=1)
|
|
233
|
+
console.print(f"[green]OK[/green] kind={stats.get('kind')} rows={stats.get('rows')}")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@data_app.command("presets")
|
|
237
|
+
def data_presets():
|
|
238
|
+
presets = list_presets()
|
|
239
|
+
if not presets:
|
|
240
|
+
console.print("[yellow]No presets defined[/yellow]")
|
|
241
|
+
return
|
|
242
|
+
table = Table(title="mlxsmith data presets")
|
|
243
|
+
table.add_column("name")
|
|
244
|
+
table.add_column("dataset")
|
|
245
|
+
table.add_column("kind")
|
|
246
|
+
table.add_column("split")
|
|
247
|
+
table.add_column("license")
|
|
248
|
+
for name, cfg in presets.items():
|
|
249
|
+
table.add_row(
|
|
250
|
+
name,
|
|
251
|
+
str(cfg.get("dataset", "")),
|
|
252
|
+
str(cfg.get("kind", "")),
|
|
253
|
+
str(cfg.get("split", "")),
|
|
254
|
+
str(cfg.get("license", "")),
|
|
255
|
+
)
|
|
256
|
+
console.print(table)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@data_app.command("pull")
|
|
260
|
+
def data_pull(
|
|
261
|
+
dataset: Optional[str] = typer.Option(None, "--dataset", help="HF dataset name"),
|
|
262
|
+
preset: Optional[str] = typer.Option(None, "--preset", help="Preset name"),
|
|
263
|
+
split: str = typer.Option("train", "--split"),
|
|
264
|
+
out_dir: str = typer.Option("data/sft", "--out-dir"),
|
|
265
|
+
kind: str = typer.Option("sft", "--kind", help="Dataset kind: sft or prefs"),
|
|
266
|
+
limit: Optional[int] = typer.Option(None, "--limit"),
|
|
267
|
+
prompt_field: Optional[str] = typer.Option(None, "--prompt-field"),
|
|
268
|
+
response_field: Optional[str] = typer.Option(None, "--response-field"),
|
|
269
|
+
chosen_field: Optional[str] = typer.Option(None, "--chosen-field"),
|
|
270
|
+
rejected_field: Optional[str] = typer.Option(None, "--rejected-field"),
|
|
271
|
+
config: Optional[str] = typer.Option(None, "--config"),
|
|
272
|
+
revision: Optional[str] = typer.Option(None, "--revision"),
|
|
273
|
+
):
|
|
274
|
+
root = project_root_from_cwd()
|
|
275
|
+
outd = Path(out_dir)
|
|
276
|
+
if not outd.is_absolute():
|
|
277
|
+
outd = root / outd
|
|
278
|
+
license_name = None
|
|
279
|
+
notes = None
|
|
280
|
+
if preset:
|
|
281
|
+
preset_cfg = resolve_preset(preset)
|
|
282
|
+
dataset = dataset or preset_cfg.get("dataset")
|
|
283
|
+
kind = preset_cfg.get("kind", kind)
|
|
284
|
+
split = preset_cfg.get("split", split)
|
|
285
|
+
config = preset_cfg.get("config", config)
|
|
286
|
+
revision = preset_cfg.get("revision", revision)
|
|
287
|
+
prompt_field = prompt_field or preset_cfg.get("prompt_field")
|
|
288
|
+
response_field = response_field or preset_cfg.get("response_field")
|
|
289
|
+
chosen_field = chosen_field or preset_cfg.get("chosen_field")
|
|
290
|
+
rejected_field = rejected_field or preset_cfg.get("rejected_field")
|
|
291
|
+
license_name = preset_cfg.get("license")
|
|
292
|
+
notes = preset_cfg.get("notes")
|
|
293
|
+
if not dataset:
|
|
294
|
+
raise typer.BadParameter("Missing --dataset (or use --preset)")
|
|
295
|
+
stats = pull_hf_dataset(
|
|
296
|
+
dataset,
|
|
297
|
+
outd,
|
|
298
|
+
split=split,
|
|
299
|
+
limit=limit,
|
|
300
|
+
prompt_field=prompt_field,
|
|
301
|
+
response_field=response_field,
|
|
302
|
+
chosen_field=chosen_field,
|
|
303
|
+
rejected_field=rejected_field,
|
|
304
|
+
config=config,
|
|
305
|
+
revision=revision,
|
|
306
|
+
kind=kind,
|
|
307
|
+
license=license_name,
|
|
308
|
+
notes=notes,
|
|
309
|
+
preset=preset,
|
|
310
|
+
)
|
|
311
|
+
console.print(f"[green]Pulled[/green] {stats}")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@app.command()
|
|
315
|
+
def sft(
|
|
316
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
317
|
+
data: str = typer.Option("data/sft", "--data"),
|
|
318
|
+
model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
|
|
319
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
320
|
+
lr: Optional[float] = typer.Option(None, "--lr", help="Override train.lr (learning rate)"),
|
|
321
|
+
iters: Optional[int] = typer.Option(None, "--iters", help="Override train.iters"),
|
|
322
|
+
batch_size: Optional[int] = typer.Option(None, "--batch-size", help="Override train.batch_size"),
|
|
323
|
+
):
|
|
324
|
+
root = project_root_from_cwd()
|
|
325
|
+
cfg = get_config(
|
|
326
|
+
config_path=config,
|
|
327
|
+
root=root,
|
|
328
|
+
model_id=model,
|
|
329
|
+
accel_backend=accel,
|
|
330
|
+
lr=lr,
|
|
331
|
+
iters=iters,
|
|
332
|
+
batch_size=batch_size,
|
|
333
|
+
)
|
|
334
|
+
data_dir = root / data
|
|
335
|
+
run = run_sft(root, cfg, data_dir, cfg.model.id, cfg.accel.backend)
|
|
336
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@app.command()
|
|
340
|
+
def pref(
|
|
341
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
342
|
+
data: str = typer.Option("data/prefs", "--data"),
|
|
343
|
+
model: str = typer.Option(..., "--model", help="Base adapter or model path (e.g., runs/sft_0001/adapter)"),
|
|
344
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
345
|
+
algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (dpo|orpo|grpo)"),
|
|
346
|
+
):
|
|
347
|
+
root = project_root_from_cwd()
|
|
348
|
+
cfg = get_config(
|
|
349
|
+
config_path=config,
|
|
350
|
+
root=root,
|
|
351
|
+
accel_backend=accel,
|
|
352
|
+
algo=algo,
|
|
353
|
+
)
|
|
354
|
+
data_dir = root / data
|
|
355
|
+
run = run_pref(root, cfg, data_dir, Path(model), cfg.accel.backend)
|
|
356
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@app.command()
|
|
360
|
+
def rft(
|
|
361
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
362
|
+
env: str = typer.Option("envs/coding.yaml", "--env"),
|
|
363
|
+
verifier: str = typer.Option("verifiers/regex.py", "--verifier"),
|
|
364
|
+
model: str = typer.Option(..., "--model"),
|
|
365
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
366
|
+
rollouts: Optional[int] = typer.Option(None, "--rollouts", help="Override rft.rollouts"),
|
|
367
|
+
):
|
|
368
|
+
root = project_root_from_cwd()
|
|
369
|
+
cfg = get_config(
|
|
370
|
+
config_path=config,
|
|
371
|
+
root=root,
|
|
372
|
+
accel_backend=accel,
|
|
373
|
+
rollouts=rollouts,
|
|
374
|
+
)
|
|
375
|
+
run = run_rft(root, cfg, root / env, root / verifier, Path(model), cfg.accel.backend)
|
|
376
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@app.command()
|
|
380
|
+
def pipeline(
|
|
381
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
382
|
+
model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
|
|
383
|
+
data_sft: str = typer.Option("data/sft", "--data-sft"),
|
|
384
|
+
data_pref: str = typer.Option("data/prefs", "--data-pref"),
|
|
385
|
+
env: str = typer.Option("envs/coding.yaml", "--env"),
|
|
386
|
+
verifier: str = typer.Option("verifiers/regex.py", "--verifier"),
|
|
387
|
+
iterations: Optional[int] = typer.Option(None, "--iterations", help="Override rlm.iterations"),
|
|
388
|
+
resume: bool = typer.Option(False, "--resume"),
|
|
389
|
+
orchestrated: bool = typer.Option(False, "--orchestrated", help="Use multi-process orchestrator mode"),
|
|
390
|
+
):
|
|
391
|
+
"""Run SFT -> Pref -> RFT -> RLM in one command."""
|
|
392
|
+
root = project_root_from_cwd()
|
|
393
|
+
cfg = get_config(config_path=config, root=root, model_id=model, iterations=iterations)
|
|
394
|
+
|
|
395
|
+
# SFT
|
|
396
|
+
run_sft_out = run_sft(root, cfg, root / data_sft, cfg.model.id, cfg.accel.backend)
|
|
397
|
+
console.print(f"[bold]SFT[/bold] {run_sft_out.run_dir}")
|
|
398
|
+
|
|
399
|
+
# Pref (DPO/ORPO)
|
|
400
|
+
run_pref_out = run_pref(root, cfg, root / data_pref, run_sft_out.adapter_dir, cfg.accel.backend)
|
|
401
|
+
console.print(f"[bold]PREF[/bold] {run_pref_out.run_dir}")
|
|
402
|
+
|
|
403
|
+
# RFT (GRPO)
|
|
404
|
+
run_rft_out = run_rft(root, cfg, root / env, root / verifier, run_pref_out.adapter_dir, cfg.accel.backend)
|
|
405
|
+
console.print(f"[bold]RFT[/bold] {run_rft_out.run_dir}")
|
|
406
|
+
|
|
407
|
+
# RLM
|
|
408
|
+
if orchestrated:
|
|
409
|
+
run_rlm_orchestrated(root, cfg, model_spec=str(run_rft_out.adapter_dir), iterations=iterations, resume=resume)
|
|
410
|
+
else:
|
|
411
|
+
run_rlm(root, cfg, model_spec=str(run_rft_out.adapter_dir), iterations=iterations, resume=resume)
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@app.command()
|
|
415
|
+
def distill(
|
|
416
|
+
data: str = typer.Option(..., "--data", help="JSONL with prompts"),
|
|
417
|
+
teacher: str = typer.Option(..., "--teacher"),
|
|
418
|
+
student: str = typer.Option(..., "--student"),
|
|
419
|
+
mode: str = typer.Option("offline", "--mode", help="offline | opd"),
|
|
420
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
421
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
422
|
+
max_new_tokens: int = typer.Option(256, "--max-new-tokens"),
|
|
423
|
+
temperature: float = typer.Option(0.7, "--temperature"),
|
|
424
|
+
):
|
|
425
|
+
root = project_root_from_cwd()
|
|
426
|
+
cfg = get_config(config_path=config, root=root, accel_backend=accel)
|
|
427
|
+
run = run_distill(
|
|
428
|
+
root,
|
|
429
|
+
cfg,
|
|
430
|
+
Path(data),
|
|
431
|
+
teacher_model=teacher,
|
|
432
|
+
student_model=student,
|
|
433
|
+
accel=cfg.accel.backend,
|
|
434
|
+
mode=mode,
|
|
435
|
+
max_new_tokens=max_new_tokens,
|
|
436
|
+
temperature=temperature,
|
|
437
|
+
)
|
|
438
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@app.command()
|
|
442
|
+
def eval(
|
|
443
|
+
suite: str = typer.Option("eval/suites/coding.yaml", "--suite"),
|
|
444
|
+
model: str = typer.Option(..., "--model"),
|
|
445
|
+
):
|
|
446
|
+
root = project_root_from_cwd()
|
|
447
|
+
out = run_eval(root, root / suite, Path(model))
|
|
448
|
+
console.print(f"[bold]Eval:[/bold] {out}")
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@app.command()
|
|
452
|
+
def serve(
|
|
453
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
454
|
+
model: str = typer.Option(..., "--model"),
|
|
455
|
+
host: Optional[str] = typer.Option(None, "--host", help="Override serve.host"),
|
|
456
|
+
port: Optional[int] = typer.Option(None, "--port", help="Override serve.port"),
|
|
457
|
+
ui: Optional[bool] = typer.Option(None, "--ui", help="Override serve.ui (true/false)"),
|
|
458
|
+
):
|
|
459
|
+
root = project_root_from_cwd()
|
|
460
|
+
cfg = get_config(
|
|
461
|
+
config_path=config,
|
|
462
|
+
root=root,
|
|
463
|
+
host=host,
|
|
464
|
+
port=port,
|
|
465
|
+
ui=ui,
|
|
466
|
+
)
|
|
467
|
+
h = cfg.serve.host
|
|
468
|
+
p = cfg.serve.port
|
|
469
|
+
|
|
470
|
+
import uvicorn
|
|
471
|
+
from .server import create_app
|
|
472
|
+
|
|
473
|
+
server_app = create_app(model, cfg)
|
|
474
|
+
console.print(f"[green]Serving[/green] model={model} on http://{h}:{p}")
|
|
475
|
+
uvicorn.run(server_app, host=h, port=p)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@app.command()
|
|
479
|
+
def bench(
|
|
480
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
481
|
+
model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
|
|
482
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
483
|
+
prompt: str = typer.Option("Hello", "--prompt"),
|
|
484
|
+
max_tokens: int = typer.Option(128, "--max-tokens"),
|
|
485
|
+
reps: int = typer.Option(3, "--reps"),
|
|
486
|
+
mode: str = typer.Option("inference", "--mode", help="inference | trainer | end_to_end"),
|
|
487
|
+
steps: int = typer.Option(5, "--steps", help="trainer mode steps per rep"),
|
|
488
|
+
):
|
|
489
|
+
root = project_root_from_cwd()
|
|
490
|
+
cfg = get_config(
|
|
491
|
+
config_path=config,
|
|
492
|
+
root=root,
|
|
493
|
+
model_id=model,
|
|
494
|
+
accel_backend=accel,
|
|
495
|
+
)
|
|
496
|
+
out = run_bench(
|
|
497
|
+
root,
|
|
498
|
+
cfg,
|
|
499
|
+
cfg.model.id,
|
|
500
|
+
cfg.accel.backend,
|
|
501
|
+
prompt=prompt,
|
|
502
|
+
max_tokens=max_tokens,
|
|
503
|
+
reps=reps,
|
|
504
|
+
mode=mode,
|
|
505
|
+
steps=steps,
|
|
506
|
+
)
|
|
507
|
+
console.print(f"[bold]Bench:[/bold] {out}")
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
# Config subcommand
|
|
511
|
+
config_app = typer.Typer(help="Configuration management")
|
|
512
|
+
app.add_typer(config_app, name="config")
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
@config_app.command("show")
|
|
516
|
+
def config_show(
|
|
517
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
518
|
+
format: str = typer.Option("yaml", "-f", "--format", help="Output format: yaml, json, toml"),
|
|
519
|
+
sources: bool = typer.Option(False, "--sources", help="Show value sources (default, env, file, cli)"),
|
|
520
|
+
root: Optional[str] = typer.Option(None, "--root", help="Project root directory"),
|
|
521
|
+
):
|
|
522
|
+
"""Show merged configuration with all overrides applied."""
|
|
523
|
+
project_root = Path(root) if root else project_root_from_cwd()
|
|
524
|
+
cfg_path = resolve_config_path(config, root=project_root)
|
|
525
|
+
|
|
526
|
+
if sources:
|
|
527
|
+
from .config import get_config_sources
|
|
528
|
+
cfg, srcs = get_config_sources(cfg_path if cfg_path.exists() else None)
|
|
529
|
+
output = show_merged_config(cfg, show_sources=True, sources=srcs)
|
|
530
|
+
console.print(output)
|
|
531
|
+
else:
|
|
532
|
+
cfg = load_config(cfg_path if cfg_path.exists() else None)
|
|
533
|
+
try:
|
|
534
|
+
output = dump_config(cfg, format=format)
|
|
535
|
+
console.print(output)
|
|
536
|
+
except ValueError as e:
|
|
537
|
+
raise typer.BadParameter(str(e))
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
@config_app.command("init")
|
|
541
|
+
def config_init(
|
|
542
|
+
path: str = typer.Argument("mlxsmith.yaml", help="Output config file path"),
|
|
543
|
+
format: str = typer.Option("yaml", "-f", "--format", help="Output format: yaml, json, toml"),
|
|
544
|
+
):
|
|
545
|
+
"""Initialize a new configuration file with defaults."""
|
|
546
|
+
out_path = Path(path)
|
|
547
|
+
if out_path.exists():
|
|
548
|
+
overwrite = typer.confirm(f"File {path} already exists. Overwrite?")
|
|
549
|
+
if not overwrite:
|
|
550
|
+
raise typer.Exit()
|
|
551
|
+
|
|
552
|
+
write_default_config(out_path, format=format)
|
|
553
|
+
console.print(f"[green]Created config file:[/green] {out_path.resolve()}")
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
@config_app.command("validate")
|
|
557
|
+
def config_validate(
|
|
558
|
+
config: str = typer.Argument(..., help="Config file path to validate"),
|
|
559
|
+
root: Optional[str] = typer.Option(None, "--root", help="Project root directory"),
|
|
560
|
+
):
|
|
561
|
+
"""Validate a configuration file."""
|
|
562
|
+
project_root = Path(root) if root else project_root_from_cwd()
|
|
563
|
+
cfg_path = resolve_config_path(config, root=project_root)
|
|
564
|
+
|
|
565
|
+
try:
|
|
566
|
+
cfg = load_config(cfg_path, require=True)
|
|
567
|
+
console.print(f"[green]✓ Configuration is valid[/green]")
|
|
568
|
+
|
|
569
|
+
# Show summary
|
|
570
|
+
table = Table(title="Configuration Summary")
|
|
571
|
+
table.add_column("Section")
|
|
572
|
+
table.add_column("Key Settings")
|
|
573
|
+
|
|
574
|
+
data = cfg.model_dump()
|
|
575
|
+
for section, values in data.items():
|
|
576
|
+
if isinstance(values, dict):
|
|
577
|
+
summary = ", ".join(f"{k}={v}" for k, v in list(values.items())[:3])
|
|
578
|
+
if len(values) > 3:
|
|
579
|
+
summary += f" ... ({len(values) - 3} more)"
|
|
580
|
+
table.add_row(section, summary)
|
|
581
|
+
|
|
582
|
+
console.print(table)
|
|
583
|
+
|
|
584
|
+
except Exception as e:
|
|
585
|
+
console.print(f"[red]✗ Configuration validation failed:[/red] {e}")
|
|
586
|
+
raise typer.Exit(code=1)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@config_app.command("env")
|
|
590
|
+
def config_env(
|
|
591
|
+
prefix: str = typer.Option("MLXSMITH__", "--prefix", help="Environment variable prefix"),
|
|
592
|
+
):
|
|
593
|
+
"""Show available environment variables."""
|
|
594
|
+
cfg = ProjectConfig()
|
|
595
|
+
|
|
596
|
+
console.print(f"\n[bold]Environment Variable Configuration[/bold]")
|
|
597
|
+
console.print(f"Prefix: [cyan]{prefix}[/cyan]")
|
|
598
|
+
console.print(f"Nested delimiter: [cyan]__[/cyan] (double underscore)\n")
|
|
599
|
+
|
|
600
|
+
table = Table(title=f"Available {prefix}* Environment Variables")
|
|
601
|
+
table.add_column("Environment Variable")
|
|
602
|
+
table.add_column("Default Value")
|
|
603
|
+
table.add_column("Description")
|
|
604
|
+
|
|
605
|
+
data = cfg.model_dump()
|
|
606
|
+
for section_name, section_data in data.items():
|
|
607
|
+
if not isinstance(section_data, dict):
|
|
608
|
+
continue
|
|
609
|
+
for key, value in section_data.items():
|
|
610
|
+
env_var = f"{prefix}{section_name.upper()}__{key.upper()}"
|
|
611
|
+
value_str = str(value) if value is not None else "None"
|
|
612
|
+
if len(value_str) > 40:
|
|
613
|
+
value_str = value_str[:37] + "..."
|
|
614
|
+
table.add_row(env_var, value_str, f"{section_name}.{key}")
|
|
615
|
+
|
|
616
|
+
console.print(table)
|
|
617
|
+
console.print("\n[dim]Example: MLXSMITH__MODEL__ID=custom/model mlxsmith sft[/dim]")
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
auth_app = typer.Typer(help="Hugging Face authentication")
|
|
621
|
+
app.add_typer(auth_app, name="auth")
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
@auth_app.command("login")
|
|
625
|
+
def auth_login(
|
|
626
|
+
token: Optional[str] = typer.Option(None, "--token", envvar="HF_TOKEN"),
|
|
627
|
+
validate: bool = typer.Option(True, "--validate", help="Validate token with HF API"),
|
|
628
|
+
):
|
|
629
|
+
if not token:
|
|
630
|
+
token = typer.prompt("Hugging Face token", hide_input=True)
|
|
631
|
+
status = hf_login(token, validate=validate)
|
|
632
|
+
if status.user:
|
|
633
|
+
console.print(f"[green]Logged in[/green] as {status.user}")
|
|
634
|
+
else:
|
|
635
|
+
hint = f" ({status.token_hint})" if status.token_hint else ""
|
|
636
|
+
console.print(f"[green]Token saved[/green]{hint}")
|
|
637
|
+
for warning in status.warnings:
|
|
638
|
+
console.print(f"[yellow]{warning}[/yellow]")
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
@auth_app.command("status")
|
|
642
|
+
def auth_status(validate: bool = typer.Option(False, "--validate", help="Validate token with HF API")):
|
|
643
|
+
status = get_auth_status(validate=validate)
|
|
644
|
+
if not status.token_present:
|
|
645
|
+
console.print("[yellow]No token found[/yellow]")
|
|
646
|
+
return
|
|
647
|
+
if status.user:
|
|
648
|
+
console.print(f"[green]Logged in[/green] as {status.user}")
|
|
649
|
+
else:
|
|
650
|
+
hint = f" ({status.token_hint})" if status.token_hint else ""
|
|
651
|
+
console.print(f"[green]Token present[/green]{hint}")
|
|
652
|
+
for warning in status.warnings:
|
|
653
|
+
console.print(f"[yellow]{warning}[/yellow]")
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
@auth_app.command("logout")
|
|
657
|
+
def auth_logout():
|
|
658
|
+
if hf_logout():
|
|
659
|
+
console.print("[green]Logged out[/green]")
|
|
660
|
+
else:
|
|
661
|
+
console.print("[yellow]No token found[/yellow]")
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
accel_app = typer.Typer(help="Acceleration utilities")
|
|
665
|
+
app.add_typer(accel_app, name="accel")
|
|
666
|
+
|
|
667
|
+
rlm_app = typer.Typer(help="Recursive Language Model (RLM) loop")
|
|
668
|
+
app.add_typer(rlm_app, name="rlm")
|
|
669
|
+
|
|
670
|
+
env_app = typer.Typer(help="Environment plugins")
|
|
671
|
+
app.add_typer(env_app, name="env")
|
|
672
|
+
|
|
673
|
+
adapter_app = typer.Typer(help="Adapter utilities")
|
|
674
|
+
app.add_typer(adapter_app, name="adapters")
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
@rlm_app.callback(invoke_without_command=True)
|
|
678
|
+
def rlm_callback(
|
|
679
|
+
ctx: typer.Context,
|
|
680
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
681
|
+
model: Optional[str] = typer.Option(None, "--model", help="Override model.id"),
|
|
682
|
+
iterations: Optional[int] = typer.Option(None, "--iterations", help="Override rlm.iterations"),
|
|
683
|
+
resume: bool = typer.Option(False, "--resume"),
|
|
684
|
+
orchestrated: bool = typer.Option(False, "--orchestrated", help="Use multi-process orchestrator mode"),
|
|
685
|
+
):
|
|
686
|
+
if ctx.invoked_subcommand is not None:
|
|
687
|
+
return
|
|
688
|
+
root = project_root_from_cwd()
|
|
689
|
+
cfg = get_config(
|
|
690
|
+
config_path=config,
|
|
691
|
+
root=root,
|
|
692
|
+
model_id=model,
|
|
693
|
+
iterations=iterations,
|
|
694
|
+
)
|
|
695
|
+
if orchestrated:
|
|
696
|
+
run_rlm_orchestrated(root, cfg, model_spec=model, iterations=iterations, resume=resume)
|
|
697
|
+
else:
|
|
698
|
+
run_rlm(root, cfg, model_spec=model, iterations=iterations, resume=resume)
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
@rlm_app.command("status")
|
|
702
|
+
def rlm_status():
|
|
703
|
+
root = project_root_from_cwd()
|
|
704
|
+
state_path = root / "runs" / "rlm_state.json"
|
|
705
|
+
state = _load_rlm_state(state_path)
|
|
706
|
+
table = Table(title="mlxsmith rlm status")
|
|
707
|
+
table.add_column("item")
|
|
708
|
+
table.add_column("value")
|
|
709
|
+
table.add_row("last_iteration", str(state.last_iteration))
|
|
710
|
+
table.add_row("current_adapter", state.current_adapter or "n/a")
|
|
711
|
+
table.add_row("best_adapter", state.best_adapter or "n/a")
|
|
712
|
+
table.add_row("best_score", str(state.best_score) if state.best_score is not None else "n/a")
|
|
713
|
+
table.add_row("ema_score", str(state.ema_score) if state.ema_score is not None else "n/a")
|
|
714
|
+
console.print(table)
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
@rlm_app.command("history")
|
|
718
|
+
def rlm_history(limit: int = typer.Option(10, "--limit")):
|
|
719
|
+
root = project_root_from_cwd()
|
|
720
|
+
history_path = root / "runs" / "rlm_history.jsonl"
|
|
721
|
+
if not history_path.exists():
|
|
722
|
+
console.print("[yellow]No history found[/yellow]")
|
|
723
|
+
return
|
|
724
|
+
lines = history_path.read_text(encoding="utf-8").splitlines()
|
|
725
|
+
tail = lines[-limit:] if limit > 0 else lines
|
|
726
|
+
for line in tail:
|
|
727
|
+
console.print(line)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
@accel_app.command("status")
|
|
731
|
+
def accel_status():
|
|
732
|
+
backends = ["none", "zmlx"]
|
|
733
|
+
table = Table(title="mlxsmith accel status")
|
|
734
|
+
table.add_column("backend")
|
|
735
|
+
table.add_column("available")
|
|
736
|
+
table.add_column("notes")
|
|
737
|
+
for name in backends:
|
|
738
|
+
b = get_backend(name)
|
|
739
|
+
stats = b.stats()
|
|
740
|
+
table.add_row(stats.backend, "yes" if not stats.notes or "error" not in (stats.notes or {}) else "no", json.dumps(stats.notes or {}))
|
|
741
|
+
console.print(table)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
@env_app.command("init")
|
|
745
|
+
def env_init(name: str = typer.Argument(..., help="Environment name")):
|
|
746
|
+
root = project_root_from_cwd()
|
|
747
|
+
env_dir = init_env_plugin(root, name)
|
|
748
|
+
console.print(f"[green]Initialized env[/green] {env_dir}")
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
@env_app.command("list")
|
|
752
|
+
def env_list(
|
|
753
|
+
name: Optional[str] = typer.Argument(None, help="Filter by env name"),
|
|
754
|
+
all_versions: bool = typer.Option(False, "--all", help="Show all versions"),
|
|
755
|
+
):
|
|
756
|
+
root = project_root_from_cwd()
|
|
757
|
+
packages = list_registry_packages(root, name=name, all_versions=all_versions)
|
|
758
|
+
if not packages:
|
|
759
|
+
console.print("[yellow]No registry entries found[/yellow]")
|
|
760
|
+
return
|
|
761
|
+
table = Table(title="mlxsmith env registry")
|
|
762
|
+
table.add_column("name")
|
|
763
|
+
table.add_column("version")
|
|
764
|
+
table.add_column("description")
|
|
765
|
+
for pkg in packages:
|
|
766
|
+
table.add_row(
|
|
767
|
+
str(pkg.get("name") or ""),
|
|
768
|
+
str(pkg.get("version") or ""),
|
|
769
|
+
str(pkg.get("description") or ""),
|
|
770
|
+
)
|
|
771
|
+
console.print(table)
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
@env_app.command("info")
|
|
775
|
+
def env_info(
|
|
776
|
+
env: str = typer.Argument(..., help="Env name (optionally name@version or name==version)"),
|
|
777
|
+
version: Optional[str] = typer.Option(None, "--version", help="Pin to a specific version"),
|
|
778
|
+
):
|
|
779
|
+
root = project_root_from_cwd()
|
|
780
|
+
pkg, manifest = registry_info_plugin(root, env, version=version)
|
|
781
|
+
table = Table(title=f"mlxsmith env info: {manifest.name}")
|
|
782
|
+
table.add_column("field")
|
|
783
|
+
table.add_column("value")
|
|
784
|
+
table.add_row("name", manifest.name)
|
|
785
|
+
table.add_row("version", manifest.version)
|
|
786
|
+
table.add_row("description", manifest.description or "n/a")
|
|
787
|
+
table.add_row("verifier", manifest.verifier or "n/a")
|
|
788
|
+
table.add_row("tasks", str(len(manifest.tasks or [])))
|
|
789
|
+
table.add_row("token_env", "yes" if manifest.token_env else "no")
|
|
790
|
+
table.add_row("registry_path", str(pkg.get("path") or ""))
|
|
791
|
+
console.print(table)
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
@env_app.command("install")
|
|
795
|
+
def env_install(
|
|
796
|
+
source: str = typer.Argument(..., help="Env dir, package path, or registry name"),
|
|
797
|
+
version: Optional[str] = typer.Option(None, "--version", help="Pin to a specific version when using registry"),
|
|
798
|
+
):
|
|
799
|
+
root = project_root_from_cwd()
|
|
800
|
+
env_dir = install_env_plugin(root, source, version=version)
|
|
801
|
+
console.print(f"[green]Installed env[/green] {env_dir}")
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
@env_app.command("package")
|
|
805
|
+
def env_package(
|
|
806
|
+
name: str = typer.Argument(..., help="Env name (directory under envs/)"),
|
|
807
|
+
out: str = typer.Option(None, "--out", help="Output directory for package"),
|
|
808
|
+
):
|
|
809
|
+
root = project_root_from_cwd()
|
|
810
|
+
out_path = package_env_plugin(root, name, out_path=out)
|
|
811
|
+
console.print(f"[green]Packaged env[/green] {out_path}")
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
@env_app.command("publish")
|
|
815
|
+
def env_publish(
|
|
816
|
+
package: str = typer.Argument(..., help="Path to .tar.gz package"),
|
|
817
|
+
):
|
|
818
|
+
root = project_root_from_cwd()
|
|
819
|
+
dest = publish_env_plugin(root, package)
|
|
820
|
+
console.print(f"[green]Published env[/green] {dest}")
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
@env_app.command("pull")
|
|
824
|
+
def env_pull(
|
|
825
|
+
env: str = typer.Argument(..., help="Env name (optionally name@version or name==version)"),
|
|
826
|
+
out: Optional[str] = typer.Option(None, "--out", help="Output directory"),
|
|
827
|
+
version: Optional[str] = typer.Option(None, "--version", help="Pin to a specific version"),
|
|
828
|
+
force: bool = typer.Option(False, "--force", help="Overwrite destination if it exists"),
|
|
829
|
+
):
|
|
830
|
+
root = project_root_from_cwd()
|
|
831
|
+
dest = pull_env_plugin(root, env, out_dir=out, version=version, force=force)
|
|
832
|
+
console.print(f"[green]Pulled env[/green] {dest}")
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
@env_app.command("run")
|
|
836
|
+
def env_run(
|
|
837
|
+
env: str = typer.Argument(..., help="Env name or path to env.yaml"),
|
|
838
|
+
model: str = typer.Option(..., "--model"),
|
|
839
|
+
accel: Optional[str] = typer.Option(None, "--accel"),
|
|
840
|
+
verifier: Optional[str] = typer.Option(None, "--verifier"),
|
|
841
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
842
|
+
):
|
|
843
|
+
root = project_root_from_cwd()
|
|
844
|
+
cfg = get_config(config_path=config, root=root, accel_backend=accel)
|
|
845
|
+
env_path = resolve_env_path_plugin(root, env)
|
|
846
|
+
if env_path.is_dir():
|
|
847
|
+
env_path = env_path / "env.yaml"
|
|
848
|
+
if not env_path.exists():
|
|
849
|
+
raise typer.BadParameter(f"Env not found: {env}")
|
|
850
|
+
manifest = load_env_manifest(env_path)
|
|
851
|
+
verifier_path = verifier or manifest.verifier or "verifiers/regex.py"
|
|
852
|
+
vpath = Path(verifier_path)
|
|
853
|
+
if not vpath.is_absolute():
|
|
854
|
+
vpath = root / vpath
|
|
855
|
+
run = run_rft(root, cfg, env_path, vpath, Path(model), cfg.accel.backend)
|
|
856
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
@env_app.command("registry")
|
|
860
|
+
def env_registry():
|
|
861
|
+
root = project_root_from_cwd()
|
|
862
|
+
registry_path = root / "envs" / "registry.json"
|
|
863
|
+
if not registry_path.exists():
|
|
864
|
+
console.print("[yellow]No registry found[/yellow]")
|
|
865
|
+
return
|
|
866
|
+
console.print(registry_path.read_text(encoding="utf-8"))
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
@adapter_app.command("merge")
|
|
870
|
+
def adapters_merge(
|
|
871
|
+
base: str = typer.Option(..., "--base", help="Base model id or path"),
|
|
872
|
+
adapters: str = typer.Option(..., "--adapters", help="Comma-separated adapter paths"),
|
|
873
|
+
out: str = typer.Option("models/merged_adapter", "--out"),
|
|
874
|
+
weights: Optional[str] = typer.Option(None, "--weights", help="Comma-separated weights"),
|
|
875
|
+
):
|
|
876
|
+
root = project_root_from_cwd()
|
|
877
|
+
adapter_paths = [Path(p.strip()) for p in adapters.split(",") if p.strip()]
|
|
878
|
+
out_path = Path(out)
|
|
879
|
+
if not out_path.is_absolute():
|
|
880
|
+
out_path = root / out_path
|
|
881
|
+
w = None
|
|
882
|
+
if weights:
|
|
883
|
+
w = [float(x) for x in weights.split(",") if x.strip()]
|
|
884
|
+
result = merge_adapters(base, adapter_paths, out_path, weights=w)
|
|
885
|
+
console.print(f"[green]Merged adapters[/green] {result}")
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def _sample_env_yaml() -> str:
|
|
889
|
+
return """name: coding-sample
|
|
890
|
+
tasks:
|
|
891
|
+
- id: hello
|
|
892
|
+
prompt: |
|
|
893
|
+
Write a Python function `add(a, b)` that returns the sum.
|
|
894
|
+
gold: |
|
|
895
|
+
def add(a, b):
|
|
896
|
+
return a + b
|
|
897
|
+
verifier_kwargs:
|
|
898
|
+
pattern: "def\\s+add\\("
|
|
899
|
+
- id: pytest_task
|
|
900
|
+
prompt: |
|
|
901
|
+
Implement `mul(a,b)` in main.py. Tests provided.
|
|
902
|
+
tests: |
|
|
903
|
+
from main import mul
|
|
904
|
+
def test_mul():
|
|
905
|
+
assert mul(2,3) == 6
|
|
906
|
+
assert mul(-1,5) == -5
|
|
907
|
+
gold: |
|
|
908
|
+
def mul(a,b):
|
|
909
|
+
return a*b
|
|
910
|
+
"""
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def _sample_verifier_regex() -> str:
|
|
914
|
+
return """from mlxsmith.verifiers.regex import verify as _verify
|
|
915
|
+
|
|
916
|
+
def verify(prompt: str, completion: str, workdir: str, **kwargs):
|
|
917
|
+
return _verify(prompt, completion, workdir, **kwargs)
|
|
918
|
+
"""
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def _sample_verifier_pytest() -> str:
|
|
922
|
+
return """from mlxsmith.verifiers.pytest_verifier import verify as _verify
|
|
923
|
+
|
|
924
|
+
def verify(prompt: str, completion: str, workdir: str, **kwargs):
|
|
925
|
+
return _verify(prompt, completion, workdir, **kwargs)
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
def _sample_verifier_jsonschema() -> str:
|
|
930
|
+
return """from mlxsmith.verifiers.jsonschema import verify as _verify
|
|
931
|
+
|
|
932
|
+
def verify(prompt: str, completion: str, workdir: str, **kwargs):
|
|
933
|
+
return _verify(prompt, completion, workdir, **kwargs)
|
|
934
|
+
"""
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
def _sample_eval_suite() -> str:
|
|
938
|
+
return """name: coding-eval-sample
|
|
939
|
+
notes: |
|
|
940
|
+
Minimal eval suite for smoke testing.
|
|
941
|
+
tasks:
|
|
942
|
+
- id: add
|
|
943
|
+
prompt: |
|
|
944
|
+
Write a Python function `add(a, b)` that returns the sum.
|
|
945
|
+
k: 2
|
|
946
|
+
max_new_tokens: 128
|
|
947
|
+
verifier: verifiers/regex.py
|
|
948
|
+
verifier_kwargs:
|
|
949
|
+
pattern: "def\\s+add\\("
|
|
950
|
+
"""
|