faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -14,9 +14,9 @@
14
14
  #if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \
15
15
  defined(_M_IX86)
16
16
  #include <immintrin.h>
17
- #endif
17
+ #endif // defined(__x86_64__) || defined(_M_X64) || defined(__i386__) ||
18
18
 
19
- namespace faiss {
19
+ namespace faiss::rabitq {
20
20
 
21
21
  #if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || \
22
22
  defined(_M_IX86)
@@ -24,8 +24,11 @@ namespace faiss {
24
24
  * Returns the lookup table for AVX512 popcount operations.
25
25
  * This table is used for lookup-based popcount implementation.
26
26
  *
27
+ * Source: https://github.com/WojciechMula/sse-popcount.
28
+ *
27
29
  * @return Lookup table as __m512i register
28
30
  */
31
+ #if defined(__AVX512F__)
29
32
  inline __m512i get_lookup_512() {
30
33
  return _mm512_set_epi8(
31
34
  /* f */ 4,
@@ -93,7 +96,8 @@ inline __m512i get_lookup_512() {
93
96
  /* 1 */ 1,
94
97
  /* 0 */ 0);
95
98
  }
96
-
99
+ #endif // defined(__AVX512F__)
100
+ #if defined(__AVX2__)
97
101
  /**
98
102
  * Returns the lookup table for AVX2 popcount operations.
99
103
  * This table is used for lookup-based popcount implementation.
@@ -135,405 +139,287 @@ inline __m256i get_lookup_256() {
135
139
  /* e */ 3,
136
140
  /* f */ 4);
137
141
  }
142
+ #endif // defined(__AVX2__)
138
143
 
144
+ #if defined(__AVX512F__)
139
145
  /**
140
- * Performs lookup-based popcount on AVX512 registers.
146
+ * Popcount for a 512-bit register, using lookup tables if necessary.
141
147
  *
142
- * @param v_and Input vector to count bits in
143
- * @return Vector with popcount results
148
+ * @param v Input vector to count bits in
149
+ * @return Vector int32_t[16] with popcount results.
144
150
  */
145
- inline __m512i popcount_lookup_avx512(__m512i v_and) {
151
+ inline __m512i popcount_512(__m512i v) {
152
+ #if defined(__AVX512VPOPCNTDQ__)
153
+ return _mm512_popcnt_epi64(v);
154
+ #else
146
155
  const __m512i lookup = get_lookup_512();
147
156
  const __m512i low_mask = _mm512_set1_epi8(0x0f);
148
157
 
149
- const __m512i lo = _mm512_and_si512(v_and, low_mask);
150
- const __m512i hi = _mm512_and_si512(_mm512_srli_epi16(v_and, 4), low_mask);
151
- const __m512i popcnt1 = _mm512_shuffle_epi8(lookup, lo);
152
- const __m512i popcnt2 = _mm512_shuffle_epi8(lookup, hi);
153
- return _mm512_add_epi8(popcnt1, popcnt2);
158
+ const __m512i lo = _mm512_and_si512(v, low_mask);
159
+ const __m512i hi = _mm512_and_si512(_mm512_srli_epi16(v, 4), low_mask);
160
+ const __m512i popcnt_lo = _mm512_shuffle_epi8(lookup, lo);
161
+ const __m512i popcnt_hi = _mm512_shuffle_epi8(lookup, hi);
162
+ const __m512i popcnt = _mm512_add_epi8(popcnt_lo, popcnt_hi);
163
+ return _mm512_sad_epu8(_mm512_setzero_si512(), popcnt);
164
+ #endif // defined(__AVX512VPOPCNTDQ__)
154
165
  }
166
+ #endif // defined(__AVX512F__)
155
167
 
168
+ #if defined(__AVX2__)
156
169
  /**
157
- * Performs lookup-based popcount on AVX2 registers.
170
+ * Popcount for a 256-bit register, using lookup tables if necessary.
158
171
  *
159
- * @param v_and Input vector to count bits in
160
- * @return Vector with popcount results
172
+ * @param v Input vector to count bits in
173
+ * @return uint64_t[4] of popcounts for each portion of the input vector.
161
174
  */
162
- inline __m256i popcount_lookup_avx2(__m256i v_and) {
175
+ inline __m256i popcount_256(__m256i v) {
163
176
  const __m256i lookup = get_lookup_256();
164
177
  const __m256i low_mask = _mm256_set1_epi8(0x0f);
165
178
 
166
- const __m256i lo = _mm256_and_si256(v_and, low_mask);
167
- const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(v_and, 4), low_mask);
168
- const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo);
169
- const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi);
170
- return _mm256_add_epi8(popcnt1, popcnt2);
179
+ const __m256i lo = _mm256_and_si256(v, low_mask);
180
+ const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(v, 4), low_mask);
181
+ const __m256i popcnt_lo = _mm256_shuffle_epi8(lookup, lo);
182
+ const __m256i popcnt_hi = _mm256_shuffle_epi8(lookup, hi);
183
+ const __m256i popcnt = _mm256_add_epi8(popcnt_lo, popcnt_hi);
184
+ // Reduce uint8_t[32] into uint64_t[4] by addition.
185
+ return _mm256_sad_epu8(_mm256_setzero_si256(), popcnt);
186
+ }
187
+
188
+ inline uint64_t reduce_add_256(__m256i v) {
189
+ alignas(32) uint64_t lanes[4];
190
+ _mm256_store_si256((__m256i*)lanes, v);
191
+ return lanes[0] + lanes[1] + lanes[2] + lanes[3];
192
+ }
193
+ #endif // defined(__AVX2__)
194
+
195
+ #if defined(__SSE4_1__)
196
+ inline __m128i popcount_128(__m128i v) {
197
+ // Scalar popcount for each 64-bit lane
198
+ uint64_t lane0 = _mm_extract_epi64(v, 0);
199
+ uint64_t lane1 = _mm_extract_epi64(v, 1);
200
+ uint64_t pop0 = __builtin_popcountll(lane0);
201
+ uint64_t pop1 = __builtin_popcountll(lane1);
202
+ return _mm_set_epi64x(pop1, pop0);
171
203
  }
172
- #endif
173
204
 
174
- #if defined(__AVX512F__) && defined(__AVX512VPOPCNTDQ__)
205
+ inline uint64_t reduce_add_128(__m128i v) {
206
+ alignas(16) uint64_t lanes[2];
207
+ _mm_store_si128((__m128i*)lanes, v);
208
+ return lanes[0] + lanes[1];
209
+ }
210
+ #endif // defined(__SSE4_1__)
211
+ #endif // defined(__x86_64__) || defined(_M_X64) || defined(__i386__) ||
175
212
 
176
213
  /**
177
- * AVX512-optimized version of dot product computation between query and binary
178
- * data. Requires AVX512F and AVX512VPOPCNTDQ instruction sets.
214
+ * Compute dot product between query and binary data using popcount operations.
179
215
  *
180
216
  * @param query Pointer to rearranged rotated query data
181
- * @param binary_data Pointer to binary data
217
+ * @param data Pointer to binary data
182
218
  * @param d Dimension
183
219
  * @param qb Number of quantization bits
184
- * @return Dot product result as float
220
+ * @return Unsigned integer dot product
185
221
  */
186
- inline float rabitq_dp_popcnt_avx512(
222
+ inline uint64_t bitwise_and_dot_product(
187
223
  const uint8_t* query,
188
- const uint8_t* binary_data,
189
- size_t d,
224
+ const uint8_t* data,
225
+ size_t size,
190
226
  size_t qb) {
191
- __m512i sum_512 = _mm512_setzero_si512();
192
-
193
- const size_t di_8b = (d + 7) / 8;
194
-
195
- const size_t d_512 = (d / 512) * 512;
196
- const size_t d_256 = (d / 256) * 256;
197
- const size_t d_128 = (d / 128) * 128;
198
-
199
- for (size_t i = 0; i < d_512; i += 512) {
200
- __m512i v_x = _mm512_loadu_si512((const __m512i*)(binary_data + i / 8));
201
- for (size_t j = 0; j < qb; j++) {
202
- __m512i v_q = _mm512_loadu_si512(
203
- (const __m512i*)(query + j * di_8b + i / 8));
204
- __m512i v_and = _mm512_and_si512(v_q, v_x);
205
- __m512i v_popcnt = _mm512_popcnt_epi32(v_and);
206
- sum_512 = _mm512_add_epi32(sum_512, _mm512_slli_epi32(v_popcnt, j));
227
+ uint64_t sum = 0;
228
+ size_t offset = 0;
229
+ #if defined(__AVX512F__)
230
+ // Handle 512-bit chunks.
231
+ if (size_t step = 512 / 8; offset + step <= size) {
232
+ __m512i sum_512 = _mm512_setzero_si512();
233
+ for (; offset + step <= size; offset += step) {
234
+ __m512i v_x = _mm512_loadu_si512((const __m512i*)(data + offset));
235
+ for (int j = 0; j < qb; j++) {
236
+ __m512i v_q = _mm512_loadu_si512(
237
+ (const __m512i*)(query + j * size + offset));
238
+ __m512i v_and = _mm512_and_si512(v_q, v_x);
239
+ __m512i v_popcnt = popcount_512(v_and);
240
+ __m512i v_shifted = _mm512_slli_epi64(v_popcnt, j);
241
+ sum_512 = _mm512_add_epi64(sum_512, v_shifted);
242
+ }
207
243
  }
244
+ sum += _mm512_reduce_add_epi64(sum_512);
208
245
  }
209
-
210
- __m256i sum_256 = _mm256_add_epi32(
211
- _mm512_extracti32x8_epi32(sum_512, 0),
212
- _mm512_extracti32x8_epi32(sum_512, 1));
213
-
214
- if (d_256 != d_512) {
215
- __m256i v_x =
216
- _mm256_loadu_si256((const __m256i*)(binary_data + d_512 / 8));
217
- for (size_t j = 0; j < qb; j++) {
218
- __m256i v_q = _mm256_loadu_si256(
219
- (const __m256i*)(query + j * di_8b + d_512 / 8));
220
- __m256i v_and = _mm256_and_si256(v_q, v_x);
221
- __m256i v_popcnt = _mm256_popcnt_epi32(v_and);
222
- sum_256 = _mm256_add_epi32(sum_256, _mm256_slli_epi32(v_popcnt, j));
246
+ #endif // defined(__AVX512F__)
247
+ #if defined(__AVX2__)
248
+ if (size_t step = 256 / 8; offset + step <= size) {
249
+ __m256i sum_256 = _mm256_setzero_si256();
250
+ for (; offset + step <= size; offset += step) {
251
+ __m256i v_x = _mm256_loadu_si256((const __m256i*)(data + offset));
252
+ for (int j = 0; j < qb; j++) {
253
+ __m256i v_q = _mm256_loadu_si256(
254
+ (const __m256i*)(query + j * size + offset));
255
+ __m256i v_and = _mm256_and_si256(v_q, v_x);
256
+ __m256i v_popcnt = popcount_256(v_and);
257
+ __m256i v_shifted = _mm256_slli_epi64(v_popcnt, j);
258
+ sum_256 = _mm256_add_epi64(sum_256, v_shifted);
259
+ }
223
260
  }
261
+ sum += reduce_add_256(sum_256);
224
262
  }
225
-
226
- __m128i sum_128 = _mm_add_epi32(
227
- _mm256_extracti32x4_epi32(sum_256, 0),
228
- _mm256_extracti32x4_epi32(sum_256, 1));
229
-
230
- if (d_128 != d_256) {
231
- __m128i v_x =
232
- _mm_loadu_si128((const __m128i*)(binary_data + d_256 / 8));
233
- for (size_t j = 0; j < qb; j++) {
263
+ #endif // defined(__AVX2__)
264
+ #if defined(__SSE4_1__)
265
+ __m128i sum_128 = _mm_setzero_si128();
266
+ for (size_t step = 128 / 8; offset + step <= size; offset += step) {
267
+ __m128i v_x = _mm_loadu_si128((const __m128i*)(data + offset));
268
+ for (int j = 0; j < qb; j++) {
234
269
  __m128i v_q = _mm_loadu_si128(
235
- (const __m128i*)(query + j * di_8b + d_256 / 8));
270
+ (const __m128i*)(query + j * size + offset));
236
271
  __m128i v_and = _mm_and_si128(v_q, v_x);
237
- __m128i v_popcnt = _mm_popcnt_epi32(v_and);
238
- sum_128 = _mm_add_epi32(sum_128, _mm_slli_epi32(v_popcnt, j));
272
+ __m128i v_popcnt = popcount_128(v_and);
273
+ __m128i v_shifted = _mm_slli_epi64(v_popcnt, j);
274
+ sum_128 = _mm_add_epi64(sum_128, v_shifted);
239
275
  }
240
276
  }
241
-
242
- if (d != d_128) {
243
- const size_t leftovers = d - d_128;
244
- const __mmask16 mask = (1 << ((leftovers + 7) / 8)) - 1;
245
-
246
- __m128i v_x = _mm_maskz_loadu_epi8(
247
- mask, (const __m128i*)(binary_data + d_128 / 8));
248
- for (size_t j = 0; j < qb; j++) {
249
- __m128i v_q = _mm_maskz_loadu_epi8(
250
- mask, (const __m128i*)(query + j * di_8b + d_128 / 8));
251
- __m128i v_and = _mm_and_si128(v_q, v_x);
252
- __m128i v_popcnt = _mm_popcnt_epi32(v_and);
253
- sum_128 = _mm_add_epi32(sum_128, _mm_slli_epi32(v_popcnt, j));
277
+ sum += reduce_add_128(sum_128);
278
+ #endif // defined(__SSE4_1__)
279
+ for (size_t step = 64 / 8; offset + step <= size; offset += step) {
280
+ const auto yv = *(const uint64_t*)(data + offset);
281
+ for (int j = 0; j < qb; j++) {
282
+ const auto qv = *(const uint64_t*)(query + j * size + offset);
283
+ sum += __builtin_popcountll(qv & yv) << j;
254
284
  }
255
285
  }
256
-
257
- int sum_64le = 0;
258
- sum_64le += _mm_extract_epi32(sum_128, 0);
259
- sum_64le += _mm_extract_epi32(sum_128, 1);
260
- sum_64le += _mm_extract_epi32(sum_128, 2);
261
- sum_64le += _mm_extract_epi32(sum_128, 3);
262
-
263
- return static_cast<float>(sum_64le);
286
+ for (; offset < size; ++offset) {
287
+ const auto yv = *(data + offset);
288
+ for (int j = 0; j < qb; j++) {
289
+ const auto qv = *(query + j * size + offset);
290
+ sum += __builtin_popcount(qv & yv) << j;
291
+ }
292
+ }
293
+ return sum;
264
294
  }
265
- #endif
266
295
 
267
- #if defined(__AVX512F__) && !defined(__AVX512VPOPCNTDQ__)
268
296
  /**
269
- * AVX512-optimized version of dot product computation between query and binary
270
- * data. Uses AVX512F instructions but does not require AVX512VPOPCNTDQ.
297
+ * Compute dot product between query and binary data using popcount operations.
271
298
  *
272
299
  * @param query Pointer to rearranged rotated query data
273
- * @param binary_data Pointer to binary data
300
+ * @param data Pointer to binary data
274
301
  * @param d Dimension
275
302
  * @param qb Number of quantization bits
276
- * @return Dot product result as float
303
+ * @return Unsigned integer dot product
277
304
  */
278
- inline float rabitq_dp_popcnt_avx512_fallback(
305
+ inline uint64_t bitwise_xor_dot_product(
279
306
  const uint8_t* query,
280
- const uint8_t* binary_data,
281
- size_t d,
307
+ const uint8_t* data,
308
+ size_t size,
282
309
  size_t qb) {
283
- const size_t di_8b = (d + 7) / 8;
284
- const size_t d_512 = (d / 512) * 512;
285
- const size_t d_256 = (d / 256) * 256;
286
- const size_t d_128 = (d / 128) * 128;
287
-
288
- // Use the lookup-based popcount helper function
289
-
290
- __m512i sum_512 = _mm512_setzero_si512();
291
-
292
- // Process 512 bits (64 bytes) at a time using lookup-based popcount
293
- for (size_t i = 0; i < d_512; i += 512) {
294
- __m512i v_x = _mm512_loadu_si512((const __m512i*)(binary_data + i / 8));
295
- for (size_t j = 0; j < qb; j++) {
296
- __m512i v_q = _mm512_loadu_si512(
297
- (const __m512i*)(query + j * di_8b + i / 8));
298
- __m512i v_and = _mm512_and_si512(v_q, v_x);
299
-
300
- // Use the popcount_lookup_avx512 helper function
301
- __m512i v_popcnt = popcount_lookup_avx512(v_and);
302
-
303
- // Sum bytes to 32-bit integers
304
- __m512i v_sad = _mm512_sad_epu8(v_popcnt, _mm512_setzero_si512());
305
-
306
- // Shift by j and add to sum
307
- __m512i v_shifted = _mm512_slli_epi64(v_sad, j);
308
- sum_512 = _mm512_add_epi64(sum_512, v_shifted);
310
+ uint64_t sum = 0;
311
+ size_t offset = 0;
312
+ #if defined(__AVX512F__)
313
+ // Handle 512-bit chunks.
314
+ if (size_t step = 512 / 8; offset + step <= size) {
315
+ __m512i sum_512 = _mm512_setzero_si512();
316
+ for (; offset + step <= size; offset += step) {
317
+ __m512i v_x = _mm512_loadu_si512((const __m512i*)(data + offset));
318
+ for (int j = 0; j < qb; j++) {
319
+ __m512i v_q = _mm512_loadu_si512(
320
+ (const __m512i*)(query + j * size + offset));
321
+ __m512i v_xor = _mm512_xor_si512(v_q, v_x);
322
+ __m512i v_popcnt = popcount_512(v_xor);
323
+ __m512i v_shifted = _mm512_slli_epi64(v_popcnt, j);
324
+ sum_512 = _mm512_add_epi64(sum_512, v_shifted);
325
+ }
309
326
  }
327
+ sum += _mm512_reduce_add_epi64(sum_512);
310
328
  }
311
-
312
- // Handle 256-bit section if needed
313
- __m256i sum_256 = _mm256_setzero_si256();
314
- if (d_256 != d_512) {
315
- __m256i v_x =
316
- _mm256_loadu_si256((const __m256i*)(binary_data + d_512 / 8));
317
- for (size_t j = 0; j < qb; j++) {
318
- __m256i v_q = _mm256_loadu_si256(
319
- (const __m256i*)(query + j * di_8b + d_512 / 8));
320
- __m256i v_and = _mm256_and_si256(v_q, v_x);
321
-
322
- // Use the popcount_lookup_avx2 helper function
323
- __m256i v_popcnt = popcount_lookup_avx2(v_and);
324
-
325
- // Sum bytes to 64-bit integers
326
- __m256i v_sad = _mm256_sad_epu8(v_popcnt, _mm256_setzero_si256());
327
-
328
- // Shift by j and add to sum
329
- __m256i v_shifted = _mm256_slli_epi64(v_sad, j);
330
- sum_256 = _mm256_add_epi64(sum_256, v_shifted);
329
+ #endif
330
+ #if defined(__AVX2__)
331
+ if (size_t step = 256 / 8; offset + step <= size) {
332
+ __m256i sum_256 = _mm256_setzero_si256();
333
+ for (; offset + step <= size; offset += step) {
334
+ __m256i v_x = _mm256_loadu_si256((const __m256i*)(data + offset));
335
+ for (int j = 0; j < qb; j++) {
336
+ __m256i v_q = _mm256_loadu_si256(
337
+ (const __m256i*)(query + j * size + offset));
338
+ __m256i v_xor = _mm256_xor_si256(v_q, v_x);
339
+ __m256i v_popcnt = popcount_256(v_xor);
340
+ __m256i v_shifted = _mm256_slli_epi64(v_popcnt, j);
341
+ sum_256 = _mm256_add_epi64(sum_256, v_shifted);
342
+ }
331
343
  }
344
+ sum += reduce_add_256(sum_256);
332
345
  }
333
-
334
- // Handle 128-bit section and leftovers
346
+ #endif
347
+ #if defined(__SSE4_1__)
335
348
  __m128i sum_128 = _mm_setzero_si128();
336
- if (d_128 != d_256) {
337
- __m128i v_x =
338
- _mm_loadu_si128((const __m128i*)(binary_data + d_256 / 8));
339
- for (size_t j = 0; j < qb; j++) {
349
+ for (size_t step = 128 / 8; offset + step <= size; offset += step) {
350
+ __m128i v_x = _mm_loadu_si128((const __m128i*)(data + offset));
351
+ for (int j = 0; j < qb; j++) {
340
352
  __m128i v_q = _mm_loadu_si128(
341
- (const __m128i*)(query + j * di_8b + d_256 / 8));
342
- __m128i v_and = _mm_and_si128(v_q, v_x);
343
-
344
- // Scalar popcount for each 64-bit lane
345
- uint64_t lane0 = _mm_extract_epi64(v_and, 0);
346
- uint64_t lane1 = _mm_extract_epi64(v_and, 1);
347
- uint64_t pop0 = __builtin_popcountll(lane0) << j;
348
- uint64_t pop1 = __builtin_popcountll(lane1) << j;
349
- sum_128 = _mm_add_epi64(sum_128, _mm_set_epi64x(pop1, pop0));
353
+ (const __m128i*)(query + j * size + offset));
354
+ __m128i v_xor = _mm_xor_si128(v_q, v_x);
355
+ __m128i v_popcnt = popcount_128(v_xor);
356
+ __m128i v_shifted = _mm_slli_epi64(v_popcnt, j);
357
+ sum_128 = _mm_add_epi64(sum_128, v_shifted);
350
358
  }
351
359
  }
352
-
353
- // Handle remaining bytes (less than 16)
354
- uint64_t sum_leftover = 0;
355
- size_t d_leftover = d - d_128;
356
- if (d_leftover > 0) {
357
- for (size_t j = 0; j < qb; j++) {
358
- for (size_t k = 0; k < (d_leftover + 7) / 8; ++k) {
359
- uint8_t qv = query[j * di_8b + d_128 / 8 + k];
360
- uint8_t yv = binary_data[d_128 / 8 + k];
361
- sum_leftover += (__builtin_popcount(qv & yv) << j);
362
- }
360
+ sum += reduce_add_128(sum_128);
361
+ #endif
362
+ for (size_t step = 64 / 8; offset + step <= size; offset += step) {
363
+ const auto yv = *(const uint64_t*)(data + offset);
364
+ for (int j = 0; j < qb; j++) {
365
+ const auto qv = *(const uint64_t*)(query + j * size + offset);
366
+ sum += __builtin_popcountll(qv ^ yv) << j;
363
367
  }
364
368
  }
365
-
366
- // Horizontal sum of all lanes
367
- uint64_t sum = 0;
368
-
369
- // Sum from 512-bit registers
370
- alignas(64) uint64_t lanes512[8];
371
- _mm512_store_si512((__m512i*)lanes512, sum_512);
372
- for (int i = 0; i < 8; ++i) {
373
- sum += lanes512[i];
374
- }
375
-
376
- // Sum from 256-bit registers
377
- alignas(32) uint64_t lanes256[4];
378
- _mm256_store_si256((__m256i*)lanes256, sum_256);
379
- for (int i = 0; i < 4; ++i) {
380
- sum += lanes256[i];
369
+ for (; offset < size; ++offset) {
370
+ const auto yv = *(data + offset);
371
+ for (int j = 0; j < qb; j++) {
372
+ const auto qv = *(query + j * size + offset);
373
+ sum += __builtin_popcount(qv ^ yv) << j;
374
+ }
381
375
  }
382
-
383
- // Sum from 128-bit registers
384
- alignas(16) uint64_t lanes128[2];
385
- _mm_store_si128((__m128i*)lanes128, sum_128);
386
- sum += lanes128[0] + lanes128[1];
387
-
388
- // Add leftovers
389
- sum += sum_leftover;
390
-
391
- return static_cast<float>(sum);
376
+ return sum;
392
377
  }
393
- #endif
394
-
395
- #ifdef __AVX2__
396
-
397
- /**
398
- * AVX2-optimized version of dot product computation between query and binary
399
- * data.
400
- *
401
- * @param query Pointer to rearranged rotated query data
402
- * @param binary_data Pointer to binary data
403
- * @param d Dimension
404
- * @param qb Number of quantization bits
405
- * @return Dot product result as float
406
- */
407
-
408
- inline float rabitq_dp_popcnt_avx2(
409
- const uint8_t* query,
410
- const uint8_t* binary_data,
411
- size_t d,
412
- size_t qb) {
413
- const size_t di_8b = (d + 7) / 8;
414
- const size_t d_256 = (d / 256) * 256;
415
- const size_t d_128 = (d / 128) * 128;
416
-
417
- // Use the lookup-based popcount helper function
418
-
419
- __m256i sum_256 = _mm256_setzero_si256();
420
-
421
- // Process 256 bits (32 bytes) at a time using lookup-based popcount
422
- for (size_t i = 0; i < d_256; i += 256) {
423
- __m256i v_x = _mm256_loadu_si256((const __m256i*)(binary_data + i / 8));
424
- for (size_t j = 0; j < qb; j++) {
425
- __m256i v_q = _mm256_loadu_si256(
426
- (const __m256i*)(query + j * di_8b + i / 8));
427
- __m256i v_and = _mm256_and_si256(v_q, v_x);
428
378
 
429
- // Use the popcount_lookup_avx2 helper function
430
- __m256i v_popcnt = popcount_lookup_avx2(v_and);
431
-
432
- // Convert byte counts to 64-bit lanes and shift by j
433
- __m256i v_sad = _mm256_sad_epu8(v_popcnt, _mm256_setzero_si256());
434
- __m256i v_shifted = _mm256_slli_epi64(v_sad, static_cast<int>(j));
435
- sum_256 = _mm256_add_epi64(sum_256, v_shifted);
379
+ inline uint64_t popcount(const uint8_t* data, size_t size) {
380
+ uint64_t sum = 0;
381
+ size_t offset = 0;
382
+ #if defined(__AVX512F__)
383
+ // Handle 512-bit chunks.
384
+ if (offset + 512 / 8 <= size) {
385
+ __m512i sum_512 = _mm512_setzero_si512();
386
+ for (size_t end; (end = offset + 512 / 8) <= size; offset = end) {
387
+ __m512i v_x = _mm512_loadu_si512((const __m512i*)(data + offset));
388
+ __m512i v_popcnt = popcount_512(v_x);
389
+ sum_512 = _mm512_add_epi64(sum_512, v_popcnt);
436
390
  }
391
+ sum += _mm512_reduce_add_epi64(sum_512);
437
392
  }
438
-
439
- // Handle leftovers with 128-bit SIMD
440
- __m128i sum_128 = _mm_setzero_si128();
441
- if (d_128 != d_256) {
442
- __m128i v_x =
443
- _mm_loadu_si128((const __m128i*)(binary_data + d_256 / 8));
444
- for (size_t j = 0; j < qb; j++) {
445
- __m128i v_q = _mm_loadu_si128(
446
- (const __m128i*)(query + j * di_8b + d_256 / 8));
447
- __m128i v_and = _mm_and_si128(v_q, v_x);
448
- // Scalar popcount for each 64-bit lane
449
- uint64_t lane0 = _mm_extract_epi64(v_and, 0);
450
- uint64_t lane1 = _mm_extract_epi64(v_and, 1);
451
- uint64_t pop0 = __builtin_popcountll(lane0) << j;
452
- uint64_t pop1 = __builtin_popcountll(lane1) << j;
453
- sum_128 = _mm_add_epi64(sum_128, _mm_set_epi64x(pop1, pop0));
393
+ #endif // defined(__AVX512F__)
394
+ #if defined(__AVX2__)
395
+ if (offset + 256 / 8 <= size) {
396
+ __m256i sum_256 = _mm256_setzero_si256();
397
+ for (size_t end; (end = offset + 256 / 8) <= size; offset = end) {
398
+ __m256i v_x = _mm256_loadu_si256((const __m256i*)(data + offset));
399
+ __m256i v_popcnt = popcount_256(v_x);
400
+ sum_256 = _mm256_add_epi64(sum_256, v_popcnt);
454
401
  }
402
+ sum += reduce_add_256(sum_256);
455
403
  }
456
-
457
- // Handle remaining bytes (less than 16)
458
- uint64_t sum_leftover = 0;
459
- size_t d_leftover = d - d_128;
460
- if (d_leftover > 0) {
461
- for (size_t j = 0; j < qb; j++) {
462
- for (size_t k = 0; k < (d_leftover + 7) / 8; ++k) {
463
- uint8_t qv = query[j * di_8b + d_128 / 8 + k];
464
- uint8_t yv = binary_data[d_128 / 8 + k];
465
- sum_leftover += (__builtin_popcount(qv & yv) << j);
466
- }
467
- }
404
+ #endif // defined(__AVX2__)
405
+ #if defined(__SSE4_1__)
406
+ __m128i sum_128 = _mm_setzero_si128();
407
+ for (size_t step = 128 / 8; offset + step <= size; offset += step) {
408
+ __m128i v_x = _mm_loadu_si128((const __m128i*)(data + offset));
409
+ sum_128 = _mm_add_epi64(sum_128, popcount_128(v_x));
468
410
  }
411
+ sum += reduce_add_128(sum_128);
412
+ #endif // defined(__SSE4_1__)
469
413
 
470
- // Horizontal sum of all lanes
471
- uint64_t sum = 0;
472
- // sum_256: 4 lanes of 64 bits
473
- alignas(32) uint64_t lanes[4];
474
- _mm256_store_si256((__m256i*)lanes, sum_256);
475
- for (int i = 0; i < 4; ++i) {
476
- sum += lanes[i];
414
+ for (size_t step = 64 / 8; offset + step <= size; offset += step) {
415
+ const auto yv = *(const uint64_t*)(data + offset);
416
+ sum += __builtin_popcountll(yv);
477
417
  }
478
- // sum_128: 2 lanes of 64 bits
479
- alignas(16) uint64_t lanes128[2];
480
- _mm_store_si128((__m128i*)lanes128, sum_128);
481
- sum += lanes128[0] + lanes128[1];
482
- // leftovers
483
- sum += sum_leftover;
484
-
485
- return static_cast<float>(sum);
486
- }
487
- #endif
488
-
489
- /**
490
- * Compute dot product between query and binary data using popcount operations.
491
- *
492
- * @param query Pointer to rearranged rotated query data
493
- * @param binary_data Pointer to binary data
494
- * @param d Dimension
495
- * @param qb Number of quantization bits
496
- * @return Dot product result as float
497
- */
498
- inline float rabitq_dp_popcnt(
499
- const uint8_t* query,
500
- const uint8_t* binary_data,
501
- size_t d,
502
- size_t qb) {
503
- #if defined(__AVX512F__) && defined(__AVX512VPOPCNTDQ__)
504
- return rabitq_dp_popcnt_avx512(query, binary_data, d, qb);
505
- #elif defined(__AVX512F__)
506
- return rabitq_dp_popcnt_avx512_fallback(query, binary_data, d, qb);
507
- #elif defined(__AVX2__)
508
- return rabitq_dp_popcnt_avx2(query, binary_data, d, qb);
509
- #else
510
- const size_t di_8b = (d + 7) / 8;
511
- const size_t di_64b = (di_8b / 8) * 8;
512
-
513
- uint64_t dot_qo = 0;
514
- for (size_t j = 0; j < qb; j++) {
515
- const uint8_t* query_j = query + j * di_8b;
516
-
517
- // process 64-bit popcounts
518
- uint64_t count_dot = 0;
519
- for (size_t i = 0; i < di_64b; i += 8) {
520
- const auto qv = *(const uint64_t*)(query_j + i);
521
- const auto yv = *(const uint64_t*)(binary_data + i);
522
- count_dot += __builtin_popcountll(qv & yv);
523
- }
524
-
525
- // process leftovers
526
- for (size_t i = di_64b; i < di_8b; i++) {
527
- const auto qv = *(query_j + i);
528
- const auto yv = *(binary_data + i);
529
- count_dot += __builtin_popcount(qv & yv);
530
- }
531
-
532
- dot_qo += (count_dot << j);
418
+ for (; offset < size; ++offset) {
419
+ const auto yv = *(data + offset);
420
+ sum += __builtin_popcount(yv);
533
421
  }
534
-
535
- return static_cast<float>(dot_qo);
536
- #endif
422
+ return sum;
537
423
  }
538
424
 
539
- } // namespace faiss
425
+ } // namespace faiss::rabitq