faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -0,0 +1,559 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifdef COMPILE_SIMD_AVX512_SPR
9
+
10
+ #include <immintrin.h>
11
+
12
+ #include <faiss/impl/scalar_quantizer/codecs.h>
13
+ #include <faiss/impl/scalar_quantizer/distance_computers.h>
14
+ #include <faiss/impl/scalar_quantizer/quantizers.h>
15
+ #include <faiss/impl/scalar_quantizer/scanners.h>
16
+ #include <faiss/impl/scalar_quantizer/similarities.h>
17
+ #include <faiss/impl/simdlib/simdlib_avx512.h>
18
+
19
+ #include <faiss/impl/scalar_quantizer/sq-avx512-impl.h>
20
+
21
+ namespace faiss {
22
+ namespace scalar_quantizer {
23
+
24
+ /**********************************************************
25
+ * Codecs — inherit AVX512 implementations
26
+ **********************************************************/
27
+
28
+ template <>
29
+ struct Codec8bit<SIMDLevel::AVX512_SPR> : Codec8bit<SIMDLevel::AVX512> {};
30
+
31
+ template <>
32
+ struct Codec4bit<SIMDLevel::AVX512_SPR> : Codec4bit<SIMDLevel::AVX512> {};
33
+
34
+ template <>
35
+ struct Codec6bit<SIMDLevel::AVX512_SPR> : Codec6bit<SIMDLevel::AVX512> {};
36
+
37
+ /**********************************************************
38
+ * Quantizers — inherit AVX512 implementations
39
+ **********************************************************/
40
+
41
+ template <class Codec>
42
+ struct QuantizerTemplate<
43
+ Codec,
44
+ QuantizerTemplateScaling::UNIFORM,
45
+ SIMDLevel::AVX512_SPR>
46
+ : QuantizerTemplate<
47
+ Codec,
48
+ QuantizerTemplateScaling::UNIFORM,
49
+ SIMDLevel::AVX512> {
50
+ using QuantizerTemplate<
51
+ Codec,
52
+ QuantizerTemplateScaling::UNIFORM,
53
+ SIMDLevel::AVX512>::QuantizerTemplate;
54
+ };
55
+
56
+ template <class Codec>
57
+ struct QuantizerTemplate<
58
+ Codec,
59
+ QuantizerTemplateScaling::NON_UNIFORM,
60
+ SIMDLevel::AVX512_SPR>
61
+ : QuantizerTemplate<
62
+ Codec,
63
+ QuantizerTemplateScaling::NON_UNIFORM,
64
+ SIMDLevel::AVX512> {
65
+ using QuantizerTemplate<
66
+ Codec,
67
+ QuantizerTemplateScaling::NON_UNIFORM,
68
+ SIMDLevel::AVX512>::QuantizerTemplate;
69
+ };
70
+
71
+ template <>
72
+ struct QuantizerFP16<SIMDLevel::AVX512_SPR> : QuantizerFP16<SIMDLevel::AVX512> {
73
+ using QuantizerFP16<SIMDLevel::AVX512>::QuantizerFP16;
74
+ };
75
+
76
+ template <>
77
+ struct QuantizerBF16<SIMDLevel::AVX512_SPR> : QuantizerBF16<SIMDLevel::AVX512> {
78
+ using QuantizerBF16<SIMDLevel::AVX512>::QuantizerBF16;
79
+
80
+ void encode_vector(const float* x, uint8_t* code) const override {
81
+ encode_bf16_simd(x, (uint16_t*)code, this->d);
82
+ }
83
+
84
+ void decode_vector(const uint8_t* code, float* x) const override {
85
+ decode_bf16_simd((const uint16_t*)code, x, this->d);
86
+ }
87
+ };
88
+
89
+ template <>
90
+ struct Quantizer8bitDirect<SIMDLevel::AVX512_SPR>
91
+ : Quantizer8bitDirect<SIMDLevel::AVX512> {
92
+ using Quantizer8bitDirect<SIMDLevel::AVX512>::Quantizer8bitDirect;
93
+ };
94
+
95
+ template <>
96
+ struct Quantizer8bitDirectSigned<SIMDLevel::AVX512_SPR>
97
+ : Quantizer8bitDirectSigned<SIMDLevel::AVX512> {
98
+ using Quantizer8bitDirectSigned<
99
+ SIMDLevel::AVX512>::Quantizer8bitDirectSigned;
100
+ };
101
+
102
+ /**********************************************************
103
+ * TurboQuant MSE — inherit AVX512 implementations
104
+ **********************************************************/
105
+
106
+ template <int NBits>
107
+ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::AVX512_SPR>
108
+ : QuantizerTurboQuantMSE<NBits, SIMDLevel::AVX512> {
109
+ using QuantizerTurboQuantMSE<NBits, SIMDLevel::AVX512>::
110
+ QuantizerTurboQuantMSE;
111
+ };
112
+
113
+ /**********************************************************
114
+ * Similarities — inherit AVX512 implementations
115
+ **********************************************************/
116
+
117
+ template <>
118
+ struct SimilarityL2<SIMDLevel::AVX512_SPR> : SimilarityL2<SIMDLevel::AVX512> {
119
+ using SimilarityL2<SIMDLevel::AVX512>::SimilarityL2;
120
+ static constexpr SIMDLevel simd_level = SIMDLevel::AVX512_SPR;
121
+ };
122
+
123
+ template <>
124
+ struct SimilarityIP<SIMDLevel::AVX512_SPR> : SimilarityIP<SIMDLevel::AVX512> {
125
+ using SimilarityIP<SIMDLevel::AVX512>::SimilarityIP;
126
+ static constexpr SIMDLevel simd_level = SIMDLevel::AVX512_SPR;
127
+ };
128
+
129
+ /**********************************************************
130
+ * Generic DCTemplate — delegate to AVX512 implementations
131
+ **********************************************************/
132
+
133
+ template <class Quantizer, class Similarity>
134
+ struct DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512_SPR>
135
+ : DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512> {
136
+ using DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512>::DCTemplate;
137
+ };
138
+
139
+ /**********************************************************
140
+ * DistanceComputerByte: AVX512-VNNI
141
+ *
142
+ * Uses _mm512_dpbusd_epi32 to compute dot products of uint8 vectors
143
+ * at 64 bytes per instruction (4x throughput vs generic AVX512).
144
+ **********************************************************/
145
+
146
+ template <class Similarity>
147
+ struct DistanceComputerByte<Similarity, SIMDLevel::AVX512_SPR>
148
+ : SQDistanceComputer {
149
+ using Sim = Similarity;
150
+
151
+ int d;
152
+ std::vector<uint8_t> tmp;
153
+
154
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
155
+
156
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
157
+ const {
158
+ if constexpr (Sim::metric_type == METRIC_INNER_PRODUCT) {
159
+ __m512i accu = _mm512_setzero_si512();
160
+ int i = 0;
161
+ for (; i + 64 <= d; i += 64) {
162
+ __m512i c1 = _mm512_loadu_si512(code1 + i);
163
+ __m512i c2 = _mm512_loadu_si512(code2 + i);
164
+
165
+ __m512i c2_signed = _mm512_sub_epi8(c2, _mm512_set1_epi8(-128));
166
+ accu = _mm512_dpbusd_epi32(accu, c1, c2_signed);
167
+ }
168
+ int32_t sum_c1 = 0;
169
+ for (int j = 0; j < i; j++) {
170
+ sum_c1 += code1[j];
171
+ }
172
+ int32_t result = _mm512_reduce_add_epi32(accu) + 128 * sum_c1;
173
+
174
+ for (; i < d; i++) {
175
+ result += int(code1[i]) * code2[i];
176
+ }
177
+ return result;
178
+ } else {
179
+ __m512i accu = _mm512_setzero_si512();
180
+ int i = 0;
181
+ for (; i + 64 <= d; i += 64) {
182
+ __m256i c1_lo = _mm256_loadu_si256((const __m256i*)(code1 + i));
183
+ __m256i c2_lo = _mm256_loadu_si256((const __m256i*)(code2 + i));
184
+ __m256i c1_hi =
185
+ _mm256_loadu_si256((const __m256i*)(code1 + i + 32));
186
+ __m256i c2_hi =
187
+ _mm256_loadu_si256((const __m256i*)(code2 + i + 32));
188
+
189
+ __m512i c1_16_lo = _mm512_cvtepu8_epi16(c1_lo);
190
+ __m512i c2_16_lo = _mm512_cvtepu8_epi16(c2_lo);
191
+ __m512i diff_lo = _mm512_sub_epi16(c1_16_lo, c2_16_lo);
192
+
193
+ __m512i c1_16_hi = _mm512_cvtepu8_epi16(c1_hi);
194
+ __m512i c2_16_hi = _mm512_cvtepu8_epi16(c2_hi);
195
+ __m512i diff_hi = _mm512_sub_epi16(c1_16_hi, c2_16_hi);
196
+
197
+ accu = _mm512_add_epi32(
198
+ accu, _mm512_madd_epi16(diff_lo, diff_lo));
199
+ accu = _mm512_add_epi32(
200
+ accu, _mm512_madd_epi16(diff_hi, diff_hi));
201
+ }
202
+ for (; i + 32 <= d; i += 32) {
203
+ __m256i c1v = _mm256_loadu_si256((const __m256i*)(code1 + i));
204
+ __m256i c2v = _mm256_loadu_si256((const __m256i*)(code2 + i));
205
+ __m512i c1_16 = _mm512_cvtepu8_epi16(c1v);
206
+ __m512i c2_16 = _mm512_cvtepu8_epi16(c2v);
207
+ __m512i diff = _mm512_sub_epi16(c1_16, c2_16);
208
+ accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff));
209
+ }
210
+ int32_t result = _mm512_reduce_add_epi32(accu);
211
+
212
+ for (; i < d; i++) {
213
+ int diff = int(code1[i]) - code2[i];
214
+ result += diff * diff;
215
+ }
216
+ return result;
217
+ }
218
+ }
219
+
220
+ void set_query(const float* x) final {
221
+ for (int i = 0; i < d; i++) {
222
+ tmp[i] = int(x[i]);
223
+ }
224
+ }
225
+
226
+ int compute_distance(const float* x, const uint8_t* code) {
227
+ set_query(x);
228
+ return compute_code_distance(tmp.data(), code);
229
+ }
230
+
231
+ float symmetric_dis(idx_t i, idx_t j) override {
232
+ return compute_code_distance(
233
+ codes + i * code_size, codes + j * code_size);
234
+ }
235
+
236
+ float query_to_code(const uint8_t* code) const final {
237
+ return compute_code_distance(tmp.data(), code);
238
+ }
239
+ };
240
+
241
+ /**********************************************************
242
+ * DistanceComputerByteSigned: AVX512_SPR specialization for
243
+ * QT_8bit_direct_signed.
244
+ *
245
+ * Storage convention (see Quantizer8bitDirectSigned):
246
+ * stored_byte = value + 128, i.e. value = stored_byte - 128
247
+ *
248
+ * L2: (s_a - 128) - (s_b - 128) == s_a - s_b, so the unsigned
249
+ * widened-madd kernel is bit-exact for the signed variant.
250
+ *
251
+ * IP: (s_a - 128) * (s_b - 128)
252
+ * = s_a*s_b - 128*(s_a + s_b) + 16384
253
+ * summed over d components:
254
+ * sum_ip_signed = sum_ip_unsigned
255
+ * - 128 * (sum(s_a) + sum(s_b))
256
+ * + 16384 * d
257
+ * sum(s_a), sum(s_b) are cheap via _mm512_sad_epu8 against zero.
258
+ **********************************************************/
259
+
260
+ template <class Similarity>
261
+ struct DistanceComputerByteSigned<Similarity, SIMDLevel::AVX512_SPR>
262
+ : SQDistanceComputer {
263
+ using Sim = Similarity;
264
+
265
+ int d;
266
+ std::vector<uint8_t> tmp;
267
+
268
+ DistanceComputerByteSigned(int d, const std::vector<float>&)
269
+ : d(d), tmp(d) {}
270
+
271
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
272
+ const {
273
+ if constexpr (Sim::metric_type == METRIC_INNER_PRODUCT) {
274
+ __m512i accu = _mm512_setzero_si512();
275
+ __m512i sum_a = _mm512_setzero_si512();
276
+ __m512i sum_b = _mm512_setzero_si512();
277
+ const __m512i zero = _mm512_setzero_si512();
278
+ const __m512i bias = _mm512_set1_epi8(-128);
279
+
280
+ int i = 0;
281
+ for (; i + 64 <= d; i += 64) {
282
+ __m512i c1 = _mm512_loadu_si512(code1 + i);
283
+ __m512i c2 = _mm512_loadu_si512(code2 + i);
284
+
285
+ sum_a = _mm512_add_epi64(sum_a, _mm512_sad_epu8(c1, zero));
286
+ sum_b = _mm512_add_epi64(sum_b, _mm512_sad_epu8(c2, zero));
287
+
288
+ __m512i c2_signed = _mm512_sub_epi8(c2, bias);
289
+ accu = _mm512_dpbusd_epi32(accu, c1, c2_signed);
290
+ }
291
+ int32_t sum_c1_for_bias = int32_t(_mm512_reduce_add_epi64(sum_a));
292
+ int32_t result =
293
+ _mm512_reduce_add_epi32(accu) + 128 * sum_c1_for_bias;
294
+
295
+ int32_t tail_sum_a = 0, tail_sum_b = 0;
296
+ for (; i < d; ++i) {
297
+ result += int32_t(code1[i]) * int32_t(code2[i]);
298
+ tail_sum_a += code1[i];
299
+ tail_sum_b += code2[i];
300
+ }
301
+
302
+ int32_t total_sum_a = sum_c1_for_bias + tail_sum_a;
303
+ int32_t total_sum_b =
304
+ int32_t(_mm512_reduce_add_epi64(sum_b)) + tail_sum_b;
305
+ result -= 128 * (total_sum_a + total_sum_b);
306
+ result += 16384 * d;
307
+ return result;
308
+ } else {
309
+ __m512i accu = _mm512_setzero_si512();
310
+ int i = 0;
311
+ for (; i + 64 <= d; i += 64) {
312
+ __m256i c1_lo = _mm256_loadu_si256((const __m256i*)(code1 + i));
313
+ __m256i c2_lo = _mm256_loadu_si256((const __m256i*)(code2 + i));
314
+ __m256i c1_hi =
315
+ _mm256_loadu_si256((const __m256i*)(code1 + i + 32));
316
+ __m256i c2_hi =
317
+ _mm256_loadu_si256((const __m256i*)(code2 + i + 32));
318
+ __m512i diff_lo = _mm512_sub_epi16(
319
+ _mm512_cvtepu8_epi16(c1_lo),
320
+ _mm512_cvtepu8_epi16(c2_lo));
321
+ __m512i diff_hi = _mm512_sub_epi16(
322
+ _mm512_cvtepu8_epi16(c1_hi),
323
+ _mm512_cvtepu8_epi16(c2_hi));
324
+ accu = _mm512_add_epi32(
325
+ accu, _mm512_madd_epi16(diff_lo, diff_lo));
326
+ accu = _mm512_add_epi32(
327
+ accu, _mm512_madd_epi16(diff_hi, diff_hi));
328
+ }
329
+ for (; i + 32 <= d; i += 32) {
330
+ __m256i c1v = _mm256_loadu_si256((const __m256i*)(code1 + i));
331
+ __m256i c2v = _mm256_loadu_si256((const __m256i*)(code2 + i));
332
+ __m512i diff = _mm512_sub_epi16(
333
+ _mm512_cvtepu8_epi16(c1v), _mm512_cvtepu8_epi16(c2v));
334
+ accu = _mm512_add_epi32(accu, _mm512_madd_epi16(diff, diff));
335
+ }
336
+ int32_t result = _mm512_reduce_add_epi32(accu);
337
+ for (; i < d; ++i) {
338
+ int32_t diff = int32_t(code1[i]) - int32_t(code2[i]);
339
+ result += diff * diff;
340
+ }
341
+ return result;
342
+ }
343
+ }
344
+
345
+ void set_query(const float* x) final {
346
+ for (int i = 0; i < d; ++i) {
347
+ tmp[i] = uint8_t(int(x[i]) + 128);
348
+ }
349
+ }
350
+
351
+ int compute_distance(const float* x, const uint8_t* code) {
352
+ set_query(x);
353
+ return compute_code_distance(tmp.data(), code);
354
+ }
355
+
356
+ float symmetric_dis(idx_t i, idx_t j) override {
357
+ return compute_code_distance(
358
+ codes + i * code_size, codes + j * code_size);
359
+ }
360
+
361
+ float query_to_code(const uint8_t* code) const final {
362
+ return compute_code_distance(tmp.data(), code);
363
+ }
364
+ };
365
+
366
+ /**********************************************************
367
+ * BF16 native distance helpers using VDPBF16PS
368
+ **********************************************************/
369
+
370
+ static FAISS_ALWAYS_INLINE float bf16_vdpbf16ps(
371
+ const uint16_t* a,
372
+ const uint16_t* b,
373
+ size_t d) {
374
+ __m512 acc = _mm512_setzero_ps();
375
+ size_t i = 0;
376
+ for (; i + 32 <= d; i += 32) {
377
+ __m512bh va = (__m512bh)_mm512_loadu_epi16(a + i);
378
+ __m512bh vb = (__m512bh)_mm512_loadu_epi16(b + i);
379
+ acc = _mm512_dpbf16_ps(acc, va, vb);
380
+ }
381
+ // Remainder: 16 elements (d % 16 == 0 but may not be % 32)
382
+ if (i < d) {
383
+ __m256i a_lo = _mm256_loadu_epi16(a + i);
384
+ __m256i b_lo = _mm256_loadu_epi16(b + i);
385
+ __m512bh va =
386
+ (__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), a_lo, 0);
387
+ __m512bh vb =
388
+ (__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), b_lo, 0);
389
+ acc = _mm512_dpbf16_ps(acc, va, vb);
390
+ }
391
+ return _mm512_reduce_add_ps(acc);
392
+ }
393
+
394
+ static FAISS_ALWAYS_INLINE float bf16_L2_asymmetric(
395
+ const uint16_t* query_bf16,
396
+ const uint16_t* code,
397
+ size_t d) {
398
+ __m512 acc_qc = _mm512_setzero_ps();
399
+ __m512 acc_cc = _mm512_setzero_ps();
400
+ size_t i = 0;
401
+ for (; i + 32 <= d; i += 32) {
402
+ __m512bh vq = (__m512bh)_mm512_loadu_epi16(query_bf16 + i);
403
+ __m512bh vc = (__m512bh)_mm512_loadu_epi16(code + i);
404
+ acc_qc = _mm512_dpbf16_ps(acc_qc, vq, vc);
405
+ acc_cc = _mm512_dpbf16_ps(acc_cc, vc, vc);
406
+ }
407
+ if (i < d) {
408
+ __m256i q_lo = _mm256_loadu_epi16(query_bf16 + i);
409
+ __m256i c_lo = _mm256_loadu_epi16(code + i);
410
+ __m512bh vq =
411
+ (__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), q_lo, 0);
412
+ __m512bh vc =
413
+ (__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), c_lo, 0);
414
+ acc_qc = _mm512_dpbf16_ps(acc_qc, vq, vc);
415
+ acc_cc = _mm512_dpbf16_ps(acc_cc, vc, vc);
416
+ }
417
+ float dot_qc = _mm512_reduce_add_ps(acc_qc);
418
+ float norm_c = _mm512_reduce_add_ps(acc_cc);
419
+ return -2.0f * dot_qc + norm_c;
420
+ }
421
+
422
+ static FAISS_ALWAYS_INLINE float bf16_L2_symmetric(
423
+ const uint16_t* a,
424
+ const uint16_t* b,
425
+ size_t d) {
426
+ __m512 acc_ab = _mm512_setzero_ps();
427
+ __m512 acc_aa = _mm512_setzero_ps();
428
+ __m512 acc_bb = _mm512_setzero_ps();
429
+ size_t i = 0;
430
+ for (; i + 32 <= d; i += 32) {
431
+ __m512bh va = (__m512bh)_mm512_loadu_epi16(a + i);
432
+ __m512bh vb = (__m512bh)_mm512_loadu_epi16(b + i);
433
+ acc_ab = _mm512_dpbf16_ps(acc_ab, va, vb);
434
+ acc_aa = _mm512_dpbf16_ps(acc_aa, va, va);
435
+ acc_bb = _mm512_dpbf16_ps(acc_bb, vb, vb);
436
+ }
437
+ if (i < d) {
438
+ __m256i a_lo = _mm256_loadu_epi16(a + i);
439
+ __m256i b_lo = _mm256_loadu_epi16(b + i);
440
+ __m512bh va =
441
+ (__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), a_lo, 0);
442
+ __m512bh vb =
443
+ (__m512bh)_mm512_inserti64x4(_mm512_setzero_si512(), b_lo, 0);
444
+ acc_ab = _mm512_dpbf16_ps(acc_ab, va, vb);
445
+ acc_aa = _mm512_dpbf16_ps(acc_aa, va, va);
446
+ acc_bb = _mm512_dpbf16_ps(acc_bb, vb, vb);
447
+ }
448
+ return _mm512_reduce_add_ps(acc_aa) - 2.0f * _mm512_reduce_add_ps(acc_ab) +
449
+ _mm512_reduce_add_ps(acc_bb);
450
+ }
451
+
452
+ /**********************************************************
453
+ * BF16 + Inner Product distance computer (SPR)
454
+ **********************************************************/
455
+
456
+ struct DCBF16_IP : SQDistanceComputer {
457
+ using Sim = SimilarityIP<SIMDLevel::AVX512_SPR>;
458
+
459
+ size_t d;
460
+ std::vector<uint16_t> query_bf16;
461
+
462
+ DCBF16_IP(size_t d, const std::vector<float>&) : d(d), query_bf16(d) {}
463
+
464
+ void set_query(const float* x) final {
465
+ q = x;
466
+ encode_bf16_simd(x, query_bf16.data(), d);
467
+ }
468
+
469
+ float query_to_code(const uint8_t* code) const final {
470
+ return bf16_vdpbf16ps(query_bf16.data(), (const uint16_t*)code, d);
471
+ }
472
+
473
+ float symmetric_dis(idx_t i, idx_t j) override {
474
+ return bf16_vdpbf16ps(
475
+ (const uint16_t*)(codes + i * code_size),
476
+ (const uint16_t*)(codes + j * code_size),
477
+ d);
478
+ }
479
+ };
480
+
481
+ /**********************************************************
482
+ * BF16 + L2 distance computer (SPR)
483
+ **********************************************************/
484
+
485
+ struct DCBF16_L2 : SQDistanceComputer {
486
+ using Sim = SimilarityL2<SIMDLevel::AVX512_SPR>;
487
+
488
+ size_t d;
489
+ std::vector<uint16_t> query_bf16;
490
+ float query_norm_sq;
491
+
492
+ DCBF16_L2(size_t d, const std::vector<float>&)
493
+ : d(d), query_bf16(d), query_norm_sq(0) {}
494
+
495
+ void set_query(const float* x) final {
496
+ q = x;
497
+ encode_bf16_simd(x, query_bf16.data(), d);
498
+ query_norm_sq = bf16_vdpbf16ps(query_bf16.data(), query_bf16.data(), d);
499
+ }
500
+
501
+ float query_to_code(const uint8_t* code) const final {
502
+ return query_norm_sq +
503
+ bf16_L2_asymmetric(query_bf16.data(), (const uint16_t*)code, d);
504
+ }
505
+
506
+ float symmetric_dis(idx_t i, idx_t j) override {
507
+ return bf16_L2_symmetric(
508
+ (const uint16_t*)(codes + i * code_size),
509
+ (const uint16_t*)(codes + j * code_size),
510
+ d);
511
+ }
512
+ };
513
+
514
+ template <>
515
+ struct DCTemplate<
516
+ QuantizerBF16<SIMDLevel::AVX512_SPR>,
517
+ SimilarityIP<SIMDLevel::AVX512_SPR>,
518
+ SIMDLevel::AVX512_SPR> : DCBF16_IP {
519
+ using Sim = SimilarityIP<SIMDLevel::AVX512_SPR>;
520
+ using DCBF16_IP::DCBF16_IP;
521
+ };
522
+
523
+ template <>
524
+ struct DCTemplate<
525
+ QuantizerBF16<SIMDLevel::AVX512_SPR>,
526
+ SimilarityL2<SIMDLevel::AVX512_SPR>,
527
+ SIMDLevel::AVX512_SPR> : DCBF16_L2 {
528
+ using Sim = SimilarityL2<SIMDLevel::AVX512_SPR>;
529
+ using DCBF16_L2::DCBF16_L2;
530
+ };
531
+
532
+ /**********************************************************
533
+ * turboq_masked_sum — delegate to AVX512 implementation
534
+ **********************************************************/
535
+
536
+ template <SIMDLevel SL0>
537
+ float turboq_masked_sum(const float* arr, const uint8_t* bits, size_t d);
538
+
539
+ template <>
540
+ float turboq_masked_sum<SIMDLevel::AVX512>(
541
+ const float* arr,
542
+ const uint8_t* bits,
543
+ size_t d);
544
+
545
+ template <>
546
+ float turboq_masked_sum<SIMDLevel::AVX512_SPR>(
547
+ const float* arr,
548
+ const uint8_t* bits,
549
+ size_t d) {
550
+ return turboq_masked_sum<SIMDLevel::AVX512>(arr, bits, d);
551
+ }
552
+
553
+ } // namespace scalar_quantizer
554
+ } // namespace faiss
555
+
556
+ #define THE_LEVEL_TO_DISPATCH SIMDLevel::AVX512_SPR
557
+ #include <faiss/impl/scalar_quantizer/sq-dispatch.h>
558
+
559
+ #endif // COMPILE_SIMD_AVX512_SPR