faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -9,8 +9,10 @@
9
9
 
10
10
  #include <faiss/MetricType.h>
11
11
  #include <faiss/impl/platform_macros.h>
12
+ #include <faiss/utils/AlignedTable.h>
12
13
  #include <cstddef>
13
14
  #include <cstdint>
15
+ #include <cstring>
14
16
  #include <vector>
15
17
 
16
18
  namespace faiss {
@@ -68,6 +70,7 @@ struct QueryFactorsData {
68
70
 
69
71
  float qr_to_c_L2sqr = 0;
70
72
  float qr_norm_L2sqr = 0;
73
+ float q_dot_c = 0; // <query, centroid> for IP metric; 0 for L2
71
74
 
72
75
  float int_dot_scale = 1;
73
76
 
@@ -239,6 +242,41 @@ inline float compute_1bit_adjusted_distance(
239
242
  return adjusted_distance;
240
243
  }
241
244
 
245
+ /** Determine whether a candidate should be refined in two-stage search.
246
+ * Consolidates the filtering logic for both L2 and IP metrics.
247
+ *
248
+ * For L2 (min-heap): uses lower_bound = est_distance - error_adjustment
249
+ * - Skip if lower_bound >= threshold (can't beat current worst)
250
+ * For IP (max-heap): uses upper_bound = est_distance + error_adjustment
251
+ * - Skip if upper_bound <= threshold (can't beat current best)
252
+ *
253
+ * @param est_distance Estimated 1-bit distance
254
+ * @param f_error Database vector error factor
255
+ * @param g_error Query vector error factor
256
+ * @param threshold Current heap threshold (worst result in heap)
257
+ * @param is_similarity True for IP metric (max-heap), false for L2
258
+ * (min-heap)
259
+ * @return True if candidate should be refined with full
260
+ * multi-bit distance
261
+ */
262
+ inline bool should_refine_candidate(
263
+ float est_distance,
264
+ float f_error,
265
+ float g_error,
266
+ float threshold,
267
+ bool is_similarity) {
268
+ float error_adjustment = f_error * g_error;
269
+ if (is_similarity) {
270
+ // IP (max-heap): use upper bound for filtering
271
+ float upper_bound = est_distance + error_adjustment;
272
+ return upper_bound > threshold;
273
+ } else {
274
+ // L2 (min-heap): use lower bound for filtering
275
+ float lower_bound = std::max(0.0f, est_distance - error_adjustment);
276
+ return lower_bound < threshold;
277
+ }
278
+ }
279
+
242
280
  /** Extract multi-bit code on-the-fly from packed ex-bit codes.
243
281
  * This inline function extracts a single code value without unpacking the
244
282
  * entire array, enabling efficient on-the-fly decoding during distance
@@ -276,55 +314,79 @@ inline int extract_code_inline(
276
314
  *
277
315
  * The multi-bit distance combines the sign bit (1-bit) with additional
278
316
  * magnitude bits (ex_bits) to compute a more accurate distance estimate.
317
+ * Uses SIMD-optimized bit-plane decomposition (AVX2+BMI2) for ex_bits 1-7,
318
+ * with scalar fallback for non-x86 or non-BMI2 platforms.
279
319
  *
280
320
  * @param sign_bits unpacked sign bits (1-bit codes in standard format)
281
321
  * @param ex_code packed ex-bit codes
282
322
  * @param ex_fac ex-bit factors (f_add_ex, f_rescale_ex)
283
323
  * @param rotated_q rotated query vector
284
- * @param qr_to_c_L2sqr precomputed ||query_rotated - centroid||^2
285
- * @param qr_norm_L2sqr precomputed ||query_rotated||^2 (0 for L2 metric)
324
+ * @param qr_base precomputed base term: ||q-c||^2 for L2, <q,c> for IP
286
325
  * @param d dimensionality
287
326
  * @param ex_bits number of extra bits (nb_bits - 1)
288
327
  * @param metric_type distance metric (L2 or Inner Product)
289
328
  * @return computed full multi-bit distance
290
329
  */
291
- inline float compute_full_multibit_distance(
330
+ float compute_full_multibit_distance(
292
331
  const uint8_t* sign_bits,
293
332
  const uint8_t* ex_code,
294
333
  const ExtraBitsFactors& ex_fac,
295
334
  const float* rotated_q,
296
- float qr_to_c_L2sqr,
297
- float qr_norm_L2sqr,
335
+ float qr_base,
298
336
  size_t d,
299
337
  size_t ex_bits,
300
- MetricType metric_type) {
301
- float ex_ip = 0.0f;
302
- const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
303
-
304
- for (size_t i = 0; i < d; i++) {
305
- const size_t byte_idx = i / 8;
306
- const size_t bit_offset = i % 8;
307
- const bool sign_bit = (sign_bits[byte_idx] >> bit_offset) & 1;
308
-
309
- int ex_code_val = extract_code_inline(ex_code, i, ex_bits);
310
-
311
- int total_code = (sign_bit ? 1 : 0) << ex_bits;
312
- total_code += ex_code_val;
313
- float reconstructed = static_cast<float>(total_code) + cb;
314
-
315
- ex_ip += rotated_q[i] * reconstructed;
316
- }
317
-
318
- float dist = qr_to_c_L2sqr + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
338
+ MetricType metric_type);
339
+
340
+ /** Compute pointer to a vector's auxiliary data within block layout. */
341
+ template <typename T>
342
+ inline T* get_block_aux_ptr(
343
+ T* block_data,
344
+ size_t vec_pos,
345
+ size_t bbs,
346
+ size_t packed_block_size,
347
+ size_t full_block_size,
348
+ size_t storage_size) {
349
+ return block_data + (vec_pos / bbs) * full_block_size + packed_block_size +
350
+ (vec_pos % bbs) * storage_size;
351
+ }
319
352
 
320
- if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
321
- dist = -0.5f * (dist - qr_norm_L2sqr);
322
- } else {
323
- dist = std::max(0.0f, dist);
324
- }
353
+ /** Compute per-vector auxiliary storage size.
354
+ *
355
+ * @param nb_bits number of quantization bits (1 = sign-bit only)
356
+ * @param d dimensionality
357
+ * @return storage size in bytes
358
+ */
359
+ size_t compute_per_vector_storage_size(size_t nb_bits, size_t d);
325
360
 
326
- return dist;
327
- }
361
+ /** [LEGACY FORMAT SUPPORT] Migrate block data from old I/O format to new
362
+ * format.
363
+ *
364
+ * This function is used only when reading indexes saved with the legacy format
365
+ * (fourcc "Irfs"/"Iwrf") to convert them to the new embedded auxiliary data
366
+ * format. Not needed for indexes saved with the new format ("Irfn"/"Iwrn").
367
+ *
368
+ * Re-layouts blocks in-place and copies aux data from flat_storage.
369
+ *
370
+ * @param flat_storage legacy per-vector aux data indexed by global ID
371
+ * @param codes block data (will be resized and re-laid out)
372
+ * @param num_vectors number of vectors in this segment
373
+ * @param bbs block batch size (vectors per block)
374
+ * @param M2 rounded sub-quantizer count
375
+ * @param old_block_stride old block size (packed codes only, or current)
376
+ * @param new_block_stride new block size (packed codes + aux region)
377
+ * @param storage_size per-vector aux storage size in bytes
378
+ * @param id_map maps local offset to global ID; null = sequential
379
+ */
380
+ void populate_block_aux_from_flat_storage(
381
+ const std::vector<uint8_t>& flat_storage,
382
+ AlignedTable<uint8_t>& codes,
383
+ size_t num_vectors,
384
+ size_t bbs,
385
+ size_t M2,
386
+ size_t old_block_stride,
387
+ size_t new_block_stride,
388
+ size_t storage_size,
389
+ const int64_t* id_map = nullptr);
328
390
 
329
391
  } // namespace rabitq_utils
330
392
  } // namespace faiss
@@ -12,6 +12,7 @@
12
12
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
13
13
  #include <faiss/utils/distances.h>
14
14
  #include <faiss/utils/rabitq_simd.h>
15
+
15
16
  #include <algorithm>
16
17
  #include <cmath>
17
18
  #include <cstring>
@@ -63,7 +64,7 @@ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
63
64
  return base_size + ex_size;
64
65
  }
65
66
 
66
- void RaBitQuantizer::train(size_t n, const float* x) {
67
+ void RaBitQuantizer::train(size_t /*n*/, const float* /*x*/) {
67
68
  // does nothing
68
69
  }
69
70
 
@@ -215,29 +216,6 @@ void RaBitQuantizer::decode_core(
215
216
  }
216
217
  }
217
218
 
218
- // Implementation of RaBitQDistanceComputer (declared in header)
219
-
220
- float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
221
- FAISS_ASSERT(code != nullptr);
222
-
223
- // Compute estimated distance using 1-bit codes
224
- float est_distance = distance_to_code_1bit(code);
225
-
226
- // Extract f_error from the code
227
- size_t size = (d + 7) / 8;
228
- const SignBitFactorsWithError* base_fac =
229
- reinterpret_cast<const SignBitFactorsWithError*>(code + size);
230
- float f_error = base_fac->f_error;
231
-
232
- // Compute proper lower bound using RaBitQ error formula:
233
- // lower_bound = est_distance - f_error * g_error
234
- // This guarantees: lower_bound ≤ true_distance
235
- float lower_bound = est_distance - (f_error * g_error);
236
-
237
- // Distance cannot be negative
238
- return std::max(0.0f, lower_bound);
239
- }
240
-
241
219
  namespace {
242
220
 
243
221
  struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
@@ -336,13 +314,15 @@ float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
336
314
  ex_code + (d * ex_bits + 7) / 8);
337
315
 
338
316
  // Call shared utility directly with rotated_q pointer
317
+ float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
318
+ ? query_fac.q_dot_c
319
+ : query_fac.qr_to_c_L2sqr;
339
320
  return rabitq_utils::compute_full_multibit_distance(
340
321
  binary_data,
341
322
  ex_code,
342
323
  *ex_fac,
343
324
  rotated_q.data(),
344
- query_fac.qr_to_c_L2sqr,
345
- query_fac.qr_norm_L2sqr,
325
+ qr_base,
346
326
  d,
347
327
  ex_bits,
348
328
  metric_type);
@@ -388,6 +368,8 @@ void RaBitQDistanceComputerNotQ::set_query(const float* x) {
388
368
  if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
389
369
  // precompute if needed
390
370
  query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
371
+ query_fac.q_dot_c =
372
+ centroid ? fvec_inner_product(x, centroid, d) : 0.0f;
391
373
  }
392
374
  }
393
375
 
@@ -502,13 +484,15 @@ float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
502
484
  ex_code + (d * ex_bits + 7) / 8);
503
485
 
504
486
  // Call shared utility directly with rotated_q pointer
487
+ float qr_base = (metric_type == MetricType::METRIC_INNER_PRODUCT)
488
+ ? query_fac.q_dot_c
489
+ : query_fac.qr_to_c_L2sqr;
505
490
  return rabitq_utils::compute_full_multibit_distance(
506
491
  binary_data,
507
492
  ex_code,
508
493
  *ex_fac,
509
494
  rotated_q.data(),
510
- query_fac.qr_to_c_L2sqr,
511
- query_fac.qr_norm_L2sqr,
495
+ qr_base,
512
496
  d,
513
497
  ex_bits,
514
498
  metric_type);
@@ -103,10 +103,8 @@ struct RaBitQuantizer : Quantizer {
103
103
  //
104
104
  // 1. distance_to_code_1bit() - Fast 1-bit filtering using only sign bits
105
105
  // 2. distance_to_code_full() - Accurate multi-bit refinement using all bits
106
- // 3. lower_bound_distance() - Error-bounded adaptive filtering
107
- // (based on 1-bit estimator)
108
106
  //
109
- // These three methods implement RaBitQ's two-stage search pattern and are
107
+ // These methods implement RaBitQ's two-stage search pattern and are
110
108
  // shared between the quantized (Q) and non-quantized (NotQ) query variants.
111
109
  // The intermediate class allows two-stage search code to work with both
112
110
  // variants via a single dynamic_cast.
@@ -116,8 +114,8 @@ struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
116
114
  MetricType metric_type = MetricType::METRIC_L2;
117
115
  size_t nb_bits = 1;
118
116
 
119
- // Query norm for lower bound computation (g_error in rabitq-library)
120
- // This is the L2 norm of the rotated query: ||query - centroid||
117
+ // Query error factor for bound computation (g_error in rabitq-library)
118
+ // Used with f_error to compute error bounds for two-stage filtering
121
119
  float g_error = 0.0f;
122
120
 
123
121
  float symmetric_dis(idx_t /*i*/, idx_t /*j*/) override {
@@ -131,11 +129,6 @@ struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
131
129
  // Compute full multi-bit distance (accurate)
132
130
  virtual float distance_to_code_full(const uint8_t* code) = 0;
133
131
 
134
- // Compute lower bound of distance using error bounds
135
- // Guarantees: actual_distance >= lower_bound_distance
136
- // Used for adaptive filtering in two-stage search
137
- virtual float lower_bound_distance(const uint8_t* code);
138
-
139
132
  // Override from FlatCodesDistanceComputer
140
133
  // Delegates to distance_to_code_full() for multi-bit distance computation
141
134
  float distance_to_code(const uint8_t* code) final {
@@ -180,9 +180,7 @@ void pack_multibit_codes(
180
180
  *
181
181
  * @param residual Original residual vector (data - centroid)
182
182
  * @param centroid Centroid vector (can be nullptr for zero centroid)
183
- * @param tmp_code Quantized ex-bit codes (before packing, after bit flipping)
184
183
  * @param d Dimensionality
185
- * @param ex_bits Number of extra bits
186
184
  * @param norm L2 norm of residual
187
185
  * @param ipnorm Unnormalized inner product between quantized and normalized
188
186
  * residual
@@ -192,9 +190,7 @@ void pack_multibit_codes(
192
190
  void compute_ex_factors(
193
191
  const float* residual,
194
192
  const float* centroid,
195
- const int* tmp_code,
196
193
  size_t d,
197
- size_t ex_bits,
198
194
  float norm,
199
195
  double ipnorm,
200
196
  ExtraBitsFactors& ex_factors,
@@ -210,45 +206,23 @@ void compute_ex_factors(
210
206
  ipnorm_inv = 1.0f;
211
207
  }
212
208
 
213
- // Reconstruct xu_cb from total_code
214
- // total_code was formed from: total_code[i] = (sign << ex_bits) +
215
- // ex_code[i] Reconstruction: xu_cb[i] = total_code[i] + cb
216
- const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
217
- std::vector<float> xu_cb(d);
218
- for (size_t i = 0; i < d; i++) {
219
- xu_cb[i] = static_cast<float>(tmp_code[i]) + cb;
220
- }
221
-
222
209
  // Compute inner products needed for factors
223
- float l2_sqr = norm * norm;
224
- float ip_resi_xucb = fvec_inner_product(residual, xu_cb.data(), d);
210
+ float l2_sqr = norm * norm; // ||residual||^2 = ||x - c||^2
225
211
 
226
- // Compute factors
227
212
  if (metric_type == MetricType::METRIC_L2) {
228
- // For L2, no centroid correction needed in IVF setting
229
- // because residual = x - centroid, distance computed in residual space
213
+ // For L2: f_add_ex = ||residual||^2
214
+ // No centroid correction needed in IVF setting because
215
+ // residual = x - centroid, distance computed in residual space
230
216
  ex_factors.f_add_ex = l2_sqr;
231
217
  ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
232
218
  } else {
233
- // For IP, centroid correction is needed
234
- float ip_resi_cent = 0;
235
- if (centroid != nullptr) {
236
- ip_resi_cent = fvec_inner_product(residual, centroid, d);
237
- }
238
-
239
- float ip_cent_xucb = 0;
240
- if (centroid != nullptr) {
241
- ip_cent_xucb = fvec_inner_product(centroid, xu_cb.data(), d);
242
- }
243
-
244
- // When ip_resi_xucb is zero, the correction term should be zero
245
- float correction_term = 0.0f;
246
- if (ip_resi_xucb != 0.0f) {
247
- correction_term = l2_sqr * ip_cent_xucb / ip_resi_xucb;
248
- }
249
-
250
- ex_factors.f_add_ex = 1 - ip_resi_cent + correction_term;
251
- ex_factors.f_rescale_ex = ipnorm_inv * -norm;
219
+ // For IP: direct dot-product formulation
220
+ // f_add_ex = <c, r> (dot product of centroid and residual)
221
+ // f_rescale_ex = ||r|| / ipnorm (positive scaling)
222
+ float c_dot_r =
223
+ centroid ? fvec_inner_product(residual, centroid, d) : 0.0f;
224
+ ex_factors.f_add_ex = c_dot_r;
225
+ ex_factors.f_rescale_ex = ipnorm_inv * norm;
252
226
  }
253
227
  }
254
228
 
@@ -290,12 +264,14 @@ void quantize_ex_bits(
290
264
  float norm_sqr = fvec_norm_L2sqr(residual, d);
291
265
  float norm = std::sqrt(norm_sqr);
292
266
 
293
- // Handle degenerate case
267
+ // Handle degenerate case: residual is (near-)zero, meaning x ≈ centroid.
268
+ // For both L2 and IP, f_add_ex and f_rescale_ex are trivially zero:
269
+ // L2: ||r||² ≈ 0, IP: <c,r> ≈ 0 and ||r||/ipnorm ≈ 0
294
270
  if (norm < 1e-10f) {
295
271
  size_t code_size = (d * ex_bits + 7) / 8;
296
272
  memset(ex_code, 0, code_size);
297
273
  ex_factors.f_add_ex = 0.0f;
298
- ex_factors.f_rescale_ex = 1.0f;
274
+ ex_factors.f_rescale_ex = 0.0f;
299
275
  return;
300
276
  }
301
277
 
@@ -349,9 +325,7 @@ void quantize_ex_bits(
349
325
  compute_ex_factors(
350
326
  residual,
351
327
  centroid, // Pass centroid for IP metric factor computation
352
- total_code.data(),
353
328
  d,
354
- ex_bits,
355
329
  norm,
356
330
  ipnorm,
357
331
  ex_factors,
@@ -60,9 +60,7 @@ void pack_multibit_codes(
60
60
  *
61
61
  * @param residual Original residual vector (data - centroid)
62
62
  * @param centroid Centroid vector (can be nullptr for zero centroid)
63
- * @param tmp_code Quantized ex-bit codes (unpacked integers)
64
63
  * @param d Dimensionality
65
- * @param ex_bits Number of extra bits
66
64
  * @param norm L2 norm of residual
67
65
  * @param ipnorm Unnormalized inner product
68
66
  * @param ex_factors Output factors structure
@@ -71,9 +69,7 @@ void pack_multibit_codes(
71
69
  void compute_ex_factors(
72
70
  const float* residual,
73
71
  const float* centroid,
74
- const int* tmp_code,
75
72
  size_t d,
76
- size_t ex_bits,
77
73
  float norm,
78
74
  double ipnorm,
79
75
  rabitq_utils::ExtraBitsFactors& ex_factors,
@@ -18,6 +18,7 @@
18
18
  #include <faiss/VectorTransform.h>
19
19
  #include <faiss/impl/FaissAssert.h>
20
20
  #include <faiss/impl/residual_quantizer_encode_steps.h>
21
+ #include <faiss/impl/simd_dispatch.h>
21
22
  #include <faiss/utils/distances.h>
22
23
  #include <faiss/utils/hamming.h>
23
24
  #include <faiss/utils/utils.h>
@@ -274,10 +275,12 @@ void ResidualQuantizer::train(size_t n, const float* x) {
274
275
  // find min and max norms
275
276
  std::vector<float> norms(n);
276
277
 
277
- for (size_t i = 0; i < n; i++) {
278
- norms[i] = fvec_L2sqr(
279
- x + i * d, residuals.data() + i * cur_beam_size * d, d);
280
- }
278
+ with_simd_level([&]<SIMDLevel SL>() {
279
+ for (size_t i = 0; i < n; i++) {
280
+ norms[i] = fvec_L2sqr<SL>(
281
+ x + i * d, residuals.data() + i * cur_beam_size * d, d);
282
+ }
283
+ });
281
284
 
282
285
  // fvec_norms_L2sqr(norms.data(), x, d, n);
283
286
  train_norm(n, norms.data());
@@ -393,11 +396,13 @@ float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
393
396
  }
394
397
 
395
398
  float output_recons_error = 0;
396
- for (size_t j = 0; j < d; j++) {
397
- output_recons_error += fvec_norm_L2sqr(
398
- xt.data() + total_codebook_size + n * j,
399
- n - total_codebook_size);
400
- }
399
+ with_simd_level([&]<SIMDLevel SL>() {
400
+ for (size_t j = 0; j < d; j++) {
401
+ output_recons_error += fvec_norm_L2sqr<SL>(
402
+ xt.data() + total_codebook_size + n * j,
403
+ n - total_codebook_size);
404
+ }
405
+ });
401
406
  if (verbose) {
402
407
  printf(" output quantization error %g\n", output_recons_error);
403
408
  }