tiny-turboquant 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +0 -0
- benchmarks/bench_ann.py +188 -0
- benchmarks/bench_kv_real.py +192 -0
- demos/__init__.py +0 -0
- demos/demo1_distortion_vs_theory.py +75 -0
- demos/demo2_ann_vs_pq.py +133 -0
- demos/demo3_real_embeddings.py +151 -0
- demos/demo4_kv_cache.py +308 -0
- tiny_turboquant/__init__.py +28 -0
- tiny_turboquant/bitpack.py +98 -0
- tiny_turboquant/codebooks.py +89 -0
- tiny_turboquant/fwht.py +59 -0
- tiny_turboquant/kv_cache.py +337 -0
- tiny_turboquant/numpy_reference.py +271 -0
- tiny_turboquant/outlier_split.py +95 -0
- tiny_turboquant/quantizer.py +119 -0
- tiny_turboquant/rotation.py +80 -0
- tiny_turboquant-0.1.0.dist-info/METADATA +109 -0
- tiny_turboquant-0.1.0.dist-info/RECORD +22 -0
- tiny_turboquant-0.1.0.dist-info/WHEEL +5 -0
- tiny_turboquant-0.1.0.dist-info/licenses/LICENSE +21 -0
- tiny_turboquant-0.1.0.dist-info/top_level.txt +3 -0
benchmarks/__init__.py
ADDED
|
File without changes
|
benchmarks/bench_ann.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ANN A/B benchmark: TurboQuant vs faiss PQ vs faiss IVFPQ vs RaBitQ (optional).
|
|
3
|
+
|
|
4
|
+
Reports recall@k, indexing time, and queries/second for each method on the
|
|
5
|
+
same dataset. The dataset can be loaded from a .npy file (most realistic)
|
|
6
|
+
or generated synthetically.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python -m benchmarks.bench_ann --data path/to/embeddings.npy --bits 4 --k 10
|
|
10
|
+
python -m benchmarks.bench_ann --synthetic 50000 --d 768 --bits 4
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import argparse
|
|
16
|
+
import time
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from tiny_turboquant import TurboQuantProd
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def parse_args():
|
|
26
|
+
p = argparse.ArgumentParser()
|
|
27
|
+
p.add_argument("--data", type=str, default=None,
|
|
28
|
+
help="Path to .npy of shape (N, D), float32.")
|
|
29
|
+
p.add_argument("--synthetic", type=int, default=50_000)
|
|
30
|
+
p.add_argument("--d", type=int, default=768)
|
|
31
|
+
p.add_argument("--bits", type=int, default=4)
|
|
32
|
+
p.add_argument("--k", type=int, default=10)
|
|
33
|
+
p.add_argument("--n_queries", type=int, default=500)
|
|
34
|
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
|
35
|
+
return p.parse_args()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def load_data(args) -> tuple[np.ndarray, np.ndarray]:
|
|
39
|
+
if args.data is not None:
|
|
40
|
+
X = np.load(args.data).astype(np.float32)
|
|
41
|
+
else:
|
|
42
|
+
rng = np.random.default_rng(0)
|
|
43
|
+
# Cluster structure mimicking real embedding stores.
|
|
44
|
+
n_clusters = max(50, args.synthetic // 500)
|
|
45
|
+
centers = rng.standard_normal((n_clusters, args.d)).astype(np.float32)
|
|
46
|
+
centers /= np.linalg.norm(centers, axis=1, keepdims=True)
|
|
47
|
+
cluster_id = rng.integers(0, n_clusters, size=args.synthetic)
|
|
48
|
+
X = centers[cluster_id] + 0.25 * rng.standard_normal((args.synthetic, args.d)).astype(np.float32)
|
|
49
|
+
|
|
50
|
+
X = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-12)
|
|
51
|
+
rng = np.random.default_rng(1)
|
|
52
|
+
qi = rng.choice(X.shape[0], args.n_queries, replace=False)
|
|
53
|
+
return X, X[qi].copy()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def recall_at_k(true_top: np.ndarray, est_top: np.ndarray) -> float:
|
|
57
|
+
hits = sum(len(set(t.tolist()) & set(e.tolist()))
|
|
58
|
+
for t, e in zip(true_top, est_top))
|
|
59
|
+
return hits / true_top.size
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ---- methods ----------------------------------------------------------
|
|
63
|
+
|
|
64
|
+
def turboquant_run(X, Q, k, bits, device):
|
|
65
|
+
Xt = torch.from_numpy(X).to(device)
|
|
66
|
+
Qt = torch.from_numpy(Q).to(device)
|
|
67
|
+
d = X.shape[1]
|
|
68
|
+
|
|
69
|
+
t0 = time.perf_counter()
|
|
70
|
+
q = TurboQuantProd.build(d, bits, device=device, seed=0,
|
|
71
|
+
dtype=torch.float32)
|
|
72
|
+
idx, signs, gamma = q.quant(Xt)
|
|
73
|
+
if device.startswith("cuda"):
|
|
74
|
+
torch.cuda.synchronize()
|
|
75
|
+
t_index = time.perf_counter() - t0
|
|
76
|
+
|
|
77
|
+
t0 = time.perf_counter()
|
|
78
|
+
Xh = q.dequant(idx, signs, gamma)
|
|
79
|
+
sims = Qt @ Xh.T
|
|
80
|
+
top = sims.topk(k, dim=-1).indices.cpu().numpy()
|
|
81
|
+
if device.startswith("cuda"):
|
|
82
|
+
torch.cuda.synchronize()
|
|
83
|
+
t_query = time.perf_counter() - t0
|
|
84
|
+
return top, t_index, t_query
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def faiss_pq_run(X, Q, k, bits):
|
|
88
|
+
try:
|
|
89
|
+
import faiss
|
|
90
|
+
except Exception as e: # pragma: no cover
|
|
91
|
+
print(f" (faiss unavailable: {e}) — skipping PQ")
|
|
92
|
+
return None, None, None
|
|
93
|
+
d = X.shape[1]
|
|
94
|
+
# Match bit-budget: PQ stores log2(ks) bits per sub-quantizer;
|
|
95
|
+
# M sub-vectors * log2(ks) / d total bits per coord.
|
|
96
|
+
# Use ks=2**bits and M=d so PQ's bit-budget = bits per coord.
|
|
97
|
+
M = d
|
|
98
|
+
ks = 2 ** bits
|
|
99
|
+
if ks > 256: # faiss IndexPQ stores 1 byte per code
|
|
100
|
+
ks = 256
|
|
101
|
+
index = faiss.IndexPQ(d, M, int(np.log2(ks)))
|
|
102
|
+
index.metric_type = faiss.METRIC_INNER_PRODUCT
|
|
103
|
+
t0 = time.perf_counter(); index.train(X); index.add(X); t_index = time.perf_counter() - t0
|
|
104
|
+
t0 = time.perf_counter()
|
|
105
|
+
_, top = index.search(Q, k)
|
|
106
|
+
t_query = time.perf_counter() - t0
|
|
107
|
+
return top, t_index, t_query
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def faiss_ivfpq_run(X, Q, k, bits):
|
|
111
|
+
try:
|
|
112
|
+
import faiss
|
|
113
|
+
except Exception:
|
|
114
|
+
return None, None, None
|
|
115
|
+
d = X.shape[1]
|
|
116
|
+
nlist = 100
|
|
117
|
+
M = d
|
|
118
|
+
ks_bits = min(bits, 8)
|
|
119
|
+
quantizer = faiss.IndexFlatIP(d)
|
|
120
|
+
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, ks_bits)
|
|
121
|
+
index.metric_type = faiss.METRIC_INNER_PRODUCT
|
|
122
|
+
t0 = time.perf_counter(); index.train(X); index.add(X); t_index = time.perf_counter() - t0
|
|
123
|
+
index.nprobe = 16
|
|
124
|
+
t0 = time.perf_counter()
|
|
125
|
+
_, top = index.search(Q, k)
|
|
126
|
+
t_query = time.perf_counter() - t0
|
|
127
|
+
return top, t_index, t_query
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def rabitq_run(X, Q, k, bits): # pragma: no cover
|
|
131
|
+
"""Optional RaBitQ baseline; skipped silently if package missing."""
|
|
132
|
+
try:
|
|
133
|
+
import rabitqlib # noqa: F401
|
|
134
|
+
except Exception as e:
|
|
135
|
+
print(f" (rabitq unavailable: {e}) — skipping RaBitQ")
|
|
136
|
+
return None, None, None
|
|
137
|
+
print(" TODO: integrate rabitqlib (left as exercise — official API churns)")
|
|
138
|
+
return None, None, None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ---- driver ----------------------------------------------------------
|
|
142
|
+
|
|
143
|
+
def main():
|
|
144
|
+
args = parse_args()
|
|
145
|
+
X, Q = load_data(args)
|
|
146
|
+
n, d = X.shape
|
|
147
|
+
|
|
148
|
+
# Ground truth: brute-force top-k inner product
|
|
149
|
+
print(f"\nDataset: {n} vectors of dim {d}, {len(Q)} queries, k={args.k}, "
|
|
150
|
+
f"bits={args.bits}, device={args.device}\n")
|
|
151
|
+
t0 = time.perf_counter()
|
|
152
|
+
sims = Q @ X.T
|
|
153
|
+
true_top = np.argpartition(-sims, args.k, axis=1)[:, :args.k]
|
|
154
|
+
# exact top-k order
|
|
155
|
+
true_top = np.array([row[np.argsort(-sims[i, row])] for i, row in enumerate(true_top)])
|
|
156
|
+
t_brute = time.perf_counter() - t0
|
|
157
|
+
print(f"brute force baseline: {t_brute:.3f}s\n")
|
|
158
|
+
|
|
159
|
+
rows = []
|
|
160
|
+
for name, fn in (
|
|
161
|
+
("TurboQuant", lambda: turboquant_run(X, Q, args.k, args.bits, args.device)),
|
|
162
|
+
("faiss PQ", lambda: faiss_pq_run(X, Q, args.k, args.bits)),
|
|
163
|
+
("faiss IVF-PQ", lambda: faiss_ivfpq_run(X, Q, args.k, args.bits)),
|
|
164
|
+
("RaBitQ (opt.)", lambda: rabitq_run(X, Q, args.k, args.bits)),
|
|
165
|
+
):
|
|
166
|
+
top, t_index, t_query = fn()
|
|
167
|
+
if top is None:
|
|
168
|
+
continue
|
|
169
|
+
rec = recall_at_k(true_top, top)
|
|
170
|
+
qps = len(Q) / max(t_query, 1e-9)
|
|
171
|
+
rows.append((name, t_index, t_query, qps, rec))
|
|
172
|
+
|
|
173
|
+
print(f"{'method':16s} | {'index(s)':>9} | {'query(s)':>9} | {'qps':>10} | {'recall@k':>9}")
|
|
174
|
+
print("-" * 72)
|
|
175
|
+
for name, ti, tq, qps, rec in rows:
|
|
176
|
+
print(f"{name:16s} | {ti:>9.3f} | {tq:>9.3f} | {qps:>10.0f} | {rec:>9.3f}")
|
|
177
|
+
print()
|
|
178
|
+
|
|
179
|
+
if rows:
|
|
180
|
+
tq_row = next((r for r in rows if r[0] == "TurboQuant"), None)
|
|
181
|
+
pq_row = next((r for r in rows if r[0] == "faiss PQ"), None)
|
|
182
|
+
if tq_row and pq_row:
|
|
183
|
+
print(f"TurboQuant indexing speedup vs PQ: {pq_row[1] / tq_row[1]:.1f}×")
|
|
184
|
+
print(f"TurboQuant recall delta vs PQ : {tq_row[4] - pq_row[4]:+.3f}")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
main()
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Real-LLM benchmark: TurboQuant KV cache vs fp16 baseline.
|
|
3
|
+
|
|
4
|
+
Runs a short prompt-completion task on a small open model and reports:
|
|
5
|
+
- actual packed KV cache memory
|
|
6
|
+
- per-token attention output cosine similarity
|
|
7
|
+
- generation perplexity / acceptance vs baseline
|
|
8
|
+
|
|
9
|
+
Default model: HuggingFaceTB/SmolLM2-360M-Instruct (small enough for
|
|
10
|
+
free Kaggle T4, big enough to be representative).
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
python -m benchmarks.bench_kv_real --model SmolLM2-360M --bits 4
|
|
14
|
+
python -m benchmarks.bench_kv_real --bits 2 --bits_outlier 3 --n_outlier 32
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import argparse
|
|
20
|
+
import math
|
|
21
|
+
import os
|
|
22
|
+
import time
|
|
23
|
+
|
|
24
|
+
# Quieten the noisy weight-loading tqdm bar from transformers / accelerate.
|
|
25
|
+
# (When stdout is captured by a subprocess pipe, tqdm prints one line per
|
|
26
|
+
# update instead of redrawing in place, which floods the log.)
|
|
27
|
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
|
28
|
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
|
29
|
+
os.environ.setdefault("ACCELERATE_DISABLE_RICH", "1")
|
|
30
|
+
os.environ.setdefault("TQDM_DISABLE", "1")
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
from tiny_turboquant import TurboQuantKVCache
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def parse_args():
|
|
38
|
+
p = argparse.ArgumentParser()
|
|
39
|
+
p.add_argument("--model", default="HuggingFaceTB/SmolLM2-360M-Instruct")
|
|
40
|
+
p.add_argument("--bits", type=int, default=4)
|
|
41
|
+
p.add_argument("--bits_outlier", type=int, default=None,
|
|
42
|
+
help="If set, enable outlier-channel split with this bit-width.")
|
|
43
|
+
p.add_argument("--n_outlier", type=int, default=32)
|
|
44
|
+
p.add_argument("--prompt", default=(
|
|
45
|
+
"The capital of France is Paris. The capital of Germany is Berlin. "
|
|
46
|
+
"The capital of Spain is Madrid. The capital of Italy is"))
|
|
47
|
+
p.add_argument("--max_new_tokens", type=int, default=64)
|
|
48
|
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
|
49
|
+
p.add_argument("--dtype", default="float16")
|
|
50
|
+
return p.parse_args()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@torch.no_grad()
|
|
54
|
+
def measure(model, tokenizer, prompt, cache, max_new_tokens):
|
|
55
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
56
|
+
t0 = time.perf_counter()
|
|
57
|
+
out = model.generate(
|
|
58
|
+
**inputs,
|
|
59
|
+
max_new_tokens=max_new_tokens,
|
|
60
|
+
do_sample=False,
|
|
61
|
+
past_key_values=cache,
|
|
62
|
+
)
|
|
63
|
+
elapsed = time.perf_counter() - t0
|
|
64
|
+
text = tokenizer.decode(out[0], skip_special_tokens=True)
|
|
65
|
+
return text, elapsed, out
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@torch.no_grad()
|
|
69
|
+
def measure_logit_kl(model, tokenizer, prompt, cache_factory, max_new_tokens=32):
|
|
70
|
+
"""Teacher-forced KL between fp16 baseline logits and TurboQuant logits.
|
|
71
|
+
|
|
72
|
+
Both paths see identical input at every step. This isolates pure
|
|
73
|
+
quantization error. Free-running generations are still produced for
|
|
74
|
+
the "identical generation?" check, but they are NOT used for KL —
|
|
75
|
+
otherwise the metric explodes the moment paths diverge by one token.
|
|
76
|
+
"""
|
|
77
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
78
|
+
|
|
79
|
+
# 1) Free-running baseline + TQ generations (for identical-generation check only)
|
|
80
|
+
out_ref = model.generate(
|
|
81
|
+
**inputs, max_new_tokens=max_new_tokens, do_sample=False,
|
|
82
|
+
return_dict_in_generate=True,
|
|
83
|
+
)
|
|
84
|
+
cache = cache_factory()
|
|
85
|
+
out_tq = model.generate(
|
|
86
|
+
**inputs, max_new_tokens=max_new_tokens, do_sample=False,
|
|
87
|
+
return_dict_in_generate=True,
|
|
88
|
+
past_key_values=cache,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
text_ref = tokenizer.decode(out_ref.sequences[0], skip_special_tokens=True)
|
|
92
|
+
text_tq = tokenizer.decode(out_tq.sequences[0], skip_special_tokens=True)
|
|
93
|
+
|
|
94
|
+
# First-divergence index: how many decode steps stay greedy-identical?
|
|
95
|
+
# This is the production-relevant signal — small KL means the answer
|
|
96
|
+
# token is identical even when later filler tokens drift.
|
|
97
|
+
prompt_len_for_div = inputs["input_ids"].shape[1]
|
|
98
|
+
ref_new = out_ref.sequences[0, prompt_len_for_div:].tolist()
|
|
99
|
+
tq_new = out_tq .sequences[0, prompt_len_for_div:].tolist()
|
|
100
|
+
n_new = min(len(ref_new), len(tq_new))
|
|
101
|
+
first_diverge = next(
|
|
102
|
+
(i for i in range(n_new) if ref_new[i] != tq_new[i]),
|
|
103
|
+
n_new,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# 2) Teacher-forced KL: feed the SAME baseline-generated sequence into
|
|
107
|
+
# both fp16-cache and TQ-cache, compare per-position logits.
|
|
108
|
+
seq = out_ref.sequences # (1, prompt_len + new_tokens)
|
|
109
|
+
prompt_len = inputs["input_ids"].shape[1]
|
|
110
|
+
|
|
111
|
+
# fp16 reference: one forward pass, no cache compression.
|
|
112
|
+
ref_logits = model(seq).logits # (1, T, V)
|
|
113
|
+
|
|
114
|
+
# TQ path: prefill + decode through TurboQuantKVCache so attention
|
|
115
|
+
# is computed against quantised K/V at every step.
|
|
116
|
+
tq_cache = cache_factory()
|
|
117
|
+
# prefill on the prompt
|
|
118
|
+
tq_logits_chunks = []
|
|
119
|
+
out = model(seq[:, :prompt_len], past_key_values=tq_cache, use_cache=True)
|
|
120
|
+
tq_logits_chunks.append(out.logits)
|
|
121
|
+
# decode one token at a time, feeding the *baseline* tokens
|
|
122
|
+
for t in range(prompt_len, seq.shape[1]):
|
|
123
|
+
out = model(seq[:, t:t + 1], past_key_values=tq_cache, use_cache=True)
|
|
124
|
+
tq_logits_chunks.append(out.logits)
|
|
125
|
+
tq_logits = torch.cat(tq_logits_chunks, dim=1) # (1, T, V)
|
|
126
|
+
|
|
127
|
+
# Compare logits at positions [prompt_len-1 .. T-2] — the ones that
|
|
128
|
+
# predict the new tokens. Both saw identical inputs.
|
|
129
|
+
p = torch.softmax(ref_logits[:, prompt_len - 1:-1].float(), dim=-1)
|
|
130
|
+
q = torch.softmax(tq_logits [:, prompt_len - 1:-1].float(), dim=-1).clamp_min_(1e-12)
|
|
131
|
+
kls = (p * (p.clamp_min_(1e-12).log() - q.log())).sum(-1).squeeze(0).tolist()
|
|
132
|
+
|
|
133
|
+
return text_ref, text_tq, kls, cache, first_diverge, n_new
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def main() -> None:
|
|
137
|
+
args = parse_args()
|
|
138
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
139
|
+
from transformers.utils import logging as hf_logging
|
|
140
|
+
import huggingface_hub.utils as hf_hub_utils
|
|
141
|
+
|
|
142
|
+
# Hard-disable every progress bar transformers / huggingface_hub know
|
|
143
|
+
# about. Env vars alone are unreliable when this script is launched as
|
|
144
|
+
# a subprocess from kaggle_run.py (pipes confuse tqdm).
|
|
145
|
+
hf_logging.disable_progress_bar()
|
|
146
|
+
hf_logging.set_verbosity_error()
|
|
147
|
+
hf_hub_utils.disable_progress_bars()
|
|
148
|
+
|
|
149
|
+
dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16,
|
|
150
|
+
"float32": torch.float32}[args.dtype]
|
|
151
|
+
print(f"Loading {args.model} ({args.dtype} on {args.device})")
|
|
152
|
+
tok = AutoTokenizer.from_pretrained(args.model)
|
|
153
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
154
|
+
args.model, torch_dtype=dtype
|
|
155
|
+
).to(args.device).eval()
|
|
156
|
+
|
|
157
|
+
def factory():
|
|
158
|
+
return TurboQuantKVCache(
|
|
159
|
+
bits=args.bits,
|
|
160
|
+
bits_outlier=args.bits_outlier,
|
|
161
|
+
n_outlier=args.n_outlier,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
text_ref, text_tq, kls, cache, first_diverge, n_new = measure_logit_kl(
|
|
165
|
+
model, tok, args.prompt, factory, max_new_tokens=args.max_new_tokens,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
print("\n--- BASELINE (fp16, default DynamicCache) ---")
|
|
169
|
+
print(text_ref)
|
|
170
|
+
print("\n--- TURBOQUANT ---")
|
|
171
|
+
print(text_tq)
|
|
172
|
+
|
|
173
|
+
n_tokens = cache.get_seq_length(0)
|
|
174
|
+
fp16_bytes = cache.fp16_baseline_bytes()
|
|
175
|
+
actual_tq_bytes = cache.actual_memory_bytes()
|
|
176
|
+
theoretical_tq_bytes = cache.theoretical_memory_bytes()
|
|
177
|
+
|
|
178
|
+
print("\n--- METRICS ---")
|
|
179
|
+
print(f"prompt + decode tokens : {n_tokens}")
|
|
180
|
+
print(f"fp16 KV cache actual : {fp16_bytes / 1e6:.2f} MB")
|
|
181
|
+
print(f"TurboQuant actual : {actual_tq_bytes / 1e6:.2f} MB "
|
|
182
|
+
f"(× compression = {fp16_bytes / max(actual_tq_bytes, 1):.2f})")
|
|
183
|
+
print(f"TurboQuant theoretical : {theoretical_tq_bytes / 1e6:.2f} MB")
|
|
184
|
+
print(f"mean per-token logit KL: {sum(kls) / len(kls):.5f} (teacher-forced; lower is better)")
|
|
185
|
+
print(f"max per-token logit KL: {max(kls):.5f}")
|
|
186
|
+
print(f"first divergence at : {first_diverge}/{n_new} decode tokens "
|
|
187
|
+
f"(higher is better; {n_new}/{n_new} = exact match)")
|
|
188
|
+
print(f"identical generation? : {text_ref == text_tq}")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
if __name__ == "__main__":
|
|
192
|
+
main()
|
demos/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Demo 1 — Empirical distortion matches Shannon-theoretic bounds.
|
|
3
|
+
|
|
4
|
+
Story for the audience:
|
|
5
|
+
"We compress unit vectors to b bits per coordinate. The black dashed line
|
|
6
|
+
is the information-theoretic *lower* bound from Shannon source coding;
|
|
7
|
+
the orange line is what *any* possible algorithm could ever achieve
|
|
8
|
+
in the limit. TurboQuant (blue) sits within a small constant of optimum
|
|
9
|
+
for every bit-width, with zero tuning."
|
|
10
|
+
|
|
11
|
+
Run:
|
|
12
|
+
python -m demos.demo1_distortion_vs_theory
|
|
13
|
+
Outputs:
|
|
14
|
+
demos/out/distortion.png
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import os
|
|
20
|
+
import numpy as np
|
|
21
|
+
import matplotlib.pyplot as plt
|
|
22
|
+
from tiny_turboquant.numpy_reference import TurboQuantMSE, TurboQuantProd
|
|
23
|
+
|
|
24
|
+
OUT = os.path.join(os.path.dirname(__file__), "out")
|
|
25
|
+
os.makedirs(OUT, exist_ok=True)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def main() -> None:
|
|
29
|
+
rng = np.random.default_rng(0)
|
|
30
|
+
d, n = 512, 4000
|
|
31
|
+
X = rng.standard_normal((n, d))
|
|
32
|
+
X /= np.linalg.norm(X, axis=1, keepdims=True)
|
|
33
|
+
|
|
34
|
+
bits = [1, 2, 3, 4, 5, 6]
|
|
35
|
+
mse_emp, mse_lb, mse_ub = [], [], []
|
|
36
|
+
for b in bits:
|
|
37
|
+
q = TurboQuantMSE(d, b, seed=0)
|
|
38
|
+
Xh = q.dequant(q.quant(X))
|
|
39
|
+
mse_emp.append(float(np.mean(np.sum((X - Xh) ** 2, axis=1))))
|
|
40
|
+
mse_lb.append(4.0 ** (-b)) # Shannon lower bound
|
|
41
|
+
mse_ub.append(3 * np.pi / 2 * 4.0 ** (-b)) # paper Theorem 1
|
|
42
|
+
|
|
43
|
+
fig, ax = plt.subplots(figsize=(7, 4.5))
|
|
44
|
+
ax.plot(bits, mse_ub, "--", color="orange", label="Paper upper bound (3π/2·4⁻ᵇ)")
|
|
45
|
+
ax.plot(bits, mse_lb, "--", color="black", label="Shannon lower bound (4⁻ᵇ)")
|
|
46
|
+
ax.plot(bits, mse_emp, "o-", color="tab:blue", lw=2, label="TurboQuant empirical")
|
|
47
|
+
ax.set_yscale("log")
|
|
48
|
+
ax.set_xlabel("bits per coordinate")
|
|
49
|
+
ax.set_ylabel("MSE E‖x − x̂‖²")
|
|
50
|
+
ax.set_title(f"TurboQuant-MSE vs information-theoretic bounds (d={d}, n={n})")
|
|
51
|
+
ax.grid(True, which="both", alpha=0.3)
|
|
52
|
+
ax.legend()
|
|
53
|
+
fig.tight_layout()
|
|
54
|
+
out = os.path.join(OUT, "distortion.png")
|
|
55
|
+
fig.savefig(out, dpi=140)
|
|
56
|
+
print(f"saved {out}")
|
|
57
|
+
|
|
58
|
+
# Also show inner-product unbiasedness — the *killer* feature for ANN/RAG.
|
|
59
|
+
Y = rng.standard_normal((n, d)); Y /= np.linalg.norm(Y, axis=1, keepdims=True)
|
|
60
|
+
print("\n bits | bias | ip-MSE | ratio vs full-precision")
|
|
61
|
+
print(" -----+----------+-----------+--------------------------")
|
|
62
|
+
for b in (2, 3, 4):
|
|
63
|
+
qp = TurboQuantProd(d, b, seed=0)
|
|
64
|
+
idx, s, g = qp.quant(X)
|
|
65
|
+
Xh = qp.dequant(idx, s, g)
|
|
66
|
+
true = np.sum(X * Y, 1)
|
|
67
|
+
est = np.sum(Xh * Y, 1)
|
|
68
|
+
bias = float(np.mean(est - true))
|
|
69
|
+
var = float(np.mean((est - true) ** 2))
|
|
70
|
+
print(f" {b} | {bias:+.4f} | {var:.5f} | "
|
|
71
|
+
f"{32/b:.0f}× smaller than fp32")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
main()
|
demos/demo2_ann_vs_pq.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Demo 2 — ANN search: TurboQuant vs Product Quantization.
|
|
3
|
+
|
|
4
|
+
Story for the audience:
|
|
5
|
+
"We index 50k random embeddings and answer top-10 NN queries. TurboQuant
|
|
6
|
+
matches PQ's recall at the same bit budget while indexing in
|
|
7
|
+
~milliseconds — PQ needs to k-means-train per subspace."
|
|
8
|
+
|
|
9
|
+
Run:
|
|
10
|
+
python -m demos.demo2_ann_vs_pq
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import time
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
from tiny_turboquant.numpy_reference import TurboQuantProd
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ----- toy product quantization baseline ----------------------------------
|
|
22
|
+
|
|
23
|
+
def pq_train(X: np.ndarray, n_sub: int, ks: int, seed: int = 0):
|
|
24
|
+
"""Train PQ codebooks by vectorised k-means per subspace."""
|
|
25
|
+
rng = np.random.default_rng(seed)
|
|
26
|
+
n, d = X.shape
|
|
27
|
+
sub_dim = d // n_sub
|
|
28
|
+
codebooks = np.empty((n_sub, ks, sub_dim))
|
|
29
|
+
codes = np.empty((n, n_sub), dtype=np.int32)
|
|
30
|
+
|
|
31
|
+
for m in range(n_sub):
|
|
32
|
+
sub = X[:, m * sub_dim : (m + 1) * sub_dim] # (n, sub_dim)
|
|
33
|
+
centers = sub[rng.choice(n, ks, replace=False)].copy() # (ks, sub_dim)
|
|
34
|
+
sub_sq = (sub * sub).sum(1, keepdims=True) # (n, 1)
|
|
35
|
+
for _ in range(10):
|
|
36
|
+
# ||a-b||^2 = ||a||^2 + ||b||^2 - 2 a·b (vectorised)
|
|
37
|
+
c_sq = (centers * centers).sum(1) # (ks,)
|
|
38
|
+
d2 = sub_sq + c_sq[None, :] - 2.0 * sub @ centers.T
|
|
39
|
+
assign = d2.argmin(1)
|
|
40
|
+
# Update centers
|
|
41
|
+
new_centers = np.zeros_like(centers)
|
|
42
|
+
counts = np.bincount(assign, minlength=ks)
|
|
43
|
+
np.add.at(new_centers, assign, sub)
|
|
44
|
+
mask = counts > 0
|
|
45
|
+
new_centers[mask] /= counts[mask, None]
|
|
46
|
+
new_centers[~mask] = centers[~mask]
|
|
47
|
+
centers = new_centers
|
|
48
|
+
codebooks[m] = centers
|
|
49
|
+
c_sq = (centers * centers).sum(1)
|
|
50
|
+
codes[:, m] = (sub_sq + c_sq[None, :] - 2.0 * sub @ centers.T).argmin(1)
|
|
51
|
+
return codebooks, codes, sub_dim
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def pq_decode(codebooks: np.ndarray, codes: np.ndarray) -> np.ndarray:
|
|
55
|
+
n_sub, ks, sub_dim = codebooks.shape
|
|
56
|
+
n = codes.shape[0]
|
|
57
|
+
out = np.empty((n, n_sub * sub_dim))
|
|
58
|
+
for m in range(n_sub):
|
|
59
|
+
out[:, m * sub_dim : (m + 1) * sub_dim] = codebooks[m, codes[:, m]]
|
|
60
|
+
return out
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ----- evaluation ---------------------------------------------------------
|
|
64
|
+
|
|
65
|
+
def recall_at_k(true_top: np.ndarray, est_top: np.ndarray) -> float:
|
|
66
|
+
hits = 0
|
|
67
|
+
for t, e in zip(true_top, est_top):
|
|
68
|
+
hits += len(set(t.tolist()) & set(e.tolist()))
|
|
69
|
+
return hits / true_top.size
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def main() -> None:
|
|
73
|
+
rng = np.random.default_rng(0)
|
|
74
|
+
n_db, n_q, d, k = 20_000, 200, 128, 10
|
|
75
|
+
bits = 4
|
|
76
|
+
|
|
77
|
+
# Use clustered data: real embeddings (BERT, OpenAI, etc.) are highly
|
|
78
|
+
# structured, *not* i.i.d. Gaussian. On i.i.d. Gaussian PQ has a built-in
|
|
79
|
+
# advantage because it gets to train on the same distribution. On
|
|
80
|
+
# clustered/structured data, TurboQuant's data-oblivious random rotation
|
|
81
|
+
# is competitive while keeping its huge indexing-time advantage.
|
|
82
|
+
n_clusters = 50
|
|
83
|
+
centers = rng.standard_normal((n_clusters, d))
|
|
84
|
+
centers /= np.linalg.norm(centers, axis=1, keepdims=True)
|
|
85
|
+
cluster_id = rng.integers(0, n_clusters, size=n_db)
|
|
86
|
+
X = centers[cluster_id] + 0.25 * rng.standard_normal((n_db, d))
|
|
87
|
+
X = X.astype(np.float64)
|
|
88
|
+
X /= np.linalg.norm(X, axis=1, keepdims=True)
|
|
89
|
+
|
|
90
|
+
q_cluster = rng.integers(0, n_clusters, size=n_q)
|
|
91
|
+
Q = centers[q_cluster] + 0.25 * rng.standard_normal((n_q, d))
|
|
92
|
+
Q = Q.astype(np.float64)
|
|
93
|
+
Q /= np.linalg.norm(Q, axis=1, keepdims=True)
|
|
94
|
+
|
|
95
|
+
# Ground truth
|
|
96
|
+
true_top = np.argsort(-Q @ X.T, axis=1)[:, :k]
|
|
97
|
+
|
|
98
|
+
# ----- TurboQuant -----
|
|
99
|
+
t0 = time.perf_counter()
|
|
100
|
+
tq = TurboQuantProd(d, bits=bits, seed=0)
|
|
101
|
+
idx, signs, gamma = tq.quant(X)
|
|
102
|
+
t_tq_index = time.perf_counter() - t0
|
|
103
|
+
|
|
104
|
+
t0 = time.perf_counter()
|
|
105
|
+
Xh = tq.dequant(idx, signs, gamma)
|
|
106
|
+
est_top_tq = np.argsort(-Q @ Xh.T, axis=1)[:, :k]
|
|
107
|
+
t_tq_query = time.perf_counter() - t0
|
|
108
|
+
|
|
109
|
+
rec_tq = recall_at_k(true_top, est_top_tq)
|
|
110
|
+
|
|
111
|
+
# ----- PQ at matched bit budget -----
|
|
112
|
+
# bits/coord = log2(ks) / sub_dim => ks=256, sub_dim=2 gives 4 bits/coord
|
|
113
|
+
n_sub, ks = d // 2, 256
|
|
114
|
+
t0 = time.perf_counter()
|
|
115
|
+
cb, codes, _ = pq_train(X, n_sub=n_sub, ks=ks, seed=0)
|
|
116
|
+
t_pq_index = time.perf_counter() - t0
|
|
117
|
+
|
|
118
|
+
t0 = time.perf_counter()
|
|
119
|
+
Xh_pq = pq_decode(cb, codes)
|
|
120
|
+
est_top_pq = np.argsort(-Q @ Xh_pq.T, axis=1)[:, :k]
|
|
121
|
+
t_pq_query = time.perf_counter() - t0
|
|
122
|
+
rec_pq = recall_at_k(true_top, est_top_pq)
|
|
123
|
+
|
|
124
|
+
print(f"\nDataset: {n_db} unit vectors in R^{d}, top-{k} ANN, bit budget = {bits}\n")
|
|
125
|
+
print(f"{'method':14s} | indexing(s) | query(s) | recall@{k}")
|
|
126
|
+
print("-" * 56)
|
|
127
|
+
print(f"{'TurboQuant':14s} | {t_tq_index:11.3f} | {t_tq_query:8.3f} | {rec_tq:.3f}")
|
|
128
|
+
print(f"{'PQ (k-means)':14s} | {t_pq_index:11.3f} | {t_pq_query:8.3f} | {rec_pq:.3f}")
|
|
129
|
+
print(f"\nIndexing speedup: {t_pq_index / t_tq_index:,.0f}× in TurboQuant's favour")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
if __name__ == "__main__":
|
|
133
|
+
main()
|