faiss 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +8 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/IVFlib.cpp +25 -49
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +24 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.h +3 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +80 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
- data/vendor/faiss/faiss/IndexHNSW.h +57 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
- data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/clone_index.cpp +2 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
- data/vendor/faiss/faiss/impl/HNSW.h +31 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
- data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/index_factory.cpp +182 -15
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/utils.cpp +4 -0
- metadata +18 -1
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include <faiss/IndexRaBitQFastScan.h>
|
|
9
9
|
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
10
10
|
#include <faiss/impl/RaBitQUtils.h>
|
|
11
|
+
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
11
12
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
12
13
|
#include <faiss/utils/utils.h>
|
|
13
14
|
#include <algorithm>
|
|
@@ -19,15 +20,35 @@ static inline size_t roundup(size_t a, size_t b) {
|
|
|
19
20
|
return (a + b - 1) / b * b;
|
|
20
21
|
}
|
|
21
22
|
|
|
23
|
+
size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
24
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
25
|
+
|
|
26
|
+
if (ex_bits == 0) {
|
|
27
|
+
// 1-bit: only SignBitFactors
|
|
28
|
+
return sizeof(rabitq_utils::SignBitFactors);
|
|
29
|
+
} else {
|
|
30
|
+
// Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
|
|
31
|
+
// mag-codes
|
|
32
|
+
return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
|
|
33
|
+
(d * ex_bits + 7) / 8;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
22
37
|
IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
|
|
23
38
|
|
|
24
|
-
IndexRaBitQFastScan::IndexRaBitQFastScan(
|
|
25
|
-
|
|
39
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan(
|
|
40
|
+
idx_t d,
|
|
41
|
+
MetricType metric,
|
|
42
|
+
int bbs,
|
|
43
|
+
uint8_t nb_bits)
|
|
44
|
+
: rabitq(d, metric, nb_bits) {
|
|
26
45
|
// RaBitQ-specific validation
|
|
27
46
|
FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
|
|
28
47
|
FAISS_THROW_IF_NOT_MSG(
|
|
29
48
|
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
30
49
|
"RaBitQ FastScan only supports L2 and Inner Product metrics");
|
|
50
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
51
|
+
nb_bits >= 1 && nb_bits <= 9, "nb_bits must be between 1 and 9");
|
|
31
52
|
|
|
32
53
|
// RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
|
|
33
54
|
// Each FastScan sub-quantizer handles 4 RaBitQ dimensions
|
|
@@ -37,17 +58,15 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(idx_t d, MetricType metric, int bbs)
|
|
|
37
58
|
// init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
|
|
38
59
|
init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
|
|
39
60
|
|
|
40
|
-
//
|
|
41
|
-
|
|
42
|
-
const size_t bit_pattern_size = (d + 7) / 8;
|
|
43
|
-
code_size = bit_pattern_size + sizeof(FactorsData);
|
|
61
|
+
// Compute code_size directly using RaBitQuantizer
|
|
62
|
+
code_size = rabitq.compute_code_size(d, nb_bits);
|
|
44
63
|
|
|
45
64
|
// Set RaBitQ-specific parameters
|
|
46
65
|
qb = 8;
|
|
47
66
|
center.resize(d, 0.0f);
|
|
48
67
|
|
|
49
|
-
//
|
|
50
|
-
|
|
68
|
+
// Initialize empty flat storage
|
|
69
|
+
flat_storage.clear();
|
|
51
70
|
}
|
|
52
71
|
|
|
53
72
|
IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
@@ -72,10 +91,7 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
|
72
91
|
orig.metric_type,
|
|
73
92
|
bbs);
|
|
74
93
|
|
|
75
|
-
|
|
76
|
-
// RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
|
|
77
|
-
const size_t bit_pattern_size = (orig.d + 7) / 8;
|
|
78
|
-
code_size = bit_pattern_size + sizeof(FactorsData);
|
|
94
|
+
code_size = rabitq.compute_code_size(d, rabitq.nb_bits);
|
|
79
95
|
|
|
80
96
|
// Copy properties from original index
|
|
81
97
|
ntotal = orig.ntotal;
|
|
@@ -88,23 +104,19 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
|
88
104
|
|
|
89
105
|
// If the original index has data, extract factors and pack codes
|
|
90
106
|
if (ntotal > 0) {
|
|
91
|
-
//
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
// Extract factors from original codes for each vector
|
|
95
|
-
const float* centroid_data = center.data();
|
|
107
|
+
// Compute per-vector storage size for flat storage
|
|
108
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
96
109
|
|
|
97
|
-
//
|
|
98
|
-
|
|
99
|
-
orig.sa_decode(ntotal, orig.codes.data(), decoded_vectors.data());
|
|
110
|
+
// Allocate flat storage
|
|
111
|
+
flat_storage.resize(ntotal * storage_size);
|
|
100
112
|
|
|
113
|
+
// Copy factors directly from original codes
|
|
114
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
101
115
|
for (idx_t i = 0; i < ntotal; i++) {
|
|
102
|
-
|
|
103
|
-
const
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
fac = rabitq_utils::compute_vector_factors(
|
|
107
|
-
x_row, orig.d, centroid_data, orig.metric_type);
|
|
116
|
+
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
117
|
+
const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
|
|
118
|
+
uint8_t* storage = flat_storage.data() + i * storage_size;
|
|
119
|
+
memcpy(storage, source_factors_ptr, storage_size);
|
|
108
120
|
}
|
|
109
121
|
|
|
110
122
|
// Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
|
|
@@ -191,15 +203,19 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
|
|
|
191
203
|
AlignedTable<uint8_t> tmp_codes(n * code_size);
|
|
192
204
|
compute_codes(tmp_codes.get(), n, x);
|
|
193
205
|
|
|
194
|
-
|
|
206
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
207
|
+
flat_storage.resize((ntotal + n) * storage_size);
|
|
208
|
+
|
|
209
|
+
// Populate flat storage (no sign bits copying needed!)
|
|
195
210
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
196
|
-
factors_storage.resize(ntotal + n);
|
|
197
211
|
for (idx_t i = 0; i < n; i++) {
|
|
198
212
|
const uint8_t* code = tmp_codes.get() + i * code_size;
|
|
199
|
-
const
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
213
|
+
const idx_t vec_idx = ntotal + i;
|
|
214
|
+
|
|
215
|
+
// Copy factors data directly to flat storage (no reordering needed)
|
|
216
|
+
const uint8_t* source_factors_ptr = code + bit_pattern_size;
|
|
217
|
+
uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
|
|
218
|
+
memcpy(storage, source_factors_ptr, storage_size);
|
|
203
219
|
}
|
|
204
220
|
|
|
205
221
|
// Resize main storage (same logic as parent)
|
|
@@ -239,6 +255,8 @@ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
|
239
255
|
// Hoist loop-invariant computations
|
|
240
256
|
const float* centroid_data = center.data();
|
|
241
257
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
258
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
259
|
+
const size_t ex_code_size = (d * ex_bits + 7) / 8;
|
|
242
260
|
|
|
243
261
|
memset(codes, 0, n * code_size);
|
|
244
262
|
|
|
@@ -247,25 +265,52 @@ void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
|
247
265
|
uint8_t* const code = codes + i * code_size;
|
|
248
266
|
const float* const x_row = x + i * d;
|
|
249
267
|
|
|
250
|
-
//
|
|
268
|
+
// Compute residual once, reuse for both sign bits and ex-bits
|
|
269
|
+
std::vector<float> residual(d);
|
|
251
270
|
for (size_t j = 0; j < d; j++) {
|
|
252
|
-
const float x_val = x_row[j];
|
|
253
271
|
const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
|
|
254
|
-
|
|
255
|
-
|
|
272
|
+
residual[j] = x_row[j] - centroid_val;
|
|
273
|
+
}
|
|
256
274
|
|
|
257
|
-
|
|
275
|
+
// Pack sign bits directly into FastScan format using precomputed
|
|
276
|
+
// residual
|
|
277
|
+
for (size_t j = 0; j < d; j++) {
|
|
278
|
+
if (residual[j] > 0.0f) {
|
|
258
279
|
rabitq_utils::set_bit_fastscan(code, j);
|
|
259
280
|
}
|
|
260
281
|
}
|
|
261
282
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
x_row, d, centroid_data, metric_type);
|
|
283
|
+
SignBitFactorsWithError factors = rabitq_utils::compute_vector_factors(
|
|
284
|
+
x_row, d, centroid_data, metric_type, ex_bits > 0);
|
|
265
285
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
286
|
+
if (ex_bits == 0) {
|
|
287
|
+
// 1-bit: store only SignBitFactors (8 bytes)
|
|
288
|
+
memcpy(code + bit_pattern_size, &factors, sizeof(SignBitFactors));
|
|
289
|
+
} else {
|
|
290
|
+
// Multi-bit: store full SignBitFactorsWithError (12 bytes)
|
|
291
|
+
memcpy(code + bit_pattern_size,
|
|
292
|
+
&factors,
|
|
293
|
+
sizeof(SignBitFactorsWithError));
|
|
294
|
+
|
|
295
|
+
// Add mag-codes and ExtraBitsFactors using precomputed
|
|
296
|
+
// residual
|
|
297
|
+
uint8_t* ex_code =
|
|
298
|
+
code + bit_pattern_size + sizeof(SignBitFactorsWithError);
|
|
299
|
+
ExtraBitsFactors ex_factors_temp;
|
|
300
|
+
|
|
301
|
+
rabitq_multibit::quantize_ex_bits(
|
|
302
|
+
residual.data(),
|
|
303
|
+
d,
|
|
304
|
+
rabitq.nb_bits,
|
|
305
|
+
ex_code,
|
|
306
|
+
ex_factors_temp,
|
|
307
|
+
metric_type,
|
|
308
|
+
centroid_data);
|
|
309
|
+
|
|
310
|
+
memcpy(ex_code + ex_code_size,
|
|
311
|
+
&ex_factors_temp,
|
|
312
|
+
sizeof(ExtraBitsFactors));
|
|
313
|
+
}
|
|
269
314
|
}
|
|
270
315
|
}
|
|
271
316
|
|
|
@@ -300,7 +345,8 @@ void IndexRaBitQFastScan::compute_float_LUT(
|
|
|
300
345
|
rotated_qq);
|
|
301
346
|
|
|
302
347
|
// Store query factors in context array if provided
|
|
303
|
-
if (context.query_factors) {
|
|
348
|
+
if (context.query_factors != nullptr) {
|
|
349
|
+
query_factors_data.rotated_q = rotated_q;
|
|
304
350
|
context.query_factors[i] = query_factors_data;
|
|
305
351
|
}
|
|
306
352
|
|
|
@@ -397,8 +443,9 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
397
443
|
|
|
398
444
|
// Extract factors directly from embedded codes
|
|
399
445
|
const uint8_t* factors_ptr = code + bit_pattern_size;
|
|
400
|
-
const
|
|
401
|
-
|
|
446
|
+
const rabitq_utils::SignBitFactors* fac =
|
|
447
|
+
reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
448
|
+
factors_ptr);
|
|
402
449
|
|
|
403
450
|
for (size_t j = 0; j < d; j++) {
|
|
404
451
|
// Use RaBitQUtils for consistent bit extraction
|
|
@@ -406,7 +453,7 @@ void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
406
453
|
float bit = bit_value ? 1.0f : 0.0f;
|
|
407
454
|
|
|
408
455
|
// Compute the output using RaBitQ reconstruction formula
|
|
409
|
-
x[i * d + j] = (bit - 0.5f) * fac
|
|
456
|
+
x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt +
|
|
410
457
|
((centroid_in == nullptr) ? 0 : centroid_in[j]);
|
|
411
458
|
}
|
|
412
459
|
}
|
|
@@ -446,14 +493,16 @@ RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
|
|
|
446
493
|
float* distances,
|
|
447
494
|
int64_t* labels,
|
|
448
495
|
const IDSelector* sel_in,
|
|
449
|
-
const FastScanDistancePostProcessing& ctx
|
|
496
|
+
const FastScanDistancePostProcessing& ctx,
|
|
497
|
+
bool multi_bit)
|
|
450
498
|
: RHC(nq_val, index->ntotal, sel_in),
|
|
451
499
|
rabitq_index(index),
|
|
452
500
|
heap_distances(distances),
|
|
453
501
|
heap_labels(labels),
|
|
454
502
|
nq(nq_val),
|
|
455
503
|
k(k_val),
|
|
456
|
-
context(ctx)
|
|
504
|
+
context(ctx),
|
|
505
|
+
is_multi_bit(multi_bit) {
|
|
457
506
|
// Initialize heaps for all queries in constructor
|
|
458
507
|
// This allows us to support direct normalizer assignment
|
|
459
508
|
#pragma omp parallel for if (nq > 100)
|
|
@@ -480,7 +529,7 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
|
480
529
|
|
|
481
530
|
// Access query factors from query_factors pointer
|
|
482
531
|
rabitq_utils::QueryFactorsData query_factors_data = {};
|
|
483
|
-
if (context.query_factors) {
|
|
532
|
+
if (context.query_factors != nullptr) {
|
|
484
533
|
query_factors_data = context.query_factors[q];
|
|
485
534
|
}
|
|
486
535
|
|
|
@@ -494,6 +543,15 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
|
494
543
|
? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
|
|
495
544
|
: 0;
|
|
496
545
|
|
|
546
|
+
// Get storage size once
|
|
547
|
+
const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
|
|
548
|
+
|
|
549
|
+
// Stats tracking for multi-bit two-stage search only
|
|
550
|
+
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
551
|
+
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
552
|
+
size_t local_1bit_evaluations = 0;
|
|
553
|
+
size_t local_multibit_evaluations = 0;
|
|
554
|
+
|
|
497
555
|
// Process distances in batch
|
|
498
556
|
for (size_t i = 0; i < max_vectors; i++) {
|
|
499
557
|
const size_t db_idx = base_db_idx + i;
|
|
@@ -501,43 +559,70 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
|
501
559
|
// Normalize distance from LUT lookup
|
|
502
560
|
const float normalized_distance = d32tab[i] * one_a + bias;
|
|
503
561
|
|
|
504
|
-
// Access factors from storage
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
562
|
+
// Access factors from flat storage
|
|
563
|
+
const uint8_t* base_ptr =
|
|
564
|
+
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
565
|
+
|
|
566
|
+
if (is_multi_bit) {
|
|
567
|
+
// Track candidates actually considered for two-stage filtering
|
|
568
|
+
local_1bit_evaluations++;
|
|
569
|
+
|
|
570
|
+
const SignBitFactorsWithError& full_factors =
|
|
571
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
572
|
+
|
|
573
|
+
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
574
|
+
normalized_distance,
|
|
575
|
+
full_factors,
|
|
576
|
+
query_factors_data,
|
|
577
|
+
rabitq_index->centered,
|
|
578
|
+
rabitq_index->qb,
|
|
579
|
+
rabitq_index->d);
|
|
580
|
+
|
|
581
|
+
float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
|
|
582
|
+
|
|
583
|
+
// Adaptive filtering: decide whether to compute full distance
|
|
584
|
+
const bool is_similarity = rabitq_index->metric_type ==
|
|
585
|
+
MetricType::METRIC_INNER_PRODUCT;
|
|
586
|
+
bool should_refine = is_similarity
|
|
587
|
+
? (lower_bound > heap_dis[0]) // IP: keep if better
|
|
588
|
+
: (lower_bound < heap_dis[0]); // L2: keep if better
|
|
589
|
+
|
|
590
|
+
if (should_refine) {
|
|
591
|
+
local_multibit_evaluations++;
|
|
592
|
+
float dist_full = compute_full_multibit_distance(db_idx, q);
|
|
593
|
+
|
|
594
|
+
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
595
|
+
heap_replace_top<Cfloat>(
|
|
596
|
+
k, heap_dis, heap_ids, dist_full, db_idx);
|
|
597
|
+
}
|
|
598
|
+
}
|
|
521
599
|
} else {
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
600
|
+
const rabitq_utils::SignBitFactors& db_factors =
|
|
601
|
+
*reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
602
|
+
base_ptr);
|
|
603
|
+
|
|
604
|
+
float adjusted_distance =
|
|
605
|
+
rabitq_utils::compute_1bit_adjusted_distance(
|
|
606
|
+
normalized_distance,
|
|
607
|
+
db_factors,
|
|
608
|
+
query_factors_data,
|
|
609
|
+
rabitq_index->centered,
|
|
610
|
+
rabitq_index->qb,
|
|
611
|
+
rabitq_index->d);
|
|
612
|
+
|
|
613
|
+
// Add to heap if better than current worst
|
|
614
|
+
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
615
|
+
heap_replace_top<Cfloat>(
|
|
616
|
+
k, heap_dis, heap_ids, adjusted_distance, db_idx);
|
|
617
|
+
}
|
|
539
618
|
}
|
|
540
619
|
}
|
|
620
|
+
|
|
621
|
+
// Update global stats atomically
|
|
622
|
+
#pragma omp atomic
|
|
623
|
+
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
624
|
+
#pragma omp atomic
|
|
625
|
+
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
541
626
|
}
|
|
542
627
|
|
|
543
628
|
template <class C, bool with_id_map>
|
|
@@ -557,8 +642,71 @@ void RaBitQHeapHandler<C, with_id_map>::end() {
|
|
|
557
642
|
}
|
|
558
643
|
}
|
|
559
644
|
|
|
645
|
+
template <class C, bool with_id_map>
|
|
646
|
+
float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
|
|
647
|
+
float dist_1bit,
|
|
648
|
+
size_t db_idx,
|
|
649
|
+
size_t q) const {
|
|
650
|
+
// Access f_error directly from SignBitFactorsWithError in flat storage
|
|
651
|
+
const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
|
|
652
|
+
const uint8_t* base_ptr =
|
|
653
|
+
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
654
|
+
const SignBitFactorsWithError& db_factors =
|
|
655
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
656
|
+
float f_error = db_factors.f_error;
|
|
657
|
+
|
|
658
|
+
// Get g_error from query factors (query-dependent error term)
|
|
659
|
+
float g_error = 0.0f;
|
|
660
|
+
if (context.query_factors != nullptr) {
|
|
661
|
+
g_error = context.query_factors[q].g_error;
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
// Compute error adjustment: f_error * g_error
|
|
665
|
+
float error_adjustment = f_error * g_error;
|
|
666
|
+
|
|
667
|
+
return dist_1bit - error_adjustment;
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
template <class C, bool with_id_map>
|
|
671
|
+
float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
|
|
672
|
+
size_t db_idx,
|
|
673
|
+
size_t q) const {
|
|
674
|
+
const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
|
|
675
|
+
const size_t dim = rabitq_index->d;
|
|
676
|
+
|
|
677
|
+
const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
|
|
678
|
+
const uint8_t* base_ptr =
|
|
679
|
+
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
680
|
+
|
|
681
|
+
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
682
|
+
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
683
|
+
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
|
|
684
|
+
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
|
|
685
|
+
|
|
686
|
+
// Get query factors reference (avoid copying)
|
|
687
|
+
const rabitq_utils::QueryFactorsData& query_factors =
|
|
688
|
+
context.query_factors[q];
|
|
689
|
+
|
|
690
|
+
// Get sign bits from FastScan packed format
|
|
691
|
+
std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
|
|
692
|
+
CodePackerPQ4 packer(rabitq_index->M2, rabitq_index->bbs);
|
|
693
|
+
packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
|
|
694
|
+
const uint8_t* sign_bits = unpacked_code.data();
|
|
695
|
+
|
|
696
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
697
|
+
sign_bits,
|
|
698
|
+
ex_code,
|
|
699
|
+
ex_fac,
|
|
700
|
+
query_factors.rotated_q.data(),
|
|
701
|
+
query_factors.qr_to_c_L2sqr,
|
|
702
|
+
query_factors.qr_norm_L2sqr,
|
|
703
|
+
dim,
|
|
704
|
+
ex_bits,
|
|
705
|
+
rabitq_index->metric_type);
|
|
706
|
+
}
|
|
707
|
+
|
|
560
708
|
// Implementation of virtual make_knn_handler method
|
|
561
|
-
|
|
709
|
+
SIMDResultHandlerToFloat* IndexRaBitQFastScan::make_knn_handler(
|
|
562
710
|
bool is_max,
|
|
563
711
|
int /*impl*/,
|
|
564
712
|
idx_t n,
|
|
@@ -568,19 +716,16 @@ void* IndexRaBitQFastScan::make_knn_handler(
|
|
|
568
716
|
idx_t* labels,
|
|
569
717
|
const IDSelector* sel,
|
|
570
718
|
const FastScanDistancePostProcessing& context) const {
|
|
719
|
+
// Use runtime boolean for multi-bit mode
|
|
720
|
+
const bool multi_bit = rabitq.nb_bits > 1;
|
|
721
|
+
|
|
571
722
|
if (is_max) {
|
|
572
723
|
return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
|
|
573
|
-
this, n, k, distances, labels, sel, context);
|
|
724
|
+
this, n, k, distances, labels, sel, context, multi_bit);
|
|
574
725
|
} else {
|
|
575
726
|
return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
|
|
576
|
-
this, n, k, distances, labels, sel, context);
|
|
727
|
+
this, n, k, distances, labels, sel, context, multi_bit);
|
|
577
728
|
}
|
|
578
729
|
}
|
|
579
730
|
|
|
580
|
-
// Explicit template instantiations for the required comparator types
|
|
581
|
-
template struct RaBitQHeapHandler<CMin<uint16_t, int>, false>;
|
|
582
|
-
template struct RaBitQHeapHandler<CMax<uint16_t, int>, false>;
|
|
583
|
-
template struct RaBitQHeapHandler<CMin<uint16_t, int>, true>;
|
|
584
|
-
template struct RaBitQHeapHandler<CMax<uint16_t, int>, true>;
|
|
585
|
-
|
|
586
731
|
} // namespace faiss
|
|
@@ -11,6 +11,7 @@
|
|
|
11
11
|
|
|
12
12
|
#include <faiss/IndexFastScan.h>
|
|
13
13
|
#include <faiss/IndexRaBitQ.h>
|
|
14
|
+
#include <faiss/impl/RaBitQStats.h>
|
|
14
15
|
#include <faiss/impl/RaBitQUtils.h>
|
|
15
16
|
#include <faiss/impl/RaBitQuantizer.h>
|
|
16
17
|
#include <faiss/impl/simd_result_handlers.h>
|
|
@@ -20,8 +21,10 @@
|
|
|
20
21
|
namespace faiss {
|
|
21
22
|
|
|
22
23
|
// Import shared utilities from RaBitQUtils
|
|
23
|
-
using rabitq_utils::
|
|
24
|
+
using rabitq_utils::ExtraBitsFactors;
|
|
24
25
|
using rabitq_utils::QueryFactorsData;
|
|
26
|
+
using rabitq_utils::SignBitFactors;
|
|
27
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
25
28
|
|
|
26
29
|
/** Fast-scan version of RaBitQ index that processes 32 database vectors at a
|
|
27
30
|
* time using SIMD operations. Similar to IndexPQFastScan but adapted for
|
|
@@ -40,9 +43,16 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
40
43
|
/// Center of all points (same as IndexRaBitQ)
|
|
41
44
|
std::vector<float> center;
|
|
42
45
|
|
|
43
|
-
///
|
|
44
|
-
///
|
|
45
|
-
|
|
46
|
+
/// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
|
|
47
|
+
///
|
|
48
|
+
/// 1-bit codes (sign bits) are stored in the inherited `codes` array from
|
|
49
|
+
/// IndexFastScan in packed FastScan format for SIMD processing.
|
|
50
|
+
///
|
|
51
|
+
/// This flat_storage holds per-vector factors and refinement-bit codes:
|
|
52
|
+
/// Layout for 1-bit: [SignBitFactors (8 bytes)]
|
|
53
|
+
/// Layout for multi-bit: [SignBitFactorsWithError
|
|
54
|
+
/// (12B)][ref_codes][ExtraBitsFactors (8B)]
|
|
55
|
+
std::vector<uint8_t> flat_storage;
|
|
46
56
|
|
|
47
57
|
/// Default number of bits to quantize a query with
|
|
48
58
|
uint8_t qb = 8;
|
|
@@ -55,7 +65,8 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
55
65
|
explicit IndexRaBitQFastScan(
|
|
56
66
|
idx_t d,
|
|
57
67
|
MetricType metric = METRIC_L2,
|
|
58
|
-
int bbs = 32
|
|
68
|
+
int bbs = 32,
|
|
69
|
+
uint8_t nb_bits = 1);
|
|
59
70
|
|
|
60
71
|
/// build from an existing IndexRaBitQ
|
|
61
72
|
explicit IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs = 32);
|
|
@@ -66,6 +77,9 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
66
77
|
|
|
67
78
|
void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
|
|
68
79
|
|
|
80
|
+
/// Compute storage size per vector in flat_storage
|
|
81
|
+
size_t compute_per_vector_storage_size() const;
|
|
82
|
+
|
|
69
83
|
void compute_float_LUT(
|
|
70
84
|
float* lut,
|
|
71
85
|
idx_t n,
|
|
@@ -83,7 +97,7 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
83
97
|
const SearchParameters* params = nullptr) const override;
|
|
84
98
|
|
|
85
99
|
/// Override to create RaBitQ-specific handlers
|
|
86
|
-
|
|
100
|
+
SIMDResultHandlerToFloat* make_knn_handler(
|
|
87
101
|
bool is_max,
|
|
88
102
|
int /*impl*/,
|
|
89
103
|
idx_t n,
|
|
@@ -108,6 +122,8 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
108
122
|
* - Direct heap integration (no intermediate result storage)
|
|
109
123
|
* - Batch-level computation of normalizers and query factors
|
|
110
124
|
* - Preserves exact mathematical equivalence to original RaBitQ distances
|
|
125
|
+
* - Runtime boolean for multi-bit support
|
|
126
|
+
*
|
|
111
127
|
* @tparam C Comparator type (CMin/CMax) for heap operations
|
|
112
128
|
* @tparam with_id_map Whether to use id mapping (similar to HeapHandler)
|
|
113
129
|
*/
|
|
@@ -122,7 +138,8 @@ struct RaBitQHeapHandler
|
|
|
122
138
|
int64_t* heap_labels; // [nq * k]
|
|
123
139
|
const size_t nq, k;
|
|
124
140
|
const FastScanDistancePostProcessing&
|
|
125
|
-
context;
|
|
141
|
+
context; // Processing context with query offset
|
|
142
|
+
const bool is_multi_bit; // Runtime flag for multi-bit mode
|
|
126
143
|
|
|
127
144
|
// Use float-based comparator for heap operations
|
|
128
145
|
using Cfloat = typename std::conditional<
|
|
@@ -137,13 +154,22 @@ struct RaBitQHeapHandler
|
|
|
137
154
|
float* distances,
|
|
138
155
|
int64_t* labels,
|
|
139
156
|
const IDSelector* sel_in,
|
|
140
|
-
const FastScanDistancePostProcessing& context
|
|
157
|
+
const FastScanDistancePostProcessing& context,
|
|
158
|
+
bool multi_bit);
|
|
141
159
|
|
|
142
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1)
|
|
160
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) override;
|
|
143
161
|
|
|
144
162
|
void begin(const float* norms);
|
|
145
163
|
|
|
146
164
|
void end();
|
|
165
|
+
|
|
166
|
+
private:
|
|
167
|
+
/// Compute full multi-bit distance for a candidate vector (multi-bit only)
|
|
168
|
+
float compute_full_multibit_distance(size_t db_idx, size_t q) const;
|
|
169
|
+
|
|
170
|
+
/// Compute lower bound using 1-bit distance and error bound (multi-bit
|
|
171
|
+
/// only)
|
|
172
|
+
float compute_lower_bound(float dist_1bit, size_t db_idx, size_t q) const;
|
|
147
173
|
};
|
|
148
174
|
|
|
149
175
|
} // namespace faiss
|
|
@@ -341,4 +341,53 @@ void IndexRefineFlat::search(
|
|
|
341
341
|
}
|
|
342
342
|
}
|
|
343
343
|
|
|
344
|
+
/***************************************************
|
|
345
|
+
* IndexRefinePanorama
|
|
346
|
+
***************************************************/
|
|
347
|
+
|
|
348
|
+
void IndexRefinePanorama::search(
|
|
349
|
+
idx_t n,
|
|
350
|
+
const float* x,
|
|
351
|
+
idx_t k,
|
|
352
|
+
float* distances,
|
|
353
|
+
idx_t* labels,
|
|
354
|
+
const SearchParameters* params_in) const {
|
|
355
|
+
const IndexRefineSearchParameters* params = nullptr;
|
|
356
|
+
if (params_in) {
|
|
357
|
+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
|
|
358
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
359
|
+
params, "IndexRefineFlat params have incorrect type");
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
|
|
363
|
+
: idx_t(k * k_factor);
|
|
364
|
+
SearchParameters* base_index_params =
|
|
365
|
+
(params != nullptr) ? params->base_index_params : nullptr;
|
|
366
|
+
|
|
367
|
+
FAISS_THROW_IF_NOT(k_base >= k);
|
|
368
|
+
|
|
369
|
+
FAISS_THROW_IF_NOT(base_index);
|
|
370
|
+
FAISS_THROW_IF_NOT(refine_index);
|
|
371
|
+
|
|
372
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
373
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
374
|
+
|
|
375
|
+
std::unique_ptr<idx_t[]> del1;
|
|
376
|
+
std::unique_ptr<float[]> del2;
|
|
377
|
+
idx_t* base_labels = new idx_t[n * k_base];
|
|
378
|
+
float* base_distances = new float[n * k_base];
|
|
379
|
+
del1.reset(base_labels);
|
|
380
|
+
del2.reset(base_distances);
|
|
381
|
+
|
|
382
|
+
base_index->search(
|
|
383
|
+
n, x, k_base, base_distances, base_labels, base_index_params);
|
|
384
|
+
|
|
385
|
+
for (int i = 0; i < n * k_base; i++) {
|
|
386
|
+
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
refine_index->search_subset(
|
|
390
|
+
n, x, k_base, base_labels, k, distances, labels);
|
|
391
|
+
}
|
|
392
|
+
|
|
344
393
|
} // namespace faiss
|
|
@@ -95,4 +95,21 @@ struct IndexRefineFlat : IndexRefine {
|
|
|
95
95
|
const SearchParameters* params = nullptr) const override;
|
|
96
96
|
};
|
|
97
97
|
|
|
98
|
+
/** Version where the search calls search_subset, allowing for Panorama
|
|
99
|
+
* refinement. */
|
|
100
|
+
struct IndexRefinePanorama : IndexRefine {
|
|
101
|
+
explicit IndexRefinePanorama(Index* base_index, Index* refine_index)
|
|
102
|
+
: IndexRefine(base_index, refine_index) {}
|
|
103
|
+
|
|
104
|
+
IndexRefinePanorama() : IndexRefine() {}
|
|
105
|
+
|
|
106
|
+
void search(
|
|
107
|
+
idx_t n,
|
|
108
|
+
const float* x,
|
|
109
|
+
idx_t k,
|
|
110
|
+
float* distances,
|
|
111
|
+
idx_t* labels,
|
|
112
|
+
const SearchParameters* params = nullptr) const override;
|
|
113
|
+
};
|
|
114
|
+
|
|
98
115
|
} // namespace faiss
|