faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -18,6 +18,8 @@
18
18
 
19
19
  #include <arm_neon.h>
20
20
 
21
+ #include <faiss/impl/FaissAssert.h>
22
+
21
23
  namespace faiss {
22
24
 
23
25
  namespace detail {
@@ -88,6 +90,23 @@ static inline float32x4x2_t reinterpret_f32(const float32x4x2_t& v) {
88
90
  return v;
89
91
  }
90
92
 
93
+ // Surprisingly, vdupq_n_u16 has the type of
94
+ // uint16x8_t (std::uint32_t) , and vdupq_n_u8 also has
95
+ // uint8x16_t (std::uint32_t) on **some environments**.
96
+ // We want argument type as same as the type of element
97
+ // of result vector type (std::uint16_t for uint16x8_t,
98
+ // and std::uint8_t for uint8x16_t) instead of
99
+ // std::uint32_t due to using set1 function templates,
100
+ // so let's fix the argument type here and use these
101
+ // overload below.
102
+ static inline ::uint16x8_t vdupq_n_u16(std::uint16_t v) {
103
+ return ::vdupq_n_u16(v);
104
+ }
105
+
106
+ static inline ::uint8x16_t vdupq_n_u8(std::uint8_t v) {
107
+ return ::vdupq_n_u8(v);
108
+ }
109
+
91
110
  template <
92
111
  typename T,
93
112
  typename U = decltype(reinterpret_u8(std::declval<T>().data))>
@@ -119,11 +138,25 @@ static inline std::string bin(const S& simd) {
119
138
  return std::string(bits);
120
139
  }
121
140
 
122
- template <typename D, typename F, typename T>
123
- static inline void set1(D& d, F&& f, T t) {
124
- const auto v = f(t);
125
- d.val[0] = v;
126
- d.val[1] = v;
141
+ template <typename T>
142
+ using remove_cv_ref_t =
143
+ typename std::remove_reference<typename std::remove_cv<T>::type>::type;
144
+
145
+ template <typename D, typename T>
146
+ struct set1_impl {
147
+ D& d;
148
+ T t;
149
+ template <remove_cv_ref_t<decltype(std::declval<D>().val[0])> (*F)(T)>
150
+ inline void call() {
151
+ const auto v = F(t);
152
+ d.val[0] = v;
153
+ d.val[1] = v;
154
+ }
155
+ };
156
+
157
+ template <typename D, typename T>
158
+ static inline set1_impl<remove_cv_ref_t<D>, T> set1(D& d, T t) {
159
+ return {d, t};
127
160
  }
128
161
 
129
162
  template <typename T, size_t N, typename S>
@@ -142,20 +175,57 @@ static inline std::string elements_to_string(const char* fmt, const S& simd) {
142
175
  return std::string(res);
143
176
  }
144
177
 
145
- template <typename T, typename F>
146
- static inline T unary_func(const T& a, F&& f) {
147
- T t;
148
- t.val[0] = f(a.val[0]);
149
- t.val[1] = f(a.val[1]);
150
- return t;
178
+ template <typename T, typename U>
179
+ struct unary_func_impl {
180
+ const U& a;
181
+ using Telem = remove_cv_ref_t<decltype(std::declval<T>().val[0])>;
182
+ using Uelem = remove_cv_ref_t<decltype(std::declval<U>().val[0])>;
183
+ template <Telem (*F)(Uelem)>
184
+ inline T call() {
185
+ T t;
186
+ t.val[0] = F(a.val[0]);
187
+ t.val[1] = F(a.val[1]);
188
+ return t;
189
+ }
190
+ };
191
+
192
+ template <typename T>
193
+ static inline unary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<T>> unary_func(
194
+ const T& a) {
195
+ return {a};
151
196
  }
152
197
 
153
- template <typename T, typename F>
154
- static inline T binary_func(const T& a, const T& b, F&& f) {
155
- T t;
156
- t.val[0] = f(a.val[0], b.val[0]);
157
- t.val[1] = f(a.val[1], b.val[1]);
158
- return t;
198
+ template <typename T, typename U>
199
+ static inline unary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<U>> unary_func(
200
+ const U& a) {
201
+ return {a};
202
+ }
203
+
204
+ template <typename T, typename U>
205
+ struct binary_func_impl {
206
+ const U& a;
207
+ const U& b;
208
+ using Telem = remove_cv_ref_t<decltype(std::declval<T>().val[0])>;
209
+ using Uelem = remove_cv_ref_t<decltype(std::declval<U>().val[0])>;
210
+ template <Telem (*F)(Uelem, Uelem)>
211
+ inline T call() {
212
+ T t;
213
+ t.val[0] = F(a.val[0], b.val[0]);
214
+ t.val[1] = F(a.val[1], b.val[1]);
215
+ return t;
216
+ }
217
+ };
218
+
219
+ template <typename T>
220
+ static inline binary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<T>>
221
+ binary_func(const T& a, const T& b) {
222
+ return {a, b};
223
+ }
224
+
225
+ template <typename T, typename U>
226
+ static inline binary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<U>>
227
+ binary_func(const U& a, const U& b) {
228
+ return {a, b};
159
229
  }
160
230
 
161
231
  static inline uint16_t vmovmask_u8(const uint8x16_t& v) {
@@ -172,8 +242,8 @@ static inline uint32_t cmp_xe32(
172
242
  const uint16x8x2_t& d0,
173
243
  const uint16x8x2_t& d1,
174
244
  const uint16x8x2_t& thr) {
175
- const auto d0_thr = detail::simdlib::binary_func(d0, thr, F);
176
- const auto d1_thr = detail::simdlib::binary_func(d1, thr, F);
245
+ const auto d0_thr = detail::simdlib::binary_func(d0, thr).call<F>();
246
+ const auto d1_thr = detail::simdlib::binary_func(d1, thr).call<F>();
177
247
  const auto d0_mask = vmovmask_u8(
178
248
  vmovn_high_u16(vmovn_u16(d0_thr.val[0]), d0_thr.val[1]));
179
249
  const auto d1_mask = vmovmask_u8(
@@ -207,6 +277,44 @@ struct simd16uint16 {
207
277
 
208
278
  explicit simd16uint16(const uint16x8x2_t& v) : data{v} {}
209
279
 
280
+ explicit simd16uint16(
281
+ uint16_t u0,
282
+ uint16_t u1,
283
+ uint16_t u2,
284
+ uint16_t u3,
285
+ uint16_t u4,
286
+ uint16_t u5,
287
+ uint16_t u6,
288
+ uint16_t u7,
289
+ uint16_t u8,
290
+ uint16_t u9,
291
+ uint16_t u10,
292
+ uint16_t u11,
293
+ uint16_t u12,
294
+ uint16_t u13,
295
+ uint16_t u14,
296
+ uint16_t u15) {
297
+ uint16_t temp[16] = {
298
+ u0,
299
+ u1,
300
+ u2,
301
+ u3,
302
+ u4,
303
+ u5,
304
+ u6,
305
+ u7,
306
+ u8,
307
+ u9,
308
+ u10,
309
+ u11,
310
+ u12,
311
+ u13,
312
+ u14,
313
+ u15};
314
+ data.val[0] = vld1q_u16(temp);
315
+ data.val[1] = vld1q_u16(temp + 8);
316
+ }
317
+
210
318
  template <
211
319
  typename T,
212
320
  typename std::enable_if<
@@ -219,7 +327,8 @@ struct simd16uint16 {
219
327
  : data{vld1q_u16(x), vld1q_u16(x + 8)} {}
220
328
 
221
329
  void clear() {
222
- detail::simdlib::set1(data, &vdupq_n_u16, static_cast<uint16_t>(0));
330
+ detail::simdlib::set1(data, static_cast<uint16_t>(0))
331
+ .call<&detail::simdlib::vdupq_n_u16>();
223
332
  }
224
333
 
225
334
  void storeu(uint16_t* ptr) const {
@@ -257,12 +366,12 @@ struct simd16uint16 {
257
366
  }
258
367
 
259
368
  void set1(uint16_t x) {
260
- detail::simdlib::set1(data, &vdupq_n_u16, x);
369
+ detail::simdlib::set1(data, x).call<&detail::simdlib::vdupq_n_u16>();
261
370
  }
262
371
 
263
372
  simd16uint16 operator*(const simd16uint16& other) const {
264
- return simd16uint16{
265
- detail::simdlib::binary_func(data, other.data, &vmulq_u16)};
373
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
374
+ .call<&vmulq_u16>()};
266
375
  }
267
376
 
268
377
  // shift must be known at compile time
@@ -271,50 +380,56 @@ struct simd16uint16 {
271
380
  case 0:
272
381
  return *this;
273
382
  case 1:
274
- return simd16uint16{detail::simdlib::unary_func(
275
- data, detail::simdlib::vshrq<1>)};
383
+ return simd16uint16{detail::simdlib::unary_func(data)
384
+ .call<detail::simdlib::vshrq<1>>()};
276
385
  case 2:
277
- return simd16uint16{detail::simdlib::unary_func(
278
- data, detail::simdlib::vshrq<2>)};
386
+ return simd16uint16{detail::simdlib::unary_func(data)
387
+ .call<detail::simdlib::vshrq<2>>()};
279
388
  case 3:
280
- return simd16uint16{detail::simdlib::unary_func(
281
- data, detail::simdlib::vshrq<3>)};
389
+ return simd16uint16{detail::simdlib::unary_func(data)
390
+ .call<detail::simdlib::vshrq<3>>()};
282
391
  case 4:
283
- return simd16uint16{detail::simdlib::unary_func(
284
- data, detail::simdlib::vshrq<4>)};
392
+ return simd16uint16{detail::simdlib::unary_func(data)
393
+ .call<detail::simdlib::vshrq<4>>()};
285
394
  case 5:
286
- return simd16uint16{detail::simdlib::unary_func(
287
- data, detail::simdlib::vshrq<5>)};
395
+ return simd16uint16{detail::simdlib::unary_func(data)
396
+ .call<detail::simdlib::vshrq<5>>()};
288
397
  case 6:
289
- return simd16uint16{detail::simdlib::unary_func(
290
- data, detail::simdlib::vshrq<6>)};
398
+ return simd16uint16{detail::simdlib::unary_func(data)
399
+ .call<detail::simdlib::vshrq<6>>()};
291
400
  case 7:
292
- return simd16uint16{detail::simdlib::unary_func(
293
- data, detail::simdlib::vshrq<7>)};
401
+ return simd16uint16{detail::simdlib::unary_func(data)
402
+ .call<detail::simdlib::vshrq<7>>()};
294
403
  case 8:
295
- return simd16uint16{detail::simdlib::unary_func(
296
- data, detail::simdlib::vshrq<8>)};
404
+ return simd16uint16{detail::simdlib::unary_func(data)
405
+ .call<detail::simdlib::vshrq<8>>()};
297
406
  case 9:
298
- return simd16uint16{detail::simdlib::unary_func(
299
- data, detail::simdlib::vshrq<9>)};
407
+ return simd16uint16{detail::simdlib::unary_func(data)
408
+ .call<detail::simdlib::vshrq<9>>()};
300
409
  case 10:
301
- return simd16uint16{detail::simdlib::unary_func(
302
- data, detail::simdlib::vshrq<10>)};
410
+ return simd16uint16{
411
+ detail::simdlib::unary_func(data)
412
+ .call<detail::simdlib::vshrq<10>>()};
303
413
  case 11:
304
- return simd16uint16{detail::simdlib::unary_func(
305
- data, detail::simdlib::vshrq<11>)};
414
+ return simd16uint16{
415
+ detail::simdlib::unary_func(data)
416
+ .call<detail::simdlib::vshrq<11>>()};
306
417
  case 12:
307
- return simd16uint16{detail::simdlib::unary_func(
308
- data, detail::simdlib::vshrq<12>)};
418
+ return simd16uint16{
419
+ detail::simdlib::unary_func(data)
420
+ .call<detail::simdlib::vshrq<12>>()};
309
421
  case 13:
310
- return simd16uint16{detail::simdlib::unary_func(
311
- data, detail::simdlib::vshrq<13>)};
422
+ return simd16uint16{
423
+ detail::simdlib::unary_func(data)
424
+ .call<detail::simdlib::vshrq<13>>()};
312
425
  case 14:
313
- return simd16uint16{detail::simdlib::unary_func(
314
- data, detail::simdlib::vshrq<14>)};
426
+ return simd16uint16{
427
+ detail::simdlib::unary_func(data)
428
+ .call<detail::simdlib::vshrq<14>>()};
315
429
  case 15:
316
- return simd16uint16{detail::simdlib::unary_func(
317
- data, detail::simdlib::vshrq<15>)};
430
+ return simd16uint16{
431
+ detail::simdlib::unary_func(data)
432
+ .call<detail::simdlib::vshrq<15>>()};
318
433
  default:
319
434
  FAISS_THROW_FMT("Invalid shift %d", shift);
320
435
  }
@@ -326,50 +441,56 @@ struct simd16uint16 {
326
441
  case 0:
327
442
  return *this;
328
443
  case 1:
329
- return simd16uint16{detail::simdlib::unary_func(
330
- data, detail::simdlib::vshlq<1>)};
444
+ return simd16uint16{detail::simdlib::unary_func(data)
445
+ .call<detail::simdlib::vshlq<1>>()};
331
446
  case 2:
332
- return simd16uint16{detail::simdlib::unary_func(
333
- data, detail::simdlib::vshlq<2>)};
447
+ return simd16uint16{detail::simdlib::unary_func(data)
448
+ .call<detail::simdlib::vshlq<2>>()};
334
449
  case 3:
335
- return simd16uint16{detail::simdlib::unary_func(
336
- data, detail::simdlib::vshlq<3>)};
450
+ return simd16uint16{detail::simdlib::unary_func(data)
451
+ .call<detail::simdlib::vshlq<3>>()};
337
452
  case 4:
338
- return simd16uint16{detail::simdlib::unary_func(
339
- data, detail::simdlib::vshlq<4>)};
453
+ return simd16uint16{detail::simdlib::unary_func(data)
454
+ .call<detail::simdlib::vshlq<4>>()};
340
455
  case 5:
341
- return simd16uint16{detail::simdlib::unary_func(
342
- data, detail::simdlib::vshlq<5>)};
456
+ return simd16uint16{detail::simdlib::unary_func(data)
457
+ .call<detail::simdlib::vshlq<5>>()};
343
458
  case 6:
344
- return simd16uint16{detail::simdlib::unary_func(
345
- data, detail::simdlib::vshlq<6>)};
459
+ return simd16uint16{detail::simdlib::unary_func(data)
460
+ .call<detail::simdlib::vshlq<6>>()};
346
461
  case 7:
347
- return simd16uint16{detail::simdlib::unary_func(
348
- data, detail::simdlib::vshlq<7>)};
462
+ return simd16uint16{detail::simdlib::unary_func(data)
463
+ .call<detail::simdlib::vshlq<7>>()};
349
464
  case 8:
350
- return simd16uint16{detail::simdlib::unary_func(
351
- data, detail::simdlib::vshlq<8>)};
465
+ return simd16uint16{detail::simdlib::unary_func(data)
466
+ .call<detail::simdlib::vshlq<8>>()};
352
467
  case 9:
353
- return simd16uint16{detail::simdlib::unary_func(
354
- data, detail::simdlib::vshlq<9>)};
468
+ return simd16uint16{detail::simdlib::unary_func(data)
469
+ .call<detail::simdlib::vshlq<9>>()};
355
470
  case 10:
356
- return simd16uint16{detail::simdlib::unary_func(
357
- data, detail::simdlib::vshlq<10>)};
471
+ return simd16uint16{
472
+ detail::simdlib::unary_func(data)
473
+ .call<detail::simdlib::vshlq<10>>()};
358
474
  case 11:
359
- return simd16uint16{detail::simdlib::unary_func(
360
- data, detail::simdlib::vshlq<11>)};
475
+ return simd16uint16{
476
+ detail::simdlib::unary_func(data)
477
+ .call<detail::simdlib::vshlq<11>>()};
361
478
  case 12:
362
- return simd16uint16{detail::simdlib::unary_func(
363
- data, detail::simdlib::vshlq<12>)};
479
+ return simd16uint16{
480
+ detail::simdlib::unary_func(data)
481
+ .call<detail::simdlib::vshlq<12>>()};
364
482
  case 13:
365
- return simd16uint16{detail::simdlib::unary_func(
366
- data, detail::simdlib::vshlq<13>)};
483
+ return simd16uint16{
484
+ detail::simdlib::unary_func(data)
485
+ .call<detail::simdlib::vshlq<13>>()};
367
486
  case 14:
368
- return simd16uint16{detail::simdlib::unary_func(
369
- data, detail::simdlib::vshlq<14>)};
487
+ return simd16uint16{
488
+ detail::simdlib::unary_func(data)
489
+ .call<detail::simdlib::vshlq<14>>()};
370
490
  case 15:
371
- return simd16uint16{detail::simdlib::unary_func(
372
- data, detail::simdlib::vshlq<15>)};
491
+ return simd16uint16{
492
+ detail::simdlib::unary_func(data)
493
+ .call<detail::simdlib::vshlq<15>>()};
373
494
  default:
374
495
  FAISS_THROW_FMT("Invalid shift %d", shift);
375
496
  }
@@ -386,13 +507,13 @@ struct simd16uint16 {
386
507
  }
387
508
 
388
509
  simd16uint16 operator+(const simd16uint16& other) const {
389
- return simd16uint16{
390
- detail::simdlib::binary_func(data, other.data, &vaddq_u16)};
510
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
511
+ .call<&vaddq_u16>()};
391
512
  }
392
513
 
393
514
  simd16uint16 operator-(const simd16uint16& other) const {
394
- return simd16uint16{
395
- detail::simdlib::binary_func(data, other.data, &vsubq_u16)};
515
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
516
+ .call<&vsubq_u16>()};
396
517
  }
397
518
 
398
519
  template <
@@ -401,10 +522,10 @@ struct simd16uint16 {
401
522
  detail::simdlib::is_simd256bit<T>::value,
402
523
  std::nullptr_t>::type = nullptr>
403
524
  simd16uint16 operator&(const T& other) const {
404
- return simd16uint16{detail::simdlib::binary_func(
405
- data,
406
- detail::simdlib::reinterpret_u16(other.data),
407
- &vandq_u16)};
525
+ return simd16uint16{
526
+ detail::simdlib::binary_func(
527
+ data, detail::simdlib::reinterpret_u16(other.data))
528
+ .template call<&vandq_u16>()};
408
529
  }
409
530
 
410
531
  template <
@@ -413,20 +534,45 @@ struct simd16uint16 {
413
534
  detail::simdlib::is_simd256bit<T>::value,
414
535
  std::nullptr_t>::type = nullptr>
415
536
  simd16uint16 operator|(const T& other) const {
416
- return simd16uint16{detail::simdlib::binary_func(
417
- data,
418
- detail::simdlib::reinterpret_u16(other.data),
419
- &vorrq_u16)};
537
+ return simd16uint16{
538
+ detail::simdlib::binary_func(
539
+ data, detail::simdlib::reinterpret_u16(other.data))
540
+ .template call<&vorrq_u16>()};
541
+ }
542
+
543
+ template <
544
+ typename T,
545
+ typename std::enable_if<
546
+ detail::simdlib::is_simd256bit<T>::value,
547
+ std::nullptr_t>::type = nullptr>
548
+ simd16uint16 operator^(const T& other) const {
549
+ return simd16uint16{
550
+ detail::simdlib::binary_func(
551
+ data, detail::simdlib::reinterpret_u16(other.data))
552
+ .template call<&veorq_u16>()};
420
553
  }
421
554
 
422
555
  // returns binary masks
423
556
  simd16uint16 operator==(const simd16uint16& other) const {
424
- return simd16uint16{
425
- detail::simdlib::binary_func(data, other.data, &vceqq_u16)};
557
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
558
+ .call<&vceqq_u16>()};
559
+ }
560
+
561
+ // Checks whether the other holds exactly the same bytes.
562
+ bool is_same_as(simd16uint16 other) const {
563
+ const bool equal0 =
564
+ (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
565
+ 0xffff);
566
+ const bool equal1 =
567
+ (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
568
+ 0xffff);
569
+
570
+ return equal0 && equal1;
426
571
  }
427
572
 
428
573
  simd16uint16 operator~() const {
429
- return simd16uint16{detail::simdlib::unary_func(data, &vmvnq_u16)};
574
+ return simd16uint16{
575
+ detail::simdlib::unary_func(data).call<&vmvnq_u16>()};
430
576
  }
431
577
 
432
578
  // get scalar at index 0
@@ -437,8 +583,8 @@ struct simd16uint16 {
437
583
  // mask of elements where this >= thresh
438
584
  // 2 bit per component: 16 * 2 = 32 bit
439
585
  uint32_t ge_mask(const simd16uint16& thresh) const {
440
- const auto input =
441
- detail::simdlib::binary_func(data, thresh.data, &vcgeq_u16);
586
+ const auto input = detail::simdlib::binary_func(data, thresh.data)
587
+ .call<&vcgeq_u16>();
442
588
  const auto vmovmask_u16 = [](uint16x8_t v) -> uint16_t {
443
589
  uint16_t d[8];
444
590
  const auto v2 = vreinterpretq_u32_u16(vshrq_n_u16(v, 14));
@@ -471,23 +617,25 @@ struct simd16uint16 {
471
617
  }
472
618
 
473
619
  void accu_min(const simd16uint16& incoming) {
474
- data = detail::simdlib::binary_func(incoming.data, data, &vminq_u16);
620
+ data = detail::simdlib::binary_func(incoming.data, data)
621
+ .call<&vminq_u16>();
475
622
  }
476
623
 
477
624
  void accu_max(const simd16uint16& incoming) {
478
- data = detail::simdlib::binary_func(incoming.data, data, &vmaxq_u16);
625
+ data = detail::simdlib::binary_func(incoming.data, data)
626
+ .call<&vmaxq_u16>();
479
627
  }
480
628
  };
481
629
 
482
630
  // not really a std::min because it returns an elementwise min
483
631
  inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
484
632
  return simd16uint16{
485
- detail::simdlib::binary_func(av.data, bv.data, &vminq_u16)};
633
+ detail::simdlib::binary_func(av.data, bv.data).call<&vminq_u16>()};
486
634
  }
487
635
 
488
636
  inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
489
637
  return simd16uint16{
490
- detail::simdlib::binary_func(av.data, bv.data, &vmaxq_u16)};
638
+ detail::simdlib::binary_func(av.data, bv.data).call<&vmaxq_u16>()};
491
639
  }
492
640
 
493
641
  // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
@@ -515,6 +663,63 @@ inline uint32_t cmp_le32(
515
663
  return detail::simdlib::cmp_xe32<&vcleq_u16>(d0.data, d1.data, thr.data);
516
664
  }
517
665
 
666
+ // hadd does not cross lanes
667
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
668
+ return simd16uint16{
669
+ detail::simdlib::binary_func(a.data, b.data).call<&vpaddq_u16>()};
670
+ }
671
+
672
+ // Vectorized version of the following code:
673
+ // for (size_t i = 0; i < n; i++) {
674
+ // bool flag = (candidateValues[i] < currentValues[i]);
675
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
676
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
677
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
678
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
679
+ // }
680
+ // Max indices evaluation is inaccurate in case of equal values (the index of
681
+ // the last equal value is saved instead of the first one), but this behavior
682
+ // saves instructions.
683
+ inline void cmplt_min_max_fast(
684
+ const simd16uint16 candidateValues,
685
+ const simd16uint16 candidateIndices,
686
+ const simd16uint16 currentValues,
687
+ const simd16uint16 currentIndices,
688
+ simd16uint16& minValues,
689
+ simd16uint16& minIndices,
690
+ simd16uint16& maxValues,
691
+ simd16uint16& maxIndices) {
692
+ const uint16x8x2_t comparison = uint16x8x2_t{
693
+ vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
694
+ vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
695
+
696
+ minValues.data = uint16x8x2_t{
697
+ vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
698
+ vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
699
+ minIndices.data = uint16x8x2_t{
700
+ vbslq_u16(
701
+ comparison.val[0],
702
+ candidateIndices.data.val[0],
703
+ currentIndices.data.val[0]),
704
+ vbslq_u16(
705
+ comparison.val[1],
706
+ candidateIndices.data.val[1],
707
+ currentIndices.data.val[1])};
708
+
709
+ maxValues.data = uint16x8x2_t{
710
+ vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
711
+ vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
712
+ maxIndices.data = uint16x8x2_t{
713
+ vbslq_u16(
714
+ comparison.val[0],
715
+ currentIndices.data.val[0],
716
+ candidateIndices.data.val[0]),
717
+ vbslq_u16(
718
+ comparison.val[1],
719
+ currentIndices.data.val[1],
720
+ candidateIndices.data.val[1])};
721
+ }
722
+
518
723
  // vector of 32 unsigned 8-bit integers
519
724
  struct simd32uint8 {
520
725
  uint8x16x2_t data;
@@ -527,6 +732,47 @@ struct simd32uint8 {
527
732
 
528
733
  explicit simd32uint8(const uint8x16x2_t& v) : data{v} {}
529
734
 
735
+ template <
736
+ uint8_t _0,
737
+ uint8_t _1,
738
+ uint8_t _2,
739
+ uint8_t _3,
740
+ uint8_t _4,
741
+ uint8_t _5,
742
+ uint8_t _6,
743
+ uint8_t _7,
744
+ uint8_t _8,
745
+ uint8_t _9,
746
+ uint8_t _10,
747
+ uint8_t _11,
748
+ uint8_t _12,
749
+ uint8_t _13,
750
+ uint8_t _14,
751
+ uint8_t _15,
752
+ uint8_t _16,
753
+ uint8_t _17,
754
+ uint8_t _18,
755
+ uint8_t _19,
756
+ uint8_t _20,
757
+ uint8_t _21,
758
+ uint8_t _22,
759
+ uint8_t _23,
760
+ uint8_t _24,
761
+ uint8_t _25,
762
+ uint8_t _26,
763
+ uint8_t _27,
764
+ uint8_t _28,
765
+ uint8_t _29,
766
+ uint8_t _30,
767
+ uint8_t _31>
768
+ static simd32uint8 create() {
769
+ constexpr uint8_t ds[32] = {_0, _1, _2, _3, _4, _5, _6, _7,
770
+ _8, _9, _10, _11, _12, _13, _14, _15,
771
+ _16, _17, _18, _19, _20, _21, _22, _23,
772
+ _24, _25, _26, _27, _28, _29, _30, _31};
773
+ return simd32uint8{ds};
774
+ }
775
+
530
776
  template <
531
777
  typename T,
532
778
  typename std::enable_if<
@@ -539,7 +785,8 @@ struct simd32uint8 {
539
785
  : data{vld1q_u8(x), vld1q_u8(x + 16)} {}
540
786
 
541
787
  void clear() {
542
- detail::simdlib::set1(data, &vdupq_n_u8, static_cast<uint8_t>(0));
788
+ detail::simdlib::set1(data, static_cast<uint8_t>(0))
789
+ .call<&detail::simdlib::vdupq_n_u8>();
543
790
  }
544
791
 
545
792
  void storeu(uint8_t* ptr) const {
@@ -582,7 +829,7 @@ struct simd32uint8 {
582
829
  }
583
830
 
584
831
  void set1(uint8_t x) {
585
- detail::simdlib::set1(data, &vdupq_n_u8, x);
832
+ detail::simdlib::set1(data, x).call<&detail::simdlib::vdupq_n_u8>();
586
833
  }
587
834
 
588
835
  template <
@@ -591,19 +838,21 @@ struct simd32uint8 {
591
838
  detail::simdlib::is_simd256bit<T>::value,
592
839
  std::nullptr_t>::type = nullptr>
593
840
  simd32uint8 operator&(const T& other) const {
594
- return simd32uint8{detail::simdlib::binary_func(
595
- data, detail::simdlib::reinterpret_u8(other.data), &vandq_u8)};
841
+ return simd32uint8{
842
+ detail::simdlib::binary_func(
843
+ data, detail::simdlib::reinterpret_u8(other.data))
844
+ .template call<&vandq_u8>()};
596
845
  }
597
846
 
598
847
  simd32uint8 operator+(const simd32uint8& other) const {
599
- return simd32uint8{
600
- detail::simdlib::binary_func(data, other.data, &vaddq_u8)};
848
+ return simd32uint8{detail::simdlib::binary_func(data, other.data)
849
+ .call<&vaddq_u8>()};
601
850
  }
602
851
 
603
852
  // The very important operation that everything relies on
604
853
  simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
605
- return simd32uint8{
606
- detail::simdlib::binary_func(data, idx.data, &vqtbl1q_u8)};
854
+ return simd32uint8{detail::simdlib::binary_func(data, idx.data)
855
+ .call<&vqtbl1q_u8>()};
607
856
  }
608
857
 
609
858
  simd32uint8 operator+=(const simd32uint8& other) {
@@ -618,6 +867,16 @@ struct simd32uint8 {
618
867
  vst1q_u8(tab, data.val[high]);
619
868
  return tab[i - high * 16];
620
869
  }
870
+
871
+ // Checks whether the other holds exactly the same bytes.
872
+ bool is_same_as(simd32uint8 other) const {
873
+ const bool equal0 =
874
+ (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
875
+ const bool equal1 =
876
+ (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
877
+
878
+ return equal0 && equal1;
879
+ }
621
880
  };
622
881
 
623
882
  // convert with saturation
@@ -671,8 +930,62 @@ struct simd8uint32 {
671
930
 
672
931
  explicit simd8uint32(const uint8_t* x) : simd8uint32(simd32uint8(x)) {}
673
932
 
933
+ explicit simd8uint32(
934
+ uint32_t u0,
935
+ uint32_t u1,
936
+ uint32_t u2,
937
+ uint32_t u3,
938
+ uint32_t u4,
939
+ uint32_t u5,
940
+ uint32_t u6,
941
+ uint32_t u7) {
942
+ uint32_t temp[8] = {u0, u1, u2, u3, u4, u5, u6, u7};
943
+ data.val[0] = vld1q_u32(temp);
944
+ data.val[1] = vld1q_u32(temp + 4);
945
+ }
946
+
947
+ simd8uint32 operator+(simd8uint32 other) const {
948
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
949
+ .call<&vaddq_u32>()};
950
+ }
951
+
952
+ simd8uint32 operator-(simd8uint32 other) const {
953
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
954
+ .call<&vsubq_u32>()};
955
+ }
956
+
957
+ simd8uint32& operator+=(const simd8uint32& other) {
958
+ data.val[0] = vaddq_u32(data.val[0], other.data.val[0]);
959
+ data.val[1] = vaddq_u32(data.val[1], other.data.val[1]);
960
+ return *this;
961
+ }
962
+
963
+ bool operator==(simd8uint32 other) const {
964
+ const auto equals = detail::simdlib::binary_func(data, other.data)
965
+ .call<&vceqq_u32>();
966
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
967
+ return vminvq_u32(equal) == 0xffffffff;
968
+ }
969
+
970
+ bool operator!=(simd8uint32 other) const {
971
+ return !(*this == other);
972
+ }
973
+
974
+ // Checks whether the other holds exactly the same bytes.
975
+ bool is_same_as(simd8uint32 other) const {
976
+ const bool equal0 =
977
+ (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
978
+ 0xffffffff);
979
+ const bool equal1 =
980
+ (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
981
+ 0xffffffff);
982
+
983
+ return equal0 && equal1;
984
+ }
985
+
674
986
  void clear() {
675
- detail::simdlib::set1(data, &vdupq_n_u32, static_cast<uint32_t>(0));
987
+ detail::simdlib::set1(data, static_cast<uint32_t>(0))
988
+ .call<&vdupq_n_u32>();
676
989
  }
677
990
 
678
991
  void storeu(uint32_t* ptr) const {
@@ -710,10 +1023,67 @@ struct simd8uint32 {
710
1023
  }
711
1024
 
712
1025
  void set1(uint32_t x) {
713
- detail::simdlib::set1(data, &vdupq_n_u32, x);
1026
+ detail::simdlib::set1(data, x).call<&vdupq_n_u32>();
1027
+ }
1028
+
1029
+ simd8uint32 unzip() const {
1030
+ return simd8uint32{uint32x4x2_t{
1031
+ vuzp1q_u32(data.val[0], data.val[1]),
1032
+ vuzp2q_u32(data.val[0], data.val[1])}};
714
1033
  }
715
1034
  };
716
1035
 
1036
+ // Vectorized version of the following code:
1037
+ // for (size_t i = 0; i < n; i++) {
1038
+ // bool flag = (candidateValues[i] < currentValues[i]);
1039
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1040
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1041
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1042
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1043
+ // }
1044
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1045
+ // the last equal value is saved instead of the first one), but this behavior
1046
+ // saves instructions.
1047
+ inline void cmplt_min_max_fast(
1048
+ const simd8uint32 candidateValues,
1049
+ const simd8uint32 candidateIndices,
1050
+ const simd8uint32 currentValues,
1051
+ const simd8uint32 currentIndices,
1052
+ simd8uint32& minValues,
1053
+ simd8uint32& minIndices,
1054
+ simd8uint32& maxValues,
1055
+ simd8uint32& maxIndices) {
1056
+ const uint32x4x2_t comparison = uint32x4x2_t{
1057
+ vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1058
+ vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1059
+
1060
+ minValues.data = uint32x4x2_t{
1061
+ vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1062
+ vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1063
+ minIndices.data = uint32x4x2_t{
1064
+ vbslq_u32(
1065
+ comparison.val[0],
1066
+ candidateIndices.data.val[0],
1067
+ currentIndices.data.val[0]),
1068
+ vbslq_u32(
1069
+ comparison.val[1],
1070
+ candidateIndices.data.val[1],
1071
+ currentIndices.data.val[1])};
1072
+
1073
+ maxValues.data = uint32x4x2_t{
1074
+ vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1075
+ vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1076
+ maxIndices.data = uint32x4x2_t{
1077
+ vbslq_u32(
1078
+ comparison.val[0],
1079
+ currentIndices.data.val[0],
1080
+ candidateIndices.data.val[0]),
1081
+ vbslq_u32(
1082
+ comparison.val[1],
1083
+ currentIndices.data.val[1],
1084
+ candidateIndices.data.val[1])};
1085
+ }
1086
+
717
1087
  struct simd8float32 {
718
1088
  float32x4x2_t data;
719
1089
 
@@ -734,8 +1104,22 @@ struct simd8float32 {
734
1104
  explicit simd8float32(const float* x)
735
1105
  : data{vld1q_f32(x), vld1q_f32(x + 4)} {}
736
1106
 
1107
+ explicit simd8float32(
1108
+ float f0,
1109
+ float f1,
1110
+ float f2,
1111
+ float f3,
1112
+ float f4,
1113
+ float f5,
1114
+ float f6,
1115
+ float f7) {
1116
+ float temp[8] = {f0, f1, f2, f3, f4, f5, f6, f7};
1117
+ data.val[0] = vld1q_f32(temp);
1118
+ data.val[1] = vld1q_f32(temp + 4);
1119
+ }
1120
+
737
1121
  void clear() {
738
- detail::simdlib::set1(data, &vdupq_n_f32, 0.f);
1122
+ detail::simdlib::set1(data, 0.f).call<&vdupq_n_f32>();
739
1123
  }
740
1124
 
741
1125
  void storeu(float* ptr) const {
@@ -761,18 +1145,50 @@ struct simd8float32 {
761
1145
  }
762
1146
 
763
1147
  simd8float32 operator*(const simd8float32& other) const {
764
- return simd8float32{
765
- detail::simdlib::binary_func(data, other.data, &vmulq_f32)};
1148
+ return simd8float32{detail::simdlib::binary_func(data, other.data)
1149
+ .call<&vmulq_f32>()};
766
1150
  }
767
1151
 
768
1152
  simd8float32 operator+(const simd8float32& other) const {
769
- return simd8float32{
770
- detail::simdlib::binary_func(data, other.data, &vaddq_f32)};
1153
+ return simd8float32{detail::simdlib::binary_func(data, other.data)
1154
+ .call<&vaddq_f32>()};
771
1155
  }
772
1156
 
773
1157
  simd8float32 operator-(const simd8float32& other) const {
774
- return simd8float32{
775
- detail::simdlib::binary_func(data, other.data, &vsubq_f32)};
1158
+ return simd8float32{detail::simdlib::binary_func(data, other.data)
1159
+ .call<&vsubq_f32>()};
1160
+ }
1161
+
1162
+ simd8float32& operator+=(const simd8float32& other) {
1163
+ // In this context, it is more compiler friendly to write intrinsics
1164
+ // directly instead of using binary_func
1165
+ data.val[0] = vaddq_f32(data.val[0], other.data.val[0]);
1166
+ data.val[1] = vaddq_f32(data.val[1], other.data.val[1]);
1167
+ return *this;
1168
+ }
1169
+
1170
+ bool operator==(simd8float32 other) const {
1171
+ const auto equals =
1172
+ detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
1173
+ .call<&vceqq_f32>();
1174
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1175
+ return vminvq_u32(equal) == 0xffffffff;
1176
+ }
1177
+
1178
+ bool operator!=(simd8float32 other) const {
1179
+ return !(*this == other);
1180
+ }
1181
+
1182
+ // Checks whether the other holds exactly the same bytes.
1183
+ bool is_same_as(simd8float32 other) const {
1184
+ const bool equal0 =
1185
+ (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
1186
+ 0xffffffff);
1187
+ const bool equal1 =
1188
+ (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
1189
+ 0xffffffff);
1190
+
1191
+ return equal0 && equal1;
776
1192
  }
777
1193
 
778
1194
  std::string tostring() const {
@@ -783,17 +1199,17 @@ struct simd8float32 {
783
1199
  // hadd does not cross lanes
784
1200
  inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
785
1201
  return simd8float32{
786
- detail::simdlib::binary_func(a.data, b.data, &vpaddq_f32)};
1202
+ detail::simdlib::binary_func(a.data, b.data).call<&vpaddq_f32>()};
787
1203
  }
788
1204
 
789
1205
  inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
790
1206
  return simd8float32{
791
- detail::simdlib::binary_func(a.data, b.data, &vzip1q_f32)};
1207
+ detail::simdlib::binary_func(a.data, b.data).call<&vzip1q_f32>()};
792
1208
  }
793
1209
 
794
1210
  inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
795
1211
  return simd8float32{
796
- detail::simdlib::binary_func(a.data, b.data, &vzip2q_f32)};
1212
+ detail::simdlib::binary_func(a.data, b.data).call<&vzip2q_f32>()};
797
1213
  }
798
1214
 
799
1215
  // compute a * b + c
@@ -806,20 +1222,129 @@ inline simd8float32 fmadd(
806
1222
  vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}};
807
1223
  }
808
1224
 
1225
+ // The following primitive is a vectorized version of the following code
1226
+ // snippet:
1227
+ // float lowestValue = HUGE_VAL;
1228
+ // uint lowestIndex = 0;
1229
+ // for (size_t i = 0; i < n; i++) {
1230
+ // if (values[i] < lowestValue) {
1231
+ // lowestValue = values[i];
1232
+ // lowestIndex = i;
1233
+ // }
1234
+ // }
1235
+ // Vectorized version can be implemented via two operations: cmp and blend
1236
+ // with something like this:
1237
+ // lowestValues = [HUGE_VAL; 8];
1238
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
1239
+ // for (size_t i = 0; i < n; i += 8) {
1240
+ // auto comparison = cmp(values + i, lowestValues);
1241
+ // lowestValues = blend(
1242
+ // comparison,
1243
+ // values + i,
1244
+ // lowestValues);
1245
+ // lowestIndices = blend(
1246
+ // comparison,
1247
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
1248
+ // lowestIndices);
1249
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
1250
+ // }
1251
+ // The problem is that blend primitive needs very different instruction
1252
+ // order for AVX and ARM.
1253
+ // So, let's introduce a combination of these two in order to avoid
1254
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
1255
+ // these two ops (cmp and blend) are very often used together.
1256
+ inline void cmplt_and_blend_inplace(
1257
+ const simd8float32 candidateValues,
1258
+ const simd8uint32 candidateIndices,
1259
+ simd8float32& lowestValues,
1260
+ simd8uint32& lowestIndices) {
1261
+ const auto comparison = detail::simdlib::binary_func<::uint32x4x2_t>(
1262
+ candidateValues.data, lowestValues.data)
1263
+ .call<&vcltq_f32>();
1264
+
1265
+ lowestValues.data = float32x4x2_t{
1266
+ vbslq_f32(
1267
+ comparison.val[0],
1268
+ candidateValues.data.val[0],
1269
+ lowestValues.data.val[0]),
1270
+ vbslq_f32(
1271
+ comparison.val[1],
1272
+ candidateValues.data.val[1],
1273
+ lowestValues.data.val[1])};
1274
+ lowestIndices.data = uint32x4x2_t{
1275
+ vbslq_u32(
1276
+ comparison.val[0],
1277
+ candidateIndices.data.val[0],
1278
+ lowestIndices.data.val[0]),
1279
+ vbslq_u32(
1280
+ comparison.val[1],
1281
+ candidateIndices.data.val[1],
1282
+ lowestIndices.data.val[1])};
1283
+ }
1284
+
1285
+ // Vectorized version of the following code:
1286
+ // for (size_t i = 0; i < n; i++) {
1287
+ // bool flag = (candidateValues[i] < currentValues[i]);
1288
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1289
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1290
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1291
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1292
+ // }
1293
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1294
+ // the last equal value is saved instead of the first one), but this behavior
1295
+ // saves instructions.
1296
+ inline void cmplt_min_max_fast(
1297
+ const simd8float32 candidateValues,
1298
+ const simd8uint32 candidateIndices,
1299
+ const simd8float32 currentValues,
1300
+ const simd8uint32 currentIndices,
1301
+ simd8float32& minValues,
1302
+ simd8uint32& minIndices,
1303
+ simd8float32& maxValues,
1304
+ simd8uint32& maxIndices) {
1305
+ const uint32x4x2_t comparison = uint32x4x2_t{
1306
+ vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1307
+ vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1308
+
1309
+ minValues.data = float32x4x2_t{
1310
+ vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1311
+ vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1312
+ minIndices.data = uint32x4x2_t{
1313
+ vbslq_u32(
1314
+ comparison.val[0],
1315
+ candidateIndices.data.val[0],
1316
+ currentIndices.data.val[0]),
1317
+ vbslq_u32(
1318
+ comparison.val[1],
1319
+ candidateIndices.data.val[1],
1320
+ currentIndices.data.val[1])};
1321
+
1322
+ maxValues.data = float32x4x2_t{
1323
+ vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1324
+ vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1325
+ maxIndices.data = uint32x4x2_t{
1326
+ vbslq_u32(
1327
+ comparison.val[0],
1328
+ currentIndices.data.val[0],
1329
+ candidateIndices.data.val[0]),
1330
+ vbslq_u32(
1331
+ comparison.val[1],
1332
+ currentIndices.data.val[1],
1333
+ candidateIndices.data.val[1])};
1334
+ }
1335
+
809
1336
  namespace {
810
1337
 
811
1338
  // get even float32's of a and b, interleaved
812
1339
  simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
813
- return simd8float32{float32x4x2_t{
814
- vuzp1q_f32(a.data.val[0], b.data.val[0]),
815
- vuzp1q_f32(a.data.val[1], b.data.val[1])}};
1340
+ return simd8float32{
1341
+ detail::simdlib::binary_func(a.data, b.data).call<&vuzp1q_f32>()};
816
1342
  }
817
1343
 
818
1344
  // get odd float32's of a and b, interleaved
819
1345
  simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
820
- return simd8float32{float32x4x2_t{
821
- vuzp2q_f32(a.data.val[0], b.data.val[0]),
822
- vuzp2q_f32(a.data.val[1], b.data.val[1])}};
1346
+ return simd8float32{
1347
+ detail::simdlib::binary_func(a.data, b.data).call<&vuzp2q_f32>()};
823
1348
  }
824
1349
 
825
1350
  // 3 cycles