turboquant-gpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- turboquant_gpu/__init__.py +9 -0
- turboquant_gpu/attention.py +386 -0
- turboquant_gpu/codebook.py +70 -0
- turboquant_gpu/compress.py +210 -0
- turboquant_gpu/constants.py +8 -0
- turboquant_gpu/decompress.py +77 -0
- turboquant_gpu/host.py +288 -0
- turboquant_gpu-0.1.0.dist-info/METADATA +96 -0
- turboquant_gpu-0.1.0.dist-info/RECORD +12 -0
- turboquant_gpu-0.1.0.dist-info/WHEEL +5 -0
- turboquant_gpu-0.1.0.dist-info/licenses/LICENSE +21 -0
- turboquant_gpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
|
2
|
+
|
|
3
|
+
def __getattr__(name):
|
|
4
|
+
if name == "TurboQuantEngine":
|
|
5
|
+
from .host import TurboQuantEngine
|
|
6
|
+
return TurboQuantEngine
|
|
7
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
8
|
+
|
|
9
|
+
__all__ = ["__version__", "TurboQuantEngine"]
|
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
"""
|
|
2
|
+
asymmetric attention kernels for compressed keys.
|
|
3
|
+
|
|
4
|
+
score(q, k) ~ <q, k_mse> + ||r|| * sqrt(pi/2)/m * <S*q, signs>
|
|
5
|
+
|
|
6
|
+
score-only variant outputs raw scores.
|
|
7
|
+
fused variant does online softmax + V accumulation in one pass.
|
|
8
|
+
USE_SWIZZLE enables exp2 ftz, approx div, prefetched loads, and block
|
|
9
|
+
interleaving on architectures that support them.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import cuda.tile as ct
|
|
16
|
+
from cuda.tile import RoundingMode as RMd
|
|
17
|
+
_HAS_APPROX = hasattr(RMd, "APPROX")
|
|
18
|
+
except ImportError:
|
|
19
|
+
import cutile as ct # type: ignore
|
|
20
|
+
RMd = None
|
|
21
|
+
_HAS_APPROX = False
|
|
22
|
+
|
|
23
|
+
from .constants import BLOCK_Q, BLOCK_KV, HEAD_DIM
|
|
24
|
+
|
|
25
|
+
INV_LOG_2 = 1.0 / math.log(2)
|
|
26
|
+
ConstBool = ct.Constant[bool]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ── score-only kernel ────────────────────────────────────────────────
|
|
30
|
+
|
|
31
|
+
@ct.kernel(occupancy=2)
|
|
32
|
+
def turboquant_attention_scores(
|
|
33
|
+
Q, K_mse, Signs, R_norms, Q_proj, Output,
|
|
34
|
+
scale: float,
|
|
35
|
+
correction_scale: float,
|
|
36
|
+
seq_k: int,
|
|
37
|
+
USE_SWIZZLE: ConstBool,
|
|
38
|
+
):
|
|
39
|
+
"""one program per query block, streams over all KV blocks."""
|
|
40
|
+
q_block = ct.bid(0)
|
|
41
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
42
|
+
|
|
43
|
+
q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
44
|
+
padding_mode=zero_pad)
|
|
45
|
+
qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
46
|
+
padding_mode=zero_pad)
|
|
47
|
+
|
|
48
|
+
num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
|
|
49
|
+
|
|
50
|
+
for kv_block in range(num_kv_blocks):
|
|
51
|
+
if USE_SWIZZLE:
|
|
52
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
53
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
|
|
54
|
+
latency=2)
|
|
55
|
+
else:
|
|
56
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
57
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
58
|
+
|
|
59
|
+
# mse term: Q @ K_mse^T
|
|
60
|
+
term1 = ct.mma(q_tile, ct.transpose(k_tile),
|
|
61
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
62
|
+
|
|
63
|
+
# qjl bias-correction term
|
|
64
|
+
s_tile = ct.load(Signs, index=(kv_block, 0),
|
|
65
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
66
|
+
s_float = ct.astype(s_tile, ct.float16)
|
|
67
|
+
|
|
68
|
+
qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
|
|
69
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
70
|
+
|
|
71
|
+
rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
|
|
72
|
+
padding_mode=zero_pad)
|
|
73
|
+
rn_f32 = ct.astype(rn, ct.float32)
|
|
74
|
+
|
|
75
|
+
term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
|
|
76
|
+
|
|
77
|
+
scores = (term1 + term2) * scale
|
|
78
|
+
ct.store(Output, index=(q_block, kv_block), tile=scores)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ── fused attention (pre-decompressed V) ─────────────────────────────
|
|
82
|
+
|
|
83
|
+
@ct.kernel(occupancy=2)
|
|
84
|
+
def turboquant_fused_attention(
|
|
85
|
+
Q, K_mse, Signs, R_norms, Q_proj, V, Output,
|
|
86
|
+
scale: float,
|
|
87
|
+
correction_scale: float,
|
|
88
|
+
seq_k: int,
|
|
89
|
+
USE_SWIZZLE: ConstBool,
|
|
90
|
+
):
|
|
91
|
+
"""fused attention + online softmax. V is already fp16."""
|
|
92
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
93
|
+
|
|
94
|
+
if USE_SWIZZLE:
|
|
95
|
+
half = ct.cdiv(ct.num_blocks(0), 2)
|
|
96
|
+
q_block = ct.bid(0) // 2 + (ct.bid(0) % 2) * half
|
|
97
|
+
else:
|
|
98
|
+
q_block = ct.bid(0)
|
|
99
|
+
|
|
100
|
+
q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
101
|
+
padding_mode=zero_pad)
|
|
102
|
+
qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
103
|
+
padding_mode=zero_pad)
|
|
104
|
+
|
|
105
|
+
# online softmax accumulators
|
|
106
|
+
m_i = ct.full((BLOCK_Q,), -1e30, dtype=ct.float32)
|
|
107
|
+
l_i = ct.zeros((BLOCK_Q,), dtype=ct.float32)
|
|
108
|
+
acc = ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32)
|
|
109
|
+
|
|
110
|
+
scale_log2 = scale * INV_LOG_2
|
|
111
|
+
num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
|
|
112
|
+
|
|
113
|
+
for kv_block in range(num_kv_blocks):
|
|
114
|
+
if USE_SWIZZLE:
|
|
115
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
116
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
|
|
117
|
+
latency=2)
|
|
118
|
+
v_tile = ct.load(V, index=(kv_block, 0),
|
|
119
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
|
|
120
|
+
latency=4)
|
|
121
|
+
else:
|
|
122
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
123
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
124
|
+
v_tile = ct.load(V, index=(kv_block, 0),
|
|
125
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
126
|
+
|
|
127
|
+
s_tile = ct.load(Signs, index=(kv_block, 0),
|
|
128
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
129
|
+
rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
|
|
130
|
+
padding_mode=zero_pad)
|
|
131
|
+
|
|
132
|
+
term1 = ct.mma(q_tile, ct.transpose(k_tile),
|
|
133
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
134
|
+
|
|
135
|
+
s_float = ct.astype(s_tile, ct.float16)
|
|
136
|
+
qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
|
|
137
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
138
|
+
rn_f32 = ct.astype(rn, ct.float32)
|
|
139
|
+
term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
|
|
140
|
+
|
|
141
|
+
raw_scores = term1 + term2
|
|
142
|
+
|
|
143
|
+
# online softmax update
|
|
144
|
+
if USE_SWIZZLE:
|
|
145
|
+
m_new = ct.maximum(m_i, ct.max(raw_scores, axis=1) * scale_log2)
|
|
146
|
+
alpha = ct.exp2(m_i - m_new, flush_to_zero=True)
|
|
147
|
+
p = ct.exp2(raw_scores * scale_log2 - ct.expand_dims(m_new, axis=1),
|
|
148
|
+
flush_to_zero=True)
|
|
149
|
+
else:
|
|
150
|
+
scores = raw_scores * scale
|
|
151
|
+
m_new = ct.maximum(m_i, ct.max(scores, axis=1))
|
|
152
|
+
alpha = ct.exp(m_i - m_new)
|
|
153
|
+
p = ct.exp(scores - ct.expand_dims(m_new, axis=1))
|
|
154
|
+
|
|
155
|
+
l_i = alpha * l_i + ct.sum(p, axis=1)
|
|
156
|
+
|
|
157
|
+
p_fp16 = ct.astype(p, ct.float16)
|
|
158
|
+
acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(
|
|
159
|
+
p_fp16, v_tile, ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32))
|
|
160
|
+
|
|
161
|
+
m_i = m_new
|
|
162
|
+
|
|
163
|
+
if USE_SWIZZLE and _HAS_APPROX:
|
|
164
|
+
out = ct.truediv(acc, ct.expand_dims(l_i, axis=1),
|
|
165
|
+
flush_to_zero=True, rounding_mode=RMd.APPROX)
|
|
166
|
+
else:
|
|
167
|
+
out = acc / ct.expand_dims(l_i, axis=1)
|
|
168
|
+
|
|
169
|
+
ct.store(Output, index=(q_block, 0), tile=out)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
# ── fused attention (on-chip 3-bit V decompression) ──────────────────
|
|
173
|
+
|
|
174
|
+
@ct.kernel(occupancy=2)
|
|
175
|
+
def turboquant_fused_attention_vfused_3bit(
|
|
176
|
+
Q, K_mse, Signs, R_norms, Q_proj,
|
|
177
|
+
V_Indices, V_Norms, Pi, Output,
|
|
178
|
+
scale: float,
|
|
179
|
+
correction_scale: float,
|
|
180
|
+
seq_k: int,
|
|
181
|
+
vc0: float, vc1: float, vc2: float, vc3: float,
|
|
182
|
+
vc4: float, vc5: float, vc6: float, vc7: float,
|
|
183
|
+
USE_SWIZZLE: ConstBool,
|
|
184
|
+
):
|
|
185
|
+
"""fused attention that decompresses 3-bit V on-chip per block."""
|
|
186
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
187
|
+
|
|
188
|
+
if USE_SWIZZLE:
|
|
189
|
+
half = ct.cdiv(ct.num_blocks(0), 2)
|
|
190
|
+
q_block = ct.bid(0) // 2 + (ct.bid(0) % 2) * half
|
|
191
|
+
else:
|
|
192
|
+
q_block = ct.bid(0)
|
|
193
|
+
|
|
194
|
+
q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
195
|
+
padding_mode=zero_pad)
|
|
196
|
+
qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
197
|
+
padding_mode=zero_pad)
|
|
198
|
+
|
|
199
|
+
pi_tile = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
200
|
+
|
|
201
|
+
m_i = ct.full((BLOCK_Q,), -1e30, dtype=ct.float32)
|
|
202
|
+
l_i = ct.zeros((BLOCK_Q,), dtype=ct.float32)
|
|
203
|
+
acc = ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32)
|
|
204
|
+
|
|
205
|
+
scale_log2 = scale * INV_LOG_2
|
|
206
|
+
num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
|
|
207
|
+
|
|
208
|
+
for kv_block in range(num_kv_blocks):
|
|
209
|
+
if USE_SWIZZLE:
|
|
210
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
211
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
|
|
212
|
+
latency=2)
|
|
213
|
+
else:
|
|
214
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
215
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
216
|
+
|
|
217
|
+
# decompress V tile on-chip
|
|
218
|
+
v_idx = ct.load(V_Indices, index=(kv_block, 0),
|
|
219
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
220
|
+
v_nrm = ct.load(V_Norms, index=(kv_block,),
|
|
221
|
+
shape=(BLOCK_KV,), padding_mode=zero_pad)
|
|
222
|
+
|
|
223
|
+
vi_f32 = ct.astype(v_idx, ct.float32)
|
|
224
|
+
y_hat = ct.full((BLOCK_KV, HEAD_DIM), vc0, dtype=ct.float32)
|
|
225
|
+
y_hat = ct.where(vi_f32 > 0.5, vc1, y_hat)
|
|
226
|
+
y_hat = ct.where(vi_f32 > 1.5, vc2, y_hat)
|
|
227
|
+
y_hat = ct.where(vi_f32 > 2.5, vc3, y_hat)
|
|
228
|
+
y_hat = ct.where(vi_f32 > 3.5, vc4, y_hat)
|
|
229
|
+
y_hat = ct.where(vi_f32 > 4.5, vc5, y_hat)
|
|
230
|
+
y_hat = ct.where(vi_f32 > 5.5, vc6, y_hat)
|
|
231
|
+
y_hat = ct.where(vi_f32 > 6.5, vc7, y_hat)
|
|
232
|
+
|
|
233
|
+
v_recon = ct.mma(ct.astype(y_hat, ct.float16), pi_tile,
|
|
234
|
+
ct.zeros((BLOCK_KV, HEAD_DIM), dtype=ct.float32))
|
|
235
|
+
v_nrm_f32 = ct.astype(v_nrm, ct.float32)
|
|
236
|
+
v_tile = ct.astype(
|
|
237
|
+
v_recon * ct.expand_dims(v_nrm_f32, axis=1), ct.float16)
|
|
238
|
+
|
|
239
|
+
# attention scores
|
|
240
|
+
s_tile = ct.load(Signs, index=(kv_block, 0),
|
|
241
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
242
|
+
rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
|
|
243
|
+
padding_mode=zero_pad)
|
|
244
|
+
|
|
245
|
+
term1 = ct.mma(q_tile, ct.transpose(k_tile),
|
|
246
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
247
|
+
|
|
248
|
+
s_float = ct.astype(s_tile, ct.float16)
|
|
249
|
+
qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
|
|
250
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
251
|
+
rn_f32 = ct.astype(rn, ct.float32)
|
|
252
|
+
term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
|
|
253
|
+
|
|
254
|
+
raw_scores = term1 + term2
|
|
255
|
+
|
|
256
|
+
if USE_SWIZZLE:
|
|
257
|
+
m_new = ct.maximum(m_i, ct.max(raw_scores, axis=1) * scale_log2)
|
|
258
|
+
alpha = ct.exp2(m_i - m_new, flush_to_zero=True)
|
|
259
|
+
p = ct.exp2(raw_scores * scale_log2 - ct.expand_dims(m_new, axis=1),
|
|
260
|
+
flush_to_zero=True)
|
|
261
|
+
else:
|
|
262
|
+
scores = raw_scores * scale
|
|
263
|
+
m_new = ct.maximum(m_i, ct.max(scores, axis=1))
|
|
264
|
+
alpha = ct.exp(m_i - m_new)
|
|
265
|
+
p = ct.exp(scores - ct.expand_dims(m_new, axis=1))
|
|
266
|
+
|
|
267
|
+
l_i = alpha * l_i + ct.sum(p, axis=1)
|
|
268
|
+
|
|
269
|
+
p_fp16 = ct.astype(p, ct.float16)
|
|
270
|
+
acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(
|
|
271
|
+
p_fp16, v_tile, ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32))
|
|
272
|
+
|
|
273
|
+
m_i = m_new
|
|
274
|
+
|
|
275
|
+
if USE_SWIZZLE and _HAS_APPROX:
|
|
276
|
+
out = ct.truediv(acc, ct.expand_dims(l_i, axis=1),
|
|
277
|
+
flush_to_zero=True, rounding_mode=RMd.APPROX)
|
|
278
|
+
else:
|
|
279
|
+
out = acc / ct.expand_dims(l_i, axis=1)
|
|
280
|
+
|
|
281
|
+
ct.store(Output, index=(q_block, 0), tile=out)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# ── fused attention (on-chip 2-bit V decompression) ──────────────────
|
|
285
|
+
|
|
286
|
+
@ct.kernel(occupancy=2)
|
|
287
|
+
def turboquant_fused_attention_vfused_2bit(
|
|
288
|
+
Q, K_mse, Signs, R_norms, Q_proj,
|
|
289
|
+
V_Indices, V_Norms, Pi, Output,
|
|
290
|
+
scale: float,
|
|
291
|
+
correction_scale: float,
|
|
292
|
+
seq_k: int,
|
|
293
|
+
vc0: float, vc1: float, vc2: float, vc3: float,
|
|
294
|
+
USE_SWIZZLE: ConstBool,
|
|
295
|
+
):
|
|
296
|
+
"""fused attention that decompresses 2-bit V on-chip per block."""
|
|
297
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
298
|
+
|
|
299
|
+
if USE_SWIZZLE:
|
|
300
|
+
half = ct.cdiv(ct.num_blocks(0), 2)
|
|
301
|
+
q_block = ct.bid(0) // 2 + (ct.bid(0) % 2) * half
|
|
302
|
+
else:
|
|
303
|
+
q_block = ct.bid(0)
|
|
304
|
+
|
|
305
|
+
q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
306
|
+
padding_mode=zero_pad)
|
|
307
|
+
qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
|
|
308
|
+
padding_mode=zero_pad)
|
|
309
|
+
|
|
310
|
+
pi_tile = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
311
|
+
|
|
312
|
+
m_i = ct.full((BLOCK_Q,), -1e30, dtype=ct.float32)
|
|
313
|
+
l_i = ct.zeros((BLOCK_Q,), dtype=ct.float32)
|
|
314
|
+
acc = ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32)
|
|
315
|
+
|
|
316
|
+
scale_log2 = scale * INV_LOG_2
|
|
317
|
+
num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
|
|
318
|
+
|
|
319
|
+
for kv_block in range(num_kv_blocks):
|
|
320
|
+
if USE_SWIZZLE:
|
|
321
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
322
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
|
|
323
|
+
latency=2)
|
|
324
|
+
else:
|
|
325
|
+
k_tile = ct.load(K_mse, index=(kv_block, 0),
|
|
326
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
327
|
+
|
|
328
|
+
v_idx = ct.load(V_Indices, index=(kv_block, 0),
|
|
329
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
330
|
+
v_nrm = ct.load(V_Norms, index=(kv_block,),
|
|
331
|
+
shape=(BLOCK_KV,), padding_mode=zero_pad)
|
|
332
|
+
|
|
333
|
+
vi_f32 = ct.astype(v_idx, ct.float32)
|
|
334
|
+
y_hat = ct.full((BLOCK_KV, HEAD_DIM), vc0, dtype=ct.float32)
|
|
335
|
+
y_hat = ct.where(vi_f32 > 0.5, vc1, y_hat)
|
|
336
|
+
y_hat = ct.where(vi_f32 > 1.5, vc2, y_hat)
|
|
337
|
+
y_hat = ct.where(vi_f32 > 2.5, vc3, y_hat)
|
|
338
|
+
|
|
339
|
+
v_recon = ct.mma(ct.astype(y_hat, ct.float16), pi_tile,
|
|
340
|
+
ct.zeros((BLOCK_KV, HEAD_DIM), dtype=ct.float32))
|
|
341
|
+
v_nrm_f32 = ct.astype(v_nrm, ct.float32)
|
|
342
|
+
v_tile = ct.astype(
|
|
343
|
+
v_recon * ct.expand_dims(v_nrm_f32, axis=1), ct.float16)
|
|
344
|
+
|
|
345
|
+
s_tile = ct.load(Signs, index=(kv_block, 0),
|
|
346
|
+
shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
|
|
347
|
+
rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
|
|
348
|
+
padding_mode=zero_pad)
|
|
349
|
+
|
|
350
|
+
term1 = ct.mma(q_tile, ct.transpose(k_tile),
|
|
351
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
352
|
+
|
|
353
|
+
s_float = ct.astype(s_tile, ct.float16)
|
|
354
|
+
qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
|
|
355
|
+
ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
|
|
356
|
+
rn_f32 = ct.astype(rn, ct.float32)
|
|
357
|
+
term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
|
|
358
|
+
|
|
359
|
+
raw_scores = term1 + term2
|
|
360
|
+
|
|
361
|
+
if USE_SWIZZLE:
|
|
362
|
+
m_new = ct.maximum(m_i, ct.max(raw_scores, axis=1) * scale_log2)
|
|
363
|
+
alpha = ct.exp2(m_i - m_new, flush_to_zero=True)
|
|
364
|
+
p = ct.exp2(raw_scores * scale_log2 - ct.expand_dims(m_new, axis=1),
|
|
365
|
+
flush_to_zero=True)
|
|
366
|
+
else:
|
|
367
|
+
scores = raw_scores * scale
|
|
368
|
+
m_new = ct.maximum(m_i, ct.max(scores, axis=1))
|
|
369
|
+
alpha = ct.exp(m_i - m_new)
|
|
370
|
+
p = ct.exp(scores - ct.expand_dims(m_new, axis=1))
|
|
371
|
+
|
|
372
|
+
l_i = alpha * l_i + ct.sum(p, axis=1)
|
|
373
|
+
|
|
374
|
+
p_fp16 = ct.astype(p, ct.float16)
|
|
375
|
+
acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(
|
|
376
|
+
p_fp16, v_tile, ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32))
|
|
377
|
+
|
|
378
|
+
m_i = m_new
|
|
379
|
+
|
|
380
|
+
if USE_SWIZZLE and _HAS_APPROX:
|
|
381
|
+
out = ct.truediv(acc, ct.expand_dims(l_i, axis=1),
|
|
382
|
+
flush_to_zero=True, rounding_mode=RMd.APPROX)
|
|
383
|
+
else:
|
|
384
|
+
out = acc / ct.expand_dims(l_i, axis=1)
|
|
385
|
+
|
|
386
|
+
ct.store(Output, index=(q_block, 0), tile=out)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
lloyd-max optimal scalar quantizer.
|
|
3
|
+
|
|
4
|
+
after rotation each coordinate ~ N(0, 1/d). we solve lloyd-max
|
|
5
|
+
(1d k-means against that pdf) to get optimal centroids + boundaries.
|
|
6
|
+
runs once on cpu at init — the codebook is tiny.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
import torch
|
|
11
|
+
from scipy import integrate
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _gaussian_pdf(x, sigma):
|
|
15
|
+
return (1.0 / (math.sqrt(2 * math.pi) * sigma)) * math.exp(
|
|
16
|
+
-x * x / (2 * sigma * sigma)
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def solve_lloyd_max(d, bits, max_iter=200, tol=1e-10):
|
|
21
|
+
"""return (centroids, boundaries) as sorted float32 tensors."""
|
|
22
|
+
n_levels = 1 << bits
|
|
23
|
+
sigma = 1.0 / math.sqrt(d)
|
|
24
|
+
pdf = lambda x: _gaussian_pdf(x, sigma)
|
|
25
|
+
|
|
26
|
+
lo, hi = -3.5 * sigma, 3.5 * sigma
|
|
27
|
+
centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
|
|
28
|
+
|
|
29
|
+
for _ in range(max_iter):
|
|
30
|
+
boundaries = [
|
|
31
|
+
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
32
|
+
]
|
|
33
|
+
edges = [lo * 3] + boundaries + [hi * 3]
|
|
34
|
+
new_centroids = []
|
|
35
|
+
for i in range(n_levels):
|
|
36
|
+
a, b = edges[i], edges[i + 1]
|
|
37
|
+
num, _ = integrate.quad(lambda x: x * pdf(x), a, b)
|
|
38
|
+
den, _ = integrate.quad(pdf, a, b)
|
|
39
|
+
new_centroids.append(num / den if den > 1e-15 else centroids[i])
|
|
40
|
+
if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol:
|
|
41
|
+
break
|
|
42
|
+
centroids = new_centroids
|
|
43
|
+
|
|
44
|
+
boundaries = [
|
|
45
|
+
(centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
|
|
46
|
+
]
|
|
47
|
+
return (
|
|
48
|
+
torch.tensor(centroids, dtype=torch.float32),
|
|
49
|
+
torch.tensor(boundaries, dtype=torch.float32),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LloydMaxCodebook:
|
|
54
|
+
"""pre-solved lloyd-max codebook for a given (d, bits) pair."""
|
|
55
|
+
|
|
56
|
+
def __init__(self, d, bits):
|
|
57
|
+
self.d = d
|
|
58
|
+
self.bits = bits
|
|
59
|
+
self.n_levels = 1 << bits
|
|
60
|
+
self.centroids, self.boundaries = solve_lloyd_max(d, bits)
|
|
61
|
+
|
|
62
|
+
def quantize(self, x):
|
|
63
|
+
diffs = x.unsqueeze(-1) - self.centroids.to(x.device)
|
|
64
|
+
return diffs.abs().argmin(dim=-1)
|
|
65
|
+
|
|
66
|
+
def dequantize(self, indices):
|
|
67
|
+
return self.centroids.to(indices.device)[indices.long()]
|
|
68
|
+
|
|
69
|
+
def __repr__(self):
|
|
70
|
+
return f"LloydMaxCodebook(d={self.d}, bits={self.bits}, levels={self.n_levels})"
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""
|
|
2
|
+
key and value compression kernels.
|
|
3
|
+
|
|
4
|
+
pipeline per key token:
|
|
5
|
+
normalize → rotate (Pi^T) → lloyd-max quantize → store indices + norms
|
|
6
|
+
residual → QJL project (S^T) → store signs + residual norms
|
|
7
|
+
|
|
8
|
+
value compression is the same minus the QJL step.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import cuda.tile as ct
|
|
13
|
+
except ImportError:
|
|
14
|
+
import cutile as ct # type: ignore
|
|
15
|
+
|
|
16
|
+
from .constants import BLOCK_S, HEAD_DIM
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# ── key compression ──────────────────────────────────────────────────
|
|
20
|
+
|
|
21
|
+
@ct.kernel
|
|
22
|
+
def turboquant_compress_2bit(
|
|
23
|
+
K, Pi_T, Pi, S_T,
|
|
24
|
+
Indices, Signs, Norms, RNorms,
|
|
25
|
+
c0: float, c1: float, c2: float, c3: float,
|
|
26
|
+
b1: float, b2: float, b3: float,
|
|
27
|
+
seq_k: int,
|
|
28
|
+
):
|
|
29
|
+
"""2-bit mse (4 centroids) + 1-bit qjl. total_bits=3."""
|
|
30
|
+
block_id = ct.bid(0)
|
|
31
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
32
|
+
|
|
33
|
+
k_tile = ct.load(K, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
|
|
34
|
+
padding_mode=zero_pad)
|
|
35
|
+
pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
36
|
+
pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
37
|
+
s_t = ct.load(S_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
38
|
+
|
|
39
|
+
k_f32 = ct.astype(k_tile, ct.float32)
|
|
40
|
+
norms = ct.sqrt(ct.sum(k_f32 * k_f32, axis=1))
|
|
41
|
+
safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
|
|
42
|
+
k_normed = k_f32 / ct.expand_dims(safe_norms, axis=1)
|
|
43
|
+
|
|
44
|
+
# rotate into quantization basis
|
|
45
|
+
y = ct.mma(ct.astype(k_normed, ct.float16), pi_t,
|
|
46
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
47
|
+
|
|
48
|
+
# lloyd-max assignment via boundary comparisons
|
|
49
|
+
idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
|
|
50
|
+
idx = ct.where(y > b1, 1.0, idx)
|
|
51
|
+
idx = ct.where(y > b2, 2.0, idx)
|
|
52
|
+
idx = ct.where(y > b3, 3.0, idx)
|
|
53
|
+
|
|
54
|
+
# dequantize and un-rotate for mse reconstruction
|
|
55
|
+
y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
|
|
56
|
+
y_hat = ct.where(idx > 0.5, c1, y_hat)
|
|
57
|
+
y_hat = ct.where(idx > 1.5, c2, y_hat)
|
|
58
|
+
y_hat = ct.where(idx > 2.5, c3, y_hat)
|
|
59
|
+
|
|
60
|
+
k_bar_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
|
|
61
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
62
|
+
k_mse = k_bar_hat * ct.expand_dims(norms, axis=1)
|
|
63
|
+
|
|
64
|
+
# qjl: project residual for bias-correction signs
|
|
65
|
+
residual = k_f32 - k_mse
|
|
66
|
+
r_norms = ct.sqrt(ct.sum(residual * residual, axis=1))
|
|
67
|
+
projected = ct.mma(ct.astype(residual, ct.float16), s_t,
|
|
68
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
69
|
+
signs = ct.where(projected >= 0.0, 1.0, -1.0)
|
|
70
|
+
|
|
71
|
+
ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
|
|
72
|
+
ct.store(Signs, index=(block_id, 0), tile=ct.astype(signs, ct.int8))
|
|
73
|
+
ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
|
|
74
|
+
ct.store(RNorms, index=(block_id,), tile=ct.astype(r_norms, ct.float16))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@ct.kernel
|
|
78
|
+
def turboquant_compress_3bit(
|
|
79
|
+
K, Pi_T, Pi, S_T,
|
|
80
|
+
Indices, Signs, Norms, RNorms,
|
|
81
|
+
c0: float, c1: float, c2: float, c3: float,
|
|
82
|
+
c4: float, c5: float, c6: float, c7: float,
|
|
83
|
+
b1: float, b2: float, b3: float, b4: float,
|
|
84
|
+
b5: float, b6: float, b7: float,
|
|
85
|
+
seq_k: int,
|
|
86
|
+
):
|
|
87
|
+
"""3-bit mse (8 centroids) + 1-bit qjl. total_bits=4."""
|
|
88
|
+
block_id = ct.bid(0)
|
|
89
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
90
|
+
|
|
91
|
+
k_tile = ct.load(K, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
|
|
92
|
+
padding_mode=zero_pad)
|
|
93
|
+
pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
94
|
+
pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
95
|
+
s_t = ct.load(S_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
96
|
+
|
|
97
|
+
k_f32 = ct.astype(k_tile, ct.float32)
|
|
98
|
+
norms = ct.sqrt(ct.sum(k_f32 * k_f32, axis=1))
|
|
99
|
+
safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
|
|
100
|
+
k_normed = k_f32 / ct.expand_dims(safe_norms, axis=1)
|
|
101
|
+
|
|
102
|
+
y = ct.mma(ct.astype(k_normed, ct.float16), pi_t,
|
|
103
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
104
|
+
|
|
105
|
+
idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
|
|
106
|
+
idx = ct.where(y > b1, 1.0, idx)
|
|
107
|
+
idx = ct.where(y > b2, 2.0, idx)
|
|
108
|
+
idx = ct.where(y > b3, 3.0, idx)
|
|
109
|
+
idx = ct.where(y > b4, 4.0, idx)
|
|
110
|
+
idx = ct.where(y > b5, 5.0, idx)
|
|
111
|
+
idx = ct.where(y > b6, 6.0, idx)
|
|
112
|
+
idx = ct.where(y > b7, 7.0, idx)
|
|
113
|
+
|
|
114
|
+
y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
|
|
115
|
+
y_hat = ct.where(idx > 0.5, c1, y_hat)
|
|
116
|
+
y_hat = ct.where(idx > 1.5, c2, y_hat)
|
|
117
|
+
y_hat = ct.where(idx > 2.5, c3, y_hat)
|
|
118
|
+
y_hat = ct.where(idx > 3.5, c4, y_hat)
|
|
119
|
+
y_hat = ct.where(idx > 4.5, c5, y_hat)
|
|
120
|
+
y_hat = ct.where(idx > 5.5, c6, y_hat)
|
|
121
|
+
y_hat = ct.where(idx > 6.5, c7, y_hat)
|
|
122
|
+
|
|
123
|
+
k_bar_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
|
|
124
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
125
|
+
k_mse = k_bar_hat * ct.expand_dims(norms, axis=1)
|
|
126
|
+
|
|
127
|
+
residual = k_f32 - k_mse
|
|
128
|
+
r_norms = ct.sqrt(ct.sum(residual * residual, axis=1))
|
|
129
|
+
projected = ct.mma(ct.astype(residual, ct.float16), s_t,
|
|
130
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
131
|
+
signs = ct.where(projected >= 0.0, 1.0, -1.0)
|
|
132
|
+
|
|
133
|
+
ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
|
|
134
|
+
ct.store(Signs, index=(block_id, 0), tile=ct.astype(signs, ct.int8))
|
|
135
|
+
ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
|
|
136
|
+
ct.store(RNorms, index=(block_id,), tile=ct.astype(r_norms, ct.float16))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
# ── value compression (mse only, no qjl) ────────────────────────────
|
|
140
|
+
|
|
141
|
+
@ct.kernel
|
|
142
|
+
def turboquant_compress_values_3bit(
|
|
143
|
+
V, Pi_T,
|
|
144
|
+
Indices, Norms,
|
|
145
|
+
c0: float, c1: float, c2: float, c3: float,
|
|
146
|
+
c4: float, c5: float, c6: float, c7: float,
|
|
147
|
+
b1: float, b2: float, b3: float, b4: float,
|
|
148
|
+
b5: float, b6: float, b7: float,
|
|
149
|
+
seq_v: int,
|
|
150
|
+
):
|
|
151
|
+
"""3-bit value compression (8 levels)."""
|
|
152
|
+
block_id = ct.bid(0)
|
|
153
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
154
|
+
|
|
155
|
+
v_tile = ct.load(V, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
|
|
156
|
+
padding_mode=zero_pad)
|
|
157
|
+
pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
158
|
+
|
|
159
|
+
v_f32 = ct.astype(v_tile, ct.float32)
|
|
160
|
+
norms = ct.sqrt(ct.sum(v_f32 * v_f32, axis=1))
|
|
161
|
+
safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
|
|
162
|
+
v_normed = v_f32 / ct.expand_dims(safe_norms, axis=1)
|
|
163
|
+
|
|
164
|
+
y = ct.mma(ct.astype(v_normed, ct.float16), pi_t,
|
|
165
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
166
|
+
|
|
167
|
+
idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
|
|
168
|
+
idx = ct.where(y > b1, 1.0, idx)
|
|
169
|
+
idx = ct.where(y > b2, 2.0, idx)
|
|
170
|
+
idx = ct.where(y > b3, 3.0, idx)
|
|
171
|
+
idx = ct.where(y > b4, 4.0, idx)
|
|
172
|
+
idx = ct.where(y > b5, 5.0, idx)
|
|
173
|
+
idx = ct.where(y > b6, 6.0, idx)
|
|
174
|
+
idx = ct.where(y > b7, 7.0, idx)
|
|
175
|
+
|
|
176
|
+
ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
|
|
177
|
+
ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@ct.kernel
|
|
181
|
+
def turboquant_compress_values_2bit(
|
|
182
|
+
V, Pi_T,
|
|
183
|
+
Indices, Norms,
|
|
184
|
+
c0: float, c1: float, c2: float, c3: float,
|
|
185
|
+
b1: float, b2: float, b3: float,
|
|
186
|
+
seq_v: int,
|
|
187
|
+
):
|
|
188
|
+
"""2-bit value compression (4 levels)."""
|
|
189
|
+
block_id = ct.bid(0)
|
|
190
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
191
|
+
|
|
192
|
+
v_tile = ct.load(V, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
|
|
193
|
+
padding_mode=zero_pad)
|
|
194
|
+
pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
195
|
+
|
|
196
|
+
v_f32 = ct.astype(v_tile, ct.float32)
|
|
197
|
+
norms = ct.sqrt(ct.sum(v_f32 * v_f32, axis=1))
|
|
198
|
+
safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
|
|
199
|
+
v_normed = v_f32 / ct.expand_dims(safe_norms, axis=1)
|
|
200
|
+
|
|
201
|
+
y = ct.mma(ct.astype(v_normed, ct.float16), pi_t,
|
|
202
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
203
|
+
|
|
204
|
+
idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
|
|
205
|
+
idx = ct.where(y > b1, 1.0, idx)
|
|
206
|
+
idx = ct.where(y > b2, 2.0, idx)
|
|
207
|
+
idx = ct.where(y > b3, 3.0, idx)
|
|
208
|
+
|
|
209
|
+
ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
|
|
210
|
+
ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
value decompression kernels.
|
|
3
|
+
|
|
4
|
+
indices → dequantize via centroids → un-rotate (Pi) → scale by norms.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import cuda.tile as ct
|
|
9
|
+
except ImportError:
|
|
10
|
+
import cutile as ct # type: ignore
|
|
11
|
+
|
|
12
|
+
from .constants import BLOCK_S, HEAD_DIM
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@ct.kernel
|
|
16
|
+
def turboquant_decompress_3bit(
|
|
17
|
+
Indices, Norms, Pi,
|
|
18
|
+
Output,
|
|
19
|
+
c0: float, c1: float, c2: float, c3: float,
|
|
20
|
+
c4: float, c5: float, c6: float, c7: float,
|
|
21
|
+
seq_k: int,
|
|
22
|
+
):
|
|
23
|
+
block_id = ct.bid(0)
|
|
24
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
25
|
+
|
|
26
|
+
idx_tile = ct.load(Indices, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
|
|
27
|
+
padding_mode=zero_pad)
|
|
28
|
+
norm_tile = ct.load(Norms, index=(block_id,), shape=(BLOCK_S,),
|
|
29
|
+
padding_mode=zero_pad)
|
|
30
|
+
pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
31
|
+
|
|
32
|
+
idx_f32 = ct.astype(idx_tile, ct.float32)
|
|
33
|
+
|
|
34
|
+
y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
|
|
35
|
+
y_hat = ct.where(idx_f32 > 0.5, c1, y_hat)
|
|
36
|
+
y_hat = ct.where(idx_f32 > 1.5, c2, y_hat)
|
|
37
|
+
y_hat = ct.where(idx_f32 > 2.5, c3, y_hat)
|
|
38
|
+
y_hat = ct.where(idx_f32 > 3.5, c4, y_hat)
|
|
39
|
+
y_hat = ct.where(idx_f32 > 4.5, c5, y_hat)
|
|
40
|
+
y_hat = ct.where(idx_f32 > 5.5, c6, y_hat)
|
|
41
|
+
y_hat = ct.where(idx_f32 > 6.5, c7, y_hat)
|
|
42
|
+
|
|
43
|
+
x_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
|
|
44
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
45
|
+
result = x_hat * ct.expand_dims(ct.astype(norm_tile, ct.float32), axis=1)
|
|
46
|
+
|
|
47
|
+
ct.store(Output, index=(block_id, 0), tile=ct.astype(result, ct.float16))
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@ct.kernel
|
|
51
|
+
def turboquant_decompress_2bit(
|
|
52
|
+
Indices, Norms, Pi,
|
|
53
|
+
Output,
|
|
54
|
+
c0: float, c1: float, c2: float, c3: float,
|
|
55
|
+
seq_k: int,
|
|
56
|
+
):
|
|
57
|
+
block_id = ct.bid(0)
|
|
58
|
+
zero_pad = ct.PaddingMode.ZERO
|
|
59
|
+
|
|
60
|
+
idx_tile = ct.load(Indices, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
|
|
61
|
+
padding_mode=zero_pad)
|
|
62
|
+
norm_tile = ct.load(Norms, index=(block_id,), shape=(BLOCK_S,),
|
|
63
|
+
padding_mode=zero_pad)
|
|
64
|
+
pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
|
|
65
|
+
|
|
66
|
+
idx_f32 = ct.astype(idx_tile, ct.float32)
|
|
67
|
+
|
|
68
|
+
y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
|
|
69
|
+
y_hat = ct.where(idx_f32 > 0.5, c1, y_hat)
|
|
70
|
+
y_hat = ct.where(idx_f32 > 1.5, c2, y_hat)
|
|
71
|
+
y_hat = ct.where(idx_f32 > 2.5, c3, y_hat)
|
|
72
|
+
|
|
73
|
+
x_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
|
|
74
|
+
ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
|
|
75
|
+
result = x_hat * ct.expand_dims(ct.astype(norm_tile, ct.float32), axis=1)
|
|
76
|
+
|
|
77
|
+
ct.store(Output, index=(block_id, 0), tile=ct.astype(result, ct.float16))
|
turboquant_gpu/host.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""
|
|
2
|
+
host-side engine: compress / decompress KV caches, run generation.
|
|
3
|
+
|
|
4
|
+
tries cutile kernels first; falls back to pytorch if the driver or
|
|
5
|
+
import is unavailable.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .codebook import LloydMaxCodebook
|
|
12
|
+
from .constants import BLOCK_S, DEFAULT_SEED, DEFAULT_TOTAL_BITS, HEAD_DIM
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _rotation_matrix(d, seed, device="cpu"):
|
|
16
|
+
gen = torch.Generator(device="cpu")
|
|
17
|
+
gen.manual_seed(seed)
|
|
18
|
+
G = torch.randn(d, d, generator=gen)
|
|
19
|
+
Q, R = torch.linalg.qr(G)
|
|
20
|
+
diag_sign = torch.sign(torch.diag(R))
|
|
21
|
+
diag_sign[diag_sign == 0] = 1.0
|
|
22
|
+
return (Q * diag_sign.unsqueeze(0)).to(device)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _qjl_matrix(d, seed, device="cpu"):
|
|
26
|
+
gen = torch.Generator(device="cpu")
|
|
27
|
+
gen.manual_seed(seed + 10000)
|
|
28
|
+
return torch.randn(d, d, generator=gen).to(device)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TurboQuantEngine:
|
|
32
|
+
def __init__(self, head_dim=HEAD_DIM, total_bits=DEFAULT_TOTAL_BITS,
|
|
33
|
+
seed=DEFAULT_SEED, device="cpu"):
|
|
34
|
+
self.head_dim = head_dim
|
|
35
|
+
self.total_bits = total_bits
|
|
36
|
+
self.mse_bits = max(total_bits - 1, 1)
|
|
37
|
+
self.device = device
|
|
38
|
+
|
|
39
|
+
self.Pi = _rotation_matrix(head_dim, seed, device)
|
|
40
|
+
self.PiT = self.Pi.T.contiguous()
|
|
41
|
+
self.S = _qjl_matrix(head_dim, seed, device)
|
|
42
|
+
self.ST = self.S.T.contiguous()
|
|
43
|
+
|
|
44
|
+
self.key_codebook = LloydMaxCodebook(head_dim, self.mse_bits)
|
|
45
|
+
self.val_codebook = LloydMaxCodebook(head_dim, total_bits)
|
|
46
|
+
|
|
47
|
+
# ── public api ───────────────────────────────────────────────────
|
|
48
|
+
|
|
49
|
+
@torch.no_grad()
|
|
50
|
+
def compress_kv_cache(self, past_key_values):
|
|
51
|
+
"""compress a full KV cache from a huggingface model forward pass."""
|
|
52
|
+
kv_keys, kv_vals = self._extract_kv(past_key_values)
|
|
53
|
+
|
|
54
|
+
layers = []
|
|
55
|
+
for li in range(len(kv_keys)):
|
|
56
|
+
n_heads = kv_keys[li].shape[1]
|
|
57
|
+
ck = [self._compress_keys(kv_keys[li][0, h].half().contiguous())
|
|
58
|
+
for h in range(n_heads)]
|
|
59
|
+
cv = [self._compress_values(kv_vals[li][0, h].half().contiguous())
|
|
60
|
+
for h in range(n_heads)]
|
|
61
|
+
layers.append((ck, cv))
|
|
62
|
+
|
|
63
|
+
return {"layers": layers}
|
|
64
|
+
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def build_cache(self, compressed):
|
|
67
|
+
"""reconstruct a DynamicCache from compressed K (via k_mse) and V."""
|
|
68
|
+
from transformers import DynamicCache
|
|
69
|
+
cache = DynamicCache()
|
|
70
|
+
for li, (ck_list, cv_list) in enumerate(compressed["layers"]):
|
|
71
|
+
k_heads = [ck["k_mse"] for ck in ck_list]
|
|
72
|
+
k_layer = torch.stack(k_heads).unsqueeze(0)
|
|
73
|
+
v_heads = [self._decompress_values(cv) for cv in cv_list]
|
|
74
|
+
v_layer = torch.stack(v_heads).unsqueeze(0)
|
|
75
|
+
cache.update(k_layer, v_layer, li)
|
|
76
|
+
return cache
|
|
77
|
+
|
|
78
|
+
@torch.no_grad()
|
|
79
|
+
def generate(self, model, tokenizer, prompt, max_new_tokens=100,
|
|
80
|
+
repetition_penalty=1.3):
|
|
81
|
+
"""prefill → compress KV → decode with repetition penalty."""
|
|
82
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
|
|
83
|
+
out = model(**inputs, use_cache=True)
|
|
84
|
+
|
|
85
|
+
compressed = self.compress_kv_cache(out.past_key_values)
|
|
86
|
+
stats = self.compression_stats(out.past_key_values)
|
|
87
|
+
cache = self.build_cache(compressed)
|
|
88
|
+
|
|
89
|
+
seq_len = inputs["input_ids"].shape[1]
|
|
90
|
+
all_ids = inputs["input_ids"][0].tolist()
|
|
91
|
+
|
|
92
|
+
next_tok = out.logits[:, -1:].argmax(dim=-1)
|
|
93
|
+
all_ids.append(next_tok.item())
|
|
94
|
+
|
|
95
|
+
for step in range(max_new_tokens - 1):
|
|
96
|
+
o = model(input_ids=next_tok, past_key_values=cache,
|
|
97
|
+
position_ids=torch.tensor([[seq_len + step]],
|
|
98
|
+
device=self.device),
|
|
99
|
+
use_cache=True)
|
|
100
|
+
cache = o.past_key_values
|
|
101
|
+
logits = o.logits[:, -1, :]
|
|
102
|
+
|
|
103
|
+
if repetition_penalty != 1.0:
|
|
104
|
+
for tid in set(all_ids):
|
|
105
|
+
if logits[0, tid] > 0:
|
|
106
|
+
logits[0, tid] /= repetition_penalty
|
|
107
|
+
else:
|
|
108
|
+
logits[0, tid] *= repetition_penalty
|
|
109
|
+
|
|
110
|
+
next_tok = logits.argmax(dim=-1, keepdim=True)
|
|
111
|
+
all_ids.append(next_tok.item())
|
|
112
|
+
if next_tok.item() == tokenizer.eos_token_id:
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
n_new = len(all_ids) - seq_len
|
|
116
|
+
text = tokenizer.decode(all_ids, skip_special_tokens=True)
|
|
117
|
+
return {"text": text, "tokens": n_new, "stats": stats}
|
|
118
|
+
|
|
119
|
+
def compression_stats(self, past_key_values):
|
|
120
|
+
"""return compression ratio and byte counts."""
|
|
121
|
+
kv_keys, kv_vals = self._extract_kv(past_key_values)
|
|
122
|
+
n_layers = len(kv_keys)
|
|
123
|
+
n_heads = kv_keys[0].shape[1]
|
|
124
|
+
seq_len = kv_keys[0].shape[2]
|
|
125
|
+
|
|
126
|
+
fp16_bytes = sum(k.nbytes + v.nbytes for k, v in zip(kv_keys, kv_vals))
|
|
127
|
+
tq_bytes = self._compressed_bytes(seq_len) * n_heads * n_layers
|
|
128
|
+
|
|
129
|
+
return {
|
|
130
|
+
"seq_len": seq_len, "n_layers": n_layers, "n_heads": n_heads,
|
|
131
|
+
"fp16_bytes": fp16_bytes, "tq_bytes": tq_bytes,
|
|
132
|
+
"ratio": fp16_bytes / tq_bytes,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
# ── internals ────────────────────────────────────────────────────
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def _extract_kv(past_key_values):
|
|
139
|
+
"""handle both DynamicCache and list-of-tuples formats."""
|
|
140
|
+
try:
|
|
141
|
+
return past_key_values.key_cache, past_key_values.value_cache
|
|
142
|
+
except AttributeError:
|
|
143
|
+
keys = [kv[0] for kv in past_key_values]
|
|
144
|
+
vals = [kv[1] for kv in past_key_values]
|
|
145
|
+
return keys, vals
|
|
146
|
+
|
|
147
|
+
def _compress_keys(self, K):
|
|
148
|
+
try:
|
|
149
|
+
from .compress import turboquant_compress_2bit, turboquant_compress_3bit
|
|
150
|
+
import cuda.tile as ct
|
|
151
|
+
|
|
152
|
+
seq_k, d = K.shape
|
|
153
|
+
indices = torch.empty(seq_k, d, dtype=torch.uint8, device=K.device)
|
|
154
|
+
signs = torch.empty(seq_k, d, dtype=torch.int8, device=K.device)
|
|
155
|
+
norms = torch.empty(seq_k, dtype=torch.float16, device=K.device)
|
|
156
|
+
r_norms = torch.empty(seq_k, dtype=torch.float16, device=K.device)
|
|
157
|
+
|
|
158
|
+
grid = (self._cdiv(seq_k, BLOCK_S), 1, 1)
|
|
159
|
+
c = self.key_codebook.centroids.tolist()
|
|
160
|
+
b = self.key_codebook.boundaries.tolist()
|
|
161
|
+
stream = torch.cuda.current_stream()
|
|
162
|
+
|
|
163
|
+
if self.mse_bits == 2:
|
|
164
|
+
ct.launch(stream, grid, turboquant_compress_2bit, (
|
|
165
|
+
K, self.PiT.half(), self.Pi.half(), self.ST.half(),
|
|
166
|
+
indices, signs, norms, r_norms, *c, *b, seq_k))
|
|
167
|
+
elif self.mse_bits == 3:
|
|
168
|
+
ct.launch(stream, grid, turboquant_compress_3bit, (
|
|
169
|
+
K, self.PiT.half(), self.Pi.half(), self.ST.half(),
|
|
170
|
+
indices, signs, norms, r_norms, *c, *b, seq_k))
|
|
171
|
+
else:
|
|
172
|
+
return self._compress_keys_pt(K)
|
|
173
|
+
|
|
174
|
+
k_mse = self._dequant_keys(indices, norms)
|
|
175
|
+
return {"indices": indices, "k_mse": k_mse, "qjl_signs": signs,
|
|
176
|
+
"vec_norms": norms, "residual_norms": r_norms}
|
|
177
|
+
except (ImportError, RuntimeError):
|
|
178
|
+
return self._compress_keys_pt(K)
|
|
179
|
+
|
|
180
|
+
def _compress_values(self, V):
|
|
181
|
+
try:
|
|
182
|
+
from .compress import (turboquant_compress_values_3bit,
|
|
183
|
+
turboquant_compress_values_2bit)
|
|
184
|
+
import cuda.tile as ct
|
|
185
|
+
|
|
186
|
+
seq_v, d = V.shape
|
|
187
|
+
indices = torch.empty(seq_v, d, dtype=torch.uint8, device=V.device)
|
|
188
|
+
norms = torch.empty(seq_v, dtype=torch.float16, device=V.device)
|
|
189
|
+
|
|
190
|
+
grid = (self._cdiv(seq_v, BLOCK_S), 1, 1)
|
|
191
|
+
c = self.val_codebook.centroids.tolist()
|
|
192
|
+
b = self.val_codebook.boundaries.tolist()
|
|
193
|
+
stream = torch.cuda.current_stream()
|
|
194
|
+
|
|
195
|
+
if self.total_bits == 3:
|
|
196
|
+
ct.launch(stream, grid, turboquant_compress_values_3bit, (
|
|
197
|
+
V, self.PiT.half(), indices, norms, *c, *b, seq_v))
|
|
198
|
+
elif self.total_bits == 2:
|
|
199
|
+
ct.launch(stream, grid, turboquant_compress_values_2bit, (
|
|
200
|
+
V, self.PiT.half(), indices, norms, *c, *b, seq_v))
|
|
201
|
+
else:
|
|
202
|
+
return self._compress_values_pt(V)
|
|
203
|
+
|
|
204
|
+
return {"indices": indices, "vec_norms": norms}
|
|
205
|
+
except (ImportError, RuntimeError):
|
|
206
|
+
return self._compress_values_pt(V)
|
|
207
|
+
|
|
208
|
+
def _decompress_values(self, cv):
|
|
209
|
+
try:
|
|
210
|
+
from .decompress import (turboquant_decompress_3bit,
|
|
211
|
+
turboquant_decompress_2bit)
|
|
212
|
+
import cuda.tile as ct
|
|
213
|
+
|
|
214
|
+
indices = cv["indices"]
|
|
215
|
+
norms = cv["vec_norms"]
|
|
216
|
+
seq_v = indices.shape[0]
|
|
217
|
+
output = torch.empty(seq_v, self.head_dim, dtype=torch.float16,
|
|
218
|
+
device=indices.device)
|
|
219
|
+
|
|
220
|
+
grid = (self._cdiv(seq_v, BLOCK_S), 1, 1)
|
|
221
|
+
c = self.val_codebook.centroids.tolist()
|
|
222
|
+
stream = torch.cuda.current_stream()
|
|
223
|
+
|
|
224
|
+
if self.total_bits == 3:
|
|
225
|
+
ct.launch(stream, grid, turboquant_decompress_3bit, (
|
|
226
|
+
indices, norms, self.Pi.half(), output, *c, seq_v))
|
|
227
|
+
elif self.total_bits == 2:
|
|
228
|
+
ct.launch(stream, grid, turboquant_decompress_2bit, (
|
|
229
|
+
indices, norms, self.Pi.half(), output, *c, seq_v))
|
|
230
|
+
else:
|
|
231
|
+
return self._decompress_values_pt(cv)
|
|
232
|
+
return output
|
|
233
|
+
except (ImportError, RuntimeError):
|
|
234
|
+
return self._decompress_values_pt(cv)
|
|
235
|
+
|
|
236
|
+
# ── pytorch fallbacks ────────────────────────────────────────────
|
|
237
|
+
|
|
238
|
+
def _compress_keys_pt(self, K):
|
|
239
|
+
K_f = K.float()
|
|
240
|
+
norms = torch.norm(K_f, dim=-1, keepdim=True)
|
|
241
|
+
K_normed = K_f / (norms + 1e-8)
|
|
242
|
+
rotated = K_normed @ self.PiT.float()
|
|
243
|
+
|
|
244
|
+
c = self.key_codebook.centroids.to(K.device)
|
|
245
|
+
indices = (rotated.unsqueeze(-1) - c).abs().argmin(dim=-1).to(torch.uint8)
|
|
246
|
+
|
|
247
|
+
y_hat = c[indices.long()]
|
|
248
|
+
k_mse = (y_hat @ self.Pi.float()) * norms
|
|
249
|
+
residual = K_f - k_mse
|
|
250
|
+
r_norms = torch.norm(residual, dim=-1)
|
|
251
|
+
|
|
252
|
+
signs = torch.sign(residual @ self.ST.float()).to(torch.int8)
|
|
253
|
+
signs[signs == 0] = 1
|
|
254
|
+
|
|
255
|
+
return {"indices": indices, "k_mse": k_mse.half(), "qjl_signs": signs,
|
|
256
|
+
"vec_norms": norms.squeeze(-1).half(),
|
|
257
|
+
"residual_norms": r_norms.half()}
|
|
258
|
+
|
|
259
|
+
def _compress_values_pt(self, V):
|
|
260
|
+
V_f = V.float()
|
|
261
|
+
norms = torch.norm(V_f, dim=-1, keepdim=True)
|
|
262
|
+
V_normed = V_f / (norms + 1e-8)
|
|
263
|
+
rotated = V_normed @ self.PiT.float()
|
|
264
|
+
|
|
265
|
+
c = self.val_codebook.centroids.to(V.device)
|
|
266
|
+
indices = (rotated.unsqueeze(-1) - c).abs().argmin(dim=-1).to(torch.uint8)
|
|
267
|
+
|
|
268
|
+
return {"indices": indices, "vec_norms": norms.squeeze(-1).half()}
|
|
269
|
+
|
|
270
|
+
def _decompress_values_pt(self, cv):
|
|
271
|
+
c = self.val_codebook.centroids.to(cv["indices"].device)
|
|
272
|
+
y_hat = c[cv["indices"].long()]
|
|
273
|
+
norms = cv["vec_norms"].float().unsqueeze(-1)
|
|
274
|
+
return ((y_hat @ self.Pi.float()) * norms).half()
|
|
275
|
+
|
|
276
|
+
def _dequant_keys(self, indices, norms):
|
|
277
|
+
c = self.key_codebook.centroids.to(indices.device)
|
|
278
|
+
y_hat = c[indices.long()]
|
|
279
|
+
return ((y_hat.float() @ self.Pi.float()) * norms.float().unsqueeze(-1)).half()
|
|
280
|
+
|
|
281
|
+
def _cdiv(self, a, b):
|
|
282
|
+
return (a + b - 1) // b
|
|
283
|
+
|
|
284
|
+
def _compressed_bytes(self, seq_len):
|
|
285
|
+
d = self.head_dim
|
|
286
|
+
key_bytes = (seq_len * d * self.mse_bits + seq_len * d + seq_len * 32) / 8
|
|
287
|
+
val_bytes = (seq_len * d * self.total_bits + seq_len * 16) / 8
|
|
288
|
+
return key_bytes + val_bytes
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: turboquant-gpu
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: TurboQuant KV cache compression for LLM inference — cuTile GPU kernels
|
|
5
|
+
Author: Anirudh Bharadwaj Vangara
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: quantization,kv-cache,llm,inference,cutile,cuda,gpu,attention,blackwell,hopper,h100,b200
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch
|
|
19
|
+
Requires-Dist: scipy
|
|
20
|
+
Provides-Extra: gpu
|
|
21
|
+
Requires-Dist: cuda-tile; extra == "gpu"
|
|
22
|
+
Dynamic: license-file
|
|
23
|
+
|
|
24
|
+
# turboquant-gpu
|
|
25
|
+
|
|
26
|
+
**5.02x KV cache compression for LLM inference** — GPU-accelerated cuTile kernels with PyTorch fallback.
|
|
27
|
+
|
|
28
|
+
```
|
|
29
|
+
pip install turboquant-gpu
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## quick start
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
36
|
+
from turboquant_gpu import TurboQuantEngine
|
|
37
|
+
import torch
|
|
38
|
+
|
|
39
|
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B", torch_dtype=torch.float16, device_map="cuda")
|
|
40
|
+
tok = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
|
|
41
|
+
|
|
42
|
+
engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
|
|
43
|
+
result = engine.generate(model, tok, "The key to efficient LLM inference is")
|
|
44
|
+
|
|
45
|
+
print(result["text"])
|
|
46
|
+
print(f"{result['tokens']} tokens | {result['stats']['ratio']:.1f}x compression")
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
## how it works
|
|
50
|
+
|
|
51
|
+
Implements the [TurboQuant](https://arxiv.org/abs/2501.09747) algorithm:
|
|
52
|
+
|
|
53
|
+
1. **normalize + rotate** — random orthogonal rotation (Pi) makes coordinates near-Gaussian
|
|
54
|
+
2. **Lloyd-Max quantize** — optimal scalar quantization against N(0, 1/d)
|
|
55
|
+
3. **QJL bias correction** — 1-bit sign sketch of the residual for unbiased key scores
|
|
56
|
+
|
|
57
|
+
At 3-bit (2-bit MSE + 1-bit QJL) this gives ~5x compression with negligible quality loss.
|
|
58
|
+
|
|
59
|
+
## step-by-step api
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
|
|
63
|
+
|
|
64
|
+
# after model prefill:
|
|
65
|
+
compressed = engine.compress_kv_cache(out.past_key_values)
|
|
66
|
+
cache = engine.build_cache(compressed)
|
|
67
|
+
stats = engine.compression_stats(out.past_key_values)
|
|
68
|
+
|
|
69
|
+
# or just do it all in one call:
|
|
70
|
+
result = engine.generate(model, tokenizer, "your prompt here")
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
## gpu support
|
|
74
|
+
|
|
75
|
+
Written in [cuTile](https://docs.nvidia.com/cuda/cutile-python/) for cross-architecture portability.
|
|
76
|
+
Falls back to PyTorch if cuTile or a compatible driver isn't available.
|
|
77
|
+
|
|
78
|
+
| GPU family | Architecture | Status |
|
|
79
|
+
|------------|-------------|--------|
|
|
80
|
+
| A100 | Ampere | supported (PyTorch fallback) |
|
|
81
|
+
| H100 | Hopper | supported |
|
|
82
|
+
| B200/B300 | Blackwell | supported + swizzle fast path |
|
|
83
|
+
|
|
84
|
+
## kernels
|
|
85
|
+
|
|
86
|
+
| kernel | what it does |
|
|
87
|
+
|--------|-------------|
|
|
88
|
+
| `compress_keys` | normalize → rotate(Pi^T) → Lloyd-Max quantize → QJL signs |
|
|
89
|
+
| `compress_values` | normalize → rotate(Pi^T) → Lloyd-Max quantize |
|
|
90
|
+
| `decompress_values` | dequantize → un-rotate(Pi) → scale by norms |
|
|
91
|
+
| `attention_scores` | asymmetric dot product with QJL correction |
|
|
92
|
+
| `fused_attention` | scores + online softmax + V accumulation |
|
|
93
|
+
|
|
94
|
+
## license
|
|
95
|
+
|
|
96
|
+
MIT
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
turboquant_gpu/__init__.py,sha256=hVDM-NQiB_PzgSyqQXi2KoMEUspka_PKvw55yGPhODk,278
|
|
2
|
+
turboquant_gpu/attention.py,sha256=O13thg9bO7h9OIySgmhYnv61aWidWmC5wkCKKVuh4L8,15147
|
|
3
|
+
turboquant_gpu/codebook.py,sha256=cJAOa_qnUBAJqMwNrHIK2igbmX61cQ4gbVXGfr9hyRk,2240
|
|
4
|
+
turboquant_gpu/compress.py,sha256=yP6BYrk29yAG8dMxkT2rUet1Z8qThOLRb8AWpwwbu6s,8003
|
|
5
|
+
turboquant_gpu/constants.py,sha256=VSmWSTSlxO0e0QTfvtxso41EvJTR0gCKJK4QL7POXpM,133
|
|
6
|
+
turboquant_gpu/decompress.py,sha256=b8VKHC3ZmmzFC_CKgP_BmlL0OGooFI30ECYWM2DXbPo,2581
|
|
7
|
+
turboquant_gpu/host.py,sha256=mVzm4Vhl3NT0clN4XXR4Z3VpEZ9PyFRqwxtexmTTvng,11848
|
|
8
|
+
turboquant_gpu-0.1.0.dist-info/licenses/LICENSE,sha256=jdrItB2_D55fIzIsSe7Esi-l2ykHOQS1CmAr0Y_fBAA,1082
|
|
9
|
+
turboquant_gpu-0.1.0.dist-info/METADATA,sha256=9dqKsJja0ErrdAffd9rXw4r3e--y4SoSkGeHG2e2dKI,3314
|
|
10
|
+
turboquant_gpu-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
11
|
+
turboquant_gpu-0.1.0.dist-info/top_level.txt,sha256=ZrloYBosuQLyIN7iBgLJS9NENKJMkYszJU3zi7aCIBI,15
|
|
12
|
+
turboquant_gpu-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Anirudh Bharadwaj Vangara
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
turboquant_gpu
|