faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -0,0 +1,430 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifdef COMPILE_SIMD_AVX512
9
+
10
+ #include <faiss/impl/scalar_quantizer/codecs.h>
11
+ #include <faiss/impl/scalar_quantizer/distance_computers.h>
12
+ #include <faiss/impl/scalar_quantizer/quantizers.h>
13
+ #include <faiss/impl/scalar_quantizer/scanners.h>
14
+ #include <faiss/impl/scalar_quantizer/similarities.h>
15
+
16
+ namespace faiss {
17
+
18
+ namespace scalar_quantizer {
19
+
20
+ /**********************************************************
21
+ * Codecs
22
+ **********************************************************/
23
+
24
+ template <>
25
+ struct Codec8bit<SIMDLevel::AVX512> : Codec8bit<SIMDLevel::NONE> {
26
+ static FAISS_ALWAYS_INLINE simd16float32
27
+ decode_16_components(const uint8_t* code, size_t i) {
28
+ const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
29
+ const __m512i i32 = _mm512_cvtepu8_epi32(c16);
30
+ const __m512 f16 = _mm512_cvtepi32_ps(i32);
31
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
32
+ const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
33
+ return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
34
+ }
35
+ };
36
+
37
+ template <>
38
+ struct Codec4bit<SIMDLevel::AVX512> : Codec4bit<SIMDLevel::NONE> {
39
+ static FAISS_ALWAYS_INLINE simd16float32
40
+ decode_16_components(const uint8_t* code, size_t i) {
41
+ uint64_t c8 = *(uint64_t*)(code + (i >> 1));
42
+ uint64_t mask = 0x0f0f0f0f0f0f0f0f;
43
+ uint64_t c8ev = c8 & mask;
44
+ uint64_t c8od = (c8 >> 4) & mask;
45
+
46
+ __m128i c16 =
47
+ _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
48
+ __m256i c8lo = _mm256_cvtepu8_epi32(c16);
49
+ __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
50
+ __m512i i16 = _mm512_castsi256_si512(c8lo);
51
+ i16 = _mm512_inserti32x8(i16, c8hi, 1);
52
+ __m512 f16 = _mm512_cvtepi32_ps(i16);
53
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
54
+ const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
55
+ return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
56
+ }
57
+ };
58
+
59
+ template <>
60
+ struct Codec6bit<SIMDLevel::AVX512> : Codec6bit<SIMDLevel::NONE> {
61
+ static FAISS_ALWAYS_INLINE simd16float32
62
+ decode_16_components(const uint8_t* code, size_t i) {
63
+ // pure AVX512 implementation (not necessarily the fastest).
64
+ // see:
65
+ // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
66
+
67
+ // clang-format off
68
+
69
+ // 16 components, 16x6 bit=12 bytes
70
+ const __m128i bit_6v =
71
+ _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
72
+ const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
73
+
74
+ // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
75
+ // 00 01 02 03
76
+ const __m256i shuffle_mask = _mm256_setr_epi16(
77
+ 0xFF00, 0x0100, 0x0201, 0xFF02,
78
+ 0xFF03, 0x0403, 0x0504, 0xFF05,
79
+ 0xFF06, 0x0706, 0x0807, 0xFF08,
80
+ 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
81
+ const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
82
+
83
+ // 0: xxxxxxxx xx543210
84
+ // 1: xxxx5432 10xxxxxx
85
+ // 2: xxxxxx54 3210xxxx
86
+ // 3: xxxxxxxx 543210xx
87
+ const __m256i shift_right_v = _mm256_setr_epi16(
88
+ 0x0U, 0x6U, 0x4U, 0x2U,
89
+ 0x0U, 0x6U, 0x4U, 0x2U,
90
+ 0x0U, 0x6U, 0x4U, 0x2U,
91
+ 0x0U, 0x6U, 0x4U, 0x2U);
92
+ __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
93
+
94
+ // remove unneeded bits
95
+ shuffled_shifted =
96
+ _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
97
+
98
+ // scale
99
+ const __m512 f8 =
100
+ _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
101
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
102
+ const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
103
+ return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255));
104
+
105
+ // clang-format on
106
+ }
107
+ };
108
+
109
+ /**********************************************************
110
+ * Quantizers (uniform and non-uniform)
111
+ **********************************************************/
112
+
113
+ template <class Codec>
114
+ struct QuantizerTemplate<
115
+ Codec,
116
+ scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
117
+ SIMDLevel::AVX512>
118
+ : QuantizerTemplate<
119
+ Codec,
120
+ scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
121
+ SIMDLevel::NONE> {
122
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
123
+ : QuantizerTemplate<
124
+ Codec,
125
+ scalar_quantizer::QuantizerTemplateScaling::UNIFORM,
126
+ SIMDLevel::NONE>(d, trained) {
127
+ assert(d % 16 == 0);
128
+ }
129
+
130
+ FAISS_ALWAYS_INLINE simd16float32
131
+ reconstruct_16_components(const uint8_t* code, int i) const {
132
+ __m512 xi = Codec::decode_16_components(code, i).f;
133
+ return simd16float32(_mm512_fmadd_ps(
134
+ xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin)));
135
+ }
136
+ };
137
+
138
+ template <class Codec>
139
+ struct QuantizerTemplate<
140
+ Codec,
141
+ scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
142
+ SIMDLevel::AVX512>
143
+ : QuantizerTemplate<
144
+ Codec,
145
+ scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
146
+ SIMDLevel::NONE> {
147
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
148
+ : QuantizerTemplate<
149
+ Codec,
150
+ scalar_quantizer::QuantizerTemplateScaling::NON_UNIFORM,
151
+ SIMDLevel::NONE>(d, trained) {
152
+ assert(d % 16 == 0);
153
+ }
154
+
155
+ FAISS_ALWAYS_INLINE simd16float32
156
+ reconstruct_16_components(const uint8_t* code, int i) const {
157
+ __m512 xi = Codec::decode_16_components(code, i).f;
158
+ return simd16float32(_mm512_fmadd_ps(
159
+ xi,
160
+ _mm512_loadu_ps(this->vdiff + i),
161
+ _mm512_loadu_ps(this->vmin + i)));
162
+ }
163
+ };
164
+
165
+ /**********************************************************
166
+ * FP16 Quantizer
167
+ **********************************************************/
168
+
169
+ template <>
170
+ struct QuantizerFP16<SIMDLevel::AVX512> : QuantizerFP16<SIMDLevel::NONE> {
171
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
172
+ : QuantizerFP16<SIMDLevel::NONE>(d, trained) {
173
+ assert(d % 16 == 0);
174
+ }
175
+
176
+ FAISS_ALWAYS_INLINE simd16float32
177
+ reconstruct_16_components(const uint8_t* code, int i) const {
178
+ __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
179
+ return simd16float32(_mm512_cvtph_ps(codei));
180
+ }
181
+ };
182
+
183
+ /**********************************************************
184
+ * BF16 Quantizer
185
+ **********************************************************/
186
+
187
+ template <>
188
+ struct QuantizerBF16<SIMDLevel::AVX512> : QuantizerBF16<SIMDLevel::NONE> {
189
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
190
+ : QuantizerBF16<SIMDLevel::NONE>(d, trained) {
191
+ assert(d % 16 == 0);
192
+ }
193
+
194
+ FAISS_ALWAYS_INLINE simd16float32
195
+ reconstruct_16_components(const uint8_t* code, int i) const {
196
+ __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
197
+ __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
198
+ code_512i = _mm512_slli_epi32(code_512i, 16);
199
+ return simd16float32(_mm512_castsi512_ps(code_512i));
200
+ }
201
+ };
202
+
203
+ /**********************************************************
204
+ * 8bit Direct Quantizer
205
+ **********************************************************/
206
+
207
+ template <>
208
+ struct Quantizer8bitDirect<SIMDLevel::AVX512>
209
+ : Quantizer8bitDirect<SIMDLevel::NONE> {
210
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
211
+ : Quantizer8bitDirect<SIMDLevel::NONE>(d, trained) {
212
+ assert(d % 16 == 0);
213
+ }
214
+
215
+ FAISS_ALWAYS_INLINE simd16float32
216
+ reconstruct_16_components(const uint8_t* code, int i) const {
217
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
218
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
219
+ return simd16float32(_mm512_cvtepi32_ps(y16)); // 16 * float32
220
+ }
221
+ };
222
+
223
+ /**********************************************************
224
+ * 8bit Direct Signed Quantizer
225
+ **********************************************************/
226
+
227
+ template <>
228
+ struct Quantizer8bitDirectSigned<SIMDLevel::AVX512>
229
+ : Quantizer8bitDirectSigned<SIMDLevel::NONE> {
230
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
231
+ : Quantizer8bitDirectSigned<SIMDLevel::NONE>(d, trained) {
232
+ assert(d % 16 == 0);
233
+ }
234
+
235
+ FAISS_ALWAYS_INLINE simd16float32
236
+ reconstruct_16_components(const uint8_t* code, int i) const {
237
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
238
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
239
+ __m512i c16 = _mm512_set1_epi32(128);
240
+ __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
241
+ return simd16float32(_mm512_cvtepi32_ps(z16)); // 16 * float32
242
+ }
243
+ };
244
+
245
+ /**********************************************************
246
+ * Similarities (L2 and IP)
247
+ **********************************************************/
248
+
249
+ template <>
250
+ struct SimilarityL2<SIMDLevel::AVX512> {
251
+ static constexpr int simdwidth = 16;
252
+ static constexpr SIMDLevel simd_level = SIMDLevel::AVX512;
253
+ static constexpr MetricType metric_type = METRIC_L2;
254
+
255
+ const float *y, *yi;
256
+
257
+ explicit SimilarityL2(const float* y) : y(y), yi(nullptr) {}
258
+
259
+ simd16float32 accu16;
260
+
261
+ FAISS_ALWAYS_INLINE void begin_16() {
262
+ accu16.clear();
263
+ yi = y;
264
+ }
265
+
266
+ FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) {
267
+ simd16float32 yiv(yi);
268
+ yi += 16;
269
+ simd16float32 tmp = yiv - x;
270
+ accu16 = accu16 + tmp * tmp;
271
+ }
272
+
273
+ FAISS_ALWAYS_INLINE void add_16_components_2(
274
+ simd16float32 x,
275
+ simd16float32 y_2) {
276
+ simd16float32 tmp = y_2 - x;
277
+ accu16 = accu16 + tmp * tmp;
278
+ }
279
+
280
+ FAISS_ALWAYS_INLINE float result_16() {
281
+ return horizontal_add(accu16);
282
+ }
283
+ };
284
+
285
+ template <>
286
+ struct SimilarityIP<SIMDLevel::AVX512> {
287
+ static constexpr int simdwidth = 16;
288
+ static constexpr SIMDLevel simd_level = SIMDLevel::AVX512;
289
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
290
+
291
+ const float *y, *yi;
292
+
293
+ explicit SimilarityIP(const float* y) : y(y), yi(nullptr) {}
294
+
295
+ simd16float32 accu16;
296
+
297
+ FAISS_ALWAYS_INLINE void begin_16() {
298
+ accu16.clear();
299
+ yi = y;
300
+ }
301
+
302
+ FAISS_ALWAYS_INLINE void add_16_components(simd16float32 x) {
303
+ simd16float32 yiv(yi);
304
+ yi += 16;
305
+ accu16 = accu16 + yiv * x;
306
+ }
307
+
308
+ FAISS_ALWAYS_INLINE void add_16_components_2(
309
+ simd16float32 x1,
310
+ simd16float32 x2) {
311
+ accu16 = accu16 + x1 * x2;
312
+ }
313
+
314
+ FAISS_ALWAYS_INLINE float result_16() {
315
+ return horizontal_add(accu16);
316
+ }
317
+ };
318
+
319
+ /**********************************************************
320
+ * Distance Computers
321
+ **********************************************************/
322
+
323
+ template <class Quantizer, class Similarity>
324
+ struct DCTemplate<Quantizer, Similarity, SIMDLevel::AVX512>
325
+ : SQDistanceComputer {
326
+ using Sim = Similarity;
327
+
328
+ Quantizer quant;
329
+
330
+ DCTemplate(size_t d, const std::vector<float>& trained)
331
+ : quant(d, trained) {}
332
+
333
+ float compute_distance(const float* x, const uint8_t* code) const {
334
+ Similarity sim(x);
335
+ sim.begin_16();
336
+ for (size_t i = 0; i < quant.d; i += 16) {
337
+ simd16float32 xi = quant.reconstruct_16_components(code, i);
338
+ sim.add_16_components(xi);
339
+ }
340
+ return sim.result_16();
341
+ }
342
+
343
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
344
+ const {
345
+ Similarity sim(nullptr);
346
+ sim.begin_16();
347
+ for (size_t i = 0; i < quant.d; i += 16) {
348
+ simd16float32 x1 = quant.reconstruct_16_components(code1, i);
349
+ simd16float32 x2 = quant.reconstruct_16_components(code2, i);
350
+ sim.add_16_components_2(x1, x2);
351
+ }
352
+ return sim.result_16();
353
+ }
354
+
355
+ void set_query(const float* x) final {
356
+ q = x;
357
+ }
358
+
359
+ float symmetric_dis(idx_t i, idx_t j) override {
360
+ return compute_code_distance(
361
+ codes + i * code_size, codes + j * code_size);
362
+ }
363
+
364
+ float query_to_code(const uint8_t* code) const final {
365
+ return compute_distance(q, code);
366
+ }
367
+ };
368
+
369
+ template <class Similarity>
370
+ struct DistanceComputerByte<Similarity, SIMDLevel::AVX512>
371
+ : SQDistanceComputer {
372
+ using Sim = Similarity;
373
+
374
+ int d;
375
+ std::vector<uint8_t> tmp;
376
+
377
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
378
+
379
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
380
+ const {
381
+ // compute 16 lanes of 32-bit products (16-bytes) at once for
382
+ // the supported metrics
383
+ __m512i accu = _mm512_setzero_si512();
384
+ constexpr int kLanes = 16;
385
+ for (int i = 0; i < d; i += kLanes) {
386
+ __m128i c1 = _mm_loadu_si128((__m128i*)(code1 + i));
387
+ __m128i c2 = _mm_loadu_si128((__m128i*)(code2 + i));
388
+ __m512i c1i = _mm512_cvtepu8_epi32(c1);
389
+ __m512i c2i = _mm512_cvtepu8_epi32(c2);
390
+
391
+ __m512i v;
392
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
393
+ v = _mm512_mullo_epi32(c1i, c2i);
394
+ } else {
395
+ __m512i diff = _mm512_sub_epi32(c1i, c2i);
396
+ v = _mm512_mullo_epi32(diff, diff);
397
+ }
398
+ accu = _mm512_add_epi32(accu, v);
399
+ }
400
+ return _mm512_reduce_add_epi32(accu);
401
+ }
402
+
403
+ void set_query(const float* x) final {
404
+ for (int i = 0; i < d; i++) {
405
+ tmp[i] = int(x[i]);
406
+ }
407
+ }
408
+
409
+ int compute_distance(const float* x, const uint8_t* code) {
410
+ set_query(x);
411
+ return compute_code_distance(tmp.data(), code);
412
+ }
413
+
414
+ float symmetric_dis(idx_t i, idx_t j) override {
415
+ return compute_code_distance(
416
+ codes + i * code_size, codes + j * code_size);
417
+ }
418
+
419
+ float query_to_code(const uint8_t* code) const final {
420
+ return compute_code_distance(tmp.data(), code);
421
+ }
422
+ };
423
+
424
+ } // namespace scalar_quantizer
425
+ } // namespace faiss
426
+
427
+ #define THE_LEVEL_TO_DISPATCH SIMDLevel::AVX512
428
+ #include <faiss/impl/scalar_quantizer/sq-dispatch.h>
429
+
430
+ #endif // COMPILE_SIMD_AVX512