inductive-mlxrl 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. inductive_mlxrl-0.1.0/.gitignore +14 -0
  2. inductive_mlxrl-0.1.0/CHANGELOG.md +29 -0
  3. inductive_mlxrl-0.1.0/CONTRIBUTING.md +107 -0
  4. inductive_mlxrl-0.1.0/DESIGN.md +72 -0
  5. inductive_mlxrl-0.1.0/LICENSE +21 -0
  6. inductive_mlxrl-0.1.0/PKG-INFO +388 -0
  7. inductive_mlxrl-0.1.0/README.md +365 -0
  8. inductive_mlxrl-0.1.0/THIRD_PARTY_LICENSES.md +46 -0
  9. inductive_mlxrl-0.1.0/benchmarks/__init__.py +1 -0
  10. inductive_mlxrl-0.1.0/benchmarks/audit_phase4_config.py +385 -0
  11. inductive_mlxrl-0.1.0/benchmarks/external_baseline_worker.py +553 -0
  12. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after.jsonl +1 -0
  13. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after.md +9 -0
  14. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_final.jsonl +1 -0
  15. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_final.md +9 -0
  16. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_noprefixsync.jsonl +1 -0
  17. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_after_noprefixsync.md +9 -0
  18. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_before.jsonl +1 -0
  19. inductive_mlxrl-0.1.0/benchmarks/results/gate1_mlxrl_before.md +9 -0
  20. inductive_mlxrl-0.1.0/benchmarks/results/gate2_unified_smoke.jsonl +5 -0
  21. inductive_mlxrl-0.1.0/benchmarks/results/gate2_unified_smoke.md +17 -0
  22. inductive_mlxrl-0.1.0/benchmarks/results/gate3_mlx_tune_sync_256.jsonl +1 -0
  23. inductive_mlxrl-0.1.0/benchmarks/results/gate3_mlx_tune_sync_256.md +9 -0
  24. inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_blocked.jsonl +5 -0
  25. inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_blocked.md +11 -0
  26. inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_smoke.jsonl +5 -0
  27. inductive_mlxrl-0.1.0/benchmarks/results/gate3_sync_smoke.md +17 -0
  28. inductive_mlxrl-0.1.0/benchmarks/results/gate4_config_audit.json +102 -0
  29. inductive_mlxrl-0.1.0/benchmarks/results/gate4_config_audit.md +9 -0
  30. inductive_mlxrl-0.1.0/benchmarks/results/gate4_sampler_smoke.jsonl +5 -0
  31. inductive_mlxrl-0.1.0/benchmarks/results/gate4_sampler_smoke.md +17 -0
  32. inductive_mlxrl-0.1.0/benchmarks/results/gate5_full_reconciled.jsonl +10 -0
  33. inductive_mlxrl-0.1.0/benchmarks/results/gate5_full_reconciled.md +28 -0
  34. inductive_mlxrl-0.1.0/benchmarks/results/phase4_external_smoke.jsonl +2 -0
  35. inductive_mlxrl-0.1.0/benchmarks/results/phase4_external_smoke.md +11 -0
  36. inductive_mlxrl-0.1.0/benchmarks/results/phase4_full_all_targets.jsonl +8 -0
  37. inductive_mlxrl-0.1.0/benchmarks/results/phase4_full_all_targets.md +19 -0
  38. inductive_mlxrl-0.1.0/benchmarks/results/phase4_local_mlxrl_vs_mlx_lm.jsonl +4 -0
  39. inductive_mlxrl-0.1.0/benchmarks/results/phase4_local_mlxrl_vs_mlx_lm.md +13 -0
  40. inductive_mlxrl-0.1.0/benchmarks/results/phase4_mlx_lm_g4_smoke.jsonl +2 -0
  41. inductive_mlxrl-0.1.0/benchmarks/results/phase4_mlx_lm_g4_smoke.md +11 -0
  42. inductive_mlxrl-0.1.0/benchmarks/run_phase4.py +945 -0
  43. inductive_mlxrl-0.1.0/mlxrl/__init__.py +6 -0
  44. inductive_mlxrl-0.1.0/mlxrl/algo/__init__.py +43 -0
  45. inductive_mlxrl-0.1.0/mlxrl/algo/grpo.py +568 -0
  46. inductive_mlxrl-0.1.0/mlxrl/algorithm.py +63 -0
  47. inductive_mlxrl-0.1.0/mlxrl/cli.py +793 -0
  48. inductive_mlxrl-0.1.0/mlxrl/config.py +397 -0
  49. inductive_mlxrl-0.1.0/mlxrl/data/__init__.py +31 -0
  50. inductive_mlxrl-0.1.0/mlxrl/data/gsm8k.py +60 -0
  51. inductive_mlxrl-0.1.0/mlxrl/data/rewards.py +101 -0
  52. inductive_mlxrl-0.1.0/mlxrl/policy/__init__.py +41 -0
  53. inductive_mlxrl-0.1.0/mlxrl/policy/logprobs.py +299 -0
  54. inductive_mlxrl-0.1.0/mlxrl/policy/model.py +300 -0
  55. inductive_mlxrl-0.1.0/mlxrl/py.typed +1 -0
  56. inductive_mlxrl-0.1.0/mlxrl/rollout/__init__.py +29 -0
  57. inductive_mlxrl-0.1.0/mlxrl/rollout/naive.py +186 -0
  58. inductive_mlxrl-0.1.0/mlxrl/rollout/optimized.py +644 -0
  59. inductive_mlxrl-0.1.0/mlxrl/train/__init__.py +5 -0
  60. inductive_mlxrl-0.1.0/mlxrl/train/grpo.py +304 -0
  61. inductive_mlxrl-0.1.0/pyproject.toml +83 -0
  62. inductive_mlxrl-0.1.0/tests/test_benchmark_metrics.py +150 -0
  63. inductive_mlxrl-0.1.0/tests/test_cli.py +111 -0
  64. inductive_mlxrl-0.1.0/tests/test_config.py +138 -0
  65. inductive_mlxrl-0.1.0/tests/test_grpo.py +324 -0
  66. inductive_mlxrl-0.1.0/tests/test_import_direction.py +27 -0
  67. inductive_mlxrl-0.1.0/tests/test_optimized_rollout.py +76 -0
  68. inductive_mlxrl-0.1.0/tests/test_policy_logprobs.py +232 -0
  69. inductive_mlxrl-0.1.0/tests/test_policy_model.py +180 -0
  70. inductive_mlxrl-0.1.0/tests/test_rewards.py +54 -0
  71. inductive_mlxrl-0.1.0/tests/test_train_grpo.py +141 -0
@@ -0,0 +1,14 @@
1
+ .venv/
2
+ .uv-cache/
3
+ .uv-python/
4
+ .pytest_cache/
5
+ .ruff_cache/
6
+ __pycache__/
7
+ *.py[cod]
8
+ dist/
9
+ build/
10
+ *.egg-info/
11
+ .pyright/
12
+ reference_outputs/
13
+ runs/
14
+ grpo_outputs/
@@ -0,0 +1,29 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project are documented here.
4
+
5
+ The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
+ and this project uses semantic versioning once releases are cut.
7
+
8
+ ## [0.1.0] - 2026-06-24
9
+
10
+ ### Added
11
+
12
+ - Single-process MLX GRPO+QLoRA training path for Apple Silicon.
13
+ - Prefix-cached grouped rollout engine with compiled decode support.
14
+ - Adapter-disabled reference logprob path on one model object.
15
+ - Full-forward old-policy logprob recompute for 4-bit correctness.
16
+ - Algorithm protocol with GRPO, Dr. GRPO, DAPO, GSPO, and RLOO.
17
+ - DAPO `filter_batch` hook for dynamic zero-advantage group filtering.
18
+ - Per-layer MLX-LM gradient checkpointing for DeltaNet/linear-attention models.
19
+ - Micro-batched gradient accumulation for token-mean losses.
20
+ - Typed config schema and memory preflight helper calibrated to measured anchors.
21
+ - Phase 4 benchmark harness for `mlxrl`, `mlx-lm`, `mlx-tune`, and `mlx-lm-lora`.
22
+ - Linux CPU-safe CI for docs/config/import-direction checks.
23
+
24
+ ### Notes
25
+
26
+ - `mlxrl` is pre-1.0; public APIs may change while the correctness gates settle.
27
+ - Full MLX/Metal tests require Apple Silicon or a self-hosted Mac runner.
28
+ - Published on PyPI as `inductive-mlxrl`; the import package and CLI command
29
+ remain `mlxrl`.
@@ -0,0 +1,107 @@
1
+ # Contributing
2
+
3
+ `mlxrl` is fast on-policy MLX RL for Apple Silicon. It is not a broad RL
4
+ framework, not preference tuning, and not multi-GPU or distributed training.
5
+
6
+ ## Development Loop
7
+
8
+ Use `uv` from the repository root:
9
+
10
+ ```bash
11
+ UV_CACHE_DIR=.uv-cache uv sync --all-groups
12
+ UV_CACHE_DIR=.uv-cache uv run pytest
13
+ UV_CACHE_DIR=.uv-cache uv run ruff check .
14
+ UV_CACHE_DIR=.uv-cache uv run pyright
15
+ ```
16
+
17
+ Tests must stay green while refactoring. Correctness tests are not optional:
18
+ rollout equivalence, loss/gradient checks, import-direction guards, and memory
19
+ estimator anchor tests are part of the contract.
20
+
21
+ ## Adding An Algorithm
22
+
23
+ Use `RLOOAlgorithm` in `mlxrl/algo/grpo.py` as the smallest template.
24
+
25
+ Required steps:
26
+
27
+ - implement the `Algorithm` protocol;
28
+ - add a hand-computed toy loss and gradient test;
29
+ - add a reduction or relationship test when the algorithm should collapse to an
30
+ existing objective under a degenerate config;
31
+ - keep algorithm code in `mlxrl/algo/`;
32
+ - do not import concrete algorithms from `rollout/`, `policy/`, or `train/`;
33
+ - add an end-to-end smoke before relying on a new objective.
34
+
35
+ If an algorithm needs to drop or reshape examples before the loss, use the
36
+ `filter_batch` hook. Do not special-case it in the trainer.
37
+
38
+ ## MLX And CI
39
+
40
+ GitHub-hosted Linux runners cannot run the MLX/Metal path. Linux CI gates the
41
+ CPU-safe subset: config validation, memory-estimator math, import-direction
42
+ checks, reward functions, and benchmark-result rendering.
43
+
44
+ Tests marked `@pytest.mark.metal` require Apple Silicon with MLX/Metal. Run the
45
+ full suite locally on a Mac before cutting releases. A self-hosted Mac runner
46
+ would close this coverage gap.
47
+
48
+ ## Release Cycle
49
+
50
+ Releases are tag-driven and publish to PyPI through Trusted Publishing. Do not
51
+ store PyPI API tokens in GitHub secrets.
52
+
53
+ One-time PyPI setup:
54
+
55
+ - create a PyPI Trusted Publisher for project `inductive-mlxrl`;
56
+ - set owner to `inductiveML` and repository to `mlxrl`;
57
+ - set workflow name to `release.yml`;
58
+ - set environment name to `pypi`;
59
+ - require manual approval on the GitHub `pypi` environment before publishing.
60
+
61
+ For each release:
62
+
63
+ 1. Update the version in `pyproject.toml`.
64
+ 2. Add the release notes to `CHANGELOG.md`.
65
+ 3. Run the local gates:
66
+
67
+ ```bash
68
+ UV_CACHE_DIR=.uv-cache uv run pytest
69
+ UV_CACHE_DIR=.uv-cache uv run ruff check .
70
+ UV_CACHE_DIR=.uv-cache uv run pyright
71
+ UV_CACHE_DIR=.uv-cache uv build
72
+ UV_CACHE_DIR=.uv-cache uv run --no-project --python 3.11 --with twine twine check dist/*
73
+ ```
74
+
75
+ 4. Commit and merge the release prep.
76
+ 5. Tag the release from `main`:
77
+
78
+ ```bash
79
+ git tag -a vX.Y.Z -m "Release vX.Y.Z"
80
+ git push origin vX.Y.Z
81
+ ```
82
+
83
+ The `Release` workflow builds the source distribution and wheel, checks package
84
+ metadata with Twine, stores the artifacts, publishes them to PyPI from the
85
+ `pypi` environment, and creates a GitHub Release with the built distributions
86
+ attached. Running the workflow manually builds and checks artifacts without
87
+ publishing because the publish and GitHub Release jobs only run for `v*.*.*`
88
+ tags.
89
+
90
+ The PyPI distribution name is `inductive-mlxrl`; the Python import package and
91
+ CLI command remain `mlxrl`.
92
+
93
+ ## Scope
94
+
95
+ In scope:
96
+
97
+ - single-process Apple Silicon RL post-training;
98
+ - QLoRA on local MLX LLMs;
99
+ - critic-free on-policy algorithms;
100
+ - memory-conscious rollout and training paths.
101
+
102
+ Out of scope:
103
+
104
+ - PPO or other critic/value-model algorithms;
105
+ - DPO/ORPO and other offline preference objectives;
106
+ - CUDA, torch fallback, or distributed training;
107
+ - inference servers or second reference-model copies.
@@ -0,0 +1,72 @@
1
+ # mlxrl Design
2
+
3
+ `mlxrl` is a small, single-process RL post-training library for Apple Silicon.
4
+ The design goal is narrow: make on-policy, critic-free RL for local MLX LLMs
5
+ fast enough that rollout is the main problem and the loss stays thin.
6
+
7
+ ## Critic-Free By Design
8
+
9
+ The library is intentionally limited to rollout-based policy-gradient methods:
10
+ GRPO, Dr. GRPO, DAPO, GSPO, and RLOO. PPO is out of scope because it needs a
11
+ critic/value model path, a different forward pass, and a different memory
12
+ profile. DPO and ORPO are also out of scope because they are offline preference
13
+ objectives rather than on-policy rollout algorithms.
14
+
15
+ That boundary matters. `mlxrl` keeps one policy model object in memory, attaches
16
+ QLoRA adapters to it, and computes reference logprobs by disabling those
17
+ adapters for a second pass. There is no inference server and no second reference
18
+ model copy.
19
+
20
+ ## Algorithm Protocol
21
+
22
+ The engine does not know which algorithm it is training. Concrete algorithms
23
+ implement the `Algorithm` protocol:
24
+
25
+ - compute per-completion advantages;
26
+ - optionally filter a prepared batch;
27
+ - compute loss and diagnostics from policy, old-policy, and reference logprobs.
28
+
29
+ `rollout/`, `policy/`, and `train/` must not import from `algo/`. The import
30
+ direction test enforces this. The payoff is that rollout and logprob code stay
31
+ stable while algorithms change. DAPO's dynamic sampling is the proof that the
32
+ interface is more general than "GRPO with renamed constants": it drops
33
+ zero-advantage groups through `filter_batch` without special trainer branches.
34
+
35
+ ## Correctness Gates
36
+
37
+ Speed changes are allowed only behind equivalence gates. The original Phase 1
38
+ path is simple and readable; optimized rollout variants must match it
39
+ token-for-token at fixed seed and match loss within tolerance. The protocol
40
+ refactor was checked against the pre-refactor GRPO and Dr. GRPO losses with
41
+ zero loss and adapter-gradient difference.
42
+
43
+ Old-policy changes use a stronger gate than naive equality. At 4-bit, the
44
+ rollout-time cached decode realization and the full-forward realization are not
45
+ numerically identical. Stored rollout logprobs are attractive because they are
46
+ the behavior policy, but the actual gate is importance-ratio stability: on a
47
+ freshly sampled batch, `exp(logpi_current - logpi_old)` must stay near 1.0 and
48
+ short training runs must not show KL or gradient spikes. Until that gate justifies
49
+ a semantics change, `mlxrl` recomputes old-policy logprobs with a full forward.
50
+
51
+ ## The 4-Bit KV Boundary
52
+
53
+ Quantized KV cache decode is not a drop-in numerical replacement for a full
54
+ forward over prompt+completion. That is fine for sampling, but it is not fine
55
+ for gradient-bearing or importance-weighting quantities. Anything used in the
56
+ loss denominator, KL, or adapter gradient must come from the full-forward path
57
+ unless a dedicated stability gate proves otherwise.
58
+
59
+ ## Memory As A First-Class Constraint
60
+
61
+ Apple Silicon's unified memory is the deployment target, not an afterthought.
62
+ The code supports per-layer MLX-LM gradient checkpointing because DeltaNet and
63
+ linear-attention models otherwise keep O(sequence length) recurrent state live
64
+ through backward. The 9B fitting anchor is Qwen3.5-9B 4-bit at G=2, total
65
+ sequence length 609, with per-layer checkpointing at about 25.9 GB peak. G=4 at
66
+ the same shape is about 45.9 GB and tight on a 48 GB machine.
67
+
68
+ The memory estimator is deliberately conservative. It interpolates near measured
69
+ anchors and labels long-sequence uncheckpointed hybrid configs as OOM-risk
70
+ estimates, not measurements. Its purpose is to nudge users toward the knob that
71
+ will most likely make a run fit: enable checkpointing, reduce G, then reduce
72
+ completion length.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 mlxrl contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,388 @@
1
+ Metadata-Version: 2.4
2
+ Name: inductive-mlxrl
3
+ Version: 0.1.0
4
+ Summary: Small single-process RL post-training for LLMs on Apple Silicon with MLX.
5
+ Project-URL: Homepage, https://github.com/inductiveML/mlxrl
6
+ Project-URL: Repository, https://github.com/inductiveML/mlxrl
7
+ Project-URL: Issues, https://github.com/inductiveML/mlxrl/issues
8
+ Project-URL: Changelog, https://github.com/inductiveML/mlxrl/blob/main/CHANGELOG.md
9
+ License-Expression: MIT
10
+ License-File: LICENSE
11
+ License-File: THIRD_PARTY_LICENSES.md
12
+ Keywords: apple-silicon,grpo,mlx,qlora,rl
13
+ Classifier: Development Status :: 3 - Alpha
14
+ Classifier: Intended Audience :: Developers
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.11
20
+ Requires-Dist: mlx-lm>=0.31.0
21
+ Requires-Dist: mlx>=0.31.0
22
+ Description-Content-Type: text/markdown
23
+
24
+ # mlxrl
25
+
26
+ Fast on-policy MLX RL for Apple Silicon; not a general RL framework, not
27
+ preference tuning, and not distributed training.
28
+
29
+ `mlxrl` is a small, single-process RL post-training library for LLMs on Apple
30
+ Silicon. It is built around one idea: GRPO on MLX should be a fast batched
31
+ rollout path with a thin loss and optimizer step on top, not a framework.
32
+
33
+ The current implementation targets QLoRA GRPO on local 4-bit MLX models. It
34
+ reuses `mlx-lm` model loading, LoRA layers, KV caches, and sampling utilities,
35
+ and keeps generation and training in one Python process with one model object.
36
+
37
+ `mlxrl` is pre-1.0. The correctness gates are stable, but import APIs and config
38
+ fields may change before a 1.0 release.
39
+
40
+ ## Quickstart
41
+
42
+ ```bash
43
+ git clone https://github.com/inductiveML/mlxrl.git
44
+ cd mlxrl
45
+ UV_CACHE_DIR=.uv-cache uv sync --all-groups
46
+ UV_CACHE_DIR=.uv-cache uv run mlxrl train \
47
+ --config examples/qwen3_0_6b_grpo.toml \
48
+ --available-memory-gb 48
49
+ ```
50
+
51
+ For the measured 9B-on-48GB shape, use the checkpointed G=2 config:
52
+
53
+ ```bash
54
+ UV_CACHE_DIR=.uv-cache uv run mlxrl train \
55
+ --config examples/qwen35_9b_g2_checkpoint.toml \
56
+ --available-memory-gb 48 \
57
+ --dry-run
58
+ ```
59
+
60
+ ## What Works
61
+
62
+ - Batched group rollouts with MLX-LM KV caches and sampling.
63
+ - Full-forward old-policy logprob recompute for training-time `pi_old`.
64
+ - Adapter-disabled reference policy on the same model object.
65
+ - GRPO, Dr. GRPO, DAPO, and GSPO loss variants.
66
+ - RLOO (REINFORCE Leave-One-Out) as a critic-free rollout objective.
67
+ - QLoRA injection on dense and heterogeneous/hybrid layer stacks.
68
+ - Qwen3.5-style hybrid support via MLX-LM auto LoRA targeting, including
69
+ DeltaNet `linear_attn.in_proj_*` and dense attention `q/k/v/o_proj`.
70
+ - Per-layer gradient checkpointing through `mlx_lm.tuner.trainer.grad_checkpoint`
71
+ for linear-attention/DeltaNet backward memory.
72
+ - Micro-batched gradient accumulation for token-mean policy losses.
73
+ - `beta == 0` reference-forward skip.
74
+ - Phase 4 benchmark harness for `mlxrl`, `mlx-tune`, `mlx-lm-lora`, and `mlx-lm`.
75
+
76
+ ## Install
77
+
78
+ After the first tagged release, install the PyPI distribution:
79
+
80
+ ```bash
81
+ pip install inductive-mlxrl
82
+ ```
83
+
84
+ The Python import package and CLI command are still `mlxrl`:
85
+
86
+ ```bash
87
+ mlxrl --help
88
+ ```
89
+
90
+ Source installs are also supported:
91
+
92
+ ```bash
93
+ UV_CACHE_DIR=.uv-cache uv sync --all-groups
94
+ ```
95
+
96
+ Run commands through the local environment:
97
+
98
+ ```bash
99
+ UV_CACHE_DIR=.uv-cache uv run mlxrl --help
100
+ ```
101
+
102
+ Python 3.11+ is required. Runtime dependencies are intentionally small:
103
+ `mlx` and `mlx-lm`. Development dependencies include `pytest`, `ruff`,
104
+ `pyright`, `mlx-tune`, and `mlx-lm-lora` for comparison benchmarks.
105
+ The PyPI distribution name is `inductive-mlxrl`; the import package and console
106
+ script remain `mlxrl`.
107
+
108
+ ## Quick Smoke Tests
109
+
110
+ Dense Qwen3 0.6B:
111
+
112
+ ```bash
113
+ UV_CACHE_DIR=.uv-cache uv run mlxrl phase0-smoke \
114
+ --model mlx-community/Qwen3-0.6B-4bit \
115
+ --prompt "What is 2+2?"
116
+ ```
117
+
118
+ Hybrid Qwen3.5 9B with rank-16 LoRA:
119
+
120
+ ```bash
121
+ UV_CACHE_DIR=.uv-cache uv run mlxrl phase0-smoke \
122
+ --model mlx-community/Qwen3.5-9B-MLX-4bit \
123
+ --rank 16 \
124
+ --scale 2.0 \
125
+ --prompt "What is 2+2?"
126
+ ```
127
+
128
+ The smoke gate prints the model id, layer count, LoRA target keys, per-layer
129
+ LoRA module counts, total/trainable parameter counts, and logits shape. It
130
+ fails if any trainable leaf is not `lora_a` or `lora_b`.
131
+
132
+ ## Training Commands
133
+
134
+ Toy hand-computed GRPO math gate:
135
+
136
+ ```bash
137
+ UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-toy-gate
138
+ ```
139
+
140
+ Small built-in GSM8K-style run:
141
+
142
+ ```bash
143
+ UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-gsm8k \
144
+ --model mlx-community/Qwen3-0.6B-4bit \
145
+ --steps 20 \
146
+ --group-size 4 \
147
+ --max-tokens 64
148
+ ```
149
+
150
+ Config-driven run:
151
+
152
+ ```bash
153
+ UV_CACHE_DIR=.uv-cache uv run mlxrl train \
154
+ --config examples/qwen3_0_6b_grpo.toml \
155
+ --available-memory-gb 48
156
+ ```
157
+
158
+ The config schema validates model id, quant bits, group size, completion/prompt
159
+ lengths, checkpointing granularity, `iogpu.wired_limit_mb`, optimizer settings,
160
+ algorithm hyperparameters, KL beta, and seed before a model is loaded. CLI
161
+ overrides such as `--steps`, `--group-size`, `--max-tokens`, `--algorithm`,
162
+ `--beta`, and `--seed` apply on top of the file.
163
+
164
+ For DeltaNet / linear-attention models, enable per-layer checkpointing:
165
+
166
+ ```bash
167
+ UV_CACHE_DIR=.uv-cache uv run mlxrl phase1-gsm8k \
168
+ --model mlx-community/Qwen3.5-9B-MLX-4bit \
169
+ --rank 16 \
170
+ --scale 2.0 \
171
+ --checkpoint-completion-forward \
172
+ --steps 1 \
173
+ --group-size 2 \
174
+ --max-tokens 256
175
+ ```
176
+
177
+ Despite the historical CLI name, `--checkpoint-completion-forward` now enables
178
+ per-transformer-block checkpointing at model setup. The old whole-model
179
+ `mx.checkpoint(...)` wrapper was removed because it does not cap DeltaNet's
180
+ per-layer scan memory.
181
+
182
+ Phase 2 rollout equivalence check:
183
+
184
+ ```bash
185
+ UV_CACHE_DIR=.uv-cache uv run mlxrl phase2-equivalence \
186
+ --model mlx-community/Qwen3-0.6B-4bit \
187
+ --group-size 4 \
188
+ --max-tokens 32 \
189
+ --compile-decode-step \
190
+ --batch-groups
191
+ ```
192
+
193
+ ## Import API
194
+
195
+ Minimal model setup:
196
+
197
+ ```python
198
+ from mlxrl.policy import LoRAConfig, load_policy_with_lora
199
+
200
+ model, tokenizer, report = load_policy_with_lora(
201
+ model_id="mlx-community/Qwen3.5-9B-MLX-4bit",
202
+ config=LoRAConfig(
203
+ rank=16,
204
+ scale=2.0,
205
+ dropout=0.0,
206
+ grad_checkpoint=True,
207
+ ),
208
+ )
209
+ ```
210
+
211
+ One optimizer step:
212
+
213
+ ```python
214
+ import mlx.optimizers as optim
215
+
216
+ from mlxrl.algo import GRPOAlgorithm
217
+ from mlxrl.train import batch_from_rollouts, optimizer_step
218
+
219
+ optimizer = optim.Adam(learning_rate=1e-5)
220
+ algorithm = GRPOAlgorithm()
221
+ batch = batch_from_rollouts(
222
+ model=model,
223
+ completions=completions,
224
+ rewards=rewards,
225
+ group_size=4,
226
+ pad_token_id=pad_token_id,
227
+ algorithm=algorithm,
228
+ compute_reference=beta != 0.0,
229
+ )
230
+ metrics = optimizer_step(
231
+ model=model,
232
+ optimizer=optimizer,
233
+ batch=batch,
234
+ beta=beta,
235
+ pad_token_id=pad_token_id,
236
+ algorithm=algorithm,
237
+ use_checkpoint=True,
238
+ micro_batch_size=2,
239
+ )
240
+ ```
241
+
242
+ `micro_batch_size=0` keeps the original whole-batch path. Micro-batching is
243
+ currently exact for token-mean policy losses: base GRPO, DAPO, GSPO token mode,
244
+ RLOO, and Dr. GRPO with `loss_reduction="token_mean"`. Sequence-reduced losses
245
+ should keep `micro_batch_size=0`.
246
+
247
+ ## Policy Semantics
248
+
249
+ - The base model is frozen before LoRA injection.
250
+ - Only LoRA adapter leaves are trainable.
251
+ - Reference logprobs are computed by temporarily disabling adapters on the same
252
+ model object; there is no second reference model in memory.
253
+ - Old-policy logprobs are recomputed with a full forward for the training batch.
254
+ Rollout-time logprobs are captured for inspection, but 4-bit sequential decode
255
+ and full-forward prefill are not numerically identical on hybrid/quantized
256
+ models, so recompute remains the default training semantics.
257
+ - When `beta == 0`, the reference forward is skipped and the policy logprobs are
258
+ used as a zero-KL placeholder.
259
+ - PPO, DPO, and ORPO are intentionally out of scope. PPO needs a separate critic
260
+ and value forward; DPO/ORPO are offline preference objectives with no rollout
261
+ phase. `mlxrl` is critic-free, on-policy, and rollout-based by design.
262
+
263
+ ## Algorithms
264
+
265
+ Concrete algorithms implement the small `Algorithm` protocol: compute
266
+ advantages, optionally filter a prepared batch, then compute a loss from policy,
267
+ old-policy, and reference logprobs. `rollout/`, `policy/`, and `train/` do not
268
+ import concrete algorithm implementations.
269
+
270
+ | algorithm | defining behavior |
271
+ | --- | --- |
272
+ | GRPO | group-normalized rewards, token-level importance ratio |
273
+ | Dr. GRPO | centered or normalized rewards with decoupled length reduction |
274
+ | DAPO | asymmetric low/high clipping plus optional dynamic zero-advantage group filtering |
275
+ | GSPO | sequence-level, length-normalized importance ratio and clipping |
276
+ | RLOO | leave-one-out group baseline, no critic, no std-normalized advantage |
277
+
278
+ ## Memory Preflight
279
+
280
+ `mlxrl train` can estimate memory before loading the model:
281
+
282
+ ```bash
283
+ UV_CACHE_DIR=.uv-cache uv run mlxrl train \
284
+ --config examples/qwen3_0_6b_grpo.toml \
285
+ --available-memory-gb 48 \
286
+ --dry-run
287
+ ```
288
+
289
+ The estimator is calibrated to measured anchors: `6.245 GB` for
290
+ Qwen3-0.6B/G4/prompt≈19/T256, `25.9 GB` for
291
+ Qwen3.5-9B/G2/seq609/per-layer-checkpointed, `45.9 GB` for
292
+ Qwen3.5-9B/G4/seq609/per-layer-checkpointed, and `36 GB` for
293
+ Qwen3.5-9B/G2/seq128/no-checkpoint. For hybrid 9B no-checkpoint long-sequence
294
+ configs, it reports an OOM-risk lower bound rather than a fake precise peak.
295
+ For an obviously too-large Qwen3.5-9B/G8/prompt97/T512/no-checkpoint config on
296
+ 48 GB, it flags the run and suggests the measured-boundary fallback around
297
+ G4/T512/checkpointed.
298
+
299
+ ## Benchmarks
300
+
301
+ Local M4 Max Phase 4 snapshot:
302
+
303
+ - `454` rollout tok/s on Qwen3-0.6B GRPO with G=4 and 256-token completions.
304
+ - `0.283` end-to-end it/s with full `mlxrl` training semantics.
305
+ - `3.2x` faster rollout and `2.2x` higher end-to-end it/s than `mlx-tune`
306
+ v0.5.1 on the same run shape.
307
+ - `1.3x` faster rollout than sequential `mlx-lm` generation at G=4.
308
+
309
+ These are the two-pass means from
310
+ `benchmarks/results/gate5_full_reconciled.md`, run with MLX 0.31.2,
311
+ MLX-LM 0.31.3, `mlx-community/Qwen3-0.6B-4bit`, 100 measured steps with
312
+ 5 warmup steps discarded:
313
+
314
+ | target | comparison | rollout tok/s | grad s/step | samples/s | it/s | peak GB |
315
+ | --- | --- | ---: | ---: | ---: | ---: | ---: |
316
+ | `mlxrl` | apples-to-apples GRPO | 454.1 | 1.282 | 1.133 | 0.283 | 6.25 |
317
+ | `mlx-lm` | generation-only, G=1 | 347.0 | - | 1.355 | - | 0.52 |
318
+ | `mlx-lm-g4` | generation-only, sequential G=4 | 349.7 | - | 1.366 | - | 0.52 |
319
+ | `mlx-tune` | package-speed reference | 142.2 | 0.502 | 0.519 | 0.130 | 6.16 |
320
+ | `mlx-lm-lora` | package-speed reference | 557.9 | 0.592 | 1.648 | 0.412 | 5.32 |
321
+
322
+ `mlx-lm-lora` reports higher raw package-speed throughput in this snapshot, but
323
+ its benchmarked path is not the same training problem as `mlxrl`'s live
324
+ old-policy/reference semantics and completion-loss masking. That is the honest
325
+ case where `mlxrl` is not faster; the apples-to-apples comparison label is
326
+ reserved for `mlxrl`'s own semantic path. On the 9B Noether real workload, the
327
+ checkpointed MLX path measured about 6x faster than the previous torch-MPS path;
328
+ that workload is separate from the public Phase 4 package-speed harness.
329
+
330
+ Run the Phase 4 harness:
331
+
332
+ ```bash
333
+ UV_CACHE_DIR=.uv-cache uv run python benchmarks/run_phase4.py run \
334
+ --targets mlxrl,mlx-lm,mlx-tune,mlx-lm-lora \
335
+ --model mlx-community/Qwen3-0.6B-4bit \
336
+ --steps 100 \
337
+ --warmup-steps 5 \
338
+ --group-size 4 \
339
+ --max-tokens 256 \
340
+ --passes 2 \
341
+ --output benchmarks/results/phase4.jsonl \
342
+ --summary benchmarks/results/phase4.md \
343
+ --allow-missing-baselines
344
+ ```
345
+
346
+ The harness reports synchronized rollout tok/s, gradient seconds per step,
347
+ samples/s, it/s, and peak MLX memory. `mlx-lm` targets are generation-only;
348
+ external package targets are useful speed references but may not match `mlxrl`
349
+ training semantics.
350
+
351
+ ## Development
352
+
353
+ See [CONTRIBUTING.md](CONTRIBUTING.md) and [DESIGN.md](DESIGN.md) before adding
354
+ algorithms or changing rollout/logprob semantics.
355
+
356
+ Run the quality gates:
357
+
358
+ ```bash
359
+ UV_CACHE_DIR=.uv-cache uv run pytest
360
+ UV_CACHE_DIR=.uv-cache uv run ruff check .
361
+ UV_CACHE_DIR=.uv-cache uv run pyright
362
+ ```
363
+
364
+ MLX lazy evaluation matters. Any `mx.eval(...)` or `mx.synchronize()` in this
365
+ repo should mark a real boundary: sampled token append/EOS checks, logprob
366
+ freezing before adapter mutation, per-micro-batch graph release, optimizer
367
+ updates, or benchmark timing boundaries.
368
+
369
+ ## Layout
370
+
371
+ ```text
372
+ mlxrl/
373
+ rollout/ # batched group generation
374
+ policy/ # model loading, LoRA setup, logprob passes
375
+ algo/ # GRPO-family advantages and losses
376
+ train/ # value_and_grad and optimizer integration
377
+ data/ # toy GSM8K data and rewards
378
+ cli.py
379
+ tests/
380
+ benchmarks/
381
+ ```
382
+
383
+ ## Non-Goals
384
+
385
+ - No inference server or second model copy.
386
+ - No CUDA or torch fallback.
387
+ - No distributed training.
388
+ - No broad RL framework abstractions beyond the small algorithm interface.