tiny-turboquant 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. tiny_turboquant-0.1.0/LICENSE +21 -0
  2. tiny_turboquant-0.1.0/MANIFEST.in +6 -0
  3. tiny_turboquant-0.1.0/PKG-INFO +109 -0
  4. tiny_turboquant-0.1.0/README.md +64 -0
  5. tiny_turboquant-0.1.0/benchmarks/__init__.py +0 -0
  6. tiny_turboquant-0.1.0/benchmarks/bench_ann.py +188 -0
  7. tiny_turboquant-0.1.0/benchmarks/bench_kv_real.py +192 -0
  8. tiny_turboquant-0.1.0/demos/__init__.py +0 -0
  9. tiny_turboquant-0.1.0/demos/demo1_distortion_vs_theory.py +75 -0
  10. tiny_turboquant-0.1.0/demos/demo2_ann_vs_pq.py +133 -0
  11. tiny_turboquant-0.1.0/demos/demo3_real_embeddings.py +151 -0
  12. tiny_turboquant-0.1.0/demos/demo4_kv_cache.py +308 -0
  13. tiny_turboquant-0.1.0/pyproject.toml +81 -0
  14. tiny_turboquant-0.1.0/setup.cfg +4 -0
  15. tiny_turboquant-0.1.0/setup.py +3 -0
  16. tiny_turboquant-0.1.0/tests/test_tiny_turboquant.py +114 -0
  17. tiny_turboquant-0.1.0/tiny_turboquant/__init__.py +28 -0
  18. tiny_turboquant-0.1.0/tiny_turboquant/bitpack.py +98 -0
  19. tiny_turboquant-0.1.0/tiny_turboquant/codebooks.py +89 -0
  20. tiny_turboquant-0.1.0/tiny_turboquant/fwht.py +59 -0
  21. tiny_turboquant-0.1.0/tiny_turboquant/kv_cache.py +337 -0
  22. tiny_turboquant-0.1.0/tiny_turboquant/numpy_reference.py +271 -0
  23. tiny_turboquant-0.1.0/tiny_turboquant/outlier_split.py +95 -0
  24. tiny_turboquant-0.1.0/tiny_turboquant/quantizer.py +119 -0
  25. tiny_turboquant-0.1.0/tiny_turboquant/rotation.py +80 -0
  26. tiny_turboquant-0.1.0/tiny_turboquant.egg-info/PKG-INFO +109 -0
  27. tiny_turboquant-0.1.0/tiny_turboquant.egg-info/SOURCES.txt +28 -0
  28. tiny_turboquant-0.1.0/tiny_turboquant.egg-info/dependency_links.txt +1 -0
  29. tiny_turboquant-0.1.0/tiny_turboquant.egg-info/requires.txt +24 -0
  30. tiny_turboquant-0.1.0/tiny_turboquant.egg-info/top_level.txt +3 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Pradeep Boopathy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,6 @@
1
+ include README.md
2
+ include LICENSE
3
+ recursive-include tiny_turboquant *.py
4
+ recursive-include tests *.py
5
+ recursive-include demos *.py
6
+ recursive-include benchmarks *.py
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: tiny-turboquant
3
+ Version: 0.1.0
4
+ Summary: Low-bit vector and KV-cache compression research toolkit for PyTorch
5
+ Author: Pradeep Boopathy
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/pradeepboopathy/tiny-turboquant
8
+ Project-URL: Repository, https://github.com/pradeepboopathy/tiny-turboquant
9
+ Project-URL: Issues, https://github.com/pradeepboopathy/tiny-turboquant/issues
10
+ Keywords: quantization,kv-cache,llm,compression,vector-search,pytorch,rag,transformers
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
+ Requires-Python: >=3.9
22
+ Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: torch>=2.1
25
+ Requires-Dist: numpy>=1.24
26
+ Provides-Extra: dev
27
+ Requires-Dist: pytest>=8; extra == "dev"
28
+ Requires-Dist: build>=1.2; extra == "dev"
29
+ Requires-Dist: twine>=5.0; extra == "dev"
30
+ Requires-Dist: ruff>=0.5; extra == "dev"
31
+ Provides-Extra: demos
32
+ Requires-Dist: matplotlib>=3.7; extra == "demos"
33
+ Requires-Dist: faiss-cpu>=1.7.4; extra == "demos"
34
+ Requires-Dist: sentence-transformers>=2.6; extra == "demos"
35
+ Provides-Extra: llm
36
+ Requires-Dist: transformers>=4.40; extra == "llm"
37
+ Requires-Dist: accelerate>=0.30; extra == "llm"
38
+ Provides-Extra: all
39
+ Requires-Dist: matplotlib>=3.7; extra == "all"
40
+ Requires-Dist: faiss-cpu>=1.7.4; extra == "all"
41
+ Requires-Dist: sentence-transformers>=2.6; extra == "all"
42
+ Requires-Dist: transformers>=4.40; extra == "all"
43
+ Requires-Dist: accelerate>=0.30; extra == "all"
44
+ Dynamic: license-file
45
+
46
+ # Tiny TurboQuant
47
+
48
+ Tiny TurboQuant is a lightweight PyTorch research toolkit for low-bit vector compression and KV-cache compression experiments.
49
+
50
+ It includes:
51
+
52
+ - real `uint8` bit-packing for low-bit indices
53
+ - MSE-style scalar quantization
54
+ - product / inner-product-oriented quantization experiments
55
+ - outlier-split quantization
56
+ - a Hugging Face-compatible KV-cache prototype
57
+ - demos for ANN search, embedding compression, and KV-cache memory measurement
58
+
59
+ ## Important limitation
60
+
61
+ This package demonstrates packed memory compression. It is **not** a production compressed-attention engine. The KV-cache wrapper still dequantizes tensors before attention. Real latency gains require fused CUDA/Triton kernels or integration with an inference engine such as vLLM or TensorRT-LLM.
62
+
63
+ ## Install from local wheel
64
+
65
+ ```bash
66
+ pip install tiny_turboquant-0.1.0-py3-none-any.whl
67
+ ```
68
+
69
+ ## Basic usage
70
+
71
+ ```python
72
+ import torch
73
+ from tiny_turboquant import TurboQuantMSE, TurboQuantKVCache
74
+
75
+ x = torch.randn(128, 64)
76
+ x = x / x.norm(dim=-1, keepdim=True)
77
+
78
+ q = TurboQuantMSE.build(d=64, bits=4)
79
+ idx = q.quant(x)
80
+ x_hat = q.dequant(idx)
81
+
82
+ cache = TurboQuantKVCache(bits=4)
83
+ ```
84
+
85
+ ## Run tests
86
+
87
+ ```bash
88
+ python -m pytest
89
+ ```
90
+
91
+ ## Run demos from source checkout
92
+
93
+ ```bash
94
+ python -m demos.demo1_distortion_vs_theory
95
+ python -m demos.demo2_ann_vs_pq
96
+ python -m demos.demo3_real_embeddings
97
+ python -m demos.demo4_kv_cache
98
+ ```
99
+
100
+ ## Scope
101
+
102
+ Use this package for:
103
+
104
+ - compressed vector-search experiments
105
+ - RAG embedding compression experiments
106
+ - KV-cache memory/quality tradeoff experiments
107
+ - educational or research benchmarking
108
+
109
+ Do not claim it provides production-speed LLM inference. It reduces packed storage; speed requires optimized compressed-attention kernels.
@@ -0,0 +1,64 @@
1
+ # Tiny TurboQuant
2
+
3
+ Tiny TurboQuant is a lightweight PyTorch research toolkit for low-bit vector compression and KV-cache compression experiments.
4
+
5
+ It includes:
6
+
7
+ - real `uint8` bit-packing for low-bit indices
8
+ - MSE-style scalar quantization
9
+ - product / inner-product-oriented quantization experiments
10
+ - outlier-split quantization
11
+ - a Hugging Face-compatible KV-cache prototype
12
+ - demos for ANN search, embedding compression, and KV-cache memory measurement
13
+
14
+ ## Important limitation
15
+
16
+ This package demonstrates packed memory compression. It is **not** a production compressed-attention engine. The KV-cache wrapper still dequantizes tensors before attention. Real latency gains require fused CUDA/Triton kernels or integration with an inference engine such as vLLM or TensorRT-LLM.
17
+
18
+ ## Install from local wheel
19
+
20
+ ```bash
21
+ pip install tiny_turboquant-0.1.0-py3-none-any.whl
22
+ ```
23
+
24
+ ## Basic usage
25
+
26
+ ```python
27
+ import torch
28
+ from tiny_turboquant import TurboQuantMSE, TurboQuantKVCache
29
+
30
+ x = torch.randn(128, 64)
31
+ x = x / x.norm(dim=-1, keepdim=True)
32
+
33
+ q = TurboQuantMSE.build(d=64, bits=4)
34
+ idx = q.quant(x)
35
+ x_hat = q.dequant(idx)
36
+
37
+ cache = TurboQuantKVCache(bits=4)
38
+ ```
39
+
40
+ ## Run tests
41
+
42
+ ```bash
43
+ python -m pytest
44
+ ```
45
+
46
+ ## Run demos from source checkout
47
+
48
+ ```bash
49
+ python -m demos.demo1_distortion_vs_theory
50
+ python -m demos.demo2_ann_vs_pq
51
+ python -m demos.demo3_real_embeddings
52
+ python -m demos.demo4_kv_cache
53
+ ```
54
+
55
+ ## Scope
56
+
57
+ Use this package for:
58
+
59
+ - compressed vector-search experiments
60
+ - RAG embedding compression experiments
61
+ - KV-cache memory/quality tradeoff experiments
62
+ - educational or research benchmarking
63
+
64
+ Do not claim it provides production-speed LLM inference. It reduces packed storage; speed requires optimized compressed-attention kernels.
File without changes
@@ -0,0 +1,188 @@
1
+ """
2
+ ANN A/B benchmark: TurboQuant vs faiss PQ vs faiss IVFPQ vs RaBitQ (optional).
3
+
4
+ Reports recall@k, indexing time, and queries/second for each method on the
5
+ same dataset. The dataset can be loaded from a .npy file (most realistic)
6
+ or generated synthetically.
7
+
8
+ Usage:
9
+ python -m benchmarks.bench_ann --data path/to/embeddings.npy --bits 4 --k 10
10
+ python -m benchmarks.bench_ann --synthetic 50000 --d 768 --bits 4
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import time
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from tiny_turboquant import TurboQuantProd
23
+
24
+
25
+ def parse_args():
26
+ p = argparse.ArgumentParser()
27
+ p.add_argument("--data", type=str, default=None,
28
+ help="Path to .npy of shape (N, D), float32.")
29
+ p.add_argument("--synthetic", type=int, default=50_000)
30
+ p.add_argument("--d", type=int, default=768)
31
+ p.add_argument("--bits", type=int, default=4)
32
+ p.add_argument("--k", type=int, default=10)
33
+ p.add_argument("--n_queries", type=int, default=500)
34
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
35
+ return p.parse_args()
36
+
37
+
38
+ def load_data(args) -> tuple[np.ndarray, np.ndarray]:
39
+ if args.data is not None:
40
+ X = np.load(args.data).astype(np.float32)
41
+ else:
42
+ rng = np.random.default_rng(0)
43
+ # Cluster structure mimicking real embedding stores.
44
+ n_clusters = max(50, args.synthetic // 500)
45
+ centers = rng.standard_normal((n_clusters, args.d)).astype(np.float32)
46
+ centers /= np.linalg.norm(centers, axis=1, keepdims=True)
47
+ cluster_id = rng.integers(0, n_clusters, size=args.synthetic)
48
+ X = centers[cluster_id] + 0.25 * rng.standard_normal((args.synthetic, args.d)).astype(np.float32)
49
+
50
+ X = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-12)
51
+ rng = np.random.default_rng(1)
52
+ qi = rng.choice(X.shape[0], args.n_queries, replace=False)
53
+ return X, X[qi].copy()
54
+
55
+
56
+ def recall_at_k(true_top: np.ndarray, est_top: np.ndarray) -> float:
57
+ hits = sum(len(set(t.tolist()) & set(e.tolist()))
58
+ for t, e in zip(true_top, est_top))
59
+ return hits / true_top.size
60
+
61
+
62
+ # ---- methods ----------------------------------------------------------
63
+
64
+ def turboquant_run(X, Q, k, bits, device):
65
+ Xt = torch.from_numpy(X).to(device)
66
+ Qt = torch.from_numpy(Q).to(device)
67
+ d = X.shape[1]
68
+
69
+ t0 = time.perf_counter()
70
+ q = TurboQuantProd.build(d, bits, device=device, seed=0,
71
+ dtype=torch.float32)
72
+ idx, signs, gamma = q.quant(Xt)
73
+ if device.startswith("cuda"):
74
+ torch.cuda.synchronize()
75
+ t_index = time.perf_counter() - t0
76
+
77
+ t0 = time.perf_counter()
78
+ Xh = q.dequant(idx, signs, gamma)
79
+ sims = Qt @ Xh.T
80
+ top = sims.topk(k, dim=-1).indices.cpu().numpy()
81
+ if device.startswith("cuda"):
82
+ torch.cuda.synchronize()
83
+ t_query = time.perf_counter() - t0
84
+ return top, t_index, t_query
85
+
86
+
87
+ def faiss_pq_run(X, Q, k, bits):
88
+ try:
89
+ import faiss
90
+ except Exception as e: # pragma: no cover
91
+ print(f" (faiss unavailable: {e}) — skipping PQ")
92
+ return None, None, None
93
+ d = X.shape[1]
94
+ # Match bit-budget: PQ stores log2(ks) bits per sub-quantizer;
95
+ # M sub-vectors * log2(ks) / d total bits per coord.
96
+ # Use ks=2**bits and M=d so PQ's bit-budget = bits per coord.
97
+ M = d
98
+ ks = 2 ** bits
99
+ if ks > 256: # faiss IndexPQ stores 1 byte per code
100
+ ks = 256
101
+ index = faiss.IndexPQ(d, M, int(np.log2(ks)))
102
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
103
+ t0 = time.perf_counter(); index.train(X); index.add(X); t_index = time.perf_counter() - t0
104
+ t0 = time.perf_counter()
105
+ _, top = index.search(Q, k)
106
+ t_query = time.perf_counter() - t0
107
+ return top, t_index, t_query
108
+
109
+
110
+ def faiss_ivfpq_run(X, Q, k, bits):
111
+ try:
112
+ import faiss
113
+ except Exception:
114
+ return None, None, None
115
+ d = X.shape[1]
116
+ nlist = 100
117
+ M = d
118
+ ks_bits = min(bits, 8)
119
+ quantizer = faiss.IndexFlatIP(d)
120
+ index = faiss.IndexIVFPQ(quantizer, d, nlist, M, ks_bits)
121
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
122
+ t0 = time.perf_counter(); index.train(X); index.add(X); t_index = time.perf_counter() - t0
123
+ index.nprobe = 16
124
+ t0 = time.perf_counter()
125
+ _, top = index.search(Q, k)
126
+ t_query = time.perf_counter() - t0
127
+ return top, t_index, t_query
128
+
129
+
130
+ def rabitq_run(X, Q, k, bits): # pragma: no cover
131
+ """Optional RaBitQ baseline; skipped silently if package missing."""
132
+ try:
133
+ import rabitqlib # noqa: F401
134
+ except Exception as e:
135
+ print(f" (rabitq unavailable: {e}) — skipping RaBitQ")
136
+ return None, None, None
137
+ print(" TODO: integrate rabitqlib (left as exercise — official API churns)")
138
+ return None, None, None
139
+
140
+
141
+ # ---- driver ----------------------------------------------------------
142
+
143
+ def main():
144
+ args = parse_args()
145
+ X, Q = load_data(args)
146
+ n, d = X.shape
147
+
148
+ # Ground truth: brute-force top-k inner product
149
+ print(f"\nDataset: {n} vectors of dim {d}, {len(Q)} queries, k={args.k}, "
150
+ f"bits={args.bits}, device={args.device}\n")
151
+ t0 = time.perf_counter()
152
+ sims = Q @ X.T
153
+ true_top = np.argpartition(-sims, args.k, axis=1)[:, :args.k]
154
+ # exact top-k order
155
+ true_top = np.array([row[np.argsort(-sims[i, row])] for i, row in enumerate(true_top)])
156
+ t_brute = time.perf_counter() - t0
157
+ print(f"brute force baseline: {t_brute:.3f}s\n")
158
+
159
+ rows = []
160
+ for name, fn in (
161
+ ("TurboQuant", lambda: turboquant_run(X, Q, args.k, args.bits, args.device)),
162
+ ("faiss PQ", lambda: faiss_pq_run(X, Q, args.k, args.bits)),
163
+ ("faiss IVF-PQ", lambda: faiss_ivfpq_run(X, Q, args.k, args.bits)),
164
+ ("RaBitQ (opt.)", lambda: rabitq_run(X, Q, args.k, args.bits)),
165
+ ):
166
+ top, t_index, t_query = fn()
167
+ if top is None:
168
+ continue
169
+ rec = recall_at_k(true_top, top)
170
+ qps = len(Q) / max(t_query, 1e-9)
171
+ rows.append((name, t_index, t_query, qps, rec))
172
+
173
+ print(f"{'method':16s} | {'index(s)':>9} | {'query(s)':>9} | {'qps':>10} | {'recall@k':>9}")
174
+ print("-" * 72)
175
+ for name, ti, tq, qps, rec in rows:
176
+ print(f"{name:16s} | {ti:>9.3f} | {tq:>9.3f} | {qps:>10.0f} | {rec:>9.3f}")
177
+ print()
178
+
179
+ if rows:
180
+ tq_row = next((r for r in rows if r[0] == "TurboQuant"), None)
181
+ pq_row = next((r for r in rows if r[0] == "faiss PQ"), None)
182
+ if tq_row and pq_row:
183
+ print(f"TurboQuant indexing speedup vs PQ: {pq_row[1] / tq_row[1]:.1f}×")
184
+ print(f"TurboQuant recall delta vs PQ : {tq_row[4] - pq_row[4]:+.3f}")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
@@ -0,0 +1,192 @@
1
+ """
2
+ Real-LLM benchmark: TurboQuant KV cache vs fp16 baseline.
3
+
4
+ Runs a short prompt-completion task on a small open model and reports:
5
+ - actual packed KV cache memory
6
+ - per-token attention output cosine similarity
7
+ - generation perplexity / acceptance vs baseline
8
+
9
+ Default model: HuggingFaceTB/SmolLM2-360M-Instruct (small enough for
10
+ free Kaggle T4, big enough to be representative).
11
+
12
+ Usage:
13
+ python -m benchmarks.bench_kv_real --model SmolLM2-360M --bits 4
14
+ python -m benchmarks.bench_kv_real --bits 2 --bits_outlier 3 --n_outlier 32
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import math
21
+ import os
22
+ import time
23
+
24
+ # Quieten the noisy weight-loading tqdm bar from transformers / accelerate.
25
+ # (When stdout is captured by a subprocess pipe, tqdm prints one line per
26
+ # update instead of redrawing in place, which floods the log.)
27
+ os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
28
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
29
+ os.environ.setdefault("ACCELERATE_DISABLE_RICH", "1")
30
+ os.environ.setdefault("TQDM_DISABLE", "1")
31
+
32
+ import torch
33
+
34
+ from tiny_turboquant import TurboQuantKVCache
35
+
36
+
37
+ def parse_args():
38
+ p = argparse.ArgumentParser()
39
+ p.add_argument("--model", default="HuggingFaceTB/SmolLM2-360M-Instruct")
40
+ p.add_argument("--bits", type=int, default=4)
41
+ p.add_argument("--bits_outlier", type=int, default=None,
42
+ help="If set, enable outlier-channel split with this bit-width.")
43
+ p.add_argument("--n_outlier", type=int, default=32)
44
+ p.add_argument("--prompt", default=(
45
+ "The capital of France is Paris. The capital of Germany is Berlin. "
46
+ "The capital of Spain is Madrid. The capital of Italy is"))
47
+ p.add_argument("--max_new_tokens", type=int, default=64)
48
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
49
+ p.add_argument("--dtype", default="float16")
50
+ return p.parse_args()
51
+
52
+
53
+ @torch.no_grad()
54
+ def measure(model, tokenizer, prompt, cache, max_new_tokens):
55
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
56
+ t0 = time.perf_counter()
57
+ out = model.generate(
58
+ **inputs,
59
+ max_new_tokens=max_new_tokens,
60
+ do_sample=False,
61
+ past_key_values=cache,
62
+ )
63
+ elapsed = time.perf_counter() - t0
64
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
65
+ return text, elapsed, out
66
+
67
+
68
+ @torch.no_grad()
69
+ def measure_logit_kl(model, tokenizer, prompt, cache_factory, max_new_tokens=32):
70
+ """Teacher-forced KL between fp16 baseline logits and TurboQuant logits.
71
+
72
+ Both paths see identical input at every step. This isolates pure
73
+ quantization error. Free-running generations are still produced for
74
+ the "identical generation?" check, but they are NOT used for KL —
75
+ otherwise the metric explodes the moment paths diverge by one token.
76
+ """
77
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
78
+
79
+ # 1) Free-running baseline + TQ generations (for identical-generation check only)
80
+ out_ref = model.generate(
81
+ **inputs, max_new_tokens=max_new_tokens, do_sample=False,
82
+ return_dict_in_generate=True,
83
+ )
84
+ cache = cache_factory()
85
+ out_tq = model.generate(
86
+ **inputs, max_new_tokens=max_new_tokens, do_sample=False,
87
+ return_dict_in_generate=True,
88
+ past_key_values=cache,
89
+ )
90
+
91
+ text_ref = tokenizer.decode(out_ref.sequences[0], skip_special_tokens=True)
92
+ text_tq = tokenizer.decode(out_tq.sequences[0], skip_special_tokens=True)
93
+
94
+ # First-divergence index: how many decode steps stay greedy-identical?
95
+ # This is the production-relevant signal — small KL means the answer
96
+ # token is identical even when later filler tokens drift.
97
+ prompt_len_for_div = inputs["input_ids"].shape[1]
98
+ ref_new = out_ref.sequences[0, prompt_len_for_div:].tolist()
99
+ tq_new = out_tq .sequences[0, prompt_len_for_div:].tolist()
100
+ n_new = min(len(ref_new), len(tq_new))
101
+ first_diverge = next(
102
+ (i for i in range(n_new) if ref_new[i] != tq_new[i]),
103
+ n_new,
104
+ )
105
+
106
+ # 2) Teacher-forced KL: feed the SAME baseline-generated sequence into
107
+ # both fp16-cache and TQ-cache, compare per-position logits.
108
+ seq = out_ref.sequences # (1, prompt_len + new_tokens)
109
+ prompt_len = inputs["input_ids"].shape[1]
110
+
111
+ # fp16 reference: one forward pass, no cache compression.
112
+ ref_logits = model(seq).logits # (1, T, V)
113
+
114
+ # TQ path: prefill + decode through TurboQuantKVCache so attention
115
+ # is computed against quantised K/V at every step.
116
+ tq_cache = cache_factory()
117
+ # prefill on the prompt
118
+ tq_logits_chunks = []
119
+ out = model(seq[:, :prompt_len], past_key_values=tq_cache, use_cache=True)
120
+ tq_logits_chunks.append(out.logits)
121
+ # decode one token at a time, feeding the *baseline* tokens
122
+ for t in range(prompt_len, seq.shape[1]):
123
+ out = model(seq[:, t:t + 1], past_key_values=tq_cache, use_cache=True)
124
+ tq_logits_chunks.append(out.logits)
125
+ tq_logits = torch.cat(tq_logits_chunks, dim=1) # (1, T, V)
126
+
127
+ # Compare logits at positions [prompt_len-1 .. T-2] — the ones that
128
+ # predict the new tokens. Both saw identical inputs.
129
+ p = torch.softmax(ref_logits[:, prompt_len - 1:-1].float(), dim=-1)
130
+ q = torch.softmax(tq_logits [:, prompt_len - 1:-1].float(), dim=-1).clamp_min_(1e-12)
131
+ kls = (p * (p.clamp_min_(1e-12).log() - q.log())).sum(-1).squeeze(0).tolist()
132
+
133
+ return text_ref, text_tq, kls, cache, first_diverge, n_new
134
+
135
+
136
+ def main() -> None:
137
+ args = parse_args()
138
+ from transformers import AutoModelForCausalLM, AutoTokenizer
139
+ from transformers.utils import logging as hf_logging
140
+ import huggingface_hub.utils as hf_hub_utils
141
+
142
+ # Hard-disable every progress bar transformers / huggingface_hub know
143
+ # about. Env vars alone are unreliable when this script is launched as
144
+ # a subprocess from kaggle_run.py (pipes confuse tqdm).
145
+ hf_logging.disable_progress_bar()
146
+ hf_logging.set_verbosity_error()
147
+ hf_hub_utils.disable_progress_bars()
148
+
149
+ dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16,
150
+ "float32": torch.float32}[args.dtype]
151
+ print(f"Loading {args.model} ({args.dtype} on {args.device})")
152
+ tok = AutoTokenizer.from_pretrained(args.model)
153
+ model = AutoModelForCausalLM.from_pretrained(
154
+ args.model, torch_dtype=dtype
155
+ ).to(args.device).eval()
156
+
157
+ def factory():
158
+ return TurboQuantKVCache(
159
+ bits=args.bits,
160
+ bits_outlier=args.bits_outlier,
161
+ n_outlier=args.n_outlier,
162
+ )
163
+
164
+ text_ref, text_tq, kls, cache, first_diverge, n_new = measure_logit_kl(
165
+ model, tok, args.prompt, factory, max_new_tokens=args.max_new_tokens,
166
+ )
167
+
168
+ print("\n--- BASELINE (fp16, default DynamicCache) ---")
169
+ print(text_ref)
170
+ print("\n--- TURBOQUANT ---")
171
+ print(text_tq)
172
+
173
+ n_tokens = cache.get_seq_length(0)
174
+ fp16_bytes = cache.fp16_baseline_bytes()
175
+ actual_tq_bytes = cache.actual_memory_bytes()
176
+ theoretical_tq_bytes = cache.theoretical_memory_bytes()
177
+
178
+ print("\n--- METRICS ---")
179
+ print(f"prompt + decode tokens : {n_tokens}")
180
+ print(f"fp16 KV cache actual : {fp16_bytes / 1e6:.2f} MB")
181
+ print(f"TurboQuant actual : {actual_tq_bytes / 1e6:.2f} MB "
182
+ f"(× compression = {fp16_bytes / max(actual_tq_bytes, 1):.2f})")
183
+ print(f"TurboQuant theoretical : {theoretical_tq_bytes / 1e6:.2f} MB")
184
+ print(f"mean per-token logit KL: {sum(kls) / len(kls):.5f} (teacher-forced; lower is better)")
185
+ print(f"max per-token logit KL: {max(kls):.5f}")
186
+ print(f"first divergence at : {first_diverge}/{n_new} decode tokens "
187
+ f"(higher is better; {n_new}/{n_new} = exact match)")
188
+ print(f"identical generation? : {text_ref == text_tq}")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
File without changes
@@ -0,0 +1,75 @@
1
+ """
2
+ Demo 1 — Empirical distortion matches Shannon-theoretic bounds.
3
+
4
+ Story for the audience:
5
+ "We compress unit vectors to b bits per coordinate. The black dashed line
6
+ is the information-theoretic *lower* bound from Shannon source coding;
7
+ the orange line is what *any* possible algorithm could ever achieve
8
+ in the limit. TurboQuant (blue) sits within a small constant of optimum
9
+ for every bit-width, with zero tuning."
10
+
11
+ Run:
12
+ python -m demos.demo1_distortion_vs_theory
13
+ Outputs:
14
+ demos/out/distortion.png
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import numpy as np
21
+ import matplotlib.pyplot as plt
22
+ from tiny_turboquant.numpy_reference import TurboQuantMSE, TurboQuantProd
23
+
24
+ OUT = os.path.join(os.path.dirname(__file__), "out")
25
+ os.makedirs(OUT, exist_ok=True)
26
+
27
+
28
+ def main() -> None:
29
+ rng = np.random.default_rng(0)
30
+ d, n = 512, 4000
31
+ X = rng.standard_normal((n, d))
32
+ X /= np.linalg.norm(X, axis=1, keepdims=True)
33
+
34
+ bits = [1, 2, 3, 4, 5, 6]
35
+ mse_emp, mse_lb, mse_ub = [], [], []
36
+ for b in bits:
37
+ q = TurboQuantMSE(d, b, seed=0)
38
+ Xh = q.dequant(q.quant(X))
39
+ mse_emp.append(float(np.mean(np.sum((X - Xh) ** 2, axis=1))))
40
+ mse_lb.append(4.0 ** (-b)) # Shannon lower bound
41
+ mse_ub.append(3 * np.pi / 2 * 4.0 ** (-b)) # paper Theorem 1
42
+
43
+ fig, ax = plt.subplots(figsize=(7, 4.5))
44
+ ax.plot(bits, mse_ub, "--", color="orange", label="Paper upper bound (3π/2·4⁻ᵇ)")
45
+ ax.plot(bits, mse_lb, "--", color="black", label="Shannon lower bound (4⁻ᵇ)")
46
+ ax.plot(bits, mse_emp, "o-", color="tab:blue", lw=2, label="TurboQuant empirical")
47
+ ax.set_yscale("log")
48
+ ax.set_xlabel("bits per coordinate")
49
+ ax.set_ylabel("MSE E‖x − x̂‖²")
50
+ ax.set_title(f"TurboQuant-MSE vs information-theoretic bounds (d={d}, n={n})")
51
+ ax.grid(True, which="both", alpha=0.3)
52
+ ax.legend()
53
+ fig.tight_layout()
54
+ out = os.path.join(OUT, "distortion.png")
55
+ fig.savefig(out, dpi=140)
56
+ print(f"saved {out}")
57
+
58
+ # Also show inner-product unbiasedness — the *killer* feature for ANN/RAG.
59
+ Y = rng.standard_normal((n, d)); Y /= np.linalg.norm(Y, axis=1, keepdims=True)
60
+ print("\n bits | bias | ip-MSE | ratio vs full-precision")
61
+ print(" -----+----------+-----------+--------------------------")
62
+ for b in (2, 3, 4):
63
+ qp = TurboQuantProd(d, b, seed=0)
64
+ idx, s, g = qp.quant(X)
65
+ Xh = qp.dequant(idx, s, g)
66
+ true = np.sum(X * Y, 1)
67
+ est = np.sum(Xh * Y, 1)
68
+ bias = float(np.mean(est - true))
69
+ var = float(np.mean((est - true) ** 2))
70
+ print(f" {b} | {bias:+.4f} | {var:.5f} | "
71
+ f"{32/b:.0f}× smaller than fp32")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()