turboquant-gpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ __version__ = "0.1.0"
2
+
3
+ def __getattr__(name):
4
+ if name == "TurboQuantEngine":
5
+ from .host import TurboQuantEngine
6
+ return TurboQuantEngine
7
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
8
+
9
+ __all__ = ["__version__", "TurboQuantEngine"]
@@ -0,0 +1,386 @@
1
+ """
2
+ asymmetric attention kernels for compressed keys.
3
+
4
+ score(q, k) ~ <q, k_mse> + ||r|| * sqrt(pi/2)/m * <S*q, signs>
5
+
6
+ score-only variant outputs raw scores.
7
+ fused variant does online softmax + V accumulation in one pass.
8
+ USE_SWIZZLE enables exp2 ftz, approx div, prefetched loads, and block
9
+ interleaving on architectures that support them.
10
+ """
11
+
12
+ import math
13
+
14
+ try:
15
+ import cuda.tile as ct
16
+ from cuda.tile import RoundingMode as RMd
17
+ _HAS_APPROX = hasattr(RMd, "APPROX")
18
+ except ImportError:
19
+ import cutile as ct # type: ignore
20
+ RMd = None
21
+ _HAS_APPROX = False
22
+
23
+ from .constants import BLOCK_Q, BLOCK_KV, HEAD_DIM
24
+
25
+ INV_LOG_2 = 1.0 / math.log(2)
26
+ ConstBool = ct.Constant[bool]
27
+
28
+
29
+ # ── score-only kernel ────────────────────────────────────────────────
30
+
31
+ @ct.kernel(occupancy=2)
32
+ def turboquant_attention_scores(
33
+ Q, K_mse, Signs, R_norms, Q_proj, Output,
34
+ scale: float,
35
+ correction_scale: float,
36
+ seq_k: int,
37
+ USE_SWIZZLE: ConstBool,
38
+ ):
39
+ """one program per query block, streams over all KV blocks."""
40
+ q_block = ct.bid(0)
41
+ zero_pad = ct.PaddingMode.ZERO
42
+
43
+ q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
44
+ padding_mode=zero_pad)
45
+ qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
46
+ padding_mode=zero_pad)
47
+
48
+ num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
49
+
50
+ for kv_block in range(num_kv_blocks):
51
+ if USE_SWIZZLE:
52
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
53
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
54
+ latency=2)
55
+ else:
56
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
57
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
58
+
59
+ # mse term: Q @ K_mse^T
60
+ term1 = ct.mma(q_tile, ct.transpose(k_tile),
61
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
62
+
63
+ # qjl bias-correction term
64
+ s_tile = ct.load(Signs, index=(kv_block, 0),
65
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
66
+ s_float = ct.astype(s_tile, ct.float16)
67
+
68
+ qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
69
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
70
+
71
+ rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
72
+ padding_mode=zero_pad)
73
+ rn_f32 = ct.astype(rn, ct.float32)
74
+
75
+ term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
76
+
77
+ scores = (term1 + term2) * scale
78
+ ct.store(Output, index=(q_block, kv_block), tile=scores)
79
+
80
+
81
+ # ── fused attention (pre-decompressed V) ─────────────────────────────
82
+
83
+ @ct.kernel(occupancy=2)
84
+ def turboquant_fused_attention(
85
+ Q, K_mse, Signs, R_norms, Q_proj, V, Output,
86
+ scale: float,
87
+ correction_scale: float,
88
+ seq_k: int,
89
+ USE_SWIZZLE: ConstBool,
90
+ ):
91
+ """fused attention + online softmax. V is already fp16."""
92
+ zero_pad = ct.PaddingMode.ZERO
93
+
94
+ if USE_SWIZZLE:
95
+ half = ct.cdiv(ct.num_blocks(0), 2)
96
+ q_block = ct.bid(0) // 2 + (ct.bid(0) % 2) * half
97
+ else:
98
+ q_block = ct.bid(0)
99
+
100
+ q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
101
+ padding_mode=zero_pad)
102
+ qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
103
+ padding_mode=zero_pad)
104
+
105
+ # online softmax accumulators
106
+ m_i = ct.full((BLOCK_Q,), -1e30, dtype=ct.float32)
107
+ l_i = ct.zeros((BLOCK_Q,), dtype=ct.float32)
108
+ acc = ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32)
109
+
110
+ scale_log2 = scale * INV_LOG_2
111
+ num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
112
+
113
+ for kv_block in range(num_kv_blocks):
114
+ if USE_SWIZZLE:
115
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
116
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
117
+ latency=2)
118
+ v_tile = ct.load(V, index=(kv_block, 0),
119
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
120
+ latency=4)
121
+ else:
122
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
123
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
124
+ v_tile = ct.load(V, index=(kv_block, 0),
125
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
126
+
127
+ s_tile = ct.load(Signs, index=(kv_block, 0),
128
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
129
+ rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
130
+ padding_mode=zero_pad)
131
+
132
+ term1 = ct.mma(q_tile, ct.transpose(k_tile),
133
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
134
+
135
+ s_float = ct.astype(s_tile, ct.float16)
136
+ qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
137
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
138
+ rn_f32 = ct.astype(rn, ct.float32)
139
+ term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
140
+
141
+ raw_scores = term1 + term2
142
+
143
+ # online softmax update
144
+ if USE_SWIZZLE:
145
+ m_new = ct.maximum(m_i, ct.max(raw_scores, axis=1) * scale_log2)
146
+ alpha = ct.exp2(m_i - m_new, flush_to_zero=True)
147
+ p = ct.exp2(raw_scores * scale_log2 - ct.expand_dims(m_new, axis=1),
148
+ flush_to_zero=True)
149
+ else:
150
+ scores = raw_scores * scale
151
+ m_new = ct.maximum(m_i, ct.max(scores, axis=1))
152
+ alpha = ct.exp(m_i - m_new)
153
+ p = ct.exp(scores - ct.expand_dims(m_new, axis=1))
154
+
155
+ l_i = alpha * l_i + ct.sum(p, axis=1)
156
+
157
+ p_fp16 = ct.astype(p, ct.float16)
158
+ acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(
159
+ p_fp16, v_tile, ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32))
160
+
161
+ m_i = m_new
162
+
163
+ if USE_SWIZZLE and _HAS_APPROX:
164
+ out = ct.truediv(acc, ct.expand_dims(l_i, axis=1),
165
+ flush_to_zero=True, rounding_mode=RMd.APPROX)
166
+ else:
167
+ out = acc / ct.expand_dims(l_i, axis=1)
168
+
169
+ ct.store(Output, index=(q_block, 0), tile=out)
170
+
171
+
172
+ # ── fused attention (on-chip 3-bit V decompression) ──────────────────
173
+
174
+ @ct.kernel(occupancy=2)
175
+ def turboquant_fused_attention_vfused_3bit(
176
+ Q, K_mse, Signs, R_norms, Q_proj,
177
+ V_Indices, V_Norms, Pi, Output,
178
+ scale: float,
179
+ correction_scale: float,
180
+ seq_k: int,
181
+ vc0: float, vc1: float, vc2: float, vc3: float,
182
+ vc4: float, vc5: float, vc6: float, vc7: float,
183
+ USE_SWIZZLE: ConstBool,
184
+ ):
185
+ """fused attention that decompresses 3-bit V on-chip per block."""
186
+ zero_pad = ct.PaddingMode.ZERO
187
+
188
+ if USE_SWIZZLE:
189
+ half = ct.cdiv(ct.num_blocks(0), 2)
190
+ q_block = ct.bid(0) // 2 + (ct.bid(0) % 2) * half
191
+ else:
192
+ q_block = ct.bid(0)
193
+
194
+ q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
195
+ padding_mode=zero_pad)
196
+ qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
197
+ padding_mode=zero_pad)
198
+
199
+ pi_tile = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
200
+
201
+ m_i = ct.full((BLOCK_Q,), -1e30, dtype=ct.float32)
202
+ l_i = ct.zeros((BLOCK_Q,), dtype=ct.float32)
203
+ acc = ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32)
204
+
205
+ scale_log2 = scale * INV_LOG_2
206
+ num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
207
+
208
+ for kv_block in range(num_kv_blocks):
209
+ if USE_SWIZZLE:
210
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
211
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
212
+ latency=2)
213
+ else:
214
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
215
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
216
+
217
+ # decompress V tile on-chip
218
+ v_idx = ct.load(V_Indices, index=(kv_block, 0),
219
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
220
+ v_nrm = ct.load(V_Norms, index=(kv_block,),
221
+ shape=(BLOCK_KV,), padding_mode=zero_pad)
222
+
223
+ vi_f32 = ct.astype(v_idx, ct.float32)
224
+ y_hat = ct.full((BLOCK_KV, HEAD_DIM), vc0, dtype=ct.float32)
225
+ y_hat = ct.where(vi_f32 > 0.5, vc1, y_hat)
226
+ y_hat = ct.where(vi_f32 > 1.5, vc2, y_hat)
227
+ y_hat = ct.where(vi_f32 > 2.5, vc3, y_hat)
228
+ y_hat = ct.where(vi_f32 > 3.5, vc4, y_hat)
229
+ y_hat = ct.where(vi_f32 > 4.5, vc5, y_hat)
230
+ y_hat = ct.where(vi_f32 > 5.5, vc6, y_hat)
231
+ y_hat = ct.where(vi_f32 > 6.5, vc7, y_hat)
232
+
233
+ v_recon = ct.mma(ct.astype(y_hat, ct.float16), pi_tile,
234
+ ct.zeros((BLOCK_KV, HEAD_DIM), dtype=ct.float32))
235
+ v_nrm_f32 = ct.astype(v_nrm, ct.float32)
236
+ v_tile = ct.astype(
237
+ v_recon * ct.expand_dims(v_nrm_f32, axis=1), ct.float16)
238
+
239
+ # attention scores
240
+ s_tile = ct.load(Signs, index=(kv_block, 0),
241
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
242
+ rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
243
+ padding_mode=zero_pad)
244
+
245
+ term1 = ct.mma(q_tile, ct.transpose(k_tile),
246
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
247
+
248
+ s_float = ct.astype(s_tile, ct.float16)
249
+ qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
250
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
251
+ rn_f32 = ct.astype(rn, ct.float32)
252
+ term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
253
+
254
+ raw_scores = term1 + term2
255
+
256
+ if USE_SWIZZLE:
257
+ m_new = ct.maximum(m_i, ct.max(raw_scores, axis=1) * scale_log2)
258
+ alpha = ct.exp2(m_i - m_new, flush_to_zero=True)
259
+ p = ct.exp2(raw_scores * scale_log2 - ct.expand_dims(m_new, axis=1),
260
+ flush_to_zero=True)
261
+ else:
262
+ scores = raw_scores * scale
263
+ m_new = ct.maximum(m_i, ct.max(scores, axis=1))
264
+ alpha = ct.exp(m_i - m_new)
265
+ p = ct.exp(scores - ct.expand_dims(m_new, axis=1))
266
+
267
+ l_i = alpha * l_i + ct.sum(p, axis=1)
268
+
269
+ p_fp16 = ct.astype(p, ct.float16)
270
+ acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(
271
+ p_fp16, v_tile, ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32))
272
+
273
+ m_i = m_new
274
+
275
+ if USE_SWIZZLE and _HAS_APPROX:
276
+ out = ct.truediv(acc, ct.expand_dims(l_i, axis=1),
277
+ flush_to_zero=True, rounding_mode=RMd.APPROX)
278
+ else:
279
+ out = acc / ct.expand_dims(l_i, axis=1)
280
+
281
+ ct.store(Output, index=(q_block, 0), tile=out)
282
+
283
+
284
+ # ── fused attention (on-chip 2-bit V decompression) ──────────────────
285
+
286
+ @ct.kernel(occupancy=2)
287
+ def turboquant_fused_attention_vfused_2bit(
288
+ Q, K_mse, Signs, R_norms, Q_proj,
289
+ V_Indices, V_Norms, Pi, Output,
290
+ scale: float,
291
+ correction_scale: float,
292
+ seq_k: int,
293
+ vc0: float, vc1: float, vc2: float, vc3: float,
294
+ USE_SWIZZLE: ConstBool,
295
+ ):
296
+ """fused attention that decompresses 2-bit V on-chip per block."""
297
+ zero_pad = ct.PaddingMode.ZERO
298
+
299
+ if USE_SWIZZLE:
300
+ half = ct.cdiv(ct.num_blocks(0), 2)
301
+ q_block = ct.bid(0) // 2 + (ct.bid(0) % 2) * half
302
+ else:
303
+ q_block = ct.bid(0)
304
+
305
+ q_tile = ct.load(Q, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
306
+ padding_mode=zero_pad)
307
+ qp_tile = ct.load(Q_proj, index=(q_block, 0), shape=(BLOCK_Q, HEAD_DIM),
308
+ padding_mode=zero_pad)
309
+
310
+ pi_tile = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
311
+
312
+ m_i = ct.full((BLOCK_Q,), -1e30, dtype=ct.float32)
313
+ l_i = ct.zeros((BLOCK_Q,), dtype=ct.float32)
314
+ acc = ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32)
315
+
316
+ scale_log2 = scale * INV_LOG_2
317
+ num_kv_blocks = ct.num_tiles(K_mse, axis=0, shape=(BLOCK_KV, HEAD_DIM))
318
+
319
+ for kv_block in range(num_kv_blocks):
320
+ if USE_SWIZZLE:
321
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
322
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad,
323
+ latency=2)
324
+ else:
325
+ k_tile = ct.load(K_mse, index=(kv_block, 0),
326
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
327
+
328
+ v_idx = ct.load(V_Indices, index=(kv_block, 0),
329
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
330
+ v_nrm = ct.load(V_Norms, index=(kv_block,),
331
+ shape=(BLOCK_KV,), padding_mode=zero_pad)
332
+
333
+ vi_f32 = ct.astype(v_idx, ct.float32)
334
+ y_hat = ct.full((BLOCK_KV, HEAD_DIM), vc0, dtype=ct.float32)
335
+ y_hat = ct.where(vi_f32 > 0.5, vc1, y_hat)
336
+ y_hat = ct.where(vi_f32 > 1.5, vc2, y_hat)
337
+ y_hat = ct.where(vi_f32 > 2.5, vc3, y_hat)
338
+
339
+ v_recon = ct.mma(ct.astype(y_hat, ct.float16), pi_tile,
340
+ ct.zeros((BLOCK_KV, HEAD_DIM), dtype=ct.float32))
341
+ v_nrm_f32 = ct.astype(v_nrm, ct.float32)
342
+ v_tile = ct.astype(
343
+ v_recon * ct.expand_dims(v_nrm_f32, axis=1), ct.float16)
344
+
345
+ s_tile = ct.load(Signs, index=(kv_block, 0),
346
+ shape=(BLOCK_KV, HEAD_DIM), padding_mode=zero_pad)
347
+ rn = ct.load(R_norms, index=(kv_block,), shape=(BLOCK_KV,),
348
+ padding_mode=zero_pad)
349
+
350
+ term1 = ct.mma(q_tile, ct.transpose(k_tile),
351
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
352
+
353
+ s_float = ct.astype(s_tile, ct.float16)
354
+ qjl_ip = ct.mma(qp_tile, ct.transpose(s_float),
355
+ ct.zeros((BLOCK_Q, BLOCK_KV), dtype=ct.float32))
356
+ rn_f32 = ct.astype(rn, ct.float32)
357
+ term2 = correction_scale * qjl_ip * ct.expand_dims(rn_f32, axis=0)
358
+
359
+ raw_scores = term1 + term2
360
+
361
+ if USE_SWIZZLE:
362
+ m_new = ct.maximum(m_i, ct.max(raw_scores, axis=1) * scale_log2)
363
+ alpha = ct.exp2(m_i - m_new, flush_to_zero=True)
364
+ p = ct.exp2(raw_scores * scale_log2 - ct.expand_dims(m_new, axis=1),
365
+ flush_to_zero=True)
366
+ else:
367
+ scores = raw_scores * scale
368
+ m_new = ct.maximum(m_i, ct.max(scores, axis=1))
369
+ alpha = ct.exp(m_i - m_new)
370
+ p = ct.exp(scores - ct.expand_dims(m_new, axis=1))
371
+
372
+ l_i = alpha * l_i + ct.sum(p, axis=1)
373
+
374
+ p_fp16 = ct.astype(p, ct.float16)
375
+ acc = ct.expand_dims(alpha, axis=1) * acc + ct.mma(
376
+ p_fp16, v_tile, ct.zeros((BLOCK_Q, HEAD_DIM), dtype=ct.float32))
377
+
378
+ m_i = m_new
379
+
380
+ if USE_SWIZZLE and _HAS_APPROX:
381
+ out = ct.truediv(acc, ct.expand_dims(l_i, axis=1),
382
+ flush_to_zero=True, rounding_mode=RMd.APPROX)
383
+ else:
384
+ out = acc / ct.expand_dims(l_i, axis=1)
385
+
386
+ ct.store(Output, index=(q_block, 0), tile=out)
@@ -0,0 +1,70 @@
1
+ """
2
+ lloyd-max optimal scalar quantizer.
3
+
4
+ after rotation each coordinate ~ N(0, 1/d). we solve lloyd-max
5
+ (1d k-means against that pdf) to get optimal centroids + boundaries.
6
+ runs once on cpu at init — the codebook is tiny.
7
+ """
8
+
9
+ import math
10
+ import torch
11
+ from scipy import integrate
12
+
13
+
14
+ def _gaussian_pdf(x, sigma):
15
+ return (1.0 / (math.sqrt(2 * math.pi) * sigma)) * math.exp(
16
+ -x * x / (2 * sigma * sigma)
17
+ )
18
+
19
+
20
+ def solve_lloyd_max(d, bits, max_iter=200, tol=1e-10):
21
+ """return (centroids, boundaries) as sorted float32 tensors."""
22
+ n_levels = 1 << bits
23
+ sigma = 1.0 / math.sqrt(d)
24
+ pdf = lambda x: _gaussian_pdf(x, sigma)
25
+
26
+ lo, hi = -3.5 * sigma, 3.5 * sigma
27
+ centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)]
28
+
29
+ for _ in range(max_iter):
30
+ boundaries = [
31
+ (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
32
+ ]
33
+ edges = [lo * 3] + boundaries + [hi * 3]
34
+ new_centroids = []
35
+ for i in range(n_levels):
36
+ a, b = edges[i], edges[i + 1]
37
+ num, _ = integrate.quad(lambda x: x * pdf(x), a, b)
38
+ den, _ = integrate.quad(pdf, a, b)
39
+ new_centroids.append(num / den if den > 1e-15 else centroids[i])
40
+ if max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) < tol:
41
+ break
42
+ centroids = new_centroids
43
+
44
+ boundaries = [
45
+ (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1)
46
+ ]
47
+ return (
48
+ torch.tensor(centroids, dtype=torch.float32),
49
+ torch.tensor(boundaries, dtype=torch.float32),
50
+ )
51
+
52
+
53
+ class LloydMaxCodebook:
54
+ """pre-solved lloyd-max codebook for a given (d, bits) pair."""
55
+
56
+ def __init__(self, d, bits):
57
+ self.d = d
58
+ self.bits = bits
59
+ self.n_levels = 1 << bits
60
+ self.centroids, self.boundaries = solve_lloyd_max(d, bits)
61
+
62
+ def quantize(self, x):
63
+ diffs = x.unsqueeze(-1) - self.centroids.to(x.device)
64
+ return diffs.abs().argmin(dim=-1)
65
+
66
+ def dequantize(self, indices):
67
+ return self.centroids.to(indices.device)[indices.long()]
68
+
69
+ def __repr__(self):
70
+ return f"LloydMaxCodebook(d={self.d}, bits={self.bits}, levels={self.n_levels})"
@@ -0,0 +1,210 @@
1
+ """
2
+ key and value compression kernels.
3
+
4
+ pipeline per key token:
5
+ normalize → rotate (Pi^T) → lloyd-max quantize → store indices + norms
6
+ residual → QJL project (S^T) → store signs + residual norms
7
+
8
+ value compression is the same minus the QJL step.
9
+ """
10
+
11
+ try:
12
+ import cuda.tile as ct
13
+ except ImportError:
14
+ import cutile as ct # type: ignore
15
+
16
+ from .constants import BLOCK_S, HEAD_DIM
17
+
18
+
19
+ # ── key compression ──────────────────────────────────────────────────
20
+
21
+ @ct.kernel
22
+ def turboquant_compress_2bit(
23
+ K, Pi_T, Pi, S_T,
24
+ Indices, Signs, Norms, RNorms,
25
+ c0: float, c1: float, c2: float, c3: float,
26
+ b1: float, b2: float, b3: float,
27
+ seq_k: int,
28
+ ):
29
+ """2-bit mse (4 centroids) + 1-bit qjl. total_bits=3."""
30
+ block_id = ct.bid(0)
31
+ zero_pad = ct.PaddingMode.ZERO
32
+
33
+ k_tile = ct.load(K, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
34
+ padding_mode=zero_pad)
35
+ pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
36
+ pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
37
+ s_t = ct.load(S_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
38
+
39
+ k_f32 = ct.astype(k_tile, ct.float32)
40
+ norms = ct.sqrt(ct.sum(k_f32 * k_f32, axis=1))
41
+ safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
42
+ k_normed = k_f32 / ct.expand_dims(safe_norms, axis=1)
43
+
44
+ # rotate into quantization basis
45
+ y = ct.mma(ct.astype(k_normed, ct.float16), pi_t,
46
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
47
+
48
+ # lloyd-max assignment via boundary comparisons
49
+ idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
50
+ idx = ct.where(y > b1, 1.0, idx)
51
+ idx = ct.where(y > b2, 2.0, idx)
52
+ idx = ct.where(y > b3, 3.0, idx)
53
+
54
+ # dequantize and un-rotate for mse reconstruction
55
+ y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
56
+ y_hat = ct.where(idx > 0.5, c1, y_hat)
57
+ y_hat = ct.where(idx > 1.5, c2, y_hat)
58
+ y_hat = ct.where(idx > 2.5, c3, y_hat)
59
+
60
+ k_bar_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
61
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
62
+ k_mse = k_bar_hat * ct.expand_dims(norms, axis=1)
63
+
64
+ # qjl: project residual for bias-correction signs
65
+ residual = k_f32 - k_mse
66
+ r_norms = ct.sqrt(ct.sum(residual * residual, axis=1))
67
+ projected = ct.mma(ct.astype(residual, ct.float16), s_t,
68
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
69
+ signs = ct.where(projected >= 0.0, 1.0, -1.0)
70
+
71
+ ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
72
+ ct.store(Signs, index=(block_id, 0), tile=ct.astype(signs, ct.int8))
73
+ ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
74
+ ct.store(RNorms, index=(block_id,), tile=ct.astype(r_norms, ct.float16))
75
+
76
+
77
+ @ct.kernel
78
+ def turboquant_compress_3bit(
79
+ K, Pi_T, Pi, S_T,
80
+ Indices, Signs, Norms, RNorms,
81
+ c0: float, c1: float, c2: float, c3: float,
82
+ c4: float, c5: float, c6: float, c7: float,
83
+ b1: float, b2: float, b3: float, b4: float,
84
+ b5: float, b6: float, b7: float,
85
+ seq_k: int,
86
+ ):
87
+ """3-bit mse (8 centroids) + 1-bit qjl. total_bits=4."""
88
+ block_id = ct.bid(0)
89
+ zero_pad = ct.PaddingMode.ZERO
90
+
91
+ k_tile = ct.load(K, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
92
+ padding_mode=zero_pad)
93
+ pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
94
+ pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
95
+ s_t = ct.load(S_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
96
+
97
+ k_f32 = ct.astype(k_tile, ct.float32)
98
+ norms = ct.sqrt(ct.sum(k_f32 * k_f32, axis=1))
99
+ safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
100
+ k_normed = k_f32 / ct.expand_dims(safe_norms, axis=1)
101
+
102
+ y = ct.mma(ct.astype(k_normed, ct.float16), pi_t,
103
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
104
+
105
+ idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
106
+ idx = ct.where(y > b1, 1.0, idx)
107
+ idx = ct.where(y > b2, 2.0, idx)
108
+ idx = ct.where(y > b3, 3.0, idx)
109
+ idx = ct.where(y > b4, 4.0, idx)
110
+ idx = ct.where(y > b5, 5.0, idx)
111
+ idx = ct.where(y > b6, 6.0, idx)
112
+ idx = ct.where(y > b7, 7.0, idx)
113
+
114
+ y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
115
+ y_hat = ct.where(idx > 0.5, c1, y_hat)
116
+ y_hat = ct.where(idx > 1.5, c2, y_hat)
117
+ y_hat = ct.where(idx > 2.5, c3, y_hat)
118
+ y_hat = ct.where(idx > 3.5, c4, y_hat)
119
+ y_hat = ct.where(idx > 4.5, c5, y_hat)
120
+ y_hat = ct.where(idx > 5.5, c6, y_hat)
121
+ y_hat = ct.where(idx > 6.5, c7, y_hat)
122
+
123
+ k_bar_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
124
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
125
+ k_mse = k_bar_hat * ct.expand_dims(norms, axis=1)
126
+
127
+ residual = k_f32 - k_mse
128
+ r_norms = ct.sqrt(ct.sum(residual * residual, axis=1))
129
+ projected = ct.mma(ct.astype(residual, ct.float16), s_t,
130
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
131
+ signs = ct.where(projected >= 0.0, 1.0, -1.0)
132
+
133
+ ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
134
+ ct.store(Signs, index=(block_id, 0), tile=ct.astype(signs, ct.int8))
135
+ ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
136
+ ct.store(RNorms, index=(block_id,), tile=ct.astype(r_norms, ct.float16))
137
+
138
+
139
+ # ── value compression (mse only, no qjl) ────────────────────────────
140
+
141
+ @ct.kernel
142
+ def turboquant_compress_values_3bit(
143
+ V, Pi_T,
144
+ Indices, Norms,
145
+ c0: float, c1: float, c2: float, c3: float,
146
+ c4: float, c5: float, c6: float, c7: float,
147
+ b1: float, b2: float, b3: float, b4: float,
148
+ b5: float, b6: float, b7: float,
149
+ seq_v: int,
150
+ ):
151
+ """3-bit value compression (8 levels)."""
152
+ block_id = ct.bid(0)
153
+ zero_pad = ct.PaddingMode.ZERO
154
+
155
+ v_tile = ct.load(V, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
156
+ padding_mode=zero_pad)
157
+ pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
158
+
159
+ v_f32 = ct.astype(v_tile, ct.float32)
160
+ norms = ct.sqrt(ct.sum(v_f32 * v_f32, axis=1))
161
+ safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
162
+ v_normed = v_f32 / ct.expand_dims(safe_norms, axis=1)
163
+
164
+ y = ct.mma(ct.astype(v_normed, ct.float16), pi_t,
165
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
166
+
167
+ idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
168
+ idx = ct.where(y > b1, 1.0, idx)
169
+ idx = ct.where(y > b2, 2.0, idx)
170
+ idx = ct.where(y > b3, 3.0, idx)
171
+ idx = ct.where(y > b4, 4.0, idx)
172
+ idx = ct.where(y > b5, 5.0, idx)
173
+ idx = ct.where(y > b6, 6.0, idx)
174
+ idx = ct.where(y > b7, 7.0, idx)
175
+
176
+ ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
177
+ ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
178
+
179
+
180
+ @ct.kernel
181
+ def turboquant_compress_values_2bit(
182
+ V, Pi_T,
183
+ Indices, Norms,
184
+ c0: float, c1: float, c2: float, c3: float,
185
+ b1: float, b2: float, b3: float,
186
+ seq_v: int,
187
+ ):
188
+ """2-bit value compression (4 levels)."""
189
+ block_id = ct.bid(0)
190
+ zero_pad = ct.PaddingMode.ZERO
191
+
192
+ v_tile = ct.load(V, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
193
+ padding_mode=zero_pad)
194
+ pi_t = ct.load(Pi_T, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
195
+
196
+ v_f32 = ct.astype(v_tile, ct.float32)
197
+ norms = ct.sqrt(ct.sum(v_f32 * v_f32, axis=1))
198
+ safe_norms = ct.where(norms > 1e-8, norms, 1e-8)
199
+ v_normed = v_f32 / ct.expand_dims(safe_norms, axis=1)
200
+
201
+ y = ct.mma(ct.astype(v_normed, ct.float16), pi_t,
202
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
203
+
204
+ idx = ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32)
205
+ idx = ct.where(y > b1, 1.0, idx)
206
+ idx = ct.where(y > b2, 2.0, idx)
207
+ idx = ct.where(y > b3, 3.0, idx)
208
+
209
+ ct.store(Indices, index=(block_id, 0), tile=ct.astype(idx, ct.uint8))
210
+ ct.store(Norms, index=(block_id,), tile=ct.astype(norms, ct.float16))
@@ -0,0 +1,8 @@
1
+ HEAD_DIM = 128
2
+ BLOCK_Q = 16
3
+ BLOCK_KV = 64
4
+ BLOCK_S = 64
5
+
6
+ SUPPORTED_MSE_BITS = {1, 2, 3, 4}
7
+ DEFAULT_TOTAL_BITS = 3
8
+ DEFAULT_SEED = 42
@@ -0,0 +1,77 @@
1
+ """
2
+ value decompression kernels.
3
+
4
+ indices → dequantize via centroids → un-rotate (Pi) → scale by norms.
5
+ """
6
+
7
+ try:
8
+ import cuda.tile as ct
9
+ except ImportError:
10
+ import cutile as ct # type: ignore
11
+
12
+ from .constants import BLOCK_S, HEAD_DIM
13
+
14
+
15
+ @ct.kernel
16
+ def turboquant_decompress_3bit(
17
+ Indices, Norms, Pi,
18
+ Output,
19
+ c0: float, c1: float, c2: float, c3: float,
20
+ c4: float, c5: float, c6: float, c7: float,
21
+ seq_k: int,
22
+ ):
23
+ block_id = ct.bid(0)
24
+ zero_pad = ct.PaddingMode.ZERO
25
+
26
+ idx_tile = ct.load(Indices, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
27
+ padding_mode=zero_pad)
28
+ norm_tile = ct.load(Norms, index=(block_id,), shape=(BLOCK_S,),
29
+ padding_mode=zero_pad)
30
+ pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
31
+
32
+ idx_f32 = ct.astype(idx_tile, ct.float32)
33
+
34
+ y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
35
+ y_hat = ct.where(idx_f32 > 0.5, c1, y_hat)
36
+ y_hat = ct.where(idx_f32 > 1.5, c2, y_hat)
37
+ y_hat = ct.where(idx_f32 > 2.5, c3, y_hat)
38
+ y_hat = ct.where(idx_f32 > 3.5, c4, y_hat)
39
+ y_hat = ct.where(idx_f32 > 4.5, c5, y_hat)
40
+ y_hat = ct.where(idx_f32 > 5.5, c6, y_hat)
41
+ y_hat = ct.where(idx_f32 > 6.5, c7, y_hat)
42
+
43
+ x_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
44
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
45
+ result = x_hat * ct.expand_dims(ct.astype(norm_tile, ct.float32), axis=1)
46
+
47
+ ct.store(Output, index=(block_id, 0), tile=ct.astype(result, ct.float16))
48
+
49
+
50
+ @ct.kernel
51
+ def turboquant_decompress_2bit(
52
+ Indices, Norms, Pi,
53
+ Output,
54
+ c0: float, c1: float, c2: float, c3: float,
55
+ seq_k: int,
56
+ ):
57
+ block_id = ct.bid(0)
58
+ zero_pad = ct.PaddingMode.ZERO
59
+
60
+ idx_tile = ct.load(Indices, index=(block_id, 0), shape=(BLOCK_S, HEAD_DIM),
61
+ padding_mode=zero_pad)
62
+ norm_tile = ct.load(Norms, index=(block_id,), shape=(BLOCK_S,),
63
+ padding_mode=zero_pad)
64
+ pi = ct.load(Pi, index=(0, 0), shape=(HEAD_DIM, HEAD_DIM))
65
+
66
+ idx_f32 = ct.astype(idx_tile, ct.float32)
67
+
68
+ y_hat = ct.full((BLOCK_S, HEAD_DIM), c0, dtype=ct.float32)
69
+ y_hat = ct.where(idx_f32 > 0.5, c1, y_hat)
70
+ y_hat = ct.where(idx_f32 > 1.5, c2, y_hat)
71
+ y_hat = ct.where(idx_f32 > 2.5, c3, y_hat)
72
+
73
+ x_hat = ct.mma(ct.astype(y_hat, ct.float16), pi,
74
+ ct.zeros((BLOCK_S, HEAD_DIM), dtype=ct.float32))
75
+ result = x_hat * ct.expand_dims(ct.astype(norm_tile, ct.float32), axis=1)
76
+
77
+ ct.store(Output, index=(block_id, 0), tile=ct.astype(result, ct.float16))
turboquant_gpu/host.py ADDED
@@ -0,0 +1,288 @@
1
+ """
2
+ host-side engine: compress / decompress KV caches, run generation.
3
+
4
+ tries cutile kernels first; falls back to pytorch if the driver or
5
+ import is unavailable.
6
+ """
7
+
8
+ import math
9
+ import torch
10
+
11
+ from .codebook import LloydMaxCodebook
12
+ from .constants import BLOCK_S, DEFAULT_SEED, DEFAULT_TOTAL_BITS, HEAD_DIM
13
+
14
+
15
+ def _rotation_matrix(d, seed, device="cpu"):
16
+ gen = torch.Generator(device="cpu")
17
+ gen.manual_seed(seed)
18
+ G = torch.randn(d, d, generator=gen)
19
+ Q, R = torch.linalg.qr(G)
20
+ diag_sign = torch.sign(torch.diag(R))
21
+ diag_sign[diag_sign == 0] = 1.0
22
+ return (Q * diag_sign.unsqueeze(0)).to(device)
23
+
24
+
25
+ def _qjl_matrix(d, seed, device="cpu"):
26
+ gen = torch.Generator(device="cpu")
27
+ gen.manual_seed(seed + 10000)
28
+ return torch.randn(d, d, generator=gen).to(device)
29
+
30
+
31
+ class TurboQuantEngine:
32
+ def __init__(self, head_dim=HEAD_DIM, total_bits=DEFAULT_TOTAL_BITS,
33
+ seed=DEFAULT_SEED, device="cpu"):
34
+ self.head_dim = head_dim
35
+ self.total_bits = total_bits
36
+ self.mse_bits = max(total_bits - 1, 1)
37
+ self.device = device
38
+
39
+ self.Pi = _rotation_matrix(head_dim, seed, device)
40
+ self.PiT = self.Pi.T.contiguous()
41
+ self.S = _qjl_matrix(head_dim, seed, device)
42
+ self.ST = self.S.T.contiguous()
43
+
44
+ self.key_codebook = LloydMaxCodebook(head_dim, self.mse_bits)
45
+ self.val_codebook = LloydMaxCodebook(head_dim, total_bits)
46
+
47
+ # ── public api ───────────────────────────────────────────────────
48
+
49
+ @torch.no_grad()
50
+ def compress_kv_cache(self, past_key_values):
51
+ """compress a full KV cache from a huggingface model forward pass."""
52
+ kv_keys, kv_vals = self._extract_kv(past_key_values)
53
+
54
+ layers = []
55
+ for li in range(len(kv_keys)):
56
+ n_heads = kv_keys[li].shape[1]
57
+ ck = [self._compress_keys(kv_keys[li][0, h].half().contiguous())
58
+ for h in range(n_heads)]
59
+ cv = [self._compress_values(kv_vals[li][0, h].half().contiguous())
60
+ for h in range(n_heads)]
61
+ layers.append((ck, cv))
62
+
63
+ return {"layers": layers}
64
+
65
+ @torch.no_grad()
66
+ def build_cache(self, compressed):
67
+ """reconstruct a DynamicCache from compressed K (via k_mse) and V."""
68
+ from transformers import DynamicCache
69
+ cache = DynamicCache()
70
+ for li, (ck_list, cv_list) in enumerate(compressed["layers"]):
71
+ k_heads = [ck["k_mse"] for ck in ck_list]
72
+ k_layer = torch.stack(k_heads).unsqueeze(0)
73
+ v_heads = [self._decompress_values(cv) for cv in cv_list]
74
+ v_layer = torch.stack(v_heads).unsqueeze(0)
75
+ cache.update(k_layer, v_layer, li)
76
+ return cache
77
+
78
+ @torch.no_grad()
79
+ def generate(self, model, tokenizer, prompt, max_new_tokens=100,
80
+ repetition_penalty=1.3):
81
+ """prefill → compress KV → decode with repetition penalty."""
82
+ inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
83
+ out = model(**inputs, use_cache=True)
84
+
85
+ compressed = self.compress_kv_cache(out.past_key_values)
86
+ stats = self.compression_stats(out.past_key_values)
87
+ cache = self.build_cache(compressed)
88
+
89
+ seq_len = inputs["input_ids"].shape[1]
90
+ all_ids = inputs["input_ids"][0].tolist()
91
+
92
+ next_tok = out.logits[:, -1:].argmax(dim=-1)
93
+ all_ids.append(next_tok.item())
94
+
95
+ for step in range(max_new_tokens - 1):
96
+ o = model(input_ids=next_tok, past_key_values=cache,
97
+ position_ids=torch.tensor([[seq_len + step]],
98
+ device=self.device),
99
+ use_cache=True)
100
+ cache = o.past_key_values
101
+ logits = o.logits[:, -1, :]
102
+
103
+ if repetition_penalty != 1.0:
104
+ for tid in set(all_ids):
105
+ if logits[0, tid] > 0:
106
+ logits[0, tid] /= repetition_penalty
107
+ else:
108
+ logits[0, tid] *= repetition_penalty
109
+
110
+ next_tok = logits.argmax(dim=-1, keepdim=True)
111
+ all_ids.append(next_tok.item())
112
+ if next_tok.item() == tokenizer.eos_token_id:
113
+ break
114
+
115
+ n_new = len(all_ids) - seq_len
116
+ text = tokenizer.decode(all_ids, skip_special_tokens=True)
117
+ return {"text": text, "tokens": n_new, "stats": stats}
118
+
119
+ def compression_stats(self, past_key_values):
120
+ """return compression ratio and byte counts."""
121
+ kv_keys, kv_vals = self._extract_kv(past_key_values)
122
+ n_layers = len(kv_keys)
123
+ n_heads = kv_keys[0].shape[1]
124
+ seq_len = kv_keys[0].shape[2]
125
+
126
+ fp16_bytes = sum(k.nbytes + v.nbytes for k, v in zip(kv_keys, kv_vals))
127
+ tq_bytes = self._compressed_bytes(seq_len) * n_heads * n_layers
128
+
129
+ return {
130
+ "seq_len": seq_len, "n_layers": n_layers, "n_heads": n_heads,
131
+ "fp16_bytes": fp16_bytes, "tq_bytes": tq_bytes,
132
+ "ratio": fp16_bytes / tq_bytes,
133
+ }
134
+
135
+ # ── internals ────────────────────────────────────────────────────
136
+
137
+ @staticmethod
138
+ def _extract_kv(past_key_values):
139
+ """handle both DynamicCache and list-of-tuples formats."""
140
+ try:
141
+ return past_key_values.key_cache, past_key_values.value_cache
142
+ except AttributeError:
143
+ keys = [kv[0] for kv in past_key_values]
144
+ vals = [kv[1] for kv in past_key_values]
145
+ return keys, vals
146
+
147
+ def _compress_keys(self, K):
148
+ try:
149
+ from .compress import turboquant_compress_2bit, turboquant_compress_3bit
150
+ import cuda.tile as ct
151
+
152
+ seq_k, d = K.shape
153
+ indices = torch.empty(seq_k, d, dtype=torch.uint8, device=K.device)
154
+ signs = torch.empty(seq_k, d, dtype=torch.int8, device=K.device)
155
+ norms = torch.empty(seq_k, dtype=torch.float16, device=K.device)
156
+ r_norms = torch.empty(seq_k, dtype=torch.float16, device=K.device)
157
+
158
+ grid = (self._cdiv(seq_k, BLOCK_S), 1, 1)
159
+ c = self.key_codebook.centroids.tolist()
160
+ b = self.key_codebook.boundaries.tolist()
161
+ stream = torch.cuda.current_stream()
162
+
163
+ if self.mse_bits == 2:
164
+ ct.launch(stream, grid, turboquant_compress_2bit, (
165
+ K, self.PiT.half(), self.Pi.half(), self.ST.half(),
166
+ indices, signs, norms, r_norms, *c, *b, seq_k))
167
+ elif self.mse_bits == 3:
168
+ ct.launch(stream, grid, turboquant_compress_3bit, (
169
+ K, self.PiT.half(), self.Pi.half(), self.ST.half(),
170
+ indices, signs, norms, r_norms, *c, *b, seq_k))
171
+ else:
172
+ return self._compress_keys_pt(K)
173
+
174
+ k_mse = self._dequant_keys(indices, norms)
175
+ return {"indices": indices, "k_mse": k_mse, "qjl_signs": signs,
176
+ "vec_norms": norms, "residual_norms": r_norms}
177
+ except (ImportError, RuntimeError):
178
+ return self._compress_keys_pt(K)
179
+
180
+ def _compress_values(self, V):
181
+ try:
182
+ from .compress import (turboquant_compress_values_3bit,
183
+ turboquant_compress_values_2bit)
184
+ import cuda.tile as ct
185
+
186
+ seq_v, d = V.shape
187
+ indices = torch.empty(seq_v, d, dtype=torch.uint8, device=V.device)
188
+ norms = torch.empty(seq_v, dtype=torch.float16, device=V.device)
189
+
190
+ grid = (self._cdiv(seq_v, BLOCK_S), 1, 1)
191
+ c = self.val_codebook.centroids.tolist()
192
+ b = self.val_codebook.boundaries.tolist()
193
+ stream = torch.cuda.current_stream()
194
+
195
+ if self.total_bits == 3:
196
+ ct.launch(stream, grid, turboquant_compress_values_3bit, (
197
+ V, self.PiT.half(), indices, norms, *c, *b, seq_v))
198
+ elif self.total_bits == 2:
199
+ ct.launch(stream, grid, turboquant_compress_values_2bit, (
200
+ V, self.PiT.half(), indices, norms, *c, *b, seq_v))
201
+ else:
202
+ return self._compress_values_pt(V)
203
+
204
+ return {"indices": indices, "vec_norms": norms}
205
+ except (ImportError, RuntimeError):
206
+ return self._compress_values_pt(V)
207
+
208
+ def _decompress_values(self, cv):
209
+ try:
210
+ from .decompress import (turboquant_decompress_3bit,
211
+ turboquant_decompress_2bit)
212
+ import cuda.tile as ct
213
+
214
+ indices = cv["indices"]
215
+ norms = cv["vec_norms"]
216
+ seq_v = indices.shape[0]
217
+ output = torch.empty(seq_v, self.head_dim, dtype=torch.float16,
218
+ device=indices.device)
219
+
220
+ grid = (self._cdiv(seq_v, BLOCK_S), 1, 1)
221
+ c = self.val_codebook.centroids.tolist()
222
+ stream = torch.cuda.current_stream()
223
+
224
+ if self.total_bits == 3:
225
+ ct.launch(stream, grid, turboquant_decompress_3bit, (
226
+ indices, norms, self.Pi.half(), output, *c, seq_v))
227
+ elif self.total_bits == 2:
228
+ ct.launch(stream, grid, turboquant_decompress_2bit, (
229
+ indices, norms, self.Pi.half(), output, *c, seq_v))
230
+ else:
231
+ return self._decompress_values_pt(cv)
232
+ return output
233
+ except (ImportError, RuntimeError):
234
+ return self._decompress_values_pt(cv)
235
+
236
+ # ── pytorch fallbacks ────────────────────────────────────────────
237
+
238
+ def _compress_keys_pt(self, K):
239
+ K_f = K.float()
240
+ norms = torch.norm(K_f, dim=-1, keepdim=True)
241
+ K_normed = K_f / (norms + 1e-8)
242
+ rotated = K_normed @ self.PiT.float()
243
+
244
+ c = self.key_codebook.centroids.to(K.device)
245
+ indices = (rotated.unsqueeze(-1) - c).abs().argmin(dim=-1).to(torch.uint8)
246
+
247
+ y_hat = c[indices.long()]
248
+ k_mse = (y_hat @ self.Pi.float()) * norms
249
+ residual = K_f - k_mse
250
+ r_norms = torch.norm(residual, dim=-1)
251
+
252
+ signs = torch.sign(residual @ self.ST.float()).to(torch.int8)
253
+ signs[signs == 0] = 1
254
+
255
+ return {"indices": indices, "k_mse": k_mse.half(), "qjl_signs": signs,
256
+ "vec_norms": norms.squeeze(-1).half(),
257
+ "residual_norms": r_norms.half()}
258
+
259
+ def _compress_values_pt(self, V):
260
+ V_f = V.float()
261
+ norms = torch.norm(V_f, dim=-1, keepdim=True)
262
+ V_normed = V_f / (norms + 1e-8)
263
+ rotated = V_normed @ self.PiT.float()
264
+
265
+ c = self.val_codebook.centroids.to(V.device)
266
+ indices = (rotated.unsqueeze(-1) - c).abs().argmin(dim=-1).to(torch.uint8)
267
+
268
+ return {"indices": indices, "vec_norms": norms.squeeze(-1).half()}
269
+
270
+ def _decompress_values_pt(self, cv):
271
+ c = self.val_codebook.centroids.to(cv["indices"].device)
272
+ y_hat = c[cv["indices"].long()]
273
+ norms = cv["vec_norms"].float().unsqueeze(-1)
274
+ return ((y_hat @ self.Pi.float()) * norms).half()
275
+
276
+ def _dequant_keys(self, indices, norms):
277
+ c = self.key_codebook.centroids.to(indices.device)
278
+ y_hat = c[indices.long()]
279
+ return ((y_hat.float() @ self.Pi.float()) * norms.float().unsqueeze(-1)).half()
280
+
281
+ def _cdiv(self, a, b):
282
+ return (a + b - 1) // b
283
+
284
+ def _compressed_bytes(self, seq_len):
285
+ d = self.head_dim
286
+ key_bytes = (seq_len * d * self.mse_bits + seq_len * d + seq_len * 32) / 8
287
+ val_bytes = (seq_len * d * self.total_bits + seq_len * 16) / 8
288
+ return key_bytes + val_bytes
@@ -0,0 +1,96 @@
1
+ Metadata-Version: 2.4
2
+ Name: turboquant-gpu
3
+ Version: 0.1.0
4
+ Summary: TurboQuant KV cache compression for LLM inference — cuTile GPU kernels
5
+ Author: Anirudh Bharadwaj Vangara
6
+ License-Expression: MIT
7
+ Keywords: quantization,kv-cache,llm,inference,cutile,cuda,gpu,attention,blackwell,hopper,h100,b200
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Python: >=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch
19
+ Requires-Dist: scipy
20
+ Provides-Extra: gpu
21
+ Requires-Dist: cuda-tile; extra == "gpu"
22
+ Dynamic: license-file
23
+
24
+ # turboquant-gpu
25
+
26
+ **5.02x KV cache compression for LLM inference** — GPU-accelerated cuTile kernels with PyTorch fallback.
27
+
28
+ ```
29
+ pip install turboquant-gpu
30
+ ```
31
+
32
+ ## quick start
33
+
34
+ ```python
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer
36
+ from turboquant_gpu import TurboQuantEngine
37
+ import torch
38
+
39
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B", torch_dtype=torch.float16, device_map="cuda")
40
+ tok = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
41
+
42
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
43
+ result = engine.generate(model, tok, "The key to efficient LLM inference is")
44
+
45
+ print(result["text"])
46
+ print(f"{result['tokens']} tokens | {result['stats']['ratio']:.1f}x compression")
47
+ ```
48
+
49
+ ## how it works
50
+
51
+ Implements the [TurboQuant](https://arxiv.org/abs/2501.09747) algorithm:
52
+
53
+ 1. **normalize + rotate** — random orthogonal rotation (Pi) makes coordinates near-Gaussian
54
+ 2. **Lloyd-Max quantize** — optimal scalar quantization against N(0, 1/d)
55
+ 3. **QJL bias correction** — 1-bit sign sketch of the residual for unbiased key scores
56
+
57
+ At 3-bit (2-bit MSE + 1-bit QJL) this gives ~5x compression with negligible quality loss.
58
+
59
+ ## step-by-step api
60
+
61
+ ```python
62
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
63
+
64
+ # after model prefill:
65
+ compressed = engine.compress_kv_cache(out.past_key_values)
66
+ cache = engine.build_cache(compressed)
67
+ stats = engine.compression_stats(out.past_key_values)
68
+
69
+ # or just do it all in one call:
70
+ result = engine.generate(model, tokenizer, "your prompt here")
71
+ ```
72
+
73
+ ## gpu support
74
+
75
+ Written in [cuTile](https://docs.nvidia.com/cuda/cutile-python/) for cross-architecture portability.
76
+ Falls back to PyTorch if cuTile or a compatible driver isn't available.
77
+
78
+ | GPU family | Architecture | Status |
79
+ |------------|-------------|--------|
80
+ | A100 | Ampere | supported (PyTorch fallback) |
81
+ | H100 | Hopper | supported |
82
+ | B200/B300 | Blackwell | supported + swizzle fast path |
83
+
84
+ ## kernels
85
+
86
+ | kernel | what it does |
87
+ |--------|-------------|
88
+ | `compress_keys` | normalize → rotate(Pi^T) → Lloyd-Max quantize → QJL signs |
89
+ | `compress_values` | normalize → rotate(Pi^T) → Lloyd-Max quantize |
90
+ | `decompress_values` | dequantize → un-rotate(Pi) → scale by norms |
91
+ | `attention_scores` | asymmetric dot product with QJL correction |
92
+ | `fused_attention` | scores + online softmax + V accumulation |
93
+
94
+ ## license
95
+
96
+ MIT
@@ -0,0 +1,12 @@
1
+ turboquant_gpu/__init__.py,sha256=hVDM-NQiB_PzgSyqQXi2KoMEUspka_PKvw55yGPhODk,278
2
+ turboquant_gpu/attention.py,sha256=O13thg9bO7h9OIySgmhYnv61aWidWmC5wkCKKVuh4L8,15147
3
+ turboquant_gpu/codebook.py,sha256=cJAOa_qnUBAJqMwNrHIK2igbmX61cQ4gbVXGfr9hyRk,2240
4
+ turboquant_gpu/compress.py,sha256=yP6BYrk29yAG8dMxkT2rUet1Z8qThOLRb8AWpwwbu6s,8003
5
+ turboquant_gpu/constants.py,sha256=VSmWSTSlxO0e0QTfvtxso41EvJTR0gCKJK4QL7POXpM,133
6
+ turboquant_gpu/decompress.py,sha256=b8VKHC3ZmmzFC_CKgP_BmlL0OGooFI30ECYWM2DXbPo,2581
7
+ turboquant_gpu/host.py,sha256=mVzm4Vhl3NT0clN4XXR4Z3VpEZ9PyFRqwxtexmTTvng,11848
8
+ turboquant_gpu-0.1.0.dist-info/licenses/LICENSE,sha256=jdrItB2_D55fIzIsSe7Esi-l2ykHOQS1CmAr0Y_fBAA,1082
9
+ turboquant_gpu-0.1.0.dist-info/METADATA,sha256=9dqKsJja0ErrdAffd9rXw4r3e--y4SoSkGeHG2e2dKI,3314
10
+ turboquant_gpu-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
11
+ turboquant_gpu-0.1.0.dist-info/top_level.txt,sha256=ZrloYBosuQLyIN7iBgLJS9NENKJMkYszJU3zi7aCIBI,15
12
+ turboquant_gpu-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Anirudh Bharadwaj Vangara
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ turboquant_gpu