faiss 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +5 -6
  5. data/ext/faiss/index_binary.cpp +38 -28
  6. data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
  7. data/ext/faiss/kmeans.cpp +10 -9
  8. data/ext/faiss/pca_matrix.cpp +10 -8
  9. data/ext/faiss/product_quantizer.cpp +14 -12
  10. data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
  11. data/ext/faiss/{utils.h → utils_rb.h} +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  17. data/vendor/faiss/faiss/Clustering.h +12 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  19. data/vendor/faiss/faiss/Index.cpp +20 -8
  20. data/vendor/faiss/faiss/Index.h +25 -3
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  22. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  25. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  26. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  27. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  28. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  29. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  30. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  31. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  32. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  33. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  35. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  42. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  43. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  44. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  45. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  46. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  47. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  48. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  49. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  50. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  51. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  52. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  53. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  54. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  56. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  57. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  58. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  59. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  60. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  61. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  62. data/vendor/faiss/faiss/MetricType.h +16 -0
  63. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  64. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  65. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  66. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  67. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  68. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  69. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  70. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  71. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  72. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  73. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  74. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  75. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  76. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  77. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  78. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  79. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  80. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  81. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  82. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  83. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  84. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  85. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  86. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  87. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  88. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  89. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  90. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  91. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  92. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  93. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  94. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  95. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  96. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  97. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  98. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  99. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  100. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  101. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  102. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  103. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  104. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  105. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  106. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  109. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  110. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  111. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  112. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  113. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  114. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  115. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  116. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  126. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  127. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  128. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  129. data/vendor/faiss/faiss/index_io.h +29 -3
  130. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  131. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  132. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  133. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  134. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  135. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  136. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  137. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  138. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  139. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  140. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  141. data/vendor/faiss/faiss/utils/distances.h +98 -0
  142. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  143. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  144. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  145. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  146. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  147. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  148. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  149. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  150. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  151. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  152. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  157. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  158. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  159. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  160. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  161. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  162. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  163. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  164. metadata +47 -18
  165. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  166. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  167. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/IndexLattice.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
+ #include <faiss/impl/simd_dispatch.h>
12
13
  #include <faiss/utils/distances.h>
13
14
  #include <faiss/utils/hamming.h> // for the bitstring routines
14
15
 
@@ -44,17 +45,19 @@ void IndexLattice::train(idx_t n, const float* x) {
44
45
  maxs[sq] = -1;
45
46
  }
46
47
 
47
- for (idx_t i = 0; i < n; i++) {
48
- for (int sq = 0; sq < nsq; sq++) {
49
- float norm2 = fvec_norm_L2sqr(x + i * d + sq * dsq, dsq);
50
- if (norm2 > maxs[sq]) {
51
- maxs[sq] = norm2;
52
- }
53
- if (norm2 < mins[sq]) {
54
- mins[sq] = norm2;
48
+ with_simd_level([&]<SIMDLevel SL>() {
49
+ for (idx_t i = 0; i < n; i++) {
50
+ for (int sq = 0; sq < nsq; sq++) {
51
+ float norm2 = fvec_norm_L2sqr<SL>(x + i * d + sq * dsq, dsq);
52
+ if (norm2 > maxs[sq]) {
53
+ maxs[sq] = norm2;
54
+ }
55
+ if (norm2 < mins[sq]) {
56
+ mins[sq] = norm2;
57
+ }
55
58
  }
56
59
  }
57
- }
60
+ });
58
61
 
59
62
  for (int sq = 0; sq < nsq; sq++) {
60
63
  mins[sq] = sqrtf(mins[sq]);
@@ -74,24 +77,26 @@ void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
74
77
  const float* maxs = mins + nsq;
75
78
  int64_t sc = int64_t(1) << scale_nbit;
76
79
 
80
+ with_simd_level([&]<SIMDLevel SL>() {
77
81
  #pragma omp parallel for
78
- for (idx_t i = 0; i < n; i++) {
79
- BitstringWriter wr(codes + i * code_size, code_size);
80
- const float* xi = x + i * d;
81
- for (int j = 0; j < nsq; j++) {
82
- float nj = (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) * sc /
83
- (maxs[j] - mins[j]);
84
- if (nj < 0) {
85
- nj = 0;
86
- }
87
- if (nj >= sc) {
88
- nj = sc - 1;
82
+ for (idx_t i = 0; i < n; i++) {
83
+ BitstringWriter wr(codes + i * code_size, code_size);
84
+ const float* xi = x + i * d;
85
+ for (int j = 0; j < nsq; j++) {
86
+ float nj = (sqrtf(fvec_norm_L2sqr<SL>(xi, dsq)) - mins[j]) *
87
+ sc / (maxs[j] - mins[j]);
88
+ if (nj < 0) {
89
+ nj = 0;
90
+ }
91
+ if (nj >= sc) {
92
+ nj = sc - 1;
93
+ }
94
+ wr.write((int64_t)nj, scale_nbit);
95
+ wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
96
+ xi += dsq;
89
97
  }
90
- wr.write((int64_t)nj, scale_nbit);
91
- wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
92
- xi += dsq;
93
98
  }
94
- }
99
+ });
95
100
  }
96
101
 
97
102
  void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
@@ -19,6 +19,7 @@
19
19
  #include <faiss/IndexFlat.h>
20
20
  #include <faiss/impl/AuxIndexStructures.h>
21
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/VisitedTable.h>
22
23
  #include <faiss/utils/distances.h>
23
24
 
24
25
  extern "C" {
@@ -16,6 +16,7 @@
16
16
  #include <faiss/IndexNNDescent.h>
17
17
  #include <faiss/impl/AuxIndexStructures.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/VisitedTable.h>
19
20
  #include <faiss/utils/distances.h>
20
21
 
21
22
  namespace faiss {
@@ -74,7 +75,7 @@ void IndexNSG::search(
74
75
 
75
76
  #pragma omp parallel
76
77
  {
77
- VisitedTable vt(ntotal);
78
+ VisitedTable vt(ntotal, nsg.use_visited_hashset);
78
79
 
79
80
  std::unique_ptr<DistanceComputer> dis(
80
81
  storage_distance_computer(storage));
@@ -9,8 +9,6 @@
9
9
 
10
10
  #pragma once
11
11
 
12
- #include <vector>
13
-
14
12
  #include <faiss/IndexFlat.h>
15
13
  #include <faiss/IndexNNDescent.h>
16
14
  #include <faiss/IndexPQ.h>
@@ -24,7 +24,7 @@ IndexNeuralNetCodec::IndexNeuralNetCodec(
24
24
  is_trained = false;
25
25
  }
26
26
 
27
- void IndexNeuralNetCodec::train(idx_t n, const float* x) {
27
+ void IndexNeuralNetCodec::train(idx_t /*n*/, const float* /*x*/) {
28
28
  FAISS_THROW_MSG("Training not implemented in C++, use Pytorch");
29
29
  }
30
30
 
@@ -19,7 +19,8 @@
19
19
  #include <faiss/impl/FaissAssert.h>
20
20
  #include <faiss/utils/hamming.h>
21
21
 
22
- #include <faiss/impl/code_distance/code_distance.h>
22
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
23
+ #include <faiss/impl/simd_dispatch.h>
23
24
 
24
25
  namespace faiss {
25
26
 
@@ -72,8 +73,9 @@ void IndexPQ::train(idx_t n, const float* x) {
72
73
 
73
74
  namespace {
74
75
 
75
- template <class PQDecoder>
76
+ template <class PQCodeDist>
76
77
  struct PQDistanceComputer : FlatCodesDistanceComputer {
78
+ using PQDecoder = typename PQCodeDist::PQDecoder;
77
79
  size_t d;
78
80
  MetricType metric;
79
81
  idx_t nb;
@@ -86,7 +88,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
86
88
  float distance_to_code(const uint8_t* code) final {
87
89
  ndis++;
88
90
 
89
- float dis = distance_single_code<PQDecoder>(
91
+ float dis = PQCodeDist::distance_single_code(
90
92
  pq.M, pq.nbits, precomputed_table.data(), code);
91
93
  return dis;
92
94
  }
@@ -134,16 +136,23 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
134
136
  }
135
137
  };
136
138
 
139
+ template <SIMDLevel SL>
140
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1(
141
+ const IndexPQ& index) {
142
+ if (index.pq.nbits == 8) {
143
+ return new PQDistanceComputer<PQCodeDistance<PQDecoder8, SL>>(index);
144
+ } else if (index.pq.nbits == 16) {
145
+ return new PQDistanceComputer<PQCodeDistance<PQDecoder16, SL>>(index);
146
+ } else {
147
+ return new PQDistanceComputer<PQCodeDistance<PQDecoderGeneric, SL>>(
148
+ index);
149
+ }
150
+ }
151
+
137
152
  } // namespace
138
153
 
139
154
  FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
140
- if (pq.nbits == 8) {
141
- return new PQDistanceComputer<PQDecoder8>(*this);
142
- } else if (pq.nbits == 16) {
143
- return new PQDistanceComputer<PQDecoder16>(*this);
144
- } else {
145
- return new PQDistanceComputer<PQDecoderGeneric>(*this);
146
- }
155
+ DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this);
147
156
  }
148
157
 
149
158
  /*****************************************
@@ -8,6 +8,7 @@
8
8
  #include <faiss/IndexRaBitQ.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/RaBitQUtils.h>
11
12
  #include <faiss/impl/ResultHandler.h>
12
13
  #include <memory>
13
14
 
@@ -16,6 +17,8 @@ namespace faiss {
16
17
  // Forward declaration from RaBitQuantizer.cpp
17
18
  struct RaBitQDistanceComputer;
18
19
 
20
+ using rabitq_utils::SignBitFactorsWithError;
21
+
19
22
  IndexRaBitQ::IndexRaBitQ() = default;
20
23
 
21
24
  IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric, uint8_t nb_bits_in)
@@ -141,19 +144,29 @@ struct Run_search_with_dc_res {
141
144
 
142
145
  local_1bit_evaluations++;
143
146
 
144
- // Stage 1: Compute 1-bit lower bound
145
- float lower_bound = dc->lower_bound_distance(code);
146
-
147
- // Stage 2: Adaptive filtering using threshold
148
- // For L2 (min-heap): filter if lower_bound <
149
- // resi.threshold For IP (max-heap): filter if
150
- // lower_bound > resi.threshold Note: Using
151
- // resi.threshold directly (not cached) enables more
152
- // aggressive filtering as the heap is updated
153
- bool should_refine = is_similarity
154
- ? (lower_bound > resi.threshold)
155
- : (lower_bound < resi.threshold);
156
-
147
+ // Stage 1: Compute distance bound using 1-bit codes
148
+ // For L2 (min-heap): use lower_bound (est -
149
+ // error) For IP (max-heap): use upper_bound (est
150
+ // + error)
151
+ float est_distance =
152
+ dc->distance_to_code_1bit(code);
153
+
154
+ // Extract f_error for filtering
155
+ size_t code_size_base = (index->d + 7) / 8;
156
+ const rabitq_utils::SignBitFactorsWithError*
157
+ base_fac = reinterpret_cast<
158
+ const rabitq_utils::
159
+ SignBitFactorsWithError*>(
160
+ code + code_size_base);
161
+
162
+ // Stage 2: Adaptive filtering
163
+ bool should_refine =
164
+ rabitq_utils::should_refine_candidate(
165
+ est_distance,
166
+ base_fac->f_error,
167
+ dc->g_error,
168
+ resi.threshold,
169
+ is_similarity);
157
170
  if (should_refine) {
158
171
  local_multibit_evaluations++;
159
172
  // Compute full multi-bit distance
@@ -14,7 +14,7 @@
14
14
  namespace faiss {
15
15
 
16
16
  struct RaBitQSearchParameters : SearchParameters {
17
- uint8_t qb = 0;
17
+ uint8_t qb = 4;
18
18
  bool centered = false;
19
19
  };
20
20
 
@@ -26,7 +26,7 @@ struct IndexRaBitQ : IndexFlatCodes {
26
26
 
27
27
  // the default number of bits to quantize a query with.
28
28
  // use '0' to disable quantization and use raw fp32 values.
29
- uint8_t qb = 0;
29
+ uint8_t qb = 4;
30
30
 
31
31
  // quantize the query with a zero-centered scalar quantizer.
32
32
  bool centered = false;
@@ -6,6 +6,7 @@
6
6
  */
7
7
 
8
8
  #include <faiss/IndexRaBitQFastScan.h>
9
+ #include <faiss/impl/CodePackerRaBitQ.h>
9
10
  #include <faiss/impl/FastScanDistancePostProcessing.h>
10
11
  #include <faiss/impl/RaBitQUtils.h>
11
12
  #include <faiss/impl/RaBitQuantizerMultiBit.h>
@@ -21,17 +22,7 @@ static inline size_t roundup(size_t a, size_t b) {
21
22
  }
22
23
 
23
24
  size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
24
- const size_t ex_bits = rabitq.nb_bits - 1;
25
-
26
- if (ex_bits == 0) {
27
- // 1-bit: only SignBitFactors
28
- return sizeof(rabitq_utils::SignBitFactors);
29
- } else {
30
- // Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
31
- // mag-codes
32
- return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
33
- (d * ex_bits + 7) / 8;
34
- }
25
+ return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
35
26
  }
36
27
 
37
28
  IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
@@ -64,9 +55,51 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(
64
55
  // Set RaBitQ-specific parameters
65
56
  qb = 8;
66
57
  center.resize(d, 0.0f);
58
+ }
67
59
 
68
- // Initialize empty flat storage
69
- flat_storage.clear();
60
+ CodePacker* IndexRaBitQFastScan::get_CodePacker() const {
61
+ return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
62
+ }
63
+
64
+ size_t IndexRaBitQFastScan::remove_ids(const IDSelector& sel) {
65
+ const size_t block_stride = get_block_stride();
66
+
67
+ idx_t j = 0;
68
+ std::vector<uint8_t> buffer(code_size);
69
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
70
+ for (idx_t i = 0; i < ntotal; i++) {
71
+ if (sel.is_member(i)) {
72
+ } else {
73
+ if (i > j) {
74
+ packer->unpack_1(codes.data(), i, buffer.data());
75
+ packer->pack_1(buffer.data(), j, codes.data());
76
+ }
77
+ j++;
78
+ }
79
+ }
80
+ size_t nremove = ntotal - j;
81
+ if (nremove > 0) {
82
+ ntotal = j;
83
+ ntotal2 = roundup(ntotal, bbs);
84
+ size_t new_size = ntotal2 / bbs * block_stride;
85
+
86
+ // Zero out stale data in the last block beyond the retained vectors.
87
+ // This is necessary because pq4_pack_codes_range uses |= to write
88
+ // new codes, so any stale non-zero nibbles would corrupt future adds.
89
+ // pack_1 with a zero buffer zeroes both PQ4 codes and aux data.
90
+ const size_t last_pos = ntotal % bbs;
91
+ if (last_pos > 0) {
92
+ const size_t last_block = ntotal / bbs;
93
+ std::vector<uint8_t> zero_code(code_size, 0);
94
+ for (size_t pos = last_pos; pos < bbs; pos++) {
95
+ packer->pack_1(
96
+ zero_code.data(), last_block * bbs + pos, codes.data());
97
+ }
98
+ }
99
+
100
+ codes.resize(new_size);
101
+ }
102
+ return nremove;
70
103
  }
71
104
 
72
105
  IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
@@ -104,58 +137,59 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
104
137
 
105
138
  // If the original index has data, extract factors and pack codes
106
139
  if (ntotal > 0) {
107
- // Compute per-vector storage size for flat storage
108
140
  const size_t storage_size = compute_per_vector_storage_size();
109
-
110
- // Allocate flat storage
111
- flat_storage.resize(ntotal * storage_size);
112
-
113
- // Copy factors directly from original codes
114
141
  const size_t bit_pattern_size = (d + 7) / 8;
115
- for (idx_t i = 0; i < ntotal; i++) {
116
- const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
117
- const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
118
- uint8_t* storage = flat_storage.data() + i * storage_size;
119
- memcpy(storage, source_factors_ptr, storage_size);
120
- }
121
142
 
122
143
  // Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
123
- // This follows the same pattern as IndexPQFastScan constructor
124
144
  AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
125
145
  memset(fastscan_codes.get(), 0, ntotal * code_size);
126
146
 
127
- // Convert from RaBitQ 1-bit-per-dimension to FastScan
128
- // 4-bit-per-sub-quantizer
129
147
  for (idx_t i = 0; i < ntotal; i++) {
130
148
  const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
131
149
  uint8_t* fs_code = fastscan_codes.get() + i * code_size;
132
150
 
133
- // Convert each dimension's bit (same logic as compute_codes)
134
151
  for (size_t j = 0; j < orig.d; j++) {
135
- // Extract bit from original RaBitQ format
136
152
  const size_t orig_byte_idx = j / 8;
137
153
  const size_t orig_bit_offset = j % 8;
138
154
  const bool bit_value =
139
155
  (orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
140
156
 
141
- // Use RaBitQUtils for consistent bit setting
142
157
  if (bit_value) {
143
158
  rabitq_utils::set_bit_fastscan(fs_code, j);
144
159
  }
145
160
  }
146
161
  }
147
162
 
148
- // Pack the converted codes using pq4_pack_codes with custom stride
149
- codes.resize(ntotal2 * M2 / 2);
150
- pq4_pack_codes(
163
+ // Pack the converted codes using enlarged block layout
164
+ const size_t block_stride = get_block_stride();
165
+ const size_t n_blocks = ntotal2 / bbs;
166
+ codes.resize(n_blocks * block_stride);
167
+ memset(codes.get(), 0, n_blocks * block_stride);
168
+ pq4_pack_codes_range(
151
169
  fastscan_codes.get(),
152
- ntotal,
153
170
  M,
154
- ntotal2,
171
+ 0,
172
+ ntotal,
155
173
  bbs,
156
174
  M2,
157
175
  codes.get(),
158
- code_size);
176
+ code_size,
177
+ block_stride);
178
+
179
+ // Copy auxiliary data from original codes into block aux region
180
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
181
+ for (idx_t i = 0; i < ntotal; i++) {
182
+ const uint8_t* src =
183
+ orig.codes.data() + i * orig.code_size + bit_pattern_size;
184
+ uint8_t* dst = rabitq_utils::get_block_aux_ptr(
185
+ codes.get(),
186
+ i,
187
+ bbs,
188
+ packed_block_size,
189
+ block_stride,
190
+ storage_size);
191
+ memcpy(dst, src, storage_size);
192
+ }
159
193
  }
160
194
  }
161
195
 
@@ -204,23 +238,13 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
204
238
  compute_codes(tmp_codes.get(), n, x);
205
239
 
206
240
  const size_t storage_size = compute_per_vector_storage_size();
207
- flat_storage.resize((ntotal + n) * storage_size);
208
-
209
- // Populate flat storage (no sign bits copying needed!)
210
241
  const size_t bit_pattern_size = (d + 7) / 8;
211
- for (idx_t i = 0; i < n; i++) {
212
- const uint8_t* code = tmp_codes.get() + i * code_size;
213
- const idx_t vec_idx = ntotal + i;
214
-
215
- // Copy factors data directly to flat storage (no reordering needed)
216
- const uint8_t* source_factors_ptr = code + bit_pattern_size;
217
- uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
218
- memcpy(storage, source_factors_ptr, storage_size);
219
- }
220
242
 
221
- // Resize main storage (same logic as parent)
243
+ // Resize main storage with enlarged block layout
222
244
  ntotal2 = roundup(ntotal + n, bbs);
223
- size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
245
+ const size_t block_stride = get_block_stride();
246
+ const size_t n_blocks = ntotal2 / bbs;
247
+ size_t new_size = n_blocks * block_stride;
224
248
  size_t old_size = codes.size();
225
249
  if (new_size > old_size) {
226
250
  codes.resize(new_size);
@@ -230,13 +254,27 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
230
254
  // Use our custom packing function with correct stride
231
255
  pq4_pack_codes_range(
232
256
  tmp_codes.get(),
233
- M, // Number of sub-quantizers (bit patterns only)
257
+ M,
234
258
  ntotal,
235
- ntotal + n, // Range to pack
259
+ ntotal + n,
236
260
  bbs,
237
- M2, // Block parameters
238
- codes.get(), // Output
239
- code_size); // CUSTOM STRIDE: includes factor space
261
+ M2,
262
+ codes.get(),
263
+ code_size,
264
+ block_stride);
265
+
266
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
267
+ for (idx_t i = 0; i < n; i++) {
268
+ const uint8_t* src = tmp_codes.get() + i * code_size + bit_pattern_size;
269
+ uint8_t* dst = rabitq_utils::get_block_aux_ptr(
270
+ codes.get(),
271
+ ntotal + i,
272
+ bbs,
273
+ packed_block_size,
274
+ block_stride,
275
+ storage_size);
276
+ memcpy(dst, src, storage_size);
277
+ }
240
278
 
241
279
  ntotal += n;
242
280
  }
@@ -502,7 +540,11 @@ RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
502
540
  nq(nq_val),
503
541
  k(k_val),
504
542
  context(ctx),
505
- is_multi_bit(multi_bit) {
543
+ is_multi_bit(multi_bit),
544
+ storage_size(index->compute_per_vector_storage_size()),
545
+ packed_block_size(((index->M2 + 1) / 2) * index->bbs),
546
+ full_block_size(index->get_block_stride()),
547
+ packer(index->get_CodePacker()) {
506
548
  // Initialize heaps for all queries in constructor
507
549
  // This allows us to support direct normalizer assignment
508
550
  #pragma omp parallel for if (nq > 100)
@@ -543,8 +585,11 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
543
585
  ? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
544
586
  : 0;
545
587
 
546
- // Get storage size once
547
- const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
588
+ // Compute block auxiliary region base pointer once per batch.
589
+ // Since bbs=32, each batch of 32 vectors aligns to one block.
590
+ const size_t block_idx = base_db_idx / rabitq_index->bbs;
591
+ const uint8_t* aux_base = rabitq_index->codes.get() +
592
+ block_idx * full_block_size + packed_block_size;
548
593
 
549
594
  // Stats tracking for multi-bit two-stage search only
550
595
  // n_1bit_evaluations: candidates evaluated using 1-bit lower bound
@@ -559,9 +604,8 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
559
604
  // Normalize distance from LUT lookup
560
605
  const float normalized_distance = d32tab[i] * one_a + bias;
561
606
 
562
- // Access factors from flat storage
563
- const uint8_t* base_ptr =
564
- rabitq_index->flat_storage.data() + db_idx * storage_size;
607
+ // Access factors from block auxiliary region
608
+ const uint8_t* base_ptr = aux_base + i * storage_size;
565
609
 
566
610
  if (is_multi_bit) {
567
611
  // Track candidates actually considered for two-stage filtering
@@ -578,14 +622,16 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
578
622
  rabitq_index->qb,
579
623
  rabitq_index->d);
580
624
 
581
- float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
582
-
583
625
  // Adaptive filtering: decide whether to compute full distance
584
626
  const bool is_similarity = rabitq_index->metric_type ==
585
627
  MetricType::METRIC_INNER_PRODUCT;
586
- bool should_refine = is_similarity
587
- ? (lower_bound > heap_dis[0]) // IP: keep if better
588
- : (lower_bound < heap_dis[0]); // L2: keep if better
628
+ bool should_refine = rabitq_utils::should_refine_candidate(
629
+ dist_1bit,
630
+ full_factors.f_error,
631
+ context.query_factors ? context.query_factors[q].g_error
632
+ : 0.0f,
633
+ heap_dis[0],
634
+ is_similarity);
589
635
 
590
636
  if (should_refine) {
591
637
  local_multibit_evaluations++;
@@ -647,10 +693,14 @@ float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
647
693
  float dist_1bit,
648
694
  size_t db_idx,
649
695
  size_t q) const {
650
- // Access f_error directly from SignBitFactorsWithError in flat storage
651
- const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
652
- const uint8_t* base_ptr =
653
- rabitq_index->flat_storage.data() + db_idx * storage_size;
696
+ // Access f_error from block auxiliary region
697
+ const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
698
+ rabitq_index->codes.get(),
699
+ db_idx,
700
+ rabitq_index->bbs,
701
+ packed_block_size,
702
+ full_block_size,
703
+ storage_size);
654
704
  const SignBitFactorsWithError& db_factors =
655
705
  *reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
656
706
  float f_error = db_factors.f_error;
@@ -674,9 +724,13 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
674
724
  const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
675
725
  const size_t dim = rabitq_index->d;
676
726
 
677
- const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
678
- const uint8_t* base_ptr =
679
- rabitq_index->flat_storage.data() + db_idx * storage_size;
727
+ const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
728
+ rabitq_index->codes.get(),
729
+ db_idx,
730
+ rabitq_index->bbs,
731
+ packed_block_size,
732
+ full_block_size,
733
+ storage_size);
680
734
 
681
735
  const size_t ex_code_size = (dim * ex_bits + 7) / 8;
682
736
  const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
@@ -689,8 +743,7 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
689
743
 
690
744
  // Get sign bits from FastScan packed format
691
745
  std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
692
- CodePackerPQ4 packer(rabitq_index->M2, rabitq_index->bbs);
693
- packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
746
+ packer->unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
694
747
  const uint8_t* sign_bits = unpacked_code.data();
695
748
 
696
749
  return rabitq_utils::compute_full_multibit_distance(
@@ -698,8 +751,9 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
698
751
  ex_code,
699
752
  ex_fac,
700
753
  query_factors.rotated_q.data(),
701
- query_factors.qr_to_c_L2sqr,
702
- query_factors.qr_norm_L2sqr,
754
+ (rabitq_index->metric_type == MetricType::METRIC_INNER_PRODUCT)
755
+ ? query_factors.q_dot_c
756
+ : query_factors.qr_to_c_L2sqr,
703
757
  dim,
704
758
  ex_bits,
705
759
  rabitq_index->metric_type);
@@ -7,6 +7,7 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <memory>
10
11
  #include <vector>
11
12
 
12
13
  #include <faiss/IndexFastScan.h>
@@ -43,17 +44,6 @@ struct IndexRaBitQFastScan : IndexFastScan {
43
44
  /// Center of all points (same as IndexRaBitQ)
44
45
  std::vector<float> center;
45
46
 
46
- /// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
47
- ///
48
- /// 1-bit codes (sign bits) are stored in the inherited `codes` array from
49
- /// IndexFastScan in packed FastScan format for SIMD processing.
50
- ///
51
- /// This flat_storage holds per-vector factors and refinement-bit codes:
52
- /// Layout for 1-bit: [SignBitFactors (8 bytes)]
53
- /// Layout for multi-bit: [SignBitFactorsWithError
54
- /// (12B)][ref_codes][ExtraBitsFactors (8B)]
55
- std::vector<uint8_t> flat_storage;
56
-
57
47
  /// Default number of bits to quantize a query with
58
48
  uint8_t qb = 8;
59
49
 
@@ -77,7 +67,7 @@ struct IndexRaBitQFastScan : IndexFastScan {
77
67
 
78
68
  void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
79
69
 
80
- /// Compute storage size per vector in flat_storage
70
+ /// Compute per-vector auxiliary data size in block aux region
81
71
  size_t compute_per_vector_storage_size() const;
82
72
 
83
73
  void compute_float_LUT(
@@ -88,6 +78,12 @@ struct IndexRaBitQFastScan : IndexFastScan {
88
78
 
89
79
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
90
80
 
81
+ /// Return CodePackerRaBitQ with enlarged block size
82
+ CodePacker* get_CodePacker() const override;
83
+
84
+ /// Remove vectors and compact both PQ4 codes and auxiliary data
85
+ size_t remove_ids(const IDSelector& sel) override;
86
+
91
87
  void search(
92
88
  idx_t n,
93
89
  const float* x,
@@ -141,6 +137,12 @@ struct RaBitQHeapHandler
141
137
  context; // Processing context with query offset
142
138
  const bool is_multi_bit; // Runtime flag for multi-bit mode
143
139
 
140
+ // Cached block-layout constants (invariant for handler lifetime)
141
+ const size_t storage_size;
142
+ const size_t packed_block_size;
143
+ const size_t full_block_size;
144
+ std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
145
+
144
146
  // Use float-based comparator for heap operations
145
147
  using Cfloat = typename std::conditional<
146
148
  C::is_max,