tiny-turboquant 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
benchmarks/__init__.py ADDED
File without changes
@@ -0,0 +1,188 @@
1
+ """
2
+ ANN A/B benchmark: TurboQuant vs faiss PQ vs faiss IVFPQ vs RaBitQ (optional).
3
+
4
+ Reports recall@k, indexing time, and queries/second for each method on the
5
+ same dataset. The dataset can be loaded from a .npy file (most realistic)
6
+ or generated synthetically.
7
+
8
+ Usage:
9
+ python -m benchmarks.bench_ann --data path/to/embeddings.npy --bits 4 --k 10
10
+ python -m benchmarks.bench_ann --synthetic 50000 --d 768 --bits 4
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import time
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from tiny_turboquant import TurboQuantProd
23
+
24
+
25
+ def parse_args():
26
+ p = argparse.ArgumentParser()
27
+ p.add_argument("--data", type=str, default=None,
28
+ help="Path to .npy of shape (N, D), float32.")
29
+ p.add_argument("--synthetic", type=int, default=50_000)
30
+ p.add_argument("--d", type=int, default=768)
31
+ p.add_argument("--bits", type=int, default=4)
32
+ p.add_argument("--k", type=int, default=10)
33
+ p.add_argument("--n_queries", type=int, default=500)
34
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
35
+ return p.parse_args()
36
+
37
+
38
+ def load_data(args) -> tuple[np.ndarray, np.ndarray]:
39
+ if args.data is not None:
40
+ X = np.load(args.data).astype(np.float32)
41
+ else:
42
+ rng = np.random.default_rng(0)
43
+ # Cluster structure mimicking real embedding stores.
44
+ n_clusters = max(50, args.synthetic // 500)
45
+ centers = rng.standard_normal((n_clusters, args.d)).astype(np.float32)
46
+ centers /= np.linalg.norm(centers, axis=1, keepdims=True)
47
+ cluster_id = rng.integers(0, n_clusters, size=args.synthetic)
48
+ X = centers[cluster_id] + 0.25 * rng.standard_normal((args.synthetic, args.d)).astype(np.float32)
49
+
50
+ X = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-12)
51
+ rng = np.random.default_rng(1)
52
+ qi = rng.choice(X.shape[0], args.n_queries, replace=False)
53
+ return X, X[qi].copy()
54
+
55
+
56
+ def recall_at_k(true_top: np.ndarray, est_top: np.ndarray) -> float:
57
+ hits = sum(len(set(t.tolist()) & set(e.tolist()))
58
+ for t, e in zip(true_top, est_top))
59
+ return hits / true_top.size
60
+
61
+
62
+ # ---- methods ----------------------------------------------------------
63
+
64
+ def turboquant_run(X, Q, k, bits, device):
65
+ Xt = torch.from_numpy(X).to(device)
66
+ Qt = torch.from_numpy(Q).to(device)
67
+ d = X.shape[1]
68
+
69
+ t0 = time.perf_counter()
70
+ q = TurboQuantProd.build(d, bits, device=device, seed=0,
71
+ dtype=torch.float32)
72
+ idx, signs, gamma = q.quant(Xt)
73
+ if device.startswith("cuda"):
74
+ torch.cuda.synchronize()
75
+ t_index = time.perf_counter() - t0
76
+
77
+ t0 = time.perf_counter()
78
+ Xh = q.dequant(idx, signs, gamma)
79
+ sims = Qt @ Xh.T
80
+ top = sims.topk(k, dim=-1).indices.cpu().numpy()
81
+ if device.startswith("cuda"):
82
+ torch.cuda.synchronize()
83
+ t_query = time.perf_counter() - t0
84
+ return top, t_index, t_query
85
+
86
+
87
+ def faiss_pq_run(X, Q, k, bits):
88
+ try:
89
+ import faiss
90
+ except Exception as e: # pragma: no cover
91
+ print(f" (faiss unavailable: {e}) — skipping PQ")
92
+ return None, None, None
93
+ d = X.shape[1]
94
+ # Match bit-budget: PQ stores log2(ks) bits per sub-quantizer;
95
+ # M sub-vectors * log2(ks) / d total bits per coord.
96
+ # Use ks=2**bits and M=d so PQ's bit-budget = bits per coord.
97
+ M = d
98
+ ks = 2 ** bits
99
+ if ks > 256: # faiss IndexPQ stores 1 byte per code
100
+ ks = 256
101
+ index = faiss.IndexPQ(d, M, int(np.log2(ks)))
102
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
103
+ t0 = time.perf_counter(); index.train(X); index.add(X); t_index = time.perf_counter() - t0
104
+ t0 = time.perf_counter()
105
+ _, top = index.search(Q, k)
106
+ t_query = time.perf_counter() - t0
107
+ return top, t_index, t_query
108
+
109
+
110
+ def faiss_ivfpq_run(X, Q, k, bits):
111
+ try:
112
+ import faiss
113
+ except Exception:
114
+ return None, None, None
115
+ d = X.shape[1]
116
+ nlist = 100
117
+ M = d
118
+ ks_bits = min(bits, 8)
119
+ quantizer = faiss.IndexFlatIP(d)
120
+ index = faiss.IndexIVFPQ(quantizer, d, nlist, M, ks_bits)
121
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
122
+ t0 = time.perf_counter(); index.train(X); index.add(X); t_index = time.perf_counter() - t0
123
+ index.nprobe = 16
124
+ t0 = time.perf_counter()
125
+ _, top = index.search(Q, k)
126
+ t_query = time.perf_counter() - t0
127
+ return top, t_index, t_query
128
+
129
+
130
+ def rabitq_run(X, Q, k, bits): # pragma: no cover
131
+ """Optional RaBitQ baseline; skipped silently if package missing."""
132
+ try:
133
+ import rabitqlib # noqa: F401
134
+ except Exception as e:
135
+ print(f" (rabitq unavailable: {e}) — skipping RaBitQ")
136
+ return None, None, None
137
+ print(" TODO: integrate rabitqlib (left as exercise — official API churns)")
138
+ return None, None, None
139
+
140
+
141
+ # ---- driver ----------------------------------------------------------
142
+
143
+ def main():
144
+ args = parse_args()
145
+ X, Q = load_data(args)
146
+ n, d = X.shape
147
+
148
+ # Ground truth: brute-force top-k inner product
149
+ print(f"\nDataset: {n} vectors of dim {d}, {len(Q)} queries, k={args.k}, "
150
+ f"bits={args.bits}, device={args.device}\n")
151
+ t0 = time.perf_counter()
152
+ sims = Q @ X.T
153
+ true_top = np.argpartition(-sims, args.k, axis=1)[:, :args.k]
154
+ # exact top-k order
155
+ true_top = np.array([row[np.argsort(-sims[i, row])] for i, row in enumerate(true_top)])
156
+ t_brute = time.perf_counter() - t0
157
+ print(f"brute force baseline: {t_brute:.3f}s\n")
158
+
159
+ rows = []
160
+ for name, fn in (
161
+ ("TurboQuant", lambda: turboquant_run(X, Q, args.k, args.bits, args.device)),
162
+ ("faiss PQ", lambda: faiss_pq_run(X, Q, args.k, args.bits)),
163
+ ("faiss IVF-PQ", lambda: faiss_ivfpq_run(X, Q, args.k, args.bits)),
164
+ ("RaBitQ (opt.)", lambda: rabitq_run(X, Q, args.k, args.bits)),
165
+ ):
166
+ top, t_index, t_query = fn()
167
+ if top is None:
168
+ continue
169
+ rec = recall_at_k(true_top, top)
170
+ qps = len(Q) / max(t_query, 1e-9)
171
+ rows.append((name, t_index, t_query, qps, rec))
172
+
173
+ print(f"{'method':16s} | {'index(s)':>9} | {'query(s)':>9} | {'qps':>10} | {'recall@k':>9}")
174
+ print("-" * 72)
175
+ for name, ti, tq, qps, rec in rows:
176
+ print(f"{name:16s} | {ti:>9.3f} | {tq:>9.3f} | {qps:>10.0f} | {rec:>9.3f}")
177
+ print()
178
+
179
+ if rows:
180
+ tq_row = next((r for r in rows if r[0] == "TurboQuant"), None)
181
+ pq_row = next((r for r in rows if r[0] == "faiss PQ"), None)
182
+ if tq_row and pq_row:
183
+ print(f"TurboQuant indexing speedup vs PQ: {pq_row[1] / tq_row[1]:.1f}×")
184
+ print(f"TurboQuant recall delta vs PQ : {tq_row[4] - pq_row[4]:+.3f}")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
@@ -0,0 +1,192 @@
1
+ """
2
+ Real-LLM benchmark: TurboQuant KV cache vs fp16 baseline.
3
+
4
+ Runs a short prompt-completion task on a small open model and reports:
5
+ - actual packed KV cache memory
6
+ - per-token attention output cosine similarity
7
+ - generation perplexity / acceptance vs baseline
8
+
9
+ Default model: HuggingFaceTB/SmolLM2-360M-Instruct (small enough for
10
+ free Kaggle T4, big enough to be representative).
11
+
12
+ Usage:
13
+ python -m benchmarks.bench_kv_real --model SmolLM2-360M --bits 4
14
+ python -m benchmarks.bench_kv_real --bits 2 --bits_outlier 3 --n_outlier 32
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import math
21
+ import os
22
+ import time
23
+
24
+ # Quieten the noisy weight-loading tqdm bar from transformers / accelerate.
25
+ # (When stdout is captured by a subprocess pipe, tqdm prints one line per
26
+ # update instead of redrawing in place, which floods the log.)
27
+ os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
28
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
29
+ os.environ.setdefault("ACCELERATE_DISABLE_RICH", "1")
30
+ os.environ.setdefault("TQDM_DISABLE", "1")
31
+
32
+ import torch
33
+
34
+ from tiny_turboquant import TurboQuantKVCache
35
+
36
+
37
+ def parse_args():
38
+ p = argparse.ArgumentParser()
39
+ p.add_argument("--model", default="HuggingFaceTB/SmolLM2-360M-Instruct")
40
+ p.add_argument("--bits", type=int, default=4)
41
+ p.add_argument("--bits_outlier", type=int, default=None,
42
+ help="If set, enable outlier-channel split with this bit-width.")
43
+ p.add_argument("--n_outlier", type=int, default=32)
44
+ p.add_argument("--prompt", default=(
45
+ "The capital of France is Paris. The capital of Germany is Berlin. "
46
+ "The capital of Spain is Madrid. The capital of Italy is"))
47
+ p.add_argument("--max_new_tokens", type=int, default=64)
48
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
49
+ p.add_argument("--dtype", default="float16")
50
+ return p.parse_args()
51
+
52
+
53
+ @torch.no_grad()
54
+ def measure(model, tokenizer, prompt, cache, max_new_tokens):
55
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
56
+ t0 = time.perf_counter()
57
+ out = model.generate(
58
+ **inputs,
59
+ max_new_tokens=max_new_tokens,
60
+ do_sample=False,
61
+ past_key_values=cache,
62
+ )
63
+ elapsed = time.perf_counter() - t0
64
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
65
+ return text, elapsed, out
66
+
67
+
68
+ @torch.no_grad()
69
+ def measure_logit_kl(model, tokenizer, prompt, cache_factory, max_new_tokens=32):
70
+ """Teacher-forced KL between fp16 baseline logits and TurboQuant logits.
71
+
72
+ Both paths see identical input at every step. This isolates pure
73
+ quantization error. Free-running generations are still produced for
74
+ the "identical generation?" check, but they are NOT used for KL —
75
+ otherwise the metric explodes the moment paths diverge by one token.
76
+ """
77
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
78
+
79
+ # 1) Free-running baseline + TQ generations (for identical-generation check only)
80
+ out_ref = model.generate(
81
+ **inputs, max_new_tokens=max_new_tokens, do_sample=False,
82
+ return_dict_in_generate=True,
83
+ )
84
+ cache = cache_factory()
85
+ out_tq = model.generate(
86
+ **inputs, max_new_tokens=max_new_tokens, do_sample=False,
87
+ return_dict_in_generate=True,
88
+ past_key_values=cache,
89
+ )
90
+
91
+ text_ref = tokenizer.decode(out_ref.sequences[0], skip_special_tokens=True)
92
+ text_tq = tokenizer.decode(out_tq.sequences[0], skip_special_tokens=True)
93
+
94
+ # First-divergence index: how many decode steps stay greedy-identical?
95
+ # This is the production-relevant signal — small KL means the answer
96
+ # token is identical even when later filler tokens drift.
97
+ prompt_len_for_div = inputs["input_ids"].shape[1]
98
+ ref_new = out_ref.sequences[0, prompt_len_for_div:].tolist()
99
+ tq_new = out_tq .sequences[0, prompt_len_for_div:].tolist()
100
+ n_new = min(len(ref_new), len(tq_new))
101
+ first_diverge = next(
102
+ (i for i in range(n_new) if ref_new[i] != tq_new[i]),
103
+ n_new,
104
+ )
105
+
106
+ # 2) Teacher-forced KL: feed the SAME baseline-generated sequence into
107
+ # both fp16-cache and TQ-cache, compare per-position logits.
108
+ seq = out_ref.sequences # (1, prompt_len + new_tokens)
109
+ prompt_len = inputs["input_ids"].shape[1]
110
+
111
+ # fp16 reference: one forward pass, no cache compression.
112
+ ref_logits = model(seq).logits # (1, T, V)
113
+
114
+ # TQ path: prefill + decode through TurboQuantKVCache so attention
115
+ # is computed against quantised K/V at every step.
116
+ tq_cache = cache_factory()
117
+ # prefill on the prompt
118
+ tq_logits_chunks = []
119
+ out = model(seq[:, :prompt_len], past_key_values=tq_cache, use_cache=True)
120
+ tq_logits_chunks.append(out.logits)
121
+ # decode one token at a time, feeding the *baseline* tokens
122
+ for t in range(prompt_len, seq.shape[1]):
123
+ out = model(seq[:, t:t + 1], past_key_values=tq_cache, use_cache=True)
124
+ tq_logits_chunks.append(out.logits)
125
+ tq_logits = torch.cat(tq_logits_chunks, dim=1) # (1, T, V)
126
+
127
+ # Compare logits at positions [prompt_len-1 .. T-2] — the ones that
128
+ # predict the new tokens. Both saw identical inputs.
129
+ p = torch.softmax(ref_logits[:, prompt_len - 1:-1].float(), dim=-1)
130
+ q = torch.softmax(tq_logits [:, prompt_len - 1:-1].float(), dim=-1).clamp_min_(1e-12)
131
+ kls = (p * (p.clamp_min_(1e-12).log() - q.log())).sum(-1).squeeze(0).tolist()
132
+
133
+ return text_ref, text_tq, kls, cache, first_diverge, n_new
134
+
135
+
136
+ def main() -> None:
137
+ args = parse_args()
138
+ from transformers import AutoModelForCausalLM, AutoTokenizer
139
+ from transformers.utils import logging as hf_logging
140
+ import huggingface_hub.utils as hf_hub_utils
141
+
142
+ # Hard-disable every progress bar transformers / huggingface_hub know
143
+ # about. Env vars alone are unreliable when this script is launched as
144
+ # a subprocess from kaggle_run.py (pipes confuse tqdm).
145
+ hf_logging.disable_progress_bar()
146
+ hf_logging.set_verbosity_error()
147
+ hf_hub_utils.disable_progress_bars()
148
+
149
+ dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16,
150
+ "float32": torch.float32}[args.dtype]
151
+ print(f"Loading {args.model} ({args.dtype} on {args.device})")
152
+ tok = AutoTokenizer.from_pretrained(args.model)
153
+ model = AutoModelForCausalLM.from_pretrained(
154
+ args.model, torch_dtype=dtype
155
+ ).to(args.device).eval()
156
+
157
+ def factory():
158
+ return TurboQuantKVCache(
159
+ bits=args.bits,
160
+ bits_outlier=args.bits_outlier,
161
+ n_outlier=args.n_outlier,
162
+ )
163
+
164
+ text_ref, text_tq, kls, cache, first_diverge, n_new = measure_logit_kl(
165
+ model, tok, args.prompt, factory, max_new_tokens=args.max_new_tokens,
166
+ )
167
+
168
+ print("\n--- BASELINE (fp16, default DynamicCache) ---")
169
+ print(text_ref)
170
+ print("\n--- TURBOQUANT ---")
171
+ print(text_tq)
172
+
173
+ n_tokens = cache.get_seq_length(0)
174
+ fp16_bytes = cache.fp16_baseline_bytes()
175
+ actual_tq_bytes = cache.actual_memory_bytes()
176
+ theoretical_tq_bytes = cache.theoretical_memory_bytes()
177
+
178
+ print("\n--- METRICS ---")
179
+ print(f"prompt + decode tokens : {n_tokens}")
180
+ print(f"fp16 KV cache actual : {fp16_bytes / 1e6:.2f} MB")
181
+ print(f"TurboQuant actual : {actual_tq_bytes / 1e6:.2f} MB "
182
+ f"(× compression = {fp16_bytes / max(actual_tq_bytes, 1):.2f})")
183
+ print(f"TurboQuant theoretical : {theoretical_tq_bytes / 1e6:.2f} MB")
184
+ print(f"mean per-token logit KL: {sum(kls) / len(kls):.5f} (teacher-forced; lower is better)")
185
+ print(f"max per-token logit KL: {max(kls):.5f}")
186
+ print(f"first divergence at : {first_diverge}/{n_new} decode tokens "
187
+ f"(higher is better; {n_new}/{n_new} = exact match)")
188
+ print(f"identical generation? : {text_ref == text_tq}")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ main()
demos/__init__.py ADDED
File without changes
@@ -0,0 +1,75 @@
1
+ """
2
+ Demo 1 — Empirical distortion matches Shannon-theoretic bounds.
3
+
4
+ Story for the audience:
5
+ "We compress unit vectors to b bits per coordinate. The black dashed line
6
+ is the information-theoretic *lower* bound from Shannon source coding;
7
+ the orange line is what *any* possible algorithm could ever achieve
8
+ in the limit. TurboQuant (blue) sits within a small constant of optimum
9
+ for every bit-width, with zero tuning."
10
+
11
+ Run:
12
+ python -m demos.demo1_distortion_vs_theory
13
+ Outputs:
14
+ demos/out/distortion.png
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import numpy as np
21
+ import matplotlib.pyplot as plt
22
+ from tiny_turboquant.numpy_reference import TurboQuantMSE, TurboQuantProd
23
+
24
+ OUT = os.path.join(os.path.dirname(__file__), "out")
25
+ os.makedirs(OUT, exist_ok=True)
26
+
27
+
28
+ def main() -> None:
29
+ rng = np.random.default_rng(0)
30
+ d, n = 512, 4000
31
+ X = rng.standard_normal((n, d))
32
+ X /= np.linalg.norm(X, axis=1, keepdims=True)
33
+
34
+ bits = [1, 2, 3, 4, 5, 6]
35
+ mse_emp, mse_lb, mse_ub = [], [], []
36
+ for b in bits:
37
+ q = TurboQuantMSE(d, b, seed=0)
38
+ Xh = q.dequant(q.quant(X))
39
+ mse_emp.append(float(np.mean(np.sum((X - Xh) ** 2, axis=1))))
40
+ mse_lb.append(4.0 ** (-b)) # Shannon lower bound
41
+ mse_ub.append(3 * np.pi / 2 * 4.0 ** (-b)) # paper Theorem 1
42
+
43
+ fig, ax = plt.subplots(figsize=(7, 4.5))
44
+ ax.plot(bits, mse_ub, "--", color="orange", label="Paper upper bound (3π/2·4⁻ᵇ)")
45
+ ax.plot(bits, mse_lb, "--", color="black", label="Shannon lower bound (4⁻ᵇ)")
46
+ ax.plot(bits, mse_emp, "o-", color="tab:blue", lw=2, label="TurboQuant empirical")
47
+ ax.set_yscale("log")
48
+ ax.set_xlabel("bits per coordinate")
49
+ ax.set_ylabel("MSE E‖x − x̂‖²")
50
+ ax.set_title(f"TurboQuant-MSE vs information-theoretic bounds (d={d}, n={n})")
51
+ ax.grid(True, which="both", alpha=0.3)
52
+ ax.legend()
53
+ fig.tight_layout()
54
+ out = os.path.join(OUT, "distortion.png")
55
+ fig.savefig(out, dpi=140)
56
+ print(f"saved {out}")
57
+
58
+ # Also show inner-product unbiasedness — the *killer* feature for ANN/RAG.
59
+ Y = rng.standard_normal((n, d)); Y /= np.linalg.norm(Y, axis=1, keepdims=True)
60
+ print("\n bits | bias | ip-MSE | ratio vs full-precision")
61
+ print(" -----+----------+-----------+--------------------------")
62
+ for b in (2, 3, 4):
63
+ qp = TurboQuantProd(d, b, seed=0)
64
+ idx, s, g = qp.quant(X)
65
+ Xh = qp.dequant(idx, s, g)
66
+ true = np.sum(X * Y, 1)
67
+ est = np.sum(Xh * Y, 1)
68
+ bias = float(np.mean(est - true))
69
+ var = float(np.mean((est - true) ** 2))
70
+ print(f" {b} | {bias:+.4f} | {var:.5f} | "
71
+ f"{32/b:.0f}× smaller than fp32")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
@@ -0,0 +1,133 @@
1
+ """
2
+ Demo 2 — ANN search: TurboQuant vs Product Quantization.
3
+
4
+ Story for the audience:
5
+ "We index 50k random embeddings and answer top-10 NN queries. TurboQuant
6
+ matches PQ's recall at the same bit budget while indexing in
7
+ ~milliseconds — PQ needs to k-means-train per subspace."
8
+
9
+ Run:
10
+ python -m demos.demo2_ann_vs_pq
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import time
16
+ import numpy as np
17
+
18
+ from tiny_turboquant.numpy_reference import TurboQuantProd
19
+
20
+
21
+ # ----- toy product quantization baseline ----------------------------------
22
+
23
+ def pq_train(X: np.ndarray, n_sub: int, ks: int, seed: int = 0):
24
+ """Train PQ codebooks by vectorised k-means per subspace."""
25
+ rng = np.random.default_rng(seed)
26
+ n, d = X.shape
27
+ sub_dim = d // n_sub
28
+ codebooks = np.empty((n_sub, ks, sub_dim))
29
+ codes = np.empty((n, n_sub), dtype=np.int32)
30
+
31
+ for m in range(n_sub):
32
+ sub = X[:, m * sub_dim : (m + 1) * sub_dim] # (n, sub_dim)
33
+ centers = sub[rng.choice(n, ks, replace=False)].copy() # (ks, sub_dim)
34
+ sub_sq = (sub * sub).sum(1, keepdims=True) # (n, 1)
35
+ for _ in range(10):
36
+ # ||a-b||^2 = ||a||^2 + ||b||^2 - 2 a·b (vectorised)
37
+ c_sq = (centers * centers).sum(1) # (ks,)
38
+ d2 = sub_sq + c_sq[None, :] - 2.0 * sub @ centers.T
39
+ assign = d2.argmin(1)
40
+ # Update centers
41
+ new_centers = np.zeros_like(centers)
42
+ counts = np.bincount(assign, minlength=ks)
43
+ np.add.at(new_centers, assign, sub)
44
+ mask = counts > 0
45
+ new_centers[mask] /= counts[mask, None]
46
+ new_centers[~mask] = centers[~mask]
47
+ centers = new_centers
48
+ codebooks[m] = centers
49
+ c_sq = (centers * centers).sum(1)
50
+ codes[:, m] = (sub_sq + c_sq[None, :] - 2.0 * sub @ centers.T).argmin(1)
51
+ return codebooks, codes, sub_dim
52
+
53
+
54
+ def pq_decode(codebooks: np.ndarray, codes: np.ndarray) -> np.ndarray:
55
+ n_sub, ks, sub_dim = codebooks.shape
56
+ n = codes.shape[0]
57
+ out = np.empty((n, n_sub * sub_dim))
58
+ for m in range(n_sub):
59
+ out[:, m * sub_dim : (m + 1) * sub_dim] = codebooks[m, codes[:, m]]
60
+ return out
61
+
62
+
63
+ # ----- evaluation ---------------------------------------------------------
64
+
65
+ def recall_at_k(true_top: np.ndarray, est_top: np.ndarray) -> float:
66
+ hits = 0
67
+ for t, e in zip(true_top, est_top):
68
+ hits += len(set(t.tolist()) & set(e.tolist()))
69
+ return hits / true_top.size
70
+
71
+
72
+ def main() -> None:
73
+ rng = np.random.default_rng(0)
74
+ n_db, n_q, d, k = 20_000, 200, 128, 10
75
+ bits = 4
76
+
77
+ # Use clustered data: real embeddings (BERT, OpenAI, etc.) are highly
78
+ # structured, *not* i.i.d. Gaussian. On i.i.d. Gaussian PQ has a built-in
79
+ # advantage because it gets to train on the same distribution. On
80
+ # clustered/structured data, TurboQuant's data-oblivious random rotation
81
+ # is competitive while keeping its huge indexing-time advantage.
82
+ n_clusters = 50
83
+ centers = rng.standard_normal((n_clusters, d))
84
+ centers /= np.linalg.norm(centers, axis=1, keepdims=True)
85
+ cluster_id = rng.integers(0, n_clusters, size=n_db)
86
+ X = centers[cluster_id] + 0.25 * rng.standard_normal((n_db, d))
87
+ X = X.astype(np.float64)
88
+ X /= np.linalg.norm(X, axis=1, keepdims=True)
89
+
90
+ q_cluster = rng.integers(0, n_clusters, size=n_q)
91
+ Q = centers[q_cluster] + 0.25 * rng.standard_normal((n_q, d))
92
+ Q = Q.astype(np.float64)
93
+ Q /= np.linalg.norm(Q, axis=1, keepdims=True)
94
+
95
+ # Ground truth
96
+ true_top = np.argsort(-Q @ X.T, axis=1)[:, :k]
97
+
98
+ # ----- TurboQuant -----
99
+ t0 = time.perf_counter()
100
+ tq = TurboQuantProd(d, bits=bits, seed=0)
101
+ idx, signs, gamma = tq.quant(X)
102
+ t_tq_index = time.perf_counter() - t0
103
+
104
+ t0 = time.perf_counter()
105
+ Xh = tq.dequant(idx, signs, gamma)
106
+ est_top_tq = np.argsort(-Q @ Xh.T, axis=1)[:, :k]
107
+ t_tq_query = time.perf_counter() - t0
108
+
109
+ rec_tq = recall_at_k(true_top, est_top_tq)
110
+
111
+ # ----- PQ at matched bit budget -----
112
+ # bits/coord = log2(ks) / sub_dim => ks=256, sub_dim=2 gives 4 bits/coord
113
+ n_sub, ks = d // 2, 256
114
+ t0 = time.perf_counter()
115
+ cb, codes, _ = pq_train(X, n_sub=n_sub, ks=ks, seed=0)
116
+ t_pq_index = time.perf_counter() - t0
117
+
118
+ t0 = time.perf_counter()
119
+ Xh_pq = pq_decode(cb, codes)
120
+ est_top_pq = np.argsort(-Q @ Xh_pq.T, axis=1)[:, :k]
121
+ t_pq_query = time.perf_counter() - t0
122
+ rec_pq = recall_at_k(true_top, est_top_pq)
123
+
124
+ print(f"\nDataset: {n_db} unit vectors in R^{d}, top-{k} ANN, bit budget = {bits}\n")
125
+ print(f"{'method':14s} | indexing(s) | query(s) | recall@{k}")
126
+ print("-" * 56)
127
+ print(f"{'TurboQuant':14s} | {t_tq_index:11.3f} | {t_tq_query:8.3f} | {rec_tq:.3f}")
128
+ print(f"{'PQ (k-means)':14s} | {t_pq_index:11.3f} | {t_pq_query:8.3f} | {rec_pq:.3f}")
129
+ print(f"\nIndexing speedup: {t_pq_index / t_tq_index:,.0f}× in TurboQuant's favour")
130
+
131
+
132
+ if __name__ == "__main__":
133
+ main()