mlxsmith 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. {mlxsmith-0.1.1/src/mlxsmith.egg-info → mlxsmith-0.1.3}/PKG-INFO +29 -13
  2. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/README.md +25 -10
  3. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/pyproject.toml +3 -2
  4. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/accel/__init__.py +0 -3
  5. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/bench.py +12 -2
  6. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/cli.py +188 -3
  7. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/config_models.py +16 -2
  8. mlxsmith-0.1.3/src/mlxsmith/integrations/__init__.py +19 -0
  9. mlxsmith-0.1.3/src/mlxsmith/integrations/mlx_lm_lora.py +117 -0
  10. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/llm/backend.py +8 -1
  11. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/llm/mlx_lm_backend.py +59 -2
  12. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/llm/mock_backend.py +8 -1
  13. mlxsmith-0.1.3/src/mlxsmith/optim/__init__.py +3 -0
  14. mlxsmith-0.1.3/src/mlxsmith/optim/muon.py +93 -0
  15. mlxsmith-0.1.3/src/mlxsmith/orchestrator/daemon.py +116 -0
  16. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/orchestrator/trainer_worker.py +4 -0
  17. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/loop.py +53 -92
  18. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/sdk/__init__.py +18 -2
  19. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/sdk/losses.py +102 -1
  20. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/sdk/training_client.py +24 -5
  21. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/train/distill.py +6 -1
  22. mlxsmith-0.1.3/src/mlxsmith/train/online_dpo.py +249 -0
  23. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/train/pref.py +31 -29
  24. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/train/rft.py +123 -38
  25. mlxsmith-0.1.3/src/mlxsmith/train/self_verify.py +199 -0
  26. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/train/sft.py +13 -2
  27. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/util.py +0 -6
  28. mlxsmith-0.1.3/src/mlxsmith/verifiers/llm_judge.py +278 -0
  29. mlxsmith-0.1.3/src/mlxsmith/verifiers/prime.py +127 -0
  30. {mlxsmith-0.1.1 → mlxsmith-0.1.3/src/mlxsmith.egg-info}/PKG-INFO +29 -13
  31. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith.egg-info/SOURCES.txt +11 -1
  32. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith.egg-info/requires.txt +4 -3
  33. mlxsmith-0.1.3/tests/test_lora_integration.py +47 -0
  34. mlxsmith-0.1.3/tests/test_online_dpo_self_verify.py +61 -0
  35. mlxsmith-0.1.3/tests/test_pref_variants.py +31 -0
  36. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_sdk.py +17 -0
  37. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_training_smoke.py +6 -0
  38. mlxsmith-0.1.3/tests/test_verifiers.py +73 -0
  39. mlxsmith-0.1.1/src/mlxsmith/accel/zmlx_backend.py +0 -42
  40. mlxsmith-0.1.1/src/mlxsmith/orchestrator/daemon.py +0 -449
  41. mlxsmith-0.1.1/tests/test_verifiers.py +0 -25
  42. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/LICENSE +0 -0
  43. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/setup.cfg +0 -0
  44. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/__init__.py +0 -0
  45. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/accel/base.py +0 -0
  46. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/accel/none.py +0 -0
  47. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/adapters.py +0 -0
  48. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/api/__init__.py +0 -0
  49. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/api/handlers.py +0 -0
  50. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/api/schemas.py +0 -0
  51. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/auth.py +0 -0
  52. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/config.py +0 -0
  53. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/data.py +0 -0
  54. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/envs/__init__.py +0 -0
  55. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/envs/system.py +0 -0
  56. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/envs/token_env.py +0 -0
  57. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/eval.py +0 -0
  58. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/infer.py +0 -0
  59. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/llm/__init__.py +0 -0
  60. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/llm/interface.py +0 -0
  61. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/llm/registry.py +0 -0
  62. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/models.py +0 -0
  63. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/orchestrator/__init__.py +0 -0
  64. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/orchestrator/inference_worker.py +0 -0
  65. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/orchestrator/queue.py +0 -0
  66. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/__init__.py +0 -0
  67. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/corpus.py +0 -0
  68. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/gating.py +0 -0
  69. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/generate.py +0 -0
  70. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/history.py +0 -0
  71. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/inference.py +0 -0
  72. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/mutate.py +0 -0
  73. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/trainer.py +0 -0
  74. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/rlm/weights.py +0 -0
  75. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/runs.py +0 -0
  76. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/sdk/future.py +0 -0
  77. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/sdk/sampling_client.py +0 -0
  78. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/server.py +0 -0
  79. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/train/__init__.py +0 -0
  80. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/train/lora.py +0 -0
  81. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/__init__.py +0 -0
  82. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/compose.py +0 -0
  83. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/docker_verifier.py +0 -0
  84. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/jsonschema.py +0 -0
  85. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/pytest_verifier.py +0 -0
  86. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/regex.py +0 -0
  87. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith/verifiers/types.py +0 -0
  88. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith.egg-info/dependency_links.txt +0 -0
  89. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith.egg-info/entry_points.txt +0 -0
  90. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/src/mlxsmith.egg-info/top_level.txt +0 -0
  91. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_api.py +0 -0
  92. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_auth.py +0 -0
  93. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_config.py +0 -0
  94. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_data.py +0 -0
  95. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_rlm.py +0 -0
  96. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_rlm_mutation.py +0 -0
  97. {mlxsmith-0.1.1 → mlxsmith-0.1.3}/tests/test_runs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mlxsmith
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: Apple Silicon MLX fine-tuning toolkit — SFT, DPO/ORPO, GRPO, distillation, and OpenAI-compatible serving.
5
5
  Author-email: Shannon Labs <hmbown@gmail.com>
6
6
  License: MIT
@@ -36,18 +36,19 @@ Provides-Extra: llm
36
36
  Requires-Dist: mlx-lm>=0.30.5; extra == "llm"
37
37
  Requires-Dist: transformers>=5.0.0; extra == "llm"
38
38
  Requires-Dist: datasets>=3.0.0; extra == "llm"
39
+ Provides-Extra: lora
40
+ Requires-Dist: mlx-lm-lora>=1.0.0; extra == "lora"
39
41
  Provides-Extra: serve
40
42
  Requires-Dist: fastapi>=0.128.0; extra == "serve"
41
43
  Requires-Dist: uvicorn>=0.40.0; extra == "serve"
42
44
  Requires-Dist: httpx>=0.28.0; extra == "serve"
43
- Provides-Extra: zmlx
44
- Requires-Dist: zmlx; extra == "zmlx"
45
45
  Provides-Extra: dev
46
46
  Requires-Dist: pytest>=9.0.0; extra == "dev"
47
47
  Requires-Dist: ruff>=0.14.0; extra == "dev"
48
48
  Provides-Extra: all
49
49
  Requires-Dist: mlx>=0.30.4; extra == "all"
50
50
  Requires-Dist: mlx-lm>=0.30.5; extra == "all"
51
+ Requires-Dist: mlx-lm-lora>=1.0.0; extra == "all"
51
52
  Requires-Dist: transformers>=5.0.0; extra == "all"
52
53
  Requires-Dist: datasets>=3.0.0; extra == "all"
53
54
  Requires-Dist: fastapi>=0.128.0; extra == "all"
@@ -59,7 +60,7 @@ Dynamic: license-file
59
60
 
60
61
  Apple Silicon MLX fine-tuning toolkit — SFT, DPO/ORPO, GRPO, distillation, and OpenAI-compatible serving.
61
62
 
62
- **Status:** alpha (v0.1.0). Full training pipeline validated on Qwen3-4B.
63
+ **Status:** alpha (v0.1.2). Full training pipeline validated on Qwen3-4B.
63
64
 
64
65
  ## Install
65
66
 
@@ -76,6 +77,9 @@ pip install mlxsmith
76
77
  # Apple Silicon training + serving
77
78
  pip install "mlxsmith[mlx,llm,serve]"
78
79
 
80
+ # mlx-lm-lora passthrough (advanced training methods)
81
+ pip install "mlxsmith[lora]"
82
+
79
83
  # Everything
80
84
  pip install "mlxsmith[all]"
81
85
  ```
@@ -85,7 +89,7 @@ pip install "mlxsmith[all]"
85
89
  ```bash
86
90
  mlxsmith init myproj
87
91
  cd myproj
88
- mlxsmith doctor # check Python, MLX, Metal, ZMLX
92
+ mlxsmith doctor # check Python, MLX, Metal
89
93
  ```
90
94
 
91
95
  ## Training
@@ -133,6 +137,22 @@ mlxsmith distill --teacher large-model --student small-model --mode opd
133
137
  mlxsmith pipeline
134
138
  ```
135
139
 
140
+ ### mlx-lm-lora parity (all methods)
141
+
142
+ Use the passthrough to access mlx-lm-lora features (DPO variants, GRPO variants,
143
+ PPO, synthetic datasets, judge training, etc.):
144
+
145
+ ```bash
146
+ # Train with mlx-lm-lora directly
147
+ mlxsmith lora train --model Qwen/Qwen3-4B-Instruct-2507 --data data/prefs --train-mode dpo -- --beta 0.1
148
+
149
+ # Generate synthetic datasets
150
+ mlxsmith lora synthetic prompts -- --model mlx-community/Qwen3-4B-Instruct-2507-4bit --num-samples 1000
151
+
152
+ # Train judge model
153
+ mlxsmith lora judge -- --model mlx-community/Qwen3-4B-Instruct-2507-4bit --data data/prefs
154
+ ```
155
+
136
156
  ## Serving
137
157
 
138
158
  OpenAI-compatible `/v1/chat/completions` endpoint.
@@ -204,6 +224,7 @@ Built-in verifiers for eval, RFT, and preference tuning:
204
224
  - **pytest** — sandboxed test execution
205
225
  - **docker** — containerized verification
206
226
  - **compose** — multi-verifier composition (AND/OR/weighted)
227
+ - **llm_judge** — LLM-based self-verification / ThinkPRM-style verifier
207
228
 
208
229
  See `docs/VERIFIERS.md` for the verifier API.
209
230
 
@@ -232,6 +253,9 @@ mlxsmith config env # show environment variable mapping
232
253
 
233
254
  Config sources (in priority order): CLI flags > environment variables (`MLXSMITH__SECTION__KEY`) > config file > defaults.
234
255
 
256
+ Training optimizers are configurable via `train.optimizer` and `train.optimizer_kwargs`
257
+ (for example `adamw`, `adam`, `qhadam`, `muon` when available in MLX).
258
+
235
259
  ## SDK (programmatic API)
236
260
 
237
261
  For building custom training loops:
@@ -269,14 +293,6 @@ mlxsmith rlm history # view history
269
293
 
270
294
  Includes task generation, mutation for data diversity, corpus management, EMA-based gating, and weight pointer IPC for multi-process coordination. See `docs/orchestrator.md`.
271
295
 
272
- ### ZMLX acceleration
273
-
274
- Optional zero-copy MLX acceleration backend.
275
-
276
- ```bash
277
- mlxsmith accel status
278
- ```
279
-
280
296
  ## Docs
281
297
 
282
298
  - `docs/PROJECT_FORMAT.md` — project layout and artifacts
@@ -2,7 +2,7 @@
2
2
 
3
3
  Apple Silicon MLX fine-tuning toolkit — SFT, DPO/ORPO, GRPO, distillation, and OpenAI-compatible serving.
4
4
 
5
- **Status:** alpha (v0.1.0). Full training pipeline validated on Qwen3-4B.
5
+ **Status:** alpha (v0.1.2). Full training pipeline validated on Qwen3-4B.
6
6
 
7
7
  ## Install
8
8
 
@@ -19,6 +19,9 @@ pip install mlxsmith
19
19
  # Apple Silicon training + serving
20
20
  pip install "mlxsmith[mlx,llm,serve]"
21
21
 
22
+ # mlx-lm-lora passthrough (advanced training methods)
23
+ pip install "mlxsmith[lora]"
24
+
22
25
  # Everything
23
26
  pip install "mlxsmith[all]"
24
27
  ```
@@ -28,7 +31,7 @@ pip install "mlxsmith[all]"
28
31
  ```bash
29
32
  mlxsmith init myproj
30
33
  cd myproj
31
- mlxsmith doctor # check Python, MLX, Metal, ZMLX
34
+ mlxsmith doctor # check Python, MLX, Metal
32
35
  ```
33
36
 
34
37
  ## Training
@@ -76,6 +79,22 @@ mlxsmith distill --teacher large-model --student small-model --mode opd
76
79
  mlxsmith pipeline
77
80
  ```
78
81
 
82
+ ### mlx-lm-lora parity (all methods)
83
+
84
+ Use the passthrough to access mlx-lm-lora features (DPO variants, GRPO variants,
85
+ PPO, synthetic datasets, judge training, etc.):
86
+
87
+ ```bash
88
+ # Train with mlx-lm-lora directly
89
+ mlxsmith lora train --model Qwen/Qwen3-4B-Instruct-2507 --data data/prefs --train-mode dpo -- --beta 0.1
90
+
91
+ # Generate synthetic datasets
92
+ mlxsmith lora synthetic prompts -- --model mlx-community/Qwen3-4B-Instruct-2507-4bit --num-samples 1000
93
+
94
+ # Train judge model
95
+ mlxsmith lora judge -- --model mlx-community/Qwen3-4B-Instruct-2507-4bit --data data/prefs
96
+ ```
97
+
79
98
  ## Serving
80
99
 
81
100
  OpenAI-compatible `/v1/chat/completions` endpoint.
@@ -147,6 +166,7 @@ Built-in verifiers for eval, RFT, and preference tuning:
147
166
  - **pytest** — sandboxed test execution
148
167
  - **docker** — containerized verification
149
168
  - **compose** — multi-verifier composition (AND/OR/weighted)
169
+ - **llm_judge** — LLM-based self-verification / ThinkPRM-style verifier
150
170
 
151
171
  See `docs/VERIFIERS.md` for the verifier API.
152
172
 
@@ -175,6 +195,9 @@ mlxsmith config env # show environment variable mapping
175
195
 
176
196
  Config sources (in priority order): CLI flags > environment variables (`MLXSMITH__SECTION__KEY`) > config file > defaults.
177
197
 
198
+ Training optimizers are configurable via `train.optimizer` and `train.optimizer_kwargs`
199
+ (for example `adamw`, `adam`, `qhadam`, `muon` when available in MLX).
200
+
178
201
  ## SDK (programmatic API)
179
202
 
180
203
  For building custom training loops:
@@ -212,14 +235,6 @@ mlxsmith rlm history # view history
212
235
 
213
236
  Includes task generation, mutation for data diversity, corpus management, EMA-based gating, and weight pointer IPC for multi-process coordination. See `docs/orchestrator.md`.
214
237
 
215
- ### ZMLX acceleration
216
-
217
- Optional zero-copy MLX acceleration backend.
218
-
219
- ```bash
220
- mlxsmith accel status
221
- ```
222
-
223
238
  ## Docs
224
239
 
225
240
  - `docs/PROJECT_FORMAT.md` — project layout and artifacts
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "mlxsmith"
7
- version = "0.1.1"
7
+ version = "0.1.3"
8
8
  description = "Apple Silicon MLX fine-tuning toolkit — SFT, DPO/ORPO, GRPO, distillation, and OpenAI-compatible serving."
9
9
  readme = {file = "README.md", content-type = "text/markdown"}
10
10
  requires-python = ">=3.10"
@@ -47,16 +47,17 @@ llm = [
47
47
  "transformers>=5.0.0",
48
48
  "datasets>=3.0.0",
49
49
  ]
50
+ lora = ["mlx-lm-lora>=1.0.0"]
50
51
  serve = [
51
52
  "fastapi>=0.128.0",
52
53
  "uvicorn>=0.40.0",
53
54
  "httpx>=0.28.0",
54
55
  ]
55
- zmlx = ["zmlx"]
56
56
  dev = ["pytest>=9.0.0", "ruff>=0.14.0"]
57
57
  all = [
58
58
  "mlx>=0.30.4",
59
59
  "mlx-lm>=0.30.5",
60
+ "mlx-lm-lora>=1.0.0",
60
61
  "transformers>=5.0.0",
61
62
  "datasets>=3.0.0",
62
63
  "fastapi>=0.128.0",
@@ -1,10 +1,7 @@
1
1
  from __future__ import annotations
2
2
  from .none import NoneBackend
3
- from .zmlx_backend import ZMLXBackend
4
3
 
5
4
  def get_backend(name: str):
6
5
  if name == "none":
7
6
  return NoneBackend()
8
- if name == "zmlx":
9
- return ZMLXBackend()
10
7
  raise ValueError(f"Unknown accel backend: {name}")
@@ -44,7 +44,12 @@ def run_bench(
44
44
  mode = (mode or "inference").lower()
45
45
 
46
46
  if mode == "trainer":
47
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
47
+ opt, _params = llm.optimizer_and_params(
48
+ lr=cfg.train.lr,
49
+ weight_decay=cfg.train.weight_decay,
50
+ optimizer=cfg.train.optimizer,
51
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
52
+ )
48
53
  prompt_ids = llm.encode(prompt)
49
54
  ids = llm.encode(prompt + " " + "x" * max_tokens)
50
55
  for i in range(max(1, reps)):
@@ -59,7 +64,12 @@ def run_bench(
59
64
  elapsed = max(time.time() - t0, 1e-6)
60
65
  results.append({"rep": i, "steps": steps, "time_s": elapsed, "steps_per_s": steps / elapsed})
61
66
  elif mode == "end_to_end":
62
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
67
+ opt, _params = llm.optimizer_and_params(
68
+ lr=cfg.train.lr,
69
+ weight_decay=cfg.train.weight_decay,
70
+ optimizer=cfg.train.optimizer,
71
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
72
+ )
63
73
  for i in range(max(1, reps)):
64
74
  t0 = time.time()
65
75
  gen = llm.generate(prompt, max_new_tokens=max_tokens, temperature=0.0)
@@ -24,6 +24,8 @@ from .train.sft import run_sft
24
24
  from .train.pref import run_pref
25
25
  from .train.rft import run_rft
26
26
  from .train.distill import run_distill
27
+ from .train.online_dpo import run_online_dpo
28
+ from .train.self_verify import run_self_verify
27
29
  from .eval import run_eval
28
30
  from .bench import run_bench
29
31
  from .rlm import run_rlm, run_rlm_orchestrated
@@ -40,6 +42,13 @@ from .envs import (
40
42
  resolve_env_path as resolve_env_path_plugin,
41
43
  load_manifest as load_env_manifest,
42
44
  )
45
+ from .integrations.mlx_lm_lora import (
46
+ build_train_command as build_mlx_lm_lora_train_command,
47
+ build_synthetic_command as build_mlx_lm_lora_synth_command,
48
+ build_judge_command as build_mlx_lm_lora_judge_command,
49
+ build_reward_functions_command as build_mlx_lm_lora_reward_functions_command,
50
+ run_command as run_mlx_lm_lora_command,
51
+ )
43
52
 
44
53
  app = typer.Typer(
45
54
  add_completion=False,
@@ -65,6 +74,9 @@ def init(path: str = typer.Argument(..., help="Project directory to create")):
65
74
  (p / "verifiers" / "regex.py").write_text(_sample_verifier_regex(), encoding="utf-8")
66
75
  (p / "verifiers" / "pytest.py").write_text(_sample_verifier_pytest(), encoding="utf-8")
67
76
  (p / "verifiers" / "jsonschema.py").write_text(_sample_verifier_jsonschema(), encoding="utf-8")
77
+ (p / "verifiers" / "llm_judge.py").write_text(_sample_verifier_llm_judge(), encoding="utf-8")
78
+ (p / "verifiers" / "rubrics").mkdir(parents=True, exist_ok=True)
79
+ (p / "verifiers" / "rubrics" / "coding.txt").write_text(_sample_judge_rubric(), encoding="utf-8")
68
80
  (p / "eval" / "suites" / "coding.yaml").write_text(_sample_eval_suite(), encoding="utf-8")
69
81
  console.print(f"[green]Initialized[/green] {p.resolve()}")
70
82
 
@@ -83,7 +95,6 @@ def doctor():
83
95
  table.add_row("cpu_count", str(info.cpu_count))
84
96
  table.add_row("metal", str(info.has_metal))
85
97
  table.add_row("mlx", f"{info.has_mlx} {info.mlx_version or ''}".strip())
86
- table.add_row("zmlx", str(info.has_zmlx))
87
98
  console.print(table)
88
99
 
89
100
 
@@ -342,14 +353,19 @@ def pref(
342
353
  data: str = typer.Option("data/prefs", "--data"),
343
354
  model: str = typer.Option(..., "--model", help="Base adapter or model path (e.g., runs/sft_0001/adapter)"),
344
355
  accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
345
- algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (dpo|orpo|grpo)"),
356
+ algo: Optional[str] = typer.Option(None, "--algo", help="Override pref.algo (legacy)"),
357
+ loss_type: Optional[str] = typer.Option(None, "--loss-type", help="dpo|cpo|orpo|ipo|hinge"),
346
358
  ):
347
359
  root = project_root_from_cwd()
360
+ overrides = {}
361
+ if loss_type is not None:
362
+ overrides["pref.loss_type"] = loss_type
348
363
  cfg = get_config(
349
364
  config_path=config,
350
365
  root=root,
351
366
  accel_backend=accel,
352
367
  algo=algo,
368
+ **overrides,
353
369
  )
354
370
  data_dir = root / data
355
371
  run = run_pref(root, cfg, data_dir, Path(model), cfg.accel.backend)
@@ -364,13 +380,27 @@ def rft(
364
380
  model: str = typer.Option(..., "--model"),
365
381
  accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
366
382
  rollouts: Optional[int] = typer.Option(None, "--rollouts", help="Override rft.rollouts"),
383
+ loss_type: Optional[str] = typer.Option(None, "--loss-type", help="grpo|dr_grpo|dapo"),
384
+ epsilon_low: Optional[float] = typer.Option(None, "--epsilon-low"),
385
+ epsilon_high: Optional[float] = typer.Option(None, "--epsilon-high"),
386
+ token_level_loss: Optional[bool] = typer.Option(None, "--token-level-loss/--sequence-level-loss"),
367
387
  ):
368
388
  root = project_root_from_cwd()
389
+ overrides = {}
390
+ if loss_type is not None:
391
+ overrides["rft.loss_type"] = loss_type
392
+ if epsilon_low is not None:
393
+ overrides["rft.epsilon_low"] = epsilon_low
394
+ if epsilon_high is not None:
395
+ overrides["rft.epsilon_high"] = epsilon_high
396
+ if token_level_loss is not None:
397
+ overrides["rft.token_level_loss"] = token_level_loss
369
398
  cfg = get_config(
370
399
  config_path=config,
371
400
  root=root,
372
401
  accel_backend=accel,
373
402
  rollouts=rollouts,
403
+ **overrides,
374
404
  )
375
405
  run = run_rft(root, cfg, root / env, root / verifier, Path(model), cfg.accel.backend)
376
406
  console.print(f"[bold]Run:[/bold] {run.run_dir}")
@@ -438,6 +468,142 @@ def distill(
438
468
  console.print(f"[bold]Run:[/bold] {run.run_dir}")
439
469
 
440
470
 
471
+ @app.command("online-dpo")
472
+ def online_dpo(
473
+ data: str = typer.Option(..., "--data", help="JSONL with prompts"),
474
+ model: str = typer.Option(..., "--model"),
475
+ judge_model: Optional[str] = typer.Option(None, "--judge-model"),
476
+ judge_backend: str = typer.Option("mlx-lm", "--judge-backend"),
477
+ rubric: Optional[str] = typer.Option(None, "--rubric"),
478
+ group_size: Optional[int] = typer.Option(None, "--group-size"),
479
+ max_new_tokens: Optional[int] = typer.Option(None, "--max-new-tokens"),
480
+ temperature: Optional[float] = typer.Option(None, "--temperature"),
481
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
482
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
483
+ ):
484
+ root = project_root_from_cwd()
485
+ cfg = get_config(config_path=config, root=root, accel_backend=accel)
486
+ run = run_online_dpo(
487
+ root,
488
+ cfg,
489
+ Path(data),
490
+ model,
491
+ cfg.accel.backend,
492
+ judge_model=judge_model,
493
+ judge_backend=judge_backend,
494
+ rubric=rubric,
495
+ group_size=group_size,
496
+ max_new_tokens=max_new_tokens,
497
+ temperature=temperature,
498
+ )
499
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
500
+
501
+
502
+ @app.command("self-verify")
503
+ def self_verify(
504
+ data: str = typer.Option(..., "--data", help="JSONL with prompts"),
505
+ model: str = typer.Option(..., "--model"),
506
+ verifier_model: Optional[str] = typer.Option(None, "--verifier-model"),
507
+ verifier_backend: str = typer.Option("mlx-lm", "--verifier-backend"),
508
+ rubric: Optional[str] = typer.Option(None, "--rubric"),
509
+ max_new_tokens: Optional[int] = typer.Option(None, "--max-new-tokens"),
510
+ temperature: Optional[float] = typer.Option(None, "--temperature"),
511
+ config: str = typer.Option("mlxsmith.yaml", "-c", "--config", help="Config file path"),
512
+ accel: Optional[str] = typer.Option(None, "--accel", help="Override accel.backend"),
513
+ ):
514
+ root = project_root_from_cwd()
515
+ cfg = get_config(config_path=config, root=root, accel_backend=accel)
516
+ run = run_self_verify(
517
+ root,
518
+ cfg,
519
+ Path(data),
520
+ model,
521
+ cfg.accel.backend,
522
+ verifier_model=verifier_model,
523
+ verifier_backend=verifier_backend,
524
+ rubric=rubric,
525
+ max_new_tokens=max_new_tokens,
526
+ temperature=temperature,
527
+ )
528
+ console.print(f"[bold]Run:[/bold] {run.run_dir}")
529
+
530
+
531
+ lora_app = typer.Typer(help="mlx-lm-lora passthrough commands")
532
+ app.add_typer(lora_app, name="lora")
533
+
534
+
535
+ @lora_app.command(
536
+ "train",
537
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
538
+ )
539
+ def lora_train(
540
+ ctx: typer.Context,
541
+ config: Optional[str] = typer.Option(None, "--config", help="mlx-lm-lora config path"),
542
+ model: Optional[str] = typer.Option(None, "--model", help="Model id or path"),
543
+ data: Optional[str] = typer.Option(None, "--data", help="Dataset path or HF dataset"),
544
+ train_mode: Optional[str] = typer.Option(None, "--train-mode", help="sft|dpo|orpo|grpo|ppo|..."),
545
+ train_type: Optional[str] = typer.Option(None, "--train-type", help="lora|dora|full"),
546
+ dry_run: bool = typer.Option(False, "--dry-run"),
547
+ ):
548
+ """Run mlx-lm-lora training with passthrough args.
549
+
550
+ Use `--` to pass through any additional mlx-lm-lora flags.
551
+ """
552
+ root = project_root_from_cwd()
553
+ cmd = build_mlx_lm_lora_train_command(
554
+ config=config,
555
+ model=model,
556
+ data=data,
557
+ train_mode=train_mode,
558
+ train_type=train_type,
559
+ extra_args=list(ctx.args),
560
+ )
561
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
562
+
563
+
564
+ @lora_app.command(
565
+ "synthetic",
566
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
567
+ )
568
+ def lora_synthetic(
569
+ ctx: typer.Context,
570
+ kind: str = typer.Argument(..., help="prompts|sft|dpo"),
571
+ dry_run: bool = typer.Option(False, "--dry-run"),
572
+ ):
573
+ """Run mlx-lm-lora synthetic dataset generation."""
574
+ root = project_root_from_cwd()
575
+ cmd = build_mlx_lm_lora_synth_command(kind, extra_args=list(ctx.args))
576
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
577
+
578
+
579
+ @lora_app.command(
580
+ "judge",
581
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
582
+ )
583
+ def lora_judge(
584
+ ctx: typer.Context,
585
+ dry_run: bool = typer.Option(False, "--dry-run"),
586
+ ):
587
+ """Run mlx-lm-lora judge model training."""
588
+ root = project_root_from_cwd()
589
+ cmd = build_mlx_lm_lora_judge_command(extra_args=list(ctx.args))
590
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
591
+
592
+
593
+ @lora_app.command(
594
+ "reward-functions",
595
+ context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
596
+ )
597
+ def lora_reward_functions(
598
+ ctx: typer.Context,
599
+ dry_run: bool = typer.Option(False, "--dry-run"),
600
+ ):
601
+ """List mlx-lm-lora reward functions."""
602
+ root = project_root_from_cwd()
603
+ cmd = build_mlx_lm_lora_reward_functions_command(extra_args=list(ctx.args))
604
+ run_mlx_lm_lora_command(cmd, dry_run=dry_run, cwd=root)
605
+
606
+
441
607
  @app.command()
442
608
  def eval(
443
609
  suite: str = typer.Option("eval/suites/coding.yaml", "--suite"),
@@ -729,7 +895,7 @@ def rlm_history(limit: int = typer.Option(10, "--limit")):
729
895
 
730
896
  @accel_app.command("status")
731
897
  def accel_status():
732
- backends = ["none", "zmlx"]
898
+ backends = ["none"]
733
899
  table = Table(title="mlxsmith accel status")
734
900
  table.add_column("backend")
735
901
  table.add_column("available")
@@ -934,6 +1100,25 @@ def verify(prompt: str, completion: str, workdir: str, **kwargs):
934
1100
  """
935
1101
 
936
1102
 
1103
+ def _sample_verifier_llm_judge() -> str:
1104
+ return """from mlxsmith.verifiers.llm_judge import verify as _verify
1105
+
1106
+ def verify(prompt: str, completion: str, workdir: str, **kwargs):
1107
+ # Pass model=... or set MLXSMITH_JUDGE_MODEL for the judge model id.
1108
+ return _verify(prompt, completion, workdir, **kwargs)
1109
+ """
1110
+
1111
+
1112
+ def _sample_judge_rubric() -> str:
1113
+ return """Score from 0.0 to 1.0.
1114
+ - 1.0: Correct, complete, and safe.
1115
+ - 0.7: Mostly correct with small issues.
1116
+ - 0.4: Partial correctness or unclear reasoning.
1117
+ - 0.0: Incorrect or unsafe.
1118
+ Return JSON only.
1119
+ """
1120
+
1121
+
937
1122
  def _sample_eval_suite() -> str:
938
1123
  return """name: coding-eval-sample
939
1124
  notes: |
@@ -6,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Any
6
6
 
7
7
  from pydantic import BaseModel, Field, field_validator
8
8
 
9
- AccelBackendName = Literal["none", "zmlx"]
9
+ AccelBackendName = Literal["none"]
10
10
 
11
11
 
12
12
  class ModelConfig(BaseModel):
@@ -47,6 +47,8 @@ class TrainConfig(BaseModel):
47
47
  grad_accum: int = 8
48
48
  lr: float = 2e-4
49
49
  weight_decay: float = 0.0
50
+ optimizer: str = "adamw"
51
+ optimizer_kwargs: Dict[str, Any] = Field(default_factory=dict)
50
52
  iters: int = 1000
51
53
  save_every: int = 100
52
54
  eval_every: int = 100
@@ -61,6 +63,11 @@ class TrainConfig(BaseModel):
61
63
  raise ValueError("value must be non-negative")
62
64
  return v
63
65
 
66
+ @field_validator("optimizer")
67
+ @classmethod
68
+ def normalize_optimizer(cls, v: str) -> str:
69
+ return v.strip().lower()
70
+
64
71
 
65
72
  class LoraConfig(BaseModel):
66
73
  """LoRA/DoRA adapter configuration."""
@@ -89,11 +96,13 @@ class LoraConfig(BaseModel):
89
96
 
90
97
 
91
98
  class PrefConfig(BaseModel):
92
- """Preference tuning configuration (DPO, ORPO, GRPO)."""
99
+ """Preference tuning configuration (DPO variants)."""
93
100
 
94
101
  algo: Literal["dpo", "orpo", "grpo"] = "dpo"
102
+ loss_type: Literal["dpo", "cpo", "orpo", "ipo", "hinge"] = "dpo"
95
103
  beta: float = 0.1
96
104
  kl_coeff: float = 0.0
105
+ delta: float = 0.0
97
106
  reference_model: Optional[str] = None
98
107
 
99
108
 
@@ -101,12 +110,16 @@ class RftConfig(BaseModel):
101
110
  """Reinforcement fine-tuning configuration."""
102
111
 
103
112
  algo: Literal["grpo"] = "grpo"
113
+ loss_type: Literal["grpo", "dr_grpo", "dapo"] = "grpo"
104
114
  rollouts: int = 8
105
115
  kl_coeff: float = 0.02
106
116
  max_steps_per_task: int = 1
107
117
  temperature: float = 0.8
108
118
  max_new_tokens: int = 256
109
119
  normalize_advantage: bool = True
120
+ epsilon_low: float = 0.2
121
+ epsilon_high: float = 0.2
122
+ token_level_loss: bool = False
110
123
  reference_model: Optional[str] = None
111
124
 
112
125
 
@@ -164,6 +177,7 @@ CLI_ALIASES: dict[str, tuple[str, ...]] = {
164
177
  "lr": ("train", "lr"),
165
178
  "batch_size": ("train", "batch_size"),
166
179
  "iters": ("train", "iters"),
180
+ "optimizer": ("train", "optimizer"),
167
181
  "model_id": ("model", "id"),
168
182
  "accel_backend": ("accel", "backend"),
169
183
  "host": ("serve", "host"),
@@ -0,0 +1,19 @@
1
+ """External integrations for mlxsmith."""
2
+
3
+ from .mlx_lm_lora import (
4
+ build_train_command as build_mlx_lm_lora_train_command,
5
+ build_synthetic_command as build_mlx_lm_lora_synth_command,
6
+ build_judge_command as build_mlx_lm_lora_judge_command,
7
+ build_reward_functions_command as build_mlx_lm_lora_reward_functions_command,
8
+ run_command as run_mlx_lm_lora_command,
9
+ ensure_available as ensure_mlx_lm_lora_available,
10
+ )
11
+
12
+ __all__ = [
13
+ "build_mlx_lm_lora_train_command",
14
+ "build_mlx_lm_lora_synth_command",
15
+ "build_mlx_lm_lora_judge_command",
16
+ "build_mlx_lm_lora_reward_functions_command",
17
+ "run_mlx_lm_lora_command",
18
+ "ensure_mlx_lm_lora_available",
19
+ ]