turboquant-gpu 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Anirudh Bharadwaj Vangara
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,96 @@
1
+ Metadata-Version: 2.4
2
+ Name: turboquant-gpu
3
+ Version: 0.1.0
4
+ Summary: TurboQuant KV cache compression for LLM inference — cuTile GPU kernels
5
+ Author: Anirudh Bharadwaj Vangara
6
+ License-Expression: MIT
7
+ Keywords: quantization,kv-cache,llm,inference,cutile,cuda,gpu,attention,blackwell,hopper,h100,b200
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Python: >=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch
19
+ Requires-Dist: scipy
20
+ Provides-Extra: gpu
21
+ Requires-Dist: cuda-tile; extra == "gpu"
22
+ Dynamic: license-file
23
+
24
+ # turboquant-gpu
25
+
26
+ **5.02x KV cache compression for LLM inference** — GPU-accelerated cuTile kernels with PyTorch fallback.
27
+
28
+ ```
29
+ pip install turboquant-gpu
30
+ ```
31
+
32
+ ## quick start
33
+
34
+ ```python
35
+ from transformers import AutoModelForCausalLM, AutoTokenizer
36
+ from turboquant_gpu import TurboQuantEngine
37
+ import torch
38
+
39
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B", torch_dtype=torch.float16, device_map="cuda")
40
+ tok = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
41
+
42
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
43
+ result = engine.generate(model, tok, "The key to efficient LLM inference is")
44
+
45
+ print(result["text"])
46
+ print(f"{result['tokens']} tokens | {result['stats']['ratio']:.1f}x compression")
47
+ ```
48
+
49
+ ## how it works
50
+
51
+ Implements the [TurboQuant](https://arxiv.org/abs/2501.09747) algorithm:
52
+
53
+ 1. **normalize + rotate** — random orthogonal rotation (Pi) makes coordinates near-Gaussian
54
+ 2. **Lloyd-Max quantize** — optimal scalar quantization against N(0, 1/d)
55
+ 3. **QJL bias correction** — 1-bit sign sketch of the residual for unbiased key scores
56
+
57
+ At 3-bit (2-bit MSE + 1-bit QJL) this gives ~5x compression with negligible quality loss.
58
+
59
+ ## step-by-step api
60
+
61
+ ```python
62
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
63
+
64
+ # after model prefill:
65
+ compressed = engine.compress_kv_cache(out.past_key_values)
66
+ cache = engine.build_cache(compressed)
67
+ stats = engine.compression_stats(out.past_key_values)
68
+
69
+ # or just do it all in one call:
70
+ result = engine.generate(model, tokenizer, "your prompt here")
71
+ ```
72
+
73
+ ## gpu support
74
+
75
+ Written in [cuTile](https://docs.nvidia.com/cuda/cutile-python/) for cross-architecture portability.
76
+ Falls back to PyTorch if cuTile or a compatible driver isn't available.
77
+
78
+ | GPU family | Architecture | Status |
79
+ |------------|-------------|--------|
80
+ | A100 | Ampere | supported (PyTorch fallback) |
81
+ | H100 | Hopper | supported |
82
+ | B200/B300 | Blackwell | supported + swizzle fast path |
83
+
84
+ ## kernels
85
+
86
+ | kernel | what it does |
87
+ |--------|-------------|
88
+ | `compress_keys` | normalize → rotate(Pi^T) → Lloyd-Max quantize → QJL signs |
89
+ | `compress_values` | normalize → rotate(Pi^T) → Lloyd-Max quantize |
90
+ | `decompress_values` | dequantize → un-rotate(Pi) → scale by norms |
91
+ | `attention_scores` | asymmetric dot product with QJL correction |
92
+ | `fused_attention` | scores + online softmax + V accumulation |
93
+
94
+ ## license
95
+
96
+ MIT
@@ -0,0 +1,73 @@
1
+ # turboquant-gpu
2
+
3
+ **5.02x KV cache compression for LLM inference** — GPU-accelerated cuTile kernels with PyTorch fallback.
4
+
5
+ ```
6
+ pip install turboquant-gpu
7
+ ```
8
+
9
+ ## quick start
10
+
11
+ ```python
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from turboquant_gpu import TurboQuantEngine
14
+ import torch
15
+
16
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B", torch_dtype=torch.float16, device_map="cuda")
17
+ tok = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
18
+
19
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
20
+ result = engine.generate(model, tok, "The key to efficient LLM inference is")
21
+
22
+ print(result["text"])
23
+ print(f"{result['tokens']} tokens | {result['stats']['ratio']:.1f}x compression")
24
+ ```
25
+
26
+ ## how it works
27
+
28
+ Implements the [TurboQuant](https://arxiv.org/abs/2501.09747) algorithm:
29
+
30
+ 1. **normalize + rotate** — random orthogonal rotation (Pi) makes coordinates near-Gaussian
31
+ 2. **Lloyd-Max quantize** — optimal scalar quantization against N(0, 1/d)
32
+ 3. **QJL bias correction** — 1-bit sign sketch of the residual for unbiased key scores
33
+
34
+ At 3-bit (2-bit MSE + 1-bit QJL) this gives ~5x compression with negligible quality loss.
35
+
36
+ ## step-by-step api
37
+
38
+ ```python
39
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cuda")
40
+
41
+ # after model prefill:
42
+ compressed = engine.compress_kv_cache(out.past_key_values)
43
+ cache = engine.build_cache(compressed)
44
+ stats = engine.compression_stats(out.past_key_values)
45
+
46
+ # or just do it all in one call:
47
+ result = engine.generate(model, tokenizer, "your prompt here")
48
+ ```
49
+
50
+ ## gpu support
51
+
52
+ Written in [cuTile](https://docs.nvidia.com/cuda/cutile-python/) for cross-architecture portability.
53
+ Falls back to PyTorch if cuTile or a compatible driver isn't available.
54
+
55
+ | GPU family | Architecture | Status |
56
+ |------------|-------------|--------|
57
+ | A100 | Ampere | supported (PyTorch fallback) |
58
+ | H100 | Hopper | supported |
59
+ | B200/B300 | Blackwell | supported + swizzle fast path |
60
+
61
+ ## kernels
62
+
63
+ | kernel | what it does |
64
+ |--------|-------------|
65
+ | `compress_keys` | normalize → rotate(Pi^T) → Lloyd-Max quantize → QJL signs |
66
+ | `compress_values` | normalize → rotate(Pi^T) → Lloyd-Max quantize |
67
+ | `decompress_values` | dequantize → un-rotate(Pi) → scale by norms |
68
+ | `attention_scores` | asymmetric dot product with QJL correction |
69
+ | `fused_attention` | scores + online softmax + V accumulation |
70
+
71
+ ## license
72
+
73
+ MIT
@@ -0,0 +1,38 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "turboquant-gpu"
7
+ version = "0.1.0"
8
+ description = "TurboQuant KV cache compression for LLM inference — cuTile GPU kernels"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ { name = "Anirudh Bharadwaj Vangara" },
14
+ ]
15
+ keywords = [
16
+ "quantization", "kv-cache", "llm", "inference",
17
+ "cutile", "cuda", "gpu", "attention",
18
+ "blackwell", "hopper", "h100", "b200",
19
+ ]
20
+ classifiers = [
21
+ "Development Status :: 3 - Alpha",
22
+ "Intended Audience :: Science/Research",
23
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
24
+ "Programming Language :: Python :: 3",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Programming Language :: Python :: 3.12",
28
+ ]
29
+ dependencies = [
30
+ "torch",
31
+ "scipy",
32
+ ]
33
+
34
+ [project.optional-dependencies]
35
+ gpu = ["cuda-tile"]
36
+
37
+ [tool.setuptools.packages.find]
38
+ include = ["turboquant_gpu*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,153 @@
1
+ """
2
+ Test attention kernel correctness against PyTorch reference and ground truth.
3
+
4
+ Key properties verified:
5
+ - Scores match PyTorch reference within FP16 tolerance
6
+ - Inner product estimator is unbiased (mean error ≈ 0)
7
+ - Correct 1/√d scaling
8
+ - Sweep over multiple bit widths and sequence lengths
9
+ """
10
+
11
+ import math
12
+ import sys
13
+ import os
14
+ import torch
15
+ import pytest
16
+
17
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
18
+ from turboquant_gpu import TurboQuantEngine
19
+
20
+
21
+ @pytest.mark.parametrize("seq_q,seq_k", [(1, 64), (1, 256), (1, 1024), (16, 512)])
22
+ @pytest.mark.parametrize("total_bits", [3, 4])
23
+ def test_scores_shape(seq_q, seq_k, total_bits):
24
+ engine = TurboQuantEngine(head_dim=128, total_bits=total_bits, device="cpu")
25
+
26
+ Q = torch.randn(seq_q, 128).half()
27
+ K = torch.randn(seq_k, 128).half()
28
+ compressed_k = engine.compress_keys_pytorch(K)
29
+ scores = engine.attention_scores_pytorch(Q, compressed_k)
30
+
31
+ assert scores.shape == (seq_q, seq_k)
32
+
33
+
34
+ @pytest.mark.parametrize("total_bits", [3, 4])
35
+ def test_unbiasedness(total_bits):
36
+ """
37
+ The asymmetric estimator should be unbiased: E[estimated_ip] ≈ true_ip.
38
+ We test by averaging over many random pairs.
39
+ """
40
+ d = 128
41
+ n = 2000
42
+ engine = TurboQuantEngine(head_dim=d, total_bits=total_bits, seed=42, device="cpu")
43
+
44
+ K = torch.randn(n, d)
45
+ K = K / torch.norm(K, dim=-1, keepdim=True)
46
+ Q = torch.randn(n, d)
47
+ Q = Q / torch.norm(Q, dim=-1, keepdim=True)
48
+
49
+ true_ip = (Q * K).sum(dim=-1)
50
+
51
+ compressed_k = engine.compress_keys_pytorch(K.half())
52
+ scores_unscaled = engine.attention_scores_pytorch(Q.half(), compressed_k)
53
+
54
+ estimated_per_pair = torch.diag(scores_unscaled) / engine.scale
55
+
56
+ bias = (estimated_per_pair - true_ip).mean().abs().item()
57
+ assert bias < 0.05, f"Bias too large: {bias:.4f} (should be ~0)"
58
+
59
+
60
+ @pytest.mark.parametrize("total_bits", [2, 3, 4])
61
+ def test_correlation_with_true_scores(total_bits):
62
+ """Estimated scores should correlate well with true Q·K^T scores."""
63
+ d = 128
64
+ seq_q, seq_k = 1, 512
65
+ engine = TurboQuantEngine(head_dim=d, total_bits=total_bits, seed=42, device="cpu")
66
+
67
+ Q = torch.randn(seq_q, d).half()
68
+ K = torch.randn(seq_k, d).half()
69
+
70
+ true_scores = (Q.float() @ K.float().T) * engine.scale
71
+ compressed_k = engine.compress_keys_pytorch(K)
72
+ estimated_scores = engine.attention_scores_pytorch(Q, compressed_k)
73
+
74
+ corr = torch.corrcoef(
75
+ torch.stack([true_scores.flatten(), estimated_scores.flatten()])
76
+ )[0, 1].item()
77
+
78
+ min_corr = {2: 0.45, 3: 0.75, 4: 0.90}[total_bits]
79
+ assert corr > min_corr, (
80
+ f"bits={total_bits}: correlation={corr:.4f} < {min_corr}"
81
+ )
82
+
83
+
84
+ def test_scaling_correct():
85
+ """Scores should be scaled by 1/√d."""
86
+ d = 128
87
+ engine = TurboQuantEngine(head_dim=d, total_bits=3, device="cpu")
88
+
89
+ Q = torch.randn(1, d).half()
90
+ K = torch.randn(64, d).half()
91
+
92
+ compressed_k = engine.compress_keys_pytorch(K)
93
+ scores = engine.attention_scores_pytorch(Q, compressed_k)
94
+
95
+ raw_ip = Q.float() @ compressed_k["k_mse"].float().T
96
+ assert scores.abs().mean().item() < raw_ip.abs().mean().item(), (
97
+ "Scaled scores should be smaller than raw inner products"
98
+ )
99
+
100
+
101
+ def test_single_decode_token():
102
+ """Typical decode scenario: seq_q=1, seq_k=large."""
103
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
104
+ Q = torch.randn(1, 128).half()
105
+ K = torch.randn(2048, 128).half()
106
+
107
+ compressed_k = engine.compress_keys_pytorch(K)
108
+ scores = engine.attention_scores_pytorch(Q, compressed_k)
109
+
110
+ assert scores.shape == (1, 2048)
111
+ assert not torch.isnan(scores).any()
112
+ assert not torch.isinf(scores).any()
113
+
114
+
115
+ def test_needle_in_haystack():
116
+ """Can we still find the most-attended key after compression?"""
117
+ d = 128
118
+ seq_k = 1024
119
+ engine = TurboQuantEngine(head_dim=d, total_bits=3, seed=42, device="cpu")
120
+
121
+ K = torch.randn(seq_k, d)
122
+ K = K / torch.norm(K, dim=-1, keepdim=True)
123
+
124
+ needle_pos = seq_k // 3
125
+ Q = K[needle_pos].unsqueeze(0)
126
+
127
+ true_scores = Q @ K.T
128
+ true_top1 = true_scores.argmax(dim=-1).item()
129
+
130
+ compressed_k = engine.compress_keys_pytorch(K.half())
131
+ estimated_scores = engine.attention_scores_pytorch(Q.half(), compressed_k)
132
+ estimated_top1 = estimated_scores.argmax(dim=-1).item()
133
+
134
+ assert estimated_top1 == true_top1 or abs(estimated_top1 - needle_pos) < 5, (
135
+ f"Needle at {needle_pos}, true top1={true_top1}, estimated top1={estimated_top1}"
136
+ )
137
+
138
+
139
+ @pytest.mark.parametrize("total_bits", [3, 4])
140
+ def test_bits_sweep_scores_reasonable(total_bits):
141
+ """Higher bits should produce lower MSE in scores."""
142
+ d = 128
143
+ engine = TurboQuantEngine(head_dim=d, total_bits=total_bits, device="cpu")
144
+
145
+ Q = torch.randn(4, d).half()
146
+ K = torch.randn(256, d).half()
147
+
148
+ true_scores = (Q.float() @ K.float().T) * engine.scale
149
+ compressed_k = engine.compress_keys_pytorch(K)
150
+ estimated_scores = engine.attention_scores_pytorch(Q, compressed_k)
151
+
152
+ score_mse = ((true_scores - estimated_scores) ** 2).mean().item()
153
+ assert score_mse < 0.5, f"Score MSE too high: {score_mse:.6f}"
@@ -0,0 +1,92 @@
1
+ """
2
+ Test Lloyd-Max codebook correctness.
3
+
4
+ Verifies:
5
+ - Centroid symmetry around zero
6
+ - Correct number of levels per bit-width
7
+ - Centroids are monotonically sorted
8
+ - Distortion is within theoretical upper bound
9
+ - Quantize → dequantize round-trip consistency
10
+ """
11
+
12
+ import math
13
+ import sys
14
+ import os
15
+ import torch
16
+ import pytest
17
+
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
19
+ from turboquant_gpu.codebook import LloydMaxCodebook, solve_lloyd_max
20
+
21
+
22
+ @pytest.mark.parametrize("d", [64, 128, 256])
23
+ @pytest.mark.parametrize("bits", [1, 2, 3, 4])
24
+ def test_correct_num_levels(d, bits):
25
+ cb = LloydMaxCodebook(d, bits)
26
+ assert cb.n_levels == 2**bits
27
+ assert cb.centroids.shape == (2**bits,)
28
+ assert cb.boundaries.shape == (2**bits - 1,)
29
+
30
+
31
+ @pytest.mark.parametrize("d", [64, 128, 256])
32
+ @pytest.mark.parametrize("bits", [2, 3, 4])
33
+ def test_symmetry(d, bits):
34
+ cb = LloydMaxCodebook(d, bits)
35
+ assert cb.centroids.sum().abs().item() < 1e-4, (
36
+ f"Centroids should be symmetric: sum={cb.centroids.sum().item()}"
37
+ )
38
+
39
+
40
+ @pytest.mark.parametrize("bits", [1, 2, 3, 4])
41
+ def test_sorted(bits):
42
+ cb = LloydMaxCodebook(128, bits)
43
+ for i in range(len(cb.centroids) - 1):
44
+ assert cb.centroids[i] < cb.centroids[i + 1]
45
+ for i in range(len(cb.boundaries) - 1):
46
+ assert cb.boundaries[i] < cb.boundaries[i + 1]
47
+
48
+
49
+ @pytest.mark.parametrize("bits", [1, 2, 3, 4])
50
+ def test_boundaries_between_centroids(bits):
51
+ cb = LloydMaxCodebook(128, bits)
52
+ for i in range(len(cb.boundaries)):
53
+ assert cb.centroids[i] < cb.boundaries[i] < cb.centroids[i + 1]
54
+
55
+
56
+ @pytest.mark.parametrize("bits", [1, 2, 3, 4])
57
+ def test_distortion_within_paper_bound(bits):
58
+ """MSE distortion per vector <= sqrt(3) * pi/2 * (1/4^b) for unit vectors."""
59
+ d = 128
60
+ cb = LloydMaxCodebook(d, bits)
61
+ sigma = 1.0 / math.sqrt(d)
62
+
63
+ n_samples = 5000
64
+ x = torch.randn(n_samples, d)
65
+ x = x / torch.norm(x, dim=-1, keepdim=True)
66
+
67
+ rotated = x # with identity rotation, still approximately N(0, 1/d)
68
+ indices = cb.quantize(rotated)
69
+ reconstructed = cb.dequantize(indices)
70
+ mse = ((rotated - reconstructed) ** 2).sum(dim=-1).mean().item()
71
+
72
+ upper_bound = math.sqrt(3) * math.pi / 2 * (1.0 / (4**bits))
73
+ assert mse < upper_bound * 2.0, (
74
+ f"bits={bits}: MSE={mse:.6f} exceeds 2× paper bound {upper_bound:.6f}"
75
+ )
76
+
77
+
78
+ def test_roundtrip_identity():
79
+ """Quantize → dequantize should map each centroid exactly to itself."""
80
+ cb = LloydMaxCodebook(128, 3)
81
+ indices = torch.arange(cb.n_levels)
82
+ reconstructed = cb.dequantize(indices)
83
+ assert torch.allclose(reconstructed, cb.centroids, atol=1e-6)
84
+
85
+
86
+ def test_quantize_nearest_centroid():
87
+ """Values close to a centroid should map to that centroid's index."""
88
+ cb = LloydMaxCodebook(128, 2)
89
+ for i, c in enumerate(cb.centroids):
90
+ test_val = c + 1e-6
91
+ idx = cb.quantize(test_val.unsqueeze(0))
92
+ assert idx.item() == i, f"Expected index {i}, got {idx.item()}"
@@ -0,0 +1,133 @@
1
+ """
2
+ Test compression kernel correctness against PyTorch reference.
3
+
4
+ Compares cuTile kernel output (via TurboQuantEngine) against the pure PyTorch
5
+ reference path for:
6
+ - MSE indices match
7
+ - QJL signs match
8
+ - Norms are within FP16 tolerance
9
+ - Various sequence lengths and bit widths
10
+ """
11
+
12
+ import sys
13
+ import os
14
+ import torch
15
+ import pytest
16
+
17
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
18
+ from turboquant_gpu import TurboQuantEngine
19
+
20
+
21
+ def _make_random_keys(seq_k: int, head_dim: int = 128) -> torch.Tensor:
22
+ K = torch.randn(seq_k, head_dim)
23
+ return K.half()
24
+
25
+
26
+ @pytest.mark.parametrize("seq_k", [64, 128, 256, 512])
27
+ @pytest.mark.parametrize("total_bits", [3, 4])
28
+ def test_compress_keys_shapes(seq_k, total_bits):
29
+ engine = TurboQuantEngine(head_dim=128, total_bits=total_bits, device="cpu")
30
+ K = _make_random_keys(seq_k)
31
+ compressed = engine.compress_keys_pytorch(K)
32
+
33
+ assert compressed["indices"].shape == (seq_k, 128)
34
+ assert compressed["k_mse"].shape == (seq_k, 128)
35
+ assert compressed["qjl_signs"].shape == (seq_k, 128)
36
+ assert compressed["vec_norms"].shape == (seq_k,)
37
+ assert compressed["residual_norms"].shape == (seq_k,)
38
+
39
+
40
+ @pytest.mark.parametrize("total_bits", [3, 4])
41
+ def test_indices_in_valid_range(total_bits):
42
+ engine = TurboQuantEngine(head_dim=128, total_bits=total_bits, device="cpu")
43
+ K = _make_random_keys(256)
44
+ compressed = engine.compress_keys_pytorch(K)
45
+
46
+ mse_bits = max(total_bits - 1, 1)
47
+ max_idx = (1 << mse_bits) - 1
48
+ assert compressed["indices"].max().item() <= max_idx
49
+ assert compressed["indices"].min().item() >= 0
50
+
51
+
52
+ def test_signs_are_pm1():
53
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
54
+ K = _make_random_keys(256)
55
+ compressed = engine.compress_keys_pytorch(K)
56
+
57
+ unique_vals = set(compressed["qjl_signs"].unique().tolist())
58
+ assert unique_vals.issubset({-1, 1}), f"Signs should be +/-1, got {unique_vals}"
59
+
60
+
61
+ def test_norms_positive():
62
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
63
+ K = _make_random_keys(256)
64
+ compressed = engine.compress_keys_pytorch(K)
65
+
66
+ assert (compressed["vec_norms"] >= 0).all()
67
+ assert (compressed["residual_norms"] >= 0).all()
68
+
69
+
70
+ def test_compress_values_shapes():
71
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
72
+ V = _make_random_keys(256)
73
+ compressed = engine.compress_values_pytorch(V)
74
+
75
+ assert compressed["indices"].shape == (256, 128)
76
+ assert compressed["vec_norms"].shape == (256,)
77
+
78
+
79
+ def test_value_indices_in_valid_range():
80
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
81
+ V = _make_random_keys(256)
82
+ compressed = engine.compress_values_pytorch(V)
83
+
84
+ max_idx = (1 << 3) - 1 # values use all bits (no QJL)
85
+ assert compressed["indices"].max().item() <= max_idx
86
+
87
+
88
+ def test_compress_matches_reference():
89
+ """Compare our engine against the original cutiledump compressors.py implementation."""
90
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, seed=42, device="cpu")
91
+
92
+ K = torch.randn(64, 128).half()
93
+ compressed = engine.compress_keys_pytorch(K)
94
+
95
+ K_f = K.float()
96
+ vec_norms = torch.norm(K_f, dim=-1, keepdim=True)
97
+ K_normed = K_f / (vec_norms + 1e-8)
98
+ rotated = K_normed @ engine.PiT.float()
99
+ centroids = engine.key_codebook.centroids
100
+ diffs = rotated.unsqueeze(-1) - centroids
101
+ expected_indices = diffs.abs().argmin(dim=-1).to(torch.uint8)
102
+
103
+ assert torch.equal(compressed["indices"], expected_indices)
104
+
105
+
106
+ @pytest.mark.parametrize("total_bits", [3, 4])
107
+ def test_k_mse_reconstruction_quality(total_bits):
108
+ """k_mse should be a reasonable approximation of K."""
109
+ engine = TurboQuantEngine(head_dim=128, total_bits=total_bits, device="cpu")
110
+ K = torch.randn(512, 128).half()
111
+ compressed = engine.compress_keys_pytorch(K)
112
+
113
+ cos_sim = torch.nn.functional.cosine_similarity(
114
+ K.float(), compressed["k_mse"].float(), dim=-1
115
+ )
116
+ assert cos_sim.mean().item() > 0.85, (
117
+ f"Mean cosine similarity too low: {cos_sim.mean().item():.4f}"
118
+ )
119
+
120
+
121
+ def test_edge_case_zero_vector():
122
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
123
+ K = torch.zeros(1, 128).half()
124
+ compressed = engine.compress_keys_pytorch(K)
125
+ assert not torch.isnan(compressed["k_mse"]).any()
126
+ assert not torch.isinf(compressed["k_mse"]).any()
127
+
128
+
129
+ def test_edge_case_large_values():
130
+ engine = TurboQuantEngine(head_dim=128, total_bits=3, device="cpu")
131
+ K = torch.randn(4, 128).half() * 100
132
+ compressed = engine.compress_keys_pytorch(K)
133
+ assert not torch.isnan(compressed["k_mse"]).any()