faiss 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +8 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/IVFlib.cpp +25 -49
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +24 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.h +3 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +80 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
- data/vendor/faiss/faiss/IndexHNSW.h +57 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
- data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/clone_index.cpp +2 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
- data/vendor/faiss/faiss/impl/HNSW.h +31 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
- data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/index_factory.cpp +182 -15
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/utils.cpp +4 -0
- metadata +18 -1
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
#include <faiss/impl/FaissAssert.h>
|
|
14
14
|
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
15
15
|
#include <faiss/impl/RaBitQUtils.h>
|
|
16
|
+
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
16
17
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
17
18
|
#include <faiss/impl/simd_result_handlers.h>
|
|
18
19
|
#include <faiss/invlists/BlockInvertedLists.h>
|
|
@@ -22,8 +23,10 @@
|
|
|
22
23
|
namespace faiss {
|
|
23
24
|
|
|
24
25
|
// Import shared utilities from RaBitQUtils
|
|
25
|
-
using rabitq_utils::
|
|
26
|
+
using rabitq_utils::ExtraBitsFactors;
|
|
26
27
|
using rabitq_utils::QueryFactorsData;
|
|
28
|
+
using rabitq_utils::SignBitFactors;
|
|
29
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
27
30
|
|
|
28
31
|
inline size_t roundup(size_t a, size_t b) {
|
|
29
32
|
return (a + b - 1) / b * b;
|
|
@@ -41,9 +44,10 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
41
44
|
size_t nlist,
|
|
42
45
|
MetricType metric,
|
|
43
46
|
int bbs,
|
|
44
|
-
bool own_invlists
|
|
47
|
+
bool own_invlists,
|
|
48
|
+
uint8_t nb_bits)
|
|
45
49
|
: IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
|
|
46
|
-
rabitq(d, metric) {
|
|
50
|
+
rabitq(d, metric, nb_bits) {
|
|
47
51
|
FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
|
|
48
52
|
FAISS_THROW_IF_NOT_MSG(
|
|
49
53
|
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
@@ -66,9 +70,9 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
66
70
|
this->ksub = (1 << nbits_fastscan);
|
|
67
71
|
this->M2 = roundup(M_fastscan, 2);
|
|
68
72
|
|
|
69
|
-
//
|
|
73
|
+
// Compute code_size: bit_pattern + per-vector storage (factors/ex-codes)
|
|
70
74
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
71
|
-
this->code_size = bit_pattern_size +
|
|
75
|
+
this->code_size = bit_pattern_size + compute_per_vector_storage_size();
|
|
72
76
|
|
|
73
77
|
is_trained = false;
|
|
74
78
|
|
|
@@ -76,7 +80,7 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
76
80
|
replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
|
|
77
81
|
}
|
|
78
82
|
|
|
79
|
-
|
|
83
|
+
flat_storage.clear();
|
|
80
84
|
}
|
|
81
85
|
|
|
82
86
|
// Constructor that converts an existing IndexIVFRaBitQ to FastScan format
|
|
@@ -92,20 +96,35 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
92
96
|
false),
|
|
93
97
|
rabitq(orig.rabitq) {}
|
|
94
98
|
|
|
99
|
+
size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
100
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
101
|
+
|
|
102
|
+
if (ex_bits == 0) {
|
|
103
|
+
// 1-bit: only SignBitFactors (8 bytes)
|
|
104
|
+
return sizeof(SignBitFactors);
|
|
105
|
+
} else {
|
|
106
|
+
// Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
|
|
107
|
+
return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
|
|
108
|
+
(d * ex_bits + 7) / 8;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
95
112
|
void IndexIVFRaBitQFastScan::preprocess_code_metadata(
|
|
96
113
|
idx_t n,
|
|
97
114
|
const uint8_t* flat_codes,
|
|
98
115
|
idx_t start_global_idx) {
|
|
99
|
-
//
|
|
100
|
-
const size_t
|
|
101
|
-
|
|
116
|
+
// Unified approach: always use flat_storage for both 1-bit and multi-bit
|
|
117
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
118
|
+
flat_storage.resize((start_global_idx + n) * storage_size);
|
|
102
119
|
|
|
120
|
+
// Copy factors data directly to flat storage (no reordering needed)
|
|
121
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
103
122
|
for (idx_t i = 0; i < n; i++) {
|
|
104
123
|
const uint8_t* code = flat_codes + i * code_size;
|
|
105
|
-
const uint8_t*
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
124
|
+
const uint8_t* source_factors_ptr = code + bit_pattern_size;
|
|
125
|
+
uint8_t* storage =
|
|
126
|
+
flat_storage.data() + (start_global_idx + i) * storage_size;
|
|
127
|
+
memcpy(storage, source_factors_ptr, storage_size);
|
|
109
128
|
}
|
|
110
129
|
}
|
|
111
130
|
|
|
@@ -143,7 +162,7 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
143
162
|
size_t total_code_size = code_size + coarse_size;
|
|
144
163
|
memset(codes, 0, total_code_size * n);
|
|
145
164
|
|
|
146
|
-
const size_t
|
|
165
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
147
166
|
|
|
148
167
|
#pragma omp parallel if (n > 1000)
|
|
149
168
|
{
|
|
@@ -161,16 +180,61 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
161
180
|
// Reconstruct centroid for residual computation
|
|
162
181
|
quantizer->reconstruct(list_no, centroid.data());
|
|
163
182
|
|
|
164
|
-
|
|
165
|
-
encode_vector_to_fastscan(xi, centroid.data(), fastscan_code);
|
|
183
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
166
184
|
|
|
167
|
-
//
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
185
|
+
// Pack sign bits directly into FastScan format (inline)
|
|
186
|
+
for (size_t j = 0; j < d; j++) {
|
|
187
|
+
const float or_minus_c = xi[j] - centroid[j];
|
|
188
|
+
if (or_minus_c > 0.0f) {
|
|
189
|
+
rabitq_utils::set_bit_fastscan(fastscan_code, j);
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
// Compute factors (with or without f_error depending on mode)
|
|
194
|
+
SignBitFactorsWithError factors =
|
|
195
|
+
rabitq_utils::compute_vector_factors(
|
|
196
|
+
xi,
|
|
197
|
+
d,
|
|
198
|
+
centroid.data(),
|
|
199
|
+
rabitq.metric_type,
|
|
200
|
+
ex_bits > 0);
|
|
201
|
+
|
|
202
|
+
if (ex_bits == 0) {
|
|
203
|
+
// 1-bit: store only SignBitFactors (8 bytes)
|
|
204
|
+
memcpy(fastscan_code + bit_pattern_size,
|
|
205
|
+
&factors,
|
|
206
|
+
sizeof(SignBitFactors));
|
|
207
|
+
} else {
|
|
208
|
+
// Multi-bit: store full SignBitFactorsWithError (12 bytes)
|
|
209
|
+
memcpy(fastscan_code + bit_pattern_size,
|
|
210
|
+
&factors,
|
|
211
|
+
sizeof(SignBitFactorsWithError));
|
|
212
|
+
|
|
213
|
+
// Compute residual (needed for quantize_ex_bits)
|
|
214
|
+
std::vector<float> residual(d);
|
|
215
|
+
for (size_t j = 0; j < d; j++) {
|
|
216
|
+
residual[j] = xi[j] - centroid[j];
|
|
217
|
+
}
|
|
171
218
|
|
|
172
|
-
|
|
173
|
-
|
|
219
|
+
// Quantize ex-bits
|
|
220
|
+
const size_t ex_code_size = (d * ex_bits + 7) / 8;
|
|
221
|
+
uint8_t* ex_code = fastscan_code + bit_pattern_size +
|
|
222
|
+
sizeof(SignBitFactorsWithError);
|
|
223
|
+
ExtraBitsFactors ex_factors_temp;
|
|
224
|
+
|
|
225
|
+
rabitq_multibit::quantize_ex_bits(
|
|
226
|
+
residual.data(),
|
|
227
|
+
d,
|
|
228
|
+
rabitq.nb_bits,
|
|
229
|
+
ex_code,
|
|
230
|
+
ex_factors_temp,
|
|
231
|
+
rabitq.metric_type,
|
|
232
|
+
centroid.data());
|
|
233
|
+
|
|
234
|
+
memcpy(ex_code + ex_code_size,
|
|
235
|
+
&ex_factors_temp,
|
|
236
|
+
sizeof(ExtraBitsFactors));
|
|
237
|
+
}
|
|
174
238
|
|
|
175
239
|
// Include coarse codes if requested
|
|
176
240
|
if (include_listnos) {
|
|
@@ -181,24 +245,6 @@ void IndexIVFRaBitQFastScan::encode_vectors(
|
|
|
181
245
|
}
|
|
182
246
|
}
|
|
183
247
|
|
|
184
|
-
void IndexIVFRaBitQFastScan::encode_vector_to_fastscan(
|
|
185
|
-
const float* xi,
|
|
186
|
-
const float* centroid,
|
|
187
|
-
uint8_t* fastscan_code) const {
|
|
188
|
-
memset(fastscan_code, 0, code_size);
|
|
189
|
-
|
|
190
|
-
for (size_t j = 0; j < d; j++) {
|
|
191
|
-
const float x_val = xi[j];
|
|
192
|
-
const float centroid_val = (centroid != nullptr) ? centroid[j] : 0.0f;
|
|
193
|
-
const float or_minus_c = x_val - centroid_val;
|
|
194
|
-
const bool xb = (or_minus_c > 0.0f);
|
|
195
|
-
|
|
196
|
-
if (xb) {
|
|
197
|
-
rabitq_utils::set_bit_fastscan(fastscan_code, j);
|
|
198
|
-
}
|
|
199
|
-
}
|
|
200
|
-
}
|
|
201
|
-
|
|
202
248
|
bool IndexIVFRaBitQFastScan::lookup_table_is_3d() const {
|
|
203
249
|
return true;
|
|
204
250
|
}
|
|
@@ -231,6 +277,11 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
|
|
|
231
277
|
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
|
|
232
278
|
}
|
|
233
279
|
|
|
280
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
281
|
+
if (ex_bits > 0) {
|
|
282
|
+
query_factors.rotated_q = rotated_q;
|
|
283
|
+
}
|
|
284
|
+
|
|
234
285
|
if (centered) {
|
|
235
286
|
const float max_code_value = (1 << qb) - 1;
|
|
236
287
|
|
|
@@ -352,7 +403,7 @@ void IndexIVFRaBitQFastScan::compute_LUT(
|
|
|
352
403
|
x + i * d);
|
|
353
404
|
|
|
354
405
|
// Store query factors using compact indexing (ij directly)
|
|
355
|
-
if (context.query_factors) {
|
|
406
|
+
if (context.query_factors != nullptr) {
|
|
356
407
|
context.query_factors[ij] = query_factors_data;
|
|
357
408
|
}
|
|
358
409
|
|
|
@@ -367,52 +418,56 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
|
|
|
367
418
|
int64_t list_no,
|
|
368
419
|
int64_t offset,
|
|
369
420
|
float* recons) const {
|
|
370
|
-
//
|
|
371
|
-
|
|
421
|
+
// Get centroid for this list
|
|
422
|
+
std::vector<float> centroid(d);
|
|
423
|
+
quantizer->reconstruct(list_no, centroid.data());
|
|
424
|
+
|
|
425
|
+
// Unpack bit pattern from packed format
|
|
372
426
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
373
|
-
std::vector<uint8_t>
|
|
374
|
-
coarse_size + bit_pattern_size + sizeof(FactorsData), 0);
|
|
427
|
+
std::vector<uint8_t> fastscan_code(bit_pattern_size, 0);
|
|
375
428
|
|
|
376
|
-
encode_listno(list_no, code.data());
|
|
377
429
|
InvertedLists::ScopedCodes list_codes(invlists, list_no);
|
|
378
|
-
|
|
379
|
-
// Unpack the bit pattern from packed format to FastScan layout
|
|
380
|
-
uint8_t* fastscan_code = code.data() + coarse_size;
|
|
381
430
|
for (size_t m = 0; m < M; m++) {
|
|
382
431
|
uint8_t c =
|
|
383
432
|
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
|
|
384
433
|
|
|
385
|
-
// Write the 4-bit code value to FastScan format
|
|
386
|
-
// Each byte stores two 4-bit codes (lower and upper nibbles)
|
|
387
434
|
size_t byte_idx = m / 2;
|
|
388
435
|
if (m % 2 == 0) {
|
|
389
|
-
// Even m: write to lower 4 bits
|
|
390
436
|
fastscan_code[byte_idx] =
|
|
391
437
|
(fastscan_code[byte_idx] & 0xF0) | (c & 0x0F);
|
|
392
438
|
} else {
|
|
393
|
-
// Odd m: write to upper 4 bits
|
|
394
439
|
fastscan_code[byte_idx] =
|
|
395
440
|
(fastscan_code[byte_idx] & 0x0F) | ((c & 0x0F) << 4);
|
|
396
441
|
}
|
|
397
442
|
}
|
|
398
443
|
|
|
399
|
-
// Get
|
|
400
|
-
// Need to look up the ID from inverted lists
|
|
444
|
+
// Get dp_multiplier directly from flat_storage
|
|
401
445
|
InvertedLists::ScopedIds list_ids(invlists, list_no);
|
|
402
446
|
idx_t global_id = list_ids[offset];
|
|
403
447
|
|
|
404
|
-
|
|
405
|
-
if (global_id >= 0
|
|
406
|
-
|
|
407
|
-
const
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
448
|
+
float dp_multiplier = 1.0f;
|
|
449
|
+
if (global_id >= 0) {
|
|
450
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
451
|
+
const size_t storage_capacity = flat_storage.size() / storage_size;
|
|
452
|
+
|
|
453
|
+
if (static_cast<size_t>(global_id) < storage_capacity) {
|
|
454
|
+
const uint8_t* base_ptr =
|
|
455
|
+
flat_storage.data() + global_id * storage_size;
|
|
456
|
+
const auto& base_factors =
|
|
457
|
+
*reinterpret_cast<const SignBitFactors*>(base_ptr);
|
|
458
|
+
dp_multiplier = base_factors.dp_multiplier;
|
|
459
|
+
}
|
|
412
460
|
}
|
|
413
461
|
|
|
414
|
-
//
|
|
415
|
-
|
|
462
|
+
// Decode residual directly using dp_multiplier
|
|
463
|
+
std::vector<float> residual(d);
|
|
464
|
+
decode_fastscan_to_residual(
|
|
465
|
+
fastscan_code.data(), residual.data(), dp_multiplier);
|
|
466
|
+
|
|
467
|
+
// Reconstruct: x = centroid + residual
|
|
468
|
+
for (size_t j = 0; j < d; j++) {
|
|
469
|
+
recons[j] = centroid[j] + residual[j];
|
|
470
|
+
}
|
|
416
471
|
}
|
|
417
472
|
|
|
418
473
|
void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
@@ -426,6 +481,7 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
426
481
|
size_t total_code_size = code_size + coarse_size;
|
|
427
482
|
std::vector<float> centroid(d);
|
|
428
483
|
std::vector<float> residual(d);
|
|
484
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
429
485
|
|
|
430
486
|
#pragma omp parallel for if (n > 1000)
|
|
431
487
|
for (idx_t i = 0; i < n; i++) {
|
|
@@ -439,7 +495,12 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
439
495
|
|
|
440
496
|
const uint8_t* fastscan_code = code_i + coarse_size;
|
|
441
497
|
|
|
442
|
-
|
|
498
|
+
const uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
|
|
499
|
+
const auto& base_factors =
|
|
500
|
+
*reinterpret_cast<const SignBitFactors*>(factors_ptr);
|
|
501
|
+
|
|
502
|
+
decode_fastscan_to_residual(
|
|
503
|
+
fastscan_code, residual.data(), base_factors.dp_multiplier);
|
|
443
504
|
|
|
444
505
|
for (size_t j = 0; j < d; j++) {
|
|
445
506
|
x_i[j] = centroid[j] + residual[j];
|
|
@@ -452,23 +513,17 @@ void IndexIVFRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
|
452
513
|
|
|
453
514
|
void IndexIVFRaBitQFastScan::decode_fastscan_to_residual(
|
|
454
515
|
const uint8_t* fastscan_code,
|
|
455
|
-
float* residual
|
|
516
|
+
float* residual,
|
|
517
|
+
float dp_multiplier) const {
|
|
456
518
|
memset(residual, 0, sizeof(float) * d);
|
|
457
519
|
|
|
458
520
|
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
459
|
-
const size_t bit_pattern_size = (d + 7) / 8;
|
|
460
|
-
|
|
461
|
-
// Extract factors directly from embedded codes
|
|
462
|
-
const uint8_t* factors_ptr = fastscan_code + bit_pattern_size;
|
|
463
|
-
const FactorsData& fac = *reinterpret_cast<const FactorsData*>(factors_ptr);
|
|
464
521
|
|
|
465
522
|
for (size_t j = 0; j < d; j++) {
|
|
466
|
-
// Use RaBitQUtils for consistent bit extraction
|
|
467
523
|
bool bit_value = rabitq_utils::extract_bit_fastscan(fastscan_code, j);
|
|
468
524
|
|
|
469
525
|
float bit_as_float = bit_value ? 1.0f : 0.0f;
|
|
470
|
-
residual[j] =
|
|
471
|
-
(bit_as_float - 0.5f) * fac.dp_multiplier * 2 * inv_d_sqrt;
|
|
526
|
+
residual[j] = (bit_as_float - 0.5f) * dp_multiplier * 2 * inv_d_sqrt;
|
|
472
527
|
}
|
|
473
528
|
}
|
|
474
529
|
|
|
@@ -483,12 +538,15 @@ SIMDResultHandlerToFloat* IndexIVFRaBitQFastScan::make_knn_handler(
|
|
|
483
538
|
const IDSelector* /* sel */,
|
|
484
539
|
const FastScanDistancePostProcessing& context,
|
|
485
540
|
const float* /* normalizers */) const {
|
|
541
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
542
|
+
const bool is_multibit = ex_bits > 0;
|
|
543
|
+
|
|
486
544
|
if (is_max) {
|
|
487
545
|
return new IVFRaBitQHeapHandler<CMax<uint16_t, int64_t>>(
|
|
488
|
-
this, n, k, distances, labels, &context);
|
|
546
|
+
this, n, k, distances, labels, &context, is_multibit);
|
|
489
547
|
} else {
|
|
490
548
|
return new IVFRaBitQHeapHandler<CMin<uint16_t, int64_t>>(
|
|
491
|
-
this, n, k, distances, labels, &context);
|
|
549
|
+
this, n, k, distances, labels, &context, is_multibit);
|
|
492
550
|
}
|
|
493
551
|
}
|
|
494
552
|
|
|
@@ -503,7 +561,8 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
|
|
|
503
561
|
size_t k_val,
|
|
504
562
|
float* distances,
|
|
505
563
|
int64_t* labels,
|
|
506
|
-
const FastScanDistancePostProcessing* ctx
|
|
564
|
+
const FastScanDistancePostProcessing* ctx,
|
|
565
|
+
bool multibit)
|
|
507
566
|
: simd_result_handlers::ResultHandlerCompare<C, true>(
|
|
508
567
|
nq_val,
|
|
509
568
|
0,
|
|
@@ -513,7 +572,8 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
|
|
|
513
572
|
heap_labels(labels),
|
|
514
573
|
nq(nq_val),
|
|
515
574
|
k(k_val),
|
|
516
|
-
context(ctx)
|
|
575
|
+
context(ctx),
|
|
576
|
+
is_multibit(multibit) {
|
|
517
577
|
current_list_no = 0;
|
|
518
578
|
probe_indices.clear();
|
|
519
579
|
|
|
@@ -572,8 +632,15 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
|
572
632
|
}
|
|
573
633
|
|
|
574
634
|
size_t max_positions = std::min<size_t>(32, this->ntotal - idx_base);
|
|
635
|
+
|
|
636
|
+
// Stats tracking for two-stage search
|
|
637
|
+
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
638
|
+
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
639
|
+
size_t local_1bit_evaluations = 0;
|
|
640
|
+
size_t local_multibit_evaluations = 0;
|
|
641
|
+
|
|
575
642
|
// Process each candidate vector in the SIMD batch
|
|
576
|
-
for (
|
|
643
|
+
for (size_t j = 0; j < max_positions; j++) {
|
|
577
644
|
const int64_t result_id = this->adjust_id(b, j);
|
|
578
645
|
|
|
579
646
|
if (result_id < 0) {
|
|
@@ -582,39 +649,81 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
|
582
649
|
|
|
583
650
|
const float normalized_distance = d32tab[j] * one_a + bias;
|
|
584
651
|
|
|
585
|
-
// Get database factors
|
|
586
|
-
|
|
587
|
-
const
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
652
|
+
// Get database factors from flat_storage
|
|
653
|
+
const size_t storage_size = index->compute_per_vector_storage_size();
|
|
654
|
+
const uint8_t* base_ptr =
|
|
655
|
+
index->flat_storage.data() + result_id * storage_size;
|
|
656
|
+
|
|
657
|
+
if (is_multibit) {
|
|
658
|
+
// Track candidates actually considered for two-stage filtering
|
|
659
|
+
local_1bit_evaluations++;
|
|
660
|
+
|
|
661
|
+
// Multi-bit: use SignBitFactorsWithError and two-stage search
|
|
662
|
+
const SignBitFactorsWithError& full_factors =
|
|
663
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
664
|
+
|
|
665
|
+
// Compute 1-bit adjusted distance using shared helper
|
|
666
|
+
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
667
|
+
normalized_distance,
|
|
668
|
+
full_factors,
|
|
669
|
+
query_factors,
|
|
670
|
+
index->centered,
|
|
671
|
+
index->qb,
|
|
672
|
+
index->d);
|
|
673
|
+
|
|
674
|
+
// Compute lower bound using error bound
|
|
675
|
+
float lower_bound =
|
|
676
|
+
compute_lower_bound(dist_1bit, result_id, local_q, q);
|
|
677
|
+
|
|
678
|
+
// Adaptive filtering: decide whether to compute full distance
|
|
679
|
+
const bool is_similarity =
|
|
680
|
+
index->metric_type == MetricType::METRIC_INNER_PRODUCT;
|
|
681
|
+
bool should_refine = is_similarity
|
|
682
|
+
? (lower_bound > heap_dis[0]) // IP: keep if better
|
|
683
|
+
: (lower_bound < heap_dis[0]); // L2: keep if better
|
|
684
|
+
|
|
685
|
+
if (should_refine) {
|
|
686
|
+
local_multibit_evaluations++;
|
|
687
|
+
|
|
688
|
+
// Compute local_offset: position within current inverted list
|
|
689
|
+
size_t local_offset = this->j0 + b * 32 + j;
|
|
690
|
+
|
|
691
|
+
// Compute full multi-bit distance
|
|
692
|
+
float dist_full = compute_full_multibit_distance(
|
|
693
|
+
result_id, local_q, q, local_offset);
|
|
694
|
+
|
|
695
|
+
// Update heap if this distance is better
|
|
696
|
+
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
697
|
+
heap_replace_top<Cfloat>(
|
|
698
|
+
k, heap_dis, heap_ids, dist_full, result_id);
|
|
699
|
+
}
|
|
700
|
+
}
|
|
600
701
|
} else {
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
702
|
+
const auto& db_factors =
|
|
703
|
+
*reinterpret_cast<const SignBitFactors*>(base_ptr);
|
|
704
|
+
|
|
705
|
+
// Compute adjusted distance using shared helper
|
|
706
|
+
float adjusted_distance =
|
|
707
|
+
rabitq_utils::compute_1bit_adjusted_distance(
|
|
708
|
+
normalized_distance,
|
|
709
|
+
db_factors,
|
|
710
|
+
query_factors,
|
|
711
|
+
index->centered,
|
|
712
|
+
index->qb,
|
|
713
|
+
index->d);
|
|
714
|
+
|
|
715
|
+
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
716
|
+
heap_replace_top<Cfloat>(
|
|
717
|
+
k, heap_dis, heap_ids, adjusted_distance, result_id);
|
|
718
|
+
}
|
|
616
719
|
}
|
|
617
720
|
}
|
|
721
|
+
|
|
722
|
+
// Update global stats atomically
|
|
723
|
+
#pragma omp atomic
|
|
724
|
+
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
725
|
+
#pragma omp atomic
|
|
726
|
+
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
618
727
|
}
|
|
619
728
|
|
|
620
729
|
template <class C>
|
|
@@ -641,10 +750,79 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
|
|
|
641
750
|
}
|
|
642
751
|
}
|
|
643
752
|
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
753
|
+
template <class C>
|
|
754
|
+
float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::compute_lower_bound(
|
|
755
|
+
float dist_1bit,
|
|
756
|
+
size_t db_idx,
|
|
757
|
+
size_t local_q,
|
|
758
|
+
size_t global_q) const {
|
|
759
|
+
// Access f_error from SignBitFactorsWithError in flat storage
|
|
760
|
+
const size_t storage_size = index->compute_per_vector_storage_size();
|
|
761
|
+
const uint8_t* base_ptr =
|
|
762
|
+
index->flat_storage.data() + db_idx * storage_size;
|
|
763
|
+
const SignBitFactorsWithError& db_factors =
|
|
764
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
765
|
+
float f_error = db_factors.f_error;
|
|
766
|
+
|
|
767
|
+
// Get g_error from query factors
|
|
768
|
+
// Use local_q to access probe_indices (batch-local), global_q for storage
|
|
769
|
+
float g_error = 0.0f;
|
|
770
|
+
if (context && context->query_factors) {
|
|
771
|
+
size_t probe_rank = probe_indices[local_q];
|
|
772
|
+
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
773
|
+
size_t storage_idx = global_q * nprobe + probe_rank;
|
|
774
|
+
g_error = context->query_factors[storage_idx].g_error;
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
// Compute error adjustment: f_error * g_error
|
|
778
|
+
float error_adjustment = f_error * g_error;
|
|
779
|
+
|
|
780
|
+
return dist_1bit - error_adjustment;
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
template <class C>
|
|
784
|
+
float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
|
|
785
|
+
compute_full_multibit_distance(
|
|
786
|
+
size_t db_idx,
|
|
787
|
+
size_t local_q,
|
|
788
|
+
size_t global_q,
|
|
789
|
+
size_t local_offset) const {
|
|
790
|
+
const size_t ex_bits = index->rabitq.nb_bits - 1;
|
|
791
|
+
const size_t dim = index->d;
|
|
792
|
+
|
|
793
|
+
const size_t storage_size = index->compute_per_vector_storage_size();
|
|
794
|
+
const uint8_t* base_ptr =
|
|
795
|
+
index->flat_storage.data() + db_idx * storage_size;
|
|
796
|
+
|
|
797
|
+
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
798
|
+
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
799
|
+
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
|
|
800
|
+
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
|
|
801
|
+
|
|
802
|
+
// Use local_q to access probe_indices (batch-local), global_q for storage
|
|
803
|
+
size_t probe_rank = probe_indices[local_q];
|
|
804
|
+
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
805
|
+
size_t storage_idx = global_q * nprobe + probe_rank;
|
|
806
|
+
const auto& query_factors = context->query_factors[storage_idx];
|
|
807
|
+
|
|
808
|
+
size_t list_no = current_list_no;
|
|
809
|
+
InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
|
|
810
|
+
|
|
811
|
+
std::vector<uint8_t> unpacked_code(index->code_size);
|
|
812
|
+
CodePackerPQ4 packer(index->M2, index->bbs);
|
|
813
|
+
packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
|
|
814
|
+
const uint8_t* sign_bits = unpacked_code.data();
|
|
815
|
+
|
|
816
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
817
|
+
sign_bits,
|
|
818
|
+
ex_code,
|
|
819
|
+
ex_fac,
|
|
820
|
+
query_factors.rotated_q.data(),
|
|
821
|
+
query_factors.qr_to_c_L2sqr,
|
|
822
|
+
query_factors.qr_norm_L2sqr,
|
|
823
|
+
dim,
|
|
824
|
+
ex_bits,
|
|
825
|
+
index->metric_type);
|
|
826
|
+
}
|
|
649
827
|
|
|
650
828
|
} // namespace faiss
|