vectlite 0.1.11 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +230 -1
- package/index.d.ts +55 -0
- package/index.js +171 -12
- package/native/Cargo.toml +1 -1
- package/native/src/lib.rs +733 -25
- package/native/vectlite-core/Cargo.toml +2 -1
- package/native/vectlite-core/src/lib.rs +6092 -1990
- package/native/vectlite-core/src/quantization.rs +1587 -0
- package/package.json +1 -1
- package/prebuilds/darwin-arm64/vectlite.node +0 -0
- package/prebuilds/darwin-x64/vectlite.node +0 -0
- package/prebuilds/linux-x64-gnu/vectlite.node +0 -0
- package/prebuilds/win32-x64-msvc/vectlite.node +0 -0
|
@@ -0,0 +1,1587 @@
|
|
|
1
|
+
//! Vector quantization module for memory-efficient similarity search.
|
|
2
|
+
//!
|
|
3
|
+
//! Supports three quantization strategies:
|
|
4
|
+
//! - **Scalar (int8)**: 4x memory reduction with minimal recall loss
|
|
5
|
+
//! - **Binary**: 32x memory reduction, uses Hamming distance for fast filtering
|
|
6
|
+
//! - **Product Quantization (PQ)**: Configurable compression for very large datasets
|
|
7
|
+
//!
|
|
8
|
+
//! All strategies support a 2-stage pipeline: fast quantized search followed by
|
|
9
|
+
//! exact float32 rescoring of top candidates.
|
|
10
|
+
|
|
11
|
+
use std::io::{Read, Write};
|
|
12
|
+
|
|
13
|
+
// ---------------------------------------------------------------------------
|
|
14
|
+
// Public types
|
|
15
|
+
// ---------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
/// Configuration for enabling quantization on a database.
|
|
18
|
+
#[derive(Clone, Debug, PartialEq)]
|
|
19
|
+
pub enum QuantizationConfig {
|
|
20
|
+
/// Scalar quantization: maps each f32 dimension to int8 using per-dimension
|
|
21
|
+
/// min/max calibration. 4x memory reduction.
|
|
22
|
+
Scalar(ScalarQuantizationConfig),
|
|
23
|
+
/// Binary quantization: maps each f32 dimension to a single bit.
|
|
24
|
+
/// 32x memory reduction. Best for high-dimensional normalized embeddings.
|
|
25
|
+
Binary(BinaryQuantizationConfig),
|
|
26
|
+
/// Product quantization: splits vector into sub-vectors and quantizes each
|
|
27
|
+
/// to a centroid index. Highest compression for large datasets.
|
|
28
|
+
Product(ProductQuantizationConfig),
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
#[derive(Clone, Debug, PartialEq)]
|
|
32
|
+
pub struct ScalarQuantizationConfig {
|
|
33
|
+
/// Number of top candidates from quantized search to rescore with float32.
|
|
34
|
+
/// Default: 5x top_k (minimum 100).
|
|
35
|
+
pub rescore_multiplier: usize,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
impl Default for ScalarQuantizationConfig {
|
|
39
|
+
fn default() -> Self {
|
|
40
|
+
Self {
|
|
41
|
+
rescore_multiplier: 5,
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
#[derive(Clone, Debug, PartialEq)]
|
|
47
|
+
pub struct BinaryQuantizationConfig {
|
|
48
|
+
/// Number of top candidates from Hamming search to rescore with float32.
|
|
49
|
+
/// Default: 10x top_k (minimum 100).
|
|
50
|
+
pub rescore_multiplier: usize,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
impl Default for BinaryQuantizationConfig {
|
|
54
|
+
fn default() -> Self {
|
|
55
|
+
Self {
|
|
56
|
+
rescore_multiplier: 10,
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
#[derive(Clone, Debug, PartialEq)]
|
|
62
|
+
pub struct ProductQuantizationConfig {
|
|
63
|
+
/// Number of sub-vectors (sub-spaces). Must divide the vector dimension evenly.
|
|
64
|
+
/// Typical values: 8, 16, 32, 64.
|
|
65
|
+
pub num_sub_vectors: usize,
|
|
66
|
+
/// Number of centroids per sub-vector (k in k-means). Must be <= 256.
|
|
67
|
+
/// Default: 256 (uses u8 codes).
|
|
68
|
+
pub num_centroids: usize,
|
|
69
|
+
/// Number of k-means training iterations.
|
|
70
|
+
pub training_iterations: usize,
|
|
71
|
+
/// Number of top candidates from PQ search to rescore with float32.
|
|
72
|
+
pub rescore_multiplier: usize,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
impl Default for ProductQuantizationConfig {
|
|
76
|
+
fn default() -> Self {
|
|
77
|
+
Self {
|
|
78
|
+
num_sub_vectors: 16,
|
|
79
|
+
num_centroids: 256,
|
|
80
|
+
training_iterations: 20,
|
|
81
|
+
rescore_multiplier: 10,
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// ---------------------------------------------------------------------------
|
|
87
|
+
// Scalar Quantization
|
|
88
|
+
// ---------------------------------------------------------------------------
|
|
89
|
+
|
|
90
|
+
/// Calibration parameters for scalar quantization (per-dimension min/max).
|
|
91
|
+
#[derive(Clone, Debug)]
|
|
92
|
+
pub struct ScalarQuantizer {
|
|
93
|
+
pub dimension: usize,
|
|
94
|
+
/// Per-dimension minimum values used for calibration.
|
|
95
|
+
pub mins: Vec<f32>,
|
|
96
|
+
/// Per-dimension maximum values used for calibration.
|
|
97
|
+
pub maxs: Vec<f32>,
|
|
98
|
+
/// Per-dimension scale: 255.0 / (max - min). Pre-computed for fast quantization.
|
|
99
|
+
scales: Vec<f32>,
|
|
100
|
+
/// Quantized vectors stored as flat u8 array (n * dimension).
|
|
101
|
+
pub codes: Vec<u8>,
|
|
102
|
+
/// Number of quantized vectors.
|
|
103
|
+
pub count: usize,
|
|
104
|
+
pub config: ScalarQuantizationConfig,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
impl ScalarQuantizer {
|
|
108
|
+
/// Train a scalar quantizer by computing per-dimension min/max from training vectors.
|
|
109
|
+
pub fn train(vectors: &[&[f32]], dimension: usize, config: ScalarQuantizationConfig) -> Self {
|
|
110
|
+
assert!(!vectors.is_empty(), "need at least one vector to train");
|
|
111
|
+
assert!(vectors[0].len() == dimension);
|
|
112
|
+
|
|
113
|
+
let mut mins = vec![f32::INFINITY; dimension];
|
|
114
|
+
let mut maxs = vec![f32::NEG_INFINITY; dimension];
|
|
115
|
+
|
|
116
|
+
for vector in vectors {
|
|
117
|
+
for (i, &val) in vector.iter().enumerate() {
|
|
118
|
+
if val < mins[i] {
|
|
119
|
+
mins[i] = val;
|
|
120
|
+
}
|
|
121
|
+
if val > maxs[i] {
|
|
122
|
+
maxs[i] = val;
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
let scales: Vec<f32> = mins
|
|
128
|
+
.iter()
|
|
129
|
+
.zip(maxs.iter())
|
|
130
|
+
.map(|(&min, &max)| {
|
|
131
|
+
let range = max - min;
|
|
132
|
+
if range < 1e-10 { 0.0 } else { 255.0 / range }
|
|
133
|
+
})
|
|
134
|
+
.collect();
|
|
135
|
+
|
|
136
|
+
let mut codes = Vec::with_capacity(vectors.len() * dimension);
|
|
137
|
+
for vector in vectors {
|
|
138
|
+
for (i, &val) in vector.iter().enumerate() {
|
|
139
|
+
codes.push(quantize_scalar(val, mins[i], scales[i]));
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
Self {
|
|
144
|
+
dimension,
|
|
145
|
+
mins,
|
|
146
|
+
maxs,
|
|
147
|
+
scales,
|
|
148
|
+
codes,
|
|
149
|
+
count: vectors.len(),
|
|
150
|
+
config,
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
/// Add vectors to the quantized index (after initial training).
|
|
155
|
+
pub fn add_vectors(&mut self, vectors: &[&[f32]]) {
|
|
156
|
+
for vector in vectors {
|
|
157
|
+
assert_eq!(vector.len(), self.dimension);
|
|
158
|
+
for (i, &val) in vector.iter().enumerate() {
|
|
159
|
+
self.codes
|
|
160
|
+
.push(quantize_scalar(val, self.mins[i], self.scales[i]));
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
self.count += vectors.len();
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
/// Quantize a single query vector.
|
|
167
|
+
pub fn quantize_query(&self, query: &[f32]) -> Vec<u8> {
|
|
168
|
+
assert_eq!(query.len(), self.dimension);
|
|
169
|
+
query
|
|
170
|
+
.iter()
|
|
171
|
+
.enumerate()
|
|
172
|
+
.map(|(i, &val)| quantize_scalar(val, self.mins[i], self.scales[i]))
|
|
173
|
+
.collect()
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
/// Compute approximate cosine distance between a quantized query and all stored vectors.
|
|
177
|
+
/// Returns indices sorted by approximate similarity (best first).
|
|
178
|
+
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
|
|
179
|
+
let rescore_count = (top_k * self.config.rescore_multiplier)
|
|
180
|
+
.max(100)
|
|
181
|
+
.min(self.count);
|
|
182
|
+
let query_quantized = self.quantize_query(query);
|
|
183
|
+
let mut scores: Vec<(usize, f32)> = (0..self.count)
|
|
184
|
+
.map(|idx| {
|
|
185
|
+
let offset = idx * self.dimension;
|
|
186
|
+
let code_slice = &self.codes[offset..offset + self.dimension];
|
|
187
|
+
let sim = scalar_quantized_dot(&query_quantized, code_slice);
|
|
188
|
+
(idx, sim)
|
|
189
|
+
})
|
|
190
|
+
.collect();
|
|
191
|
+
|
|
192
|
+
// Partial sort: get top rescore_count candidates
|
|
193
|
+
scores.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
|
|
194
|
+
scores.truncate(rescore_count);
|
|
195
|
+
scores
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/// Rebuild codes from training vectors (used after deserialization with new vectors).
|
|
199
|
+
pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
|
|
200
|
+
self.codes.clear();
|
|
201
|
+
self.codes.reserve(vectors.len() * self.dimension);
|
|
202
|
+
for vector in vectors {
|
|
203
|
+
for (i, &val) in vector.iter().enumerate() {
|
|
204
|
+
self.codes
|
|
205
|
+
.push(quantize_scalar(val, self.mins[i], self.scales[i]));
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
self.count = vectors.len();
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
/// Serialize the quantizer parameters (not the codes, which are rebuilt on load).
|
|
212
|
+
pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
|
|
213
|
+
// Tag byte: 1 = scalar
|
|
214
|
+
writer.write_all(&[1u8])?;
|
|
215
|
+
write_usize(writer, self.dimension)?;
|
|
216
|
+
write_usize(writer, self.config.rescore_multiplier)?;
|
|
217
|
+
for &v in &self.mins {
|
|
218
|
+
writer.write_all(&v.to_le_bytes())?;
|
|
219
|
+
}
|
|
220
|
+
for &v in &self.maxs {
|
|
221
|
+
writer.write_all(&v.to_le_bytes())?;
|
|
222
|
+
}
|
|
223
|
+
Ok(())
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
/// Deserialize quantizer parameters.
|
|
227
|
+
pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
|
|
228
|
+
let dimension = read_usize(reader)?;
|
|
229
|
+
let rescore_multiplier = read_usize(reader)?;
|
|
230
|
+
let mut mins = vec![0.0_f32; dimension];
|
|
231
|
+
let mut maxs = vec![0.0_f32; dimension];
|
|
232
|
+
for v in &mut mins {
|
|
233
|
+
let mut buf = [0u8; 4];
|
|
234
|
+
reader.read_exact(&mut buf)?;
|
|
235
|
+
*v = f32::from_le_bytes(buf);
|
|
236
|
+
}
|
|
237
|
+
for v in &mut maxs {
|
|
238
|
+
let mut buf = [0u8; 4];
|
|
239
|
+
reader.read_exact(&mut buf)?;
|
|
240
|
+
*v = f32::from_le_bytes(buf);
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
let scales: Vec<f32> = mins
|
|
244
|
+
.iter()
|
|
245
|
+
.zip(maxs.iter())
|
|
246
|
+
.map(|(&min, &max)| {
|
|
247
|
+
let range = max - min;
|
|
248
|
+
if range < 1e-10 { 0.0 } else { 255.0 / range }
|
|
249
|
+
})
|
|
250
|
+
.collect();
|
|
251
|
+
|
|
252
|
+
Ok(Self {
|
|
253
|
+
dimension,
|
|
254
|
+
mins,
|
|
255
|
+
maxs,
|
|
256
|
+
scales,
|
|
257
|
+
codes: Vec::new(),
|
|
258
|
+
count: 0,
|
|
259
|
+
config: ScalarQuantizationConfig { rescore_multiplier },
|
|
260
|
+
})
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// ---------------------------------------------------------------------------
|
|
265
|
+
// Binary Quantization
|
|
266
|
+
// ---------------------------------------------------------------------------
|
|
267
|
+
|
|
268
|
+
/// Binary quantizer: each dimension is mapped to a single bit (sign of the value).
|
|
269
|
+
/// Uses Hamming distance for fast candidate selection.
|
|
270
|
+
#[derive(Clone, Debug)]
|
|
271
|
+
pub struct BinaryQuantizer {
|
|
272
|
+
pub dimension: usize,
|
|
273
|
+
/// Number of bytes per vector: ceil(dimension / 8).
|
|
274
|
+
pub bytes_per_vector: usize,
|
|
275
|
+
/// Binary codes stored as flat byte array (n * bytes_per_vector).
|
|
276
|
+
pub codes: Vec<u8>,
|
|
277
|
+
/// Number of quantized vectors.
|
|
278
|
+
pub count: usize,
|
|
279
|
+
pub config: BinaryQuantizationConfig,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
impl BinaryQuantizer {
|
|
283
|
+
/// Create a binary quantizer for vectors of the given dimension.
|
|
284
|
+
pub fn new(dimension: usize, config: BinaryQuantizationConfig) -> Self {
|
|
285
|
+
let bytes_per_vector = (dimension + 7) / 8;
|
|
286
|
+
Self {
|
|
287
|
+
dimension,
|
|
288
|
+
bytes_per_vector,
|
|
289
|
+
codes: Vec::new(),
|
|
290
|
+
count: 0,
|
|
291
|
+
config,
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
/// Binarize vectors and add to the index.
|
|
296
|
+
pub fn add_vectors(&mut self, vectors: &[&[f32]]) {
|
|
297
|
+
for vector in vectors {
|
|
298
|
+
assert_eq!(vector.len(), self.dimension);
|
|
299
|
+
let binary = binarize_vector(vector);
|
|
300
|
+
self.codes.extend_from_slice(&binary);
|
|
301
|
+
}
|
|
302
|
+
self.count += vectors.len();
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
/// Binarize a query vector.
|
|
306
|
+
pub fn binarize_query(&self, query: &[f32]) -> Vec<u8> {
|
|
307
|
+
assert_eq!(query.len(), self.dimension);
|
|
308
|
+
binarize_vector(query)
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
/// Search using Hamming distance. Returns candidate indices sorted by
|
|
312
|
+
/// Hamming similarity (fewest differing bits first).
|
|
313
|
+
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, u32)> {
|
|
314
|
+
let rescore_count = (top_k * self.config.rescore_multiplier)
|
|
315
|
+
.max(100)
|
|
316
|
+
.min(self.count);
|
|
317
|
+
let query_binary = self.binarize_query(query);
|
|
318
|
+
let mut distances: Vec<(usize, u32)> = (0..self.count)
|
|
319
|
+
.map(|idx| {
|
|
320
|
+
let offset = idx * self.bytes_per_vector;
|
|
321
|
+
let code_slice = &self.codes[offset..offset + self.bytes_per_vector];
|
|
322
|
+
let dist = hamming_distance(&query_binary, code_slice);
|
|
323
|
+
(idx, dist)
|
|
324
|
+
})
|
|
325
|
+
.collect();
|
|
326
|
+
|
|
327
|
+
// Sort by Hamming distance (ascending = most similar first)
|
|
328
|
+
distances.sort_unstable_by_key(|&(_, d)| d);
|
|
329
|
+
distances.truncate(rescore_count);
|
|
330
|
+
distances
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
/// Rebuild codes from vectors.
|
|
334
|
+
pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
|
|
335
|
+
self.codes.clear();
|
|
336
|
+
self.codes.reserve(vectors.len() * self.bytes_per_vector);
|
|
337
|
+
for vector in vectors {
|
|
338
|
+
let binary = binarize_vector(vector);
|
|
339
|
+
self.codes.extend_from_slice(&binary);
|
|
340
|
+
}
|
|
341
|
+
self.count = vectors.len();
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
/// Serialize parameters.
|
|
345
|
+
pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
|
|
346
|
+
// Tag byte: 2 = binary
|
|
347
|
+
writer.write_all(&[2u8])?;
|
|
348
|
+
write_usize(writer, self.dimension)?;
|
|
349
|
+
write_usize(writer, self.config.rescore_multiplier)?;
|
|
350
|
+
Ok(())
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
/// Deserialize parameters.
|
|
354
|
+
pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
|
|
355
|
+
let dimension = read_usize(reader)?;
|
|
356
|
+
let rescore_multiplier = read_usize(reader)?;
|
|
357
|
+
let bytes_per_vector = (dimension + 7) / 8;
|
|
358
|
+
Ok(Self {
|
|
359
|
+
dimension,
|
|
360
|
+
bytes_per_vector,
|
|
361
|
+
codes: Vec::new(),
|
|
362
|
+
count: 0,
|
|
363
|
+
config: BinaryQuantizationConfig { rescore_multiplier },
|
|
364
|
+
})
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
// ---------------------------------------------------------------------------
|
|
369
|
+
// Product Quantization
|
|
370
|
+
// ---------------------------------------------------------------------------
|
|
371
|
+
|
|
372
|
+
/// Product quantizer: divides vector into sub-vectors and maps each to a centroid.
|
|
373
|
+
#[derive(Clone, Debug)]
|
|
374
|
+
pub struct ProductQuantizer {
|
|
375
|
+
pub dimension: usize,
|
|
376
|
+
pub num_sub_vectors: usize,
|
|
377
|
+
pub sub_dimension: usize,
|
|
378
|
+
pub num_centroids: usize,
|
|
379
|
+
/// Codebooks: shape [num_sub_vectors][num_centroids][sub_dimension].
|
|
380
|
+
pub codebooks: Vec<Vec<Vec<f32>>>,
|
|
381
|
+
/// PQ codes: flat array of (n * num_sub_vectors) u8 indices.
|
|
382
|
+
pub codes: Vec<u8>,
|
|
383
|
+
/// Number of quantized vectors.
|
|
384
|
+
pub count: usize,
|
|
385
|
+
pub config: ProductQuantizationConfig,
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
impl ProductQuantizer {
|
|
389
|
+
/// Train a product quantizer using k-means on sub-vectors.
|
|
390
|
+
pub fn train(vectors: &[&[f32]], dimension: usize, config: ProductQuantizationConfig) -> Self {
|
|
391
|
+
assert!(!vectors.is_empty(), "need at least one vector to train PQ");
|
|
392
|
+
assert!(
|
|
393
|
+
dimension % config.num_sub_vectors == 0,
|
|
394
|
+
"dimension ({dimension}) must be divisible by num_sub_vectors ({})",
|
|
395
|
+
config.num_sub_vectors
|
|
396
|
+
);
|
|
397
|
+
assert!(config.num_centroids <= 256, "num_centroids must be <= 256");
|
|
398
|
+
|
|
399
|
+
let sub_dimension = dimension / config.num_sub_vectors;
|
|
400
|
+
let mut codebooks = Vec::with_capacity(config.num_sub_vectors);
|
|
401
|
+
|
|
402
|
+
for sub_idx in 0..config.num_sub_vectors {
|
|
403
|
+
let offset = sub_idx * sub_dimension;
|
|
404
|
+
// Extract sub-vectors for this partition
|
|
405
|
+
let sub_vectors: Vec<&[f32]> = vectors
|
|
406
|
+
.iter()
|
|
407
|
+
.map(|v| &v[offset..offset + sub_dimension])
|
|
408
|
+
.collect();
|
|
409
|
+
|
|
410
|
+
let centroids = kmeans(
|
|
411
|
+
&sub_vectors,
|
|
412
|
+
sub_dimension,
|
|
413
|
+
config.num_centroids,
|
|
414
|
+
config.training_iterations,
|
|
415
|
+
);
|
|
416
|
+
codebooks.push(centroids);
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
// Encode all training vectors
|
|
420
|
+
let mut codes = Vec::with_capacity(vectors.len() * config.num_sub_vectors);
|
|
421
|
+
for vector in vectors {
|
|
422
|
+
for sub_idx in 0..config.num_sub_vectors {
|
|
423
|
+
let offset = sub_idx * sub_dimension;
|
|
424
|
+
let sub_vector = &vector[offset..offset + sub_dimension];
|
|
425
|
+
let nearest = find_nearest_centroid(sub_vector, &codebooks[sub_idx]);
|
|
426
|
+
codes.push(nearest as u8);
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
Self {
|
|
431
|
+
dimension,
|
|
432
|
+
num_sub_vectors: config.num_sub_vectors,
|
|
433
|
+
sub_dimension,
|
|
434
|
+
num_centroids: config.num_centroids,
|
|
435
|
+
codebooks,
|
|
436
|
+
codes,
|
|
437
|
+
count: vectors.len(),
|
|
438
|
+
config,
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
/// Add vectors to the PQ index.
|
|
443
|
+
pub fn add_vectors(&mut self, vectors: &[&[f32]]) {
|
|
444
|
+
for vector in vectors {
|
|
445
|
+
assert_eq!(vector.len(), self.dimension);
|
|
446
|
+
for sub_idx in 0..self.num_sub_vectors {
|
|
447
|
+
let offset = sub_idx * self.sub_dimension;
|
|
448
|
+
let sub_vector = &vector[offset..offset + self.sub_dimension];
|
|
449
|
+
let nearest = find_nearest_centroid(sub_vector, &self.codebooks[sub_idx]);
|
|
450
|
+
self.codes.push(nearest as u8);
|
|
451
|
+
}
|
|
452
|
+
}
|
|
453
|
+
self.count += vectors.len();
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
/// Compute asymmetric distance table for a query. This precomputes distances
|
|
457
|
+
/// from the query sub-vectors to all centroids, enabling fast approximate
|
|
458
|
+
/// distance computation.
|
|
459
|
+
pub fn compute_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
|
|
460
|
+
assert_eq!(query.len(), self.dimension);
|
|
461
|
+
let mut table = Vec::with_capacity(self.num_sub_vectors);
|
|
462
|
+
|
|
463
|
+
for sub_idx in 0..self.num_sub_vectors {
|
|
464
|
+
let offset = sub_idx * self.sub_dimension;
|
|
465
|
+
let query_sub = &query[offset..offset + self.sub_dimension];
|
|
466
|
+
let distances: Vec<f32> = self.codebooks[sub_idx]
|
|
467
|
+
.iter()
|
|
468
|
+
.map(|centroid| l2_distance_sq(query_sub, centroid))
|
|
469
|
+
.collect();
|
|
470
|
+
table.push(distances);
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
table
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
/// Search using asymmetric distance computation (ADC).
|
|
477
|
+
/// Returns candidate indices sorted by approximate L2 distance.
|
|
478
|
+
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
|
|
479
|
+
let rescore_count = (top_k * self.config.rescore_multiplier)
|
|
480
|
+
.max(100)
|
|
481
|
+
.min(self.count);
|
|
482
|
+
let distance_table = self.compute_distance_table(query);
|
|
483
|
+
|
|
484
|
+
let mut distances: Vec<(usize, f32)> = (0..self.count)
|
|
485
|
+
.map(|idx| {
|
|
486
|
+
let code_offset = idx * self.num_sub_vectors;
|
|
487
|
+
let mut dist = 0.0_f32;
|
|
488
|
+
for sub_idx in 0..self.num_sub_vectors {
|
|
489
|
+
let centroid_idx = self.codes[code_offset + sub_idx] as usize;
|
|
490
|
+
dist += distance_table[sub_idx][centroid_idx];
|
|
491
|
+
}
|
|
492
|
+
(idx, dist)
|
|
493
|
+
})
|
|
494
|
+
.collect();
|
|
495
|
+
|
|
496
|
+
// Sort by distance (ascending)
|
|
497
|
+
distances.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
|
|
498
|
+
distances.truncate(rescore_count);
|
|
499
|
+
distances
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
/// Rebuild codes from vectors.
|
|
503
|
+
pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
|
|
504
|
+
self.codes.clear();
|
|
505
|
+
self.codes.reserve(vectors.len() * self.num_sub_vectors);
|
|
506
|
+
for vector in vectors {
|
|
507
|
+
for sub_idx in 0..self.num_sub_vectors {
|
|
508
|
+
let offset = sub_idx * self.sub_dimension;
|
|
509
|
+
let sub_vector = &vector[offset..offset + self.sub_dimension];
|
|
510
|
+
let nearest = find_nearest_centroid(sub_vector, &self.codebooks[sub_idx]);
|
|
511
|
+
self.codes.push(nearest as u8);
|
|
512
|
+
}
|
|
513
|
+
}
|
|
514
|
+
self.count = vectors.len();
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
/// Serialize the codebooks and parameters.
|
|
518
|
+
pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
|
|
519
|
+
// Tag byte: 3 = product
|
|
520
|
+
writer.write_all(&[3u8])?;
|
|
521
|
+
write_usize(writer, self.dimension)?;
|
|
522
|
+
write_usize(writer, self.num_sub_vectors)?;
|
|
523
|
+
write_usize(writer, self.num_centroids)?;
|
|
524
|
+
write_usize(writer, self.config.training_iterations)?;
|
|
525
|
+
write_usize(writer, self.config.rescore_multiplier)?;
|
|
526
|
+
|
|
527
|
+
// Write codebooks
|
|
528
|
+
for sub_codebook in &self.codebooks {
|
|
529
|
+
for centroid in sub_codebook {
|
|
530
|
+
for &val in centroid {
|
|
531
|
+
writer.write_all(&val.to_le_bytes())?;
|
|
532
|
+
}
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
Ok(())
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
/// Deserialize codebooks and parameters.
|
|
539
|
+
pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
|
|
540
|
+
let dimension = read_usize(reader)?;
|
|
541
|
+
let num_sub_vectors = read_usize(reader)?;
|
|
542
|
+
let num_centroids = read_usize(reader)?;
|
|
543
|
+
let training_iterations = read_usize(reader)?;
|
|
544
|
+
let rescore_multiplier = read_usize(reader)?;
|
|
545
|
+
let sub_dimension = dimension / num_sub_vectors;
|
|
546
|
+
|
|
547
|
+
// Read codebooks
|
|
548
|
+
let mut codebooks = Vec::with_capacity(num_sub_vectors);
|
|
549
|
+
for _ in 0..num_sub_vectors {
|
|
550
|
+
let mut sub_codebook = Vec::with_capacity(num_centroids);
|
|
551
|
+
for _ in 0..num_centroids {
|
|
552
|
+
let mut centroid = vec![0.0_f32; sub_dimension];
|
|
553
|
+
for v in &mut centroid {
|
|
554
|
+
let mut buf = [0u8; 4];
|
|
555
|
+
reader.read_exact(&mut buf)?;
|
|
556
|
+
*v = f32::from_le_bytes(buf);
|
|
557
|
+
}
|
|
558
|
+
sub_codebook.push(centroid);
|
|
559
|
+
}
|
|
560
|
+
codebooks.push(sub_codebook);
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
Ok(Self {
|
|
564
|
+
dimension,
|
|
565
|
+
num_sub_vectors,
|
|
566
|
+
sub_dimension,
|
|
567
|
+
num_centroids,
|
|
568
|
+
codebooks,
|
|
569
|
+
codes: Vec::new(),
|
|
570
|
+
count: 0,
|
|
571
|
+
config: ProductQuantizationConfig {
|
|
572
|
+
num_sub_vectors,
|
|
573
|
+
num_centroids,
|
|
574
|
+
training_iterations,
|
|
575
|
+
rescore_multiplier,
|
|
576
|
+
},
|
|
577
|
+
})
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
// ---------------------------------------------------------------------------
|
|
582
|
+
// Two-Bit Quantization (ColBERTv2-style)
|
|
583
|
+
// ---------------------------------------------------------------------------
|
|
584
|
+
|
|
585
|
+
/// Configuration for 2-bit multi-vector quantization (ColBERTv2-style).
|
|
586
|
+
#[derive(Clone, Debug, PartialEq)]
|
|
587
|
+
pub struct TwoBitQuantizationConfig {
|
|
588
|
+
/// Number of top candidate docs from quantized search to rescore with
|
|
589
|
+
/// exact float32 MaxSim. Default: 4x top_k (minimum 50).
|
|
590
|
+
pub rescore_multiplier: usize,
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
impl Default for TwoBitQuantizationConfig {
|
|
594
|
+
fn default() -> Self {
|
|
595
|
+
Self {
|
|
596
|
+
rescore_multiplier: 4,
|
|
597
|
+
}
|
|
598
|
+
}
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
/// Two-bit quantizer: maps each dimension to 2 bits (4 levels) using
|
|
602
|
+
/// per-dimension quartile boundaries. ~16x compression vs float32.
|
|
603
|
+
/// Designed for ColBERT-style token-level vectors.
|
|
604
|
+
#[derive(Clone, Debug)]
|
|
605
|
+
pub struct TwoBitQuantizer {
|
|
606
|
+
pub dimension: usize,
|
|
607
|
+
/// Per-dimension boundary values: [q25, q50, q75] for each dimension.
|
|
608
|
+
/// Shape: dimension * 3.
|
|
609
|
+
pub boundaries: Vec<f32>,
|
|
610
|
+
/// Quantized codes: 2 bits per dimension, packed into bytes.
|
|
611
|
+
/// Each vector uses ceil(dimension / 4) bytes.
|
|
612
|
+
pub codes: Vec<u8>,
|
|
613
|
+
/// Number of quantized vectors.
|
|
614
|
+
pub count: usize,
|
|
615
|
+
/// Bytes per quantized vector.
|
|
616
|
+
pub bytes_per_vector: usize,
|
|
617
|
+
pub config: TwoBitQuantizationConfig,
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
impl TwoBitQuantizer {
|
|
621
|
+
/// Train a 2-bit quantizer by computing per-dimension quartiles.
|
|
622
|
+
pub fn train(
|
|
623
|
+
vectors: &[&[f32]],
|
|
624
|
+
dimension: usize,
|
|
625
|
+
config: TwoBitQuantizationConfig,
|
|
626
|
+
) -> Self {
|
|
627
|
+
assert!(!vectors.is_empty(), "need at least one vector to train");
|
|
628
|
+
|
|
629
|
+
// Collect values per dimension and compute quartile boundaries
|
|
630
|
+
let mut boundaries = Vec::with_capacity(dimension * 3);
|
|
631
|
+
for d in 0..dimension {
|
|
632
|
+
let mut values: Vec<f32> = vectors.iter().map(|v| v[d]).collect();
|
|
633
|
+
values.sort_unstable_by(|a, b| a.total_cmp(b));
|
|
634
|
+
let n = values.len();
|
|
635
|
+
let q25 = values[n / 4];
|
|
636
|
+
let q50 = values[n / 2];
|
|
637
|
+
let q75 = values[(3 * n) / 4];
|
|
638
|
+
boundaries.push(q25);
|
|
639
|
+
boundaries.push(q50);
|
|
640
|
+
boundaries.push(q75);
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
let bytes_per_vector = (dimension + 3) / 4;
|
|
644
|
+
let mut codes = Vec::with_capacity(vectors.len() * bytes_per_vector);
|
|
645
|
+
for vector in vectors {
|
|
646
|
+
codes.extend_from_slice(&quantize_two_bit(vector, &boundaries, bytes_per_vector));
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
Self {
|
|
650
|
+
dimension,
|
|
651
|
+
boundaries,
|
|
652
|
+
codes,
|
|
653
|
+
count: vectors.len(),
|
|
654
|
+
bytes_per_vector,
|
|
655
|
+
config,
|
|
656
|
+
}
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
/// Quantize a single vector to 2-bit codes.
|
|
660
|
+
pub fn quantize_vector(&self, vector: &[f32]) -> Vec<u8> {
|
|
661
|
+
quantize_two_bit(vector, &self.boundaries, self.bytes_per_vector)
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
/// Compute approximate dot product between a 2-bit quantized query and
|
|
665
|
+
/// a stored quantized vector. Returns a score where higher = more similar.
|
|
666
|
+
pub fn approx_dot(&self, query_codes: &[u8], idx: usize) -> i32 {
|
|
667
|
+
let offset = idx * self.bytes_per_vector;
|
|
668
|
+
let stored = &self.codes[offset..offset + self.bytes_per_vector];
|
|
669
|
+
two_bit_approx_dot(query_codes, stored, self.dimension)
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
/// Search for top-k candidates using approximate 2-bit dot products.
|
|
673
|
+
/// Returns (index, approx_score) pairs sorted best-first.
|
|
674
|
+
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, i32)> {
|
|
675
|
+
let rescore_count = (top_k * self.config.rescore_multiplier)
|
|
676
|
+
.max(50)
|
|
677
|
+
.min(self.count);
|
|
678
|
+
let query_codes = self.quantize_vector(query);
|
|
679
|
+
|
|
680
|
+
let mut scores: Vec<(usize, i32)> = (0..self.count)
|
|
681
|
+
.map(|idx| (idx, self.approx_dot(&query_codes, idx)))
|
|
682
|
+
.collect();
|
|
683
|
+
|
|
684
|
+
scores.sort_unstable_by(|a, b| b.1.cmp(&a.1));
|
|
685
|
+
scores.truncate(rescore_count);
|
|
686
|
+
scores
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
/// Rebuild codes from vectors.
|
|
690
|
+
pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
|
|
691
|
+
self.codes.clear();
|
|
692
|
+
self.codes.reserve(vectors.len() * self.bytes_per_vector);
|
|
693
|
+
for vector in vectors {
|
|
694
|
+
self.codes
|
|
695
|
+
.extend_from_slice(&quantize_two_bit(vector, &self.boundaries, self.bytes_per_vector));
|
|
696
|
+
}
|
|
697
|
+
self.count = vectors.len();
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
/// Serialize parameters (boundaries only, codes rebuilt on load).
|
|
701
|
+
pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
|
|
702
|
+
// Tag byte: 4 = two_bit
|
|
703
|
+
writer.write_all(&[4u8])?;
|
|
704
|
+
write_usize(writer, self.dimension)?;
|
|
705
|
+
write_usize(writer, self.config.rescore_multiplier)?;
|
|
706
|
+
// Write boundaries (dimension * 3 floats)
|
|
707
|
+
for &b in &self.boundaries {
|
|
708
|
+
writer.write_all(&b.to_le_bytes())?;
|
|
709
|
+
}
|
|
710
|
+
Ok(())
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
/// Deserialize parameters.
|
|
714
|
+
pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
|
|
715
|
+
let dimension = read_usize(reader)?;
|
|
716
|
+
let rescore_multiplier = read_usize(reader)?;
|
|
717
|
+
let mut boundaries = vec![0.0_f32; dimension * 3];
|
|
718
|
+
for b in &mut boundaries {
|
|
719
|
+
let mut buf = [0u8; 4];
|
|
720
|
+
reader.read_exact(&mut buf)?;
|
|
721
|
+
*b = f32::from_le_bytes(buf);
|
|
722
|
+
}
|
|
723
|
+
let bytes_per_vector = (dimension + 3) / 4;
|
|
724
|
+
Ok(Self {
|
|
725
|
+
dimension,
|
|
726
|
+
boundaries,
|
|
727
|
+
codes: Vec::new(),
|
|
728
|
+
count: 0,
|
|
729
|
+
bytes_per_vector,
|
|
730
|
+
config: TwoBitQuantizationConfig { rescore_multiplier },
|
|
731
|
+
})
|
|
732
|
+
}
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
// ---------------------------------------------------------------------------
|
|
736
|
+
// Multi-vector quantized index (for ColBERT token-level search)
|
|
737
|
+
// ---------------------------------------------------------------------------
|
|
738
|
+
|
|
739
|
+
/// Configuration for multi-vector quantization.
|
|
740
|
+
#[derive(Clone, Debug, PartialEq)]
|
|
741
|
+
pub enum MultiVectorQuantizationConfig {
|
|
742
|
+
TwoBit(TwoBitQuantizationConfig),
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
/// A quantized index for multi-vector (late interaction) search.
|
|
746
|
+
/// Stores all token vectors from all documents in a flat quantized array,
|
|
747
|
+
/// with a mapping from document index to token range.
|
|
748
|
+
#[derive(Clone, Debug)]
|
|
749
|
+
pub struct MultiVectorQuantizedIndex {
|
|
750
|
+
pub quantizer: TwoBitQuantizer,
|
|
751
|
+
/// For each document: (start_index, count) into the quantized vector array.
|
|
752
|
+
pub doc_ranges: Vec<(usize, usize)>,
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
impl MultiVectorQuantizedIndex {
|
|
756
|
+
/// Build a multi-vector quantized index from per-document token vectors.
|
|
757
|
+
/// `doc_token_vectors[i]` is a slice of token-level vectors for document i.
|
|
758
|
+
pub fn build(
|
|
759
|
+
doc_token_vectors: &[&[Vec<f32>]],
|
|
760
|
+
token_dimension: usize,
|
|
761
|
+
config: &MultiVectorQuantizationConfig,
|
|
762
|
+
) -> Self {
|
|
763
|
+
// Flatten all token vectors for training
|
|
764
|
+
let all_tokens: Vec<&[f32]> = doc_token_vectors
|
|
765
|
+
.iter()
|
|
766
|
+
.flat_map(|tokens| tokens.iter().map(|v| v.as_slice()))
|
|
767
|
+
.collect();
|
|
768
|
+
|
|
769
|
+
let MultiVectorQuantizationConfig::TwoBit(cfg) = config;
|
|
770
|
+
|
|
771
|
+
let quantizer = if all_tokens.is_empty() {
|
|
772
|
+
// Empty case: create minimal quantizer
|
|
773
|
+
TwoBitQuantizer {
|
|
774
|
+
dimension: token_dimension,
|
|
775
|
+
boundaries: vec![0.0; token_dimension * 3],
|
|
776
|
+
codes: Vec::new(),
|
|
777
|
+
count: 0,
|
|
778
|
+
bytes_per_vector: (token_dimension + 3) / 4,
|
|
779
|
+
config: cfg.clone(),
|
|
780
|
+
}
|
|
781
|
+
} else {
|
|
782
|
+
TwoBitQuantizer::train(&all_tokens, token_dimension, cfg.clone())
|
|
783
|
+
};
|
|
784
|
+
|
|
785
|
+
// Build doc_ranges
|
|
786
|
+
let mut doc_ranges = Vec::with_capacity(doc_token_vectors.len());
|
|
787
|
+
let mut offset = 0;
|
|
788
|
+
for tokens in doc_token_vectors {
|
|
789
|
+
doc_ranges.push((offset, tokens.len()));
|
|
790
|
+
offset += tokens.len();
|
|
791
|
+
}
|
|
792
|
+
|
|
793
|
+
Self {
|
|
794
|
+
quantizer,
|
|
795
|
+
doc_ranges,
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
|
|
799
|
+
/// Compute approximate MaxSim score for a document given query token codes.
|
|
800
|
+
/// For each query token, finds the max approximate dot with any document token.
|
|
801
|
+
pub fn approx_maxsim(&self, query_codes: &[Vec<u8>], doc_idx: usize) -> i32 {
|
|
802
|
+
let (start, count) = self.doc_ranges[doc_idx];
|
|
803
|
+
if count == 0 || query_codes.is_empty() {
|
|
804
|
+
return 0;
|
|
805
|
+
}
|
|
806
|
+
let mut total = 0i32;
|
|
807
|
+
for q_code in query_codes {
|
|
808
|
+
let mut best = i32::MIN;
|
|
809
|
+
for i in start..start + count {
|
|
810
|
+
let score = two_bit_approx_dot(
|
|
811
|
+
q_code,
|
|
812
|
+
&self.quantizer.codes[i * self.quantizer.bytes_per_vector
|
|
813
|
+
..(i + 1) * self.quantizer.bytes_per_vector],
|
|
814
|
+
self.quantizer.dimension,
|
|
815
|
+
);
|
|
816
|
+
if score > best {
|
|
817
|
+
best = score;
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
total += best;
|
|
821
|
+
}
|
|
822
|
+
total
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
/// Search: returns candidate document indices sorted by approximate MaxSim.
|
|
826
|
+
pub fn search(&self, query_tokens: &[&[f32]], top_k: usize) -> Vec<usize> {
|
|
827
|
+
let rescore_count = (top_k * self.quantizer.config.rescore_multiplier)
|
|
828
|
+
.max(50)
|
|
829
|
+
.min(self.doc_ranges.len());
|
|
830
|
+
if query_tokens.is_empty() || self.doc_ranges.is_empty() {
|
|
831
|
+
return Vec::new();
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
let query_codes: Vec<Vec<u8>> = query_tokens
|
|
835
|
+
.iter()
|
|
836
|
+
.map(|t| self.quantizer.quantize_vector(t))
|
|
837
|
+
.collect();
|
|
838
|
+
|
|
839
|
+
let mut scores: Vec<(usize, i32)> = (0..self.doc_ranges.len())
|
|
840
|
+
.map(|doc_idx| (doc_idx, self.approx_maxsim(&query_codes, doc_idx)))
|
|
841
|
+
.collect();
|
|
842
|
+
|
|
843
|
+
scores.sort_unstable_by(|a, b| b.1.cmp(&a.1));
|
|
844
|
+
scores.truncate(rescore_count);
|
|
845
|
+
scores.into_iter().map(|(idx, _)| idx).collect()
|
|
846
|
+
}
|
|
847
|
+
|
|
848
|
+
/// Rebuild from document token vectors (after loading parameters from disk).
|
|
849
|
+
pub fn rebuild(
|
|
850
|
+
&mut self,
|
|
851
|
+
doc_token_vectors: &[&[Vec<f32>]],
|
|
852
|
+
) {
|
|
853
|
+
let all_tokens: Vec<&[f32]> = doc_token_vectors
|
|
854
|
+
.iter()
|
|
855
|
+
.flat_map(|tokens| tokens.iter().map(|v| v.as_slice()))
|
|
856
|
+
.collect();
|
|
857
|
+
self.quantizer.rebuild_codes(&all_tokens);
|
|
858
|
+
|
|
859
|
+
self.doc_ranges.clear();
|
|
860
|
+
let mut offset = 0;
|
|
861
|
+
for tokens in doc_token_vectors {
|
|
862
|
+
self.doc_ranges.push((offset, tokens.len()));
|
|
863
|
+
offset += tokens.len();
|
|
864
|
+
}
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
/// Serialize parameters.
|
|
868
|
+
pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
|
|
869
|
+
self.quantizer.write_params(writer)?;
|
|
870
|
+
// Write doc_ranges
|
|
871
|
+
write_usize(writer, self.doc_ranges.len())?;
|
|
872
|
+
for &(start, count) in &self.doc_ranges {
|
|
873
|
+
write_usize(writer, start)?;
|
|
874
|
+
write_usize(writer, count)?;
|
|
875
|
+
}
|
|
876
|
+
Ok(())
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
/// Deserialize parameters.
|
|
880
|
+
pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
|
|
881
|
+
// Consume the tag byte written by TwoBitQuantizer::write_params
|
|
882
|
+
let mut tag = [0u8; 1];
|
|
883
|
+
reader.read_exact(&mut tag)?;
|
|
884
|
+
if tag[0] != 4 {
|
|
885
|
+
return Err(std::io::Error::new(
|
|
886
|
+
std::io::ErrorKind::InvalidData,
|
|
887
|
+
format!("expected two_bit tag (4), got {}", tag[0]),
|
|
888
|
+
));
|
|
889
|
+
}
|
|
890
|
+
let quantizer = TwoBitQuantizer::read_params(reader)?;
|
|
891
|
+
let num_docs = read_usize(reader)?;
|
|
892
|
+
let mut doc_ranges = Vec::with_capacity(num_docs);
|
|
893
|
+
for _ in 0..num_docs {
|
|
894
|
+
let start = read_usize(reader)?;
|
|
895
|
+
let count = read_usize(reader)?;
|
|
896
|
+
doc_ranges.push((start, count));
|
|
897
|
+
}
|
|
898
|
+
Ok(Self {
|
|
899
|
+
quantizer,
|
|
900
|
+
doc_ranges,
|
|
901
|
+
})
|
|
902
|
+
}
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
// ---------------------------------------------------------------------------
|
|
906
|
+
// Unified quantization index
|
|
907
|
+
// ---------------------------------------------------------------------------
|
|
908
|
+
|
|
909
|
+
/// A quantized index that wraps any of the three quantization strategies.
|
|
910
|
+
/// Used by the Database to accelerate search.
|
|
911
|
+
#[derive(Clone, Debug)]
|
|
912
|
+
pub enum QuantizedIndex {
|
|
913
|
+
Scalar(ScalarQuantizer),
|
|
914
|
+
Binary(BinaryQuantizer),
|
|
915
|
+
Product(ProductQuantizer),
|
|
916
|
+
}
|
|
917
|
+
|
|
918
|
+
impl QuantizedIndex {
|
|
919
|
+
/// Build a quantized index from vectors.
|
|
920
|
+
pub fn build(vectors: &[&[f32]], dimension: usize, config: &QuantizationConfig) -> Self {
|
|
921
|
+
match config {
|
|
922
|
+
QuantizationConfig::Scalar(cfg) => {
|
|
923
|
+
QuantizedIndex::Scalar(ScalarQuantizer::train(vectors, dimension, cfg.clone()))
|
|
924
|
+
}
|
|
925
|
+
QuantizationConfig::Binary(cfg) => {
|
|
926
|
+
let mut quantizer = BinaryQuantizer::new(dimension, cfg.clone());
|
|
927
|
+
quantizer.add_vectors(vectors);
|
|
928
|
+
QuantizedIndex::Binary(quantizer)
|
|
929
|
+
}
|
|
930
|
+
QuantizationConfig::Product(cfg) => {
|
|
931
|
+
QuantizedIndex::Product(ProductQuantizer::train(vectors, dimension, cfg.clone()))
|
|
932
|
+
}
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
/// Search the quantized index. Returns candidate indices sorted by
|
|
937
|
+
/// approximate similarity (best first), to be rescored with exact vectors.
|
|
938
|
+
pub fn search_candidates(&self, query: &[f32], top_k: usize) -> Vec<usize> {
|
|
939
|
+
match self {
|
|
940
|
+
QuantizedIndex::Scalar(q) => {
|
|
941
|
+
q.search(query, top_k).into_iter().map(|(i, _)| i).collect()
|
|
942
|
+
}
|
|
943
|
+
QuantizedIndex::Binary(q) => {
|
|
944
|
+
q.search(query, top_k).into_iter().map(|(i, _)| i).collect()
|
|
945
|
+
}
|
|
946
|
+
QuantizedIndex::Product(q) => {
|
|
947
|
+
q.search(query, top_k).into_iter().map(|(i, _)| i).collect()
|
|
948
|
+
}
|
|
949
|
+
}
|
|
950
|
+
}
|
|
951
|
+
|
|
952
|
+
/// Rebuild quantized codes from current vectors (after deserialization or updates).
|
|
953
|
+
pub fn rebuild_codes(&mut self, vectors: &[&[f32]]) {
|
|
954
|
+
match self {
|
|
955
|
+
QuantizedIndex::Scalar(q) => q.rebuild_codes(vectors),
|
|
956
|
+
QuantizedIndex::Binary(q) => q.rebuild_codes(vectors),
|
|
957
|
+
QuantizedIndex::Product(q) => q.rebuild_codes(vectors),
|
|
958
|
+
}
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
/// Get the vector count in the quantized index.
|
|
962
|
+
pub fn count(&self) -> usize {
|
|
963
|
+
match self {
|
|
964
|
+
QuantizedIndex::Scalar(q) => q.count,
|
|
965
|
+
QuantizedIndex::Binary(q) => q.count,
|
|
966
|
+
QuantizedIndex::Product(q) => q.count,
|
|
967
|
+
}
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
/// Get the rescore multiplier for this quantization strategy.
|
|
971
|
+
pub fn rescore_multiplier(&self) -> usize {
|
|
972
|
+
match self {
|
|
973
|
+
QuantizedIndex::Scalar(q) => q.config.rescore_multiplier,
|
|
974
|
+
QuantizedIndex::Binary(q) => q.config.rescore_multiplier,
|
|
975
|
+
QuantizedIndex::Product(q) => q.config.rescore_multiplier,
|
|
976
|
+
}
|
|
977
|
+
}
|
|
978
|
+
|
|
979
|
+
/// Serialize quantization parameters to a writer.
|
|
980
|
+
pub fn write_params(&self, writer: &mut impl Write) -> std::io::Result<()> {
|
|
981
|
+
match self {
|
|
982
|
+
QuantizedIndex::Scalar(q) => q.write_params(writer),
|
|
983
|
+
QuantizedIndex::Binary(q) => q.write_params(writer),
|
|
984
|
+
QuantizedIndex::Product(q) => q.write_params(writer),
|
|
985
|
+
}
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
/// Deserialize quantization parameters from a reader.
|
|
989
|
+
pub fn read_params(reader: &mut impl Read) -> std::io::Result<Self> {
|
|
990
|
+
let mut tag = [0u8; 1];
|
|
991
|
+
reader.read_exact(&mut tag)?;
|
|
992
|
+
match tag[0] {
|
|
993
|
+
1 => Ok(QuantizedIndex::Scalar(ScalarQuantizer::read_params(
|
|
994
|
+
reader,
|
|
995
|
+
)?)),
|
|
996
|
+
2 => Ok(QuantizedIndex::Binary(BinaryQuantizer::read_params(
|
|
997
|
+
reader,
|
|
998
|
+
)?)),
|
|
999
|
+
3 => Ok(QuantizedIndex::Product(ProductQuantizer::read_params(
|
|
1000
|
+
reader,
|
|
1001
|
+
)?)),
|
|
1002
|
+
other => Err(std::io::Error::new(
|
|
1003
|
+
std::io::ErrorKind::InvalidData,
|
|
1004
|
+
format!("unknown quantization tag: {other}"),
|
|
1005
|
+
)),
|
|
1006
|
+
}
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
/// Get the config used to build this index.
|
|
1010
|
+
pub fn config(&self) -> QuantizationConfig {
|
|
1011
|
+
match self {
|
|
1012
|
+
QuantizedIndex::Scalar(q) => QuantizationConfig::Scalar(q.config.clone()),
|
|
1013
|
+
QuantizedIndex::Binary(q) => QuantizationConfig::Binary(q.config.clone()),
|
|
1014
|
+
QuantizedIndex::Product(q) => QuantizationConfig::Product(q.config.clone()),
|
|
1015
|
+
}
|
|
1016
|
+
}
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1019
|
+
// ---------------------------------------------------------------------------
|
|
1020
|
+
// Internal helper functions
|
|
1021
|
+
// ---------------------------------------------------------------------------
|
|
1022
|
+
|
|
1023
|
+
/// Quantize a single f32 value to u8 using the given min and scale.
|
|
1024
|
+
#[inline]
|
|
1025
|
+
fn quantize_scalar(val: f32, min: f32, scale: f32) -> u8 {
|
|
1026
|
+
if scale == 0.0 {
|
|
1027
|
+
128 // midpoint for constant dimensions
|
|
1028
|
+
} else {
|
|
1029
|
+
((val - min) * scale).clamp(0.0, 255.0) as u8
|
|
1030
|
+
}
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
/// Approximate dot product between two u8-quantized vectors.
|
|
1034
|
+
/// Higher value = more similar (analogous to cosine similarity for normalized vectors).
|
|
1035
|
+
#[inline]
|
|
1036
|
+
fn scalar_quantized_dot(a: &[u8], b: &[u8]) -> f32 {
|
|
1037
|
+
let mut sum = 0i32;
|
|
1038
|
+
for (&ai, &bi) in a.iter().zip(b.iter()) {
|
|
1039
|
+
sum += (ai as i32) * (bi as i32);
|
|
1040
|
+
}
|
|
1041
|
+
sum as f32
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
/// Convert a float vector to a binary representation (1 bit per dimension).
|
|
1045
|
+
/// Positive values map to 1, non-positive to 0.
|
|
1046
|
+
fn binarize_vector(vector: &[f32]) -> Vec<u8> {
|
|
1047
|
+
let bytes = (vector.len() + 7) / 8;
|
|
1048
|
+
let mut result = vec![0u8; bytes];
|
|
1049
|
+
for (i, &val) in vector.iter().enumerate() {
|
|
1050
|
+
if val > 0.0 {
|
|
1051
|
+
result[i / 8] |= 1 << (i % 8);
|
|
1052
|
+
}
|
|
1053
|
+
}
|
|
1054
|
+
result
|
|
1055
|
+
}
|
|
1056
|
+
|
|
1057
|
+
/// Compute Hamming distance between two binary vectors.
|
|
1058
|
+
#[inline]
|
|
1059
|
+
fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
|
|
1060
|
+
let mut dist = 0u32;
|
|
1061
|
+
for (&ai, &bi) in a.iter().zip(b.iter()) {
|
|
1062
|
+
dist += (ai ^ bi).count_ones();
|
|
1063
|
+
}
|
|
1064
|
+
dist
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
/// Quantize a float vector to 2-bit codes (4 levels per dimension).
|
|
1068
|
+
/// Level mapping: val <= q25 → 0, val <= q50 → 1, val <= q75 → 2, else → 3.
|
|
1069
|
+
/// Packed 4 dimensions per byte (least-significant bits first).
|
|
1070
|
+
fn quantize_two_bit(vector: &[f32], boundaries: &[f32], bytes_per_vector: usize) -> Vec<u8> {
|
|
1071
|
+
let mut result = vec![0u8; bytes_per_vector];
|
|
1072
|
+
for (i, &val) in vector.iter().enumerate() {
|
|
1073
|
+
let b_offset = i * 3;
|
|
1074
|
+
let level = if val <= boundaries[b_offset] {
|
|
1075
|
+
0u8
|
|
1076
|
+
} else if val <= boundaries[b_offset + 1] {
|
|
1077
|
+
1u8
|
|
1078
|
+
} else if val <= boundaries[b_offset + 2] {
|
|
1079
|
+
2u8
|
|
1080
|
+
} else {
|
|
1081
|
+
3u8
|
|
1082
|
+
};
|
|
1083
|
+
let byte_idx = i / 4;
|
|
1084
|
+
let bit_offset = (i % 4) * 2;
|
|
1085
|
+
result[byte_idx] |= level << bit_offset;
|
|
1086
|
+
}
|
|
1087
|
+
result
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
/// Approximate dot product between two 2-bit quantized vectors.
|
|
1091
|
+
/// Uses level values 0,1,2,3 as proxies for the original float magnitudes.
|
|
1092
|
+
/// Higher score = more similar.
|
|
1093
|
+
#[inline]
|
|
1094
|
+
fn two_bit_approx_dot(a: &[u8], b: &[u8], dimension: usize) -> i32 {
|
|
1095
|
+
let mut sum = 0i32;
|
|
1096
|
+
for i in 0..dimension {
|
|
1097
|
+
let byte_idx = i / 4;
|
|
1098
|
+
let bit_offset = (i % 4) * 2;
|
|
1099
|
+
let a_level = ((a[byte_idx] >> bit_offset) & 0x03) as i32;
|
|
1100
|
+
let b_level = ((b[byte_idx] >> bit_offset) & 0x03) as i32;
|
|
1101
|
+
sum += a_level * b_level;
|
|
1102
|
+
}
|
|
1103
|
+
sum
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
/// Squared L2 distance between two vectors.
|
|
1107
|
+
#[inline]
|
|
1108
|
+
fn l2_distance_sq(a: &[f32], b: &[f32]) -> f32 {
|
|
1109
|
+
a.iter()
|
|
1110
|
+
.zip(b.iter())
|
|
1111
|
+
.map(|(&ai, &bi)| {
|
|
1112
|
+
let diff = ai - bi;
|
|
1113
|
+
diff * diff
|
|
1114
|
+
})
|
|
1115
|
+
.sum()
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
/// Find the nearest centroid index for a sub-vector.
|
|
1119
|
+
fn find_nearest_centroid(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
|
|
1120
|
+
let mut best_idx = 0;
|
|
1121
|
+
let mut best_dist = f32::INFINITY;
|
|
1122
|
+
for (idx, centroid) in centroids.iter().enumerate() {
|
|
1123
|
+
let dist = l2_distance_sq(vector, centroid);
|
|
1124
|
+
if dist < best_dist {
|
|
1125
|
+
best_dist = dist;
|
|
1126
|
+
best_idx = idx;
|
|
1127
|
+
}
|
|
1128
|
+
}
|
|
1129
|
+
best_idx
|
|
1130
|
+
}
|
|
1131
|
+
|
|
1132
|
+
/// Simple k-means clustering for PQ training.
|
|
1133
|
+
fn kmeans(vectors: &[&[f32]], dimension: usize, k: usize, iterations: usize) -> Vec<Vec<f32>> {
|
|
1134
|
+
let n = vectors.len();
|
|
1135
|
+
let actual_k = k.min(n);
|
|
1136
|
+
|
|
1137
|
+
// Initialize centroids using first k vectors (or modular selection if k > n)
|
|
1138
|
+
let mut centroids: Vec<Vec<f32>> = (0..actual_k).map(|i| vectors[i % n].to_vec()).collect();
|
|
1139
|
+
|
|
1140
|
+
let mut assignments = vec![0usize; n];
|
|
1141
|
+
|
|
1142
|
+
for _ in 0..iterations {
|
|
1143
|
+
// Assign each vector to nearest centroid
|
|
1144
|
+
for (i, vector) in vectors.iter().enumerate() {
|
|
1145
|
+
assignments[i] = find_nearest_centroid(vector, ¢roids);
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
// Update centroids
|
|
1149
|
+
let mut new_centroids = vec![vec![0.0_f32; dimension]; actual_k];
|
|
1150
|
+
let mut counts = vec![0usize; actual_k];
|
|
1151
|
+
|
|
1152
|
+
for (i, vector) in vectors.iter().enumerate() {
|
|
1153
|
+
let cluster = assignments[i];
|
|
1154
|
+
counts[cluster] += 1;
|
|
1155
|
+
for (j, &val) in vector.iter().enumerate() {
|
|
1156
|
+
new_centroids[cluster][j] += val;
|
|
1157
|
+
}
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
for (cluster, centroid) in new_centroids.iter_mut().enumerate() {
|
|
1161
|
+
if counts[cluster] > 0 {
|
|
1162
|
+
for val in centroid.iter_mut() {
|
|
1163
|
+
*val /= counts[cluster] as f32;
|
|
1164
|
+
}
|
|
1165
|
+
} else {
|
|
1166
|
+
// Keep old centroid for empty clusters
|
|
1167
|
+
centroid.copy_from_slice(¢roids[cluster]);
|
|
1168
|
+
}
|
|
1169
|
+
}
|
|
1170
|
+
|
|
1171
|
+
centroids = new_centroids;
|
|
1172
|
+
}
|
|
1173
|
+
|
|
1174
|
+
// Pad with zeros if actual_k < k (shouldn't happen in practice)
|
|
1175
|
+
while centroids.len() < k {
|
|
1176
|
+
centroids.push(vec![0.0_f32; dimension]);
|
|
1177
|
+
}
|
|
1178
|
+
|
|
1179
|
+
centroids
|
|
1180
|
+
}
|
|
1181
|
+
|
|
1182
|
+
// ---------------------------------------------------------------------------
|
|
1183
|
+
// Serialization helpers
|
|
1184
|
+
// ---------------------------------------------------------------------------
|
|
1185
|
+
|
|
1186
|
+
fn write_usize(writer: &mut impl Write, val: usize) -> std::io::Result<()> {
|
|
1187
|
+
writer.write_all(&(val as u64).to_le_bytes())
|
|
1188
|
+
}
|
|
1189
|
+
|
|
1190
|
+
fn read_usize(reader: &mut impl Read) -> std::io::Result<usize> {
|
|
1191
|
+
let mut buf = [0u8; 8];
|
|
1192
|
+
reader.read_exact(&mut buf)?;
|
|
1193
|
+
Ok(u64::from_le_bytes(buf) as usize)
|
|
1194
|
+
}
|
|
1195
|
+
|
|
1196
|
+
// ---------------------------------------------------------------------------
|
|
1197
|
+
// Tests
|
|
1198
|
+
// ---------------------------------------------------------------------------
|
|
1199
|
+
|
|
1200
|
+
#[cfg(test)]
|
|
1201
|
+
mod tests {
|
|
1202
|
+
use super::*;
|
|
1203
|
+
|
|
1204
|
+
fn random_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
|
|
1205
|
+
// Simple deterministic pseudo-random generator (xorshift64)
|
|
1206
|
+
let mut state = seed;
|
|
1207
|
+
(0..n)
|
|
1208
|
+
.map(|_| {
|
|
1209
|
+
(0..dim)
|
|
1210
|
+
.map(|_| {
|
|
1211
|
+
state ^= state << 13;
|
|
1212
|
+
state ^= state >> 7;
|
|
1213
|
+
state ^= state << 17;
|
|
1214
|
+
// Map to [-1, 1]
|
|
1215
|
+
(state as f32 / u64::MAX as f32) * 2.0 - 1.0
|
|
1216
|
+
})
|
|
1217
|
+
.collect()
|
|
1218
|
+
})
|
|
1219
|
+
.collect()
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
#[test]
|
|
1223
|
+
fn scalar_quantization_basic() {
|
|
1224
|
+
let vectors = random_vectors(100, 128, 42);
|
|
1225
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1226
|
+
|
|
1227
|
+
let quantizer = ScalarQuantizer::train(&refs, 128, ScalarQuantizationConfig::default());
|
|
1228
|
+
|
|
1229
|
+
assert_eq!(quantizer.count, 100);
|
|
1230
|
+
assert_eq!(quantizer.codes.len(), 100 * 128);
|
|
1231
|
+
assert_eq!(quantizer.dimension, 128);
|
|
1232
|
+
|
|
1233
|
+
// Search should return candidates
|
|
1234
|
+
let query = &vectors[0];
|
|
1235
|
+
let results = quantizer.search(query, 10);
|
|
1236
|
+
assert!(!results.is_empty());
|
|
1237
|
+
// The query vector itself should be among the top results
|
|
1238
|
+
assert!(results.iter().take(5).any(|(idx, _)| *idx == 0));
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
#[test]
|
|
1242
|
+
fn binary_quantization_basic() {
|
|
1243
|
+
let vectors = random_vectors(100, 128, 42);
|
|
1244
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1245
|
+
|
|
1246
|
+
let mut quantizer = BinaryQuantizer::new(128, BinaryQuantizationConfig::default());
|
|
1247
|
+
quantizer.add_vectors(&refs);
|
|
1248
|
+
|
|
1249
|
+
assert_eq!(quantizer.count, 100);
|
|
1250
|
+
assert_eq!(quantizer.bytes_per_vector, 16); // 128 / 8
|
|
1251
|
+
assert_eq!(quantizer.codes.len(), 100 * 16);
|
|
1252
|
+
|
|
1253
|
+
// Search should find the query itself at distance 0
|
|
1254
|
+
let query = &vectors[0];
|
|
1255
|
+
let results = quantizer.search(query, 10);
|
|
1256
|
+
assert!(!results.is_empty());
|
|
1257
|
+
assert_eq!(results[0].0, 0); // First result should be index 0
|
|
1258
|
+
assert_eq!(results[0].1, 0); // Hamming distance 0
|
|
1259
|
+
}
|
|
1260
|
+
|
|
1261
|
+
#[test]
|
|
1262
|
+
fn product_quantization_basic() {
|
|
1263
|
+
let vectors = random_vectors(200, 128, 42);
|
|
1264
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1265
|
+
|
|
1266
|
+
let config = ProductQuantizationConfig {
|
|
1267
|
+
num_sub_vectors: 8,
|
|
1268
|
+
num_centroids: 16, // Small for testing
|
|
1269
|
+
training_iterations: 5,
|
|
1270
|
+
rescore_multiplier: 10,
|
|
1271
|
+
};
|
|
1272
|
+
|
|
1273
|
+
let quantizer = ProductQuantizer::train(&refs, 128, config);
|
|
1274
|
+
|
|
1275
|
+
assert_eq!(quantizer.count, 200);
|
|
1276
|
+
assert_eq!(quantizer.num_sub_vectors, 8);
|
|
1277
|
+
assert_eq!(quantizer.sub_dimension, 16);
|
|
1278
|
+
assert_eq!(quantizer.codes.len(), 200 * 8);
|
|
1279
|
+
|
|
1280
|
+
// Search should return candidates
|
|
1281
|
+
let query = &vectors[0];
|
|
1282
|
+
let results = quantizer.search(query, 10);
|
|
1283
|
+
assert!(!results.is_empty());
|
|
1284
|
+
// Query itself should be the closest (distance 0 with its own centroids)
|
|
1285
|
+
assert_eq!(results[0].0, 0);
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
#[test]
|
|
1289
|
+
fn quantized_index_build_and_search() {
|
|
1290
|
+
let vectors = random_vectors(100, 64, 123);
|
|
1291
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1292
|
+
|
|
1293
|
+
// Test scalar
|
|
1294
|
+
let idx = QuantizedIndex::build(
|
|
1295
|
+
&refs,
|
|
1296
|
+
64,
|
|
1297
|
+
&QuantizationConfig::Scalar(ScalarQuantizationConfig::default()),
|
|
1298
|
+
);
|
|
1299
|
+
let candidates = idx.search_candidates(&vectors[0], 10);
|
|
1300
|
+
assert!(!candidates.is_empty());
|
|
1301
|
+
assert!(candidates.contains(&0));
|
|
1302
|
+
|
|
1303
|
+
// Test binary
|
|
1304
|
+
let idx = QuantizedIndex::build(
|
|
1305
|
+
&refs,
|
|
1306
|
+
64,
|
|
1307
|
+
&QuantizationConfig::Binary(BinaryQuantizationConfig::default()),
|
|
1308
|
+
);
|
|
1309
|
+
let candidates = idx.search_candidates(&vectors[0], 10);
|
|
1310
|
+
assert!(!candidates.is_empty());
|
|
1311
|
+
assert_eq!(candidates[0], 0);
|
|
1312
|
+
|
|
1313
|
+
// Test product
|
|
1314
|
+
let idx = QuantizedIndex::build(
|
|
1315
|
+
&refs,
|
|
1316
|
+
64,
|
|
1317
|
+
&QuantizationConfig::Product(ProductQuantizationConfig {
|
|
1318
|
+
num_sub_vectors: 8,
|
|
1319
|
+
num_centroids: 16,
|
|
1320
|
+
training_iterations: 5,
|
|
1321
|
+
rescore_multiplier: 10,
|
|
1322
|
+
}),
|
|
1323
|
+
);
|
|
1324
|
+
let candidates = idx.search_candidates(&vectors[0], 10);
|
|
1325
|
+
assert!(!candidates.is_empty());
|
|
1326
|
+
assert_eq!(candidates[0], 0);
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
#[test]
|
|
1330
|
+
fn serialization_roundtrip_scalar() {
|
|
1331
|
+
let vectors = random_vectors(50, 32, 99);
|
|
1332
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1333
|
+
|
|
1334
|
+
let original = QuantizedIndex::build(
|
|
1335
|
+
&refs,
|
|
1336
|
+
32,
|
|
1337
|
+
&QuantizationConfig::Scalar(ScalarQuantizationConfig {
|
|
1338
|
+
rescore_multiplier: 7,
|
|
1339
|
+
}),
|
|
1340
|
+
);
|
|
1341
|
+
|
|
1342
|
+
let mut buf = Vec::new();
|
|
1343
|
+
original.write_params(&mut buf).unwrap();
|
|
1344
|
+
|
|
1345
|
+
let mut cursor = std::io::Cursor::new(&buf);
|
|
1346
|
+
let restored = QuantizedIndex::read_params(&mut cursor).unwrap();
|
|
1347
|
+
|
|
1348
|
+
match (&original, &restored) {
|
|
1349
|
+
(QuantizedIndex::Scalar(a), QuantizedIndex::Scalar(b)) => {
|
|
1350
|
+
assert_eq!(a.dimension, b.dimension);
|
|
1351
|
+
assert_eq!(a.mins, b.mins);
|
|
1352
|
+
assert_eq!(a.maxs, b.maxs);
|
|
1353
|
+
assert_eq!(a.config, b.config);
|
|
1354
|
+
}
|
|
1355
|
+
_ => panic!("type mismatch"),
|
|
1356
|
+
}
|
|
1357
|
+
}
|
|
1358
|
+
|
|
1359
|
+
#[test]
|
|
1360
|
+
fn serialization_roundtrip_binary() {
|
|
1361
|
+
let original = QuantizedIndex::build(
|
|
1362
|
+
&[&[1.0, -1.0, 0.5, -0.3][..]],
|
|
1363
|
+
4,
|
|
1364
|
+
&QuantizationConfig::Binary(BinaryQuantizationConfig {
|
|
1365
|
+
rescore_multiplier: 12,
|
|
1366
|
+
}),
|
|
1367
|
+
);
|
|
1368
|
+
|
|
1369
|
+
let mut buf = Vec::new();
|
|
1370
|
+
original.write_params(&mut buf).unwrap();
|
|
1371
|
+
|
|
1372
|
+
let mut cursor = std::io::Cursor::new(&buf);
|
|
1373
|
+
let restored = QuantizedIndex::read_params(&mut cursor).unwrap();
|
|
1374
|
+
|
|
1375
|
+
match (&original, &restored) {
|
|
1376
|
+
(QuantizedIndex::Binary(a), QuantizedIndex::Binary(b)) => {
|
|
1377
|
+
assert_eq!(a.dimension, b.dimension);
|
|
1378
|
+
assert_eq!(a.config, b.config);
|
|
1379
|
+
}
|
|
1380
|
+
_ => panic!("type mismatch"),
|
|
1381
|
+
}
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
#[test]
|
|
1385
|
+
fn serialization_roundtrip_product() {
|
|
1386
|
+
let vectors = random_vectors(50, 32, 77);
|
|
1387
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1388
|
+
|
|
1389
|
+
let original = QuantizedIndex::build(
|
|
1390
|
+
&refs,
|
|
1391
|
+
32,
|
|
1392
|
+
&QuantizationConfig::Product(ProductQuantizationConfig {
|
|
1393
|
+
num_sub_vectors: 4,
|
|
1394
|
+
num_centroids: 8,
|
|
1395
|
+
training_iterations: 3,
|
|
1396
|
+
rescore_multiplier: 5,
|
|
1397
|
+
}),
|
|
1398
|
+
);
|
|
1399
|
+
|
|
1400
|
+
let mut buf = Vec::new();
|
|
1401
|
+
original.write_params(&mut buf).unwrap();
|
|
1402
|
+
|
|
1403
|
+
let mut cursor = std::io::Cursor::new(&buf);
|
|
1404
|
+
let restored = QuantizedIndex::read_params(&mut cursor).unwrap();
|
|
1405
|
+
|
|
1406
|
+
match (&original, &restored) {
|
|
1407
|
+
(QuantizedIndex::Product(a), QuantizedIndex::Product(b)) => {
|
|
1408
|
+
assert_eq!(a.dimension, b.dimension);
|
|
1409
|
+
assert_eq!(a.num_sub_vectors, b.num_sub_vectors);
|
|
1410
|
+
assert_eq!(a.num_centroids, b.num_centroids);
|
|
1411
|
+
assert_eq!(a.codebooks.len(), b.codebooks.len());
|
|
1412
|
+
for (ca, cb) in a.codebooks.iter().zip(b.codebooks.iter()) {
|
|
1413
|
+
for (va, vb) in ca.iter().zip(cb.iter()) {
|
|
1414
|
+
assert_eq!(va.len(), vb.len());
|
|
1415
|
+
for (&fa, &fb) in va.iter().zip(vb.iter()) {
|
|
1416
|
+
assert!((fa - fb).abs() < 1e-6);
|
|
1417
|
+
}
|
|
1418
|
+
}
|
|
1419
|
+
}
|
|
1420
|
+
assert_eq!(a.config, b.config);
|
|
1421
|
+
}
|
|
1422
|
+
_ => panic!("type mismatch"),
|
|
1423
|
+
}
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
#[test]
|
|
1427
|
+
fn hamming_distance_correctness() {
|
|
1428
|
+
assert_eq!(hamming_distance(&[0b00000000], &[0b00000000]), 0);
|
|
1429
|
+
assert_eq!(hamming_distance(&[0b11111111], &[0b00000000]), 8);
|
|
1430
|
+
assert_eq!(hamming_distance(&[0b10101010], &[0b01010101]), 8);
|
|
1431
|
+
assert_eq!(hamming_distance(&[0b10101010], &[0b10101010]), 0);
|
|
1432
|
+
assert_eq!(hamming_distance(&[0b10000000], &[0b00000000]), 1);
|
|
1433
|
+
}
|
|
1434
|
+
|
|
1435
|
+
#[test]
|
|
1436
|
+
fn binarize_vector_correctness() {
|
|
1437
|
+
let v = vec![1.0, -1.0, 0.5, -0.5, 0.0, 0.1, -0.1, 0.9];
|
|
1438
|
+
let binary = binarize_vector(&v);
|
|
1439
|
+
assert_eq!(binary.len(), 1);
|
|
1440
|
+
// Bit 0: 1.0 > 0 -> 1
|
|
1441
|
+
// Bit 1: -1.0 -> 0
|
|
1442
|
+
// Bit 2: 0.5 > 0 -> 1
|
|
1443
|
+
// Bit 3: -0.5 -> 0
|
|
1444
|
+
// Bit 4: 0.0 -> 0 (not strictly positive)
|
|
1445
|
+
// Bit 5: 0.1 > 0 -> 1
|
|
1446
|
+
// Bit 6: -0.1 -> 0
|
|
1447
|
+
// Bit 7: 0.9 > 0 -> 1
|
|
1448
|
+
assert_eq!(binary[0], 0b10100101);
|
|
1449
|
+
}
|
|
1450
|
+
|
|
1451
|
+
#[test]
|
|
1452
|
+
fn two_bit_quantization_basic() {
|
|
1453
|
+
let vectors = random_vectors(100, 64, 42);
|
|
1454
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1455
|
+
|
|
1456
|
+
let config = TwoBitQuantizationConfig {
|
|
1457
|
+
rescore_multiplier: 4,
|
|
1458
|
+
};
|
|
1459
|
+
let quantizer = TwoBitQuantizer::train(&refs, 64, config);
|
|
1460
|
+
|
|
1461
|
+
assert_eq!(quantizer.dimension, 64);
|
|
1462
|
+
assert_eq!(quantizer.count, 100);
|
|
1463
|
+
assert_eq!(quantizer.bytes_per_vector, 16); // 64 dims * 2 bits / 8 = 16
|
|
1464
|
+
assert_eq!(quantizer.boundaries.len(), 64 * 3);
|
|
1465
|
+
|
|
1466
|
+
// Search should return candidates including the query itself
|
|
1467
|
+
let results = quantizer.search(&vectors[0], 10);
|
|
1468
|
+
assert!(!results.is_empty());
|
|
1469
|
+
assert!(results.iter().take(5).any(|(idx, _)| *idx == 0));
|
|
1470
|
+
}
|
|
1471
|
+
|
|
1472
|
+
#[test]
|
|
1473
|
+
fn two_bit_quantize_and_approx_dot() {
|
|
1474
|
+
// Manually test quantization of a small vector
|
|
1475
|
+
let boundaries = vec![
|
|
1476
|
+
-0.5, 0.0, 0.5, // dim 0: quartiles
|
|
1477
|
+
-0.5, 0.0, 0.5, // dim 1
|
|
1478
|
+
-0.5, 0.0, 0.5, // dim 2
|
|
1479
|
+
-0.5, 0.0, 0.5, // dim 3
|
|
1480
|
+
];
|
|
1481
|
+
let bytes_per_vector = 1; // 4 dims * 2 bits = 8 bits = 1 byte
|
|
1482
|
+
|
|
1483
|
+
// Vector with values that map to different quantization levels
|
|
1484
|
+
let v1 = [-1.0, -0.25, 0.25, 1.0]; // levels: 0, 1, 2, 3
|
|
1485
|
+
let v2 = [-1.0, -0.25, 0.25, 1.0]; // levels: 0, 1, 2, 3
|
|
1486
|
+
|
|
1487
|
+
let q1 = quantize_two_bit(&v1, &boundaries, bytes_per_vector);
|
|
1488
|
+
let q2 = quantize_two_bit(&v2, &boundaries, bytes_per_vector);
|
|
1489
|
+
|
|
1490
|
+
// Same vectors should have the maximum approx dot product
|
|
1491
|
+
let dot = two_bit_approx_dot(&q1, &q2, 4);
|
|
1492
|
+
assert!(dot > 0); // 0*0 + 1*1 + 2*2 + 3*3 = 0 + 1 + 4 + 9 = 14
|
|
1493
|
+
assert_eq!(dot, 14);
|
|
1494
|
+
}
|
|
1495
|
+
|
|
1496
|
+
#[test]
|
|
1497
|
+
fn two_bit_serialization_roundtrip() {
|
|
1498
|
+
use std::io::Read;
|
|
1499
|
+
|
|
1500
|
+
let vectors = random_vectors(50, 32, 99);
|
|
1501
|
+
let refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
|
|
1502
|
+
|
|
1503
|
+
let config = TwoBitQuantizationConfig {
|
|
1504
|
+
rescore_multiplier: 6,
|
|
1505
|
+
};
|
|
1506
|
+
let original = TwoBitQuantizer::train(&refs, 32, config);
|
|
1507
|
+
|
|
1508
|
+
let mut buf = Vec::new();
|
|
1509
|
+
original.write_params(&mut buf).unwrap();
|
|
1510
|
+
|
|
1511
|
+
let mut cursor = std::io::Cursor::new(&buf);
|
|
1512
|
+
// Consume the tag byte written by write_params
|
|
1513
|
+
let mut tag = [0u8; 1];
|
|
1514
|
+
cursor.read_exact(&mut tag).unwrap();
|
|
1515
|
+
assert_eq!(tag[0], 4);
|
|
1516
|
+
let restored = TwoBitQuantizer::read_params(&mut cursor).unwrap();
|
|
1517
|
+
|
|
1518
|
+
assert_eq!(original.dimension, restored.dimension);
|
|
1519
|
+
assert_eq!(original.boundaries.len(), restored.boundaries.len());
|
|
1520
|
+
for (a, b) in original.boundaries.iter().zip(restored.boundaries.iter()) {
|
|
1521
|
+
assert!((a - b).abs() < 1e-6);
|
|
1522
|
+
}
|
|
1523
|
+
assert_eq!(original.config.rescore_multiplier, restored.config.rescore_multiplier);
|
|
1524
|
+
}
|
|
1525
|
+
|
|
1526
|
+
#[test]
|
|
1527
|
+
fn multi_vector_quantized_index_basic() {
|
|
1528
|
+
// Create 5 "documents", each with 3-5 token vectors of dimension 16
|
|
1529
|
+
let mut doc_tokens: Vec<Vec<Vec<f32>>> = Vec::new();
|
|
1530
|
+
for doc_idx in 0..5 {
|
|
1531
|
+
let n_tokens = 3 + (doc_idx % 3); // 3, 4, 5, 3, 4 tokens
|
|
1532
|
+
let tokens = random_vectors(n_tokens, 16, 100 + doc_idx as u64);
|
|
1533
|
+
doc_tokens.push(tokens);
|
|
1534
|
+
}
|
|
1535
|
+
|
|
1536
|
+
let doc_refs: Vec<&[Vec<f32>]> = doc_tokens.iter().map(|v| v.as_slice()).collect();
|
|
1537
|
+
let config = MultiVectorQuantizationConfig::TwoBit(TwoBitQuantizationConfig {
|
|
1538
|
+
rescore_multiplier: 4,
|
|
1539
|
+
});
|
|
1540
|
+
|
|
1541
|
+
let index = MultiVectorQuantizedIndex::build(&doc_refs, 16, &config);
|
|
1542
|
+
|
|
1543
|
+
assert_eq!(index.doc_ranges.len(), 5);
|
|
1544
|
+
// Total token count: 3+4+5+3+4 = 19
|
|
1545
|
+
let total_tokens: usize = index.doc_ranges.iter().map(|(_, count)| count).sum();
|
|
1546
|
+
assert_eq!(total_tokens, 19);
|
|
1547
|
+
|
|
1548
|
+
// Search with a query that matches document 0's tokens
|
|
1549
|
+
let query_tokens: Vec<&[f32]> = doc_tokens[0].iter().map(Vec::as_slice).collect();
|
|
1550
|
+
let results = index.search(&query_tokens, 3);
|
|
1551
|
+
assert!(!results.is_empty());
|
|
1552
|
+
// Document 0 should be among top results (its own tokens should
|
|
1553
|
+
// score highest MaxSim against themselves)
|
|
1554
|
+
assert!(results.iter().take(3).any(|&idx| idx == 0));
|
|
1555
|
+
}
|
|
1556
|
+
|
|
1557
|
+
#[test]
|
|
1558
|
+
fn multi_vector_quantized_index_serialization_roundtrip() {
|
|
1559
|
+
let mut doc_tokens: Vec<Vec<Vec<f32>>> = Vec::new();
|
|
1560
|
+
for i in 0..3 {
|
|
1561
|
+
doc_tokens.push(random_vectors(4, 8, 200 + i));
|
|
1562
|
+
}
|
|
1563
|
+
let doc_refs: Vec<&[Vec<f32>]> = doc_tokens.iter().map(|v| v.as_slice()).collect();
|
|
1564
|
+
|
|
1565
|
+
let config = MultiVectorQuantizationConfig::TwoBit(TwoBitQuantizationConfig {
|
|
1566
|
+
rescore_multiplier: 2,
|
|
1567
|
+
});
|
|
1568
|
+
let original = MultiVectorQuantizedIndex::build(&doc_refs, 8, &config);
|
|
1569
|
+
|
|
1570
|
+
let mut buf = Vec::new();
|
|
1571
|
+
original.write_params(&mut buf).unwrap();
|
|
1572
|
+
|
|
1573
|
+
let mut cursor = std::io::Cursor::new(&buf);
|
|
1574
|
+
let restored = MultiVectorQuantizedIndex::read_params(&mut cursor).unwrap();
|
|
1575
|
+
|
|
1576
|
+
assert_eq!(original.doc_ranges, restored.doc_ranges);
|
|
1577
|
+
assert_eq!(original.quantizer.dimension, restored.quantizer.dimension);
|
|
1578
|
+
assert_eq!(
|
|
1579
|
+
original.quantizer.boundaries.len(),
|
|
1580
|
+
restored.quantizer.boundaries.len()
|
|
1581
|
+
);
|
|
1582
|
+
assert_eq!(
|
|
1583
|
+
original.quantizer.config.rescore_multiplier,
|
|
1584
|
+
restored.quantizer.config.rescore_multiplier,
|
|
1585
|
+
);
|
|
1586
|
+
}
|
|
1587
|
+
}
|