mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlxsmith/bench.py CHANGED
@@ -44,7 +44,12 @@ def run_bench(
44
44
  mode = (mode or "inference").lower()
45
45
 
46
46
  if mode == "trainer":
47
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
47
+ opt, _params = llm.optimizer_and_params(
48
+ lr=cfg.train.lr,
49
+ weight_decay=cfg.train.weight_decay,
50
+ optimizer=cfg.train.optimizer,
51
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
52
+ )
48
53
  prompt_ids = llm.encode(prompt)
49
54
  ids = llm.encode(prompt + " " + "x" * max_tokens)
50
55
  for i in range(max(1, reps)):
@@ -59,7 +64,12 @@ def run_bench(
59
64
  elapsed = max(time.time() - t0, 1e-6)
60
65
  results.append({"rep": i, "steps": steps, "time_s": elapsed, "steps_per_s": steps / elapsed})
61
66
  elif mode == "end_to_end":
62
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
67
+ opt, _params = llm.optimizer_and_params(
68
+ lr=cfg.train.lr,
69
+ weight_decay=cfg.train.weight_decay,
70
+ optimizer=cfg.train.optimizer,
71
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
72
+ )
63
73
  for i in range(max(1, reps)):
64
74
  t0 = time.time()
65
75
  gen = llm.generate(prompt, max_new_tokens=max_tokens, temperature=0.0)
mlxsmith/cli.py CHANGED
@@ -24,6 +24,8 @@ from .train.sft import run_sft
24
24
  from .train.pref import run_pref
25
25
  from .train.rft import run_rft
26
26
  from .train.distill import run_distill
27
+ from .train.online_dpo import run_online_dpo
28
+ from .train.self_verify import run_self_verify
27
29
  from .eval import run_eval
28
30
  from .bench import run_bench
29
31
  from .rlm import run_rlm, run_rlm_orchestrated
@@ -40,6 +42,13 @@ from .envs import (
40
42
  resolve_env_path as resolve_env_path_plugin,
41
43
  load_manifest as load_env_manifest,
42
44
  )
45
+ from .integrations.mlx_lm_lora import (
46
+ build_train_command as build_mlx_lm_lora_train_command,
47
+ build_synthetic_command as build_mlx_lm_lora_synth_command,
48
+ build_judge_command as build_mlx_lm_lora_judge_command,
49
+ build_reward_functions_command as build_mlx_lm_lora_reward_functions_command,
50
+ run_command as run_mlx_lm_lora_command,
51
+ )
43
52
 
44
53
  app = typer.Typer(
45
54
  add_completion=False,
@@ -65,6 +74,9 @@ def init(path: str = typer.Argument(..., help="Project directory to create")):
65
74
  (p / "verifiers" / "regex.py").write_text(_sample_verifier_regex(), encoding="utf-8")
66
75
  (p / "verifiers" / "pytest.py").write_text(_sample_verifier_pytest(), encoding="utf-8")
67
76
  (p / "verifiers" / "jsonschema.py").write_text(_sample_verifier_jsonschema(), encoding="utf-8")
77
+ (p / "verifiers" / "llm_judge.py").write_text(_sample_verifier_llm_judge(), encoding="utf-8")
78
+ (p / "verifiers" / "rubrics").mkdir(parents=True, exist_ok=True)
79
+ (p / "verifiers" / "rubrics" / "coding.txt").write_text(_sample_judge_rubric(), encoding="utf-8")
68
80
  (p / "eval" / "suites" / "coding.yaml").write_text(_sample_eval_suite(), encoding="utf-8")
69
81
  console.print(f"[green]Initialized[/green] {p.resolve()}")
70
82
 
@@ -341,14 +353,19 @@ def pref(
341
353
  data: str = typer.Option("data/prefs", "--data"),
342
354
  model: str = typer.Option(..., "--model", help="Base adapter or model path (e.g., runs/sft_0001/adapter)"),
343
355
  accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
344
- algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (dpo|orpo|grpo)"),
356
+ algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (legacy)"),
357
+ loss_type: Optional[str] = typer.Option(None, "--loss-type", help="dpo|cpo|orpo|ipo|hinge"),
345
358
  ):
346
359
  root = project_root_from_cwd()
360
+ overrides = {}
361
+ if loss_type is not None:
362
+ overrides["pref.loss_type"] = loss_type
347
363
  cfg = get_config(
348
364
  config_path=config,
349
365
  root=root,
350
366
  accel_backend=accel,
351
367
  algo=algo,
368
+ **overrides,
352
369
  )
353
370
  data_dir = root / data
354
371
  run = run_pref(root, cfg, data_dir, Path(model), cfg.accel.backend)
@@ -363,13 +380,27 @@ def rft(
363
380
  model: str = typer.Option(..., "--model"),
364
381
  accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
365
382
  rollouts: Optional[int] = typer.Option(None, "--rollouts", help="Override rft.rollouts"),
383
+ loss_type: Optional[str] = typer.Option(None, "--loss-type", help="grpo|dr_grpo|dapo"),
384
+ epsilon_low: Optional[float] = typer.Option(None, "--epsilon-low"),
385
+ epsilon_high: Optional[float] = typer.Option(None, "--epsilon-high"),
386
+ token_level_loss: Optional[bool] = typer.Option(None, "--token-level-loss/--sequence-level-loss"),
366
387
  ):
367
388
  root = project_root_from_cwd()
389
+ overrides = {}
390
+ if loss_type is not None:
391
+ overrides["rft.loss_type"] = loss_type
392
+ if epsilon_low is not None:
393
+ overrides["rft.epsilon_low"] = epsilon_low
394
+ if epsilon_high is not None:
395
+ overrides["rft.epsilon_high"] = epsilon_high
396
+ if token_level_loss is not None:
397
+ overrides["rft.token_level_loss"] = token_level_loss
368
398
  cfg = get_config(
369
399
  config_path=config,
370
400
  root=root,
371
401
  accel_backend=accel,
372
402
  rollouts=rollouts,
403
+ **overrides,
373
404
  )
374
405
  run = run_rft(root, cfg, root / env, root / verifier, Path(model), cfg.accel.backend)
375
406
  console.print(f"[bold]Run:[/bold] {run.run_dir}")
@@ -437,6 +468,142 @@ def distill(
437
468
  console.print(f"[bold]Run:[/bold] {run.run_dir}")
438
469
 
439
470
 
471
+ @app.command("online-dpo")
472
+ def online_dpo(
473
+ data: str = typer.Option(..., "--data", help="JSONL with prompts"),
474
+ model: str = typer.Option(..., "--model"),
475
+ judge_model: Optional[str] = typer.Option(None, "--judge-model"),
476
+ judge_backend: str = typer.Option("mlx-lm", "--judge-backend"),
477
+ rubric: Optional[str] = typer.Option(None, "--rubric"),
478
+ group_size: Optional[int] = typer.Option(None, "--group-size"),
479
+ max_new_tokens: Optional[int] = typer.Option(None, "--max-new-tokens"),
480
+ temperature: Optional[float] = typer.Option(None, "--temperature"),
481
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
482
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
483
+ ):
484
+ root = project_root_from_cwd()
485
+ cfg = get_config(config_path=config, root=root, accel_backend=accel)
486
+ run = run_online_dpo(
487
+ root,
488
+ cfg,
489
+ Path(data),
490
+ model,
491
+ cfg.accel.backend,
492
+ judge_model=judge_model,
493
+ judge_backend=judge_backend,
494
+ rubric=rubric,
495
+ group_size=group_size,
496
+ max_new_tokens=max_new_tokens,
497
+ temperature=temperature,
498
+ )
499
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
500
+
501
+
502
+ @app.command("self-verify")
503
+ def self_verify(
504
+ data: str = typer.Option(..., "--data", help="JSONL with prompts"),
505
+ model: str = typer.Option(..., "--model"),
506
+ verifier_model: Optional[str] = typer.Option(None, "--verifier-model"),
507
+ verifier_backend: str = typer.Option("mlx-lm", "--verifier-backend"),
508
+ rubric: Optional[str] = typer.Option(None, "--rubric"),
509
+ max_new_tokens: Optional[int] = typer.Option(None, "--max-new-tokens"),
510
+ temperature: Optional[float] = typer.Option(None, "--temperature"),
511
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
512
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
513
+ ):
514
+ root = project_root_from_cwd()
515
+ cfg = get_config(config_path=config, root=root, accel_backend=accel)
516
+ run = run_self_verify(
517
+ root,
518
+ cfg,
519
+ Path(data),
520
+ model,
521
+ cfg.accel.backend,
522
+ verifier_model=verifier_model,
523
+ verifier_backend=verifier_backend,
524
+ rubric=rubric,
525
+ max_new_tokens=max_new_tokens,
526
+ temperature=temperature,
527
+ )
528
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
529
+
530
+
531
+ lora_app = typer.Typer(help="mlx-lm-lora passthrough commands")
532
+ app.add_typer(lora_app, name="lora")
533
+
534
+
535
+ @lora_app.command(
536
+ "train",
537
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
538
+ )
539
+ def lora_train(
540
+ ctx: typer.Context,
541
+ config: Optional[str] = typer.Option(None, "--config", help="mlx-lm-lora config path"),
542
+ model: Optional[str] = typer.Option(None, "--model", help="Model id or path"),
543
+ data: Optional[str] = typer.Option(None, "--data", help="Dataset path or HF dataset"),
544
+ train_mode: Optional[str] = typer.Option(None, "--train-mode", help="sft|dpo|orpo|grpo|ppo|..."),
545
+ train_type: Optional[str] = typer.Option(None, "--train-type", help="lora|dora|full"),
546
+ dry_run: bool = typer.Option(False, "--dry-run"),
547
+ ):
548
+ """Run mlx-lm-lora training with passthrough args.
549
+
550
+ Use `--` to pass through any additional mlx-lm-lora flags.
551
+ """
552
+ root = project_root_from_cwd()
553
+ cmd = build_mlx_lm_lora_train_command(
554
+ config=config,
555
+ model=model,
556
+ data=data,
557
+ train_mode=train_mode,
558
+ train_type=train_type,
559
+ extra_args=list(ctx.args),
560
+ )
561
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
562
+
563
+
564
+ @lora_app.command(
565
+ "synthetic",
566
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
567
+ )
568
+ def lora_synthetic(
569
+ ctx: typer.Context,
570
+ kind: str = typer.Argument(..., help="prompts|sft|dpo"),
571
+ dry_run: bool = typer.Option(False, "--dry-run"),
572
+ ):
573
+ """Run mlx-lm-lora synthetic dataset generation."""
574
+ root = project_root_from_cwd()
575
+ cmd = build_mlx_lm_lora_synth_command(kind, extra_args=list(ctx.args))
576
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
577
+
578
+
579
+ @lora_app.command(
580
+ "judge",
581
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
582
+ )
583
+ def lora_judge(
584
+ ctx: typer.Context,
585
+ dry_run: bool = typer.Option(False, "--dry-run"),
586
+ ):
587
+ """Run mlx-lm-lora judge model training."""
588
+ root = project_root_from_cwd()
589
+ cmd = build_mlx_lm_lora_judge_command(extra_args=list(ctx.args))
590
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
591
+
592
+
593
+ @lora_app.command(
594
+ "reward-functions",
595
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
596
+ )
597
+ def lora_reward_functions(
598
+ ctx: typer.Context,
599
+ dry_run: bool = typer.Option(False, "--dry-run"),
600
+ ):
601
+ """List mlx-lm-lora reward functions."""
602
+ root = project_root_from_cwd()
603
+ cmd = build_mlx_lm_lora_reward_functions_command(extra_args=list(ctx.args))
604
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
605
+
606
+
440
607
  @app.command()
441
608
  def eval(
442
609
  suite: str = typer.Option("eval/suites/coding.yaml", "--suite"),
@@ -933,6 +1100,25 @@ def verify(prompt: str, completion: str, workdir: str, **kwargs):
933
1100
  """
934
1101
 
935
1102
 
1103
+ def _sample_verifier_llm_judge() -> str:
1104
+ return """from mlxsmith.verifiers.llm_judge import verify as _verify
1105
+
1106
+ def verify(prompt: str, completion: str, workdir: str, **kwargs):
1107
+ # Pass model=... or set MLXSMITH_JUDGE_MODEL for the judge model id.
1108
+ return _verify(prompt, completion, workdir, **kwargs)
1109
+ """
1110
+
1111
+
1112
+ def _sample_judge_rubric() -> str:
1113
+ return """Score from 0.0 to 1.0.
1114
+ - 1.0: Correct, complete, and safe.
1115
+ - 0.7: Mostly correct with small issues.
1116
+ - 0.4: Partial correctness or unclear reasoning.
1117
+ - 0.0: Incorrect or unsafe.
1118
+ Return JSON only.
1119
+ """
1120
+
1121
+
936
1122
  def _sample_eval_suite() -> str:
937
1123
  return """name: coding-eval-sample
938
1124
  notes: |
mlxsmith/config_models.py CHANGED
@@ -47,6 +47,8 @@ class TrainConfig(BaseModel):
47
47
  grad_accum: int = 8
48
48
  lr: float = 2e-4
49
49
  weight_decay: float = 0.0
50
+ optimizer: str = "adamw"
51
+ optimizer_kwargs: Dict[str, Any] = Field(default_factory=dict)
50
52
  iters: int = 1000
51
53
  save_every: int = 100
52
54
  eval_every: int = 100
@@ -61,6 +63,11 @@ class TrainConfig(BaseModel):
61
63
  raise ValueError("value must be non-negative")
62
64
  return v
63
65
 
66
+ @field_validator("optimizer")
67
+ @classmethod
68
+ def normalize_optimizer(cls, v: str) -> str:
69
+ return v.strip().lower()
70
+
64
71
 
65
72
  class LoraConfig(BaseModel):
66
73
  """LoRA/DoRA adapter configuration."""
@@ -89,11 +96,13 @@ class LoraConfig(BaseModel):
89
96
 
90
97
 
91
98
  class PrefConfig(BaseModel):
92
- """Preference tuning configuration (DPO, ORPO, GRPO)."""
99
+ """Preference tuning configuration (DPO variants)."""
93
100
 
94
101
  algo: Literal["dpo", "orpo", "grpo"] = "dpo"
102
+ loss_type: Literal["dpo", "cpo", "orpo", "ipo", "hinge"] = "dpo"
95
103
  beta: float = 0.1
96
104
  kl_coeff: float = 0.0
105
+ delta: float = 0.0
97
106
  reference_model: Optional[str] = None
98
107
 
99
108
 
@@ -101,12 +110,16 @@ class RftConfig(BaseModel):
101
110
  """Reinforcement fine-tuning configuration."""
102
111
 
103
112
  algo: Literal["grpo"] = "grpo"
113
+ loss_type: Literal["grpo", "dr_grpo", "dapo"] = "grpo"
104
114
  rollouts: int = 8
105
115
  kl_coeff: float = 0.02
106
116
  max_steps_per_task: int = 1
107
117
  temperature: float = 0.8
108
118
  max_new_tokens: int = 256
109
119
  normalize_advantage: bool = True
120
+ epsilon_low: float = 0.2
121
+ epsilon_high: float = 0.2
122
+ token_level_loss: bool = False
110
123
  reference_model: Optional[str] = None
111
124
 
112
125
 
@@ -164,6 +177,7 @@ CLI_ALIASES: dict[str, tuple[str, ...]] = {
164
177
  "lr": ("train", "lr"),
165
178
  "batch_size": ("train", "batch_size"),
166
179
  "iters": ("train", "iters"),
180
+ "optimizer": ("train", "optimizer"),
167
181
  "model_id": ("model", "id"),
168
182
  "accel_backend": ("accel", "backend"),
169
183
  "host": ("serve", "host"),
@@ -0,0 +1,19 @@
1
+ """External integrations for mlxsmith."""
2
+
3
+ from .mlx_lm_lora import (
4
+ build_train_command as build_mlx_lm_lora_train_command,
5
+ build_synthetic_command as build_mlx_lm_lora_synth_command,
6
+ build_judge_command as build_mlx_lm_lora_judge_command,
7
+ build_reward_functions_command as build_mlx_lm_lora_reward_functions_command,
8
+ run_command as run_mlx_lm_lora_command,
9
+ ensure_available as ensure_mlx_lm_lora_available,
10
+ )
11
+
12
+ __all__ = [
13
+ "build_mlx_lm_lora_train_command",
14
+ "build_mlx_lm_lora_synth_command",
15
+ "build_mlx_lm_lora_judge_command",
16
+ "build_mlx_lm_lora_reward_functions_command",
17
+ "run_mlx_lm_lora_command",
18
+ "ensure_mlx_lm_lora_available",
19
+ ]
@@ -0,0 +1,117 @@
1
+ """Passthrough helpers for mlx-lm-lora CLI integration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib.util
6
+ import os
7
+ import shlex
8
+ import subprocess
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Optional, Sequence
12
+
13
+ from rich.console import Console
14
+
15
+ console = Console()
16
+
17
+
18
+ def ensure_available() -> None:
19
+ if importlib.util.find_spec("mlx_lm_lora") is None:
20
+ raise RuntimeError(
21
+ "mlx-lm-lora is not installed. Install with: pip install 'mlxsmith[lora]' or 'mlx-lm-lora'"
22
+ )
23
+
24
+
25
+ def _flag_present(args: Sequence[str], *flags: str) -> bool:
26
+ for flag in flags:
27
+ if flag in args:
28
+ return True
29
+ if flag.startswith("--"):
30
+ prefix = flag + "="
31
+ if any(a.startswith(prefix) for a in args):
32
+ return True
33
+ return False
34
+
35
+
36
+ def _append_flag(cmd: list[str], args: Sequence[str], flag: str, value: Optional[str]) -> None:
37
+ if value is None:
38
+ return
39
+ if _flag_present(args, flag):
40
+ return
41
+ cmd.extend([flag, value])
42
+
43
+
44
+ def _base_python_cmd(module: str) -> list[str]:
45
+ return [sys.executable, "-m", module]
46
+
47
+
48
+ def build_train_command(
49
+ *,
50
+ config: Optional[str] = None,
51
+ model: Optional[str] = None,
52
+ data: Optional[str] = None,
53
+ train_mode: Optional[str] = None,
54
+ train_type: Optional[str] = None,
55
+ extra_args: Sequence[str] = (),
56
+ ) -> list[str]:
57
+ args = list(extra_args)
58
+ cmd = _base_python_cmd("mlx_lm_lora.train")
59
+ _append_flag(cmd, args, "--config", config)
60
+ _append_flag(cmd, args, "--model", model)
61
+ _append_flag(cmd, args, "--data", data)
62
+ _append_flag(cmd, args, "--train-mode", train_mode)
63
+ _append_flag(cmd, args, "--train-type", train_type)
64
+ cmd.extend(args)
65
+ return cmd
66
+
67
+
68
+ def build_synthetic_command(
69
+ kind: str,
70
+ *,
71
+ extra_args: Sequence[str] = (),
72
+ ) -> list[str]:
73
+ kind = kind.strip().lower()
74
+ module = {
75
+ "prompts": "mlx_lm_lora.synthetic_prompts",
76
+ "sft": "mlx_lm_lora.synthetic_sft",
77
+ "dpo": "mlx_lm_lora.synthetic_dpo",
78
+ }.get(kind)
79
+ if module is None:
80
+ raise ValueError(f"Unknown synthetic kind: {kind}")
81
+ cmd = _base_python_cmd(module)
82
+ cmd.extend(list(extra_args))
83
+ return cmd
84
+
85
+
86
+ def build_judge_command(*, extra_args: Sequence[str] = ()) -> list[str]:
87
+ cmd = _base_python_cmd("mlx_lm_lora.train_judge")
88
+ cmd.extend(list(extra_args))
89
+ return cmd
90
+
91
+
92
+ def build_reward_functions_command(*, extra_args: Sequence[str] = ()) -> list[str]:
93
+ cmd = _base_python_cmd("mlx_lm_lora.train")
94
+ cmd.append("--list-reward-functions")
95
+ cmd.extend(list(extra_args))
96
+ return cmd
97
+
98
+
99
+ def run_command(
100
+ cmd: Sequence[str],
101
+ *,
102
+ dry_run: bool = False,
103
+ cwd: Optional[Path] = None,
104
+ env: Optional[dict] = None,
105
+ ) -> int:
106
+ if dry_run:
107
+ console.print("[cyan]mlx-lm-lora cmd[/cyan]", shlex.join(list(cmd)))
108
+ return 0
109
+ ensure_available()
110
+ run_env = os.environ.copy()
111
+ if env:
112
+ run_env.update(env)
113
+ console.print("[cyan]mlx-lm-lora cmd[/cyan]", shlex.join(list(cmd)))
114
+ result = subprocess.run(list(cmd), cwd=str(cwd) if cwd else None, env=run_env, check=False)
115
+ if result.returncode != 0:
116
+ raise RuntimeError(f"mlx-lm-lora failed with exit code {result.returncode}")
117
+ return result.returncode
mlxsmith/llm/backend.py CHANGED
@@ -112,7 +112,14 @@ class LLMBackend(Protocol):
112
112
  def value_and_grad(self, loss_fn) -> tuple[Any, Any | None]:
113
113
  """Return (loss, grads) using backend autograd when available."""
114
114
 
115
- def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0) -> tuple[Any, Any]:
115
+ def optimizer_and_params(
116
+ self,
117
+ *,
118
+ lr: float,
119
+ weight_decay: float = 0.0,
120
+ optimizer: str | None = None,
121
+ optimizer_kwargs: dict | None = None,
122
+ ) -> tuple[Any, Any]:
116
123
  """Return (optimizer, trainable_params_tree)."""
117
124
 
118
125
  def apply_grads(self, optimizer: Any, grads: Any) -> None:
@@ -467,7 +467,54 @@ class MlxLMBackend:
467
467
  return vag(self.model, loss_fn)(self.model)
468
468
  return loss_fn(self.model), None
469
469
 
470
- def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0) -> tuple[Any, Any]:
470
+ def optimizer_and_params(
471
+ self,
472
+ *,
473
+ lr: float,
474
+ weight_decay: float = 0.0,
475
+ optimizer: str | None = None,
476
+ optimizer_kwargs: dict | None = None,
477
+ ) -> tuple[Any, Any]:
478
+ return self._optimizer_and_params(
479
+ lr=lr,
480
+ weight_decay=weight_decay,
481
+ optimizer=optimizer,
482
+ optimizer_kwargs=optimizer_kwargs,
483
+ )
484
+
485
+ def _resolve_optimizer(self, name: str):
486
+ assert self.optim is not None
487
+ name = name.strip().lower()
488
+ if name in ("muon", "muonclip", "muon_clip"):
489
+ from ..optim.muon import Muon, MuonClip
490
+
491
+ return MuonClip if name in ("muonclip", "muon_clip") else Muon
492
+ mapping = {
493
+ "adamw": ["AdamW", "Adamw", "adamw"],
494
+ "adam": ["Adam", "adam"],
495
+ "sgd": ["SGD", "Sgd", "sgd"],
496
+ "rmsprop": ["RMSprop", "RmsProp", "rmsprop"],
497
+ "qhadam": ["QHAdam", "Qhadam", "qhadam"],
498
+ "muon": ["Muon", "muon", "MuonW", "muonw"],
499
+ }
500
+ candidates = mapping.get(name, [name, name.capitalize(), name.upper()])
501
+ for cand in candidates:
502
+ opt_cls = getattr(self.optim, cand, None)
503
+ if opt_cls is not None:
504
+ return opt_cls
505
+ raise RuntimeError(
506
+ f"Optimizer '{name}' is not available in mlx.optimizers. "
507
+ "Install a build that provides it or choose a supported optimizer."
508
+ )
509
+
510
+ def _optimizer_and_params(
511
+ self,
512
+ *,
513
+ lr: float,
514
+ weight_decay: float = 0.0,
515
+ optimizer: str | None = None,
516
+ optimizer_kwargs: dict | None = None,
517
+ ) -> tuple[Any, Any]:
471
518
  assert self.optim is not None
472
519
  if self.model is None:
473
520
  raise RuntimeError("Backend not loaded")
@@ -483,7 +530,17 @@ class MlxLMBackend:
483
530
  if not params:
484
531
  params = getattr(self.model, "parameters", lambda: self.model)()
485
532
 
486
- opt = self.optim.AdamW(learning_rate=lr, weight_decay=weight_decay)
533
+ opt_name = optimizer or "adamw"
534
+ opt_cls = self._resolve_optimizer(opt_name)
535
+ kwargs = dict(optimizer_kwargs or {})
536
+ if "learning_rate" not in kwargs and "lr" not in kwargs:
537
+ kwargs["learning_rate"] = lr
538
+ kwargs["lr"] = lr
539
+ if "weight_decay" not in kwargs:
540
+ kwargs["weight_decay"] = weight_decay
541
+ opt = self._call_with_supported_kwargs(opt_cls, **kwargs)
542
+ if opt is None:
543
+ opt = opt_cls()
487
544
  opt.init(params)
488
545
  return opt, params
489
546
 
@@ -200,7 +200,14 @@ class MockBackend:
200
200
  def value_and_grad(self, loss_fn):
201
201
  return loss_fn(self.model), None
202
202
 
203
- def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0):
203
+ def optimizer_and_params(
204
+ self,
205
+ *,
206
+ lr: float,
207
+ weight_decay: float = 0.0,
208
+ optimizer: str | None = None,
209
+ optimizer_kwargs: dict | None = None,
210
+ ):
204
211
  return object(), {}
205
212
 
206
213
  def apply_grads(self, optimizer: Any, grads: Any) -> None:
@@ -0,0 +1,3 @@
1
+ from .muon import Muon, MuonClip
2
+
3
+ __all__ = ["Muon", "MuonClip"]
mlxsmith/optim/muon.py ADDED
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional
4
+
5
+ from ..util import tree_map
6
+
7
+
8
+ class Muon:
9
+ """Muon optimizer wrapper (AdamW with Newton-Schulz orthogonalization for 2D grads)."""
10
+
11
+ def __init__(
12
+ self,
13
+ learning_rate: Optional[float] = None,
14
+ lr: Optional[float] = None,
15
+ weight_decay: float = 0.0,
16
+ clip: Optional[float] = None,
17
+ ns_iters: int = 5,
18
+ a: float = 3.4445,
19
+ b: float = -4.7750,
20
+ c: float = 2.0315,
21
+ ) -> None:
22
+ import mlx.optimizers as optim # type: ignore
23
+
24
+ self.learning_rate = float(lr if lr is not None else learning_rate if learning_rate is not None else 1e-4)
25
+ self.weight_decay = float(weight_decay)
26
+ self.clip = clip
27
+ self.ns_iters = int(ns_iters)
28
+ self.a = float(a)
29
+ self.b = float(b)
30
+ self.c = float(c)
31
+ self._base = optim.AdamW(learning_rate=self.learning_rate, weight_decay=self.weight_decay)
32
+ self.state: dict[str, Any] = {}
33
+
34
+ def init(self, params: Any) -> None:
35
+ self._base.init(params)
36
+ self.state = getattr(self._base, "state", {})
37
+
38
+ def _clip_grad(self, g):
39
+ if self.clip is None or self.clip <= 0:
40
+ return g
41
+ import mlx.core as mx # type: ignore
42
+
43
+ norm = mx.sqrt((g * g).sum())
44
+ scale = mx.minimum(mx.array(1.0), mx.array(float(self.clip)) / mx.maximum(norm, mx.array(1e-8)))
45
+ return g * scale
46
+
47
+ def _orthogonalize(self, g):
48
+ import mlx.core as mx # type: ignore
49
+
50
+ g = self._clip_grad(g)
51
+ if getattr(g, "ndim", None) != 2:
52
+ return g
53
+ in_dim = int(g.shape[1])
54
+ eye = mx.eye(in_dim, dtype=g.dtype)
55
+ out = g
56
+ for _ in range(self.ns_iters):
57
+ gtg = mx.matmul(out.T, out)
58
+ gtg2 = mx.matmul(gtg, gtg)
59
+ out = mx.matmul(out, self.a * eye + self.b * gtg + self.c * gtg2)
60
+ return out
61
+
62
+ def update(self, model: Any, grads: Any) -> None:
63
+ if grads is None:
64
+ return
65
+ transformed = tree_map(self._orthogonalize, grads)
66
+ self._base.update(model, transformed)
67
+ self.state = getattr(self._base, "state", self.state)
68
+
69
+
70
+ class MuonClip(Muon):
71
+ """Muon with explicit gradient clipping before orthogonalization."""
72
+
73
+ def __init__(
74
+ self,
75
+ learning_rate: Optional[float] = None,
76
+ lr: Optional[float] = None,
77
+ weight_decay: float = 0.0,
78
+ clip: Optional[float] = 1.0,
79
+ ns_iters: int = 5,
80
+ a: float = 3.4445,
81
+ b: float = -4.7750,
82
+ c: float = 2.0315,
83
+ ) -> None:
84
+ super().__init__(
85
+ learning_rate=learning_rate,
86
+ lr=lr,
87
+ weight_decay=weight_decay,
88
+ clip=clip,
89
+ ns_iters=ns_iters,
90
+ a=a,
91
+ b=b,
92
+ c=c,
93
+ )