mlx-recurrence 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Paul O. Derrington, Jr.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,299 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-recurrence
3
+ Version: 0.3.0
4
+ Summary: A plug-in framework for linear-recurrence Metal kernels on Apple Silicon (SSD, GLA, RG-LRU)
5
+ Author-email: "Paul O. Derrington, Jr." <derrington.collaborative.ai@gmail.com>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/d-csil/mlx-recurrence
8
+ Project-URL: Repository, https://github.com/d-csil/mlx-recurrence
9
+ Project-URL: Issues, https://github.com/d-csil/mlx-recurrence/issues
10
+ Keywords: mlx,apple-silicon,metal,ssm,mamba,gla,rglru,griffin,recurrentgemma,flash-linear-attention,linear-recurrence,state-space-model,machine-learning
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Operating System :: MacOS
20
+ Requires-Python: >=3.10
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: mlx>=0.22.0
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest>=7.0; extra == "dev"
26
+ Requires-Dist: numpy>=1.24; extra == "dev"
27
+ Provides-Extra: benchmark
28
+ Requires-Dist: numpy>=1.24; extra == "benchmark"
29
+ Dynamic: license-file
30
+
31
+ # mlx-recurrence
32
+
33
+ A plug-in framework of **fused Metal GPU kernels for linear recurrence on Apple
34
+ Silicon** — think *flash-linear-attention for [MLX](https://github.com/ml-explore/mlx)*.
35
+
36
+ Sequential recurrences (SSMs, gated linear attention, diagonal RNNs) are the one
37
+ thing MLX cannot fuse for free: a Python loop over `L` timesteps costs `L`
38
+ Python→Metal dispatches. These kernels collapse the entire recurrence into a
39
+ single dispatch, with a **segment-checkpoint + recompute backward** that cuts
40
+ training memory by 12–18× over storing the full state history.
41
+
42
+ | Kernel | Recurrence | Used by | State |
43
+ |---|---|---|---|
44
+ | `ssd_scan` | Mamba-2-style head-wise SSD selective scan | Mamba-2 / SSM hybrids | `[B, H, Dh, N]` |
45
+ | `gla_scan` | Gated Linear Attention (scalar forget gate, outer-product write) | GLA / linear-attention hybrids | `[B, H, Dh, Dh]` |
46
+ | `rglru_scan` | RG-LRU diagonal scan | Griffin / RecurrentGemma-style | `[B, D]` |
47
+ | `rotlru_scan` | Rotational LRU: complex-diagonal scan (magnitude gate + per-step rotation of 2D channel pairs) | Complex-LRU / S4-style oscillatory memory | `[B, D]` (interleaved pairs) |
48
+
49
+ Each kernel is a self-contained plug-in on a shared chassis
50
+ (`mlx_recurrence._chassis`) that provides the checkpoint+recompute pattern,
51
+ shape validation, and a parity-test helper — adding a new recurrence means
52
+ writing one Metal source pair and its VJP wiring, not rebuilding the
53
+ infrastructure. The original v0.1 kernels remain available under
54
+ `mlx_recurrence.legacy` (and re-exported at top level) for backwards
55
+ compatibility.
56
+
57
+ ## Validated in production
58
+
59
+ These are not microbenchmark-only kernels. The v2 SSD and GLA kernels were
60
+ dropped into a live multi-week D-CSIL SSM+GLA hybrid training run mid-flight
61
+ (checkpoint pause → parity gates → resume), on an M3 Max, bf16, batch 3,
62
+ L=512:
63
+
64
+ | Gate | v1 (full state history) | v2 (checkpoint + recompute) |
65
+ |---|---|---|
66
+ | Kernel parity (fwd + every gradient) | — | ~1e-7 rel, all gates PASS |
67
+ | Peak training memory | 23.88 GB | **10.34 GB** |
68
+ | Sustained tokens/sec | ~1,074 | **~1,481–1,540 (≈1.4×)** |
69
+ | Loss continuity across the swap | — | clean (no NaN/inf, same loss band) |
70
+
71
+ Full report: [`docs/validation/V3_VALIDATION_REPORT_20260610.md`](docs/validation/V3_VALIDATION_REPORT_20260610.md)
72
+ (the consuming training repo names these kernels "v3" in its shim — same code).
73
+ Note the baseline above is a run *already using fused v0.1-style kernels* —
74
+ the gains over having no custom kernels at all are far larger (next section).
75
+
76
+ ## Benchmarks
77
+
78
+ Two baselines matter, and they answer different questions:
79
+
80
+ 1. **vs. no custom kernels at all** — the Python per-step loop or chunked-MLX
81
+ fallback a user would otherwise write. This is the speedup you get by
82
+ adopting the package.
83
+ 2. **v2 vs. v0.1-style full-history kernels** — what the
84
+ checkpoint+recompute redesign adds on top, mainly for training memory.
85
+
86
+ ### 1. Fused kernels vs. no custom kernels (M3 Max, seq_len=2048)
87
+
88
+ | Pass | SSM | GLA |
89
+ |---|---|---|
90
+ | Forward — fused kernel vs Python per-step loop | **7.3×** | **9.1×** |
91
+ | Forward + backward — fused VJP vs chunked-MLX autograd | **19.0×** | **31.8×** |
92
+
93
+ (Measured for the v0.1 release; charts in `benchmarks/`. Without fused
94
+ kernels, training these recurrences on Apple Silicon is impractical above
95
+ seq_len ≈ 512 — the backward pass is the killer.)
96
+
97
+ The v2 kernels measured faster still at training shapes (next table), so
98
+ vs-no-kernels speedups for v2 are expected to be at least this large. A
99
+ direct single-shape v2-vs-no-kernels measurement is planned once the current
100
+ production run frees the GPU, and will replace this estimate.
101
+
102
+ ### 2. v2 vs. v0.1-style full-history kernels (training shapes: B=3, L=512, H=12, Dh=64)
103
+
104
+ | Kernel | fwd | fwd + bwd | peak memory |
105
+ |---|---|---|---|
106
+ | SSD, full-history baseline | 3.14 ms | 32.22 ms | 1,792 MB |
107
+ | **SSD v2** | **2.30 ms** | **17.34 ms (1.86×)** | **145 MB (12×)** |
108
+ | GLA, full-history baseline | 2.10 ms | 17.92 ms | 1,477 MB |
109
+ | **GLA v2** | **1.41 ms** | **12.06 ms (1.49×)** | **81 MB (18×)** |
110
+
111
+ The memory column is the one that matters for training: the baseline stores
112
+ every per-timestep state for the backward pass; the v2 kernels store only
113
+ segment-boundary checkpoints (1/32 of the writes) and recompute each segment
114
+ into a small scratch buffer that stays cache-resident during the adjoint sweep.
115
+
116
+ ## Installation
117
+
118
+ ```bash
119
+ # v0.3.0 framework (main branch)
120
+ pip install git+https://github.com/D-CSIL/mlx-recurrence.git
121
+
122
+ # or from a clone
123
+ git clone https://github.com/D-CSIL/mlx-recurrence.git
124
+ cd mlx-recurrence && pip install -e .
125
+ ```
126
+
127
+ (PyPI release planned; install from GitHub until then. The legacy v0.1-era
128
+ kernels need no separate install — they ship inside this package under
129
+ `mlx_recurrence.legacy` with top-level re-exports.)
130
+
131
+ Requires: Python >= 3.10, MLX >= 0.22.0, Apple Silicon Mac (Metal GPU).
132
+
133
+ ## Usage (v2 kernels)
134
+
135
+ All v2 kernels are fully differentiable (`mx.grad` / `mx.value_and_grad`
136
+ work through them via custom VJPs), keep **fp32 state and accumulation
137
+ regardless of input dtype** (bf16 inputs widen implicitly), and share two
138
+ shape constraints from the checkpoint + simd-reduction pattern:
139
+
140
+ ```
141
+ L % seg == 0 # sequence tiles into segments (seg defaults to 32)
142
+ lane_dim % 32 == 0 # Dh for ssd/gla, D for rglru (32-lane simdgroups)
143
+ ```
144
+
145
+ ### SSD selective scan (Mamba-2 style)
146
+
147
+ ```python
148
+ import mlx.core as mx
149
+ from mlx_recurrence import ssd_scan, ssd_scan_with_state
150
+
151
+ B, L, H, Dh, N = 3, 512, 12, 64, 16
152
+
153
+ u = mx.random.normal((B, L, H, Dh)) # input
154
+ delta = mx.abs(mx.random.normal((B, L, H))) * 0.1 + 0.01 # per-token step size
155
+ B_in = mx.random.normal((B, L, H, N)) # input projection
156
+ C_in = mx.random.normal((B, L, H, N)) # output projection
157
+ A_neg = -mx.exp(mx.random.normal((H, N))) # decay rates, < 0
158
+
159
+ y = ssd_scan(u, delta, B_in, C_in, A_neg) # -> [B, L, H, Dh]
160
+ y, final_state = ssd_scan_with_state(u, delta, B_in, C_in, A_neg) # chunked prefill
161
+ ```
162
+
163
+ ### GLA recurrence
164
+
165
+ ```python
166
+ from mlx_recurrence import gla_scan, gla_scan_with_state
167
+
168
+ B, L, H, Dh = 3, 512, 12, 64
169
+
170
+ q = mx.random.normal((B, L, H, Dh)) * (Dh ** -0.5) # pre-scaled / post-RoPE
171
+ k = mx.random.normal((B, L, H, Dh))
172
+ v = mx.random.normal((B, L, H, Dh))
173
+ gates = mx.sigmoid(mx.random.normal((B, L, H))) # scalar forget gate, (0,1)
174
+
175
+ o = gla_scan(q, k, v, gates) # -> [B, L, H, Dh]
176
+ o, final_state = gla_scan_with_state(q, k, v, gates) # state: [B, H, Dh, Dh]
177
+ ```
178
+
179
+ ### RG-LRU diagonal scan (Griffin / RecurrentGemma)
180
+
181
+ The kernel handles the inner linear scan `h_t = a_t ⊙ h_{t-1} + b_t`; compute
182
+ the gate `a` and the already-gated input `b` in pure MLX (cheap, elementwise,
183
+ auto-differentiable) and pass them in. The kernel only multiplies — `a` may be
184
+ any real value, not just `(0, 1)` (negative / oscillating gates are covered by
185
+ the test suite).
186
+
187
+ ```python
188
+ from mlx_recurrence import rglru_scan, rglru_scan_with_state
189
+
190
+ B, L, D = 3, 512, 1536
191
+
192
+ a = mx.sigmoid(mx.random.normal((B, L, D))) # per-channel gate
193
+ b = mx.random.normal((B, L, D)) # gated input
194
+
195
+ y = rglru_scan(a, b) # -> [B, L, D]
196
+ y, final_state = rglru_scan_with_state(a, b) # state: [B, D]
197
+ ```
198
+
199
+ ### Rotational LRU (complex-diagonal scan)
200
+
201
+ Generalizes `rglru_scan` from a real diagonal gate to a complex one: each
202
+ interleaved channel pair `(u, w)` is scaled by a magnitude gate AND rotated
203
+ by an angle every step — `h_t = a_t · e^{iθ_t} · h_{t-1} + b_t` in complex
204
+ form, the eigenvalue structure of the complex LRU and S4-style oscillatory
205
+ memory. Pass `cos(θ)`/`sin(θ)` computed in MLX host code; gradients w.r.t.
206
+ the angle chain through them automatically.
207
+
208
+ ```python
209
+ from mlx_recurrence import rotlru_scan, rotlru_scan_with_state
210
+
211
+ B, L, D = 3, 512, 1536
212
+ Dp = D // 2 # channel pairs
213
+
214
+ a = mx.sigmoid(mx.random.normal((B, L, Dp))) # magnitude gate per pair
215
+ theta = mx.random.uniform(0.0, 3.14, (B, L, Dp)) # rotation per step
216
+ b = mx.random.normal((B, L, D)) # drive, pairs interleaved
217
+
218
+ y = rotlru_scan(a, mx.cos(theta), mx.sin(theta), b) # -> [B, L, D]
219
+ y, final_state = rotlru_scan_with_state(a, mx.cos(theta), mx.sin(theta), b)
220
+ ```
221
+
222
+ Validated by the parity suite (forward + every gradient vs reference,
223
+ negative gates, θ=0 reduces exactly to `rglru_scan`, isometry check) and
224
+ exercised by a 10k-step training run; microbenchmarks pending.
225
+
226
+ Every kernel ships a pure-MLX reference (`*_scan_reference`) for parity
227
+ testing and as a fallback on shapes that violate the constraints.
228
+
229
+ ## Testing
230
+
231
+ ```bash
232
+ pytest tests/ # 45 tests, ~4 s, tiny shapes
233
+ ```
234
+
235
+ - `tests/test_v2_ssd.py`, `test_v2_gla.py`, `test_v2_rglru.py`,
236
+ `test_v2_rotlru.py` — framework parity suites: forward output **and every
237
+ gradient** compared against the pure-MLX reference (two shape configs per
238
+ kernel, multi-segment, plus final-state checks). Negative-gate coverage
239
+ for `rglru`/`rotlru`; θ=0→rglru reduction and isometry checks for `rotlru`.
240
+ - `tests/test_v2_legacy_compat.py` — the legacy top-level re-exports keep
241
+ working.
242
+ - `tests/test_kernels.py`, `test_backward_metal.py` — original v0.1 suites,
243
+ unchanged.
244
+
245
+ ## Implementation details
246
+
247
+ ### The chassis pattern (shared by all v2 kernels)
248
+
249
+ **Forward:** run the recurrence once; write only the state at each segment
250
+ boundary (`seg=32` → 1/32 the state writes). The last checkpoint doubles as
251
+ the chunk's final state, enabling chunked prefill via the `*_with_state`
252
+ variants.
253
+
254
+ **Backward:** walk segments newest → oldest. For each segment, recompute its
255
+ per-timestep states from the preceding checkpoint into a small scratch buffer
256
+ (one segment's worth — stays resident in the system-level cache instead of
257
+ streaming the full history through DRAM), then run the adjoint sweep.
258
+ Cross-lane gradient reductions are fused in-kernel with `simd_sum` over
259
+ 32-lane simdgroups; the remaining sum over simdgroups is one cheap MLX
260
+ reduction. Recompute runs the same fp32 ops in the same order from the same
261
+ checkpoint, so it reproduces the forward states bit-exactly.
262
+
263
+ ### Per-kernel thread mapping
264
+
265
+ - **SSD** — one thread per `(batch, head, channel)`; the `N`-element state
266
+ lives in registers across all `L` steps. Checkpoints laid out `[B, nSeg, H,
267
+ N, Dh]` with `Dh` fastest so simdgroup lanes read/write coalesced.
268
+ - **GLA** — one thread per `(batch·head, j)`; each thread owns one column of
269
+ the `Dh×Dh` state matrix in registers. `grad_v` is exact per-thread;
270
+ `grad_q`/`grad_k`/`grad_gates` are j-lane `simd_sum` partials.
271
+ - **RG-LRU** — one thread per `(batch, channel)` owning the scalar `h[d]`.
272
+ Diagonal state means no cross-lane reductions at all — the simplest plug-in,
273
+ and the template to copy when adding a new diagonal recurrence.
274
+ - **Rotational LRU** — one thread per `(batch, pair)` owning the `(u, w)`
275
+ register pair; the 2×2 rotation is applied in-register. Pair-diagonal, so
276
+ like RG-LRU it needs no cross-lane reductions.
277
+
278
+ ### Legacy v0.1 kernels
279
+
280
+ The original token-loop kernels (`selective_scan_metal`, `gla_scan_metal`,
281
+ and the chunked pure-MLX fallbacks) are unchanged under
282
+ `mlx_recurrence.legacy` and re-exported at top level. They store the full
283
+ state history for the backward pass (fine for inference and short-sequence
284
+ training) and have no shape constraints. Original benchmarks (M3 Max,
285
+ seq_len=2048): 7.3×/9.1× forward speedup over the Python loop and 19×/31.8×
286
+ fwd+bwd over chunked-MLX autograd for SSM/GLA respectively; charts in
287
+ `benchmarks/`.
288
+
289
+ ## Citation
290
+
291
+ If you use mlx-recurrence in your work, please credit:
292
+
293
+ > Paul O. Derrington, Jr. — Derrington Collaborative Synthetic Intelligence Labs (D-CSIL)
294
+
295
+ ## License
296
+
297
+ MIT License — Copyright (c) 2026 Paul O. Derrington, Jr.
298
+
299
+ Matches the MLX license. See [LICENSE](LICENSE).
@@ -0,0 +1,269 @@
1
+ # mlx-recurrence
2
+
3
+ A plug-in framework of **fused Metal GPU kernels for linear recurrence on Apple
4
+ Silicon** — think *flash-linear-attention for [MLX](https://github.com/ml-explore/mlx)*.
5
+
6
+ Sequential recurrences (SSMs, gated linear attention, diagonal RNNs) are the one
7
+ thing MLX cannot fuse for free: a Python loop over `L` timesteps costs `L`
8
+ Python→Metal dispatches. These kernels collapse the entire recurrence into a
9
+ single dispatch, with a **segment-checkpoint + recompute backward** that cuts
10
+ training memory by 12–18× over storing the full state history.
11
+
12
+ | Kernel | Recurrence | Used by | State |
13
+ |---|---|---|---|
14
+ | `ssd_scan` | Mamba-2-style head-wise SSD selective scan | Mamba-2 / SSM hybrids | `[B, H, Dh, N]` |
15
+ | `gla_scan` | Gated Linear Attention (scalar forget gate, outer-product write) | GLA / linear-attention hybrids | `[B, H, Dh, Dh]` |
16
+ | `rglru_scan` | RG-LRU diagonal scan | Griffin / RecurrentGemma-style | `[B, D]` |
17
+ | `rotlru_scan` | Rotational LRU: complex-diagonal scan (magnitude gate + per-step rotation of 2D channel pairs) | Complex-LRU / S4-style oscillatory memory | `[B, D]` (interleaved pairs) |
18
+
19
+ Each kernel is a self-contained plug-in on a shared chassis
20
+ (`mlx_recurrence._chassis`) that provides the checkpoint+recompute pattern,
21
+ shape validation, and a parity-test helper — adding a new recurrence means
22
+ writing one Metal source pair and its VJP wiring, not rebuilding the
23
+ infrastructure. The original v0.1 kernels remain available under
24
+ `mlx_recurrence.legacy` (and re-exported at top level) for backwards
25
+ compatibility.
26
+
27
+ ## Validated in production
28
+
29
+ These are not microbenchmark-only kernels. The v2 SSD and GLA kernels were
30
+ dropped into a live multi-week D-CSIL SSM+GLA hybrid training run mid-flight
31
+ (checkpoint pause → parity gates → resume), on an M3 Max, bf16, batch 3,
32
+ L=512:
33
+
34
+ | Gate | v1 (full state history) | v2 (checkpoint + recompute) |
35
+ |---|---|---|
36
+ | Kernel parity (fwd + every gradient) | — | ~1e-7 rel, all gates PASS |
37
+ | Peak training memory | 23.88 GB | **10.34 GB** |
38
+ | Sustained tokens/sec | ~1,074 | **~1,481–1,540 (≈1.4×)** |
39
+ | Loss continuity across the swap | — | clean (no NaN/inf, same loss band) |
40
+
41
+ Full report: [`docs/validation/V3_VALIDATION_REPORT_20260610.md`](docs/validation/V3_VALIDATION_REPORT_20260610.md)
42
+ (the consuming training repo names these kernels "v3" in its shim — same code).
43
+ Note the baseline above is a run *already using fused v0.1-style kernels* —
44
+ the gains over having no custom kernels at all are far larger (next section).
45
+
46
+ ## Benchmarks
47
+
48
+ Two baselines matter, and they answer different questions:
49
+
50
+ 1. **vs. no custom kernels at all** — the Python per-step loop or chunked-MLX
51
+ fallback a user would otherwise write. This is the speedup you get by
52
+ adopting the package.
53
+ 2. **v2 vs. v0.1-style full-history kernels** — what the
54
+ checkpoint+recompute redesign adds on top, mainly for training memory.
55
+
56
+ ### 1. Fused kernels vs. no custom kernels (M3 Max, seq_len=2048)
57
+
58
+ | Pass | SSM | GLA |
59
+ |---|---|---|
60
+ | Forward — fused kernel vs Python per-step loop | **7.3×** | **9.1×** |
61
+ | Forward + backward — fused VJP vs chunked-MLX autograd | **19.0×** | **31.8×** |
62
+
63
+ (Measured for the v0.1 release; charts in `benchmarks/`. Without fused
64
+ kernels, training these recurrences on Apple Silicon is impractical above
65
+ seq_len ≈ 512 — the backward pass is the killer.)
66
+
67
+ The v2 kernels measured faster still at training shapes (next table), so
68
+ vs-no-kernels speedups for v2 are expected to be at least this large. A
69
+ direct single-shape v2-vs-no-kernels measurement is planned once the current
70
+ production run frees the GPU, and will replace this estimate.
71
+
72
+ ### 2. v2 vs. v0.1-style full-history kernels (training shapes: B=3, L=512, H=12, Dh=64)
73
+
74
+ | Kernel | fwd | fwd + bwd | peak memory |
75
+ |---|---|---|---|
76
+ | SSD, full-history baseline | 3.14 ms | 32.22 ms | 1,792 MB |
77
+ | **SSD v2** | **2.30 ms** | **17.34 ms (1.86×)** | **145 MB (12×)** |
78
+ | GLA, full-history baseline | 2.10 ms | 17.92 ms | 1,477 MB |
79
+ | **GLA v2** | **1.41 ms** | **12.06 ms (1.49×)** | **81 MB (18×)** |
80
+
81
+ The memory column is the one that matters for training: the baseline stores
82
+ every per-timestep state for the backward pass; the v2 kernels store only
83
+ segment-boundary checkpoints (1/32 of the writes) and recompute each segment
84
+ into a small scratch buffer that stays cache-resident during the adjoint sweep.
85
+
86
+ ## Installation
87
+
88
+ ```bash
89
+ # v0.3.0 framework (main branch)
90
+ pip install git+https://github.com/D-CSIL/mlx-recurrence.git
91
+
92
+ # or from a clone
93
+ git clone https://github.com/D-CSIL/mlx-recurrence.git
94
+ cd mlx-recurrence && pip install -e .
95
+ ```
96
+
97
+ (PyPI release planned; install from GitHub until then. The legacy v0.1-era
98
+ kernels need no separate install — they ship inside this package under
99
+ `mlx_recurrence.legacy` with top-level re-exports.)
100
+
101
+ Requires: Python >= 3.10, MLX >= 0.22.0, Apple Silicon Mac (Metal GPU).
102
+
103
+ ## Usage (v2 kernels)
104
+
105
+ All v2 kernels are fully differentiable (`mx.grad` / `mx.value_and_grad`
106
+ work through them via custom VJPs), keep **fp32 state and accumulation
107
+ regardless of input dtype** (bf16 inputs widen implicitly), and share two
108
+ shape constraints from the checkpoint + simd-reduction pattern:
109
+
110
+ ```
111
+ L % seg == 0 # sequence tiles into segments (seg defaults to 32)
112
+ lane_dim % 32 == 0 # Dh for ssd/gla, D for rglru (32-lane simdgroups)
113
+ ```
114
+
115
+ ### SSD selective scan (Mamba-2 style)
116
+
117
+ ```python
118
+ import mlx.core as mx
119
+ from mlx_recurrence import ssd_scan, ssd_scan_with_state
120
+
121
+ B, L, H, Dh, N = 3, 512, 12, 64, 16
122
+
123
+ u = mx.random.normal((B, L, H, Dh)) # input
124
+ delta = mx.abs(mx.random.normal((B, L, H))) * 0.1 + 0.01 # per-token step size
125
+ B_in = mx.random.normal((B, L, H, N)) # input projection
126
+ C_in = mx.random.normal((B, L, H, N)) # output projection
127
+ A_neg = -mx.exp(mx.random.normal((H, N))) # decay rates, < 0
128
+
129
+ y = ssd_scan(u, delta, B_in, C_in, A_neg) # -> [B, L, H, Dh]
130
+ y, final_state = ssd_scan_with_state(u, delta, B_in, C_in, A_neg) # chunked prefill
131
+ ```
132
+
133
+ ### GLA recurrence
134
+
135
+ ```python
136
+ from mlx_recurrence import gla_scan, gla_scan_with_state
137
+
138
+ B, L, H, Dh = 3, 512, 12, 64
139
+
140
+ q = mx.random.normal((B, L, H, Dh)) * (Dh ** -0.5) # pre-scaled / post-RoPE
141
+ k = mx.random.normal((B, L, H, Dh))
142
+ v = mx.random.normal((B, L, H, Dh))
143
+ gates = mx.sigmoid(mx.random.normal((B, L, H))) # scalar forget gate, (0,1)
144
+
145
+ o = gla_scan(q, k, v, gates) # -> [B, L, H, Dh]
146
+ o, final_state = gla_scan_with_state(q, k, v, gates) # state: [B, H, Dh, Dh]
147
+ ```
148
+
149
+ ### RG-LRU diagonal scan (Griffin / RecurrentGemma)
150
+
151
+ The kernel handles the inner linear scan `h_t = a_t ⊙ h_{t-1} + b_t`; compute
152
+ the gate `a` and the already-gated input `b` in pure MLX (cheap, elementwise,
153
+ auto-differentiable) and pass them in. The kernel only multiplies — `a` may be
154
+ any real value, not just `(0, 1)` (negative / oscillating gates are covered by
155
+ the test suite).
156
+
157
+ ```python
158
+ from mlx_recurrence import rglru_scan, rglru_scan_with_state
159
+
160
+ B, L, D = 3, 512, 1536
161
+
162
+ a = mx.sigmoid(mx.random.normal((B, L, D))) # per-channel gate
163
+ b = mx.random.normal((B, L, D)) # gated input
164
+
165
+ y = rglru_scan(a, b) # -> [B, L, D]
166
+ y, final_state = rglru_scan_with_state(a, b) # state: [B, D]
167
+ ```
168
+
169
+ ### Rotational LRU (complex-diagonal scan)
170
+
171
+ Generalizes `rglru_scan` from a real diagonal gate to a complex one: each
172
+ interleaved channel pair `(u, w)` is scaled by a magnitude gate AND rotated
173
+ by an angle every step — `h_t = a_t · e^{iθ_t} · h_{t-1} + b_t` in complex
174
+ form, the eigenvalue structure of the complex LRU and S4-style oscillatory
175
+ memory. Pass `cos(θ)`/`sin(θ)` computed in MLX host code; gradients w.r.t.
176
+ the angle chain through them automatically.
177
+
178
+ ```python
179
+ from mlx_recurrence import rotlru_scan, rotlru_scan_with_state
180
+
181
+ B, L, D = 3, 512, 1536
182
+ Dp = D // 2 # channel pairs
183
+
184
+ a = mx.sigmoid(mx.random.normal((B, L, Dp))) # magnitude gate per pair
185
+ theta = mx.random.uniform(0.0, 3.14, (B, L, Dp)) # rotation per step
186
+ b = mx.random.normal((B, L, D)) # drive, pairs interleaved
187
+
188
+ y = rotlru_scan(a, mx.cos(theta), mx.sin(theta), b) # -> [B, L, D]
189
+ y, final_state = rotlru_scan_with_state(a, mx.cos(theta), mx.sin(theta), b)
190
+ ```
191
+
192
+ Validated by the parity suite (forward + every gradient vs reference,
193
+ negative gates, θ=0 reduces exactly to `rglru_scan`, isometry check) and
194
+ exercised by a 10k-step training run; microbenchmarks pending.
195
+
196
+ Every kernel ships a pure-MLX reference (`*_scan_reference`) for parity
197
+ testing and as a fallback on shapes that violate the constraints.
198
+
199
+ ## Testing
200
+
201
+ ```bash
202
+ pytest tests/ # 45 tests, ~4 s, tiny shapes
203
+ ```
204
+
205
+ - `tests/test_v2_ssd.py`, `test_v2_gla.py`, `test_v2_rglru.py`,
206
+ `test_v2_rotlru.py` — framework parity suites: forward output **and every
207
+ gradient** compared against the pure-MLX reference (two shape configs per
208
+ kernel, multi-segment, plus final-state checks). Negative-gate coverage
209
+ for `rglru`/`rotlru`; θ=0→rglru reduction and isometry checks for `rotlru`.
210
+ - `tests/test_v2_legacy_compat.py` — the legacy top-level re-exports keep
211
+ working.
212
+ - `tests/test_kernels.py`, `test_backward_metal.py` — original v0.1 suites,
213
+ unchanged.
214
+
215
+ ## Implementation details
216
+
217
+ ### The chassis pattern (shared by all v2 kernels)
218
+
219
+ **Forward:** run the recurrence once; write only the state at each segment
220
+ boundary (`seg=32` → 1/32 the state writes). The last checkpoint doubles as
221
+ the chunk's final state, enabling chunked prefill via the `*_with_state`
222
+ variants.
223
+
224
+ **Backward:** walk segments newest → oldest. For each segment, recompute its
225
+ per-timestep states from the preceding checkpoint into a small scratch buffer
226
+ (one segment's worth — stays resident in the system-level cache instead of
227
+ streaming the full history through DRAM), then run the adjoint sweep.
228
+ Cross-lane gradient reductions are fused in-kernel with `simd_sum` over
229
+ 32-lane simdgroups; the remaining sum over simdgroups is one cheap MLX
230
+ reduction. Recompute runs the same fp32 ops in the same order from the same
231
+ checkpoint, so it reproduces the forward states bit-exactly.
232
+
233
+ ### Per-kernel thread mapping
234
+
235
+ - **SSD** — one thread per `(batch, head, channel)`; the `N`-element state
236
+ lives in registers across all `L` steps. Checkpoints laid out `[B, nSeg, H,
237
+ N, Dh]` with `Dh` fastest so simdgroup lanes read/write coalesced.
238
+ - **GLA** — one thread per `(batch·head, j)`; each thread owns one column of
239
+ the `Dh×Dh` state matrix in registers. `grad_v` is exact per-thread;
240
+ `grad_q`/`grad_k`/`grad_gates` are j-lane `simd_sum` partials.
241
+ - **RG-LRU** — one thread per `(batch, channel)` owning the scalar `h[d]`.
242
+ Diagonal state means no cross-lane reductions at all — the simplest plug-in,
243
+ and the template to copy when adding a new diagonal recurrence.
244
+ - **Rotational LRU** — one thread per `(batch, pair)` owning the `(u, w)`
245
+ register pair; the 2×2 rotation is applied in-register. Pair-diagonal, so
246
+ like RG-LRU it needs no cross-lane reductions.
247
+
248
+ ### Legacy v0.1 kernels
249
+
250
+ The original token-loop kernels (`selective_scan_metal`, `gla_scan_metal`,
251
+ and the chunked pure-MLX fallbacks) are unchanged under
252
+ `mlx_recurrence.legacy` and re-exported at top level. They store the full
253
+ state history for the backward pass (fine for inference and short-sequence
254
+ training) and have no shape constraints. Original benchmarks (M3 Max,
255
+ seq_len=2048): 7.3×/9.1× forward speedup over the Python loop and 19×/31.8×
256
+ fwd+bwd over chunked-MLX autograd for SSM/GLA respectively; charts in
257
+ `benchmarks/`.
258
+
259
+ ## Citation
260
+
261
+ If you use mlx-recurrence in your work, please credit:
262
+
263
+ > Paul O. Derrington, Jr. — Derrington Collaborative Synthetic Intelligence Labs (D-CSIL)
264
+
265
+ ## License
266
+
267
+ MIT License — Copyright (c) 2026 Paul O. Derrington, Jr.
268
+
269
+ Matches the MLX license. See [LICENSE](LICENSE).
@@ -0,0 +1,90 @@
1
+ """mlx_recurrence — A plug-in framework for linear-recurrence Metal kernels
2
+ on Apple Silicon ("flash-linear-attention for MLX").
3
+
4
+ Each kernel is a self-contained plug-in built on a shared chassis
5
+ (:mod:`mlx_recurrence._chassis`) that supplies the segment-checkpoint +
6
+ recompute backward pattern, shape validation, and a parity-test helper. The
7
+ Metal source for each recurrence stays in its own module, readable per-kernel.
8
+
9
+ v2 kernels (checkpoint + recompute, fused simd reductions, chunked-prefill
10
+ final-state variants):
11
+ ssd — Mamba-2-style head-wise SSD selective scan
12
+ gla — Gated Linear Attention recurrence
13
+ rglru — RG-LRU diagonal recurrence (Griffin / RecurrentGemma)
14
+ rotlru — rotational LRU: complex-diagonal scan over (u, w) pairs
15
+
16
+ The original v0.1 token-loop kernels remain available under
17
+ ``mlx_recurrence.legacy`` and are re-exported at top level for backwards
18
+ compatibility (``selective_scan_metal``, ``gla_scan_metal``, ...).
19
+ """
20
+
21
+ # --- v2 chassis-based kernels ---------------------------------------------
22
+ from .ssd import (
23
+ ssd_scan,
24
+ ssd_scan_with_state,
25
+ ssd_scan_reference,
26
+ )
27
+ from .gla import (
28
+ gla_scan,
29
+ gla_scan_with_state,
30
+ gla_scan_reference,
31
+ )
32
+ from .rglru import (
33
+ rglru_scan,
34
+ rglru_scan_with_state,
35
+ rglru_scan_reference,
36
+ )
37
+ from .rotlru import (
38
+ rotlru_scan,
39
+ rotlru_scan_with_state,
40
+ rotlru_scan_reference,
41
+ )
42
+
43
+ # --- shared chassis (public for building new plug-in kernels) -------------
44
+ from ._chassis import (
45
+ DEFAULT_SEG,
46
+ get_or_build_kernel,
47
+ check_segment_shape,
48
+ parity_check,
49
+ )
50
+
51
+ # --- legacy v0.1 kernels (backwards compatibility) ------------------------
52
+ from . import legacy
53
+ from .legacy import (
54
+ selective_scan_metal,
55
+ selective_scan_chunked,
56
+ gla_scan_metal,
57
+ gla_scan_chunked,
58
+ )
59
+
60
+ __all__ = [
61
+ # v2 SSD
62
+ "ssd_scan",
63
+ "ssd_scan_with_state",
64
+ "ssd_scan_reference",
65
+ # v2 GLA
66
+ "gla_scan",
67
+ "gla_scan_with_state",
68
+ "gla_scan_reference",
69
+ # v2 RG-LRU
70
+ "rglru_scan",
71
+ "rglru_scan_with_state",
72
+ "rglru_scan_reference",
73
+ # v2 rotational LRU
74
+ "rotlru_scan",
75
+ "rotlru_scan_with_state",
76
+ "rotlru_scan_reference",
77
+ # chassis
78
+ "DEFAULT_SEG",
79
+ "get_or_build_kernel",
80
+ "check_segment_shape",
81
+ "parity_check",
82
+ # legacy subpackage + re-exports
83
+ "legacy",
84
+ "selective_scan_metal",
85
+ "selective_scan_chunked",
86
+ "gla_scan_metal",
87
+ "gla_scan_chunked",
88
+ ]
89
+
90
+ __version__ = "0.3.0"