inductive-mlxrl 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inductive_mlxrl-0.1.0.dist-info/METADATA +388 -0
- inductive_mlxrl-0.1.0.dist-info/RECORD +24 -0
- inductive_mlxrl-0.1.0.dist-info/WHEEL +4 -0
- inductive_mlxrl-0.1.0.dist-info/entry_points.txt +2 -0
- inductive_mlxrl-0.1.0.dist-info/licenses/LICENSE +21 -0
- inductive_mlxrl-0.1.0.dist-info/licenses/THIRD_PARTY_LICENSES.md +46 -0
- mlxrl/__init__.py +6 -0
- mlxrl/algo/__init__.py +43 -0
- mlxrl/algo/grpo.py +568 -0
- mlxrl/algorithm.py +63 -0
- mlxrl/cli.py +793 -0
- mlxrl/config.py +397 -0
- mlxrl/data/__init__.py +31 -0
- mlxrl/data/gsm8k.py +60 -0
- mlxrl/data/rewards.py +101 -0
- mlxrl/policy/__init__.py +41 -0
- mlxrl/policy/logprobs.py +299 -0
- mlxrl/policy/model.py +300 -0
- mlxrl/py.typed +1 -0
- mlxrl/rollout/__init__.py +29 -0
- mlxrl/rollout/naive.py +186 -0
- mlxrl/rollout/optimized.py +644 -0
- mlxrl/train/__init__.py +5 -0
- mlxrl/train/grpo.py +304 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: inductive-mlxrl
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Small single-process RL post-training for LLMs on Apple Silicon with MLX.
|
|
5
|
+
Project-URL: Homepage, https://github.com/inductiveML/mlxrl
|
|
6
|
+
Project-URL: Repository, https://github.com/inductiveML/mlxrl
|
|
7
|
+
Project-URL: Issues, https://github.com/inductiveML/mlxrl/issues
|
|
8
|
+
Project-URL: Changelog, https://github.com/inductiveML/mlxrl/blob/main/CHANGELOG.md
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
License-File: THIRD_PARTY_LICENSES.md
|
|
12
|
+
Keywords: apple-silicon,grpo,mlx,qlora,rl
|
|
13
|
+
Classifier: Development Status :: 3 - Alpha
|
|
14
|
+
Classifier: Intended Audience :: Developers
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.11
|
|
20
|
+
Requires-Dist: mlx-lm>=0.31.0
|
|
21
|
+
Requires-Dist: mlx>=0.31.0
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
|
|
24
|
+
# mlxrl
|
|
25
|
+
|
|
26
|
+
Fast on-policy MLX RL for Apple Silicon; not a general RL framework, not
|
|
27
|
+
preference tuning, and not distributed training.
|
|
28
|
+
|
|
29
|
+
`mlxrl` is a small, single-process RL post-training library for LLMs on Apple
|
|
30
|
+
Silicon. It is built around one idea: GRPO on MLX should be a fast batched
|
|
31
|
+
rollout path with a thin loss and optimizer step on top, not a framework.
|
|
32
|
+
|
|
33
|
+
The current implementation targets QLoRA GRPO on local 4-bit MLX models. It
|
|
34
|
+
reuses `mlx-lm` model loading, LoRA layers, KV caches, and sampling utilities,
|
|
35
|
+
and keeps generation and training in one Python process with one model object.
|
|
36
|
+
|
|
37
|
+
`mlxrl` is pre-1.0. The correctness gates are stable, but import APIs and config
|
|
38
|
+
fields may change before a 1.0 release.
|
|
39
|
+
|
|
40
|
+
## Quickstart
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
git clone https://github.com/inductiveML/mlxrl.git
|
|
44
|
+
cd mlxrl
|
|
45
|
+
UV_CACHE_DIR=.uv-cache uv sync --all-groups
|
|
46
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
47
|
+
--config examples/qwen3_0_6b_grpo.toml \
|
|
48
|
+
--available-memory-gb 48
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
For the measured 9B-on-48GB shape, use the checkpointed G=2 config:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
55
|
+
--config examples/qwen35_9b_g2_checkpoint.toml \
|
|
56
|
+
--available-memory-gb 48 \
|
|
57
|
+
--dry-run
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## What Works
|
|
61
|
+
|
|
62
|
+
- Batched group rollouts with MLX-LM KV caches and sampling.
|
|
63
|
+
- Full-forward old-policy logprob recompute for training-time `pi_old`.
|
|
64
|
+
- Adapter-disabled reference policy on the same model object.
|
|
65
|
+
- GRPO, Dr. GRPO, DAPO, and GSPO loss variants.
|
|
66
|
+
- RLOO (REINFORCE Leave-One-Out) as a critic-free rollout objective.
|
|
67
|
+
- QLoRA injection on dense and heterogeneous/hybrid layer stacks.
|
|
68
|
+
- Qwen3.5-style hybrid support via MLX-LM auto LoRA targeting, including
|
|
69
|
+
DeltaNet `linear_attn.in_proj_*` and dense attention `q/k/v/o_proj`.
|
|
70
|
+
- Per-layer gradient checkpointing through `mlx_lm.tuner.trainer.grad_checkpoint`
|
|
71
|
+
for linear-attention/DeltaNet backward memory.
|
|
72
|
+
- Micro-batched gradient accumulation for token-mean policy losses.
|
|
73
|
+
- `beta == 0` reference-forward skip.
|
|
74
|
+
- Phase 4 benchmark harness for `mlxrl`, `mlx-tune`, `mlx-lm-lora`, and `mlx-lm`.
|
|
75
|
+
|
|
76
|
+
## Install
|
|
77
|
+
|
|
78
|
+
After the first tagged release, install the PyPI distribution:
|
|
79
|
+
|
|
80
|
+
```bash
|
|
81
|
+
pip install inductive-mlxrl
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
The Python import package and CLI command are still `mlxrl`:
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
mlxrl --help
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Source installs are also supported:
|
|
91
|
+
|
|
92
|
+
```bash
|
|
93
|
+
UV_CACHE_DIR=.uv-cache uv sync --all-groups
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
Run commands through the local environment:
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl --help
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
Python 3.11+ is required. Runtime dependencies are intentionally small:
|
|
103
|
+
`mlx` and `mlx-lm`. Development dependencies include `pytest`, `ruff`,
|
|
104
|
+
`pyright`, `mlx-tune`, and `mlx-lm-lora` for comparison benchmarks.
|
|
105
|
+
The PyPI distribution name is `inductive-mlxrl`; the import package and console
|
|
106
|
+
script remain `mlxrl`.
|
|
107
|
+
|
|
108
|
+
## Quick Smoke Tests
|
|
109
|
+
|
|
110
|
+
Dense Qwen3 0.6B:
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase0-smoke \
|
|
114
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
115
|
+
--prompt "What is 2+2?"
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
Hybrid Qwen3.5 9B with rank-16 LoRA:
|
|
119
|
+
|
|
120
|
+
```bash
|
|
121
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase0-smoke \
|
|
122
|
+
--model mlx-community/Qwen3.5-9B-MLX-4bit \
|
|
123
|
+
--rank 16 \
|
|
124
|
+
--scale 2.0 \
|
|
125
|
+
--prompt "What is 2+2?"
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
The smoke gate prints the model id, layer count, LoRA target keys, per-layer
|
|
129
|
+
LoRA module counts, total/trainable parameter counts, and logits shape. It
|
|
130
|
+
fails if any trainable leaf is not `lora_a` or `lora_b`.
|
|
131
|
+
|
|
132
|
+
## Training Commands
|
|
133
|
+
|
|
134
|
+
Toy hand-computed GRPO math gate:
|
|
135
|
+
|
|
136
|
+
```bash
|
|
137
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-toy-gate
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
Small built-in GSM8K-style run:
|
|
141
|
+
|
|
142
|
+
```bash
|
|
143
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-gsm8k \
|
|
144
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
145
|
+
--steps 20 \
|
|
146
|
+
--group-size 4 \
|
|
147
|
+
--max-tokens 64
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
Config-driven run:
|
|
151
|
+
|
|
152
|
+
```bash
|
|
153
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
154
|
+
--config examples/qwen3_0_6b_grpo.toml \
|
|
155
|
+
--available-memory-gb 48
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
The config schema validates model id, quant bits, group size, completion/prompt
|
|
159
|
+
lengths, checkpointing granularity, `iogpu.wired_limit_mb`, optimizer settings,
|
|
160
|
+
algorithm hyperparameters, KL beta, and seed before a model is loaded. CLI
|
|
161
|
+
overrides such as `--steps`, `--group-size`, `--max-tokens`, `--algorithm`,
|
|
162
|
+
`--beta`, and `--seed` apply on top of the file.
|
|
163
|
+
|
|
164
|
+
For DeltaNet / linear-attention models, enable per-layer checkpointing:
|
|
165
|
+
|
|
166
|
+
```bash
|
|
167
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-gsm8k \
|
|
168
|
+
--model mlx-community/Qwen3.5-9B-MLX-4bit \
|
|
169
|
+
--rank 16 \
|
|
170
|
+
--scale 2.0 \
|
|
171
|
+
--checkpoint-completion-forward \
|
|
172
|
+
--steps 1 \
|
|
173
|
+
--group-size 2 \
|
|
174
|
+
--max-tokens 256
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
Despite the historical CLI name, `--checkpoint-completion-forward` now enables
|
|
178
|
+
per-transformer-block checkpointing at model setup. The old whole-model
|
|
179
|
+
`mx.checkpoint(...)` wrapper was removed because it does not cap DeltaNet's
|
|
180
|
+
per-layer scan memory.
|
|
181
|
+
|
|
182
|
+
Phase 2 rollout equivalence check:
|
|
183
|
+
|
|
184
|
+
```bash
|
|
185
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase2-equivalence \
|
|
186
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
187
|
+
--group-size 4 \
|
|
188
|
+
--max-tokens 32 \
|
|
189
|
+
--compile-decode-step \
|
|
190
|
+
--batch-groups
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
## Import API
|
|
194
|
+
|
|
195
|
+
Minimal model setup:
|
|
196
|
+
|
|
197
|
+
```python
|
|
198
|
+
from mlxrl.policy import LoRAConfig, load_policy_with_lora
|
|
199
|
+
|
|
200
|
+
model, tokenizer, report = load_policy_with_lora(
|
|
201
|
+
model_id="mlx-community/Qwen3.5-9B-MLX-4bit",
|
|
202
|
+
config=LoRAConfig(
|
|
203
|
+
rank=16,
|
|
204
|
+
scale=2.0,
|
|
205
|
+
dropout=0.0,
|
|
206
|
+
grad_checkpoint=True,
|
|
207
|
+
),
|
|
208
|
+
)
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
One optimizer step:
|
|
212
|
+
|
|
213
|
+
```python
|
|
214
|
+
import mlx.optimizers as optim
|
|
215
|
+
|
|
216
|
+
from mlxrl.algo import GRPOAlgorithm
|
|
217
|
+
from mlxrl.train import batch_from_rollouts, optimizer_step
|
|
218
|
+
|
|
219
|
+
optimizer = optim.Adam(learning_rate=1e-5)
|
|
220
|
+
algorithm = GRPOAlgorithm()
|
|
221
|
+
batch = batch_from_rollouts(
|
|
222
|
+
model=model,
|
|
223
|
+
completions=completions,
|
|
224
|
+
rewards=rewards,
|
|
225
|
+
group_size=4,
|
|
226
|
+
pad_token_id=pad_token_id,
|
|
227
|
+
algorithm=algorithm,
|
|
228
|
+
compute_reference=beta != 0.0,
|
|
229
|
+
)
|
|
230
|
+
metrics = optimizer_step(
|
|
231
|
+
model=model,
|
|
232
|
+
optimizer=optimizer,
|
|
233
|
+
batch=batch,
|
|
234
|
+
beta=beta,
|
|
235
|
+
pad_token_id=pad_token_id,
|
|
236
|
+
algorithm=algorithm,
|
|
237
|
+
use_checkpoint=True,
|
|
238
|
+
micro_batch_size=2,
|
|
239
|
+
)
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
`micro_batch_size=0` keeps the original whole-batch path. Micro-batching is
|
|
243
|
+
currently exact for token-mean policy losses: base GRPO, DAPO, GSPO token mode,
|
|
244
|
+
RLOO, and Dr. GRPO with `loss_reduction="token_mean"`. Sequence-reduced losses
|
|
245
|
+
should keep `micro_batch_size=0`.
|
|
246
|
+
|
|
247
|
+
## Policy Semantics
|
|
248
|
+
|
|
249
|
+
- The base model is frozen before LoRA injection.
|
|
250
|
+
- Only LoRA adapter leaves are trainable.
|
|
251
|
+
- Reference logprobs are computed by temporarily disabling adapters on the same
|
|
252
|
+
model object; there is no second reference model in memory.
|
|
253
|
+
- Old-policy logprobs are recomputed with a full forward for the training batch.
|
|
254
|
+
Rollout-time logprobs are captured for inspection, but 4-bit sequential decode
|
|
255
|
+
and full-forward prefill are not numerically identical on hybrid/quantized
|
|
256
|
+
models, so recompute remains the default training semantics.
|
|
257
|
+
- When `beta == 0`, the reference forward is skipped and the policy logprobs are
|
|
258
|
+
used as a zero-KL placeholder.
|
|
259
|
+
- PPO, DPO, and ORPO are intentionally out of scope. PPO needs a separate critic
|
|
260
|
+
and value forward; DPO/ORPO are offline preference objectives with no rollout
|
|
261
|
+
phase. `mlxrl` is critic-free, on-policy, and rollout-based by design.
|
|
262
|
+
|
|
263
|
+
## Algorithms
|
|
264
|
+
|
|
265
|
+
Concrete algorithms implement the small `Algorithm` protocol: compute
|
|
266
|
+
advantages, optionally filter a prepared batch, then compute a loss from policy,
|
|
267
|
+
old-policy, and reference logprobs. `rollout/`, `policy/`, and `train/` do not
|
|
268
|
+
import concrete algorithm implementations.
|
|
269
|
+
|
|
270
|
+
| algorithm | defining behavior |
|
|
271
|
+
| --- | --- |
|
|
272
|
+
| GRPO | group-normalized rewards, token-level importance ratio |
|
|
273
|
+
| Dr. GRPO | centered or normalized rewards with decoupled length reduction |
|
|
274
|
+
| DAPO | asymmetric low/high clipping plus optional dynamic zero-advantage group filtering |
|
|
275
|
+
| GSPO | sequence-level, length-normalized importance ratio and clipping |
|
|
276
|
+
| RLOO | leave-one-out group baseline, no critic, no std-normalized advantage |
|
|
277
|
+
|
|
278
|
+
## Memory Preflight
|
|
279
|
+
|
|
280
|
+
`mlxrl train` can estimate memory before loading the model:
|
|
281
|
+
|
|
282
|
+
```bash
|
|
283
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
284
|
+
--config examples/qwen3_0_6b_grpo.toml \
|
|
285
|
+
--available-memory-gb 48 \
|
|
286
|
+
--dry-run
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
The estimator is calibrated to measured anchors: `6.245 GB` for
|
|
290
|
+
Qwen3-0.6B/G4/prompt≈19/T256, `25.9 GB` for
|
|
291
|
+
Qwen3.5-9B/G2/seq609/per-layer-checkpointed, `45.9 GB` for
|
|
292
|
+
Qwen3.5-9B/G4/seq609/per-layer-checkpointed, and `36 GB` for
|
|
293
|
+
Qwen3.5-9B/G2/seq128/no-checkpoint. For hybrid 9B no-checkpoint long-sequence
|
|
294
|
+
configs, it reports an OOM-risk lower bound rather than a fake precise peak.
|
|
295
|
+
For an obviously too-large Qwen3.5-9B/G8/prompt97/T512/no-checkpoint config on
|
|
296
|
+
48 GB, it flags the run and suggests the measured-boundary fallback around
|
|
297
|
+
G4/T512/checkpointed.
|
|
298
|
+
|
|
299
|
+
## Benchmarks
|
|
300
|
+
|
|
301
|
+
Local M4 Max Phase 4 snapshot:
|
|
302
|
+
|
|
303
|
+
- `454` rollout tok/s on Qwen3-0.6B GRPO with G=4 and 256-token completions.
|
|
304
|
+
- `0.283` end-to-end it/s with full `mlxrl` training semantics.
|
|
305
|
+
- `3.2x` faster rollout and `2.2x` higher end-to-end it/s than `mlx-tune`
|
|
306
|
+
v0.5.1 on the same run shape.
|
|
307
|
+
- `1.3x` faster rollout than sequential `mlx-lm` generation at G=4.
|
|
308
|
+
|
|
309
|
+
These are the two-pass means from
|
|
310
|
+
`benchmarks/results/gate5_full_reconciled.md`, run with MLX 0.31.2,
|
|
311
|
+
MLX-LM 0.31.3, `mlx-community/Qwen3-0.6B-4bit`, 100 measured steps with
|
|
312
|
+
5 warmup steps discarded:
|
|
313
|
+
|
|
314
|
+
| target | comparison | rollout tok/s | grad s/step | samples/s | it/s | peak GB |
|
|
315
|
+
| --- | --- | ---: | ---: | ---: | ---: | ---: |
|
|
316
|
+
| `mlxrl` | apples-to-apples GRPO | 454.1 | 1.282 | 1.133 | 0.283 | 6.25 |
|
|
317
|
+
| `mlx-lm` | generation-only, G=1 | 347.0 | - | 1.355 | - | 0.52 |
|
|
318
|
+
| `mlx-lm-g4` | generation-only, sequential G=4 | 349.7 | - | 1.366 | - | 0.52 |
|
|
319
|
+
| `mlx-tune` | package-speed reference | 142.2 | 0.502 | 0.519 | 0.130 | 6.16 |
|
|
320
|
+
| `mlx-lm-lora` | package-speed reference | 557.9 | 0.592 | 1.648 | 0.412 | 5.32 |
|
|
321
|
+
|
|
322
|
+
`mlx-lm-lora` reports higher raw package-speed throughput in this snapshot, but
|
|
323
|
+
its benchmarked path is not the same training problem as `mlxrl`'s live
|
|
324
|
+
old-policy/reference semantics and completion-loss masking. That is the honest
|
|
325
|
+
case where `mlxrl` is not faster; the apples-to-apples comparison label is
|
|
326
|
+
reserved for `mlxrl`'s own semantic path. On the 9B Noether real workload, the
|
|
327
|
+
checkpointed MLX path measured about 6x faster than the previous torch-MPS path;
|
|
328
|
+
that workload is separate from the public Phase 4 package-speed harness.
|
|
329
|
+
|
|
330
|
+
Run the Phase 4 harness:
|
|
331
|
+
|
|
332
|
+
```bash
|
|
333
|
+
UV_CACHE_DIR=.uv-cache uv run python benchmarks/run_phase4.py run \
|
|
334
|
+
--targets mlxrl,mlx-lm,mlx-tune,mlx-lm-lora \
|
|
335
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
336
|
+
--steps 100 \
|
|
337
|
+
--warmup-steps 5 \
|
|
338
|
+
--group-size 4 \
|
|
339
|
+
--max-tokens 256 \
|
|
340
|
+
--passes 2 \
|
|
341
|
+
--output benchmarks/results/phase4.jsonl \
|
|
342
|
+
--summary benchmarks/results/phase4.md \
|
|
343
|
+
--allow-missing-baselines
|
|
344
|
+
```
|
|
345
|
+
|
|
346
|
+
The harness reports synchronized rollout tok/s, gradient seconds per step,
|
|
347
|
+
samples/s, it/s, and peak MLX memory. `mlx-lm` targets are generation-only;
|
|
348
|
+
external package targets are useful speed references but may not match `mlxrl`
|
|
349
|
+
training semantics.
|
|
350
|
+
|
|
351
|
+
## Development
|
|
352
|
+
|
|
353
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) and [DESIGN.md](DESIGN.md) before adding
|
|
354
|
+
algorithms or changing rollout/logprob semantics.
|
|
355
|
+
|
|
356
|
+
Run the quality gates:
|
|
357
|
+
|
|
358
|
+
```bash
|
|
359
|
+
UV_CACHE_DIR=.uv-cache uv run pytest
|
|
360
|
+
UV_CACHE_DIR=.uv-cache uv run ruff check .
|
|
361
|
+
UV_CACHE_DIR=.uv-cache uv run pyright
|
|
362
|
+
```
|
|
363
|
+
|
|
364
|
+
MLX lazy evaluation matters. Any `mx.eval(...)` or `mx.synchronize()` in this
|
|
365
|
+
repo should mark a real boundary: sampled token append/EOS checks, logprob
|
|
366
|
+
freezing before adapter mutation, per-micro-batch graph release, optimizer
|
|
367
|
+
updates, or benchmark timing boundaries.
|
|
368
|
+
|
|
369
|
+
## Layout
|
|
370
|
+
|
|
371
|
+
```text
|
|
372
|
+
mlxrl/
|
|
373
|
+
rollout/ # batched group generation
|
|
374
|
+
policy/ # model loading, LoRA setup, logprob passes
|
|
375
|
+
algo/ # GRPO-family advantages and losses
|
|
376
|
+
train/ # value_and_grad and optimizer integration
|
|
377
|
+
data/ # toy GSM8K data and rewards
|
|
378
|
+
cli.py
|
|
379
|
+
tests/
|
|
380
|
+
benchmarks/
|
|
381
|
+
```
|
|
382
|
+
|
|
383
|
+
## Non-Goals
|
|
384
|
+
|
|
385
|
+
- No inference server or second model copy.
|
|
386
|
+
- No CUDA or torch fallback.
|
|
387
|
+
- No distributed training.
|
|
388
|
+
- No broad RL framework abstractions beyond the small algorithm interface.
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
mlxrl/__init__.py,sha256=OmZqyDOcKBeRwSNfiu_wdm_zZOvPdooMOaYTnzoHLtU,122
|
|
2
|
+
mlxrl/algorithm.py,sha256=hBhTlIVOnL69_CoE7pxkt_zCk63ImxIJHi2VG7KSAYE,1778
|
|
3
|
+
mlxrl/cli.py,sha256=6aovbx1Sv4r02RLrmexTDsbQGxXpcAy8rv8LDuIqrxo,29934
|
|
4
|
+
mlxrl/config.py,sha256=RiWu-pWl8u8OGJrI8Rr2ozIiA-ecu5fwntM3G9cQZ1g,14675
|
|
5
|
+
mlxrl/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
6
|
+
mlxrl/algo/__init__.py,sha256=VblutTkTpF4PQ-qKiCYUQ9EizJ3cUPGixqsQsWbbpPM,942
|
|
7
|
+
mlxrl/algo/grpo.py,sha256=SQfY38IY_zr33S2Ph49xS4HN3LylWQYh2qUejw1jJMg,18824
|
|
8
|
+
mlxrl/data/__init__.py,sha256=udtJ0TQfbEEj1LELw5NynXpsOv2_afFft54dKBgZCZ0,589
|
|
9
|
+
mlxrl/data/gsm8k.py,sha256=SBc6SIRq1TlPOBe9KhIJP0KXxn8Gmay2f1yEFpyDoE8,2075
|
|
10
|
+
mlxrl/data/rewards.py,sha256=IT6N07DkXgQidMI9BJbGG9KkiprGaa6QQpnIaLxDWKo,2988
|
|
11
|
+
mlxrl/policy/__init__.py,sha256=ieu_rA9aq6BQdDXohixt02i1Ka3wqeTT7JLeBWHFy9s,1025
|
|
12
|
+
mlxrl/policy/logprobs.py,sha256=nAFjIE7MUWUAQ_7faFdF3okOjNlLT1eagJ4S8BA7IzI,10411
|
|
13
|
+
mlxrl/policy/model.py,sha256=DA_MWaugxsIl_s0QMD3bzwEwN9rxmZG6yoc7PbKAX6M,10015
|
|
14
|
+
mlxrl/rollout/__init__.py,sha256=MCSESkZR7FDIZwj50J_ThA_RInquCQHB2Jh6OnZM9NM,847
|
|
15
|
+
mlxrl/rollout/naive.py,sha256=Jo3B7RChEJI3ty0GC3mRHyVVzCgCWLiFEPRpnrQWlQs,5843
|
|
16
|
+
mlxrl/rollout/optimized.py,sha256=MDaDVBHdqF4FeXGp6r4N8umDnOFFii-hc3JIJXd2p-g,21274
|
|
17
|
+
mlxrl/train/__init__.py,sha256=oSs0IZBzMjw_eJgUBlyfd0XisIMz0SCvJWXjupE5q6M,219
|
|
18
|
+
mlxrl/train/grpo.py,sha256=rgWEq16D07jfFKjLmJszlS016ViUmLoc5Ece0gszxxk,10486
|
|
19
|
+
inductive_mlxrl-0.1.0.dist-info/METADATA,sha256=a6PV1SFsPyB98ZZTyBLiVuRXLDLBmnbsReLO8QQb0Ec,12932
|
|
20
|
+
inductive_mlxrl-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
21
|
+
inductive_mlxrl-0.1.0.dist-info/entry_points.txt,sha256=IyD-wJ1EUMhrd40AjCGTEmDyMMb9M3lNyRuzB3Vd63w,41
|
|
22
|
+
inductive_mlxrl-0.1.0.dist-info/licenses/LICENSE,sha256=Q2qW7SWCM5UjB3Rxs2rmhMAhhk5IbsI2sYuhmn5mgUo,1075
|
|
23
|
+
inductive_mlxrl-0.1.0.dist-info/licenses/THIRD_PARTY_LICENSES.md,sha256=X5OXyGhte6HfWf3B1B4VwjoL8e6OoXpFEfHF5v9CjrE,2067
|
|
24
|
+
inductive_mlxrl-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 mlxrl contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Third-Party Notices
|
|
2
|
+
|
|
3
|
+
`mlxrl` reuses and adapts APIs, cache semantics, and sampling utilities from
|
|
4
|
+
upstream MLX projects. The repository keeps local attribution headers on files
|
|
5
|
+
that adapt those patterns.
|
|
6
|
+
|
|
7
|
+
## mlx-lm
|
|
8
|
+
|
|
9
|
+
- Source: https://github.com/ml-explore/mlx-lm
|
|
10
|
+
- License: MIT
|
|
11
|
+
- Used for: model loading, LoRA utilities, gradient checkpoint helper,
|
|
12
|
+
prompt/KV-cache construction, cache state conventions, and sampling filters.
|
|
13
|
+
|
|
14
|
+
Upstream notice:
|
|
15
|
+
|
|
16
|
+
```text
|
|
17
|
+
MIT License Copyright © 2023 Apple Inc.
|
|
18
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
19
|
+
this software and associated documentation files (the "Software"), to deal in
|
|
20
|
+
the Software without restriction, including without limitation the rights to use,
|
|
21
|
+
copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
|
|
22
|
+
Software, and to permit persons to whom the Software is furnished to do so,
|
|
23
|
+
subject to the following conditions: The above copyright notice and this
|
|
24
|
+
permission notice shall be included in all copies or substantial portions of the
|
|
25
|
+
Software.
|
|
26
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
27
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
28
|
+
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
29
|
+
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
30
|
+
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
31
|
+
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
## mlx-tune
|
|
35
|
+
|
|
36
|
+
- Source: https://github.com/ARahim3/mlx-tune
|
|
37
|
+
- License: Apache-2.0 per installed package metadata.
|
|
38
|
+
- Used for: benchmark configuration inspection and external baseline adapter
|
|
39
|
+
patterns only. No `mlx-tune` implementation code is vendored in `mlxrl`.
|
|
40
|
+
|
|
41
|
+
## mlx-lm-lora
|
|
42
|
+
|
|
43
|
+
- Source: https://github.com/Goekdeniz-Guelmez/mlx-lm-lora
|
|
44
|
+
- License: MIT per installed package metadata.
|
|
45
|
+
- Used for: benchmark configuration inspection and external baseline adapter
|
|
46
|
+
patterns only. No `mlx-lm-lora` implementation code is vendored in `mlxrl`.
|
mlxrl/__init__.py
ADDED
mlxrl/algo/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Algorithm-specific advantage and loss code."""
|
|
2
|
+
|
|
3
|
+
from mlxrl.algo.grpo import (
|
|
4
|
+
AlgorithmLossMetrics,
|
|
5
|
+
DAPOAlgorithm,
|
|
6
|
+
DrGRPOAlgorithm,
|
|
7
|
+
GRPOAlgorithm,
|
|
8
|
+
GRPOLossMetrics,
|
|
9
|
+
GSPOAlgorithm,
|
|
10
|
+
PolicyAlgorithm,
|
|
11
|
+
RLOOAlgorithm,
|
|
12
|
+
algorithm_by_name,
|
|
13
|
+
approximate_kl,
|
|
14
|
+
clip_ratio,
|
|
15
|
+
filter_zero_advantage_groups,
|
|
16
|
+
group_center_rewards,
|
|
17
|
+
group_normalize_rewards,
|
|
18
|
+
grpo_loss,
|
|
19
|
+
masked_mean,
|
|
20
|
+
sequence_importance_ratio,
|
|
21
|
+
token_importance_ratio,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"AlgorithmLossMetrics",
|
|
26
|
+
"DAPOAlgorithm",
|
|
27
|
+
"DrGRPOAlgorithm",
|
|
28
|
+
"GRPOAlgorithm",
|
|
29
|
+
"GRPOLossMetrics",
|
|
30
|
+
"GSPOAlgorithm",
|
|
31
|
+
"PolicyAlgorithm",
|
|
32
|
+
"RLOOAlgorithm",
|
|
33
|
+
"algorithm_by_name",
|
|
34
|
+
"approximate_kl",
|
|
35
|
+
"clip_ratio",
|
|
36
|
+
"filter_zero_advantage_groups",
|
|
37
|
+
"group_center_rewards",
|
|
38
|
+
"group_normalize_rewards",
|
|
39
|
+
"grpo_loss",
|
|
40
|
+
"masked_mean",
|
|
41
|
+
"sequence_importance_ratio",
|
|
42
|
+
"token_importance_ratio",
|
|
43
|
+
]
|