faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -0,0 +1,553 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstring>
11
+
12
+ namespace faiss {
13
+
14
+ namespace scalar_quantizer {
15
+
16
+ using simd16float32 = faiss::simd16float32_tpl<SIMDLevel::AVX512>;
17
+ using simd512bit = faiss::simd512bit_tpl<SIMDLevel::AVX512>;
18
+
19
+ /**********************************************************
20
+ * TurboQuant bit-unpacking helpers
21
+ **********************************************************/
22
+
23
+ namespace {
24
+
25
+ FAISS_ALWAYS_INLINE uint16_t load_u16(const uint8_t* ptr) {
26
+ uint16_t value;
27
+ std::memcpy(&value, ptr, sizeof(value));
28
+ return value;
29
+ }
30
+
31
+ FAISS_ALWAYS_INLINE uint32_t load_u32(const uint8_t* ptr) {
32
+ uint32_t value;
33
+ std::memcpy(&value, ptr, sizeof(value));
34
+ return value;
35
+ }
36
+
37
+ FAISS_ALWAYS_INLINE uint64_t load_u64(const uint8_t* ptr) {
38
+ uint64_t value;
39
+ std::memcpy(&value, ptr, sizeof(value));
40
+ return value;
41
+ }
42
+
43
+ FAISS_ALWAYS_INLINE uint32_t load_u24(const uint8_t* ptr) {
44
+ return static_cast<uint32_t>(ptr[0]) |
45
+ (static_cast<uint32_t>(ptr[1]) << 8) |
46
+ (static_cast<uint32_t>(ptr[2]) << 16);
47
+ }
48
+
49
+ FAISS_ALWAYS_INLINE __m256i unpack_8x3bit_to_u32(uint32_t packed) {
50
+ const __m256i shifts = _mm256_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21);
51
+ const __m256i indices =
52
+ _mm256_srlv_epi32(_mm256_set1_epi32(packed), shifts);
53
+ return _mm256_and_si256(indices, _mm256_set1_epi32(0x7));
54
+ }
55
+
56
+ FAISS_ALWAYS_INLINE __m512i unpack_16x1bit_to_u32(const uint8_t* code, int i) {
57
+ const uint32_t packed = load_u16(code + (static_cast<size_t>(i) >> 3));
58
+ const __m512i shifts = _mm512_setr_epi32(
59
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
60
+ const __m512i indices =
61
+ _mm512_srlv_epi32(_mm512_set1_epi32(packed), shifts);
62
+ return _mm512_and_si512(indices, _mm512_set1_epi32(0x1));
63
+ }
64
+
65
+ FAISS_ALWAYS_INLINE __m512i unpack_16x2bit_to_u32(const uint8_t* code, int i) {
66
+ const uint32_t packed = load_u32(code + (static_cast<size_t>(i) >> 2));
67
+ const __m512i shifts = _mm512_setr_epi32(
68
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
69
+ const __m512i indices =
70
+ _mm512_srlv_epi32(_mm512_set1_epi32(packed), shifts);
71
+ return _mm512_and_si512(indices, _mm512_set1_epi32(0x3));
72
+ }
73
+
74
+ FAISS_ALWAYS_INLINE __m512i unpack_16x3bit_to_u32(const uint8_t* code, int i) {
75
+ const size_t byte_offset = (static_cast<size_t>(i) >> 4) * 6;
76
+ const __m256i low = unpack_8x3bit_to_u32(load_u24(code + byte_offset));
77
+ const __m256i high = unpack_8x3bit_to_u32(load_u24(code + byte_offset + 3));
78
+ __m512i indices = _mm512_castsi256_si512(low);
79
+ return _mm512_inserti32x8(indices, high, 1);
80
+ }
81
+
82
+ FAISS_ALWAYS_INLINE __m512i unpack_16x4bit_to_u32(const uint8_t* code, int i) {
83
+ const uint64_t packed = load_u64(code + (static_cast<size_t>(i) >> 1));
84
+ const __m256i shifts = _mm256_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28);
85
+ const __m256i low = _mm256_and_si256(
86
+ _mm256_srlv_epi32(_mm256_set1_epi32((uint32_t)packed), shifts),
87
+ _mm256_set1_epi32(0xf));
88
+ const __m256i high = _mm256_and_si256(
89
+ _mm256_srlv_epi32(
90
+ _mm256_set1_epi32((uint32_t)(packed >> 32)), shifts),
91
+ _mm256_set1_epi32(0xf));
92
+ __m512i indices = _mm512_castsi256_si512(low);
93
+ return _mm512_inserti32x8(indices, high, 1);
94
+ }
95
+
96
+ } // namespace
97
+
98
+ /**********************************************************
99
+ * Codecs
100
+ **********************************************************/
101
+
102
+ template <>
103
+ struct Codec8bit<SIMDLevel::AVX512> : Codec8bit<SIMDLevel::NONE> {
104
+ static FAISS_ALWAYS_INLINE simd16float32
105
+ decode_16_components(const uint8_t* code, size_t i) {
106
+ const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
107
+ const __m512i i32 = _mm512_cvtepu8_epi32(c16);
108
+ const __m512 f16 = _mm512_cvtepi32_ps(i32);
109
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
110
+ const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
111
+ return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
112
+ }
113
+ };
114
+
115
+ template <>
116
+ struct Codec4bit<SIMDLevel::AVX512> : Codec4bit<SIMDLevel::NONE> {
117
+ static FAISS_ALWAYS_INLINE simd16float32
118
+ decode_16_components(const uint8_t* code, size_t i) {
119
+ uint64_t c8 = *(uint64_t*)(code + (i >> 1));
120
+ uint64_t mask = 0x0f0f0f0f0f0f0f0f;
121
+ uint64_t c8ev = c8 & mask;
122
+ uint64_t c8od = (c8 >> 4) & mask;
123
+
124
+ __m128i c16 =
125
+ _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
126
+ __m256i c8lo = _mm256_cvtepu8_epi32(c16);
127
+ __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
128
+ __m512i i16 = _mm512_castsi256_si512(c8lo);
129
+ i16 = _mm512_inserti32x8(i16, c8hi, 1);
130
+ __m512 f16 = _mm512_cvtepi32_ps(i16);
131
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
132
+ const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
133
+ return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
134
+ }
135
+ };
136
+
137
+ template <>
138
+ struct Codec6bit<SIMDLevel::AVX512> : Codec6bit<SIMDLevel::NONE> {
139
+ static FAISS_ALWAYS_INLINE simd16float32
140
+ decode_16_components(const uint8_t* code, size_t i) {
141
+ // pure AVX512 implementation (not necessarily the fastest).
142
+ // see:
143
+ // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
144
+
145
+ // clang-format off
146
+
147
+ // 16 components, 16x6 bit=12 bytes
148
+ const __m128i bit_6v =
149
+ _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
150
+ const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
151
+
152
+ // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
153
+ // 00 01 02 03
154
+ const __m256i shuffle_mask = _mm256_setr_epi16(
155
+ 0xFF00, 0x0100, 0x0201, 0xFF02,
156
+ 0xFF03, 0x0403, 0x0504, 0xFF05,
157
+ 0xFF06, 0x0706, 0x0807, 0xFF08,
158
+ 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
159
+ const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
160
+
161
+ // 0: xxxxxxxx xx543210
162
+ // 1: xxxx5432 10xxxxxx
163
+ // 2: xxxxxx54 3210xxxx
164
+ // 3: xxxxxxxx 543210xx
165
+ const __m256i shift_right_v = _mm256_setr_epi16(
166
+ 0x0U, 0x6U, 0x4U, 0x2U,
167
+ 0x0U, 0x6U, 0x4U, 0x2U,
168
+ 0x0U, 0x6U, 0x4U, 0x2U,
169
+ 0x0U, 0x6U, 0x4U, 0x2U);
170
+ __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
171
+
172
+ // remove unneeded bits
173
+ shuffled_shifted =
174
+ _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
175
+
176
+ // scale
177
+ const __m512 f8 =
178
+ _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
179
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
180
+ const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
181
+ return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255));
182
+
183
+ // clang-format on
184
+ }
185
+ };
186
+
187
+ /**********************************************************
188
+ * Quantizers (uniform and non-uniform)
189
+ **********************************************************/
190
+
191
+ template <class Codec>
192
+ struct QuantizerTemplate<
193
+ Codec,
194
+ scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
195
+ SIMDLevel::AVX512>
196
+ : QuantizerTemplate<
197
+ Codec,
198
+ scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
199
+ SIMDLevel::NONE> {
200
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
201
+ : QuantizerTemplate<
202
+ Codec,
203
+ scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
204
+ SIMDLevel::NONE>(d, trained) {
205
+ assert(d % 16 == 0);
206
+ }
207
+
208
+ FAISS_ALWAYS_INLINE simd16float32
209
+ reconstruct_16_components(const uint8_t* code, int i) const {
210
+ __m512 xi = Codec::decode_16_components(code, i).f;
211
+ return simd16float32(_mm512_fmadd_ps(
212
+ xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin)));
213
+ }
214
+ };
215
+
216
+ template <class Codec>
217
+ struct QuantizerTemplate<
218
+ Codec,
219
+ scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
220
+ SIMDLevel::AVX512>
221
+ : QuantizerTemplate<
222
+ Codec,
223
+ scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
224
+ SIMDLevel::NONE> {
225
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
226
+ : QuantizerTemplate<
227
+ Codec,
228
+ scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
229
+ SIMDLevel::NONE>(d, trained) {
230
+ assert(d % 16 == 0);
231
+ }
232
+
233
+ FAISS_ALWAYS_INLINE simd16float32
234
+ reconstruct_16_components(const uint8_t* code, int i) const {
235
+ __m512 xi = Codec::decode_16_components(code, i).f;
236
+ return simd16float32(_mm512_fmadd_ps(
237
+ xi,
238
+ _mm512_loadu_ps(this->vdiff + i),
239
+ _mm512_loadu_ps(this->vmin + i)));
240
+ }
241
+ };
242
+
243
+ /**********************************************************
244
+ * TurboQuant MSE quantizer
245
+ **********************************************************/
246
+
247
+ #define DEFINE_TQMSE_AVX512_SPECIALIZATION(NBITS, INDEX_EXPR) \
248
+ template <> \
249
+ struct QuantizerTurboQuantMSE<NBITS, SIMDLevel::AVX512> \
250
+ : QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE> { \
251
+ using Base = QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE>; \
252
+ \
253
+ QuantizerTurboQuantMSE(size_t d, const std::vector<float>& trained) \
254
+ : Base(d, trained) { \
255
+ assert(d % 16 == 0); \
256
+ } \
257
+ \
258
+ FAISS_ALWAYS_INLINE simd16float32 \
259
+ reconstruct_16_components(const uint8_t* code, int i) const { \
260
+ const __m512i indices = (INDEX_EXPR); \
261
+ return simd16float32(_mm512_i32gather_ps( \
262
+ indices, this->centroids, sizeof(float))); \
263
+ } \
264
+ }
265
+
266
+ DEFINE_TQMSE_AVX512_SPECIALIZATION(1, unpack_16x1bit_to_u32(code, i));
267
+ DEFINE_TQMSE_AVX512_SPECIALIZATION(2, unpack_16x2bit_to_u32(code, i));
268
+ DEFINE_TQMSE_AVX512_SPECIALIZATION(3, unpack_16x3bit_to_u32(code, i));
269
+ DEFINE_TQMSE_AVX512_SPECIALIZATION(4, unpack_16x4bit_to_u32(code, i));
270
+
271
+ #undef DEFINE_TQMSE_AVX512_SPECIALIZATION
272
+
273
+ template <>
274
+ struct QuantizerTurboQuantMSE<8, SIMDLevel::AVX512>
275
+ : QuantizerTurboQuantMSE<8, SIMDLevel::NONE> {
276
+ using Base = QuantizerTurboQuantMSE<8, SIMDLevel::NONE>;
277
+
278
+ QuantizerTurboQuantMSE(size_t d, const std::vector<float>& trained)
279
+ : Base(d, trained) {
280
+ assert(d % 16 == 0);
281
+ }
282
+
283
+ FAISS_ALWAYS_INLINE simd16float32
284
+ reconstruct_16_components(const uint8_t* code, int i) const {
285
+ const __m128i packed = _mm_loadu_si128(
286
+ (const __m128i*)(code + static_cast<size_t>(i)));
287
+ const __m512i indices = _mm512_cvtepu8_epi32(packed);
288
+ return simd16float32(
289
+ _mm512_i32gather_ps(indices, this->centroids, sizeof(float)));
290
+ }
291
+ };
292
+
293
+ /**********************************************************
294
+ * FP16 Quantizer
295
+ **********************************************************/
296
+
297
+ template <>
298
+ struct QuantizerFP16<SIMDLevel::AVX512> : QuantizerFP16<SIMDLevel::NONE> {
299
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
300
+ : QuantizerFP16<SIMDLevel::NONE>(d, trained) {
301
+ assert(d % 16 == 0);
302
+ }
303
+
304
+ FAISS_ALWAYS_INLINE simd16float32
305
+ reconstruct_16_components(const uint8_t* code, int i) const {
306
+ __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
307
+ return simd16float32(_mm512_cvtph_ps(codei));
308
+ }
309
+ };
310
+
311
+ /**********************************************************
312
+ * BF16 Quantizer
313
+ **********************************************************/
314
+
315
+ template <>
316
+ struct QuantizerBF16<SIMDLevel::AVX512> : QuantizerBF16<SIMDLevel::NONE> {
317
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
318
+ : QuantizerBF16<SIMDLevel::NONE>(d, trained) {
319
+ assert(d % 16 == 0);
320
+ }
321
+
322
+ FAISS_ALWAYS_INLINE simd16float32
323
+ reconstruct_16_components(const uint8_t* code, int i) const {
324
+ __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
325
+ __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
326
+ code_512i = _mm512_slli_epi32(code_512i, 16);
327
+ return simd16float32(_mm512_castsi512_ps(code_512i));
328
+ }
329
+ };
330
+
331
+ /**********************************************************
332
+ * 8bit Direct Quantizer
333
+ **********************************************************/
334
+
335
+ template <>
336
+ struct Quantizer8bitDirect<SIMDLevel::AVX512>
337
+ : Quantizer8bitDirect<SIMDLevel::NONE> {
338
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
339
+ : Quantizer8bitDirect<SIMDLevel::NONE>(d, trained) {
340
+ assert(d % 16 == 0);
341
+ }
342
+
343
+ FAISS_ALWAYS_INLINE simd16float32
344
+ reconstruct_16_components(const uint8_t* code, int i) const {
345
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
346
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
347
+ return simd16float32(_mm512_cvtepi32_ps(y16)); // 16 * float32
348
+ }
349
+ };
350
+
351
+ /**********************************************************
352
+ * 8bit Direct Signed Quantizer
353
+ **********************************************************/
354
+
355
+ template <>
356
+ struct Quantizer8bitDirectSigned<SIMDLevel::AVX512>
357
+ : Quantizer8bitDirectSigned<SIMDLevel::NONE> {
358
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
359
+ : Quantizer8bitDirectSigned<SIMDLevel::NONE>(d, trained) {
360
+ assert(d % 16 == 0);
361
+ }
362
+
363
+ FAISS_ALWAYS_INLINE simd16float32
364
+ reconstruct_16_components(const uint8_t* code, int i) const {
365
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
366
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
367
+ __m512i c16 = _mm512_set1_epi32(128);
368
+ __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
369
+ return simd16float32(_mm512_cvtepi32_ps(z16)); // 16 * float32
370
+ }
371
+ };
372
+
373
+ /**********************************************************
374
+ * Similarities (L2 and IP)
375
+ **********************************************************/
376
+
377
+ template <>
378
+ struct SimilarityL2<SIMDLevel::AVX512> {
379
+ static constexpr int simdwidth = 16;
380
+ static constexpr SIMDLevel simd_level = SIMDLevel::AVX512;
381
+ static constexpr MetricType metric_type = METRIC_L2;
382
+
383
+ const float *y, *yi;
384
+
385
+ explicit SimilarityL2(const float* y) : y(y), yi(nullptr) {}
386
+
387
+ simd16float32 accu16;
388
+
389
+ FAISS_ALWAYS_INLINE void begin_16() {
390
+ accu16.clear();
391
+ yi = y;
392
+ }
393
+
394
+ FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) {
395
+ simd16float32 yiv(yi);
396
+ yi += 16;
397
+ simd16float32 tmp = yiv - x;
398
+ accu16 = accu16 + tmp * tmp;
399
+ }
400
+
401
+ FAISS_ALWAYS_INLINE void add_16_components_2(
402
+ simd16float32 x,
403
+ simd16float32 y_2) {
404
+ simd16float32 tmp = y_2 - x;
405
+ accu16 = accu16 + tmp * tmp;
406
+ }
407
+
408
+ FAISS_ALWAYS_INLINE float result_16() {
409
+ return horizontal_add(accu16);
410
+ }
411
+ };
412
+
413
+ template <>
414
+ struct SimilarityIP<SIMDLevel::AVX512> {
415
+ static constexpr int simdwidth = 16;
416
+ static constexpr SIMDLevel simd_level = SIMDLevel::AVX512;
417
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
418
+
419
+ const float *y, *yi;
420
+
421
+ explicit SimilarityIP(const float* y) : y(y), yi(nullptr) {}
422
+
423
+ simd16float32 accu16;
424
+
425
+ FAISS_ALWAYS_INLINE void begin_16() {
426
+ accu16.clear();
427
+ yi = y;
428
+ }
429
+
430
+ FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) {
431
+ simd16float32 yiv(yi);
432
+ yi += 16;
433
+ accu16 = accu16 + yiv * x;
434
+ }
435
+
436
+ FAISS_ALWAYS_INLINE void add_16_components_2(
437
+ simd16float32 x1,
438
+ simd16float32 x2) {
439
+ accu16 = accu16 + x1 * x2;
440
+ }
441
+
442
+ FAISS_ALWAYS_INLINE float result_16() {
443
+ return horizontal_add(accu16);
444
+ }
445
+ };
446
+
447
+ /**********************************************************
448
+ * Distance Computers
449
+ **********************************************************/
450
+
451
+ template <class Quantizer, class Similarity>
452
+ struct DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512>
453
+ : SQDistanceComputer {
454
+ using Sim = Similarity;
455
+
456
+ Quantizer quant;
457
+
458
+ DCTemplate(size_t d, const std::vector<float>& trained)
459
+ : quant(d, trained) {}
460
+
461
+ float compute_distance(const float* x, const uint8_t* code) const {
462
+ Similarity sim(x);
463
+ sim.begin_16();
464
+ for (size_t i = 0; i < quant.d; i += 16) {
465
+ simd16float32 xi = quant.reconstruct_16_components(code, i);
466
+ sim.add_16_components(xi);
467
+ }
468
+ return sim.result_16();
469
+ }
470
+
471
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
472
+ const {
473
+ Similarity sim(nullptr);
474
+ sim.begin_16();
475
+ for (size_t i = 0; i < quant.d; i += 16) {
476
+ simd16float32 x1 = quant.reconstruct_16_components(code1, i);
477
+ simd16float32 x2 = quant.reconstruct_16_components(code2, i);
478
+ sim.add_16_components_2(x1, x2);
479
+ }
480
+ return sim.result_16();
481
+ }
482
+
483
+ void set_query(const float* x) final {
484
+ q = x;
485
+ }
486
+
487
+ float symmetric_dis(idx_t i, idx_t j) override {
488
+ return compute_code_distance(
489
+ codes + i * code_size, codes + j * code_size);
490
+ }
491
+
492
+ float query_to_code(const uint8_t* code) const final {
493
+ return compute_distance(q, code);
494
+ }
495
+ };
496
+
497
+ template <class Similarity>
498
+ struct DistanceComputerByte<Similarity, SIMDLevel::AVX512>
499
+ : SQDistanceComputer {
500
+ using Sim = Similarity;
501
+
502
+ int d;
503
+ std::vector<uint8_t> tmp;
504
+
505
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
506
+
507
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
508
+ const {
509
+ // compute 16 lanes of 32-bit products (16-bytes) at once for
510
+ // the supported metrics
511
+ __m512i accu = _mm512_setzero_si512();
512
+ constexpr int kLanes = 16;
513
+ for (int i = 0; i < d; i += kLanes) {
514
+ __m128i c1 = _mm_loadu_si128((__m128i*)(code1 + i));
515
+ __m128i c2 = _mm_loadu_si128((__m128i*)(code2 + i));
516
+ __m512i c1i = _mm512_cvtepu8_epi32(c1);
517
+ __m512i c2i = _mm512_cvtepu8_epi32(c2);
518
+
519
+ __m512i v;
520
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
521
+ v = _mm512_mullo_epi32(c1i, c2i);
522
+ } else {
523
+ __m512i diff = _mm512_sub_epi32(c1i, c2i);
524
+ v = _mm512_mullo_epi32(diff, diff);
525
+ }
526
+ accu = _mm512_add_epi32(accu, v);
527
+ }
528
+ return _mm512_reduce_add_epi32(accu);
529
+ }
530
+
531
+ void set_query(const float* x) final {
532
+ for (int i = 0; i < d; i++) {
533
+ tmp[i] = int(x[i]);
534
+ }
535
+ }
536
+
537
+ int compute_distance(const float* x, const uint8_t* code) {
538
+ set_query(x);
539
+ return compute_code_distance(tmp.data(), code);
540
+ }
541
+
542
+ float symmetric_dis(idx_t i, idx_t j) override {
543
+ return compute_code_distance(
544
+ codes + i * code_size, codes + j * code_size);
545
+ }
546
+
547
+ float query_to_code(const uint8_t* code) const final {
548
+ return compute_code_distance(tmp.data(), code);
549
+ }
550
+ };
551
+
552
+ } // namespace scalar_quantizer
553
+ } // namespace faiss