faiss 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +76 -17
- data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
- data/ext/faiss/kmeans.cpp +12 -9
- data/ext/faiss/numo.hpp +11 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
- data/ext/faiss/{utils.h → utils_rb.h} +6 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -9,8 +9,10 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/MetricType.h>
|
|
11
11
|
#include <faiss/impl/platform_macros.h>
|
|
12
|
+
#include <faiss/utils/AlignedTable.h>
|
|
12
13
|
#include <cstddef>
|
|
13
14
|
#include <cstdint>
|
|
15
|
+
#include <cstring>
|
|
14
16
|
#include <vector>
|
|
15
17
|
|
|
16
18
|
namespace faiss {
|
|
@@ -68,6 +70,7 @@ struct QueryFactorsData {
|
|
|
68
70
|
|
|
69
71
|
float qr_to_c_L2sqr = 0;
|
|
70
72
|
float qr_norm_L2sqr = 0;
|
|
73
|
+
float q_dot_c = 0; // <query, centroid> for IP metric; 0 for L2
|
|
71
74
|
|
|
72
75
|
float int_dot_scale = 1;
|
|
73
76
|
|
|
@@ -239,6 +242,41 @@ inline float compute_1bit_adjusted_distance(
|
|
|
239
242
|
return adjusted_distance;
|
|
240
243
|
}
|
|
241
244
|
|
|
245
|
+
/** Determine whether a candidate should be refined in two-stage search.
|
|
246
|
+
* Consolidates the filtering logic for both L2 and IP metrics.
|
|
247
|
+
*
|
|
248
|
+
* For L2 (min-heap): uses lower_bound = est_distance - error_adjustment
|
|
249
|
+
* - Skip if lower_bound >= threshold (can't beat current worst)
|
|
250
|
+
* For IP (max-heap): uses upper_bound = est_distance + error_adjustment
|
|
251
|
+
* - Skip if upper_bound <= threshold (can't beat current best)
|
|
252
|
+
*
|
|
253
|
+
* @param est_distance Estimated 1-bit distance
|
|
254
|
+
* @param f_error Database vector error factor
|
|
255
|
+
* @param g_error Query vector error factor
|
|
256
|
+
* @param threshold Current heap threshold (worst result in heap)
|
|
257
|
+
* @param is_similarity True for IP metric (max-heap), false for L2
|
|
258
|
+
* (min-heap)
|
|
259
|
+
* @return True if candidate should be refined with full
|
|
260
|
+
* multi-bit distance
|
|
261
|
+
*/
|
|
262
|
+
inline bool should_refine_candidate(
|
|
263
|
+
float est_distance,
|
|
264
|
+
float f_error,
|
|
265
|
+
float g_error,
|
|
266
|
+
float threshold,
|
|
267
|
+
bool is_similarity) {
|
|
268
|
+
float error_adjustment = f_error * g_error;
|
|
269
|
+
if (is_similarity) {
|
|
270
|
+
// IP (max-heap): use upper bound for filtering
|
|
271
|
+
float upper_bound = est_distance + error_adjustment;
|
|
272
|
+
return upper_bound > threshold;
|
|
273
|
+
} else {
|
|
274
|
+
// L2 (min-heap): use lower bound for filtering
|
|
275
|
+
float lower_bound = std::max(0.0f, est_distance - error_adjustment);
|
|
276
|
+
return lower_bound < threshold;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
242
280
|
/** Extract multi-bit code on-the-fly from packed ex-bit codes.
|
|
243
281
|
* This inline function extracts a single code value without unpacking the
|
|
244
282
|
* entire array, enabling efficient on-the-fly decoding during distance
|
|
@@ -276,55 +314,79 @@ inline int extract_code_inline(
|
|
|
276
314
|
*
|
|
277
315
|
* The multi-bit distance combines the sign bit (1-bit) with additional
|
|
278
316
|
* magnitude bits (ex_bits) to compute a more accurate distance estimate.
|
|
317
|
+
* Uses SIMD-optimized bit-plane decomposition (AVX2+BMI2) for ex_bits 1-7,
|
|
318
|
+
* with scalar fallback for non-x86 or non-BMI2 platforms.
|
|
279
319
|
*
|
|
280
320
|
* @param sign_bits unpacked sign bits (1-bit codes in standard format)
|
|
281
321
|
* @param ex_code packed ex-bit codes
|
|
282
322
|
* @param ex_fac ex-bit factors (f_add_ex, f_rescale_ex)
|
|
283
323
|
* @param rotated_q rotated query vector
|
|
284
|
-
* @param
|
|
285
|
-
* @param qr_norm_L2sqr precomputed ||query_rotated||^2 (0 for L2 metric)
|
|
324
|
+
* @param qr_base precomputed base term: ||q-c||^2 for L2, <q,c> for IP
|
|
286
325
|
* @param d dimensionality
|
|
287
326
|
* @param ex_bits number of extra bits (nb_bits - 1)
|
|
288
327
|
* @param metric_type distance metric (L2 or Inner Product)
|
|
289
328
|
* @return computed full multi-bit distance
|
|
290
329
|
*/
|
|
291
|
-
|
|
330
|
+
float compute_full_multibit_distance(
|
|
292
331
|
const uint8_t* sign_bits,
|
|
293
332
|
const uint8_t* ex_code,
|
|
294
333
|
const ExtraBitsFactors& ex_fac,
|
|
295
334
|
const float* rotated_q,
|
|
296
|
-
float
|
|
297
|
-
float qr_norm_L2sqr,
|
|
335
|
+
float qr_base,
|
|
298
336
|
size_t d,
|
|
299
337
|
size_t ex_bits,
|
|
300
|
-
MetricType metric_type)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
ex_ip += rotated_q[i] * reconstructed;
|
|
316
|
-
}
|
|
317
|
-
|
|
318
|
-
float dist = qr_to_c_L2sqr + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
|
|
338
|
+
MetricType metric_type);
|
|
339
|
+
|
|
340
|
+
/** Compute pointer to a vector's auxiliary data within block layout. */
|
|
341
|
+
template <typename T>
|
|
342
|
+
inline T* get_block_aux_ptr(
|
|
343
|
+
T* block_data,
|
|
344
|
+
size_t vec_pos,
|
|
345
|
+
size_t bbs,
|
|
346
|
+
size_t packed_block_size,
|
|
347
|
+
size_t full_block_size,
|
|
348
|
+
size_t storage_size) {
|
|
349
|
+
return block_data + (vec_pos / bbs) * full_block_size + packed_block_size +
|
|
350
|
+
(vec_pos % bbs) * storage_size;
|
|
351
|
+
}
|
|
319
352
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
353
|
+
/** Compute per-vector auxiliary storage size.
|
|
354
|
+
*
|
|
355
|
+
* @param nb_bits number of quantization bits (1 = sign-bit only)
|
|
356
|
+
* @param d dimensionality
|
|
357
|
+
* @return storage size in bytes
|
|
358
|
+
*/
|
|
359
|
+
size_t compute_per_vector_storage_size(size_t nb_bits, size_t d);
|
|
325
360
|
|
|
326
|
-
|
|
327
|
-
|
|
361
|
+
/** [LEGACY FORMAT SUPPORT] Migrate block data from old I/O format to new
|
|
362
|
+
* format.
|
|
363
|
+
*
|
|
364
|
+
* This function is used only when reading indexes saved with the legacy format
|
|
365
|
+
* (fourcc "Irfs"/"Iwrf") to convert them to the new embedded auxiliary data
|
|
366
|
+
* format. Not needed for indexes saved with the new format ("Irfn"/"Iwrn").
|
|
367
|
+
*
|
|
368
|
+
* Re-layouts blocks in-place and copies aux data from flat_storage.
|
|
369
|
+
*
|
|
370
|
+
* @param flat_storage legacy per-vector aux data indexed by global ID
|
|
371
|
+
* @param codes block data (will be resized and re-laid out)
|
|
372
|
+
* @param num_vectors number of vectors in this segment
|
|
373
|
+
* @param bbs block batch size (vectors per block)
|
|
374
|
+
* @param M2 rounded sub-quantizer count
|
|
375
|
+
* @param old_block_stride old block size (packed codes only, or current)
|
|
376
|
+
* @param new_block_stride new block size (packed codes + aux region)
|
|
377
|
+
* @param storage_size per-vector aux storage size in bytes
|
|
378
|
+
* @param id_map maps local offset to global ID; null = sequential
|
|
379
|
+
*/
|
|
380
|
+
void populate_block_aux_from_flat_storage(
|
|
381
|
+
const std::vector<uint8_t>& flat_storage,
|
|
382
|
+
AlignedTable<uint8_t>& codes,
|
|
383
|
+
size_t num_vectors,
|
|
384
|
+
size_t bbs,
|
|
385
|
+
size_t M2,
|
|
386
|
+
size_t old_block_stride,
|
|
387
|
+
size_t new_block_stride,
|
|
388
|
+
size_t storage_size,
|
|
389
|
+
const int64_t* id_map = nullptr);
|
|
328
390
|
|
|
329
391
|
} // namespace rabitq_utils
|
|
330
392
|
} // namespace faiss
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
13
13
|
#include <faiss/utils/distances.h>
|
|
14
14
|
#include <faiss/utils/rabitq_simd.h>
|
|
15
|
+
|
|
15
16
|
#include <algorithm>
|
|
16
17
|
#include <cmath>
|
|
17
18
|
#include <cstring>
|
|
@@ -63,7 +64,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
|
63
64
|
return base_size + ex_size;
|
|
64
65
|
}
|
|
65
66
|
|
|
66
|
-
void RaBitQuantizer::train(size_t n
|
|
67
|
+
void RaBitQuantizer::train(size_t /*n*/, const float* /*x*/) {
|
|
67
68
|
// does nothing
|
|
68
69
|
}
|
|
69
70
|
|
|
@@ -215,29 +216,6 @@ void RaBitQuantizer::decode_core(
|
|
|
215
216
|
}
|
|
216
217
|
}
|
|
217
218
|
|
|
218
|
-
// Implementation of RaBitQDistanceComputer (declared in header)
|
|
219
|
-
|
|
220
|
-
float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
|
|
221
|
-
FAISS_ASSERT(code != nullptr);
|
|
222
|
-
|
|
223
|
-
// Compute estimated distance using 1-bit codes
|
|
224
|
-
float est_distance = distance_to_code_1bit(code);
|
|
225
|
-
|
|
226
|
-
// Extract f_error from the code
|
|
227
|
-
size_t size = (d + 7) / 8;
|
|
228
|
-
const SignBitFactorsWithError* base_fac =
|
|
229
|
-
reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
230
|
-
float f_error = base_fac->f_error;
|
|
231
|
-
|
|
232
|
-
// Compute proper lower bound using RaBitQ error formula:
|
|
233
|
-
// lower_bound = est_distance - f_error * g_error
|
|
234
|
-
// This guarantees: lower_bound ≤ true_distance
|
|
235
|
-
float lower_bound = est_distance - (f_error * g_error);
|
|
236
|
-
|
|
237
|
-
// Distance cannot be negative
|
|
238
|
-
return std::max(0.0f, lower_bound);
|
|
239
|
-
}
|
|
240
|
-
|
|
241
219
|
namespace {
|
|
242
220
|
|
|
243
221
|
struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
|
|
@@ -336,13 +314,15 @@ float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
|
|
|
336
314
|
ex_code + (d * ex_bits + 7) / 8);
|
|
337
315
|
|
|
338
316
|
// Call shared utility directly with rotated_q pointer
|
|
317
|
+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
318
|
+
? query_fac.q_dot_c
|
|
319
|
+
: query_fac.qr_to_c_L2sqr;
|
|
339
320
|
return rabitq_utils::compute_full_multibit_distance(
|
|
340
321
|
binary_data,
|
|
341
322
|
ex_code,
|
|
342
323
|
*ex_fac,
|
|
343
324
|
rotated_q.data(),
|
|
344
|
-
|
|
345
|
-
query_fac.qr_norm_L2sqr,
|
|
325
|
+
qr_base,
|
|
346
326
|
d,
|
|
347
327
|
ex_bits,
|
|
348
328
|
metric_type);
|
|
@@ -388,6 +368,8 @@ void RaBitQDistanceComputerNotQ::set_query(const float* x) {
|
|
|
388
368
|
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
389
369
|
// precompute if needed
|
|
390
370
|
query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
|
|
371
|
+
query_fac.q_dot_c =
|
|
372
|
+
centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
|
|
391
373
|
}
|
|
392
374
|
}
|
|
393
375
|
|
|
@@ -502,13 +484,15 @@ float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
|
|
|
502
484
|
ex_code + (d * ex_bits + 7) / 8);
|
|
503
485
|
|
|
504
486
|
// Call shared utility directly with rotated_q pointer
|
|
487
|
+
float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
488
|
+
? query_fac.q_dot_c
|
|
489
|
+
: query_fac.qr_to_c_L2sqr;
|
|
505
490
|
return rabitq_utils::compute_full_multibit_distance(
|
|
506
491
|
binary_data,
|
|
507
492
|
ex_code,
|
|
508
493
|
*ex_fac,
|
|
509
494
|
rotated_q.data(),
|
|
510
|
-
|
|
511
|
-
query_fac.qr_norm_L2sqr,
|
|
495
|
+
qr_base,
|
|
512
496
|
d,
|
|
513
497
|
ex_bits,
|
|
514
498
|
metric_type);
|
|
@@ -103,10 +103,8 @@ struct RaBitQuantizer : Quantizer {
|
|
|
103
103
|
//
|
|
104
104
|
// 1. distance_to_code_1bit() - Fast 1-bit filtering using only sign bits
|
|
105
105
|
// 2. distance_to_code_full() - Accurate multi-bit refinement using all bits
|
|
106
|
-
// 3. lower_bound_distance() - Error-bounded adaptive filtering
|
|
107
|
-
// (based on 1-bit estimator)
|
|
108
106
|
//
|
|
109
|
-
// These
|
|
107
|
+
// These methods implement RaBitQ's two-stage search pattern and are
|
|
110
108
|
// shared between the quantized (Q) and non-quantized (NotQ) query variants.
|
|
111
109
|
// The intermediate class allows two-stage search code to work with both
|
|
112
110
|
// variants via a single dynamic_cast.
|
|
@@ -116,8 +114,8 @@ struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
116
114
|
MetricType metric_type = MetricType::METRIC_L2;
|
|
117
115
|
size_t nb_bits = 1;
|
|
118
116
|
|
|
119
|
-
// Query
|
|
120
|
-
//
|
|
117
|
+
// Query error factor for bound computation (g_error in rabitq-library)
|
|
118
|
+
// Used with f_error to compute error bounds for two-stage filtering
|
|
121
119
|
float g_error = 0.0f;
|
|
122
120
|
|
|
123
121
|
float symmetric_dis(idx_t /*i*/, idx_t /*j*/) override {
|
|
@@ -131,11 +129,6 @@ struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
131
129
|
// Compute full multi-bit distance (accurate)
|
|
132
130
|
virtual float distance_to_code_full(const uint8_t* code) = 0;
|
|
133
131
|
|
|
134
|
-
// Compute lower bound of distance using error bounds
|
|
135
|
-
// Guarantees: actual_distance >= lower_bound_distance
|
|
136
|
-
// Used for adaptive filtering in two-stage search
|
|
137
|
-
virtual float lower_bound_distance(const uint8_t* code);
|
|
138
|
-
|
|
139
132
|
// Override from FlatCodesDistanceComputer
|
|
140
133
|
// Delegates to distance_to_code_full() for multi-bit distance computation
|
|
141
134
|
float distance_to_code(const uint8_t* code) final {
|
|
@@ -180,9 +180,7 @@ void pack_multibit_codes(
|
|
|
180
180
|
*
|
|
181
181
|
* @param residual Original residual vector (data - centroid)
|
|
182
182
|
* @param centroid Centroid vector (can be nullptr for zero centroid)
|
|
183
|
-
* @param tmp_code Quantized ex-bit codes (before packing, after bit flipping)
|
|
184
183
|
* @param d Dimensionality
|
|
185
|
-
* @param ex_bits Number of extra bits
|
|
186
184
|
* @param norm L2 norm of residual
|
|
187
185
|
* @param ipnorm Unnormalized inner product between quantized and normalized
|
|
188
186
|
* residual
|
|
@@ -192,9 +190,7 @@ void pack_multibit_codes(
|
|
|
192
190
|
void compute_ex_factors(
|
|
193
191
|
const float* residual,
|
|
194
192
|
const float* centroid,
|
|
195
|
-
const int* tmp_code,
|
|
196
193
|
size_t d,
|
|
197
|
-
size_t ex_bits,
|
|
198
194
|
float norm,
|
|
199
195
|
double ipnorm,
|
|
200
196
|
ExtraBitsFactors& ex_factors,
|
|
@@ -210,45 +206,23 @@ void compute_ex_factors(
|
|
|
210
206
|
ipnorm_inv = 1.0f;
|
|
211
207
|
}
|
|
212
208
|
|
|
213
|
-
// Reconstruct xu_cb from total_code
|
|
214
|
-
// total_code was formed from: total_code[i] = (sign << ex_bits) +
|
|
215
|
-
// ex_code[i] Reconstruction: xu_cb[i] = total_code[i] + cb
|
|
216
|
-
const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
|
|
217
|
-
std::vector<float> xu_cb(d);
|
|
218
|
-
for (size_t i = 0; i < d; i++) {
|
|
219
|
-
xu_cb[i] = static_cast<float>(tmp_code[i]) + cb;
|
|
220
|
-
}
|
|
221
|
-
|
|
222
209
|
// Compute inner products needed for factors
|
|
223
|
-
float l2_sqr = norm * norm;
|
|
224
|
-
float ip_resi_xucb = fvec_inner_product(residual, xu_cb.data(), d);
|
|
210
|
+
float l2_sqr = norm * norm; // ||residual||^2 = ||x - c||^2
|
|
225
211
|
|
|
226
|
-
// Compute factors
|
|
227
212
|
if (metric_type == MetricType::METRIC_L2) {
|
|
228
|
-
// For L2
|
|
229
|
-
//
|
|
213
|
+
// For L2: f_add_ex = ||residual||^2
|
|
214
|
+
// No centroid correction needed in IVF setting because
|
|
215
|
+
// residual = x - centroid, distance computed in residual space
|
|
230
216
|
ex_factors.f_add_ex = l2_sqr;
|
|
231
217
|
ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
|
|
232
218
|
} else {
|
|
233
|
-
// For IP
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
if (centroid != nullptr) {
|
|
241
|
-
ip_cent_xucb = fvec_inner_product(centroid, xu_cb.data(), d);
|
|
242
|
-
}
|
|
243
|
-
|
|
244
|
-
// When ip_resi_xucb is zero, the correction term should be zero
|
|
245
|
-
float correction_term = 0.0f;
|
|
246
|
-
if (ip_resi_xucb != 0.0f) {
|
|
247
|
-
correction_term = l2_sqr * ip_cent_xucb / ip_resi_xucb;
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
ex_factors.f_add_ex = 1 - ip_resi_cent + correction_term;
|
|
251
|
-
ex_factors.f_rescale_ex = ipnorm_inv * -norm;
|
|
219
|
+
// For IP: direct dot-product formulation
|
|
220
|
+
// f_add_ex = <c, r> (dot product of centroid and residual)
|
|
221
|
+
// f_rescale_ex = ||r|| / ipnorm (positive scaling)
|
|
222
|
+
float c_dot_r =
|
|
223
|
+
centroid ? fvec_inner_product(residual, centroid, d) : 0.0f;
|
|
224
|
+
ex_factors.f_add_ex = c_dot_r;
|
|
225
|
+
ex_factors.f_rescale_ex = ipnorm_inv * norm;
|
|
252
226
|
}
|
|
253
227
|
}
|
|
254
228
|
|
|
@@ -290,12 +264,14 @@ void quantize_ex_bits(
|
|
|
290
264
|
float norm_sqr = fvec_norm_L2sqr(residual, d);
|
|
291
265
|
float norm = std::sqrt(norm_sqr);
|
|
292
266
|
|
|
293
|
-
// Handle degenerate case
|
|
267
|
+
// Handle degenerate case: residual is (near-)zero, meaning x ≈ centroid.
|
|
268
|
+
// For both L2 and IP, f_add_ex and f_rescale_ex are trivially zero:
|
|
269
|
+
// L2: ||r||² ≈ 0, IP: <c,r> ≈ 0 and ||r||/ipnorm ≈ 0
|
|
294
270
|
if (norm < 1e-10f) {
|
|
295
271
|
size_t code_size = (d * ex_bits + 7) / 8;
|
|
296
272
|
memset(ex_code, 0, code_size);
|
|
297
273
|
ex_factors.f_add_ex = 0.0f;
|
|
298
|
-
ex_factors.f_rescale_ex =
|
|
274
|
+
ex_factors.f_rescale_ex = 0.0f;
|
|
299
275
|
return;
|
|
300
276
|
}
|
|
301
277
|
|
|
@@ -349,9 +325,7 @@ void quantize_ex_bits(
|
|
|
349
325
|
compute_ex_factors(
|
|
350
326
|
residual,
|
|
351
327
|
centroid, // Pass centroid for IP metric factor computation
|
|
352
|
-
total_code.data(),
|
|
353
328
|
d,
|
|
354
|
-
ex_bits,
|
|
355
329
|
norm,
|
|
356
330
|
ipnorm,
|
|
357
331
|
ex_factors,
|
|
@@ -60,9 +60,7 @@ void pack_multibit_codes(
|
|
|
60
60
|
*
|
|
61
61
|
* @param residual Original residual vector (data - centroid)
|
|
62
62
|
* @param centroid Centroid vector (can be nullptr for zero centroid)
|
|
63
|
-
* @param tmp_code Quantized ex-bit codes (unpacked integers)
|
|
64
63
|
* @param d Dimensionality
|
|
65
|
-
* @param ex_bits Number of extra bits
|
|
66
64
|
* @param norm L2 norm of residual
|
|
67
65
|
* @param ipnorm Unnormalized inner product
|
|
68
66
|
* @param ex_factors Output factors structure
|
|
@@ -71,9 +69,7 @@ void pack_multibit_codes(
|
|
|
71
69
|
void compute_ex_factors(
|
|
72
70
|
const float* residual,
|
|
73
71
|
const float* centroid,
|
|
74
|
-
const int* tmp_code,
|
|
75
72
|
size_t d,
|
|
76
|
-
size_t ex_bits,
|
|
77
73
|
float norm,
|
|
78
74
|
double ipnorm,
|
|
79
75
|
rabitq_utils::ExtraBitsFactors& ex_factors,
|
|
@@ -18,6 +18,7 @@
|
|
|
18
18
|
#include <faiss/VectorTransform.h>
|
|
19
19
|
#include <faiss/impl/FaissAssert.h>
|
|
20
20
|
#include <faiss/impl/residual_quantizer_encode_steps.h>
|
|
21
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
21
22
|
#include <faiss/utils/distances.h>
|
|
22
23
|
#include <faiss/utils/hamming.h>
|
|
23
24
|
#include <faiss/utils/utils.h>
|
|
@@ -274,10 +275,12 @@ void ResidualQuantizer::train(size_t n, const float* x) {
|
|
|
274
275
|
// find min and max norms
|
|
275
276
|
std::vector<float> norms(n);
|
|
276
277
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
278
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
279
|
+
for (size_t i = 0; i < n; i++) {
|
|
280
|
+
norms[i] = fvec_L2sqr<SL>(
|
|
281
|
+
x + i * d, residuals.data() + i * cur_beam_size * d, d);
|
|
282
|
+
}
|
|
283
|
+
});
|
|
281
284
|
|
|
282
285
|
// fvec_norms_L2sqr(norms.data(), x, d, n);
|
|
283
286
|
train_norm(n, norms.data());
|
|
@@ -393,11 +396,13 @@ float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
|
|
|
393
396
|
}
|
|
394
397
|
|
|
395
398
|
float output_recons_error = 0;
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
399
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
400
|
+
for (size_t j = 0; j < d; j++) {
|
|
401
|
+
output_recons_error += fvec_norm_L2sqr<SL>(
|
|
402
|
+
xt.data() + total_codebook_size + n * j,
|
|
403
|
+
n - total_codebook_size);
|
|
404
|
+
}
|
|
405
|
+
});
|
|
401
406
|
if (verbose) {
|
|
402
407
|
printf(" output quantization error %g\n", output_recons_error);
|
|
403
408
|
}
|