mlx-recurrence 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,90 @@
1
+ """mlx_recurrence — A plug-in framework for linear-recurrence Metal kernels
2
+ on Apple Silicon ("flash-linear-attention for MLX").
3
+
4
+ Each kernel is a self-contained plug-in built on a shared chassis
5
+ (:mod:`mlx_recurrence._chassis`) that supplies the segment-checkpoint +
6
+ recompute backward pattern, shape validation, and a parity-test helper. The
7
+ Metal source for each recurrence stays in its own module, readable per-kernel.
8
+
9
+ v2 kernels (checkpoint + recompute, fused simd reductions, chunked-prefill
10
+ final-state variants):
11
+ ssd — Mamba-2-style head-wise SSD selective scan
12
+ gla — Gated Linear Attention recurrence
13
+ rglru — RG-LRU diagonal recurrence (Griffin / RecurrentGemma)
14
+ rotlru — rotational LRU: complex-diagonal scan over (u, w) pairs
15
+
16
+ The original v0.1 token-loop kernels remain available under
17
+ ``mlx_recurrence.legacy`` and are re-exported at top level for backwards
18
+ compatibility (``selective_scan_metal``, ``gla_scan_metal``, ...).
19
+ """
20
+
21
+ # --- v2 chassis-based kernels ---------------------------------------------
22
+ from .ssd import (
23
+ ssd_scan,
24
+ ssd_scan_with_state,
25
+ ssd_scan_reference,
26
+ )
27
+ from .gla import (
28
+ gla_scan,
29
+ gla_scan_with_state,
30
+ gla_scan_reference,
31
+ )
32
+ from .rglru import (
33
+ rglru_scan,
34
+ rglru_scan_with_state,
35
+ rglru_scan_reference,
36
+ )
37
+ from .rotlru import (
38
+ rotlru_scan,
39
+ rotlru_scan_with_state,
40
+ rotlru_scan_reference,
41
+ )
42
+
43
+ # --- shared chassis (public for building new plug-in kernels) -------------
44
+ from ._chassis import (
45
+ DEFAULT_SEG,
46
+ get_or_build_kernel,
47
+ check_segment_shape,
48
+ parity_check,
49
+ )
50
+
51
+ # --- legacy v0.1 kernels (backwards compatibility) ------------------------
52
+ from . import legacy
53
+ from .legacy import (
54
+ selective_scan_metal,
55
+ selective_scan_chunked,
56
+ gla_scan_metal,
57
+ gla_scan_chunked,
58
+ )
59
+
60
+ __all__ = [
61
+ # v2 SSD
62
+ "ssd_scan",
63
+ "ssd_scan_with_state",
64
+ "ssd_scan_reference",
65
+ # v2 GLA
66
+ "gla_scan",
67
+ "gla_scan_with_state",
68
+ "gla_scan_reference",
69
+ # v2 RG-LRU
70
+ "rglru_scan",
71
+ "rglru_scan_with_state",
72
+ "rglru_scan_reference",
73
+ # v2 rotational LRU
74
+ "rotlru_scan",
75
+ "rotlru_scan_with_state",
76
+ "rotlru_scan_reference",
77
+ # chassis
78
+ "DEFAULT_SEG",
79
+ "get_or_build_kernel",
80
+ "check_segment_shape",
81
+ "parity_check",
82
+ # legacy subpackage + re-exports
83
+ "legacy",
84
+ "selective_scan_metal",
85
+ "selective_scan_chunked",
86
+ "gla_scan_metal",
87
+ "gla_scan_chunked",
88
+ ]
89
+
90
+ __version__ = "0.3.0"
@@ -0,0 +1,213 @@
1
+ """_chassis.py — Shared infrastructure for chassis-based recurrence kernels.
2
+
3
+ This module factors out the machinery common to every v2 plug-in kernel
4
+ (SSD, GLA, RG-LRU) so each kernel file only has to supply its own Metal
5
+ source strings and gradient wiring. It deliberately does NOT abstract the
6
+ Metal source itself: every recurrence has a different state shape and
7
+ update rule, and the kernel bodies are meant to stay readable per-kernel.
8
+
9
+ What lives here
10
+ ---------------
11
+ 1. Kernel cache / builder around ``mx.fast.metal_kernel`` so each unique
12
+ shape configuration is JIT-compiled exactly once per process.
13
+ 2. Shape validation shared by the segment-checkpoint + recompute pattern:
14
+ the sequence length must tile evenly into segments (``L % seg == 0``)
15
+ and the simd-reduced lane dimension must be a multiple of 32.
16
+ 3. A reusable parity-test helper that compares forward output plus every
17
+ gradient against a pure-MLX reference and reports max abs / rel diffs.
18
+
19
+ The segment-checkpoint + recompute pattern (shared design)
20
+ ----------------------------------------------------------
21
+ Every kernel here follows the same playbook, tuned to the Apple Silicon
22
+ unified-memory hierarchy:
23
+
24
+ Forward: run the recurrence once, write only the state at each
25
+ segment boundary -> ``h_ckpt`` (SEG=32 => ~1/32 the writes of
26
+ saving every timestep). The last checkpoint doubles as the
27
+ chunk's final state, enabling chunked prefill.
28
+
29
+ Backward: walk segments newest -> oldest. For each segment, recompute
30
+ its per-timestep states from the preceding checkpoint into a
31
+ small scratch buffer (one segment's worth, stays resident in
32
+ the system-level cache instead of streaming the full state
33
+ history through DRAM), then run the adjoint sweep over that
34
+ segment. Cross-lane gradient reductions are fused in-kernel
35
+ with ``simd_sum`` over 32-lane simdgroups; the remaining
36
+ sum-over-simdgroups is a single cheap MLX reduction.
37
+
38
+ All kernels keep fp32 state and accumulation regardless of input dtype,
39
+ and reproduce the forward states bit-exactly on recompute (same fp32 ops,
40
+ same order, from the same checkpoint).
41
+ """
42
+
43
+ from __future__ import annotations
44
+
45
+ import mlx.core as mx
46
+
47
+ # Default segment length for the checkpoint+recompute pattern. SEG=32 is a
48
+ # sweet spot on M3 Max: matches the 32-lane simdgroup width used by the
49
+ # fused reductions and keeps the per-segment scratch buffer small enough to
50
+ # stay SLC-resident at training shapes.
51
+ DEFAULT_SEG = 32
52
+
53
+ # Simd lane width on Apple GPUs. Lane dimensions reduced with simd_sum must
54
+ # be a multiple of this.
55
+ SIMD_WIDTH = 32
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Kernel cache / builder
60
+ # ---------------------------------------------------------------------------
61
+
62
+ _kernel_cache: dict = {}
63
+
64
+
65
+ def get_or_build_kernel(name, input_names, output_names, source, header=""):
66
+ """Compile a Metal kernel once per unique ``name`` and cache it.
67
+
68
+ ``name`` should encode every shape/template constant baked into
69
+ ``source`` (e.g. ``f"ssd_fwd_{B}_{L}_{H}_{Dh}_{N}_{seg}"``) so that
70
+ distinct shapes get distinct compiled kernels and identical shapes
71
+ reuse the cached one.
72
+
73
+ ``source`` is the kernel BODY only (no ``kernel void`` signature) — MLX
74
+ generates the signature from ``input_names`` / ``output_names``. Helper
75
+ functions / includes go in ``header``.
76
+ """
77
+ if name not in _kernel_cache:
78
+ _kernel_cache[name] = mx.fast.metal_kernel(
79
+ name=name,
80
+ input_names=input_names,
81
+ output_names=output_names,
82
+ source=source,
83
+ header=header,
84
+ )
85
+ return _kernel_cache[name]
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Shape validation
90
+ # ---------------------------------------------------------------------------
91
+
92
+ def check_segment_shape(L, seg, lane_dim, lane_name="lane dimension"):
93
+ """Validate the constraints of the segment-checkpoint + simd-reduce pattern.
94
+
95
+ Args:
96
+ L: sequence length.
97
+ seg: segment length for checkpointing.
98
+ lane_dim: the dimension mapped to 32-lane simdgroups (must tile by 32).
99
+ lane_name: human-readable name of ``lane_dim`` for the error message.
100
+
101
+ Raises:
102
+ ValueError: if either constraint is violated.
103
+ """
104
+ if seg <= 0:
105
+ raise ValueError(f"seg must be positive, got {seg}")
106
+ if L % seg != 0:
107
+ raise ValueError(
108
+ f"sequence length L={L} must be divisible by seg={seg} "
109
+ f"(segment-checkpoint pattern tiles L into L/seg segments)"
110
+ )
111
+ if lane_dim % SIMD_WIDTH != 0:
112
+ raise ValueError(
113
+ f"{lane_name}={lane_dim} must be a multiple of {SIMD_WIDTH} "
114
+ f"(fused gradient reductions use {SIMD_WIDTH}-lane simdgroups)"
115
+ )
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Parity-test helper
120
+ # ---------------------------------------------------------------------------
121
+
122
+ def parity_check(
123
+ kernel_fn,
124
+ reference_fn,
125
+ inputs,
126
+ arg_names,
127
+ grad_argnums,
128
+ *,
129
+ w_out=None,
130
+ y_tol=1e-3,
131
+ grad_rtol=1e-3,
132
+ label="",
133
+ verbose=True,
134
+ ):
135
+ """Compare a kernel against a pure-MLX reference: forward + all grads.
136
+
137
+ Both ``kernel_fn`` and ``reference_fn`` take the positional ``inputs``
138
+ and return the forward output ``y``. This helper builds a scalar loss
139
+ ``sum(y * w_out)`` and compares ``mx.grad`` of that loss w.r.t. every
140
+ argument in ``grad_argnums``.
141
+
142
+ Args:
143
+ kernel_fn: the kernel under test, ``fn(*inputs) -> y``.
144
+ reference_fn: pure-MLX reference, ``fn(*inputs) -> y``.
145
+ inputs: tuple/list of input arrays (positional).
146
+ arg_names: names for each input (for readable output), same length
147
+ as ``inputs``.
148
+ grad_argnums: tuple of argument indices to differentiate.
149
+ w_out: output weighting for the scalar loss. Defaults to a
150
+ fixed-seed random tensor shaped like ``y``.
151
+ y_tol: absolute tolerance on the forward output diff.
152
+ grad_rtol: relative tolerance on each gradient diff.
153
+ label: prefix printed before results.
154
+ verbose: print per-tensor diffs.
155
+
156
+ Returns:
157
+ (ok: bool, report: dict) where ``report`` maps each compared tensor
158
+ name to ``{"abs": max_abs_diff, "rel": max_rel_diff}``.
159
+ """
160
+ inputs = list(inputs)
161
+
162
+ y_k = kernel_fn(*inputs)
163
+ y_r = reference_fn(*inputs)
164
+ mx.eval(y_k, y_r)
165
+
166
+ if w_out is None:
167
+ mx.random.seed(1234)
168
+ w_out = mx.random.normal(y_r.shape)
169
+ mx.eval(w_out)
170
+
171
+ def loss_kernel(*args):
172
+ return mx.sum(kernel_fn(*args) * w_out)
173
+
174
+ def loss_ref(*args):
175
+ return mx.sum(reference_fn(*args) * w_out)
176
+
177
+ report = {}
178
+ ok = True
179
+
180
+ y_abs = float(mx.max(mx.abs(y_k - y_r)))
181
+ y_scale = float(mx.max(mx.abs(y_r))) + 1e-8
182
+ y_rel = y_abs / y_scale
183
+ report["y"] = {"abs": y_abs, "rel": y_rel}
184
+ ok = ok and (y_abs < y_tol)
185
+
186
+ g_k = mx.grad(loss_kernel, argnums=grad_argnums)(*inputs)
187
+ g_r = mx.grad(loss_ref, argnums=grad_argnums)(*inputs)
188
+ mx.eval(g_k, g_r)
189
+
190
+ if not isinstance(g_k, (tuple, list)):
191
+ g_k = (g_k,)
192
+ g_r = (g_r,)
193
+
194
+ grad_names = [arg_names[i] for i in grad_argnums]
195
+ for name, gk, gr in zip(grad_names, g_k, g_r):
196
+ abs_diff = float(mx.max(mx.abs(gk - gr)))
197
+ scale = float(mx.max(mx.abs(gr))) + 1e-8
198
+ rel = abs_diff / scale
199
+ report[f"grad_{name}"] = {"abs": abs_diff, "rel": rel}
200
+ ok = ok and (rel < grad_rtol)
201
+
202
+ if verbose:
203
+ prefix = f"{label} " if label else ""
204
+ print(f"{prefix}y max|diff| = {y_abs:.2e} (rel {y_rel:.2e})")
205
+ for name in grad_names:
206
+ r = report[f"grad_{name}"]
207
+ print(
208
+ f"{prefix}grad_{name:<8} max|diff| = {r['abs']:.2e}"
209
+ f" (rel {r['rel']:.2e})"
210
+ )
211
+ print(f"{prefix}-> {'PASS' if ok else 'FAIL'}")
212
+
213
+ return ok, report
mlx_recurrence/gla.py ADDED
@@ -0,0 +1,346 @@
1
+ """gla.py — Gated Linear Attention recurrence (checkpoint + recompute).
2
+
3
+ v2 chassis port of the GLA recurrence kernel. Same segment-checkpoint +
4
+ recompute backward pattern as :mod:`mlx_recurrence.ssd`, applied to the
5
+ matrix-valued GLA state.
6
+
7
+ Recurrence
8
+ ----------
9
+ Per head ``head`` the state is the ``Dh x Dh`` matrix ``h[b, head, i, j]``
10
+ with a single scalar forget gate per token::
11
+
12
+ h[i, j] = gate[b,t,head] * h[i, j] + k[b,t,head,i] * v[b,t,head,j]
13
+ o[b,t,head,j] = sum_i q[b,t,head,i] * h[i, j]
14
+
15
+ i.e. ``h_t = gate_t * h_{t-1} + k_t (outer) v_t`` and ``o_t = q_t @ h_t``
16
+ (output uses the post-update state). ``gate`` is typically a sigmoid output
17
+ in ``(0, 1)``; ``q`` is assumed pre-scaled / post-RoPE.
18
+
19
+ The conceptual state tensor is ``[B, H, Dh, Dh]``. The checkpoint is laid
20
+ out ``[B, nSeg, H, Dh, Dh]`` with ``j`` fastest-moving so the 32 lanes of a
21
+ simdgroup own 32 contiguous ``j`` columns — coalesced reads/writes.
22
+
23
+ Numerics: fp32 state and accumulation regardless of input dtype; identical
24
+ update order (output uses post-update ``h``) and gradient formulas to the
25
+ validated D-CSIL-3 training kernel. grad_v is exact per-thread; grad_q,
26
+ grad_k, grad_gates are fused in-kernel via 32-lane ``simd_sum``.
27
+
28
+ Constraints:
29
+ L % seg == 0
30
+ Dh % 32 == 0 (Dh is the simd-reduced lane dimension)
31
+
32
+ Public API:
33
+ gla_scan(q, k, v, gates, seg=32) -> y
34
+ gla_scan_with_state(q, k, v, gates, seg=32) -> (y, final_state)
35
+ gla_scan_reference(q, k, v, gates) -> y (pure MLX)
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import mlx.core as mx
41
+
42
+ from ._chassis import DEFAULT_SEG, get_or_build_kernel, check_segment_shape
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Metal forward: GLA scan with segment checkpoints (no full state history)
47
+ # ---------------------------------------------------------------------------
48
+
49
+ def _gla_forward_kernel(q, k, v, gates, seg):
50
+ """Forward GLA scan writing only segment-boundary state.
51
+
52
+ One thread per (batch*head, j) owns the ``h[:, j]`` state column in
53
+ registers across all L timesteps.
54
+
55
+ Args:
56
+ q, k, v: [B, L, H, Dh]
57
+ gates: [B, L, H]
58
+ seg: segment length (L % seg == 0)
59
+
60
+ Returns:
61
+ y: [B, L, H, Dh]
62
+ h_ckpt: [B, nSeg, H, Dh, Dh] state at the END of each segment
63
+ """
64
+ B_batch, L, H, Dh = q.shape
65
+ n_seg = L // seg
66
+
67
+ source = f"""
68
+ uint j = thread_position_in_grid.x;
69
+ uint bh = thread_position_in_grid.y;
70
+ if (j >= {Dh}u || bh >= {B_batch * H}u) return;
71
+
72
+ uint b = bh / {H}u;
73
+ uint head = bh % {H}u;
74
+
75
+ float h[{Dh}];
76
+ for (int i = 0; i < {Dh}; i++) h[i] = 0.0f;
77
+
78
+ for (int t = 0; t < {L}; t++) {{
79
+ int g_idx = (b * {L} + t) * {H} + head;
80
+ int kv_base = ((b * {L} + t) * {H} + head) * {Dh};
81
+
82
+ float gate = (float)gates[g_idx];
83
+ float v_j = (float)v[kv_base + j];
84
+ float o_j = 0.0f;
85
+
86
+ for (int i = 0; i < {Dh}; i++) {{
87
+ h[i] = gate * h[i] + (float)k[kv_base + i] * v_j;
88
+ o_j += (float)q[kv_base + i] * h[i];
89
+ }}
90
+ output[kv_base + j] = o_j;
91
+
92
+ // checkpoint at end of each segment (j fastest -> coalesced)
93
+ if (((t + 1) % {seg}) == 0) {{
94
+ int s = t / {seg};
95
+ for (int i = 0; i < {Dh}; i++) {{
96
+ int ck_idx = ((((b * {n_seg} + s) * {H} + head) * {Dh} + i) * {Dh} + j);
97
+ h_ckpt[ck_idx] = h[i];
98
+ }}
99
+ }}
100
+ }}
101
+ """
102
+
103
+ kernel = get_or_build_kernel(
104
+ f"gla_fwd_{B_batch}_{L}_{H}_{Dh}_{seg}",
105
+ input_names=["q", "k", "v", "gates"],
106
+ output_names=["output", "h_ckpt"],
107
+ source=source,
108
+ )
109
+
110
+ results = kernel(
111
+ inputs=[q.reshape(-1), k.reshape(-1), v.reshape(-1), gates.reshape(-1)],
112
+ output_shapes=[
113
+ (B_batch * L * H * Dh,),
114
+ (B_batch * n_seg * H * Dh * Dh,),
115
+ ],
116
+ output_dtypes=[mx.float32, mx.float32],
117
+ grid=(Dh, B_batch * H, 1),
118
+ threadgroup=(min(Dh, 256), 1, 1),
119
+ )
120
+
121
+ y = results[0].reshape(B_batch, L, H, Dh)
122
+ h_ckpt = results[1].reshape(B_batch, n_seg, H, Dh, Dh)
123
+ return y, h_ckpt
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # Metal backward: segment recompute + fused simd-reduced gradient partials
128
+ # ---------------------------------------------------------------------------
129
+
130
+ def _gla_backward_kernel(grad_y, h_ckpt, q, k, v, gates, seg):
131
+ """Recompute states from checkpoints, then adjoint sweep with fused reductions.
132
+
133
+ Returns:
134
+ grad_v: [B, L, H, Dh] (exact, per-thread)
135
+ grad_q_p: [B, L, H, nW, Dh] (sum over nW -> grad_q)
136
+ grad_k_p: [B, L, H, nW, Dh] (sum over nW -> grad_k)
137
+ grad_g_p: [B, L, H, nW] (sum over nW -> grad_gates)
138
+ """
139
+ B_batch, L, H, Dh = q.shape
140
+ n_seg = L // seg
141
+ n_w = Dh // 32 # simdgroups (j-lane groups) per head
142
+
143
+ source = f"""
144
+ uint j = thread_position_in_grid.x;
145
+ uint bh = thread_position_in_grid.y;
146
+ if (j >= {Dh}u || bh >= {B_batch * H}u) return;
147
+
148
+ uint b = bh / {H}u;
149
+ uint head = bh % {H}u;
150
+ // threadgroup x is a multiple of 32 and x-major, so lanes of a
151
+ // simdgroup are 32 consecutive j values within one head.
152
+ uint lane = j % 32u;
153
+ uint w = j / 32u;
154
+
155
+ float adj[{Dh}];
156
+ for (int i = 0; i < {Dh}; i++) adj[i] = 0.0f;
157
+
158
+ for (int s = {n_seg - 1}; s >= 0; s--) {{
159
+ // ---- phase 1: recompute states for this segment ----
160
+ float h[{Dh}];
161
+ if (s > 0) {{
162
+ for (int i = 0; i < {Dh}; i++) {{
163
+ int ck_idx = ((((b * {n_seg} + (s - 1)) * {H} + head) * {Dh} + i) * {Dh} + j);
164
+ h[i] = h_ckpt[ck_idx];
165
+ }}
166
+ }} else {{
167
+ for (int i = 0; i < {Dh}; i++) h[i] = 0.0f;
168
+ }}
169
+
170
+ for (int tl = 0; tl < {seg}; tl++) {{
171
+ int t = s * {seg} + tl;
172
+ int g_idx = (b * {L} + t) * {H} + head;
173
+ int kv_base = ((b * {L} + t) * {H} + head) * {Dh};
174
+ float gate = (float)gates[g_idx];
175
+ float v_j = (float)v[kv_base + j];
176
+
177
+ for (int i = 0; i < {Dh}; i++) {{
178
+ h[i] = gate * h[i] + (float)k[kv_base + i] * v_j;
179
+ int sc_idx = ((((b * {H} + head) * {seg} + tl) * {Dh} + i) * {Dh} + j);
180
+ scratch[sc_idx] = h[i];
181
+ }}
182
+ }}
183
+
184
+ // ---- phase 2: adjoint sweep, newest -> oldest ----
185
+ for (int tl = {seg - 1}; tl >= 0; tl--) {{
186
+ int t = s * {seg} + tl;
187
+ int g_idx = (b * {L} + t) * {H} + head;
188
+ int kv_base = ((b * {L} + t) * {H} + head) * {Dh};
189
+
190
+ float gate = (float)gates[g_idx];
191
+ float v_j = (float)v[kv_base + j];
192
+ float go_j = (float)grad_y[kv_base + j];
193
+
194
+ float gv_j = 0.0f;
195
+ float gg = 0.0f;
196
+
197
+ for (int i = 0; i < {Dh}; i++) {{
198
+ float ki = (float)k[kv_base + i];
199
+
200
+ int sc_idx = ((((b * {H} + head) * {seg} + tl) * {Dh} + i) * {Dh} + j);
201
+ float h_cur = scratch[sc_idx];
202
+ float h_prev;
203
+ if (tl > 0) {{
204
+ h_prev = scratch[sc_idx - {Dh * Dh}];
205
+ }} else if (s > 0) {{
206
+ int ck_idx = ((((b * {n_seg} + (s - 1)) * {H} + head) * {Dh} + i) * {Dh} + j);
207
+ h_prev = h_ckpt[ck_idx];
208
+ }} else {{
209
+ h_prev = 0.0f;
210
+ }}
211
+
212
+ // driving term (adj sampled after this, before gate multiply)
213
+ adj[i] += (float)q[kv_base + i] * go_j;
214
+
215
+ gv_j += adj[i] * ki;
216
+ gg += adj[i] * h_prev;
217
+
218
+ // fused reductions over the 32 j-lanes of this simdgroup
219
+ float gq_l = simd_sum(go_j * h_cur);
220
+ float gk_l = simd_sum(adj[i] * v_j);
221
+ if (lane == 0u) {{
222
+ int p_idx = ((((b * {L} + t) * {H} + head) * {n_w} + w) * {Dh} + i);
223
+ grad_q_p[p_idx] = gq_l;
224
+ grad_k_p[p_idx] = gk_l;
225
+ }}
226
+
227
+ adj[i] *= gate;
228
+ }}
229
+
230
+ grad_v[kv_base + j] = gv_j;
231
+ float gg_l = simd_sum(gg);
232
+ if (lane == 0u) {{
233
+ grad_g_p[(((b * {L} + t) * {H} + head) * {n_w} + w)] = gg_l;
234
+ }}
235
+ }}
236
+ }}
237
+ """
238
+
239
+ kernel = get_or_build_kernel(
240
+ f"gla_bwd_{B_batch}_{L}_{H}_{Dh}_{seg}",
241
+ input_names=["grad_y", "h_ckpt", "q", "k", "v", "gates"],
242
+ output_names=["grad_v", "grad_q_p", "grad_k_p", "grad_g_p", "scratch"],
243
+ source=source,
244
+ )
245
+
246
+ results = kernel(
247
+ inputs=[grad_y.reshape(-1), h_ckpt.reshape(-1), q.reshape(-1),
248
+ k.reshape(-1), v.reshape(-1), gates.reshape(-1)],
249
+ output_shapes=[
250
+ (B_batch * L * H * Dh,),
251
+ (B_batch * L * H * n_w * Dh,),
252
+ (B_batch * L * H * n_w * Dh,),
253
+ (B_batch * L * H * n_w,),
254
+ (B_batch * H * seg * Dh * Dh,), # scratch, discarded
255
+ ],
256
+ output_dtypes=[mx.float32] * 5,
257
+ grid=(Dh, B_batch * H, 1),
258
+ threadgroup=(min(Dh, 256), 1, 1),
259
+ )
260
+
261
+ grad_v = results[0].reshape(B_batch, L, H, Dh)
262
+ grad_q_p = results[1].reshape(B_batch, L, H, n_w, Dh)
263
+ grad_k_p = results[2].reshape(B_batch, L, H, n_w, Dh)
264
+ grad_g_p = results[3].reshape(B_batch, L, H, n_w)
265
+
266
+ grad_q = mx.sum(grad_q_p, axis=3) # [B, L, H, Dh]
267
+ grad_k = mx.sum(grad_k_p, axis=3) # [B, L, H, Dh]
268
+ grad_gates = mx.sum(grad_g_p, axis=3) # [B, L, H]
269
+
270
+ return grad_q, grad_k, grad_v, grad_gates
271
+
272
+
273
+ # ---------------------------------------------------------------------------
274
+ # Custom function + VJP (one cached impl per seg, so seg can be a Python arg)
275
+ # ---------------------------------------------------------------------------
276
+
277
+ _impl_cache: dict = {}
278
+
279
+
280
+ def _make_impl(seg):
281
+ """Build (and cache) an ``mx.custom_function`` GLA impl bound to ``seg``."""
282
+ if seg in _impl_cache:
283
+ return _impl_cache[seg]
284
+
285
+ @mx.custom_function
286
+ def _impl(q, k, v, gates):
287
+ check_segment_shape(q.shape[1], seg, q.shape[3], "Dh")
288
+ return _gla_forward_kernel(q, k, v, gates, seg)
289
+
290
+ @_impl.vjp
291
+ def _vjp(primals, cotangents, outputs):
292
+ q, k, v, gates = primals
293
+ grad_y = cotangents[0]
294
+ _y, h_ckpt = outputs
295
+ return _gla_backward_kernel(grad_y, h_ckpt, q, k, v, gates, seg)
296
+
297
+ _impl_cache[seg] = _impl
298
+ return _impl
299
+
300
+
301
+ def gla_scan(q, k, v, gates, seg=DEFAULT_SEG):
302
+ """Gated Linear Attention recurrence, fused Metal forward + backward.
303
+
304
+ Args:
305
+ q: [B, L, H, Dh] queries (pre-scaled, post-RoPE)
306
+ k: [B, L, H, Dh] keys
307
+ v: [B, L, H, Dh] values
308
+ gates: [B, L, H] scalar forget gate per head, typically in (0, 1)
309
+ seg: segment length for checkpointing (L % seg == 0; default 32)
310
+
311
+ Returns:
312
+ y: [B, L, H, Dh]
313
+
314
+ Note: ``Dh`` must be a multiple of 32 (it is the simd-reduced lane dim).
315
+ fp32 state/accumulation internally; bf16 inputs widen implicitly.
316
+ """
317
+ y, _h_ckpt = _make_impl(seg)(q, k, v, gates)
318
+ return y
319
+
320
+
321
+ def gla_scan_with_state(q, k, v, gates, seg=DEFAULT_SEG):
322
+ """GLA scan that also returns the final state for chunked prefill.
323
+
324
+ Returns:
325
+ y: [B, L, H, Dh]
326
+ final_state: [B, H, Dh, Dh] (matches the GLA conceptual state)
327
+ """
328
+ y, h_ckpt = _make_impl(seg)(q, k, v, gates)
329
+ return y, h_ckpt[:, -1]
330
+
331
+
332
+ # ---------------------------------------------------------------------------
333
+ # Pure-MLX reference (slow, for parity testing only)
334
+ # ---------------------------------------------------------------------------
335
+
336
+ def gla_scan_reference(q, k, v, gates):
337
+ """Pure-MLX token-loop reference for :func:`gla_scan`. Differentiable."""
338
+ B_batch, L, H, Dh = q.shape
339
+ h = mx.zeros((B_batch, H, Dh, Dh))
340
+ ys = []
341
+ for t in range(L):
342
+ g = gates[:, t, :, None, None] # [B,H,1,1]
343
+ kv = k[:, t, :, :, None] * v[:, t, :, None, :] # [B,H,Dh,Dh]
344
+ h = g * h + kv
345
+ ys.append(mx.sum(q[:, t, :, :, None] * h, axis=-2)) # [B,H,Dh]
346
+ return mx.stack(ys, axis=1) # [B,L,H,Dh]
@@ -0,0 +1,41 @@
1
+ """mlx_recurrence.legacy — v1 token-loop Metal kernels.
2
+
3
+ These are the original (v0.1) kernels: a Metal forward that materialises
4
+ the full per-timestep state tensor h_all and a backward that reads it back.
5
+ They are kept for backwards compatibility and as a simple, readable
6
+ reference. New work should prefer the v2 chassis-based kernels
7
+ (``mlx_recurrence.ssd``, ``mlx_recurrence.gla``, ``mlx_recurrence.rglru``),
8
+ which use segment checkpointing + recompute to slash DRAM traffic.
9
+
10
+ Public API (unchanged from v0.1):
11
+ selective_scan_metal, selective_scan_chunked
12
+ gla_scan_metal, gla_scan_chunked
13
+ """
14
+
15
+ from .ssm_scan import (
16
+ selective_scan_metal,
17
+ selective_scan_chunked,
18
+ _ssm_forward_kernel,
19
+ _ssm_backward_chunked,
20
+ _ssm_backward_metal,
21
+ )
22
+ from .gla_scan import (
23
+ gla_scan_metal,
24
+ gla_scan_chunked,
25
+ _gla_forward_kernel,
26
+ _gla_backward_chunked,
27
+ _gla_backward_metal,
28
+ )
29
+
30
+ __all__ = [
31
+ "selective_scan_metal",
32
+ "selective_scan_chunked",
33
+ "gla_scan_metal",
34
+ "gla_scan_chunked",
35
+ "_ssm_forward_kernel",
36
+ "_ssm_backward_chunked",
37
+ "_ssm_backward_metal",
38
+ "_gla_forward_kernel",
39
+ "_gla_backward_chunked",
40
+ "_gla_backward_metal",
41
+ ]