mlx-recurrence 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_recurrence/__init__.py +90 -0
- mlx_recurrence/_chassis.py +213 -0
- mlx_recurrence/gla.py +346 -0
- mlx_recurrence/legacy/__init__.py +41 -0
- mlx_recurrence/legacy/_utils.py +61 -0
- mlx_recurrence/legacy/gla_scan.py +373 -0
- mlx_recurrence/legacy/ssm_scan.py +430 -0
- mlx_recurrence/rglru.py +282 -0
- mlx_recurrence/rotlru.py +351 -0
- mlx_recurrence/ssd.py +394 -0
- mlx_recurrence-0.3.0.dist-info/METADATA +299 -0
- mlx_recurrence-0.3.0.dist-info/RECORD +15 -0
- mlx_recurrence-0.3.0.dist-info/WHEEL +5 -0
- mlx_recurrence-0.3.0.dist-info/licenses/LICENSE +21 -0
- mlx_recurrence-0.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""mlx_recurrence — A plug-in framework for linear-recurrence Metal kernels
|
|
2
|
+
on Apple Silicon ("flash-linear-attention for MLX").
|
|
3
|
+
|
|
4
|
+
Each kernel is a self-contained plug-in built on a shared chassis
|
|
5
|
+
(:mod:`mlx_recurrence._chassis`) that supplies the segment-checkpoint +
|
|
6
|
+
recompute backward pattern, shape validation, and a parity-test helper. The
|
|
7
|
+
Metal source for each recurrence stays in its own module, readable per-kernel.
|
|
8
|
+
|
|
9
|
+
v2 kernels (checkpoint + recompute, fused simd reductions, chunked-prefill
|
|
10
|
+
final-state variants):
|
|
11
|
+
ssd — Mamba-2-style head-wise SSD selective scan
|
|
12
|
+
gla — Gated Linear Attention recurrence
|
|
13
|
+
rglru — RG-LRU diagonal recurrence (Griffin / RecurrentGemma)
|
|
14
|
+
rotlru — rotational LRU: complex-diagonal scan over (u, w) pairs
|
|
15
|
+
|
|
16
|
+
The original v0.1 token-loop kernels remain available under
|
|
17
|
+
``mlx_recurrence.legacy`` and are re-exported at top level for backwards
|
|
18
|
+
compatibility (``selective_scan_metal``, ``gla_scan_metal``, ...).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# --- v2 chassis-based kernels ---------------------------------------------
|
|
22
|
+
from .ssd import (
|
|
23
|
+
ssd_scan,
|
|
24
|
+
ssd_scan_with_state,
|
|
25
|
+
ssd_scan_reference,
|
|
26
|
+
)
|
|
27
|
+
from .gla import (
|
|
28
|
+
gla_scan,
|
|
29
|
+
gla_scan_with_state,
|
|
30
|
+
gla_scan_reference,
|
|
31
|
+
)
|
|
32
|
+
from .rglru import (
|
|
33
|
+
rglru_scan,
|
|
34
|
+
rglru_scan_with_state,
|
|
35
|
+
rglru_scan_reference,
|
|
36
|
+
)
|
|
37
|
+
from .rotlru import (
|
|
38
|
+
rotlru_scan,
|
|
39
|
+
rotlru_scan_with_state,
|
|
40
|
+
rotlru_scan_reference,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# --- shared chassis (public for building new plug-in kernels) -------------
|
|
44
|
+
from ._chassis import (
|
|
45
|
+
DEFAULT_SEG,
|
|
46
|
+
get_or_build_kernel,
|
|
47
|
+
check_segment_shape,
|
|
48
|
+
parity_check,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# --- legacy v0.1 kernels (backwards compatibility) ------------------------
|
|
52
|
+
from . import legacy
|
|
53
|
+
from .legacy import (
|
|
54
|
+
selective_scan_metal,
|
|
55
|
+
selective_scan_chunked,
|
|
56
|
+
gla_scan_metal,
|
|
57
|
+
gla_scan_chunked,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
__all__ = [
|
|
61
|
+
# v2 SSD
|
|
62
|
+
"ssd_scan",
|
|
63
|
+
"ssd_scan_with_state",
|
|
64
|
+
"ssd_scan_reference",
|
|
65
|
+
# v2 GLA
|
|
66
|
+
"gla_scan",
|
|
67
|
+
"gla_scan_with_state",
|
|
68
|
+
"gla_scan_reference",
|
|
69
|
+
# v2 RG-LRU
|
|
70
|
+
"rglru_scan",
|
|
71
|
+
"rglru_scan_with_state",
|
|
72
|
+
"rglru_scan_reference",
|
|
73
|
+
# v2 rotational LRU
|
|
74
|
+
"rotlru_scan",
|
|
75
|
+
"rotlru_scan_with_state",
|
|
76
|
+
"rotlru_scan_reference",
|
|
77
|
+
# chassis
|
|
78
|
+
"DEFAULT_SEG",
|
|
79
|
+
"get_or_build_kernel",
|
|
80
|
+
"check_segment_shape",
|
|
81
|
+
"parity_check",
|
|
82
|
+
# legacy subpackage + re-exports
|
|
83
|
+
"legacy",
|
|
84
|
+
"selective_scan_metal",
|
|
85
|
+
"selective_scan_chunked",
|
|
86
|
+
"gla_scan_metal",
|
|
87
|
+
"gla_scan_chunked",
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
__version__ = "0.3.0"
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""_chassis.py — Shared infrastructure for chassis-based recurrence kernels.
|
|
2
|
+
|
|
3
|
+
This module factors out the machinery common to every v2 plug-in kernel
|
|
4
|
+
(SSD, GLA, RG-LRU) so each kernel file only has to supply its own Metal
|
|
5
|
+
source strings and gradient wiring. It deliberately does NOT abstract the
|
|
6
|
+
Metal source itself: every recurrence has a different state shape and
|
|
7
|
+
update rule, and the kernel bodies are meant to stay readable per-kernel.
|
|
8
|
+
|
|
9
|
+
What lives here
|
|
10
|
+
---------------
|
|
11
|
+
1. Kernel cache / builder around ``mx.fast.metal_kernel`` so each unique
|
|
12
|
+
shape configuration is JIT-compiled exactly once per process.
|
|
13
|
+
2. Shape validation shared by the segment-checkpoint + recompute pattern:
|
|
14
|
+
the sequence length must tile evenly into segments (``L % seg == 0``)
|
|
15
|
+
and the simd-reduced lane dimension must be a multiple of 32.
|
|
16
|
+
3. A reusable parity-test helper that compares forward output plus every
|
|
17
|
+
gradient against a pure-MLX reference and reports max abs / rel diffs.
|
|
18
|
+
|
|
19
|
+
The segment-checkpoint + recompute pattern (shared design)
|
|
20
|
+
----------------------------------------------------------
|
|
21
|
+
Every kernel here follows the same playbook, tuned to the Apple Silicon
|
|
22
|
+
unified-memory hierarchy:
|
|
23
|
+
|
|
24
|
+
Forward: run the recurrence once, write only the state at each
|
|
25
|
+
segment boundary -> ``h_ckpt`` (SEG=32 => ~1/32 the writes of
|
|
26
|
+
saving every timestep). The last checkpoint doubles as the
|
|
27
|
+
chunk's final state, enabling chunked prefill.
|
|
28
|
+
|
|
29
|
+
Backward: walk segments newest -> oldest. For each segment, recompute
|
|
30
|
+
its per-timestep states from the preceding checkpoint into a
|
|
31
|
+
small scratch buffer (one segment's worth, stays resident in
|
|
32
|
+
the system-level cache instead of streaming the full state
|
|
33
|
+
history through DRAM), then run the adjoint sweep over that
|
|
34
|
+
segment. Cross-lane gradient reductions are fused in-kernel
|
|
35
|
+
with ``simd_sum`` over 32-lane simdgroups; the remaining
|
|
36
|
+
sum-over-simdgroups is a single cheap MLX reduction.
|
|
37
|
+
|
|
38
|
+
All kernels keep fp32 state and accumulation regardless of input dtype,
|
|
39
|
+
and reproduce the forward states bit-exactly on recompute (same fp32 ops,
|
|
40
|
+
same order, from the same checkpoint).
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
from __future__ import annotations
|
|
44
|
+
|
|
45
|
+
import mlx.core as mx
|
|
46
|
+
|
|
47
|
+
# Default segment length for the checkpoint+recompute pattern. SEG=32 is a
|
|
48
|
+
# sweet spot on M3 Max: matches the 32-lane simdgroup width used by the
|
|
49
|
+
# fused reductions and keeps the per-segment scratch buffer small enough to
|
|
50
|
+
# stay SLC-resident at training shapes.
|
|
51
|
+
DEFAULT_SEG = 32
|
|
52
|
+
|
|
53
|
+
# Simd lane width on Apple GPUs. Lane dimensions reduced with simd_sum must
|
|
54
|
+
# be a multiple of this.
|
|
55
|
+
SIMD_WIDTH = 32
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
# Kernel cache / builder
|
|
60
|
+
# ---------------------------------------------------------------------------
|
|
61
|
+
|
|
62
|
+
_kernel_cache: dict = {}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_or_build_kernel(name, input_names, output_names, source, header=""):
|
|
66
|
+
"""Compile a Metal kernel once per unique ``name`` and cache it.
|
|
67
|
+
|
|
68
|
+
``name`` should encode every shape/template constant baked into
|
|
69
|
+
``source`` (e.g. ``f"ssd_fwd_{B}_{L}_{H}_{Dh}_{N}_{seg}"``) so that
|
|
70
|
+
distinct shapes get distinct compiled kernels and identical shapes
|
|
71
|
+
reuse the cached one.
|
|
72
|
+
|
|
73
|
+
``source`` is the kernel BODY only (no ``kernel void`` signature) — MLX
|
|
74
|
+
generates the signature from ``input_names`` / ``output_names``. Helper
|
|
75
|
+
functions / includes go in ``header``.
|
|
76
|
+
"""
|
|
77
|
+
if name not in _kernel_cache:
|
|
78
|
+
_kernel_cache[name] = mx.fast.metal_kernel(
|
|
79
|
+
name=name,
|
|
80
|
+
input_names=input_names,
|
|
81
|
+
output_names=output_names,
|
|
82
|
+
source=source,
|
|
83
|
+
header=header,
|
|
84
|
+
)
|
|
85
|
+
return _kernel_cache[name]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# ---------------------------------------------------------------------------
|
|
89
|
+
# Shape validation
|
|
90
|
+
# ---------------------------------------------------------------------------
|
|
91
|
+
|
|
92
|
+
def check_segment_shape(L, seg, lane_dim, lane_name="lane dimension"):
|
|
93
|
+
"""Validate the constraints of the segment-checkpoint + simd-reduce pattern.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
L: sequence length.
|
|
97
|
+
seg: segment length for checkpointing.
|
|
98
|
+
lane_dim: the dimension mapped to 32-lane simdgroups (must tile by 32).
|
|
99
|
+
lane_name: human-readable name of ``lane_dim`` for the error message.
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
ValueError: if either constraint is violated.
|
|
103
|
+
"""
|
|
104
|
+
if seg <= 0:
|
|
105
|
+
raise ValueError(f"seg must be positive, got {seg}")
|
|
106
|
+
if L % seg != 0:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"sequence length L={L} must be divisible by seg={seg} "
|
|
109
|
+
f"(segment-checkpoint pattern tiles L into L/seg segments)"
|
|
110
|
+
)
|
|
111
|
+
if lane_dim % SIMD_WIDTH != 0:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"{lane_name}={lane_dim} must be a multiple of {SIMD_WIDTH} "
|
|
114
|
+
f"(fused gradient reductions use {SIMD_WIDTH}-lane simdgroups)"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ---------------------------------------------------------------------------
|
|
119
|
+
# Parity-test helper
|
|
120
|
+
# ---------------------------------------------------------------------------
|
|
121
|
+
|
|
122
|
+
def parity_check(
|
|
123
|
+
kernel_fn,
|
|
124
|
+
reference_fn,
|
|
125
|
+
inputs,
|
|
126
|
+
arg_names,
|
|
127
|
+
grad_argnums,
|
|
128
|
+
*,
|
|
129
|
+
w_out=None,
|
|
130
|
+
y_tol=1e-3,
|
|
131
|
+
grad_rtol=1e-3,
|
|
132
|
+
label="",
|
|
133
|
+
verbose=True,
|
|
134
|
+
):
|
|
135
|
+
"""Compare a kernel against a pure-MLX reference: forward + all grads.
|
|
136
|
+
|
|
137
|
+
Both ``kernel_fn`` and ``reference_fn`` take the positional ``inputs``
|
|
138
|
+
and return the forward output ``y``. This helper builds a scalar loss
|
|
139
|
+
``sum(y * w_out)`` and compares ``mx.grad`` of that loss w.r.t. every
|
|
140
|
+
argument in ``grad_argnums``.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
kernel_fn: the kernel under test, ``fn(*inputs) -> y``.
|
|
144
|
+
reference_fn: pure-MLX reference, ``fn(*inputs) -> y``.
|
|
145
|
+
inputs: tuple/list of input arrays (positional).
|
|
146
|
+
arg_names: names for each input (for readable output), same length
|
|
147
|
+
as ``inputs``.
|
|
148
|
+
grad_argnums: tuple of argument indices to differentiate.
|
|
149
|
+
w_out: output weighting for the scalar loss. Defaults to a
|
|
150
|
+
fixed-seed random tensor shaped like ``y``.
|
|
151
|
+
y_tol: absolute tolerance on the forward output diff.
|
|
152
|
+
grad_rtol: relative tolerance on each gradient diff.
|
|
153
|
+
label: prefix printed before results.
|
|
154
|
+
verbose: print per-tensor diffs.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
(ok: bool, report: dict) where ``report`` maps each compared tensor
|
|
158
|
+
name to ``{"abs": max_abs_diff, "rel": max_rel_diff}``.
|
|
159
|
+
"""
|
|
160
|
+
inputs = list(inputs)
|
|
161
|
+
|
|
162
|
+
y_k = kernel_fn(*inputs)
|
|
163
|
+
y_r = reference_fn(*inputs)
|
|
164
|
+
mx.eval(y_k, y_r)
|
|
165
|
+
|
|
166
|
+
if w_out is None:
|
|
167
|
+
mx.random.seed(1234)
|
|
168
|
+
w_out = mx.random.normal(y_r.shape)
|
|
169
|
+
mx.eval(w_out)
|
|
170
|
+
|
|
171
|
+
def loss_kernel(*args):
|
|
172
|
+
return mx.sum(kernel_fn(*args) * w_out)
|
|
173
|
+
|
|
174
|
+
def loss_ref(*args):
|
|
175
|
+
return mx.sum(reference_fn(*args) * w_out)
|
|
176
|
+
|
|
177
|
+
report = {}
|
|
178
|
+
ok = True
|
|
179
|
+
|
|
180
|
+
y_abs = float(mx.max(mx.abs(y_k - y_r)))
|
|
181
|
+
y_scale = float(mx.max(mx.abs(y_r))) + 1e-8
|
|
182
|
+
y_rel = y_abs / y_scale
|
|
183
|
+
report["y"] = {"abs": y_abs, "rel": y_rel}
|
|
184
|
+
ok = ok and (y_abs < y_tol)
|
|
185
|
+
|
|
186
|
+
g_k = mx.grad(loss_kernel, argnums=grad_argnums)(*inputs)
|
|
187
|
+
g_r = mx.grad(loss_ref, argnums=grad_argnums)(*inputs)
|
|
188
|
+
mx.eval(g_k, g_r)
|
|
189
|
+
|
|
190
|
+
if not isinstance(g_k, (tuple, list)):
|
|
191
|
+
g_k = (g_k,)
|
|
192
|
+
g_r = (g_r,)
|
|
193
|
+
|
|
194
|
+
grad_names = [arg_names[i] for i in grad_argnums]
|
|
195
|
+
for name, gk, gr in zip(grad_names, g_k, g_r):
|
|
196
|
+
abs_diff = float(mx.max(mx.abs(gk - gr)))
|
|
197
|
+
scale = float(mx.max(mx.abs(gr))) + 1e-8
|
|
198
|
+
rel = abs_diff / scale
|
|
199
|
+
report[f"grad_{name}"] = {"abs": abs_diff, "rel": rel}
|
|
200
|
+
ok = ok and (rel < grad_rtol)
|
|
201
|
+
|
|
202
|
+
if verbose:
|
|
203
|
+
prefix = f"{label} " if label else ""
|
|
204
|
+
print(f"{prefix}y max|diff| = {y_abs:.2e} (rel {y_rel:.2e})")
|
|
205
|
+
for name in grad_names:
|
|
206
|
+
r = report[f"grad_{name}"]
|
|
207
|
+
print(
|
|
208
|
+
f"{prefix}grad_{name:<8} max|diff| = {r['abs']:.2e}"
|
|
209
|
+
f" (rel {r['rel']:.2e})"
|
|
210
|
+
)
|
|
211
|
+
print(f"{prefix}-> {'PASS' if ok else 'FAIL'}")
|
|
212
|
+
|
|
213
|
+
return ok, report
|
mlx_recurrence/gla.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""gla.py — Gated Linear Attention recurrence (checkpoint + recompute).
|
|
2
|
+
|
|
3
|
+
v2 chassis port of the GLA recurrence kernel. Same segment-checkpoint +
|
|
4
|
+
recompute backward pattern as :mod:`mlx_recurrence.ssd`, applied to the
|
|
5
|
+
matrix-valued GLA state.
|
|
6
|
+
|
|
7
|
+
Recurrence
|
|
8
|
+
----------
|
|
9
|
+
Per head ``head`` the state is the ``Dh x Dh`` matrix ``h[b, head, i, j]``
|
|
10
|
+
with a single scalar forget gate per token::
|
|
11
|
+
|
|
12
|
+
h[i, j] = gate[b,t,head] * h[i, j] + k[b,t,head,i] * v[b,t,head,j]
|
|
13
|
+
o[b,t,head,j] = sum_i q[b,t,head,i] * h[i, j]
|
|
14
|
+
|
|
15
|
+
i.e. ``h_t = gate_t * h_{t-1} + k_t (outer) v_t`` and ``o_t = q_t @ h_t``
|
|
16
|
+
(output uses the post-update state). ``gate`` is typically a sigmoid output
|
|
17
|
+
in ``(0, 1)``; ``q`` is assumed pre-scaled / post-RoPE.
|
|
18
|
+
|
|
19
|
+
The conceptual state tensor is ``[B, H, Dh, Dh]``. The checkpoint is laid
|
|
20
|
+
out ``[B, nSeg, H, Dh, Dh]`` with ``j`` fastest-moving so the 32 lanes of a
|
|
21
|
+
simdgroup own 32 contiguous ``j`` columns — coalesced reads/writes.
|
|
22
|
+
|
|
23
|
+
Numerics: fp32 state and accumulation regardless of input dtype; identical
|
|
24
|
+
update order (output uses post-update ``h``) and gradient formulas to the
|
|
25
|
+
validated D-CSIL-3 training kernel. grad_v is exact per-thread; grad_q,
|
|
26
|
+
grad_k, grad_gates are fused in-kernel via 32-lane ``simd_sum``.
|
|
27
|
+
|
|
28
|
+
Constraints:
|
|
29
|
+
L % seg == 0
|
|
30
|
+
Dh % 32 == 0 (Dh is the simd-reduced lane dimension)
|
|
31
|
+
|
|
32
|
+
Public API:
|
|
33
|
+
gla_scan(q, k, v, gates, seg=32) -> y
|
|
34
|
+
gla_scan_with_state(q, k, v, gates, seg=32) -> (y, final_state)
|
|
35
|
+
gla_scan_reference(q, k, v, gates) -> y (pure MLX)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
import mlx.core as mx
|
|
41
|
+
|
|
42
|
+
from ._chassis import DEFAULT_SEG, get_or_build_kernel, check_segment_shape
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ---------------------------------------------------------------------------
|
|
46
|
+
# Metal forward: GLA scan with segment checkpoints (no full state history)
|
|
47
|
+
# ---------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
def _gla_forward_kernel(q, k, v, gates, seg):
|
|
50
|
+
"""Forward GLA scan writing only segment-boundary state.
|
|
51
|
+
|
|
52
|
+
One thread per (batch*head, j) owns the ``h[:, j]`` state column in
|
|
53
|
+
registers across all L timesteps.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
q, k, v: [B, L, H, Dh]
|
|
57
|
+
gates: [B, L, H]
|
|
58
|
+
seg: segment length (L % seg == 0)
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
y: [B, L, H, Dh]
|
|
62
|
+
h_ckpt: [B, nSeg, H, Dh, Dh] state at the END of each segment
|
|
63
|
+
"""
|
|
64
|
+
B_batch, L, H, Dh = q.shape
|
|
65
|
+
n_seg = L // seg
|
|
66
|
+
|
|
67
|
+
source = f"""
|
|
68
|
+
uint j = thread_position_in_grid.x;
|
|
69
|
+
uint bh = thread_position_in_grid.y;
|
|
70
|
+
if (j >= {Dh}u || bh >= {B_batch * H}u) return;
|
|
71
|
+
|
|
72
|
+
uint b = bh / {H}u;
|
|
73
|
+
uint head = bh % {H}u;
|
|
74
|
+
|
|
75
|
+
float h[{Dh}];
|
|
76
|
+
for (int i = 0; i < {Dh}; i++) h[i] = 0.0f;
|
|
77
|
+
|
|
78
|
+
for (int t = 0; t < {L}; t++) {{
|
|
79
|
+
int g_idx = (b * {L} + t) * {H} + head;
|
|
80
|
+
int kv_base = ((b * {L} + t) * {H} + head) * {Dh};
|
|
81
|
+
|
|
82
|
+
float gate = (float)gates[g_idx];
|
|
83
|
+
float v_j = (float)v[kv_base + j];
|
|
84
|
+
float o_j = 0.0f;
|
|
85
|
+
|
|
86
|
+
for (int i = 0; i < {Dh}; i++) {{
|
|
87
|
+
h[i] = gate * h[i] + (float)k[kv_base + i] * v_j;
|
|
88
|
+
o_j += (float)q[kv_base + i] * h[i];
|
|
89
|
+
}}
|
|
90
|
+
output[kv_base + j] = o_j;
|
|
91
|
+
|
|
92
|
+
// checkpoint at end of each segment (j fastest -> coalesced)
|
|
93
|
+
if (((t + 1) % {seg}) == 0) {{
|
|
94
|
+
int s = t / {seg};
|
|
95
|
+
for (int i = 0; i < {Dh}; i++) {{
|
|
96
|
+
int ck_idx = ((((b * {n_seg} + s) * {H} + head) * {Dh} + i) * {Dh} + j);
|
|
97
|
+
h_ckpt[ck_idx] = h[i];
|
|
98
|
+
}}
|
|
99
|
+
}}
|
|
100
|
+
}}
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
kernel = get_or_build_kernel(
|
|
104
|
+
f"gla_fwd_{B_batch}_{L}_{H}_{Dh}_{seg}",
|
|
105
|
+
input_names=["q", "k", "v", "gates"],
|
|
106
|
+
output_names=["output", "h_ckpt"],
|
|
107
|
+
source=source,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
results = kernel(
|
|
111
|
+
inputs=[q.reshape(-1), k.reshape(-1), v.reshape(-1), gates.reshape(-1)],
|
|
112
|
+
output_shapes=[
|
|
113
|
+
(B_batch * L * H * Dh,),
|
|
114
|
+
(B_batch * n_seg * H * Dh * Dh,),
|
|
115
|
+
],
|
|
116
|
+
output_dtypes=[mx.float32, mx.float32],
|
|
117
|
+
grid=(Dh, B_batch * H, 1),
|
|
118
|
+
threadgroup=(min(Dh, 256), 1, 1),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
y = results[0].reshape(B_batch, L, H, Dh)
|
|
122
|
+
h_ckpt = results[1].reshape(B_batch, n_seg, H, Dh, Dh)
|
|
123
|
+
return y, h_ckpt
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# ---------------------------------------------------------------------------
|
|
127
|
+
# Metal backward: segment recompute + fused simd-reduced gradient partials
|
|
128
|
+
# ---------------------------------------------------------------------------
|
|
129
|
+
|
|
130
|
+
def _gla_backward_kernel(grad_y, h_ckpt, q, k, v, gates, seg):
|
|
131
|
+
"""Recompute states from checkpoints, then adjoint sweep with fused reductions.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
grad_v: [B, L, H, Dh] (exact, per-thread)
|
|
135
|
+
grad_q_p: [B, L, H, nW, Dh] (sum over nW -> grad_q)
|
|
136
|
+
grad_k_p: [B, L, H, nW, Dh] (sum over nW -> grad_k)
|
|
137
|
+
grad_g_p: [B, L, H, nW] (sum over nW -> grad_gates)
|
|
138
|
+
"""
|
|
139
|
+
B_batch, L, H, Dh = q.shape
|
|
140
|
+
n_seg = L // seg
|
|
141
|
+
n_w = Dh // 32 # simdgroups (j-lane groups) per head
|
|
142
|
+
|
|
143
|
+
source = f"""
|
|
144
|
+
uint j = thread_position_in_grid.x;
|
|
145
|
+
uint bh = thread_position_in_grid.y;
|
|
146
|
+
if (j >= {Dh}u || bh >= {B_batch * H}u) return;
|
|
147
|
+
|
|
148
|
+
uint b = bh / {H}u;
|
|
149
|
+
uint head = bh % {H}u;
|
|
150
|
+
// threadgroup x is a multiple of 32 and x-major, so lanes of a
|
|
151
|
+
// simdgroup are 32 consecutive j values within one head.
|
|
152
|
+
uint lane = j % 32u;
|
|
153
|
+
uint w = j / 32u;
|
|
154
|
+
|
|
155
|
+
float adj[{Dh}];
|
|
156
|
+
for (int i = 0; i < {Dh}; i++) adj[i] = 0.0f;
|
|
157
|
+
|
|
158
|
+
for (int s = {n_seg - 1}; s >= 0; s--) {{
|
|
159
|
+
// ---- phase 1: recompute states for this segment ----
|
|
160
|
+
float h[{Dh}];
|
|
161
|
+
if (s > 0) {{
|
|
162
|
+
for (int i = 0; i < {Dh}; i++) {{
|
|
163
|
+
int ck_idx = ((((b * {n_seg} + (s - 1)) * {H} + head) * {Dh} + i) * {Dh} + j);
|
|
164
|
+
h[i] = h_ckpt[ck_idx];
|
|
165
|
+
}}
|
|
166
|
+
}} else {{
|
|
167
|
+
for (int i = 0; i < {Dh}; i++) h[i] = 0.0f;
|
|
168
|
+
}}
|
|
169
|
+
|
|
170
|
+
for (int tl = 0; tl < {seg}; tl++) {{
|
|
171
|
+
int t = s * {seg} + tl;
|
|
172
|
+
int g_idx = (b * {L} + t) * {H} + head;
|
|
173
|
+
int kv_base = ((b * {L} + t) * {H} + head) * {Dh};
|
|
174
|
+
float gate = (float)gates[g_idx];
|
|
175
|
+
float v_j = (float)v[kv_base + j];
|
|
176
|
+
|
|
177
|
+
for (int i = 0; i < {Dh}; i++) {{
|
|
178
|
+
h[i] = gate * h[i] + (float)k[kv_base + i] * v_j;
|
|
179
|
+
int sc_idx = ((((b * {H} + head) * {seg} + tl) * {Dh} + i) * {Dh} + j);
|
|
180
|
+
scratch[sc_idx] = h[i];
|
|
181
|
+
}}
|
|
182
|
+
}}
|
|
183
|
+
|
|
184
|
+
// ---- phase 2: adjoint sweep, newest -> oldest ----
|
|
185
|
+
for (int tl = {seg - 1}; tl >= 0; tl--) {{
|
|
186
|
+
int t = s * {seg} + tl;
|
|
187
|
+
int g_idx = (b * {L} + t) * {H} + head;
|
|
188
|
+
int kv_base = ((b * {L} + t) * {H} + head) * {Dh};
|
|
189
|
+
|
|
190
|
+
float gate = (float)gates[g_idx];
|
|
191
|
+
float v_j = (float)v[kv_base + j];
|
|
192
|
+
float go_j = (float)grad_y[kv_base + j];
|
|
193
|
+
|
|
194
|
+
float gv_j = 0.0f;
|
|
195
|
+
float gg = 0.0f;
|
|
196
|
+
|
|
197
|
+
for (int i = 0; i < {Dh}; i++) {{
|
|
198
|
+
float ki = (float)k[kv_base + i];
|
|
199
|
+
|
|
200
|
+
int sc_idx = ((((b * {H} + head) * {seg} + tl) * {Dh} + i) * {Dh} + j);
|
|
201
|
+
float h_cur = scratch[sc_idx];
|
|
202
|
+
float h_prev;
|
|
203
|
+
if (tl > 0) {{
|
|
204
|
+
h_prev = scratch[sc_idx - {Dh * Dh}];
|
|
205
|
+
}} else if (s > 0) {{
|
|
206
|
+
int ck_idx = ((((b * {n_seg} + (s - 1)) * {H} + head) * {Dh} + i) * {Dh} + j);
|
|
207
|
+
h_prev = h_ckpt[ck_idx];
|
|
208
|
+
}} else {{
|
|
209
|
+
h_prev = 0.0f;
|
|
210
|
+
}}
|
|
211
|
+
|
|
212
|
+
// driving term (adj sampled after this, before gate multiply)
|
|
213
|
+
adj[i] += (float)q[kv_base + i] * go_j;
|
|
214
|
+
|
|
215
|
+
gv_j += adj[i] * ki;
|
|
216
|
+
gg += adj[i] * h_prev;
|
|
217
|
+
|
|
218
|
+
// fused reductions over the 32 j-lanes of this simdgroup
|
|
219
|
+
float gq_l = simd_sum(go_j * h_cur);
|
|
220
|
+
float gk_l = simd_sum(adj[i] * v_j);
|
|
221
|
+
if (lane == 0u) {{
|
|
222
|
+
int p_idx = ((((b * {L} + t) * {H} + head) * {n_w} + w) * {Dh} + i);
|
|
223
|
+
grad_q_p[p_idx] = gq_l;
|
|
224
|
+
grad_k_p[p_idx] = gk_l;
|
|
225
|
+
}}
|
|
226
|
+
|
|
227
|
+
adj[i] *= gate;
|
|
228
|
+
}}
|
|
229
|
+
|
|
230
|
+
grad_v[kv_base + j] = gv_j;
|
|
231
|
+
float gg_l = simd_sum(gg);
|
|
232
|
+
if (lane == 0u) {{
|
|
233
|
+
grad_g_p[(((b * {L} + t) * {H} + head) * {n_w} + w)] = gg_l;
|
|
234
|
+
}}
|
|
235
|
+
}}
|
|
236
|
+
}}
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
kernel = get_or_build_kernel(
|
|
240
|
+
f"gla_bwd_{B_batch}_{L}_{H}_{Dh}_{seg}",
|
|
241
|
+
input_names=["grad_y", "h_ckpt", "q", "k", "v", "gates"],
|
|
242
|
+
output_names=["grad_v", "grad_q_p", "grad_k_p", "grad_g_p", "scratch"],
|
|
243
|
+
source=source,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
results = kernel(
|
|
247
|
+
inputs=[grad_y.reshape(-1), h_ckpt.reshape(-1), q.reshape(-1),
|
|
248
|
+
k.reshape(-1), v.reshape(-1), gates.reshape(-1)],
|
|
249
|
+
output_shapes=[
|
|
250
|
+
(B_batch * L * H * Dh,),
|
|
251
|
+
(B_batch * L * H * n_w * Dh,),
|
|
252
|
+
(B_batch * L * H * n_w * Dh,),
|
|
253
|
+
(B_batch * L * H * n_w,),
|
|
254
|
+
(B_batch * H * seg * Dh * Dh,), # scratch, discarded
|
|
255
|
+
],
|
|
256
|
+
output_dtypes=[mx.float32] * 5,
|
|
257
|
+
grid=(Dh, B_batch * H, 1),
|
|
258
|
+
threadgroup=(min(Dh, 256), 1, 1),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
grad_v = results[0].reshape(B_batch, L, H, Dh)
|
|
262
|
+
grad_q_p = results[1].reshape(B_batch, L, H, n_w, Dh)
|
|
263
|
+
grad_k_p = results[2].reshape(B_batch, L, H, n_w, Dh)
|
|
264
|
+
grad_g_p = results[3].reshape(B_batch, L, H, n_w)
|
|
265
|
+
|
|
266
|
+
grad_q = mx.sum(grad_q_p, axis=3) # [B, L, H, Dh]
|
|
267
|
+
grad_k = mx.sum(grad_k_p, axis=3) # [B, L, H, Dh]
|
|
268
|
+
grad_gates = mx.sum(grad_g_p, axis=3) # [B, L, H]
|
|
269
|
+
|
|
270
|
+
return grad_q, grad_k, grad_v, grad_gates
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# ---------------------------------------------------------------------------
|
|
274
|
+
# Custom function + VJP (one cached impl per seg, so seg can be a Python arg)
|
|
275
|
+
# ---------------------------------------------------------------------------
|
|
276
|
+
|
|
277
|
+
_impl_cache: dict = {}
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _make_impl(seg):
|
|
281
|
+
"""Build (and cache) an ``mx.custom_function`` GLA impl bound to ``seg``."""
|
|
282
|
+
if seg in _impl_cache:
|
|
283
|
+
return _impl_cache[seg]
|
|
284
|
+
|
|
285
|
+
@mx.custom_function
|
|
286
|
+
def _impl(q, k, v, gates):
|
|
287
|
+
check_segment_shape(q.shape[1], seg, q.shape[3], "Dh")
|
|
288
|
+
return _gla_forward_kernel(q, k, v, gates, seg)
|
|
289
|
+
|
|
290
|
+
@_impl.vjp
|
|
291
|
+
def _vjp(primals, cotangents, outputs):
|
|
292
|
+
q, k, v, gates = primals
|
|
293
|
+
grad_y = cotangents[0]
|
|
294
|
+
_y, h_ckpt = outputs
|
|
295
|
+
return _gla_backward_kernel(grad_y, h_ckpt, q, k, v, gates, seg)
|
|
296
|
+
|
|
297
|
+
_impl_cache[seg] = _impl
|
|
298
|
+
return _impl
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def gla_scan(q, k, v, gates, seg=DEFAULT_SEG):
|
|
302
|
+
"""Gated Linear Attention recurrence, fused Metal forward + backward.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
q: [B, L, H, Dh] queries (pre-scaled, post-RoPE)
|
|
306
|
+
k: [B, L, H, Dh] keys
|
|
307
|
+
v: [B, L, H, Dh] values
|
|
308
|
+
gates: [B, L, H] scalar forget gate per head, typically in (0, 1)
|
|
309
|
+
seg: segment length for checkpointing (L % seg == 0; default 32)
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
y: [B, L, H, Dh]
|
|
313
|
+
|
|
314
|
+
Note: ``Dh`` must be a multiple of 32 (it is the simd-reduced lane dim).
|
|
315
|
+
fp32 state/accumulation internally; bf16 inputs widen implicitly.
|
|
316
|
+
"""
|
|
317
|
+
y, _h_ckpt = _make_impl(seg)(q, k, v, gates)
|
|
318
|
+
return y
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def gla_scan_with_state(q, k, v, gates, seg=DEFAULT_SEG):
|
|
322
|
+
"""GLA scan that also returns the final state for chunked prefill.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
y: [B, L, H, Dh]
|
|
326
|
+
final_state: [B, H, Dh, Dh] (matches the GLA conceptual state)
|
|
327
|
+
"""
|
|
328
|
+
y, h_ckpt = _make_impl(seg)(q, k, v, gates)
|
|
329
|
+
return y, h_ckpt[:, -1]
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
# ---------------------------------------------------------------------------
|
|
333
|
+
# Pure-MLX reference (slow, for parity testing only)
|
|
334
|
+
# ---------------------------------------------------------------------------
|
|
335
|
+
|
|
336
|
+
def gla_scan_reference(q, k, v, gates):
|
|
337
|
+
"""Pure-MLX token-loop reference for :func:`gla_scan`. Differentiable."""
|
|
338
|
+
B_batch, L, H, Dh = q.shape
|
|
339
|
+
h = mx.zeros((B_batch, H, Dh, Dh))
|
|
340
|
+
ys = []
|
|
341
|
+
for t in range(L):
|
|
342
|
+
g = gates[:, t, :, None, None] # [B,H,1,1]
|
|
343
|
+
kv = k[:, t, :, :, None] * v[:, t, :, None, :] # [B,H,Dh,Dh]
|
|
344
|
+
h = g * h + kv
|
|
345
|
+
ys.append(mx.sum(q[:, t, :, :, None] * h, axis=-2)) # [B,H,Dh]
|
|
346
|
+
return mx.stack(ys, axis=1) # [B,L,H,Dh]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""mlx_recurrence.legacy — v1 token-loop Metal kernels.
|
|
2
|
+
|
|
3
|
+
These are the original (v0.1) kernels: a Metal forward that materialises
|
|
4
|
+
the full per-timestep state tensor h_all and a backward that reads it back.
|
|
5
|
+
They are kept for backwards compatibility and as a simple, readable
|
|
6
|
+
reference. New work should prefer the v2 chassis-based kernels
|
|
7
|
+
(``mlx_recurrence.ssd``, ``mlx_recurrence.gla``, ``mlx_recurrence.rglru``),
|
|
8
|
+
which use segment checkpointing + recompute to slash DRAM traffic.
|
|
9
|
+
|
|
10
|
+
Public API (unchanged from v0.1):
|
|
11
|
+
selective_scan_metal, selective_scan_chunked
|
|
12
|
+
gla_scan_metal, gla_scan_chunked
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from .ssm_scan import (
|
|
16
|
+
selective_scan_metal,
|
|
17
|
+
selective_scan_chunked,
|
|
18
|
+
_ssm_forward_kernel,
|
|
19
|
+
_ssm_backward_chunked,
|
|
20
|
+
_ssm_backward_metal,
|
|
21
|
+
)
|
|
22
|
+
from .gla_scan import (
|
|
23
|
+
gla_scan_metal,
|
|
24
|
+
gla_scan_chunked,
|
|
25
|
+
_gla_forward_kernel,
|
|
26
|
+
_gla_backward_chunked,
|
|
27
|
+
_gla_backward_metal,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"selective_scan_metal",
|
|
32
|
+
"selective_scan_chunked",
|
|
33
|
+
"gla_scan_metal",
|
|
34
|
+
"gla_scan_chunked",
|
|
35
|
+
"_ssm_forward_kernel",
|
|
36
|
+
"_ssm_backward_chunked",
|
|
37
|
+
"_ssm_backward_metal",
|
|
38
|
+
"_gla_forward_kernel",
|
|
39
|
+
"_gla_backward_chunked",
|
|
40
|
+
"_gla_backward_metal",
|
|
41
|
+
]
|