faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -8,39 +8,61 @@
8
8
  #include <faiss/impl/RaBitQuantizer.h>
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/impl/RaBitQUtils.h>
12
+ #include <faiss/impl/RaBitQuantizerMultiBit.h>
11
13
  #include <faiss/utils/distances.h>
12
14
  #include <faiss/utils/rabitq_simd.h>
13
15
  #include <algorithm>
14
16
  #include <cmath>
15
17
  #include <cstring>
16
- #include <limits>
17
18
  #include <memory>
18
19
  #include <vector>
19
20
 
20
21
  namespace faiss {
21
22
 
22
- struct FactorsData {
23
- // ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0)
24
- float or_minus_c_l2sqr = 0;
25
- float dp_multiplier = 0;
26
- };
27
-
28
- struct QueryFactorsData {
29
- float c1 = 0;
30
- float c2 = 0;
31
- float c34 = 0;
23
+ // Import shared utilities from RaBitQUtils
24
+ using rabitq_utils::ExtraBitsFactors;
25
+ using rabitq_utils::QueryFactorsData;
26
+ using rabitq_utils::SignBitFactors;
27
+ using rabitq_utils::SignBitFactorsWithError;
28
+
29
+ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
30
+ : Quantizer(d, 0), // code_size will be set below
31
+ metric_type{metric},
32
+ nb_bits{nb_bits} {
33
+ // Validate nb_bits range
34
+ FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
35
+
36
+ // Set code_size using compute_code_size
37
+ code_size = compute_code_size(d, nb_bits);
38
+ }
32
39
 
33
- float qr_to_c_L2sqr = 0;
34
- float qr_norm_L2sqr = 0;
35
- };
40
+ size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
41
+ // Validate inputs
42
+ FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
43
+
44
+ size_t ex_bits = num_bits - 1;
45
+
46
+ // Base: 1-bit codes + base factors
47
+ // Layout for 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
48
+ // base_factors = or_minus_c_l2sqr (4) + dp_multiplier (4)
49
+ // Layout for multi-bit: [binary_code: (d+7)/8
50
+ // bytes][SignBitFactorsWithError: 12 bytes]
51
+ // factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
52
+ size_t base_size = (d + 7) / 8 +
53
+ (ex_bits == 0 ? sizeof(SignBitFactors)
54
+ : sizeof(SignBitFactorsWithError));
55
+
56
+ // Extra: ex-bit codes + ex factors (only if ex_bits > 0)
57
+ // Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
58
+ size_t ex_size = 0;
59
+ if (ex_bits > 0) {
60
+ ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
61
+ }
36
62
 
37
- static size_t get_code_size(const size_t d) {
38
- return (d + 7) / 8 + sizeof(FactorsData);
63
+ return base_size + ex_size;
39
64
  }
40
65
 
41
- RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric)
42
- : Quantizer(d, get_code_size(d)), metric_type{metric} {}
43
-
44
66
  void RaBitQuantizer::train(size_t n, const float* x) {
45
67
  // does nothing
46
68
  }
@@ -65,68 +87,85 @@ void RaBitQuantizer::compute_codes_core(
65
87
  return;
66
88
  }
67
89
 
68
- // compute some helper constants
69
- const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
90
+ const size_t ex_bits = nb_bits - 1;
70
91
 
71
- // compute codes
92
+ // Compute codes
72
93
  #pragma omp parallel for if (n > 1000)
73
94
  for (int64_t i = 0; i < n; i++) {
74
- // ||or - c||^2
75
- float norm_L2sqr = 0;
76
- // ||or||^2, which is equal to ||P(or)||^2 and ||P^(-1)(or)||^2
77
- float or_L2sqr = 0;
78
- // dot product
79
- float dp_oO = 0;
80
-
81
- // the code
95
+ // Pointer to this vector's code
82
96
  uint8_t* code = codes + i * code_size;
83
- FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8);
84
97
 
85
- // cleanup it
86
- if (code != nullptr) {
87
- memset(code, 0, code_size);
98
+ // Clear code memory
99
+ memset(code, 0, code_size);
100
+
101
+ const float* x_row = x + i * d;
102
+
103
+ // Pointer arithmetic for code layout:
104
+ // For 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
105
+ // For multi-bit: [binary_code: (d+7)/8 bytes][SignBitFactorsWithError:
106
+ // 12 bytes]
107
+ // [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
108
+ uint8_t* binary_code = code;
109
+
110
+ // Step 1: Compute 1-bit quantization and base factors
111
+ // Store residual for potential ex-bits quantization
112
+ std::vector<float> residual(d);
113
+
114
+ // Use shared utilities for computing factors
115
+ SignBitFactorsWithError factors_data =
116
+ rabitq_utils::compute_vector_factors(
117
+ x_row, d, centroid_in, metric_type, ex_bits > 0);
118
+
119
+ // Write appropriate factors based on nb_bits
120
+ if (ex_bits == 0) {
121
+ // For 1-bit: write only SignBitFactors (8 bytes)
122
+ SignBitFactors* base_factors =
123
+ reinterpret_cast<SignBitFactors*>(code + (d + 7) / 8);
124
+ base_factors->or_minus_c_l2sqr = factors_data.or_minus_c_l2sqr;
125
+ base_factors->dp_multiplier = factors_data.dp_multiplier;
126
+ } else {
127
+ // For multi-bit: write full SignBitFactorsWithError (12 bytes)
128
+ SignBitFactorsWithError* full_factors =
129
+ reinterpret_cast<SignBitFactorsWithError*>(
130
+ code + (d + 7) / 8);
131
+ *full_factors = factors_data;
88
132
  }
89
133
 
134
+ // Pack bits into standard RaBitQ format
90
135
  for (size_t j = 0; j < d; j++) {
91
- const float or_minus_c = x[i * d + j] -
92
- ((centroid_in == nullptr) ? 0 : centroid_in[j]);
93
- norm_L2sqr += or_minus_c * or_minus_c;
94
- or_L2sqr += x[i * d + j] * x[i * d + j];
95
-
96
- const bool xb = (or_minus_c > 0);
136
+ const float x_val = x_row[j];
137
+ const float centroid_val =
138
+ (centroid_in == nullptr) ? 0.0f : centroid_in[j];
139
+ const float or_minus_c = x_val - centroid_val;
140
+ residual[j] = or_minus_c;
97
141
 
98
- dp_oO += xb ? or_minus_c : (-or_minus_c);
142
+ const bool xb = (or_minus_c > 0.0f);
99
143
 
100
- // store the output data
101
- if (code != nullptr) {
102
- if (xb) {
103
- // enable a particular bit
104
- code[j / 8] |= (1 << (j % 8));
105
- }
144
+ // Store the 1-bit sign code
145
+ if (xb) {
146
+ rabitq_utils::set_bit_standard(binary_code, j);
106
147
  }
107
148
  }
108
149
 
109
- // compute factors
110
-
111
- // compute the inverse norm
112
- const float inv_norm_L2 =
113
- (std::abs(norm_L2sqr) < std::numeric_limits<float>::epsilon())
114
- ? 1.0f
115
- : (1.0f / std::sqrt(norm_L2sqr));
116
- dp_oO *= inv_norm_L2;
117
- dp_oO *= inv_d_sqrt;
118
-
119
- const float inv_dp_oO =
120
- (std::abs(dp_oO) < std::numeric_limits<float>::epsilon())
121
- ? 1.0f
122
- : (1.0f / dp_oO);
123
-
124
- fac->or_minus_c_l2sqr = norm_L2sqr;
125
- if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
126
- fac->or_minus_c_l2sqr -= or_L2sqr;
150
+ // Step 2: Compute ex-bits quantization (if nb_bits > 1)
151
+ if (ex_bits > 0) {
152
+ // Pointer to ex-bit code section
153
+ uint8_t* ex_code =
154
+ code + (d + 7) / 8 + sizeof(SignBitFactorsWithError);
155
+ // Pointer to ex-factors section
156
+ ExtraBitsFactors* ex_factors = reinterpret_cast<ExtraBitsFactors*>(
157
+ ex_code + (d * ex_bits + 7) / 8);
158
+
159
+ // Quantize residual to ex-bits (pass centroid for IP metric)
160
+ rabitq_multibit::quantize_ex_bits(
161
+ residual.data(),
162
+ d,
163
+ nb_bits,
164
+ ex_code,
165
+ *ex_factors,
166
+ metric_type,
167
+ centroid_in);
127
168
  }
128
-
129
- fac->dp_multiplier = inv_dp_oO * std::sqrt(norm_L2sqr);
130
169
  }
131
170
  }
132
171
 
@@ -143,6 +182,7 @@ void RaBitQuantizer::decode_core(
143
182
  FAISS_ASSERT(x != nullptr);
144
183
 
145
184
  const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
185
+ const size_t ex_bits = nb_bits - 1;
146
186
 
147
187
  #pragma omp parallel for if (n > 1000)
148
188
  for (int64_t i = 0; i < n; i++) {
@@ -150,10 +190,19 @@ void RaBitQuantizer::decode_core(
150
190
 
151
191
  // split the code into parts
152
192
  const uint8_t* binary_data = code;
153
- const FactorsData* fac =
154
- reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
155
193
 
194
+ // Cast to appropriate type based on nb_bits
195
+ // For 1-bit: use SignBitFactors (8 bytes)
196
+ // For multi-bit: use SignBitFactorsWithError (12 bytes, but only first
197
+ // 8 bytes used for decode)
198
+ const SignBitFactors* fac = (ex_bits == 0)
199
+ ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
200
+ : reinterpret_cast<const SignBitFactorsWithError*>(
201
+ code + (d + 7) / 8);
202
+
203
+ // this is the baseline code
156
204
  //
205
+ // compute <q,o> using floats
157
206
  for (size_t j = 0; j < d; j++) {
158
207
  // extract i-th bit
159
208
  const uint8_t masker = (1 << (j % 8));
@@ -166,51 +215,69 @@ void RaBitQuantizer::decode_core(
166
215
  }
167
216
  }
168
217
 
169
- struct RaBitDistanceComputer : FlatCodesDistanceComputer {
170
- // dimensionality
171
- size_t d = 0;
172
- // a centroid to use
173
- const float* centroid = nullptr;
218
+ // Implementation of RaBitQDistanceComputer (declared in header)
174
219
 
175
- // the metric
176
- MetricType metric_type = MetricType::METRIC_L2;
220
+ float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
221
+ FAISS_ASSERT(code != nullptr);
177
222
 
178
- RaBitDistanceComputer();
223
+ // Compute estimated distance using 1-bit codes
224
+ float est_distance = distance_to_code_1bit(code);
179
225
 
180
- float symmetric_dis(idx_t i, idx_t j) override;
181
- };
226
+ // Extract f_error from the code
227
+ size_t size = (d + 7) / 8;
228
+ const SignBitFactorsWithError* base_fac =
229
+ reinterpret_cast<const SignBitFactorsWithError*>(code + size);
230
+ float f_error = base_fac->f_error;
182
231
 
183
- RaBitDistanceComputer::RaBitDistanceComputer() = default;
232
+ // Compute proper lower bound using RaBitQ error formula:
233
+ // lower_bound = est_distance - f_error * g_error
234
+ // This guarantees: lower_bound ≤ true_distance
235
+ float lower_bound = est_distance - (f_error * g_error);
184
236
 
185
- float RaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
186
- FAISS_THROW_MSG("Not implemented");
237
+ // Distance cannot be negative
238
+ return std::max(0.0f, lower_bound);
187
239
  }
188
240
 
189
- struct RaBitDistanceComputerNotQ : RaBitDistanceComputer {
241
+ namespace {
242
+
243
+ struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
190
244
  // the rotated query (qr - c)
191
245
  std::vector<float> rotated_q;
192
246
  // some additional numbers for the query
193
247
  QueryFactorsData query_fac;
194
248
 
195
- RaBitDistanceComputerNotQ();
249
+ RaBitQDistanceComputerNotQ();
196
250
 
197
- float distance_to_code(const uint8_t* code) override;
251
+ // Compute distance using only 1-bit codes (fast)
252
+ float distance_to_code_1bit(const uint8_t* code) override;
253
+
254
+ // Compute full distance using 1-bit + ex-bits (accurate)
255
+ float distance_to_code_full(const uint8_t* code) override;
198
256
 
199
257
  void set_query(const float* x) override;
200
258
  };
201
259
 
202
- RaBitDistanceComputerNotQ::RaBitDistanceComputerNotQ() = default;
260
+ RaBitQDistanceComputerNotQ::RaBitQDistanceComputerNotQ() = default;
203
261
 
204
- float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
262
+ float RaBitQDistanceComputerNotQ::distance_to_code_1bit(const uint8_t* code) {
205
263
  FAISS_ASSERT(code != nullptr);
206
264
  FAISS_ASSERT(
207
265
  (metric_type == MetricType::METRIC_L2 ||
208
266
  metric_type == MetricType::METRIC_INNER_PRODUCT));
267
+ FAISS_ASSERT(rotated_q.size() == d);
209
268
 
210
269
  // split the code into parts
211
270
  const uint8_t* binary_data = code;
212
- const FactorsData* fac =
213
- reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
271
+
272
+ // Cast to appropriate type based on nb_bits
273
+ // For 1-bit: use SignBitFactors (8 bytes)
274
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
275
+ // f_error
276
+ size_t ex_bits = nb_bits - 1;
277
+ const SignBitFactors* base_fac = (ex_bits == 0)
278
+ ? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
279
+ : reinterpret_cast<const SignBitFactorsWithError*>(
280
+ code + (d + 7) / 8);
214
281
 
215
282
  // this is the baseline code
216
283
  //
@@ -219,48 +286,70 @@ float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
219
286
  // It was a willful decision (after the discussion) to not to pre-cache
220
287
  // the sum of all bits, just in order to reduce the overhead per vector.
221
288
  uint64_t sum_q = 0;
222
- for (size_t i = 0; i < d; i++) {
223
- // extract i-th bit
224
- const uint8_t masker = (1 << (i % 8));
225
- const bool b_bit = ((binary_data[i / 8] & masker) == masker);
226
289
 
290
+ for (size_t i = 0; i < d; i++) {
291
+ // Extract i-th bit
292
+ bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
227
293
  // accumulate dp
228
- dot_qo += (b_bit) ? rotated_q[i] : 0;
294
+ dot_qo += bit ? rotated_q[i] : 0;
229
295
  // accumulate sum-of-bits
230
- sum_q += (b_bit) ? 1 : 0;
296
+ sum_q += bit ? 1 : 0;
231
297
  }
232
298
 
233
- float final_dot = 0;
234
- // dot-product itself
235
- final_dot += query_fac.c1 * dot_qo;
236
- // normalizer coefficients
237
- final_dot += query_fac.c2 * sum_q;
238
- // normalizer coefficients
239
- final_dot -= query_fac.c34;
240
-
241
- // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
242
- const float or_c_l2sqr = fac->or_minus_c_l2sqr;
299
+ // Apply query factors
300
+ float final_dot =
301
+ query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
243
302
 
244
303
  // pre_dist = ||or - c||^2 + ||qr - c||^2 -
245
304
  // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
246
- const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
247
- 2 * fac->dp_multiplier * final_dot;
305
+ float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
306
+ 2 * base_fac->dp_multiplier * final_dot;
248
307
 
249
308
  if (metric_type == MetricType::METRIC_L2) {
250
309
  // ||or - q||^ 2
251
310
  return pre_dist;
252
311
  } else {
253
312
  // metric == MetricType::METRIC_INNER_PRODUCT
313
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
314
+ }
315
+ }
254
316
 
255
- // this is ||q||^2
256
- const float query_norm_sqr = query_fac.qr_norm_L2sqr;
317
+ float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
318
+ FAISS_ASSERT(code != nullptr);
319
+ FAISS_ASSERT(
320
+ (metric_type == MetricType::METRIC_L2 ||
321
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
322
+ FAISS_ASSERT(rotated_q.size() == d);
257
323
 
258
- // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
259
- return -0.5f * (pre_dist - query_norm_sqr);
324
+ size_t ex_bits = nb_bits - 1;
325
+
326
+ if (ex_bits == 0) {
327
+ // No ex-bits, just return 1-bit distance
328
+ return distance_to_code_1bit(code);
260
329
  }
330
+
331
+ // Extract pointers to code sections
332
+ const uint8_t* binary_data = code;
333
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
334
+ const uint8_t* ex_code = code + offset;
335
+ const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
336
+ ex_code + (d * ex_bits + 7) / 8);
337
+
338
+ // Call shared utility directly with rotated_q pointer
339
+ return rabitq_utils::compute_full_multibit_distance(
340
+ binary_data,
341
+ ex_code,
342
+ *ex_fac,
343
+ rotated_q.data(),
344
+ query_fac.qr_to_c_L2sqr,
345
+ query_fac.qr_norm_L2sqr,
346
+ d,
347
+ ex_bits,
348
+ metric_type);
261
349
  }
262
350
 
263
- void RaBitDistanceComputerNotQ::set_query(const float* x) {
351
+ void RaBitQDistanceComputerNotQ::set_query(const float* x) {
352
+ q = x;
264
353
  FAISS_ASSERT(x != nullptr);
265
354
  FAISS_ASSERT(
266
355
  (metric_type == MetricType::METRIC_L2 ||
@@ -279,6 +368,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
279
368
  rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
280
369
  }
281
370
 
371
+ // Compute g_error (query norm for lower bound computation)
372
+ // g_error = ||qr - c|| (L2 norm of rotated query)
373
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
374
+
282
375
  // compute some numbers
283
376
  const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
284
377
 
@@ -299,8 +392,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
299
392
  }
300
393
 
301
394
  //
302
- struct RaBitDistanceComputerQ : RaBitDistanceComputer {
395
+ struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
303
396
  // the rotated and quantized query (qr - c)
397
+ std::vector<float> rotated_q;
398
+ // the rotated and quantized query (qr - c) for fast 1-bit computation
304
399
  std::vector<uint8_t> rotated_qq;
305
400
  // we're using the proposed relayout-ed scheme from 3.3 that allows
306
401
  // using popcounts for computing the distance.
@@ -310,149 +405,138 @@ struct RaBitDistanceComputerQ : RaBitDistanceComputer {
310
405
 
311
406
  // the number of bits for SQ quantization of the query (qb > 0)
312
407
  uint8_t qb = 8;
408
+ bool centered = false;
313
409
  // the smallest value divisible by 8 that is not smaller than dim
314
410
  size_t popcount_aligned_dim = 0;
315
411
 
316
- RaBitDistanceComputerQ();
412
+ RaBitQDistanceComputerQ();
317
413
 
318
- float distance_to_code(const uint8_t* code) override;
414
+ // Compute distance using only 1-bit codes (fast)
415
+ float distance_to_code_1bit(const uint8_t* code) override;
416
+
417
+ // Compute full distance using 1-bit + ex-bits (accurate)
418
+ float distance_to_code_full(const uint8_t* code) override;
319
419
 
320
420
  void set_query(const float* x) override;
321
421
  };
322
422
 
323
- RaBitDistanceComputerQ::RaBitDistanceComputerQ() = default;
423
+ RaBitQDistanceComputerQ::RaBitQDistanceComputerQ() = default;
324
424
 
325
- float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
425
+ float RaBitQDistanceComputerQ::distance_to_code_1bit(const uint8_t* code) {
326
426
  FAISS_ASSERT(code != nullptr);
327
427
  FAISS_ASSERT(
328
428
  (metric_type == MetricType::METRIC_L2 ||
329
429
  metric_type == MetricType::METRIC_INNER_PRODUCT));
330
430
 
331
431
  // split the code into parts
432
+ size_t size = (d + 7) / 8;
332
433
  const uint8_t* binary_data = code;
333
- const FactorsData* fac =
334
- reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
335
-
336
- // // this is the baseline code
337
- // //
338
- // // compute <q,o> using integers
339
- // size_t dot_qo = 0;
340
- // for (size_t i = 0; i < d; i++) {
341
- // // extract i-th bit
342
- // const uint8_t masker = (1 << (i % 8));
343
- // const uint8_t bit = ((binary_data[i / 8] & masker) == masker) ? 1 :
344
- // 0;
345
- //
346
- // // accumulate dp
347
- // dot_qo += bit * rotated_qq[i];
348
- // }
349
434
 
350
- // this is the scheme for popcount
351
- const size_t di_8b = (d + 7) / 8;
352
- const size_t di_64b = (di_8b / 8) * 8;
435
+ // Cast to appropriate type based on nb_bits
436
+ // For 1-bit: use SignBitFactors (8 bytes)
437
+ // For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
438
+ // f_error
439
+ size_t ex_bits = nb_bits - 1;
440
+ const SignBitFactors* base_fac = (ex_bits == 0)
441
+ ? reinterpret_cast<const SignBitFactors*>(code + size)
442
+ : reinterpret_cast<const SignBitFactorsWithError*>(code + size);
353
443
 
354
- // Use the optimized popcount function from rabitq_simd.h
355
- float dot_qo =
356
- rabitq_dp_popcnt(rearranged_rotated_qq.data(), binary_data, d, qb);
357
-
358
- // It was a willful decision (after the discussion) to not to pre-cache
359
- // the sum of all bits, just in order to reduce the overhead per vector.
360
- uint64_t sum_q = 0;
361
- {
444
+ // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
445
+ float final_dot = 0;
446
+ if (centered) {
447
+ int64_t int_dot = ((1 << qb) - 1) * d;
448
+ // See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
449
+ int_dot -= 2 *
450
+ rabitq::bitwise_xor_dot_product(
451
+ rearranged_rotated_qq.data(), binary_data, size, qb);
452
+ final_dot += int_dot * query_fac.int_dot_scale;
453
+ } else {
454
+ auto dot_qo = rabitq::bitwise_and_dot_product(
455
+ rearranged_rotated_qq.data(), binary_data, size, qb);
456
+ // It was a willful decision (after the discussion) to not to pre-cache
457
+ // the sum of all bits, just in order to reduce the overhead per vector.
362
458
  // process 64-bit popcounts
363
- for (size_t i = 0; i < di_64b; i += 8) {
364
- const auto yv = *(const uint64_t*)(binary_data + i);
365
- sum_q += __builtin_popcountll(yv);
366
- }
367
-
368
- // process leftovers
369
- for (size_t i = di_64b; i < di_8b; i++) {
370
- const auto yv = *(binary_data + i);
371
- sum_q += __builtin_popcount(yv);
372
- }
459
+ auto sum_q = rabitq::popcount(binary_data, size);
460
+ // dot-product itself
461
+ final_dot += query_fac.c1 * dot_qo;
462
+ // normalizer coefficients
463
+ final_dot += query_fac.c2 * sum_q;
464
+ // normalizer coefficients
465
+ final_dot -= query_fac.c34;
373
466
  }
374
467
 
375
- float final_dot = 0;
376
- // dot-product itself
377
- final_dot += query_fac.c1 * dot_qo;
378
- // normalizer coefficients
379
- final_dot += query_fac.c2 * sum_q;
380
- // normalizer coefficients
381
- final_dot -= query_fac.c34;
382
-
383
- // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
384
- const float or_c_l2sqr = fac->or_minus_c_l2sqr;
385
-
386
468
  // pre_dist = ||or - c||^2 + ||qr - c||^2 -
387
469
  // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
388
- const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
389
- 2 * fac->dp_multiplier * final_dot;
470
+ const float pre_dist = base_fac->or_minus_c_l2sqr +
471
+ query_fac.qr_to_c_L2sqr - 2 * base_fac->dp_multiplier * final_dot;
390
472
 
391
473
  if (metric_type == MetricType::METRIC_L2) {
392
474
  // ||or - q||^ 2
393
475
  return pre_dist;
394
476
  } else {
395
477
  // metric == MetricType::METRIC_INNER_PRODUCT
396
-
397
- // this is ||q||^2
398
- const float query_norm_sqr = query_fac.qr_norm_L2sqr;
399
-
400
478
  // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
401
- return -0.5f * (pre_dist - query_norm_sqr);
479
+ return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
402
480
  }
403
481
  }
404
482
 
405
- void RaBitDistanceComputerQ::set_query(const float* x) {
406
- FAISS_ASSERT(x != nullptr);
483
+ float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
484
+ FAISS_ASSERT(code != nullptr);
407
485
  FAISS_ASSERT(
408
486
  (metric_type == MetricType::METRIC_L2 ||
409
487
  metric_type == MetricType::METRIC_INNER_PRODUCT));
488
+ FAISS_ASSERT(rotated_q.size() == d);
410
489
 
411
- // compute the distance from the query to the centroid
412
- if (centroid != nullptr) {
413
- query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
414
- } else {
415
- query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
416
- }
417
-
418
- // allocate space
419
- rotated_qq.resize(d);
490
+ size_t ex_bits = nb_bits - 1;
420
491
 
421
- // rotate the query
422
- std::vector<float> rotated_q(d);
423
- for (size_t i = 0; i < d; i++) {
424
- rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
492
+ if (ex_bits == 0) {
493
+ // No ex-bits, just return 1-bit distance
494
+ return distance_to_code_1bit(code);
425
495
  }
426
496
 
427
- // compute some numbers
428
- const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
429
-
430
- // quantize the query. compute min and max
431
- float v_min = std::numeric_limits<float>::max();
432
- float v_max = std::numeric_limits<float>::lowest();
433
- for (size_t i = 0; i < d; i++) {
434
- const float v_q = rotated_q[i];
435
- v_min = std::min(v_min, v_q);
436
- v_max = std::max(v_max, v_q);
437
- }
438
-
439
- const float pow_2_qb = 1 << qb;
497
+ // Extract pointers to code sections
498
+ const uint8_t* binary_data = code;
499
+ size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
500
+ const uint8_t* ex_code = code + offset;
501
+ const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
502
+ ex_code + (d * ex_bits + 7) / 8);
503
+
504
+ // Call shared utility directly with rotated_q pointer
505
+ return rabitq_utils::compute_full_multibit_distance(
506
+ binary_data,
507
+ ex_code,
508
+ *ex_fac,
509
+ rotated_q.data(),
510
+ query_fac.qr_to_c_L2sqr,
511
+ query_fac.qr_norm_L2sqr,
512
+ d,
513
+ ex_bits,
514
+ metric_type);
515
+ }
440
516
 
441
- const float delta = (v_max - v_min) / (pow_2_qb - 1);
442
- const float inv_delta = 1.0f / delta;
517
+ // Use shared constant from RaBitQUtils
518
+ using rabitq_utils::Z_MAX_BY_QB;
443
519
 
444
- size_t sum_qq = 0;
445
- for (int32_t i = 0; i < d; i++) {
446
- const float v_q = rotated_q[i];
520
+ void RaBitQDistanceComputerQ::set_query(const float* x) {
521
+ q = x;
522
+ FAISS_ASSERT(x != nullptr);
523
+ FAISS_ASSERT(
524
+ (metric_type == MetricType::METRIC_L2 ||
525
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
526
+ FAISS_THROW_IF_NOT(qb <= 8);
527
+ FAISS_THROW_IF_NOT(qb > 0);
447
528
 
448
- // a default non-randomized SQ
449
- const int v_qq = std::round((v_q - v_min) * inv_delta);
529
+ // Use shared utilities for core query factor computation
530
+ // rotated_q is populated directly by compute_query_factors as an output
531
+ // parameter
532
+ query_fac = rabitq_utils::compute_query_factors(
533
+ x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
450
534
 
451
- rotated_qq[i] = std::min(255, std::max(0, v_qq));
452
- sum_qq += v_qq;
453
- }
535
+ // Compute g_error (query norm for lower bound computation)
536
+ // g_error = ||qr - c|| (L2 norm of rotated query)
537
+ g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
454
538
 
455
- // rearrange the query vector
539
+ // Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
456
540
  popcount_aligned_dim = ((d + 7) / 8) * 8;
457
541
  size_t offset = (d + 7) / 8;
458
542
 
@@ -466,33 +550,30 @@ void RaBitDistanceComputerQ::set_query(const float* x) {
466
550
  bit ? (1 << (idim % 8)) : 0;
467
551
  }
468
552
  }
469
-
470
- query_fac.c1 = 2 * delta * inv_d;
471
- query_fac.c2 = 2 * v_min * inv_d;
472
- query_fac.c34 = inv_d * (delta * sum_qq + d * v_min);
473
-
474
- if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
475
- // precompute if needed
476
- query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
477
- }
478
553
  }
479
554
 
555
+ } // anonymous namespace
556
+
480
557
  FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
481
558
  uint8_t qb,
482
- const float* centroid_in) const {
559
+ const float* centroid_in,
560
+ bool centered) const {
483
561
  if (qb == 0) {
484
- auto dc = std::make_unique<RaBitDistanceComputerNotQ>();
562
+ auto dc = std::make_unique<RaBitQDistanceComputerNotQ>();
485
563
  dc->metric_type = metric_type;
486
564
  dc->d = d;
487
565
  dc->centroid = centroid_in;
566
+ dc->nb_bits = nb_bits;
488
567
 
489
568
  return dc.release();
490
569
  } else {
491
- auto dc = std::make_unique<RaBitDistanceComputerQ>();
570
+ auto dc = std::make_unique<RaBitQDistanceComputerQ>();
492
571
  dc->metric_type = metric_type;
493
572
  dc->d = d;
494
573
  dc->centroid = centroid_in;
495
574
  dc->qb = qb;
575
+ dc->centered = centered;
576
+ dc->nb_bits = nb_bits;
496
577
 
497
578
  return dc.release();
498
579
  }