mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/bench.py +12 -2
- mlxsmith/cli.py +187 -1
- mlxsmith/config_models.py +15 -1
- mlxsmith/integrations/__init__.py +19 -0
- mlxsmith/integrations/mlx_lm_lora.py +117 -0
- mlxsmith/llm/backend.py +8 -1
- mlxsmith/llm/mlx_lm_backend.py +59 -2
- mlxsmith/llm/mock_backend.py +8 -1
- mlxsmith/optim/__init__.py +3 -0
- mlxsmith/optim/muon.py +93 -0
- mlxsmith/orchestrator/daemon.py +44 -377
- mlxsmith/orchestrator/trainer_worker.py +4 -0
- mlxsmith/rlm/loop.py +53 -92
- mlxsmith/sdk/__init__.py +18 -2
- mlxsmith/sdk/losses.py +102 -1
- mlxsmith/sdk/training_client.py +24 -5
- mlxsmith/train/distill.py +6 -1
- mlxsmith/train/online_dpo.py +249 -0
- mlxsmith/train/pref.py +31 -29
- mlxsmith/train/rft.py +123 -38
- mlxsmith/train/self_verify.py +199 -0
- mlxsmith/train/sft.py +13 -2
- mlxsmith/verifiers/llm_judge.py +278 -0
- mlxsmith/verifiers/prime.py +127 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/METADATA +27 -1
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/RECORD +30 -22
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/WHEEL +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/entry_points.txt +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/top_level.txt +0 -0
mlxsmith/bench.py
CHANGED
|
@@ -44,7 +44,12 @@ def run_bench(
|
|
|
44
44
|
mode = (mode or "inference").lower()
|
|
45
45
|
|
|
46
46
|
if mode == "trainer":
|
|
47
|
-
opt, _params = llm.optimizer_and_params(
|
|
47
|
+
opt, _params = llm.optimizer_and_params(
|
|
48
|
+
lr=cfg.train.lr,
|
|
49
|
+
weight_decay=cfg.train.weight_decay,
|
|
50
|
+
optimizer=cfg.train.optimizer,
|
|
51
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
52
|
+
)
|
|
48
53
|
prompt_ids = llm.encode(prompt)
|
|
49
54
|
ids = llm.encode(prompt + " " + "x" * max_tokens)
|
|
50
55
|
for i in range(max(1, reps)):
|
|
@@ -59,7 +64,12 @@ def run_bench(
|
|
|
59
64
|
elapsed = max(time.time() - t0, 1e-6)
|
|
60
65
|
results.append({"rep": i, "steps": steps, "time_s": elapsed, "steps_per_s": steps / elapsed})
|
|
61
66
|
elif mode == "end_to_end":
|
|
62
|
-
opt, _params = llm.optimizer_and_params(
|
|
67
|
+
opt, _params = llm.optimizer_and_params(
|
|
68
|
+
lr=cfg.train.lr,
|
|
69
|
+
weight_decay=cfg.train.weight_decay,
|
|
70
|
+
optimizer=cfg.train.optimizer,
|
|
71
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
72
|
+
)
|
|
63
73
|
for i in range(max(1, reps)):
|
|
64
74
|
t0 = time.time()
|
|
65
75
|
gen = llm.generate(prompt, max_new_tokens=max_tokens, temperature=0.0)
|
mlxsmith/cli.py
CHANGED
|
@@ -24,6 +24,8 @@ from .train.sft import run_sft
|
|
|
24
24
|
from .train.pref import run_pref
|
|
25
25
|
from .train.rft import run_rft
|
|
26
26
|
from .train.distill import run_distill
|
|
27
|
+
from .train.online_dpo import run_online_dpo
|
|
28
|
+
from .train.self_verify import run_self_verify
|
|
27
29
|
from .eval import run_eval
|
|
28
30
|
from .bench import run_bench
|
|
29
31
|
from .rlm import run_rlm, run_rlm_orchestrated
|
|
@@ -40,6 +42,13 @@ from .envs import (
|
|
|
40
42
|
resolve_env_path as resolve_env_path_plugin,
|
|
41
43
|
load_manifest as load_env_manifest,
|
|
42
44
|
)
|
|
45
|
+
from .integrations.mlx_lm_lora import (
|
|
46
|
+
build_train_command as build_mlx_lm_lora_train_command,
|
|
47
|
+
build_synthetic_command as build_mlx_lm_lora_synth_command,
|
|
48
|
+
build_judge_command as build_mlx_lm_lora_judge_command,
|
|
49
|
+
build_reward_functions_command as build_mlx_lm_lora_reward_functions_command,
|
|
50
|
+
run_command as run_mlx_lm_lora_command,
|
|
51
|
+
)
|
|
43
52
|
|
|
44
53
|
app = typer.Typer(
|
|
45
54
|
add_completion=False,
|
|
@@ -65,6 +74,9 @@ def init(path: str = typer.Argument(..., help="Project directory to create")):
|
|
|
65
74
|
(p / "verifiers" / "regex.py").write_text(_sample_verifier_regex(), encoding="utf-8")
|
|
66
75
|
(p / "verifiers" / "pytest.py").write_text(_sample_verifier_pytest(), encoding="utf-8")
|
|
67
76
|
(p / "verifiers" / "jsonschema.py").write_text(_sample_verifier_jsonschema(), encoding="utf-8")
|
|
77
|
+
(p / "verifiers" / "llm_judge.py").write_text(_sample_verifier_llm_judge(), encoding="utf-8")
|
|
78
|
+
(p / "verifiers" / "rubrics").mkdir(parents=True, exist_ok=True)
|
|
79
|
+
(p / "verifiers" / "rubrics" / "coding.txt").write_text(_sample_judge_rubric(), encoding="utf-8")
|
|
68
80
|
(p / "eval" / "suites" / "coding.yaml").write_text(_sample_eval_suite(), encoding="utf-8")
|
|
69
81
|
console.print(f"[green]Initialized[/green] {p.resolve()}")
|
|
70
82
|
|
|
@@ -341,14 +353,19 @@ def pref(
|
|
|
341
353
|
data: str = typer.Option("data/prefs", "--data"),
|
|
342
354
|
model: str = typer.Option(..., "--model", help="Base adapter or model path (e.g., runs/sft_0001/adapter)"),
|
|
343
355
|
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
344
|
-
algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (
|
|
356
|
+
algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (legacy)"),
|
|
357
|
+
loss_type: Optional[str] = typer.Option(None, "--loss-type", help="dpo|cpo|orpo|ipo|hinge"),
|
|
345
358
|
):
|
|
346
359
|
root = project_root_from_cwd()
|
|
360
|
+
overrides = {}
|
|
361
|
+
if loss_type is not None:
|
|
362
|
+
overrides["pref.loss_type"] = loss_type
|
|
347
363
|
cfg = get_config(
|
|
348
364
|
config_path=config,
|
|
349
365
|
root=root,
|
|
350
366
|
accel_backend=accel,
|
|
351
367
|
algo=algo,
|
|
368
|
+
**overrides,
|
|
352
369
|
)
|
|
353
370
|
data_dir = root / data
|
|
354
371
|
run = run_pref(root, cfg, data_dir, Path(model), cfg.accel.backend)
|
|
@@ -363,13 +380,27 @@ def rft(
|
|
|
363
380
|
model: str = typer.Option(..., "--model"),
|
|
364
381
|
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
365
382
|
rollouts: Optional[int] = typer.Option(None, "--rollouts", help="Override rft.rollouts"),
|
|
383
|
+
loss_type: Optional[str] = typer.Option(None, "--loss-type", help="grpo|dr_grpo|dapo"),
|
|
384
|
+
epsilon_low: Optional[float] = typer.Option(None, "--epsilon-low"),
|
|
385
|
+
epsilon_high: Optional[float] = typer.Option(None, "--epsilon-high"),
|
|
386
|
+
token_level_loss: Optional[bool] = typer.Option(None, "--token-level-loss/--sequence-level-loss"),
|
|
366
387
|
):
|
|
367
388
|
root = project_root_from_cwd()
|
|
389
|
+
overrides = {}
|
|
390
|
+
if loss_type is not None:
|
|
391
|
+
overrides["rft.loss_type"] = loss_type
|
|
392
|
+
if epsilon_low is not None:
|
|
393
|
+
overrides["rft.epsilon_low"] = epsilon_low
|
|
394
|
+
if epsilon_high is not None:
|
|
395
|
+
overrides["rft.epsilon_high"] = epsilon_high
|
|
396
|
+
if token_level_loss is not None:
|
|
397
|
+
overrides["rft.token_level_loss"] = token_level_loss
|
|
368
398
|
cfg = get_config(
|
|
369
399
|
config_path=config,
|
|
370
400
|
root=root,
|
|
371
401
|
accel_backend=accel,
|
|
372
402
|
rollouts=rollouts,
|
|
403
|
+
**overrides,
|
|
373
404
|
)
|
|
374
405
|
run = run_rft(root, cfg, root / env, root / verifier, Path(model), cfg.accel.backend)
|
|
375
406
|
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
@@ -437,6 +468,142 @@ def distill(
|
|
|
437
468
|
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
438
469
|
|
|
439
470
|
|
|
471
|
+
@app.command("online-dpo")
|
|
472
|
+
def online_dpo(
|
|
473
|
+
data: str = typer.Option(..., "--data", help="JSONL with prompts"),
|
|
474
|
+
model: str = typer.Option(..., "--model"),
|
|
475
|
+
judge_model: Optional[str] = typer.Option(None, "--judge-model"),
|
|
476
|
+
judge_backend: str = typer.Option("mlx-lm", "--judge-backend"),
|
|
477
|
+
rubric: Optional[str] = typer.Option(None, "--rubric"),
|
|
478
|
+
group_size: Optional[int] = typer.Option(None, "--group-size"),
|
|
479
|
+
max_new_tokens: Optional[int] = typer.Option(None, "--max-new-tokens"),
|
|
480
|
+
temperature: Optional[float] = typer.Option(None, "--temperature"),
|
|
481
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
482
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
483
|
+
):
|
|
484
|
+
root = project_root_from_cwd()
|
|
485
|
+
cfg = get_config(config_path=config, root=root, accel_backend=accel)
|
|
486
|
+
run = run_online_dpo(
|
|
487
|
+
root,
|
|
488
|
+
cfg,
|
|
489
|
+
Path(data),
|
|
490
|
+
model,
|
|
491
|
+
cfg.accel.backend,
|
|
492
|
+
judge_model=judge_model,
|
|
493
|
+
judge_backend=judge_backend,
|
|
494
|
+
rubric=rubric,
|
|
495
|
+
group_size=group_size,
|
|
496
|
+
max_new_tokens=max_new_tokens,
|
|
497
|
+
temperature=temperature,
|
|
498
|
+
)
|
|
499
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
@app.command("self-verify")
|
|
503
|
+
def self_verify(
|
|
504
|
+
data: str = typer.Option(..., "--data", help="JSONL with prompts"),
|
|
505
|
+
model: str = typer.Option(..., "--model"),
|
|
506
|
+
verifier_model: Optional[str] = typer.Option(None, "--verifier-model"),
|
|
507
|
+
verifier_backend: str = typer.Option("mlx-lm", "--verifier-backend"),
|
|
508
|
+
rubric: Optional[str] = typer.Option(None, "--rubric"),
|
|
509
|
+
max_new_tokens: Optional[int] = typer.Option(None, "--max-new-tokens"),
|
|
510
|
+
temperature: Optional[float] = typer.Option(None, "--temperature"),
|
|
511
|
+
config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
|
|
512
|
+
accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
|
|
513
|
+
):
|
|
514
|
+
root = project_root_from_cwd()
|
|
515
|
+
cfg = get_config(config_path=config, root=root, accel_backend=accel)
|
|
516
|
+
run = run_self_verify(
|
|
517
|
+
root,
|
|
518
|
+
cfg,
|
|
519
|
+
Path(data),
|
|
520
|
+
model,
|
|
521
|
+
cfg.accel.backend,
|
|
522
|
+
verifier_model=verifier_model,
|
|
523
|
+
verifier_backend=verifier_backend,
|
|
524
|
+
rubric=rubric,
|
|
525
|
+
max_new_tokens=max_new_tokens,
|
|
526
|
+
temperature=temperature,
|
|
527
|
+
)
|
|
528
|
+
console.print(f"[bold]Run:[/bold] {run.run_dir}")
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
lora_app = typer.Typer(help="mlx-lm-lora passthrough commands")
|
|
532
|
+
app.add_typer(lora_app, name="lora")
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
@lora_app.command(
|
|
536
|
+
"train",
|
|
537
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
538
|
+
)
|
|
539
|
+
def lora_train(
|
|
540
|
+
ctx: typer.Context,
|
|
541
|
+
config: Optional[str] = typer.Option(None, "--config", help="mlx-lm-lora config path"),
|
|
542
|
+
model: Optional[str] = typer.Option(None, "--model", help="Model id or path"),
|
|
543
|
+
data: Optional[str] = typer.Option(None, "--data", help="Dataset path or HF dataset"),
|
|
544
|
+
train_mode: Optional[str] = typer.Option(None, "--train-mode", help="sft|dpo|orpo|grpo|ppo|..."),
|
|
545
|
+
train_type: Optional[str] = typer.Option(None, "--train-type", help="lora|dora|full"),
|
|
546
|
+
dry_run: bool = typer.Option(False, "--dry-run"),
|
|
547
|
+
):
|
|
548
|
+
"""Run mlx-lm-lora training with passthrough args.
|
|
549
|
+
|
|
550
|
+
Use `--` to pass through any additional mlx-lm-lora flags.
|
|
551
|
+
"""
|
|
552
|
+
root = project_root_from_cwd()
|
|
553
|
+
cmd = build_mlx_lm_lora_train_command(
|
|
554
|
+
config=config,
|
|
555
|
+
model=model,
|
|
556
|
+
data=data,
|
|
557
|
+
train_mode=train_mode,
|
|
558
|
+
train_type=train_type,
|
|
559
|
+
extra_args=list(ctx.args),
|
|
560
|
+
)
|
|
561
|
+
run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
@lora_app.command(
|
|
565
|
+
"synthetic",
|
|
566
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
567
|
+
)
|
|
568
|
+
def lora_synthetic(
|
|
569
|
+
ctx: typer.Context,
|
|
570
|
+
kind: str = typer.Argument(..., help="prompts|sft|dpo"),
|
|
571
|
+
dry_run: bool = typer.Option(False, "--dry-run"),
|
|
572
|
+
):
|
|
573
|
+
"""Run mlx-lm-lora synthetic dataset generation."""
|
|
574
|
+
root = project_root_from_cwd()
|
|
575
|
+
cmd = build_mlx_lm_lora_synth_command(kind, extra_args=list(ctx.args))
|
|
576
|
+
run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
@lora_app.command(
|
|
580
|
+
"judge",
|
|
581
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
582
|
+
)
|
|
583
|
+
def lora_judge(
|
|
584
|
+
ctx: typer.Context,
|
|
585
|
+
dry_run: bool = typer.Option(False, "--dry-run"),
|
|
586
|
+
):
|
|
587
|
+
"""Run mlx-lm-lora judge model training."""
|
|
588
|
+
root = project_root_from_cwd()
|
|
589
|
+
cmd = build_mlx_lm_lora_judge_command(extra_args=list(ctx.args))
|
|
590
|
+
run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@lora_app.command(
|
|
594
|
+
"reward-functions",
|
|
595
|
+
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
596
|
+
)
|
|
597
|
+
def lora_reward_functions(
|
|
598
|
+
ctx: typer.Context,
|
|
599
|
+
dry_run: bool = typer.Option(False, "--dry-run"),
|
|
600
|
+
):
|
|
601
|
+
"""List mlx-lm-lora reward functions."""
|
|
602
|
+
root = project_root_from_cwd()
|
|
603
|
+
cmd = build_mlx_lm_lora_reward_functions_command(extra_args=list(ctx.args))
|
|
604
|
+
run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
|
|
605
|
+
|
|
606
|
+
|
|
440
607
|
@app.command()
|
|
441
608
|
def eval(
|
|
442
609
|
suite: str = typer.Option("eval/suites/coding.yaml", "--suite"),
|
|
@@ -933,6 +1100,25 @@ def verify(prompt: str, completion: str, workdir: str, **kwargs):
|
|
|
933
1100
|
"""
|
|
934
1101
|
|
|
935
1102
|
|
|
1103
|
+
def _sample_verifier_llm_judge() -> str:
|
|
1104
|
+
return """from mlxsmith.verifiers.llm_judge import verify as _verify
|
|
1105
|
+
|
|
1106
|
+
def verify(prompt: str, completion: str, workdir: str, **kwargs):
|
|
1107
|
+
# Pass model=... or set MLXSMITH_JUDGE_MODEL for the judge model id.
|
|
1108
|
+
return _verify(prompt, completion, workdir, **kwargs)
|
|
1109
|
+
"""
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
def _sample_judge_rubric() -> str:
|
|
1113
|
+
return """Score from 0.0 to 1.0.
|
|
1114
|
+
- 1.0: Correct, complete, and safe.
|
|
1115
|
+
- 0.7: Mostly correct with small issues.
|
|
1116
|
+
- 0.4: Partial correctness or unclear reasoning.
|
|
1117
|
+
- 0.0: Incorrect or unsafe.
|
|
1118
|
+
Return JSON only.
|
|
1119
|
+
"""
|
|
1120
|
+
|
|
1121
|
+
|
|
936
1122
|
def _sample_eval_suite() -> str:
|
|
937
1123
|
return """name: coding-eval-sample
|
|
938
1124
|
notes: |
|
mlxsmith/config_models.py
CHANGED
|
@@ -47,6 +47,8 @@ class TrainConfig(BaseModel):
|
|
|
47
47
|
grad_accum: int = 8
|
|
48
48
|
lr: float = 2e-4
|
|
49
49
|
weight_decay: float = 0.0
|
|
50
|
+
optimizer: str = "adamw"
|
|
51
|
+
optimizer_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
50
52
|
iters: int = 1000
|
|
51
53
|
save_every: int = 100
|
|
52
54
|
eval_every: int = 100
|
|
@@ -61,6 +63,11 @@ class TrainConfig(BaseModel):
|
|
|
61
63
|
raise ValueError("value must be non-negative")
|
|
62
64
|
return v
|
|
63
65
|
|
|
66
|
+
@field_validator("optimizer")
|
|
67
|
+
@classmethod
|
|
68
|
+
def normalize_optimizer(cls, v: str) -> str:
|
|
69
|
+
return v.strip().lower()
|
|
70
|
+
|
|
64
71
|
|
|
65
72
|
class LoraConfig(BaseModel):
|
|
66
73
|
"""LoRA/DoRA adapter configuration."""
|
|
@@ -89,11 +96,13 @@ class LoraConfig(BaseModel):
|
|
|
89
96
|
|
|
90
97
|
|
|
91
98
|
class PrefConfig(BaseModel):
|
|
92
|
-
"""Preference tuning configuration (DPO
|
|
99
|
+
"""Preference tuning configuration (DPO variants)."""
|
|
93
100
|
|
|
94
101
|
algo: Literal["dpo", "orpo", "grpo"] = "dpo"
|
|
102
|
+
loss_type: Literal["dpo", "cpo", "orpo", "ipo", "hinge"] = "dpo"
|
|
95
103
|
beta: float = 0.1
|
|
96
104
|
kl_coeff: float = 0.0
|
|
105
|
+
delta: float = 0.0
|
|
97
106
|
reference_model: Optional[str] = None
|
|
98
107
|
|
|
99
108
|
|
|
@@ -101,12 +110,16 @@ class RftConfig(BaseModel):
|
|
|
101
110
|
"""Reinforcement fine-tuning configuration."""
|
|
102
111
|
|
|
103
112
|
algo: Literal["grpo"] = "grpo"
|
|
113
|
+
loss_type: Literal["grpo", "dr_grpo", "dapo"] = "grpo"
|
|
104
114
|
rollouts: int = 8
|
|
105
115
|
kl_coeff: float = 0.02
|
|
106
116
|
max_steps_per_task: int = 1
|
|
107
117
|
temperature: float = 0.8
|
|
108
118
|
max_new_tokens: int = 256
|
|
109
119
|
normalize_advantage: bool = True
|
|
120
|
+
epsilon_low: float = 0.2
|
|
121
|
+
epsilon_high: float = 0.2
|
|
122
|
+
token_level_loss: bool = False
|
|
110
123
|
reference_model: Optional[str] = None
|
|
111
124
|
|
|
112
125
|
|
|
@@ -164,6 +177,7 @@ CLI_ALIASES: dict[str, tuple[str, ...]] = {
|
|
|
164
177
|
"lr": ("train", "lr"),
|
|
165
178
|
"batch_size": ("train", "batch_size"),
|
|
166
179
|
"iters": ("train", "iters"),
|
|
180
|
+
"optimizer": ("train", "optimizer"),
|
|
167
181
|
"model_id": ("model", "id"),
|
|
168
182
|
"accel_backend": ("accel", "backend"),
|
|
169
183
|
"host": ("serve", "host"),
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""External integrations for mlxsmith."""
|
|
2
|
+
|
|
3
|
+
from .mlx_lm_lora import (
|
|
4
|
+
build_train_command as build_mlx_lm_lora_train_command,
|
|
5
|
+
build_synthetic_command as build_mlx_lm_lora_synth_command,
|
|
6
|
+
build_judge_command as build_mlx_lm_lora_judge_command,
|
|
7
|
+
build_reward_functions_command as build_mlx_lm_lora_reward_functions_command,
|
|
8
|
+
run_command as run_mlx_lm_lora_command,
|
|
9
|
+
ensure_available as ensure_mlx_lm_lora_available,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"build_mlx_lm_lora_train_command",
|
|
14
|
+
"build_mlx_lm_lora_synth_command",
|
|
15
|
+
"build_mlx_lm_lora_judge_command",
|
|
16
|
+
"build_mlx_lm_lora_reward_functions_command",
|
|
17
|
+
"run_mlx_lm_lora_command",
|
|
18
|
+
"ensure_mlx_lm_lora_available",
|
|
19
|
+
]
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Passthrough helpers for mlx-lm-lora CLI integration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib.util
|
|
6
|
+
import os
|
|
7
|
+
import shlex
|
|
8
|
+
import subprocess
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Optional, Sequence
|
|
12
|
+
|
|
13
|
+
from rich.console import Console
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def ensure_available() -> None:
|
|
19
|
+
if importlib.util.find_spec("mlx_lm_lora") is None:
|
|
20
|
+
raise RuntimeError(
|
|
21
|
+
"mlx-lm-lora is not installed. Install with: pip install 'mlxsmith[lora]' or 'mlx-lm-lora'"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _flag_present(args: Sequence[str], *flags: str) -> bool:
|
|
26
|
+
for flag in flags:
|
|
27
|
+
if flag in args:
|
|
28
|
+
return True
|
|
29
|
+
if flag.startswith("--"):
|
|
30
|
+
prefix = flag + "="
|
|
31
|
+
if any(a.startswith(prefix) for a in args):
|
|
32
|
+
return True
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _append_flag(cmd: list[str], args: Sequence[str], flag: str, value: Optional[str]) -> None:
|
|
37
|
+
if value is None:
|
|
38
|
+
return
|
|
39
|
+
if _flag_present(args, flag):
|
|
40
|
+
return
|
|
41
|
+
cmd.extend([flag, value])
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _base_python_cmd(module: str) -> list[str]:
|
|
45
|
+
return [sys.executable, "-m", module]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def build_train_command(
|
|
49
|
+
*,
|
|
50
|
+
config: Optional[str] = None,
|
|
51
|
+
model: Optional[str] = None,
|
|
52
|
+
data: Optional[str] = None,
|
|
53
|
+
train_mode: Optional[str] = None,
|
|
54
|
+
train_type: Optional[str] = None,
|
|
55
|
+
extra_args: Sequence[str] = (),
|
|
56
|
+
) -> list[str]:
|
|
57
|
+
args = list(extra_args)
|
|
58
|
+
cmd = _base_python_cmd("mlx_lm_lora.train")
|
|
59
|
+
_append_flag(cmd, args, "--config", config)
|
|
60
|
+
_append_flag(cmd, args, "--model", model)
|
|
61
|
+
_append_flag(cmd, args, "--data", data)
|
|
62
|
+
_append_flag(cmd, args, "--train-mode", train_mode)
|
|
63
|
+
_append_flag(cmd, args, "--train-type", train_type)
|
|
64
|
+
cmd.extend(args)
|
|
65
|
+
return cmd
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def build_synthetic_command(
|
|
69
|
+
kind: str,
|
|
70
|
+
*,
|
|
71
|
+
extra_args: Sequence[str] = (),
|
|
72
|
+
) -> list[str]:
|
|
73
|
+
kind = kind.strip().lower()
|
|
74
|
+
module = {
|
|
75
|
+
"prompts": "mlx_lm_lora.synthetic_prompts",
|
|
76
|
+
"sft": "mlx_lm_lora.synthetic_sft",
|
|
77
|
+
"dpo": "mlx_lm_lora.synthetic_dpo",
|
|
78
|
+
}.get(kind)
|
|
79
|
+
if module is None:
|
|
80
|
+
raise ValueError(f"Unknown synthetic kind: {kind}")
|
|
81
|
+
cmd = _base_python_cmd(module)
|
|
82
|
+
cmd.extend(list(extra_args))
|
|
83
|
+
return cmd
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def build_judge_command(*, extra_args: Sequence[str] = ()) -> list[str]:
|
|
87
|
+
cmd = _base_python_cmd("mlx_lm_lora.train_judge")
|
|
88
|
+
cmd.extend(list(extra_args))
|
|
89
|
+
return cmd
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def build_reward_functions_command(*, extra_args: Sequence[str] = ()) -> list[str]:
|
|
93
|
+
cmd = _base_python_cmd("mlx_lm_lora.train")
|
|
94
|
+
cmd.append("--list-reward-functions")
|
|
95
|
+
cmd.extend(list(extra_args))
|
|
96
|
+
return cmd
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def run_command(
|
|
100
|
+
cmd: Sequence[str],
|
|
101
|
+
*,
|
|
102
|
+
dry_run: bool = False,
|
|
103
|
+
cwd: Optional[Path] = None,
|
|
104
|
+
env: Optional[dict] = None,
|
|
105
|
+
) -> int:
|
|
106
|
+
if dry_run:
|
|
107
|
+
console.print("[cyan]mlx-lm-lora cmd[/cyan]", shlex.join(list(cmd)))
|
|
108
|
+
return 0
|
|
109
|
+
ensure_available()
|
|
110
|
+
run_env = os.environ.copy()
|
|
111
|
+
if env:
|
|
112
|
+
run_env.update(env)
|
|
113
|
+
console.print("[cyan]mlx-lm-lora cmd[/cyan]", shlex.join(list(cmd)))
|
|
114
|
+
result = subprocess.run(list(cmd), cwd=str(cwd) if cwd else None, env=run_env, check=False)
|
|
115
|
+
if result.returncode != 0:
|
|
116
|
+
raise RuntimeError(f"mlx-lm-lora failed with exit code {result.returncode}")
|
|
117
|
+
return result.returncode
|
mlxsmith/llm/backend.py
CHANGED
|
@@ -112,7 +112,14 @@ class LLMBackend(Protocol):
|
|
|
112
112
|
def value_and_grad(self, loss_fn) -> tuple[Any, Any | None]:
|
|
113
113
|
"""Return (loss, grads) using backend autograd when available."""
|
|
114
114
|
|
|
115
|
-
def optimizer_and_params(
|
|
115
|
+
def optimizer_and_params(
|
|
116
|
+
self,
|
|
117
|
+
*,
|
|
118
|
+
lr: float,
|
|
119
|
+
weight_decay: float = 0.0,
|
|
120
|
+
optimizer: str | None = None,
|
|
121
|
+
optimizer_kwargs: dict | None = None,
|
|
122
|
+
) -> tuple[Any, Any]:
|
|
116
123
|
"""Return (optimizer, trainable_params_tree)."""
|
|
117
124
|
|
|
118
125
|
def apply_grads(self, optimizer: Any, grads: Any) -> None:
|
mlxsmith/llm/mlx_lm_backend.py
CHANGED
|
@@ -467,7 +467,54 @@ class MlxLMBackend:
|
|
|
467
467
|
return vag(self.model, loss_fn)(self.model)
|
|
468
468
|
return loss_fn(self.model), None
|
|
469
469
|
|
|
470
|
-
def optimizer_and_params(
|
|
470
|
+
def optimizer_and_params(
|
|
471
|
+
self,
|
|
472
|
+
*,
|
|
473
|
+
lr: float,
|
|
474
|
+
weight_decay: float = 0.0,
|
|
475
|
+
optimizer: str | None = None,
|
|
476
|
+
optimizer_kwargs: dict | None = None,
|
|
477
|
+
) -> tuple[Any, Any]:
|
|
478
|
+
return self._optimizer_and_params(
|
|
479
|
+
lr=lr,
|
|
480
|
+
weight_decay=weight_decay,
|
|
481
|
+
optimizer=optimizer,
|
|
482
|
+
optimizer_kwargs=optimizer_kwargs,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
def _resolve_optimizer(self, name: str):
|
|
486
|
+
assert self.optim is not None
|
|
487
|
+
name = name.strip().lower()
|
|
488
|
+
if name in ("muon", "muonclip", "muon_clip"):
|
|
489
|
+
from ..optim.muon import Muon, MuonClip
|
|
490
|
+
|
|
491
|
+
return MuonClip if name in ("muonclip", "muon_clip") else Muon
|
|
492
|
+
mapping = {
|
|
493
|
+
"adamw": ["AdamW", "Adamw", "adamw"],
|
|
494
|
+
"adam": ["Adam", "adam"],
|
|
495
|
+
"sgd": ["SGD", "Sgd", "sgd"],
|
|
496
|
+
"rmsprop": ["RMSprop", "RmsProp", "rmsprop"],
|
|
497
|
+
"qhadam": ["QHAdam", "Qhadam", "qhadam"],
|
|
498
|
+
"muon": ["Muon", "muon", "MuonW", "muonw"],
|
|
499
|
+
}
|
|
500
|
+
candidates = mapping.get(name, [name, name.capitalize(), name.upper()])
|
|
501
|
+
for cand in candidates:
|
|
502
|
+
opt_cls = getattr(self.optim, cand, None)
|
|
503
|
+
if opt_cls is not None:
|
|
504
|
+
return opt_cls
|
|
505
|
+
raise RuntimeError(
|
|
506
|
+
f"Optimizer '{name}' is not available in mlx.optimizers. "
|
|
507
|
+
"Install a build that provides it or choose a supported optimizer."
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
def _optimizer_and_params(
|
|
511
|
+
self,
|
|
512
|
+
*,
|
|
513
|
+
lr: float,
|
|
514
|
+
weight_decay: float = 0.0,
|
|
515
|
+
optimizer: str | None = None,
|
|
516
|
+
optimizer_kwargs: dict | None = None,
|
|
517
|
+
) -> tuple[Any, Any]:
|
|
471
518
|
assert self.optim is not None
|
|
472
519
|
if self.model is None:
|
|
473
520
|
raise RuntimeError("Backend not loaded")
|
|
@@ -483,7 +530,17 @@ class MlxLMBackend:
|
|
|
483
530
|
if not params:
|
|
484
531
|
params = getattr(self.model, "parameters", lambda: self.model)()
|
|
485
532
|
|
|
486
|
-
|
|
533
|
+
opt_name = optimizer or "adamw"
|
|
534
|
+
opt_cls = self._resolve_optimizer(opt_name)
|
|
535
|
+
kwargs = dict(optimizer_kwargs or {})
|
|
536
|
+
if "learning_rate" not in kwargs and "lr" not in kwargs:
|
|
537
|
+
kwargs["learning_rate"] = lr
|
|
538
|
+
kwargs["lr"] = lr
|
|
539
|
+
if "weight_decay" not in kwargs:
|
|
540
|
+
kwargs["weight_decay"] = weight_decay
|
|
541
|
+
opt = self._call_with_supported_kwargs(opt_cls, **kwargs)
|
|
542
|
+
if opt is None:
|
|
543
|
+
opt = opt_cls()
|
|
487
544
|
opt.init(params)
|
|
488
545
|
return opt, params
|
|
489
546
|
|
mlxsmith/llm/mock_backend.py
CHANGED
|
@@ -200,7 +200,14 @@ class MockBackend:
|
|
|
200
200
|
def value_and_grad(self, loss_fn):
|
|
201
201
|
return loss_fn(self.model), None
|
|
202
202
|
|
|
203
|
-
def optimizer_and_params(
|
|
203
|
+
def optimizer_and_params(
|
|
204
|
+
self,
|
|
205
|
+
*,
|
|
206
|
+
lr: float,
|
|
207
|
+
weight_decay: float = 0.0,
|
|
208
|
+
optimizer: str | None = None,
|
|
209
|
+
optimizer_kwargs: dict | None = None,
|
|
210
|
+
):
|
|
204
211
|
return object(), {}
|
|
205
212
|
|
|
206
213
|
def apply_grads(self, optimizer: Any, grads: Any) -> None:
|
mlxsmith/optim/muon.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from ..util import tree_map
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Muon:
|
|
9
|
+
"""Muon optimizer wrapper (AdamW with Newton-Schulz orthogonalization for 2D grads)."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
learning_rate: Optional[float] = None,
|
|
14
|
+
lr: Optional[float] = None,
|
|
15
|
+
weight_decay: float = 0.0,
|
|
16
|
+
clip: Optional[float] = None,
|
|
17
|
+
ns_iters: int = 5,
|
|
18
|
+
a: float = 3.4445,
|
|
19
|
+
b: float = -4.7750,
|
|
20
|
+
c: float = 2.0315,
|
|
21
|
+
) -> None:
|
|
22
|
+
import mlx.optimizers as optim # type: ignore
|
|
23
|
+
|
|
24
|
+
self.learning_rate = float(lr if lr is not None else learning_rate if learning_rate is not None else 1e-4)
|
|
25
|
+
self.weight_decay = float(weight_decay)
|
|
26
|
+
self.clip = clip
|
|
27
|
+
self.ns_iters = int(ns_iters)
|
|
28
|
+
self.a = float(a)
|
|
29
|
+
self.b = float(b)
|
|
30
|
+
self.c = float(c)
|
|
31
|
+
self._base = optim.AdamW(learning_rate=self.learning_rate, weight_decay=self.weight_decay)
|
|
32
|
+
self.state: dict[str, Any] = {}
|
|
33
|
+
|
|
34
|
+
def init(self, params: Any) -> None:
|
|
35
|
+
self._base.init(params)
|
|
36
|
+
self.state = getattr(self._base, "state", {})
|
|
37
|
+
|
|
38
|
+
def _clip_grad(self, g):
|
|
39
|
+
if self.clip is None or self.clip <= 0:
|
|
40
|
+
return g
|
|
41
|
+
import mlx.core as mx # type: ignore
|
|
42
|
+
|
|
43
|
+
norm = mx.sqrt((g * g).sum())
|
|
44
|
+
scale = mx.minimum(mx.array(1.0), mx.array(float(self.clip)) / mx.maximum(norm, mx.array(1e-8)))
|
|
45
|
+
return g * scale
|
|
46
|
+
|
|
47
|
+
def _orthogonalize(self, g):
|
|
48
|
+
import mlx.core as mx # type: ignore
|
|
49
|
+
|
|
50
|
+
g = self._clip_grad(g)
|
|
51
|
+
if getattr(g, "ndim", None) != 2:
|
|
52
|
+
return g
|
|
53
|
+
in_dim = int(g.shape[1])
|
|
54
|
+
eye = mx.eye(in_dim, dtype=g.dtype)
|
|
55
|
+
out = g
|
|
56
|
+
for _ in range(self.ns_iters):
|
|
57
|
+
gtg = mx.matmul(out.T, out)
|
|
58
|
+
gtg2 = mx.matmul(gtg, gtg)
|
|
59
|
+
out = mx.matmul(out, self.a * eye + self.b * gtg + self.c * gtg2)
|
|
60
|
+
return out
|
|
61
|
+
|
|
62
|
+
def update(self, model: Any, grads: Any) -> None:
|
|
63
|
+
if grads is None:
|
|
64
|
+
return
|
|
65
|
+
transformed = tree_map(self._orthogonalize, grads)
|
|
66
|
+
self._base.update(model, transformed)
|
|
67
|
+
self.state = getattr(self._base, "state", self.state)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class MuonClip(Muon):
|
|
71
|
+
"""Muon with explicit gradient clipping before orthogonalization."""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
learning_rate: Optional[float] = None,
|
|
76
|
+
lr: Optional[float] = None,
|
|
77
|
+
weight_decay: float = 0.0,
|
|
78
|
+
clip: Optional[float] = 1.0,
|
|
79
|
+
ns_iters: int = 5,
|
|
80
|
+
a: float = 3.4445,
|
|
81
|
+
b: float = -4.7750,
|
|
82
|
+
c: float = 2.0315,
|
|
83
|
+
) -> None:
|
|
84
|
+
super().__init__(
|
|
85
|
+
learning_rate=learning_rate,
|
|
86
|
+
lr=lr,
|
|
87
|
+
weight_decay=weight_decay,
|
|
88
|
+
clip=clip,
|
|
89
|
+
ns_iters=ns_iters,
|
|
90
|
+
a=a,
|
|
91
|
+
b=b,
|
|
92
|
+
c=c,
|
|
93
|
+
)
|