inductive-mlxrl 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inductive_mlxrl-0.1.0/.gitignore +14 -0
- inductive_mlxrl-0.1.0/CHANGELOG.md +29 -0
- inductive_mlxrl-0.1.0/CONTRIBUTING.md +107 -0
- inductive_mlxrl-0.1.0/DESIGN.md +72 -0
- inductive_mlxrl-0.1.0/LICENSE +21 -0
- inductive_mlxrl-0.1.0/PKG-INFO +388 -0
- inductive_mlxrl-0.1.0/README.md +365 -0
- inductive_mlxrl-0.1.0/THIRD_PARTY_LICENSES.md +46 -0
- inductive_mlxrl-0.1.0/benchmarks/__init__.py +1 -0
- inductive_mlxrl-0.1.0/benchmarks/audit_phase4_config.py +385 -0
- inductive_mlxrl-0.1.0/benchmarks/external_baseline_worker.py +553 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after.jsonl +1 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after.md +9 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_final.jsonl +1 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_final.md +9 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_noprefixsync.jsonl +1 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_noprefixsync.md +9 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_before.jsonl +1 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_before.md +9 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate2_unified_smoke.jsonl +5 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate2_unified_smoke.md +17 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate3_mlx_tune_sync_256.jsonl +1 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate3_mlx_tune_sync_256.md +9 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_blocked.jsonl +5 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_blocked.md +11 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_smoke.jsonl +5 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_smoke.md +17 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate4_config_audit.json +102 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate4_config_audit.md +9 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate4_sampler_smoke.jsonl +5 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate4_sampler_smoke.md +17 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate5_full_reconciled.jsonl +10 -0
- inductive_mlxrl-0.1.0/benchmarks/results/gate5_full_reconciled.md +28 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_external_smoke.jsonl +2 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_external_smoke.md +11 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_full_all_targets.jsonl +8 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_full_all_targets.md +19 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_local_mlxrl_vs_mlx_lm.jsonl +4 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_local_mlxrl_vs_mlx_lm.md +13 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_mlx_lm_g4_smoke.jsonl +2 -0
- inductive_mlxrl-0.1.0/benchmarks/results/phase4_mlx_lm_g4_smoke.md +11 -0
- inductive_mlxrl-0.1.0/benchmarks/run_phase4.py +945 -0
- inductive_mlxrl-0.1.0/mlxrl/__init__.py +6 -0
- inductive_mlxrl-0.1.0/mlxrl/algo/__init__.py +43 -0
- inductive_mlxrl-0.1.0/mlxrl/algo/grpo.py +568 -0
- inductive_mlxrl-0.1.0/mlxrl/algorithm.py +63 -0
- inductive_mlxrl-0.1.0/mlxrl/cli.py +793 -0
- inductive_mlxrl-0.1.0/mlxrl/config.py +397 -0
- inductive_mlxrl-0.1.0/mlxrl/data/__init__.py +31 -0
- inductive_mlxrl-0.1.0/mlxrl/data/gsm8k.py +60 -0
- inductive_mlxrl-0.1.0/mlxrl/data/rewards.py +101 -0
- inductive_mlxrl-0.1.0/mlxrl/policy/__init__.py +41 -0
- inductive_mlxrl-0.1.0/mlxrl/policy/logprobs.py +299 -0
- inductive_mlxrl-0.1.0/mlxrl/policy/model.py +300 -0
- inductive_mlxrl-0.1.0/mlxrl/py.typed +1 -0
- inductive_mlxrl-0.1.0/mlxrl/rollout/__init__.py +29 -0
- inductive_mlxrl-0.1.0/mlxrl/rollout/naive.py +186 -0
- inductive_mlxrl-0.1.0/mlxrl/rollout/optimized.py +644 -0
- inductive_mlxrl-0.1.0/mlxrl/train/__init__.py +5 -0
- inductive_mlxrl-0.1.0/mlxrl/train/grpo.py +304 -0
- inductive_mlxrl-0.1.0/pyproject.toml +83 -0
- inductive_mlxrl-0.1.0/tests/test_benchmark_metrics.py +150 -0
- inductive_mlxrl-0.1.0/tests/test_cli.py +111 -0
- inductive_mlxrl-0.1.0/tests/test_config.py +138 -0
- inductive_mlxrl-0.1.0/tests/test_grpo.py +324 -0
- inductive_mlxrl-0.1.0/tests/test_import_direction.py +27 -0
- inductive_mlxrl-0.1.0/tests/test_optimized_rollout.py +76 -0
- inductive_mlxrl-0.1.0/tests/test_policy_logprobs.py +232 -0
- inductive_mlxrl-0.1.0/tests/test_policy_model.py +180 -0
- inductive_mlxrl-0.1.0/tests/test_rewards.py +54 -0
- inductive_mlxrl-0.1.0/tests/test_train_grpo.py +141 -0
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project are documented here.
|
|
4
|
+
|
|
5
|
+
The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
|
6
|
+
and this project uses semantic versioning once releases are cut.
|
|
7
|
+
|
|
8
|
+
## [0.1.0] - 2026-06-24
|
|
9
|
+
|
|
10
|
+
### Added
|
|
11
|
+
|
|
12
|
+
- Single-process MLX GRPO+QLoRA training path for Apple Silicon.
|
|
13
|
+
- Prefix-cached grouped rollout engine with compiled decode support.
|
|
14
|
+
- Adapter-disabled reference logprob path on one model object.
|
|
15
|
+
- Full-forward old-policy logprob recompute for 4-bit correctness.
|
|
16
|
+
- Algorithm protocol with GRPO, Dr. GRPO, DAPO, GSPO, and RLOO.
|
|
17
|
+
- DAPO `filter_batch` hook for dynamic zero-advantage group filtering.
|
|
18
|
+
- Per-layer MLX-LM gradient checkpointing for DeltaNet/linear-attention models.
|
|
19
|
+
- Micro-batched gradient accumulation for token-mean losses.
|
|
20
|
+
- Typed config schema and memory preflight helper calibrated to measured anchors.
|
|
21
|
+
- Phase 4 benchmark harness for `mlxrl`, `mlx-lm`, `mlx-tune`, and `mlx-lm-lora`.
|
|
22
|
+
- Linux CPU-safe CI for docs/config/import-direction checks.
|
|
23
|
+
|
|
24
|
+
### Notes
|
|
25
|
+
|
|
26
|
+
- `mlxrl` is pre-1.0; public APIs may change while the correctness gates settle.
|
|
27
|
+
- Full MLX/Metal tests require Apple Silicon or a self-hosted Mac runner.
|
|
28
|
+
- Published on PyPI as `inductive-mlxrl`; the import package and CLI command
|
|
29
|
+
remain `mlxrl`.
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# Contributing
|
|
2
|
+
|
|
3
|
+
`mlxrl` is fast on-policy MLX RL for Apple Silicon. It is not a broad RL
|
|
4
|
+
framework, not preference tuning, and not multi-GPU or distributed training.
|
|
5
|
+
|
|
6
|
+
## Development Loop
|
|
7
|
+
|
|
8
|
+
Use `uv` from the repository root:
|
|
9
|
+
|
|
10
|
+
```bash
|
|
11
|
+
UV_CACHE_DIR=.uv-cache uv sync --all-groups
|
|
12
|
+
UV_CACHE_DIR=.uv-cache uv run pytest
|
|
13
|
+
UV_CACHE_DIR=.uv-cache uv run ruff check .
|
|
14
|
+
UV_CACHE_DIR=.uv-cache uv run pyright
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
Tests must stay green while refactoring. Correctness tests are not optional:
|
|
18
|
+
rollout equivalence, loss/gradient checks, import-direction guards, and memory
|
|
19
|
+
estimator anchor tests are part of the contract.
|
|
20
|
+
|
|
21
|
+
## Adding An Algorithm
|
|
22
|
+
|
|
23
|
+
Use `RLOOAlgorithm` in `mlxrl/algo/grpo.py` as the smallest template.
|
|
24
|
+
|
|
25
|
+
Required steps:
|
|
26
|
+
|
|
27
|
+
- implement the `Algorithm` protocol;
|
|
28
|
+
- add a hand-computed toy loss and gradient test;
|
|
29
|
+
- add a reduction or relationship test when the algorithm should collapse to an
|
|
30
|
+
existing objective under a degenerate config;
|
|
31
|
+
- keep algorithm code in `mlxrl/algo/`;
|
|
32
|
+
- do not import concrete algorithms from `rollout/`, `policy/`, or `train/`;
|
|
33
|
+
- add an end-to-end smoke before relying on a new objective.
|
|
34
|
+
|
|
35
|
+
If an algorithm needs to drop or reshape examples before the loss, use the
|
|
36
|
+
`filter_batch` hook. Do not special-case it in the trainer.
|
|
37
|
+
|
|
38
|
+
## MLX And CI
|
|
39
|
+
|
|
40
|
+
GitHub-hosted Linux runners cannot run the MLX/Metal path. Linux CI gates the
|
|
41
|
+
CPU-safe subset: config validation, memory-estimator math, import-direction
|
|
42
|
+
checks, reward functions, and benchmark-result rendering.
|
|
43
|
+
|
|
44
|
+
Tests marked `@pytest.mark.metal` require Apple Silicon with MLX/Metal. Run the
|
|
45
|
+
full suite locally on a Mac before cutting releases. A self-hosted Mac runner
|
|
46
|
+
would close this coverage gap.
|
|
47
|
+
|
|
48
|
+
## Release Cycle
|
|
49
|
+
|
|
50
|
+
Releases are tag-driven and publish to PyPI through Trusted Publishing. Do not
|
|
51
|
+
store PyPI API tokens in GitHub secrets.
|
|
52
|
+
|
|
53
|
+
One-time PyPI setup:
|
|
54
|
+
|
|
55
|
+
- create a PyPI Trusted Publisher for project `inductive-mlxrl`;
|
|
56
|
+
- set owner to `inductiveML` and repository to `mlxrl`;
|
|
57
|
+
- set workflow name to `release.yml`;
|
|
58
|
+
- set environment name to `pypi`;
|
|
59
|
+
- require manual approval on the GitHub `pypi` environment before publishing.
|
|
60
|
+
|
|
61
|
+
For each release:
|
|
62
|
+
|
|
63
|
+
1. Update the version in `pyproject.toml`.
|
|
64
|
+
2. Add the release notes to `CHANGELOG.md`.
|
|
65
|
+
3. Run the local gates:
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
UV_CACHE_DIR=.uv-cache uv run pytest
|
|
69
|
+
UV_CACHE_DIR=.uv-cache uv run ruff check .
|
|
70
|
+
UV_CACHE_DIR=.uv-cache uv run pyright
|
|
71
|
+
UV_CACHE_DIR=.uv-cache uv build
|
|
72
|
+
UV_CACHE_DIR=.uv-cache uv run --no-project --python 3.11 --with twine twine check dist/*
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
4. Commit and merge the release prep.
|
|
76
|
+
5. Tag the release from `main`:
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
git tag -a vX.Y.Z -m "Release vX.Y.Z"
|
|
80
|
+
git push origin vX.Y.Z
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
The `Release` workflow builds the source distribution and wheel, checks package
|
|
84
|
+
metadata with Twine, stores the artifacts, publishes them to PyPI from the
|
|
85
|
+
`pypi` environment, and creates a GitHub Release with the built distributions
|
|
86
|
+
attached. Running the workflow manually builds and checks artifacts without
|
|
87
|
+
publishing because the publish and GitHub Release jobs only run for `v*.*.*`
|
|
88
|
+
tags.
|
|
89
|
+
|
|
90
|
+
The PyPI distribution name is `inductive-mlxrl`; the Python import package and
|
|
91
|
+
CLI command remain `mlxrl`.
|
|
92
|
+
|
|
93
|
+
## Scope
|
|
94
|
+
|
|
95
|
+
In scope:
|
|
96
|
+
|
|
97
|
+
- single-process Apple Silicon RL post-training;
|
|
98
|
+
- QLoRA on local MLX LLMs;
|
|
99
|
+
- critic-free on-policy algorithms;
|
|
100
|
+
- memory-conscious rollout and training paths.
|
|
101
|
+
|
|
102
|
+
Out of scope:
|
|
103
|
+
|
|
104
|
+
- PPO or other critic/value-model algorithms;
|
|
105
|
+
- DPO/ORPO and other offline preference objectives;
|
|
106
|
+
- CUDA, torch fallback, or distributed training;
|
|
107
|
+
- inference servers or second reference-model copies.
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# mlxrl Design
|
|
2
|
+
|
|
3
|
+
`mlxrl` is a small, single-process RL post-training library for Apple Silicon.
|
|
4
|
+
The design goal is narrow: make on-policy, critic-free RL for local MLX LLMs
|
|
5
|
+
fast enough that rollout is the main problem and the loss stays thin.
|
|
6
|
+
|
|
7
|
+
## Critic-Free By Design
|
|
8
|
+
|
|
9
|
+
The library is intentionally limited to rollout-based policy-gradient methods:
|
|
10
|
+
GRPO, Dr. GRPO, DAPO, GSPO, and RLOO. PPO is out of scope because it needs a
|
|
11
|
+
critic/value model path, a different forward pass, and a different memory
|
|
12
|
+
profile. DPO and ORPO are also out of scope because they are offline preference
|
|
13
|
+
objectives rather than on-policy rollout algorithms.
|
|
14
|
+
|
|
15
|
+
That boundary matters. `mlxrl` keeps one policy model object in memory, attaches
|
|
16
|
+
QLoRA adapters to it, and computes reference logprobs by disabling those
|
|
17
|
+
adapters for a second pass. There is no inference server and no second reference
|
|
18
|
+
model copy.
|
|
19
|
+
|
|
20
|
+
## Algorithm Protocol
|
|
21
|
+
|
|
22
|
+
The engine does not know which algorithm it is training. Concrete algorithms
|
|
23
|
+
implement the `Algorithm` protocol:
|
|
24
|
+
|
|
25
|
+
- compute per-completion advantages;
|
|
26
|
+
- optionally filter a prepared batch;
|
|
27
|
+
- compute loss and diagnostics from policy, old-policy, and reference logprobs.
|
|
28
|
+
|
|
29
|
+
`rollout/`, `policy/`, and `train/` must not import from `algo/`. The import
|
|
30
|
+
direction test enforces this. The payoff is that rollout and logprob code stay
|
|
31
|
+
stable while algorithms change. DAPO's dynamic sampling is the proof that the
|
|
32
|
+
interface is more general than "GRPO with renamed constants": it drops
|
|
33
|
+
zero-advantage groups through `filter_batch` without special trainer branches.
|
|
34
|
+
|
|
35
|
+
## Correctness Gates
|
|
36
|
+
|
|
37
|
+
Speed changes are allowed only behind equivalence gates. The original Phase 1
|
|
38
|
+
path is simple and readable; optimized rollout variants must match it
|
|
39
|
+
token-for-token at fixed seed and match loss within tolerance. The protocol
|
|
40
|
+
refactor was checked against the pre-refactor GRPO and Dr. GRPO losses with
|
|
41
|
+
zero loss and adapter-gradient difference.
|
|
42
|
+
|
|
43
|
+
Old-policy changes use a stronger gate than naive equality. At 4-bit, the
|
|
44
|
+
rollout-time cached decode realization and the full-forward realization are not
|
|
45
|
+
numerically identical. Stored rollout logprobs are attractive because they are
|
|
46
|
+
the behavior policy, but the actual gate is importance-ratio stability: on a
|
|
47
|
+
freshly sampled batch, `exp(logpi_current - logpi_old)` must stay near 1.0 and
|
|
48
|
+
short training runs must not show KL or gradient spikes. Until that gate justifies
|
|
49
|
+
a semantics change, `mlxrl` recomputes old-policy logprobs with a full forward.
|
|
50
|
+
|
|
51
|
+
## The 4-Bit KV Boundary
|
|
52
|
+
|
|
53
|
+
Quantized KV cache decode is not a drop-in numerical replacement for a full
|
|
54
|
+
forward over prompt+completion. That is fine for sampling, but it is not fine
|
|
55
|
+
for gradient-bearing or importance-weighting quantities. Anything used in the
|
|
56
|
+
loss denominator, KL, or adapter gradient must come from the full-forward path
|
|
57
|
+
unless a dedicated stability gate proves otherwise.
|
|
58
|
+
|
|
59
|
+
## Memory As A First-Class Constraint
|
|
60
|
+
|
|
61
|
+
Apple Silicon's unified memory is the deployment target, not an afterthought.
|
|
62
|
+
The code supports per-layer MLX-LM gradient checkpointing because DeltaNet and
|
|
63
|
+
linear-attention models otherwise keep O(sequence length) recurrent state live
|
|
64
|
+
through backward. The 9B fitting anchor is Qwen3.5-9B 4-bit at G=2, total
|
|
65
|
+
sequence length 609, with per-layer checkpointing at about 25.9 GB peak. G=4 at
|
|
66
|
+
the same shape is about 45.9 GB and tight on a 48 GB machine.
|
|
67
|
+
|
|
68
|
+
The memory estimator is deliberately conservative. It interpolates near measured
|
|
69
|
+
anchors and labels long-sequence uncheckpointed hybrid configs as OOM-risk
|
|
70
|
+
estimates, not measurements. Its purpose is to nudge users toward the knob that
|
|
71
|
+
will most likely make a run fit: enable checkpointing, reduce G, then reduce
|
|
72
|
+
completion length.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 mlxrl contributors
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: inductive-mlxrl
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Small single-process RL post-training for LLMs on Apple Silicon with MLX.
|
|
5
|
+
Project-URL: Homepage, https://github.com/inductiveML/mlxrl
|
|
6
|
+
Project-URL: Repository, https://github.com/inductiveML/mlxrl
|
|
7
|
+
Project-URL: Issues, https://github.com/inductiveML/mlxrl/issues
|
|
8
|
+
Project-URL: Changelog, https://github.com/inductiveML/mlxrl/blob/main/CHANGELOG.md
|
|
9
|
+
License-Expression: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
License-File: THIRD_PARTY_LICENSES.md
|
|
12
|
+
Keywords: apple-silicon,grpo,mlx,qlora,rl
|
|
13
|
+
Classifier: Development Status :: 3 - Alpha
|
|
14
|
+
Classifier: Intended Audience :: Developers
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.11
|
|
20
|
+
Requires-Dist: mlx-lm>=0.31.0
|
|
21
|
+
Requires-Dist: mlx>=0.31.0
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
|
|
24
|
+
# mlxrl
|
|
25
|
+
|
|
26
|
+
Fast on-policy MLX RL for Apple Silicon; not a general RL framework, not
|
|
27
|
+
preference tuning, and not distributed training.
|
|
28
|
+
|
|
29
|
+
`mlxrl` is a small, single-process RL post-training library for LLMs on Apple
|
|
30
|
+
Silicon. It is built around one idea: GRPO on MLX should be a fast batched
|
|
31
|
+
rollout path with a thin loss and optimizer step on top, not a framework.
|
|
32
|
+
|
|
33
|
+
The current implementation targets QLoRA GRPO on local 4-bit MLX models. It
|
|
34
|
+
reuses `mlx-lm` model loading, LoRA layers, KV caches, and sampling utilities,
|
|
35
|
+
and keeps generation and training in one Python process with one model object.
|
|
36
|
+
|
|
37
|
+
`mlxrl` is pre-1.0. The correctness gates are stable, but import APIs and config
|
|
38
|
+
fields may change before a 1.0 release.
|
|
39
|
+
|
|
40
|
+
## Quickstart
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
git clone https://github.com/inductiveML/mlxrl.git
|
|
44
|
+
cd mlxrl
|
|
45
|
+
UV_CACHE_DIR=.uv-cache uv sync --all-groups
|
|
46
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
47
|
+
--config examples/qwen3_0_6b_grpo.toml \
|
|
48
|
+
--available-memory-gb 48
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
For the measured 9B-on-48GB shape, use the checkpointed G=2 config:
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
55
|
+
--config examples/qwen35_9b_g2_checkpoint.toml \
|
|
56
|
+
--available-memory-gb 48 \
|
|
57
|
+
--dry-run
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## What Works
|
|
61
|
+
|
|
62
|
+
- Batched group rollouts with MLX-LM KV caches and sampling.
|
|
63
|
+
- Full-forward old-policy logprob recompute for training-time `pi_old`.
|
|
64
|
+
- Adapter-disabled reference policy on the same model object.
|
|
65
|
+
- GRPO, Dr. GRPO, DAPO, and GSPO loss variants.
|
|
66
|
+
- RLOO (REINFORCE Leave-One-Out) as a critic-free rollout objective.
|
|
67
|
+
- QLoRA injection on dense and heterogeneous/hybrid layer stacks.
|
|
68
|
+
- Qwen3.5-style hybrid support via MLX-LM auto LoRA targeting, including
|
|
69
|
+
DeltaNet `linear_attn.in_proj_*` and dense attention `q/k/v/o_proj`.
|
|
70
|
+
- Per-layer gradient checkpointing through `mlx_lm.tuner.trainer.grad_checkpoint`
|
|
71
|
+
for linear-attention/DeltaNet backward memory.
|
|
72
|
+
- Micro-batched gradient accumulation for token-mean policy losses.
|
|
73
|
+
- `beta == 0` reference-forward skip.
|
|
74
|
+
- Phase 4 benchmark harness for `mlxrl`, `mlx-tune`, `mlx-lm-lora`, and `mlx-lm`.
|
|
75
|
+
|
|
76
|
+
## Install
|
|
77
|
+
|
|
78
|
+
After the first tagged release, install the PyPI distribution:
|
|
79
|
+
|
|
80
|
+
```bash
|
|
81
|
+
pip install inductive-mlxrl
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
The Python import package and CLI command are still `mlxrl`:
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
mlxrl --help
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
Source installs are also supported:
|
|
91
|
+
|
|
92
|
+
```bash
|
|
93
|
+
UV_CACHE_DIR=.uv-cache uv sync --all-groups
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
Run commands through the local environment:
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl --help
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
Python 3.11+ is required. Runtime dependencies are intentionally small:
|
|
103
|
+
`mlx` and `mlx-lm`. Development dependencies include `pytest`, `ruff`,
|
|
104
|
+
`pyright`, `mlx-tune`, and `mlx-lm-lora` for comparison benchmarks.
|
|
105
|
+
The PyPI distribution name is `inductive-mlxrl`; the import package and console
|
|
106
|
+
script remain `mlxrl`.
|
|
107
|
+
|
|
108
|
+
## Quick Smoke Tests
|
|
109
|
+
|
|
110
|
+
Dense Qwen3 0.6B:
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase0-smoke \
|
|
114
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
115
|
+
--prompt "What is 2+2?"
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
Hybrid Qwen3.5 9B with rank-16 LoRA:
|
|
119
|
+
|
|
120
|
+
```bash
|
|
121
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase0-smoke \
|
|
122
|
+
--model mlx-community/Qwen3.5-9B-MLX-4bit \
|
|
123
|
+
--rank 16 \
|
|
124
|
+
--scale 2.0 \
|
|
125
|
+
--prompt "What is 2+2?"
|
|
126
|
+
```
|
|
127
|
+
|
|
128
|
+
The smoke gate prints the model id, layer count, LoRA target keys, per-layer
|
|
129
|
+
LoRA module counts, total/trainable parameter counts, and logits shape. It
|
|
130
|
+
fails if any trainable leaf is not `lora_a` or `lora_b`.
|
|
131
|
+
|
|
132
|
+
## Training Commands
|
|
133
|
+
|
|
134
|
+
Toy hand-computed GRPO math gate:
|
|
135
|
+
|
|
136
|
+
```bash
|
|
137
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-toy-gate
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
Small built-in GSM8K-style run:
|
|
141
|
+
|
|
142
|
+
```bash
|
|
143
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-gsm8k \
|
|
144
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
145
|
+
--steps 20 \
|
|
146
|
+
--group-size 4 \
|
|
147
|
+
--max-tokens 64
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
Config-driven run:
|
|
151
|
+
|
|
152
|
+
```bash
|
|
153
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
154
|
+
--config examples/qwen3_0_6b_grpo.toml \
|
|
155
|
+
--available-memory-gb 48
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
The config schema validates model id, quant bits, group size, completion/prompt
|
|
159
|
+
lengths, checkpointing granularity, `iogpu.wired_limit_mb`, optimizer settings,
|
|
160
|
+
algorithm hyperparameters, KL beta, and seed before a model is loaded. CLI
|
|
161
|
+
overrides such as `--steps`, `--group-size`, `--max-tokens`, `--algorithm`,
|
|
162
|
+
`--beta`, and `--seed` apply on top of the file.
|
|
163
|
+
|
|
164
|
+
For DeltaNet / linear-attention models, enable per-layer checkpointing:
|
|
165
|
+
|
|
166
|
+
```bash
|
|
167
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-gsm8k \
|
|
168
|
+
--model mlx-community/Qwen3.5-9B-MLX-4bit \
|
|
169
|
+
--rank 16 \
|
|
170
|
+
--scale 2.0 \
|
|
171
|
+
--checkpoint-completion-forward \
|
|
172
|
+
--steps 1 \
|
|
173
|
+
--group-size 2 \
|
|
174
|
+
--max-tokens 256
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
Despite the historical CLI name, `--checkpoint-completion-forward` now enables
|
|
178
|
+
per-transformer-block checkpointing at model setup. The old whole-model
|
|
179
|
+
`mx.checkpoint(...)` wrapper was removed because it does not cap DeltaNet's
|
|
180
|
+
per-layer scan memory.
|
|
181
|
+
|
|
182
|
+
Phase 2 rollout equivalence check:
|
|
183
|
+
|
|
184
|
+
```bash
|
|
185
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl phase2-equivalence \
|
|
186
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
187
|
+
--group-size 4 \
|
|
188
|
+
--max-tokens 32 \
|
|
189
|
+
--compile-decode-step \
|
|
190
|
+
--batch-groups
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
## Import API
|
|
194
|
+
|
|
195
|
+
Minimal model setup:
|
|
196
|
+
|
|
197
|
+
```python
|
|
198
|
+
from mlxrl.policy import LoRAConfig, load_policy_with_lora
|
|
199
|
+
|
|
200
|
+
model, tokenizer, report = load_policy_with_lora(
|
|
201
|
+
model_id="mlx-community/Qwen3.5-9B-MLX-4bit",
|
|
202
|
+
config=LoRAConfig(
|
|
203
|
+
rank=16,
|
|
204
|
+
scale=2.0,
|
|
205
|
+
dropout=0.0,
|
|
206
|
+
grad_checkpoint=True,
|
|
207
|
+
),
|
|
208
|
+
)
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
One optimizer step:
|
|
212
|
+
|
|
213
|
+
```python
|
|
214
|
+
import mlx.optimizers as optim
|
|
215
|
+
|
|
216
|
+
from mlxrl.algo import GRPOAlgorithm
|
|
217
|
+
from mlxrl.train import batch_from_rollouts, optimizer_step
|
|
218
|
+
|
|
219
|
+
optimizer = optim.Adam(learning_rate=1e-5)
|
|
220
|
+
algorithm = GRPOAlgorithm()
|
|
221
|
+
batch = batch_from_rollouts(
|
|
222
|
+
model=model,
|
|
223
|
+
completions=completions,
|
|
224
|
+
rewards=rewards,
|
|
225
|
+
group_size=4,
|
|
226
|
+
pad_token_id=pad_token_id,
|
|
227
|
+
algorithm=algorithm,
|
|
228
|
+
compute_reference=beta != 0.0,
|
|
229
|
+
)
|
|
230
|
+
metrics = optimizer_step(
|
|
231
|
+
model=model,
|
|
232
|
+
optimizer=optimizer,
|
|
233
|
+
batch=batch,
|
|
234
|
+
beta=beta,
|
|
235
|
+
pad_token_id=pad_token_id,
|
|
236
|
+
algorithm=algorithm,
|
|
237
|
+
use_checkpoint=True,
|
|
238
|
+
micro_batch_size=2,
|
|
239
|
+
)
|
|
240
|
+
```
|
|
241
|
+
|
|
242
|
+
`micro_batch_size=0` keeps the original whole-batch path. Micro-batching is
|
|
243
|
+
currently exact for token-mean policy losses: base GRPO, DAPO, GSPO token mode,
|
|
244
|
+
RLOO, and Dr. GRPO with `loss_reduction="token_mean"`. Sequence-reduced losses
|
|
245
|
+
should keep `micro_batch_size=0`.
|
|
246
|
+
|
|
247
|
+
## Policy Semantics
|
|
248
|
+
|
|
249
|
+
- The base model is frozen before LoRA injection.
|
|
250
|
+
- Only LoRA adapter leaves are trainable.
|
|
251
|
+
- Reference logprobs are computed by temporarily disabling adapters on the same
|
|
252
|
+
model object; there is no second reference model in memory.
|
|
253
|
+
- Old-policy logprobs are recomputed with a full forward for the training batch.
|
|
254
|
+
Rollout-time logprobs are captured for inspection, but 4-bit sequential decode
|
|
255
|
+
and full-forward prefill are not numerically identical on hybrid/quantized
|
|
256
|
+
models, so recompute remains the default training semantics.
|
|
257
|
+
- When `beta == 0`, the reference forward is skipped and the policy logprobs are
|
|
258
|
+
used as a zero-KL placeholder.
|
|
259
|
+
- PPO, DPO, and ORPO are intentionally out of scope. PPO needs a separate critic
|
|
260
|
+
and value forward; DPO/ORPO are offline preference objectives with no rollout
|
|
261
|
+
phase. `mlxrl` is critic-free, on-policy, and rollout-based by design.
|
|
262
|
+
|
|
263
|
+
## Algorithms
|
|
264
|
+
|
|
265
|
+
Concrete algorithms implement the small `Algorithm` protocol: compute
|
|
266
|
+
advantages, optionally filter a prepared batch, then compute a loss from policy,
|
|
267
|
+
old-policy, and reference logprobs. `rollout/`, `policy/`, and `train/` do not
|
|
268
|
+
import concrete algorithm implementations.
|
|
269
|
+
|
|
270
|
+
| algorithm | defining behavior |
|
|
271
|
+
| --- | --- |
|
|
272
|
+
| GRPO | group-normalized rewards, token-level importance ratio |
|
|
273
|
+
| Dr. GRPO | centered or normalized rewards with decoupled length reduction |
|
|
274
|
+
| DAPO | asymmetric low/high clipping plus optional dynamic zero-advantage group filtering |
|
|
275
|
+
| GSPO | sequence-level, length-normalized importance ratio and clipping |
|
|
276
|
+
| RLOO | leave-one-out group baseline, no critic, no std-normalized advantage |
|
|
277
|
+
|
|
278
|
+
## Memory Preflight
|
|
279
|
+
|
|
280
|
+
`mlxrl train` can estimate memory before loading the model:
|
|
281
|
+
|
|
282
|
+
```bash
|
|
283
|
+
UV_CACHE_DIR=.uv-cache uv run mlxrl train \
|
|
284
|
+
--config examples/qwen3_0_6b_grpo.toml \
|
|
285
|
+
--available-memory-gb 48 \
|
|
286
|
+
--dry-run
|
|
287
|
+
```
|
|
288
|
+
|
|
289
|
+
The estimator is calibrated to measured anchors: `6.245 GB` for
|
|
290
|
+
Qwen3-0.6B/G4/prompt≈19/T256, `25.9 GB` for
|
|
291
|
+
Qwen3.5-9B/G2/seq609/per-layer-checkpointed, `45.9 GB` for
|
|
292
|
+
Qwen3.5-9B/G4/seq609/per-layer-checkpointed, and `36 GB` for
|
|
293
|
+
Qwen3.5-9B/G2/seq128/no-checkpoint. For hybrid 9B no-checkpoint long-sequence
|
|
294
|
+
configs, it reports an OOM-risk lower bound rather than a fake precise peak.
|
|
295
|
+
For an obviously too-large Qwen3.5-9B/G8/prompt97/T512/no-checkpoint config on
|
|
296
|
+
48 GB, it flags the run and suggests the measured-boundary fallback around
|
|
297
|
+
G4/T512/checkpointed.
|
|
298
|
+
|
|
299
|
+
## Benchmarks
|
|
300
|
+
|
|
301
|
+
Local M4 Max Phase 4 snapshot:
|
|
302
|
+
|
|
303
|
+
- `454` rollout tok/s on Qwen3-0.6B GRPO with G=4 and 256-token completions.
|
|
304
|
+
- `0.283` end-to-end it/s with full `mlxrl` training semantics.
|
|
305
|
+
- `3.2x` faster rollout and `2.2x` higher end-to-end it/s than `mlx-tune`
|
|
306
|
+
v0.5.1 on the same run shape.
|
|
307
|
+
- `1.3x` faster rollout than sequential `mlx-lm` generation at G=4.
|
|
308
|
+
|
|
309
|
+
These are the two-pass means from
|
|
310
|
+
`benchmarks/results/gate5_full_reconciled.md`, run with MLX 0.31.2,
|
|
311
|
+
MLX-LM 0.31.3, `mlx-community/Qwen3-0.6B-4bit`, 100 measured steps with
|
|
312
|
+
5 warmup steps discarded:
|
|
313
|
+
|
|
314
|
+
| target | comparison | rollout tok/s | grad s/step | samples/s | it/s | peak GB |
|
|
315
|
+
| --- | --- | ---: | ---: | ---: | ---: | ---: |
|
|
316
|
+
| `mlxrl` | apples-to-apples GRPO | 454.1 | 1.282 | 1.133 | 0.283 | 6.25 |
|
|
317
|
+
| `mlx-lm` | generation-only, G=1 | 347.0 | - | 1.355 | - | 0.52 |
|
|
318
|
+
| `mlx-lm-g4` | generation-only, sequential G=4 | 349.7 | - | 1.366 | - | 0.52 |
|
|
319
|
+
| `mlx-tune` | package-speed reference | 142.2 | 0.502 | 0.519 | 0.130 | 6.16 |
|
|
320
|
+
| `mlx-lm-lora` | package-speed reference | 557.9 | 0.592 | 1.648 | 0.412 | 5.32 |
|
|
321
|
+
|
|
322
|
+
`mlx-lm-lora` reports higher raw package-speed throughput in this snapshot, but
|
|
323
|
+
its benchmarked path is not the same training problem as `mlxrl`'s live
|
|
324
|
+
old-policy/reference semantics and completion-loss masking. That is the honest
|
|
325
|
+
case where `mlxrl` is not faster; the apples-to-apples comparison label is
|
|
326
|
+
reserved for `mlxrl`'s own semantic path. On the 9B Noether real workload, the
|
|
327
|
+
checkpointed MLX path measured about 6x faster than the previous torch-MPS path;
|
|
328
|
+
that workload is separate from the public Phase 4 package-speed harness.
|
|
329
|
+
|
|
330
|
+
Run the Phase 4 harness:
|
|
331
|
+
|
|
332
|
+
```bash
|
|
333
|
+
UV_CACHE_DIR=.uv-cache uv run python benchmarks/run_phase4.py run \
|
|
334
|
+
--targets mlxrl,mlx-lm,mlx-tune,mlx-lm-lora \
|
|
335
|
+
--model mlx-community/Qwen3-0.6B-4bit \
|
|
336
|
+
--steps 100 \
|
|
337
|
+
--warmup-steps 5 \
|
|
338
|
+
--group-size 4 \
|
|
339
|
+
--max-tokens 256 \
|
|
340
|
+
--passes 2 \
|
|
341
|
+
--output benchmarks/results/phase4.jsonl \
|
|
342
|
+
--summary benchmarks/results/phase4.md \
|
|
343
|
+
--allow-missing-baselines
|
|
344
|
+
```
|
|
345
|
+
|
|
346
|
+
The harness reports synchronized rollout tok/s, gradient seconds per step,
|
|
347
|
+
samples/s, it/s, and peak MLX memory. `mlx-lm` targets are generation-only;
|
|
348
|
+
external package targets are useful speed references but may not match `mlxrl`
|
|
349
|
+
training semantics.
|
|
350
|
+
|
|
351
|
+
## Development
|
|
352
|
+
|
|
353
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) and [DESIGN.md](DESIGN.md) before adding
|
|
354
|
+
algorithms or changing rollout/logprob semantics.
|
|
355
|
+
|
|
356
|
+
Run the quality gates:
|
|
357
|
+
|
|
358
|
+
```bash
|
|
359
|
+
UV_CACHE_DIR=.uv-cache uv run pytest
|
|
360
|
+
UV_CACHE_DIR=.uv-cache uv run ruff check .
|
|
361
|
+
UV_CACHE_DIR=.uv-cache uv run pyright
|
|
362
|
+
```
|
|
363
|
+
|
|
364
|
+
MLX lazy evaluation matters. Any `mx.eval(...)` or `mx.synchronize()` in this
|
|
365
|
+
repo should mark a real boundary: sampled token append/EOS checks, logprob
|
|
366
|
+
freezing before adapter mutation, per-micro-batch graph release, optimizer
|
|
367
|
+
updates, or benchmark timing boundaries.
|
|
368
|
+
|
|
369
|
+
## Layout
|
|
370
|
+
|
|
371
|
+
```text
|
|
372
|
+
mlxrl/
|
|
373
|
+
rollout/ # batched group generation
|
|
374
|
+
policy/ # model loading, LoRA setup, logprob passes
|
|
375
|
+
algo/ # GRPO-family advantages and losses
|
|
376
|
+
train/ # value_and_grad and optimizer integration
|
|
377
|
+
data/ # toy GSM8K data and rewards
|
|
378
|
+
cli.py
|
|
379
|
+
tests/
|
|
380
|
+
benchmarks/
|
|
381
|
+
```
|
|
382
|
+
|
|
383
|
+
## Non-Goals
|
|
384
|
+
|
|
385
|
+
- No inference server or second model copy.
|
|
386
|
+
- No CUDA or torch fallback.
|
|
387
|
+
- No distributed training.
|
|
388
|
+
- No broad RL framework abstractions beyond the small algorithm interface.
|