kvquant-plus-plus 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,662 @@
1
+ Metadata-Version: 2.4
2
+ Name: kvquant-plus-plus
3
+ Version: 0.1.0
4
+ Summary: Attention-aware KV cache quantization for LLM inference (KVQuant++ extensions)
5
+ License: MIT
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ License-File: LICENSE
9
+ Requires-Dist: torch>=2.1.0
10
+ Requires-Dist: transformers>=4.40.0
11
+ Requires-Dist: numpy>=1.24.0
12
+ Requires-Dist: matplotlib>=3.7.0
13
+ Provides-Extra: dev
14
+ Requires-Dist: pytest>=7.4.0; extra == "dev"
15
+ Dynamic: license-file
16
+
17
+ # TurboQuant++: Attention-Aware and Structure-Exploiting Extensions to Near-Optimal KV Cache Vector Quantization
18
+
19
+ ---
20
+
21
+ ## Abstract
22
+
23
+ TurboQuant (Zandieh et al., 2025) is a compelling approach to KV cache compression: rotate, then quantize with Lloyd-Max, and you get near-optimal MSE with provable bounds. But it treats every token the same, compresses each vector in isolation, and does nothing with the residual error once it's made. This paper asks what happens when you stop ignoring all of that.
24
+
25
+ We introduce five extensions attention-weighted quantization, delta compression, adaptive bit allocation, low-rank error correction, and product quantization each targeting a different structural property of transformer attention that the original method leaves on the table. Along the way, we also corrected several issues in the original implementation: the codebook was fitted to a Gaussian approximation rather than the actual sphere marginal distribution; the QR decomposition could silently produce a reflection instead of a rotation; the nearest-centroid search was doing $O(N \cdot d \cdot k)$ work when a binary search suffices; and the inner-product quantizer was applying a redundant second normalisation pass on vectors already on the unit sphere.
26
+
27
+ On distilgpt2, attention-weighted quantization cuts attention-weighted distortion by 47-70% per layer at the same average bit-width. Delta compression reduces MSE by 1.1-2.2x for correlated streams. Rank-4 error correction shaves off ~11% of the remaining MSE at 7.4% extra storage. Product quantization (M=16, b=8) produces coherent generation at 2 bits/dim the same storage as 2-bit scalar, which collapses matching 3-bit scalar quality. The `bucketize` lookup runs 14-22x faster than the original argmin expansion. Four additional improvements are described: k-means++ codebook initialisation (75% lower init MSE at 1-bit), K-V asymmetric quantization (V MSE reduced 61.5% at 0.5 fewer bits/dim), delta+outlier combination (V MSE reduced 95.4% vs same-budget plain), and Hadamard rotation exposed as a configurable parameter ($O(d \log d)$ vs $O(d^2)$).
28
+
29
+ ---
30
+
31
+ ## 1. Introduction
32
+
33
+ KV caches grow linearly with context length, and at long contexts they dominate memory. The obvious response is compression, and KVQuant gives you a principled way to do it: rotate the KV vectors into something approximately Gaussian, then apply Lloyd-Max quantization coordinate-by-coordinate. The MSE bound is $\frac{\sqrt{3}\,\pi}{2} \cdot 4^{-b}$ within 2.7x of the Shannon lower bound. That's a strong result.
34
+
35
+ What it doesn't do is think about which tokens matter. At the same average bit-budget, a token that receives 0.001% of the attention and one that receives 30% get identical treatment. It also compresses each token independently, even though in a streaming KV cache consecutive tokens tend to be highly correlated the delta is often much smaller than the vector itself. And once the quantization error is committed, there's no attempt to recover the structure in that error, even though quantization residuals tend to be low-rank.
36
+
37
+ These aren't obscure edge cases. They're structural properties of how transformers actually behave, and exploiting them gives measurable gains without touching the core quantization guarantees.
38
+
39
+ The paper is organized around these five extensions, preceded by a description of the implementation improvements we made to the baseline. The extensions are composable each one works independently, and they stack.
40
+
41
+ ---
42
+
43
+ ## 2. Background
44
+
45
+ ### 2.1 KVQuant
46
+
47
+ Given a vector $\mathbf{x} \in \mathbb{R}^d$ on the unit sphere, KVQuant applies two steps:
48
+
49
+ **Rotation.** Sample a Haar-uniform random orthogonal matrix $\Pi$ and compute $\mathbf{y} = \Pi \mathbf{x}$. After rotation, each coordinate $y_j$ is approximately $\mathcal{N}(0, 1/d)$ and approximately independent of the others. The rotation is what makes Lloyd-Max applicable the original KV vectors can have arbitrary non-Gaussian distributions.
50
+
51
+ **Quantization.** Map each coordinate $y_j$ to the nearest centroid in a precomputed codebook $\mathcal{C}_b = \{c_1, \ldots, c_{2^b}\}$ that solves the 1-D optimal quantization problem for the rotated distribution.
52
+
53
+ The MSE bound is:
54
+ $$D_{\mathrm{mse}} \;\leq\; \frac{\sqrt{3}\,\pi}{2} \cdot 4^{-b}.$$
55
+ So at 2 bits you get $D \leq 0.170$, at 4 bits $D \leq 0.011$.
56
+
57
+ For inner products specifically, KVQuant has a two-stage variant that uses $(b-1)$ bits on the MSE path, then applies 1-bit QJL to the residual $\mathbf{r} = \mathbf{x} \hat{\mathbf{x}}$. This gives an unbiased estimator with variance:
58
+ $$\mathrm{Var}\!\left[\langle \mathbf{y},\, \tilde{\mathbf{x}} \rangle\right] \;\leq\; \frac{\sqrt{3}\,\pi^2\,\|\mathbf{y}\|^2}{d} \cdot 4^{-b},$$
59
+ which directly bounds the error in attention score computation.
60
+
61
+ ### 2.2 Implementation Improvements
62
+
63
+ #### 2.2.1 Codebook Distribution and k-means++ Initialisation
64
+
65
+ The original implementation approximates the post-rotation coordinate distribution as $\mathcal{N}(0, 1/d)$ and builds Lloyd-Max centroids for a Gaussian. But the true marginal, after rotating a unit-sphere vector, is:
66
+ $$f(t) \;=\; C_d \cdot (1 - t^2)^{(d-3)/2}, \qquad t \in [-1,\, 1]$$
67
+ ![Figure 1: True sphere marginal vs Gaussian approximation at d=8 and d=64](figures/fig1_distribution.png)
68
+
69
+ This is a Beta-type distribution that only converges to a Gaussian for large $d$. At small $d$ and low bit-widths the difference is meaningful. We fit centroids directly by sampling from the true sphere distribution instead. The improvement is most visible at $b \in \{1,2\}$; by $b=4$ the Gaussian approximation is already pretty good. Centroids are cached by (num\_bits, dim) after first computation.
70
+
71
+ **k-means++ initialisation.** The Lloyd-Max solver is an EM algorithm: it alternates between assigning samples to the nearest centroid and updating each centroid to the mean of its cluster. Like all EM procedures it is sensitive to initialisation a bad starting point leads to slow convergence or a suboptimal local solution.
72
+
73
+ The original implementation seeds centroids with `torch.linspace(-c_max, c_max, k)`, placing them uniformly across the empirical support. For a distribution that is symmetric but non-uniform (the sphere marginal concentrates away from the origin for low $d$), uniform spacing wastes centroids in low-density regions.
74
+
75
+ We replace this with k-means++ seeding (Arthur \& Vassilvitskii, 2007):
76
+
77
+ 1. Choose the first centroid uniformly at random from the sample set.
78
+ 2. For each subsequent centroid, sample from the data with probability proportional to $D^2(x) = \min_{c \in \text{chosen}} (x - c)^2$.
79
+
80
+ This $D^2$-weighted scheme gives an $O(\log k)$ approximation guarantee over uniform initialisation and ensures initial centroids are spread across high-density regions rather than tails.
81
+
82
+ **Empirical improvement at initialisation** ($d=64$, 100k samples, before Lloyd-Max iterations):
83
+
84
+ At 1-bit ($k=2$) k-means++ cuts the initial MSE from 0.097 to 0.024 a 75.1% reduction. At 2-bit ($k=4$) the gap is 58.5% (0.006 down to 0.0025). By 3-bit ($k=8$) it narrows to 27.6%, and at 4-bit ($k=16$) both initialisations land at essentially the same MSE (~0.000241 vs 0.000240). The gains are largest where $k$ is small and every centroid placement counts. After full Lloyd-Max convergence both schemes converge to near-identical solutions; the practical benefit is fewer iterations to get there and avoidance of rare degenerate local optima at $b \in \{1, 2\}$.
85
+
86
+ ![Figure 7: Left raw init MSE (log scale) for linspace vs k-means++ at each bit-width. Right percentage MSE reduction from k-means++ seeding.](figures/fig7_kmeans_init.png)
87
+
88
+ #### 2.2.2 SO(d) Rotation
89
+
90
+ The QR decomposition gives an orthogonal matrix, but orthogonal includes both rotations ($\det = +1$) and reflections ($\det = -1$). About half the time you'll get a reflection. For most applications this probably doesn't matter much, but it's not what you want if you're claiming to rotate into a specific distribution. We add a sign-flip on the first column when $\det(Q) < 0$, which costs essentially nothing and ensures $\Pi \in \mathrm{SO}(d)$.
91
+
92
+ #### 2.2.3 Nearest-Centroid Lookup
93
+
94
+ The original approach expands $\mathbf{y}$ into a $(N, d, k)$ tensor and takes the argmin. For $b=4$, $k=16$, $N=4096$, $d=128$ this is a large temporary tensor that gets thrown away immediately. Since Lloyd-Max centroids are always sorted, you just need a binary search on the $k-1$ midpoints:
95
+
96
+ ```python
97
+ # before
98
+ diff = (y.unsqueeze(-1) - centroids.view(1,1,-1)).abs()
99
+ indices = diff.argmin(dim=-1) # O(N*d*k), large temp tensor
100
+
101
+ # after
102
+ boundaries = (centroids[:-1] + centroids[1:]) / 2
103
+ indices = torch.bucketize(y, boundaries) # O(N*d*log k), no temp tensor
104
+ ```
105
+
106
+ This drops from $O(N \cdot d \cdot k)$ to $O(N \cdot d \cdot \log k)$ and eliminates the large intermediate tensor. In practice: 14x faster at 2-bit, 22x faster at 4-bit (tested at $N=4096$, $d=128$). The same fix applies inside the Lloyd-Max solver's assignment step.
107
+
108
+ #### 2.2.4 In-Place FWHT
109
+
110
+ The butterfly step in the Fast Walsh-Hadamard Transform was allocating two clones per level ($O(\log d)$ extra tensors). You can do it with one:
111
+
112
+ ```python
113
+ a = x[..., :h] + x[..., h:] # one allocation
114
+ x[..., h:] = x[..., :h] - x[..., h:]
115
+ x[..., :h] = a
116
+ ```
117
+
118
+ #### 2.2.5 Hadamard Rotation and Entropy Coding
119
+
120
+ We also swap the dense QR rotation for a structured Hadamard rotation:
121
+ $$\mathbf{y} \;=\; \frac{1}{\sqrt{d}}\, H(D \cdot \mathbf{x}),$$
122
+ where $D = \mathrm{diag}(\pm 1)$ is a random sign flip matrix and $H$ is the Walsh-Hadamard transform. This brings rotation complexity down from $O(d^2)$ to $O(d \log d)$ and storage from $d^2$ floats to just $d$ floats for the sign mask. The randomization guarantee still holds (Ailon & Chazelle, 2006).
123
+
124
+ On top of that: codebook indices are non-uniformly distributed under the sphere marginal, so Huffman coding can compress them further toward the Shannon entropy. At $b=4$, $d=128$ the entropy is 3.816 bits vs 4 raw, giving roughly 4% compression. It's not dramatic, but it's free.
125
+
126
+ #### 2.2.6 Unit-Norm Fast Path in `KVQuantIP.quantize()`
127
+
128
+ **Problem.** `KVQuantIP.quantize()` normalises `x` to unit-norm before calling into `KVQuantMSE.quantize()`. But `KVQuantMSE.quantize()` immediately normalises again computing a norm, clamping, and dividing on a vector that is already unit-length. That second normalisation is a no-op numerically but costs three element-wise operations and a reduction over $(N, d)$.
129
+
130
+ **Fix.** Add a `_quantize_unit` fast path to `KVQuantMSE` that skips norm computation and the `QuantizedMSE` allocation:
131
+
132
+ ```python
133
+ def _quantize_unit(self, x_unit: Tensor) -> Tensor:
134
+ """Fast path: quantize pre-normalised vectors, return raw indices.
135
+ Skips norm computation and QuantizedMSE allocation."""
136
+ return torch.bucketize(self.rotation(x_unit), self.boundaries) # (N, d)
137
+ ```
138
+
139
+ `KVQuantIP.quantize()` calls this instead of the full path:
140
+
141
+ ```python
142
+ # before double-normalises x_unit
143
+ indices, x_hat_unit = self.mse_quantizer.quantize(x_unit), ...
144
+
145
+ # after single normalisation, no QuantizedMSE alloc
146
+ indices = self.mse_quantizer._quantize_unit(x_unit) # (N, d)
147
+ x_hat_unit = self.mse_quantizer._dequantize_unit(indices) # (N, d)
148
+ ```
149
+
150
+ The speedup from removing the second norm is small (~6% of total quantize time at $N=4096$, $d=128$) because the QJL projection `r @ S.T` which is $O(N \cdot d^2)$ dominates. The correctness gain is more important: the old code was silently quantizing a non-unit vector through a path that assumed unit input, giving slightly wrong centroids when `x_unit` had floating-point norm deviating from 1.0.
151
+
152
+ #### 2.2.7 Batch-Size Product with `math.prod()`
153
+
154
+ **Problem.** `KVQuantIP.dequantize()` and `OutlierKVQuant.dequantize()` both need to flatten the leading batch dimensions of the input shape into $N$. The original code used a Python loop:
155
+
156
+ ```python
157
+ N = 1
158
+ for s in q.shape[:-1]:
159
+ N *= s
160
+ ```
161
+
162
+ **Fix.** Replace with `math.prod()`, which is a single C-level call:
163
+
164
+ ```python
165
+ import math
166
+ N = math.prod(q.shape[:-1])
167
+ ```
168
+
169
+ The difference is negligible for large tensors (the loop runs in $O(\text{ndim})$ iterations, typically 2-3). The change is a clarity improvement as much as a performance one `math.prod` makes the intent immediately obvious.
170
+
171
+ #### 2.2.8 Codebook Clone Removal
172
+
173
+ **Problem.** `build_codebook()` returned `centroids.clone()` and `boundaries.clone()` unconditionally on every call, even when the caller's only purpose was to pass the tensors to `register_buffer`. The clone was a defensive copy to prevent callers from mutating the cached tensors, but it happened even when `device is not None` after `.to()` had already returned a fresh tensor.
174
+
175
+ **Fix.** Clone only on the CPU path (where the cache must be protected from device moves):
176
+
177
+ ```python
178
+ centroids, boundaries = _CACHE[key]
179
+ if device is not None:
180
+ # .to() returns a new tensor when device differs already independent
181
+ return centroids.to(device), boundaries.to(device)
182
+ # Clone so callers (register_buffer) get an independent tensor that can be
183
+ # moved to another device without corrupting the CPU cache entry.
184
+ return centroids.clone(), boundaries.clone()
185
+ ```
186
+
187
+ This halves the number of allocations on the GPU path and eliminates one unnecessary CPU copy on the CPU path when device tensors are requested.
188
+
189
+ #### 2.2.9 K-V Asymmetric Quantization
190
+
191
+ **Observation.** K and V tensors play different roles in attention:
192
+
193
+ $$\text{score}_t = \mathbf{q}^\top \hat{\mathbf{k}}_t \qquad \text{output}_t = \sum_t a_t \cdot \hat{\mathbf{v}}_t$$
194
+
195
+ The K cache enters only via inner products with the query. The V cache enters via a weighted sum reconstructed as floating-point values. These roles have different optimal quantization objectives:
196
+
197
+ - **K** -> KVQuantIP: minimises inner-product error $\mathbb{E}[(\langle \mathbf{q}, \mathbf{k} \rangle - \langle \mathbf{q}, \hat{\mathbf{k}} \rangle)^2]$. The two-stage IP quantizer gives an unbiased estimator, so attention scores remain centred even under quantization noise.
198
+ - **V** -> KVQuantMSE: minimises reconstruction error $\mathbb{E}[\|\mathbf{v} - \hat{\mathbf{v}}\|^2]$. The output token is a linear combination of V vectors; MSE-optimal quantization directly minimises the output corruption.
199
+
200
+ **Implementation.** `KVCacheQuantizer` previously used KVQuantIP for both K and V. We separate the backends:
201
+
202
+ ```python
203
+ # K: inner-product optimal
204
+ self.k_quant = KVQuantIP(head_dim, num_bits, ...)
205
+ # V: MSE optimal
206
+ self.v_quant = KVQuantMSE(head_dim, num_bits, ...)
207
+ ```
208
+
209
+ The same asymmetry propagates through `OutlierKVQuant` via a new `quantizer_cls` parameter, so the outlier-aware path also benefits.
210
+
211
+ **Empirical result** (delta cache, $d=64$, $T=50$ drifting sequence, 3-bit budget):
212
+
213
+ The 3-bit symmetric baseline (IP/IP) gives a K IP-error of 0.2815 and a V MSE of 0.005424. Switching to the asymmetric IP/MSE config at 2.5 bits/dim half a bit less per dimension drops V MSE to 0.002086, a **61.5% reduction** while spending fewer bits overall. For comparison, plain IP/IP at the same 2-bit budget produces a V MSE of 0.045147, more than 20x worse. The K IP-error rises with the asymmetric config (0.2815 to 2.1816) because the per-dimension budget is lower, but this is expected and consistent with the IP quantizer's unbiasedness guarantee the attention scores remain centred regardless.
214
+
215
+ ![Figure 8: K IP-error (left) and V MSE (right) for symmetric IP/IP vs asymmetric IP/MSE quantization at matched and reduced bit budgets.](figures/fig8_kv_asymmetric.png)
216
+
217
+ ---
218
+
219
+ ## 3. Extensions
220
+
221
+ ### 3.1 Attention-Weighted Quantization
222
+
223
+ KVQuant minimizes:
224
+ $$\mathcal{L}_{\mathrm{uniform}} \;=\; \mathbb{E}\!\left[\,\|\mathbf{k}_i - \hat{\mathbf{k}}_i\|^2\,\right].$$
225
+
226
+ But this treats a token that gets 30% of the attention the same as one that gets 0.01%. What actually matters for model output is the attention-weighted error:
227
+ $$\mathcal{L}_{\mathrm{weighted}} \;=\; \mathbb{E}\!\left[\,a_i \cdot \|\mathbf{k}_i - \hat{\mathbf{k}}_i\|^2\,\right],$$
228
+ where $a_i = \mathrm{softmax}(\mathbf{q}\mathbf{K}^\top / \sqrt{d})_i$. The fix is simple: given a query vector $\mathbf{q}$, rank tokens by their attention weights, give the top fraction extra bits, and give the rest fewer bits. The average bit-width stays the same you're just redistributing it.
229
+
230
+ Concretely, for a 3-bit average with $b_{\mathrm{hi}}=4$, $b_{\mathrm{lo}}=2$, top 50%:
231
+
232
+ 1. Compute $\mathbf{a} = \mathrm{softmax}(\mathbf{q}\mathbf{K}^\top / \sqrt{d})$
233
+ 2. Top 50% of tokens -> 4-bit quantizer
234
+ 3. Bottom 50% -> 2-bit quantizer
235
+ 4. Average: 0.5 x 4 + 0.5 x 2 = 3 bits
236
+
237
+ Results on distilgpt2 (3-bit avg):
238
+
239
+ Across all six layers, AWQ cuts attention-weighted distortion by 47.5% to 70.1% versus uniform quantization at the same average bit-width. Layer 2 sees the largest gain (0.184 down to 0.055, 70.1%), which makes sense it tends to have the most peaked attention distributions, so the high-attention tokens benefit most from the extra bits. The average reduction across all layers is 56.5%, and this is the quantity that actually determines how much the model's outputs change.
240
+
241
+ ![Figure 2: Attention-weighted bit assignment and per-layer distortion reduction](figures/fig2_awq.png)
242
+
243
+ ### 3.2 Delta Compression
244
+
245
+ During autoregressive generation, the KV vectors for adjacent tokens are correlated often strongly. The delta $\|\mathbf{k}_t - \mathbf{k}_{t-1}\|$ is typically much smaller than $\|\mathbf{k}_t\|$. Compressing deltas instead of absolute vectors at the same bit-width gives lower distortion almost for free.
246
+
247
+ The scheme is straightforward: store $\mathbf{k}_0$ at full float32 precision as an anchor, then for each subsequent token compress $\boldsymbol{\delta}_t = \mathbf{k}_t - \hat{\mathbf{k}}_{t-1}$ with KVQuantIP. Reconstruction accumulates:
248
+ $$\hat{\mathbf{k}}_t \;=\; \hat{\mathbf{k}}_{t-1} + \mathrm{decompress}(\boldsymbol{\delta}_t).$$
249
+
250
+ One thing to watch: errors accumulate over long sequences. For most use cases this isn't a problem, but two anchor strategies are available. The `anchor_every` parameter re-anchors at fixed intervals (e.g. every 128 tokens). The `anchor_threshold` parameter re-anchors adaptively when $\|\boldsymbol{\delta}_t\| / \|\mathbf{k}_t\| > \tau$ triggering exactly when the sequence changes rapidly and error would accumulate most, without wasting anchors on stable regions.
251
+
252
+ **Implementation optimisations.** Three improvements were made to the naive implementation. The anchor set was changed from a Python list (O(T) membership test) to a hash set (O(1)), eliminating a quadratic scan. Cache reconstruction was made incremental: instead of rebuilding from all deltas on every `get()` call (O(T^2) total), each `push()` appends the current running reconstruction so `get()` just stacks the list in O(1) at the cost of storing T float32 reconstructions alongside the compressed deltas. Finally, anchor placement was extended with an adaptive mode that fires when $\|\delta\|/\|\mathbf{k}\| > \tau$, triggering at actual change-points rather than fixed intervals; $\tau=0$ disables it for full backwards compatibility.
253
+
254
+ Results on distilgpt2 (3-bit):
255
+
256
+ Delta compression reduces MSE across all six layers, with earlier layers benefiting most. Layers 1 and 2 see a 2.2x improvement (MSE roughly halved), while layer 5 the deepest shows only a 1.1x gain. This gradient makes sense: early layers tend to have smoother, more predictable KV trajectories, so the deltas are smaller relative to the vectors. Later layers develop more complex, rapidly-shifting representations where consecutive tokens diverge more.
257
+
258
+ **Delta + outlier combination.** Delta compression and outlier-aware quantization are complementary and can be stacked. Outlier channels those with disproportionately high variance also tend to be the channels with the largest delta magnitudes. Allocating extra bits to these channels at the delta compression stage reduces the dominant sources of reconstruction error.
259
+
260
+ `DeltaKVCache` accepts `use_outlier=True`, which replaces the internal KVQuantIP/KVQuantMSE pair with `OutlierKVQuant` instances calibrated on the delta distribution. The asymmetric rule from Section 2.2.9 applies: K deltas use KVQuantIP (inner-product optimal), V deltas use KVQuantMSE (MSE optimal). A `calibrate(k_samples, v_samples)` method computes consecutive differences from a sample sequence and calibrates the outlier detectors on the resulting delta distribution rather than the raw KV distribution.
261
+
262
+ **Empirical result** ($d=64$, $T=50$, 3-bit budget, slow drift $\|\boldsymbol{\delta}_t\| \approx 0.15 \|\mathbf{k}_t\|$):
263
+
264
+ At 2.5 bits/dim, the delta+outlier config (IP/MSE) achieves a V MSE of 0.002086 **61.5% lower** than the 3-bit plain baseline (0.005424) while using half a bit less, and **95.4% lower** than 2-bit plain at the same budget (0.045147). K IP-error rises to 2.1816 at 2.5 bits, comparable to 2-bit plain, which is expected the IP quantizer remains unbiased regardless of bit-width.
265
+
266
+ ![Figure 9: K IP-error (left) and V MSE (right) for plain vs delta+outlier quantization. The green bar at 2.5 bpw beats both the 3-bit and same-budget 2-bit baselines on V MSE.](figures/fig9_delta_outlier.png)
267
+
268
+ ### 3.3 Adaptive Bit Allocation
269
+
270
+ Static bit allocation has an obvious limitation: you don't know which tokens will matter when you compress them. A token at position 5 might be largely ignored for the first 40 steps of generation and then suddenly become the most attended token in the sequence. Allocating bits based on importance at compression time misses this.
271
+
272
+ The fix is an EMA over attention scores. Each token maintains a score:
273
+ $$s_t \;\leftarrow\; \alpha \cdot s_{t-1} \;+\; (1-\alpha) \cdot a_t,$$
274
+ where $a_t$ is the attention weight it receives at step $t$. As scores evolve, tokens move between four bit-width tiers: scores above $\tau_{\mathrm{hi}}$ stay at 4-bit, scores between $\tau_{\mathrm{mid}}$ and $\tau_{\mathrm{hi}}$ drop to 3-bit, scores between $\tau_{\mathrm{lo}}$ and $\tau_{\mathrm{mid}}$ drop to 2-bit, and anything below $\tau_{\mathrm{lo}}$ is evicted to 1-bit. Recompression happens when a token crosses a threshold. This is reversible a demoted token can be promoted again if its scores recover.
275
+
276
+ As an illustrative example: at short sequence lengths (12 tokens, distilgpt2 layer 0), no tokens get demoted since attention is fairly spread, and all 12 end up at 4-bit MSE of 0.017 vs 0.064 for uniform 3-bit. This is a single-layer observation under low attention peakedness. The adaptive behavior activates more visibly at longer sequences with more peaked attention distributions, which is exactly when it matters most; a full multi-layer, multi-length evaluation is left for future work.
277
+
278
+ ### 3.4 Low-Rank Error Correction
279
+
280
+ Quantization error $\mathbf{R} = \mathbf{K} - \hat{\mathbf{K}}$ isn't random noise it has structure. The top few singular vectors typically account for a disproportionate share of the total error energy. This means a low-rank approximation of $\mathbf{R}$ can recover a lot of the distortion cheaply.
281
+
282
+ Given $\hat{\mathbf{K}}$ from any KVQuant variant, the correction is:
283
+
284
+ 1. Compute $\mathbf{R} = \mathbf{K} - \hat{\mathbf{K}}$
285
+ 2. Truncated SVD: $\mathbf{R} \approx \mathbf{U}_r \boldsymbol{\Sigma}_r \mathbf{V}_r^\top$
286
+ 3. Store $\mathbf{U}_s = \mathbf{U}_r \boldsymbol{\Sigma}_r$ and $\mathbf{V}_r$
287
+ 4. Corrected reconstruction: $\hat{\mathbf{K}} + \mathbf{U}_s \mathbf{V}_r^\top$
288
+
289
+ Storage cost: for $T=360$, $d=64$, $r=4$ you're storing $r(T+d) = 4 \times (360+64) = 1{,}696$ floats vs $T \cdot d = 360 \times 64 = 23{,}040$ for the full residual 7.4% of the full correction budget and it gets you most of the benefit.
290
+
291
+ You can also apply the correction directly in attention computation without materializing $\hat{\mathbf{K}}_{\mathrm{corrected}}$:
292
+ $$\mathbf{Q}\hat{\mathbf{K}}_{\mathrm{corrected}}^\top \;=\; \mathbf{Q}\hat{\mathbf{K}}^\top \;+\; (\mathbf{Q}\mathbf{V}_r)(\mathbf{U}_s)^\top.$$
293
+ The extra term costs $O(T \cdot r \cdot d)$ FLOPs, which for small $r$ is negligible.
294
+
295
+ Results ($T=360$, $d=64$):
296
+
297
+ At 2-bit, rank-4 correction drops MSE from 0.252 to 0.220 (a 13% reduction) and rank-8 pushes it further to 0.198 (21%). At 3-bit the gains are 11% and 20% respectively. At 4-bit the residual is already small enough (0.019) that rank-4 correction still shaves off about 11%, though the absolute benefit is modest. The improvement is consistent across bit-widths, which reflects the low-rank structure of quantization error being a fundamental property of the codebook, not an artifact of aggressive compression.
298
+
299
+ ![Figure 3: Low-rank correction stores rank-r SVD of the residual at 7.4% of full storage](figures/fig3_lowrank.png)
300
+
301
+ ### 3.5 Product Quantization
302
+
303
+ Scalar Lloyd-Max quantization treats each post-rotation coordinate independently. This ignores residual correlations between adjacent dimensions that survive the Hadamard transform. Product Quantization (PQ; Jégou et al., 2011) exploits these correlations by partitioning the $d$-dimensional vector into $M$ subvectors of size $d^* = d / M$ and learning a separate $k$-means codebook per subspace:
304
+
305
+ $$\hat{\mathbf{k}} \;=\; \bigl[\,\hat{\mathbf{k}}^{(1)}\;\|\;\hat{\mathbf{k}}^{(2)}\;\|\;\cdots\;\|\;\hat{\mathbf{k}}^{(M)}\,\bigr], \qquad \hat{\mathbf{k}}^{(m)} = \arg\min_{\mathbf{c} \in \mathcal{C}_m} \|\mathbf{k}^{(m)} - \mathbf{c}\|_2^2$$
306
+
307
+ where $\mathcal{C}_m \subset \mathbb{R}^{d^*}$ is the $K$-entry codebook for subspace $m$, trained by $k$-means on calibration vectors.
308
+
309
+ **Storage.** Each vector is encoded as $M$ integer codes of $b = \log_2 K$ bits each:
310
+ $$\text{bits per vector} = M \cdot b \quad \Longleftrightarrow \quad \frac{M \cdot b}{d} \;\text{ bits per dimension}.$$
311
+ For $d=64$, $M=16$, $b=8$: $16 \times 8 = 128$ bits/vector $= 2$ bits/dim identical to 2-bit scalar but with $K=256$ centroids per subspace vs $K=4$ for scalar.
312
+
313
+ **Initialisation.** Codebooks are seeded with $k$-means++ (Arthur \& Vassilvitskii, 2007), which selects initial centroids with probability proportional to $\|\mathbf{x} - \mathbf{c}_{\text{nearest}}\|_2^2$. This gives an $O(\log K)$ approximation guarantee over random initialisation and converges in fewer iterations.
314
+
315
+ **Integration with KVQuant.** The Hadamard rotation is applied before the subspace split so that information is spread uniformly across subspaces. Each subvector then has approximately isotropic variance, making all $M$ codebooks equally important. Codebooks are calibrated on the actual prefill KV vectors, so the centroid distribution matches the true per-model, per-layer KV distribution rather than the theoretical sphere marginal.
316
+
317
+ **Attention computation.** At inference, each attention head computes:
318
+ $$\text{score}_{t} = \mathbf{q}^\top \hat{\mathbf{k}}_t = \sum_{m=1}^{M} \mathbf{q}^{(m)\top} \hat{\mathbf{k}}_t^{(m)}$$
319
+ where $\hat{\mathbf{k}}_t^{(m)}$ is looked up from codebook $\mathcal{C}_m$ using stored code $c_t^{(m)}$. No full KV reconstruction is needed.
320
+
321
+ **Generation quality results** (TinyLlama-1.1B-Chat, $d=64$, prompt "What is Nihilism?", 100 tokens):
322
+
323
+ 4-bit scalar produces an excellent definition with named philosophers. 3-bit stays on-topic with minor drift. 2-bit scalar collapses entirely the output becomes incoherent within about 30 tokens, drifting into unrelated content. PQ at $M=16$, $b=8$ uses the same 128 bits/vector as 2-bit scalar but produces a coherent, correct definition across the full 100 tokens, matching 3-bit scalar quality. The difference comes down to codebook expressiveness: 2-bit scalar has $K=4$ centroids per dimension while PQ has $K=256$ per subspace, capturing the inter-dimension correlations that scalar quantization ignores entirely.
324
+
325
+ **Notable result.** PQ ($M=16$, $b=8$) matches 3-bit scalar generation quality at 2-bit scalar storage a 33% storage reduction with no perceptible quality loss, and a 2$\times$ improvement over same-budget scalar quantization.
326
+
327
+ ![Figure 4: Left PQ encoding schematic (split into M subvectors, each assigned to its codebook centroid). Right storage vs generation quality: PQ at 128 bits/vector is coherent while same-budget 2-bit scalar collapses.](figures/fig4_pq.png)
328
+
329
+ ---
330
+
331
+ ## 4. Full Pipeline
332
+
333
+ Each new KV pair passes through four stages before it is stored. The stages are independent each targets a different source of inefficiency so their gains compound.
334
+
335
+ **Stage 1 Delta compression.** Rather than compressing each token's key and value vectors in isolation, we compress the *change* from the previous token. Because adjacent KV vectors in a real generation stream are highly correlated, the delta is typically much smaller in magnitude than the absolute vector. The same bit-width therefore achieves lower distortion: 1.1--2.2x lower MSE across distilgpt2 layers (Section 3.2).
336
+
337
+ **Stage 2 Attention-weighted bit assignment.** Before committing to a quantizer, we rank the tokens by how much attention the current query places on them. The top half get one extra bit; the bottom half give one bit back. The average bit-width is unchanged, but the bits go where the model actually looks. This cuts attention-weighted distortion by 47--70% per layer with no storage overhead (Section 3.1).
338
+
339
+ **Stage 3 Quantization backend (choose one).** Two backends are available and are mutually exclusive per layer:
340
+
341
+ - *KVQuantIP (default)* scalar Lloyd-Max quantization with inner-product-optimal K encoding and MSE-optimal V encoding. Low-rank error correction is applied on top: the quantization residual is approximated with a rank-4 SVD and added back, recovering ~11% of the remaining MSE at 7.4% extra storage. Huffman coding of the codebook indices is available as a final lossless step, saving ~4% at 4-bit.
342
+
343
+ - *ProductKVCache (alternative)* Product Quantization splits each vector into M subvectors and encodes each with its own k-means codebook. At M=16, b=8 this matches 3-bit scalar quality at 2-bit scalar storage. No low-rank correction is needed at this operating point because PQ already captures inter-dimension correlations that scalar quantization misses.
344
+
345
+ **Stage 4 Adaptive reallocation.** During generation, each token's importance is tracked via an EMA over the attention weights it receives. As the sequence evolves, tokens cross bit-width thresholds and are recompressed up or down accordingly. This handles the common case where a token that seemed unimportant at compression time becomes critical steps later.
346
+
347
+ ---
348
+
349
+ ## 5. Experiments
350
+
351
+ ### 5.1 Setup
352
+
353
+ We evaluate perplexity (PPL) under KV cache quantization using the generation scenario the method is designed for: a full-precision prefill populates the KV cache, the cache is then quantized, and token generation continues from the quantized cache. PPL is measured only on the generated tokens, directly capturing the quality degradation caused by cache compression.
354
+
355
+ **Models.** We report results on three models: distilgpt2 (82 M parameters, 6 layers), gpt2-medium (345 M parameters, 24 layers), and TinyLlama-1.1B-Chat-v1.0 (1.1 B parameters, 22 layers). The GPT-2 models use a 1024-token context window; TinyLlama uses a 2048-token context window.
356
+
357
+ **Protocol.** Each text chunk uses 128 context tokens (prefill) and 64 target tokens (scored). We evaluate on 50 non-overlapping chunks. The quantizer is calibrated on the KV cache from 8 representative context sequences before evaluation.
358
+
359
+ **Quantizer.** OutlierKVQuant with automatic outlier detection (`n_outlier = head_dim / 4`), outlier channels stored at `min(bits+1, 4)` bits and regular channels at `max(bits-1, 1)` bits. Low-rank correction uses randomized SVD with rank 4.
360
+
361
+ ---
362
+
363
+ ### 5.2 Perplexity vs. Bit-width
364
+
365
+ PPL degradation (dPPL = PPL_quant - PPL_fp32) at each bit-width:
366
+
367
+ At 4-bit, all three models hold up well distilgpt2 adds only +1.71 PPL over its FP32 baseline of 33.51, gpt2-medium adds +1.10 over 13.38, and TinyLlama adds just +0.25 over 4.78. 3-bit is more model-dependent: TinyLlama's LlamaAttention architecture tolerates it gracefully (+0.87 dPPL) while distilgpt2 shows a steeper drop (+21.44). At 2-bit the degradation is severe across all models (+276, +173, +274 respectively) without correction 2-bit scalar essentially breaks generation.
368
+
369
+ ---
370
+
371
+ ### 5.3 Effect of Low-Rank Correction (rank = 4)
372
+
373
+ With rank-4 correction applied:
374
+
375
+ On distilgpt2, rank-4 correction brings 2-bit dPPL from +276.64 down to +10.89 recovering **96% of the degradation**. On gpt2-medium the recovery is **97%** (173.61 down to 5.95). At 3-bit, the corrected cache lands within 1.6--4.3 PPL of FP32 on both models. At 4-bit the gains are smaller (0.5--0.7 PPL) because the 4-bit residual is already small enough ($D_{\text{mse}} \leq 0.011$) that a rank-4 SVD risks fitting numerical noise rather than real signal so in practice we apply correction only when bits $< 4$. TinyLlama is omitted here: its 3-bit and 4-bit degradation is already below measurement noise at 50 chunks, so correction is not meaningful to report.
376
+
377
+ **Notable result.** For gpt2-medium, 2-bit + rank-4 (dPPL = +5.95) is within 0.07 PPL of plain 3-bit without correction (+6.02). This means rank-4 correction effectively turns 2-bit storage into 3-bit quality, reducing storage by ~25% with no perceptual quality loss.
378
+
379
+ ![Figure 5: PPL degradation by bit-width (left) and effect of rank-4 correction (right)](figures/fig5_ppl.png)
380
+
381
+ ---
382
+
383
+ ## 6. Related Work
384
+
385
+ **KV cache compression.** The most common approaches evict tokens entirely. H2O (Zhang et al., 2023) drops low-attention tokens; ScissorHands (Liu et al., 2023) uses historical attention patterns to decide what to evict; StreamingLLM (Xiao et al., 2023) keeps only recent and initial tokens. These methods trade accuracy for memory in a hard way once a token is gone, it's gone. Our approach keeps all tokens but at variable precision.
386
+
387
+ **Quantization for LLMs.** GPTQ (Frantar et al., 2022) quantizes weights using second-order error correction; AWQ (Lin et al., 2023) identifies and protects salient weight channels. KVQuant (Hooper et al., 2024) targets KV caches specifically, using per-channel and per-token scaling. Our work is closest to KVQuant but focuses on the streaming setting and builds on KVQuant's information-theoretic framework.
388
+
389
+ **Structured random projections.** QuIP (Chee et al., 2023) and QuaRot (Ashkboos et al., 2024) apply randomized Hadamard transforms to weight quantization, for similar reasons as KVQuant's rotation step. The technique traces back to Ailon & Chazelle (2006).
390
+
391
+ **Delta coding.** Frame-differencing is fundamental in video compression (H.264 P-frames, HEVC). The same intuition applies here: KV streams have temporal correlation, so the delta is cheaper to compress than the absolute value.
392
+
393
+ ---
394
+
395
+ ## 7. Discussion and Limitations
396
+
397
+ **Composability.** The five extensions are designed to be stacked, but not all combinations are equally useful. Delta compression and attention-weighted bit assignment are complementary delta targets temporal correlation in absolute vectors, AWQ targets token importance. Low-rank correction is orthogonal to both. PQ, however, is a full replacement for the scalar KVQuantIP path and cannot be trivially combined with it at the same layer; it is best treated as an alternative quantization backend.
398
+
399
+ **GQA amplification and compensation.** Models with Grouped Query Attention (GQA) share KV heads across multiple query heads. With a grouping factor $g = \text{num\_heads} / \text{kv\_heads}$, the effective per-attention-head distortion is:
400
+ $$D_{\text{eff}} \approx g \cdot D_{\text{mse}}.$$
401
+ For Qwen2.5-1.5B ($g=6$, $d=128$) even 4-bit quantization (theoretical $D \leq 0.011$) gives $D_{\text{eff}} \approx 0.066$, which is large enough to corrupt generation. TinyLlama ($g=8$, $d=64$) survives because the smaller absolute head dimension yields smaller per-element error.
402
+
403
+ To compensate, we solve for the effective bit-width $b_{\text{eff}}$ such that the amplified distortion matches the target $D_{\text{target}}$ of a standard $b$-bit MHA model:
404
+ $$g \cdot \frac{\sqrt{3}\,\pi}{2} \cdot 4^{-b_{\text{eff}}} \;=\; \frac{\sqrt{3}\,\pi}{2} \cdot 4^{-b}$$
405
+ $$\Rightarrow \quad b_{\text{eff}} \;=\; b + \log_4(g) \;=\; b + \frac{\log g}{\log 4}.$$
406
+ Since bit-widths must be integers, we round up: $b_{\text{eff}} = b + \lceil \log_4 g \rceil$. For $g=4$ this adds $+1$ bit; for $g=6$ or $g=7$ it adds $+2$ bits; for $g=32$ it adds $+3$ bits. Effective bits are capped at 8, which is the maximum supported by the Lloyd-Max solver (256 centroids). The same adjustment is applied to the outlier and regular channel bit-widths in `OutlierKVQuant` to ensure GQA compensation is not bypassed by caller-supplied explicit bit-widths.
407
+
408
+ This is a fundamental limitation of scalar quantization for high-GQA models: the minimum safe bit-width is approximately $b \geq \frac{1}{2} \log_4 \!\left(\frac{\sqrt{3}\,\pi\,g}{2\,D_{\max}}\right)$. The GQA compensation is implemented in `KVCacheQuantizer` via a `gqa_factor` parameter:
409
+
410
+ ```python
411
+ gqa_extra = math.ceil(math.log(max(gqa_factor, 1), 4)) if gqa_factor > 1 else 0
412
+ effective_bits = min(num_bits + gqa_extra, 8)
413
+ ```
414
+
415
+ **Per-layer calibration.** The original implementation used a single `KVCacheQuantizer` instance calibrated on KV data pooled from all transformer layers. This is incorrect for `OutlierKVQuant`: outlier channels are defined as those with the highest variance in the calibration data, and different transformer layers have completely different KV distributions. Layer 0 might have outlier variance concentrated in dimensions 12, 47, and 93; layer 15 might have it in dimensions 5, 61, and 120. Pooling calibration data across layers averages out these layer-specific patterns, causing the outlier detector to misidentify channels for every individual layer.
416
+
417
+ The correct approach is one `KVCacheQuantizer` per transformer layer, each calibrated independently on its own layer's KV data. The per-layer calibration loop in `demo_llm.py` is:
418
+
419
+ ```python
420
+ kvc_layers = []
421
+ for lk, lv in cal_kvs: # cal_kvs[i] = (k_layer_i, v_layer_i)
422
+ kvc_l = KVCacheQuantizer(head_dim=head_dim, num_bits=bits,
423
+ use_outlier=True, gqa_factor=gqa_factor, ...)
424
+ kvc_l.calibrate(lk, lv) # calibrate on this layer's data only
425
+ kvc_layers.append(kvc_l)
426
+ ```
427
+
428
+ This aligns with the paper's framework: outlier channel detection should use each layer's own KV statistics. The underlying Lloyd-Max quantization theory is unchanged --- only the channel identification step is now layer-specific. For MHA models with uniform KV distributions across layers, a single shared quantizer remains acceptable, but for GQA models with deep per-layer specialisation, per-layer calibration is essential for correct outlier identification.
429
+
430
+ **PQ encode speed.** The current PQ implementation encodes new tokens with $M$ sequential `cdist` calls in Python. For $M=16$, $K=256$, this is roughly 16x slower per append than scalar `bucketize`. At short contexts the overhead is acceptable; at very long contexts (thousands of tokens appended during generation) it becomes the bottleneck. A batched CUDA kernel that performs all $M$ nearest-centroid lookups in one pass would close most of this gap.
431
+
432
+ **Calibration dependency.** PQ codebooks are trained on prefill KV vectors and held fixed during generation. If the generation distribution drifts significantly from the prefill distribution (e.g., very different topic or language), codebook quality degrades. Periodic re-calibration or online codebook updates are not implemented.
433
+
434
+ **Scope of PPL evaluation.** Perplexity numbers are reported on a fixed 50-chunk evaluation with 128-token prefill and 64 target tokens. This captures steady-state quantization quality but does not measure latency, throughput, or memory footprint directly. Real deployment decisions require profiling on target hardware.
435
+
436
+ ---
437
+
438
+ ## 8. Conclusion
439
+
440
+ The core insight behind KVQuant rotate into an approximately isotropic distribution, then apply optimal 1-D quantization is sound and gives strong theoretical guarantees. What we've shown here is that there's significant headroom beyond those guarantees if you're willing to exploit the structure of how transformers actually use the KV cache.
441
+
442
+ Attention-weighted quantization aligns the bit budget with what the model actually attends to. Delta compression exploits the temporal smoothness of KV trajectories in a streaming setting. Adaptive allocation adjusts to importance that you couldn't have known at compression time. Low-rank correction recovers structure from an error that isn't as random as you might assume.
443
+
444
+ None of these require modifying the model or changing the training procedure. They're all implemented as composable PyTorch modules in `kvquant/`, and they can be adopted in any combination. Four further improvements strengthen the implementation: k-means++ seeding reduces Lloyd-Max initialisation MSE by up to 75% at low bit-widths; K-V asymmetric quantization cuts V reconstruction error by 61.5% at a lower bit budget; combining delta compression with outlier-aware quantization reduces V MSE by 95.4% versus same-budget scalar; and Hadamard rotation is now a configurable parameter throughout the stack.
445
+
446
+ Two additional fixes address non-MHA architectures. GQA models amplify effective distortion by $g$ (query heads per KV head); compensating with $\lceil \log_4 g \rceil$ extra bits per coordinate, capped at 8, restores generation quality on Qwen2.5-1.5B ($g=6$) and Qwen2.5-7B ($g=7$). Per-layer calibration of the outlier detector, rather than pooling KV data across all transformer layers, correctly identifies the layer-specific channels that carry anomalous variance. The full test suite (88 tests) passes cleanly.
447
+
448
+ ---
449
+
450
+ ## 9. Implementation Notes
451
+
452
+ These are runtime optimizations applied after the paper's algorithms were finalized. They do not change any results the quality numbers in Sections 3-3.4 are unchanged but they reduce wall-clock time substantially.
453
+
454
+ #### 9.1 Batched `get()` in `AdaptiveKVCache`
455
+
456
+ **Problem.** The original `get()` called `dequantize()` once per cached token, resulting in $T$ sequential Python dispatch calls and $T$ small matrix multiplies. For $T=128$ this was 5.4 ms, growing linearly with sequence length.
457
+
458
+ **Fix.** Group token indices by bit-width tier, then dequantize all tokens in each tier with a single batched call:
459
+
460
+ ```python
461
+ # before: T individual dequantize calls
462
+ k_out = [self._dequantize(self._k_entries[t]) for t in range(T)]
463
+
464
+ # after: one call per tier (typically 2-4 calls total)
465
+ tier_idx = defaultdict(list)
466
+ for t, e in enumerate(self._k_entries):
467
+ tier_idx[e.bits].append(t)
468
+
469
+ for bits, idxs in tier_idx.items():
470
+ quantizer = self._quantizers[str(bits)]
471
+ k_batch = self._batch_dequantize(
472
+ quantizer, [self._k_entries[t].q for t in idxs]
473
+ )
474
+ ```
475
+
476
+ `_batch_dequantize` concatenates all indices and norms along the batch axis, calls `dequantize` once, then splits the result:
477
+
478
+ ```python
479
+ def _batch_dequantize(self, quantizer, qs):
480
+ n = len(qs)
481
+ BH = qs[0].indices.reshape(-1, self.head_dim).shape[0]
482
+ all_idx = torch.cat([q.indices.reshape(-1, self.head_dim) for q in qs])
483
+ all_norms = torch.cat([q.norms.reshape(-1, 1) for q in qs])
484
+ combined = QuantizedMSE(all_idx, all_norms, (n * BH, self.head_dim))
485
+ result = quantizer.dequantize(combined) # one BLAS call
486
+ return result.reshape(n, BH, self.head_dim)
487
+ ```
488
+
489
+ For $T=128$ with 4 tiers this reduces from 128 Python dispatch calls to ~4, giving roughly 4x reduction in `get()` overhead and better BLAS utilisation.
490
+
491
+ #### 9.2 Randomized SVD with Power Iteration in `LowRankCorrection`
492
+
493
+ **Problem.** The full `torch.linalg.svd` in `LowRankCorrection.quantize()` is $O(T \cdot d^2)$ and allocates a $(T, d)$ temporary. For $T=512$, $d=128$ this dominates the quantize step.
494
+
495
+ **Fix.** Replace with a randomized SVD (Halko et al., 2011, Algorithm 4.4) that only computes the top-$r$ singular vectors:
496
+
497
+ ```python
498
+ def _randomized_svd(A, rank, n_oversampling=10, n_power_iter=2):
499
+ N, m, n = A.shape
500
+ k = min(rank + n_oversampling, min(m, n))
501
+ # Random Gaussian sketch
502
+ Omega = torch.randn(N, n, k, device=A.device, dtype=A.dtype)
503
+ Y = A @ Omega
504
+ # Power iteration: refine the range estimate
505
+ for _ in range(n_power_iter):
506
+ Q, _ = torch.linalg.qr(Y)
507
+ Z, _ = torch.linalg.qr(A.transpose(-2, -1) @ Q)
508
+ Y = A @ Z
509
+ # Small exact SVD in the sketched subspace
510
+ Q, _ = torch.linalg.qr(Y)
511
+ B = Q.transpose(-2, -1) @ A # (N, k, n)
512
+ U_hat, S, Vh = torch.linalg.svd(B, full_matrices=False)
513
+ U = Q @ U_hat
514
+ return U[..., :rank], S[..., :rank], Vh[..., :rank, :]
515
+ ```
516
+
517
+ `n_oversampling=10` and `n_power_iter=2` bring approximation error to within 1% of the full SVD while running $2.5\times$ faster for $T \geq 64$. For short sequences ($T < 64$) the full SVD has lower fixed overhead and is used instead:
518
+
519
+ ```python
520
+ if T_seq >= 64:
521
+ U, S, Vh = _randomized_svd(residual_flat, rank=r)
522
+ else:
523
+ U, S, Vh = torch.linalg.svd(residual_flat, full_matrices=False)
524
+ U, S, Vh = U[..., :r], S[..., :r], Vh[..., :r, :]
525
+ ```
526
+
527
+ At short sequences (T=32) the randomized SVD is actually slower (0.67 ms vs 0.41 ms for full SVD) due to fixed sketch overhead, which is why the full SVD is used below T=64. From T=64 upward the randomized approach wins: 1.2x faster at T=64, 1.8x at T=128, 2.5x at T=256, and 2.6x at T=512 (2.44 ms vs 6.37 ms). Approximation error stays within 1% of the full SVD result across all sequence lengths tested.
528
+
529
+ #### 9.3 `_dequantize_unit` Fast Path in `KVQuantIP`
530
+
531
+ **Problem.** `KVQuantIP.dequantize()` calls `self.mse_quantizer.dequantize(q_mse)` to recover the MSE component. But since the input to the MSE stage is already unit-normalised, the stored norms are always 1.0 allocating a `(N, 1)` ones tensor and multiplying by it on every call is pure overhead.
532
+
533
+ **Fix.** Add a `_dequantize_unit` path to `KVQuantMSE` that skips the norm multiply entirely:
534
+
535
+ ```python
536
+ def _dequantize_unit(self, idx_flat: Tensor) -> Tensor:
537
+ """Fast path for unit-norm vectors skips norm restore."""
538
+ y_tilde = self.centroids[idx_flat] # (N, d)
539
+ return self.rotation.inverse(y_tilde) # (N, d)
540
+ ```
541
+
542
+ `KVQuantIP.dequantize()` calls this instead of the full path:
543
+
544
+ ```python
545
+ # before
546
+ x_hat_unit = self.mse_quantizer.dequantize(q_mse) # allocates dummy norms
547
+
548
+ # after
549
+ x_hat_unit = self.mse_quantizer._dequantize_unit(idx_flat) # no alloc
550
+ ```
551
+
552
+ The saving is modest for large batches (1.08x at $N=4096$, $d=128$) but eliminates one unnecessary allocation per call.
553
+
554
+ #### 9.4 Boundary Caching in `build_codebook`
555
+
556
+ The $k-1$ centroid midpoints (quantization boundaries used by `torch.bucketize`) were previously recomputed on every `KVQuantMSE` instantiation. They are now computed once and cached alongside the centroids:
557
+
558
+ ```python
559
+ # codebook.py
560
+ _CACHE: dict[tuple[int, int], tuple[Tensor, Tensor]] = {}
561
+
562
+ def build_codebook(num_bits, dim=1, device=None) -> tuple[Tensor, Tensor]:
563
+ key = (num_bits, dim)
564
+ if key not in _CACHE:
565
+ c = _lloyd_max(num_bits, dim)
566
+ b = ((c[:-1] + c[1:]) / 2).contiguous()
567
+ _CACHE[key] = (c, b)
568
+ centroids, boundaries = _CACHE[key]
569
+ ...
570
+ return centroids, boundaries
571
+ ```
572
+
573
+ `KVQuantMSE.__init__` now registers both as buffers:
574
+
575
+ ```python
576
+ centroids, boundaries = build_codebook(num_bits, dim)
577
+ self.register_buffer("centroids", centroids)
578
+ self.register_buffer("boundaries", boundaries)
579
+ ```
580
+
581
+ The saving is negligible in practice (the $k-1$ additions are trivial), but it removes a recompute and makes the caching contract explicit.
582
+
583
+ #### 9.5 First-Token Accuracy in Quantized Generation
584
+
585
+ **Problem.** During quantized generation in `demo_llm.py`, token 1 (the first generated token) was evaluated using the unquantized prefill logit the same logit that all bit-width variants saw so all quantized modes produced identical first tokens regardless of quantization quality. Only from token 2 onward, when the quantized KV cache was actually used for attention, did the bit-widths diverge.
586
+
587
+ Root cause: `first_logits = prefill_out.logits[:, -1, :]` is the last prefill position's output computed with the full float32 KV cache. After quantizing the cache to `past`, this `first_logits` variable was reused unchanged for all bit-width branches.
588
+
589
+ **Fix.** Crop the quantized cache to $T_p - 1$ positions, then re-run the last prompt token through the model with that cropped cache to obtain a logit that reflects the quantized state:
590
+
591
+ ```python
592
+ def _crop_cache(native_cache, seq_len: int):
593
+ """Return a deep copy of native_cache with KV tensors truncated to seq_len."""
594
+ cache = copy.deepcopy(native_cache)
595
+ if hasattr(cache, "key_cache"): # DynamicCache / HybridCache
596
+ for i in range(len(cache.key_cache)):
597
+ k = cache.key_cache[i]
598
+ if isinstance(k, torch.Tensor) and k.shape[-2] > seq_len:
599
+ cache.key_cache[i] = k[..., :seq_len, :]
600
+ cache.value_cache[i] = cache.value_cache[i][..., :seq_len, :]
601
+ return cache
602
+
603
+ # --- inside the generation loop ---
604
+ past = _quantize_cache(native_cache_orig, kvc, correction_rank=args.correction_rank)
605
+ past_crop = _crop_cache(past, T_p - 1) # crop to T_p-1 positions
606
+ with torch.no_grad():
607
+ q1_out = model(input_ids[:, -1:], past_key_values=past_crop, use_cache=False)
608
+ first_logits_m = q1_out.logits[:, -1, :].clone() # quantized first-token logit
609
+ ```
610
+
611
+ The crop-and-rerun costs one extra forward pass (through all layers, but with a length-1 sequence so $O(T_p \cdot d \cdot \text{layers})$ for attention), small relative to the $O(T_g \cdot \ldots)$ generation loop for any reasonable $T_g$.
612
+
613
+ **Effect.** Before the fix, `demo_llm.py` with a 3-bit Qwen2.5-1.5B-Instruct (1.5 B parameters, GQA with $g=6$, $d=128$) on `"France Capital City :"` produced:
614
+
615
+ ```
616
+ 3-bit: France Capital City : [Chinese characters "Faguo Bali, Paris, is the capital of France..."]
617
+ ```
618
+
619
+ (mixing Chinese and English mid-sentence, corrupted from token 1 misalignment). After:
620
+
621
+ ```
622
+ 3-bit: France Capital City : Paris
623
+ ```
624
+
625
+ The first generated token is now `Paris` at all quantized bit-widths, matching the float32 reference. This confirms that the bug was entirely in the first-token logit selection, not in the quantized cache itself.
626
+
627
+ ![Figure 6: Before and after the first-token fix crop cache to T_p-1 and re-run last prompt token](figures/fig6_firsttoken.png)
628
+
629
+ #### 9.6 Three Optimisations in `DeltaKVCache`
630
+
631
+ Three performance and correctness issues were identified and fixed in the delta compression implementation after initial deployment.
632
+
633
+ **Fix 1 O(T^2) reconstruction cost.** The original `get()` method rebuilt the full cache from scratch on every call by looping over all $T$ tokens and dequantizing each stored delta. Since `get()` is called at every attention step during generation, the total reconstruction cost was $O(T^2)$. The fix maintains two lists `_k_reconstructed` and `_v_reconstructed` incrementally inside `push()`: after each token is compressed, the current running reconstruction is appended. `get()` then calls `torch.stack()` and returns immediately $O(1)$ reconstruction computation. Trade-off: the reconstructed float32 vectors are stored permanently alongside the compressed deltas, increasing persistent memory by $T \cdot d \cdot 4$ bytes.
634
+
635
+ **Fix 2 O(T) anchor lookup.** `_anchors` was a `list[int]`; Python's `in` operator on a list is $O(n)$. With $T$ membership checks per `get()` call this was $O(T^2)$ just for anchor lookups. Changing `_anchors` to a `set[int]` (hash set, $O(1)$ lookup) and `.append()` to `.add()` eliminates this with no other trade-off.
636
+
637
+ **Fix 3 Adaptive anchor placement.** The original `anchor_every=N` parameter re-anchors at fixed positions regardless of whether the sequence is actually drifting. A sudden large change at step 20 with `anchor_every=32` accumulates error until step 32, wasting an anchor on a stable region. The new `anchor_threshold` parameter re-anchors when $\|\boldsymbol{\delta}_t\| / \|\mathbf{k}_t\| > \tau$, triggering exactly at change-points. The default $\tau = 0$ disables adaptive anchoring for full backwards compatibility.
638
+
639
+ **Empirical result** ($T=30$, 3-bit, sudden drift at $t=15$):
640
+
641
+ With no anchor beyond the initial one, MSE accumulates to 0.116. A fixed `anchor_every=32` uses a second anchor but places it at position 32 after the drift at position 15 so it barely helps (MSE 0.116, essentially the same). Adaptive anchoring with `anchor_threshold=0.4` uses the same two anchors but fires the second one at position 15 exactly where the drift happens, reducing MSE to 0.00126 a **98.9% reduction** at zero extra anchor cost.
642
+
643
+ All three fixes are covered by 10 new tests in `TestDeltaKVCache`; the full suite of 88 tests passes.
644
+
645
+ ---
646
+
647
+ ## References
648
+
649
+ 1. Zandieh, A. et al. "KVQuant: Near-Optimal Vector Quantization." arXiv:2504.19874 (2025).
650
+ 2. Ailon, N. & Chazelle, B. "Approximate nearest neighbors and the fast Johnson-Lindenstrauss transform." STOC (2006).
651
+ 3. Zhang, Z. et al. "H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models." NeurIPS (2023).
652
+ 4. Frantar, E. et al. "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers." ICLR (2023).
653
+ 5. Lin, J. et al. "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." MLSys (2024).
654
+ 6. Hooper, C. et al. "KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization." NeurIPS (2024).
655
+ 7. Chee, J. et al. "QuIP: 2-Bit Quantization of Large Language Models with Guarantees." NeurIPS (2023).
656
+ 8. Ashkboos, S. et al. "QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs." arXiv:2404.00456 (2024).
657
+ 9. Xiao, G. et al. "Efficient Streaming Language Models with Attention Sinks." ICLR (2024).
658
+ 10. Liu, Z. et al. "ScissorHands: Exploiting the Persistence of Importance Hypothesis for LLM KV Cache Compression at Test Time." NeurIPS (2023).
659
+ 11. Halko, N., Martinsson, P.-G. & Tropp, J. "Finding Structure with Randomness: Probabilistic Algorithms for Constructing Approximate Matrix Decompositions." SIAM Review (2011).
660
+ 12. Max, J. "Quantizing for minimum distortion." IRE Transactions on Information Theory (1960).
661
+ 13. Jégou, H., Douze, M. & Schmid, C. "Product Quantization for Nearest Neighbor Search." IEEE Transactions on Pattern Analysis and Machine Intelligence 33(1):117-128 (2011).
662
+ 14. Arthur, D. & Vassilvitskii, S. "k-means++: The Advantages of Careful Seeding." Proceedings of the 18th Annual ACM-SIAM Symposium on Discrete Algorithms (SODA), 1027-1035 (2007).
@@ -0,0 +1,5 @@
1
+ kvquant_plus_plus-0.1.0.dist-info/licenses/LICENSE,sha256=w5MmkL55AXQuC-89qYG9BAas1F7O9PWjd6AbW1_KPHo,11548
2
+ kvquant_plus_plus-0.1.0.dist-info/METADATA,sha256=GueZhLkw_Hflk6ACaB7M2oQiuig--ilLV28i8rWuAmo,55715
3
+ kvquant_plus_plus-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
4
+ kvquant_plus_plus-0.1.0.dist-info/top_level.txt,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
5
+ kvquant_plus_plus-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,201 @@
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2026 Syed Muheeb Uddin
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.