faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -18,6 +18,8 @@
18
18
 
19
19
  #include <arm_neon.h>
20
20
 
21
+ #include <faiss/impl/FaissAssert.h>
22
+
21
23
  namespace faiss {
22
24
 
23
25
  namespace detail {
@@ -88,6 +90,23 @@ static inline float32x4x2_t reinterpret_f32(const float32x4x2_t& v) {
88
90
  return v;
89
91
  }
90
92
 
93
+ // Surprisingly, vdupq_n_u16 has the type of
94
+ // uint16x8_t (std::uint32_t) , and vdupq_n_u8 also has
95
+ // uint8x16_t (std::uint32_t) on **some environments**.
96
+ // We want argument type as same as the type of element
97
+ // of result vector type (std::uint16_t for uint16x8_t,
98
+ // and std::uint8_t for uint8x16_t) instead of
99
+ // std::uint32_t due to using set1 function templates,
100
+ // so let's fix the argument type here and use these
101
+ // overload below.
102
+ static inline ::uint16x8_t vdupq_n_u16(std::uint16_t v) {
103
+ return ::vdupq_n_u16(v);
104
+ }
105
+
106
+ static inline ::uint8x16_t vdupq_n_u8(std::uint8_t v) {
107
+ return ::vdupq_n_u8(v);
108
+ }
109
+
91
110
  template <
92
111
  typename T,
93
112
  typename U = decltype(reinterpret_u8(std::declval<T>().data))>
@@ -119,11 +138,25 @@ static inline std::string bin(const S& simd) {
119
138
  return std::string(bits);
120
139
  }
121
140
 
122
- template <typename D, typename F, typename T>
123
- static inline void set1(D& d, F&& f, T t) {
124
- const auto v = f(t);
125
- d.val[0] = v;
126
- d.val[1] = v;
141
+ template <typename T>
142
+ using remove_cv_ref_t =
143
+ typename std::remove_reference<typename std::remove_cv<T>::type>::type;
144
+
145
+ template <typename D, typename T>
146
+ struct set1_impl {
147
+ D& d;
148
+ T t;
149
+ template <remove_cv_ref_t<decltype(std::declval<D>().val[0])> (*F)(T)>
150
+ inline void call() {
151
+ const auto v = F(t);
152
+ d.val[0] = v;
153
+ d.val[1] = v;
154
+ }
155
+ };
156
+
157
+ template <typename D, typename T>
158
+ static inline set1_impl<remove_cv_ref_t<D>, T> set1(D& d, T t) {
159
+ return {d, t};
127
160
  }
128
161
 
129
162
  template <typename T, size_t N, typename S>
@@ -142,20 +175,57 @@ static inline std::string elements_to_string(const char* fmt, const S& simd) {
142
175
  return std::string(res);
143
176
  }
144
177
 
145
- template <typename T, typename F>
146
- static inline T unary_func(const T& a, F&& f) {
147
- T t;
148
- t.val[0] = f(a.val[0]);
149
- t.val[1] = f(a.val[1]);
150
- return t;
178
+ template <typename T, typename U>
179
+ struct unary_func_impl {
180
+ const U& a;
181
+ using Telem = remove_cv_ref_t<decltype(std::declval<T>().val[0])>;
182
+ using Uelem = remove_cv_ref_t<decltype(std::declval<U>().val[0])>;
183
+ template <Telem (*F)(Uelem)>
184
+ inline T call() {
185
+ T t;
186
+ t.val[0] = F(a.val[0]);
187
+ t.val[1] = F(a.val[1]);
188
+ return t;
189
+ }
190
+ };
191
+
192
+ template <typename T>
193
+ static inline unary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<T>> unary_func(
194
+ const T& a) {
195
+ return {a};
151
196
  }
152
197
 
153
- template <typename T, typename F>
154
- static inline T binary_func(const T& a, const T& b, F&& f) {
155
- T t;
156
- t.val[0] = f(a.val[0], b.val[0]);
157
- t.val[1] = f(a.val[1], b.val[1]);
158
- return t;
198
+ template <typename T, typename U>
199
+ static inline unary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<U>> unary_func(
200
+ const U& a) {
201
+ return {a};
202
+ }
203
+
204
+ template <typename T, typename U>
205
+ struct binary_func_impl {
206
+ const U& a;
207
+ const U& b;
208
+ using Telem = remove_cv_ref_t<decltype(std::declval<T>().val[0])>;
209
+ using Uelem = remove_cv_ref_t<decltype(std::declval<U>().val[0])>;
210
+ template <Telem (*F)(Uelem, Uelem)>
211
+ inline T call() {
212
+ T t;
213
+ t.val[0] = F(a.val[0], b.val[0]);
214
+ t.val[1] = F(a.val[1], b.val[1]);
215
+ return t;
216
+ }
217
+ };
218
+
219
+ template <typename T>
220
+ static inline binary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<T>>
221
+ binary_func(const T& a, const T& b) {
222
+ return {a, b};
223
+ }
224
+
225
+ template <typename T, typename U>
226
+ static inline binary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<U>>
227
+ binary_func(const U& a, const U& b) {
228
+ return {a, b};
159
229
  }
160
230
 
161
231
  static inline uint16_t vmovmask_u8(const uint8x16_t& v) {
@@ -172,8 +242,8 @@ static inline uint32_t cmp_xe32(
172
242
  const uint16x8x2_t& d0,
173
243
  const uint16x8x2_t& d1,
174
244
  const uint16x8x2_t& thr) {
175
- const auto d0_thr = detail::simdlib::binary_func(d0, thr, F);
176
- const auto d1_thr = detail::simdlib::binary_func(d1, thr, F);
245
+ const auto d0_thr = detail::simdlib::binary_func(d0, thr).call<F>();
246
+ const auto d1_thr = detail::simdlib::binary_func(d1, thr).call<F>();
177
247
  const auto d0_mask = vmovmask_u8(
178
248
  vmovn_high_u16(vmovn_u16(d0_thr.val[0]), d0_thr.val[1]));
179
249
  const auto d1_mask = vmovmask_u8(
@@ -207,6 +277,44 @@ struct simd16uint16 {
207
277
 
208
278
  explicit simd16uint16(const uint16x8x2_t& v) : data{v} {}
209
279
 
280
+ explicit simd16uint16(
281
+ uint16_t u0,
282
+ uint16_t u1,
283
+ uint16_t u2,
284
+ uint16_t u3,
285
+ uint16_t u4,
286
+ uint16_t u5,
287
+ uint16_t u6,
288
+ uint16_t u7,
289
+ uint16_t u8,
290
+ uint16_t u9,
291
+ uint16_t u10,
292
+ uint16_t u11,
293
+ uint16_t u12,
294
+ uint16_t u13,
295
+ uint16_t u14,
296
+ uint16_t u15) {
297
+ uint16_t temp[16] = {
298
+ u0,
299
+ u1,
300
+ u2,
301
+ u3,
302
+ u4,
303
+ u5,
304
+ u6,
305
+ u7,
306
+ u8,
307
+ u9,
308
+ u10,
309
+ u11,
310
+ u12,
311
+ u13,
312
+ u14,
313
+ u15};
314
+ data.val[0] = vld1q_u16(temp);
315
+ data.val[1] = vld1q_u16(temp + 8);
316
+ }
317
+
210
318
  template <
211
319
  typename T,
212
320
  typename std::enable_if<
@@ -219,7 +327,8 @@ struct simd16uint16 {
219
327
  : data{vld1q_u16(x), vld1q_u16(x + 8)} {}
220
328
 
221
329
  void clear() {
222
- detail::simdlib::set1(data, &vdupq_n_u16, static_cast<uint16_t>(0));
330
+ detail::simdlib::set1(data, static_cast<uint16_t>(0))
331
+ .call<&detail::simdlib::vdupq_n_u16>();
223
332
  }
224
333
 
225
334
  void storeu(uint16_t* ptr) const {
@@ -257,12 +366,12 @@ struct simd16uint16 {
257
366
  }
258
367
 
259
368
  void set1(uint16_t x) {
260
- detail::simdlib::set1(data, &vdupq_n_u16, x);
369
+ detail::simdlib::set1(data, x).call<&detail::simdlib::vdupq_n_u16>();
261
370
  }
262
371
 
263
372
  simd16uint16 operator*(const simd16uint16& other) const {
264
- return simd16uint16{
265
- detail::simdlib::binary_func(data, other.data, &vmulq_u16)};
373
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
374
+ .call<&vmulq_u16>()};
266
375
  }
267
376
 
268
377
  // shift must be known at compile time
@@ -271,50 +380,56 @@ struct simd16uint16 {
271
380
  case 0:
272
381
  return *this;
273
382
  case 1:
274
- return simd16uint16{detail::simdlib::unary_func(
275
- data, detail::simdlib::vshrq<1>)};
383
+ return simd16uint16{detail::simdlib::unary_func(data)
384
+ .call<detail::simdlib::vshrq<1>>()};
276
385
  case 2:
277
- return simd16uint16{detail::simdlib::unary_func(
278
- data, detail::simdlib::vshrq<2>)};
386
+ return simd16uint16{detail::simdlib::unary_func(data)
387
+ .call<detail::simdlib::vshrq<2>>()};
279
388
  case 3:
280
- return simd16uint16{detail::simdlib::unary_func(
281
- data, detail::simdlib::vshrq<3>)};
389
+ return simd16uint16{detail::simdlib::unary_func(data)
390
+ .call<detail::simdlib::vshrq<3>>()};
282
391
  case 4:
283
- return simd16uint16{detail::simdlib::unary_func(
284
- data, detail::simdlib::vshrq<4>)};
392
+ return simd16uint16{detail::simdlib::unary_func(data)
393
+ .call<detail::simdlib::vshrq<4>>()};
285
394
  case 5:
286
- return simd16uint16{detail::simdlib::unary_func(
287
- data, detail::simdlib::vshrq<5>)};
395
+ return simd16uint16{detail::simdlib::unary_func(data)
396
+ .call<detail::simdlib::vshrq<5>>()};
288
397
  case 6:
289
- return simd16uint16{detail::simdlib::unary_func(
290
- data, detail::simdlib::vshrq<6>)};
398
+ return simd16uint16{detail::simdlib::unary_func(data)
399
+ .call<detail::simdlib::vshrq<6>>()};
291
400
  case 7:
292
- return simd16uint16{detail::simdlib::unary_func(
293
- data, detail::simdlib::vshrq<7>)};
401
+ return simd16uint16{detail::simdlib::unary_func(data)
402
+ .call<detail::simdlib::vshrq<7>>()};
294
403
  case 8:
295
- return simd16uint16{detail::simdlib::unary_func(
296
- data, detail::simdlib::vshrq<8>)};
404
+ return simd16uint16{detail::simdlib::unary_func(data)
405
+ .call<detail::simdlib::vshrq<8>>()};
297
406
  case 9:
298
- return simd16uint16{detail::simdlib::unary_func(
299
- data, detail::simdlib::vshrq<9>)};
407
+ return simd16uint16{detail::simdlib::unary_func(data)
408
+ .call<detail::simdlib::vshrq<9>>()};
300
409
  case 10:
301
- return simd16uint16{detail::simdlib::unary_func(
302
- data, detail::simdlib::vshrq<10>)};
410
+ return simd16uint16{
411
+ detail::simdlib::unary_func(data)
412
+ .call<detail::simdlib::vshrq<10>>()};
303
413
  case 11:
304
- return simd16uint16{detail::simdlib::unary_func(
305
- data, detail::simdlib::vshrq<11>)};
414
+ return simd16uint16{
415
+ detail::simdlib::unary_func(data)
416
+ .call<detail::simdlib::vshrq<11>>()};
306
417
  case 12:
307
- return simd16uint16{detail::simdlib::unary_func(
308
- data, detail::simdlib::vshrq<12>)};
418
+ return simd16uint16{
419
+ detail::simdlib::unary_func(data)
420
+ .call<detail::simdlib::vshrq<12>>()};
309
421
  case 13:
310
- return simd16uint16{detail::simdlib::unary_func(
311
- data, detail::simdlib::vshrq<13>)};
422
+ return simd16uint16{
423
+ detail::simdlib::unary_func(data)
424
+ .call<detail::simdlib::vshrq<13>>()};
312
425
  case 14:
313
- return simd16uint16{detail::simdlib::unary_func(
314
- data, detail::simdlib::vshrq<14>)};
426
+ return simd16uint16{
427
+ detail::simdlib::unary_func(data)
428
+ .call<detail::simdlib::vshrq<14>>()};
315
429
  case 15:
316
- return simd16uint16{detail::simdlib::unary_func(
317
- data, detail::simdlib::vshrq<15>)};
430
+ return simd16uint16{
431
+ detail::simdlib::unary_func(data)
432
+ .call<detail::simdlib::vshrq<15>>()};
318
433
  default:
319
434
  FAISS_THROW_FMT("Invalid shift %d", shift);
320
435
  }
@@ -326,50 +441,56 @@ struct simd16uint16 {
326
441
  case 0:
327
442
  return *this;
328
443
  case 1:
329
- return simd16uint16{detail::simdlib::unary_func(
330
- data, detail::simdlib::vshlq<1>)};
444
+ return simd16uint16{detail::simdlib::unary_func(data)
445
+ .call<detail::simdlib::vshlq<1>>()};
331
446
  case 2:
332
- return simd16uint16{detail::simdlib::unary_func(
333
- data, detail::simdlib::vshlq<2>)};
447
+ return simd16uint16{detail::simdlib::unary_func(data)
448
+ .call<detail::simdlib::vshlq<2>>()};
334
449
  case 3:
335
- return simd16uint16{detail::simdlib::unary_func(
336
- data, detail::simdlib::vshlq<3>)};
450
+ return simd16uint16{detail::simdlib::unary_func(data)
451
+ .call<detail::simdlib::vshlq<3>>()};
337
452
  case 4:
338
- return simd16uint16{detail::simdlib::unary_func(
339
- data, detail::simdlib::vshlq<4>)};
453
+ return simd16uint16{detail::simdlib::unary_func(data)
454
+ .call<detail::simdlib::vshlq<4>>()};
340
455
  case 5:
341
- return simd16uint16{detail::simdlib::unary_func(
342
- data, detail::simdlib::vshlq<5>)};
456
+ return simd16uint16{detail::simdlib::unary_func(data)
457
+ .call<detail::simdlib::vshlq<5>>()};
343
458
  case 6:
344
- return simd16uint16{detail::simdlib::unary_func(
345
- data, detail::simdlib::vshlq<6>)};
459
+ return simd16uint16{detail::simdlib::unary_func(data)
460
+ .call<detail::simdlib::vshlq<6>>()};
346
461
  case 7:
347
- return simd16uint16{detail::simdlib::unary_func(
348
- data, detail::simdlib::vshlq<7>)};
462
+ return simd16uint16{detail::simdlib::unary_func(data)
463
+ .call<detail::simdlib::vshlq<7>>()};
349
464
  case 8:
350
- return simd16uint16{detail::simdlib::unary_func(
351
- data, detail::simdlib::vshlq<8>)};
465
+ return simd16uint16{detail::simdlib::unary_func(data)
466
+ .call<detail::simdlib::vshlq<8>>()};
352
467
  case 9:
353
- return simd16uint16{detail::simdlib::unary_func(
354
- data, detail::simdlib::vshlq<9>)};
468
+ return simd16uint16{detail::simdlib::unary_func(data)
469
+ .call<detail::simdlib::vshlq<9>>()};
355
470
  case 10:
356
- return simd16uint16{detail::simdlib::unary_func(
357
- data, detail::simdlib::vshlq<10>)};
471
+ return simd16uint16{
472
+ detail::simdlib::unary_func(data)
473
+ .call<detail::simdlib::vshlq<10>>()};
358
474
  case 11:
359
- return simd16uint16{detail::simdlib::unary_func(
360
- data, detail::simdlib::vshlq<11>)};
475
+ return simd16uint16{
476
+ detail::simdlib::unary_func(data)
477
+ .call<detail::simdlib::vshlq<11>>()};
361
478
  case 12:
362
- return simd16uint16{detail::simdlib::unary_func(
363
- data, detail::simdlib::vshlq<12>)};
479
+ return simd16uint16{
480
+ detail::simdlib::unary_func(data)
481
+ .call<detail::simdlib::vshlq<12>>()};
364
482
  case 13:
365
- return simd16uint16{detail::simdlib::unary_func(
366
- data, detail::simdlib::vshlq<13>)};
483
+ return simd16uint16{
484
+ detail::simdlib::unary_func(data)
485
+ .call<detail::simdlib::vshlq<13>>()};
367
486
  case 14:
368
- return simd16uint16{detail::simdlib::unary_func(
369
- data, detail::simdlib::vshlq<14>)};
487
+ return simd16uint16{
488
+ detail::simdlib::unary_func(data)
489
+ .call<detail::simdlib::vshlq<14>>()};
370
490
  case 15:
371
- return simd16uint16{detail::simdlib::unary_func(
372
- data, detail::simdlib::vshlq<15>)};
491
+ return simd16uint16{
492
+ detail::simdlib::unary_func(data)
493
+ .call<detail::simdlib::vshlq<15>>()};
373
494
  default:
374
495
  FAISS_THROW_FMT("Invalid shift %d", shift);
375
496
  }
@@ -386,13 +507,13 @@ struct simd16uint16 {
386
507
  }
387
508
 
388
509
  simd16uint16 operator+(const simd16uint16& other) const {
389
- return simd16uint16{
390
- detail::simdlib::binary_func(data, other.data, &vaddq_u16)};
510
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
511
+ .call<&vaddq_u16>()};
391
512
  }
392
513
 
393
514
  simd16uint16 operator-(const simd16uint16& other) const {
394
- return simd16uint16{
395
- detail::simdlib::binary_func(data, other.data, &vsubq_u16)};
515
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
516
+ .call<&vsubq_u16>()};
396
517
  }
397
518
 
398
519
  template <
@@ -401,10 +522,10 @@ struct simd16uint16 {
401
522
  detail::simdlib::is_simd256bit<T>::value,
402
523
  std::nullptr_t>::type = nullptr>
403
524
  simd16uint16 operator&(const T& other) const {
404
- return simd16uint16{detail::simdlib::binary_func(
405
- data,
406
- detail::simdlib::reinterpret_u16(other.data),
407
- &vandq_u16)};
525
+ return simd16uint16{
526
+ detail::simdlib::binary_func(
527
+ data, detail::simdlib::reinterpret_u16(other.data))
528
+ .template call<&vandq_u16>()};
408
529
  }
409
530
 
410
531
  template <
@@ -413,20 +534,45 @@ struct simd16uint16 {
413
534
  detail::simdlib::is_simd256bit<T>::value,
414
535
  std::nullptr_t>::type = nullptr>
415
536
  simd16uint16 operator|(const T& other) const {
416
- return simd16uint16{detail::simdlib::binary_func(
417
- data,
418
- detail::simdlib::reinterpret_u16(other.data),
419
- &vorrq_u16)};
537
+ return simd16uint16{
538
+ detail::simdlib::binary_func(
539
+ data, detail::simdlib::reinterpret_u16(other.data))
540
+ .template call<&vorrq_u16>()};
541
+ }
542
+
543
+ template <
544
+ typename T,
545
+ typename std::enable_if<
546
+ detail::simdlib::is_simd256bit<T>::value,
547
+ std::nullptr_t>::type = nullptr>
548
+ simd16uint16 operator^(const T& other) const {
549
+ return simd16uint16{
550
+ detail::simdlib::binary_func(
551
+ data, detail::simdlib::reinterpret_u16(other.data))
552
+ .template call<&veorq_u16>()};
420
553
  }
421
554
 
422
555
  // returns binary masks
423
556
  simd16uint16 operator==(const simd16uint16& other) const {
424
- return simd16uint16{
425
- detail::simdlib::binary_func(data, other.data, &vceqq_u16)};
557
+ return simd16uint16{detail::simdlib::binary_func(data, other.data)
558
+ .call<&vceqq_u16>()};
559
+ }
560
+
561
+ // Checks whether the other holds exactly the same bytes.
562
+ bool is_same_as(simd16uint16 other) const {
563
+ const bool equal0 =
564
+ (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
565
+ 0xffff);
566
+ const bool equal1 =
567
+ (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
568
+ 0xffff);
569
+
570
+ return equal0 && equal1;
426
571
  }
427
572
 
428
573
  simd16uint16 operator~() const {
429
- return simd16uint16{detail::simdlib::unary_func(data, &vmvnq_u16)};
574
+ return simd16uint16{
575
+ detail::simdlib::unary_func(data).call<&vmvnq_u16>()};
430
576
  }
431
577
 
432
578
  // get scalar at index 0
@@ -437,8 +583,8 @@ struct simd16uint16 {
437
583
  // mask of elements where this >= thresh
438
584
  // 2 bit per component: 16 * 2 = 32 bit
439
585
  uint32_t ge_mask(const simd16uint16& thresh) const {
440
- const auto input =
441
- detail::simdlib::binary_func(data, thresh.data, &vcgeq_u16);
586
+ const auto input = detail::simdlib::binary_func(data, thresh.data)
587
+ .call<&vcgeq_u16>();
442
588
  const auto vmovmask_u16 = [](uint16x8_t v) -> uint16_t {
443
589
  uint16_t d[8];
444
590
  const auto v2 = vreinterpretq_u32_u16(vshrq_n_u16(v, 14));
@@ -471,23 +617,25 @@ struct simd16uint16 {
471
617
  }
472
618
 
473
619
  void accu_min(const simd16uint16& incoming) {
474
- data = detail::simdlib::binary_func(incoming.data, data, &vminq_u16);
620
+ data = detail::simdlib::binary_func(incoming.data, data)
621
+ .call<&vminq_u16>();
475
622
  }
476
623
 
477
624
  void accu_max(const simd16uint16& incoming) {
478
- data = detail::simdlib::binary_func(incoming.data, data, &vmaxq_u16);
625
+ data = detail::simdlib::binary_func(incoming.data, data)
626
+ .call<&vmaxq_u16>();
479
627
  }
480
628
  };
481
629
 
482
630
  // not really a std::min because it returns an elementwise min
483
631
  inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
484
632
  return simd16uint16{
485
- detail::simdlib::binary_func(av.data, bv.data, &vminq_u16)};
633
+ detail::simdlib::binary_func(av.data, bv.data).call<&vminq_u16>()};
486
634
  }
487
635
 
488
636
  inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
489
637
  return simd16uint16{
490
- detail::simdlib::binary_func(av.data, bv.data, &vmaxq_u16)};
638
+ detail::simdlib::binary_func(av.data, bv.data).call<&vmaxq_u16>()};
491
639
  }
492
640
 
493
641
  // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
@@ -515,6 +663,63 @@ inline uint32_t cmp_le32(
515
663
  return detail::simdlib::cmp_xe32<&vcleq_u16>(d0.data, d1.data, thr.data);
516
664
  }
517
665
 
666
+ // hadd does not cross lanes
667
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
668
+ return simd16uint16{
669
+ detail::simdlib::binary_func(a.data, b.data).call<&vpaddq_u16>()};
670
+ }
671
+
672
+ // Vectorized version of the following code:
673
+ // for (size_t i = 0; i < n; i++) {
674
+ // bool flag = (candidateValues[i] < currentValues[i]);
675
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
676
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
677
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
678
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
679
+ // }
680
+ // Max indices evaluation is inaccurate in case of equal values (the index of
681
+ // the last equal value is saved instead of the first one), but this behavior
682
+ // saves instructions.
683
+ inline void cmplt_min_max_fast(
684
+ const simd16uint16 candidateValues,
685
+ const simd16uint16 candidateIndices,
686
+ const simd16uint16 currentValues,
687
+ const simd16uint16 currentIndices,
688
+ simd16uint16& minValues,
689
+ simd16uint16& minIndices,
690
+ simd16uint16& maxValues,
691
+ simd16uint16& maxIndices) {
692
+ const uint16x8x2_t comparison = uint16x8x2_t{
693
+ vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
694
+ vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
695
+
696
+ minValues.data = uint16x8x2_t{
697
+ vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
698
+ vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
699
+ minIndices.data = uint16x8x2_t{
700
+ vbslq_u16(
701
+ comparison.val[0],
702
+ candidateIndices.data.val[0],
703
+ currentIndices.data.val[0]),
704
+ vbslq_u16(
705
+ comparison.val[1],
706
+ candidateIndices.data.val[1],
707
+ currentIndices.data.val[1])};
708
+
709
+ maxValues.data = uint16x8x2_t{
710
+ vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
711
+ vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
712
+ maxIndices.data = uint16x8x2_t{
713
+ vbslq_u16(
714
+ comparison.val[0],
715
+ currentIndices.data.val[0],
716
+ candidateIndices.data.val[0]),
717
+ vbslq_u16(
718
+ comparison.val[1],
719
+ currentIndices.data.val[1],
720
+ candidateIndices.data.val[1])};
721
+ }
722
+
518
723
  // vector of 32 unsigned 8-bit integers
519
724
  struct simd32uint8 {
520
725
  uint8x16x2_t data;
@@ -527,6 +732,47 @@ struct simd32uint8 {
527
732
 
528
733
  explicit simd32uint8(const uint8x16x2_t& v) : data{v} {}
529
734
 
735
+ template <
736
+ uint8_t _0,
737
+ uint8_t _1,
738
+ uint8_t _2,
739
+ uint8_t _3,
740
+ uint8_t _4,
741
+ uint8_t _5,
742
+ uint8_t _6,
743
+ uint8_t _7,
744
+ uint8_t _8,
745
+ uint8_t _9,
746
+ uint8_t _10,
747
+ uint8_t _11,
748
+ uint8_t _12,
749
+ uint8_t _13,
750
+ uint8_t _14,
751
+ uint8_t _15,
752
+ uint8_t _16,
753
+ uint8_t _17,
754
+ uint8_t _18,
755
+ uint8_t _19,
756
+ uint8_t _20,
757
+ uint8_t _21,
758
+ uint8_t _22,
759
+ uint8_t _23,
760
+ uint8_t _24,
761
+ uint8_t _25,
762
+ uint8_t _26,
763
+ uint8_t _27,
764
+ uint8_t _28,
765
+ uint8_t _29,
766
+ uint8_t _30,
767
+ uint8_t _31>
768
+ static simd32uint8 create() {
769
+ constexpr uint8_t ds[32] = {_0, _1, _2, _3, _4, _5, _6, _7,
770
+ _8, _9, _10, _11, _12, _13, _14, _15,
771
+ _16, _17, _18, _19, _20, _21, _22, _23,
772
+ _24, _25, _26, _27, _28, _29, _30, _31};
773
+ return simd32uint8{ds};
774
+ }
775
+
530
776
  template <
531
777
  typename T,
532
778
  typename std::enable_if<
@@ -539,7 +785,8 @@ struct simd32uint8 {
539
785
  : data{vld1q_u8(x), vld1q_u8(x + 16)} {}
540
786
 
541
787
  void clear() {
542
- detail::simdlib::set1(data, &vdupq_n_u8, static_cast<uint8_t>(0));
788
+ detail::simdlib::set1(data, static_cast<uint8_t>(0))
789
+ .call<&detail::simdlib::vdupq_n_u8>();
543
790
  }
544
791
 
545
792
  void storeu(uint8_t* ptr) const {
@@ -582,7 +829,7 @@ struct simd32uint8 {
582
829
  }
583
830
 
584
831
  void set1(uint8_t x) {
585
- detail::simdlib::set1(data, &vdupq_n_u8, x);
832
+ detail::simdlib::set1(data, x).call<&detail::simdlib::vdupq_n_u8>();
586
833
  }
587
834
 
588
835
  template <
@@ -591,19 +838,21 @@ struct simd32uint8 {
591
838
  detail::simdlib::is_simd256bit<T>::value,
592
839
  std::nullptr_t>::type = nullptr>
593
840
  simd32uint8 operator&(const T& other) const {
594
- return simd32uint8{detail::simdlib::binary_func(
595
- data, detail::simdlib::reinterpret_u8(other.data), &vandq_u8)};
841
+ return simd32uint8{
842
+ detail::simdlib::binary_func(
843
+ data, detail::simdlib::reinterpret_u8(other.data))
844
+ .template call<&vandq_u8>()};
596
845
  }
597
846
 
598
847
  simd32uint8 operator+(const simd32uint8& other) const {
599
- return simd32uint8{
600
- detail::simdlib::binary_func(data, other.data, &vaddq_u8)};
848
+ return simd32uint8{detail::simdlib::binary_func(data, other.data)
849
+ .call<&vaddq_u8>()};
601
850
  }
602
851
 
603
852
  // The very important operation that everything relies on
604
853
  simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
605
- return simd32uint8{
606
- detail::simdlib::binary_func(data, idx.data, &vqtbl1q_u8)};
854
+ return simd32uint8{detail::simdlib::binary_func(data, idx.data)
855
+ .call<&vqtbl1q_u8>()};
607
856
  }
608
857
 
609
858
  simd32uint8 operator+=(const simd32uint8& other) {
@@ -618,6 +867,16 @@ struct simd32uint8 {
618
867
  vst1q_u8(tab, data.val[high]);
619
868
  return tab[i - high * 16];
620
869
  }
870
+
871
+ // Checks whether the other holds exactly the same bytes.
872
+ bool is_same_as(simd32uint8 other) const {
873
+ const bool equal0 =
874
+ (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
875
+ const bool equal1 =
876
+ (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
877
+
878
+ return equal0 && equal1;
879
+ }
621
880
  };
622
881
 
623
882
  // convert with saturation
@@ -671,8 +930,62 @@ struct simd8uint32 {
671
930
 
672
931
  explicit simd8uint32(const uint8_t* x) : simd8uint32(simd32uint8(x)) {}
673
932
 
933
+ explicit simd8uint32(
934
+ uint32_t u0,
935
+ uint32_t u1,
936
+ uint32_t u2,
937
+ uint32_t u3,
938
+ uint32_t u4,
939
+ uint32_t u5,
940
+ uint32_t u6,
941
+ uint32_t u7) {
942
+ uint32_t temp[8] = {u0, u1, u2, u3, u4, u5, u6, u7};
943
+ data.val[0] = vld1q_u32(temp);
944
+ data.val[1] = vld1q_u32(temp + 4);
945
+ }
946
+
947
+ simd8uint32 operator+(simd8uint32 other) const {
948
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
949
+ .call<&vaddq_u32>()};
950
+ }
951
+
952
+ simd8uint32 operator-(simd8uint32 other) const {
953
+ return simd8uint32{detail::simdlib::binary_func(data, other.data)
954
+ .call<&vsubq_u32>()};
955
+ }
956
+
957
+ simd8uint32& operator+=(const simd8uint32& other) {
958
+ data.val[0] = vaddq_u32(data.val[0], other.data.val[0]);
959
+ data.val[1] = vaddq_u32(data.val[1], other.data.val[1]);
960
+ return *this;
961
+ }
962
+
963
+ bool operator==(simd8uint32 other) const {
964
+ const auto equals = detail::simdlib::binary_func(data, other.data)
965
+ .call<&vceqq_u32>();
966
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
967
+ return vminvq_u32(equal) == 0xffffffff;
968
+ }
969
+
970
+ bool operator!=(simd8uint32 other) const {
971
+ return !(*this == other);
972
+ }
973
+
974
+ // Checks whether the other holds exactly the same bytes.
975
+ bool is_same_as(simd8uint32 other) const {
976
+ const bool equal0 =
977
+ (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
978
+ 0xffffffff);
979
+ const bool equal1 =
980
+ (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
981
+ 0xffffffff);
982
+
983
+ return equal0 && equal1;
984
+ }
985
+
674
986
  void clear() {
675
- detail::simdlib::set1(data, &vdupq_n_u32, static_cast<uint32_t>(0));
987
+ detail::simdlib::set1(data, static_cast<uint32_t>(0))
988
+ .call<&vdupq_n_u32>();
676
989
  }
677
990
 
678
991
  void storeu(uint32_t* ptr) const {
@@ -710,10 +1023,67 @@ struct simd8uint32 {
710
1023
  }
711
1024
 
712
1025
  void set1(uint32_t x) {
713
- detail::simdlib::set1(data, &vdupq_n_u32, x);
1026
+ detail::simdlib::set1(data, x).call<&vdupq_n_u32>();
1027
+ }
1028
+
1029
+ simd8uint32 unzip() const {
1030
+ return simd8uint32{uint32x4x2_t{
1031
+ vuzp1q_u32(data.val[0], data.val[1]),
1032
+ vuzp2q_u32(data.val[0], data.val[1])}};
714
1033
  }
715
1034
  };
716
1035
 
1036
+ // Vectorized version of the following code:
1037
+ // for (size_t i = 0; i < n; i++) {
1038
+ // bool flag = (candidateValues[i] < currentValues[i]);
1039
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1040
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1041
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1042
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1043
+ // }
1044
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1045
+ // the last equal value is saved instead of the first one), but this behavior
1046
+ // saves instructions.
1047
+ inline void cmplt_min_max_fast(
1048
+ const simd8uint32 candidateValues,
1049
+ const simd8uint32 candidateIndices,
1050
+ const simd8uint32 currentValues,
1051
+ const simd8uint32 currentIndices,
1052
+ simd8uint32& minValues,
1053
+ simd8uint32& minIndices,
1054
+ simd8uint32& maxValues,
1055
+ simd8uint32& maxIndices) {
1056
+ const uint32x4x2_t comparison = uint32x4x2_t{
1057
+ vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1058
+ vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1059
+
1060
+ minValues.data = uint32x4x2_t{
1061
+ vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1062
+ vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1063
+ minIndices.data = uint32x4x2_t{
1064
+ vbslq_u32(
1065
+ comparison.val[0],
1066
+ candidateIndices.data.val[0],
1067
+ currentIndices.data.val[0]),
1068
+ vbslq_u32(
1069
+ comparison.val[1],
1070
+ candidateIndices.data.val[1],
1071
+ currentIndices.data.val[1])};
1072
+
1073
+ maxValues.data = uint32x4x2_t{
1074
+ vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1075
+ vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1076
+ maxIndices.data = uint32x4x2_t{
1077
+ vbslq_u32(
1078
+ comparison.val[0],
1079
+ currentIndices.data.val[0],
1080
+ candidateIndices.data.val[0]),
1081
+ vbslq_u32(
1082
+ comparison.val[1],
1083
+ currentIndices.data.val[1],
1084
+ candidateIndices.data.val[1])};
1085
+ }
1086
+
717
1087
  struct simd8float32 {
718
1088
  float32x4x2_t data;
719
1089
 
@@ -734,8 +1104,22 @@ struct simd8float32 {
734
1104
  explicit simd8float32(const float* x)
735
1105
  : data{vld1q_f32(x), vld1q_f32(x + 4)} {}
736
1106
 
1107
+ explicit simd8float32(
1108
+ float f0,
1109
+ float f1,
1110
+ float f2,
1111
+ float f3,
1112
+ float f4,
1113
+ float f5,
1114
+ float f6,
1115
+ float f7) {
1116
+ float temp[8] = {f0, f1, f2, f3, f4, f5, f6, f7};
1117
+ data.val[0] = vld1q_f32(temp);
1118
+ data.val[1] = vld1q_f32(temp + 4);
1119
+ }
1120
+
737
1121
  void clear() {
738
- detail::simdlib::set1(data, &vdupq_n_f32, 0.f);
1122
+ detail::simdlib::set1(data, 0.f).call<&vdupq_n_f32>();
739
1123
  }
740
1124
 
741
1125
  void storeu(float* ptr) const {
@@ -761,18 +1145,50 @@ struct simd8float32 {
761
1145
  }
762
1146
 
763
1147
  simd8float32 operator*(const simd8float32& other) const {
764
- return simd8float32{
765
- detail::simdlib::binary_func(data, other.data, &vmulq_f32)};
1148
+ return simd8float32{detail::simdlib::binary_func(data, other.data)
1149
+ .call<&vmulq_f32>()};
766
1150
  }
767
1151
 
768
1152
  simd8float32 operator+(const simd8float32& other) const {
769
- return simd8float32{
770
- detail::simdlib::binary_func(data, other.data, &vaddq_f32)};
1153
+ return simd8float32{detail::simdlib::binary_func(data, other.data)
1154
+ .call<&vaddq_f32>()};
771
1155
  }
772
1156
 
773
1157
  simd8float32 operator-(const simd8float32& other) const {
774
- return simd8float32{
775
- detail::simdlib::binary_func(data, other.data, &vsubq_f32)};
1158
+ return simd8float32{detail::simdlib::binary_func(data, other.data)
1159
+ .call<&vsubq_f32>()};
1160
+ }
1161
+
1162
+ simd8float32& operator+=(const simd8float32& other) {
1163
+ // In this context, it is more compiler friendly to write intrinsics
1164
+ // directly instead of using binary_func
1165
+ data.val[0] = vaddq_f32(data.val[0], other.data.val[0]);
1166
+ data.val[1] = vaddq_f32(data.val[1], other.data.val[1]);
1167
+ return *this;
1168
+ }
1169
+
1170
+ bool operator==(simd8float32 other) const {
1171
+ const auto equals =
1172
+ detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
1173
+ .call<&vceqq_f32>();
1174
+ const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1175
+ return vminvq_u32(equal) == 0xffffffff;
1176
+ }
1177
+
1178
+ bool operator!=(simd8float32 other) const {
1179
+ return !(*this == other);
1180
+ }
1181
+
1182
+ // Checks whether the other holds exactly the same bytes.
1183
+ bool is_same_as(simd8float32 other) const {
1184
+ const bool equal0 =
1185
+ (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
1186
+ 0xffffffff);
1187
+ const bool equal1 =
1188
+ (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
1189
+ 0xffffffff);
1190
+
1191
+ return equal0 && equal1;
776
1192
  }
777
1193
 
778
1194
  std::string tostring() const {
@@ -783,17 +1199,17 @@ struct simd8float32 {
783
1199
  // hadd does not cross lanes
784
1200
  inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
785
1201
  return simd8float32{
786
- detail::simdlib::binary_func(a.data, b.data, &vpaddq_f32)};
1202
+ detail::simdlib::binary_func(a.data, b.data).call<&vpaddq_f32>()};
787
1203
  }
788
1204
 
789
1205
  inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
790
1206
  return simd8float32{
791
- detail::simdlib::binary_func(a.data, b.data, &vzip1q_f32)};
1207
+ detail::simdlib::binary_func(a.data, b.data).call<&vzip1q_f32>()};
792
1208
  }
793
1209
 
794
1210
  inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
795
1211
  return simd8float32{
796
- detail::simdlib::binary_func(a.data, b.data, &vzip2q_f32)};
1212
+ detail::simdlib::binary_func(a.data, b.data).call<&vzip2q_f32>()};
797
1213
  }
798
1214
 
799
1215
  // compute a * b + c
@@ -806,20 +1222,129 @@ inline simd8float32 fmadd(
806
1222
  vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}};
807
1223
  }
808
1224
 
1225
+ // The following primitive is a vectorized version of the following code
1226
+ // snippet:
1227
+ // float lowestValue = HUGE_VAL;
1228
+ // uint lowestIndex = 0;
1229
+ // for (size_t i = 0; i < n; i++) {
1230
+ // if (values[i] < lowestValue) {
1231
+ // lowestValue = values[i];
1232
+ // lowestIndex = i;
1233
+ // }
1234
+ // }
1235
+ // Vectorized version can be implemented via two operations: cmp and blend
1236
+ // with something like this:
1237
+ // lowestValues = [HUGE_VAL; 8];
1238
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
1239
+ // for (size_t i = 0; i < n; i += 8) {
1240
+ // auto comparison = cmp(values + i, lowestValues);
1241
+ // lowestValues = blend(
1242
+ // comparison,
1243
+ // values + i,
1244
+ // lowestValues);
1245
+ // lowestIndices = blend(
1246
+ // comparison,
1247
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
1248
+ // lowestIndices);
1249
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
1250
+ // }
1251
+ // The problem is that blend primitive needs very different instruction
1252
+ // order for AVX and ARM.
1253
+ // So, let's introduce a combination of these two in order to avoid
1254
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
1255
+ // these two ops (cmp and blend) are very often used together.
1256
+ inline void cmplt_and_blend_inplace(
1257
+ const simd8float32 candidateValues,
1258
+ const simd8uint32 candidateIndices,
1259
+ simd8float32& lowestValues,
1260
+ simd8uint32& lowestIndices) {
1261
+ const auto comparison = detail::simdlib::binary_func<::uint32x4x2_t>(
1262
+ candidateValues.data, lowestValues.data)
1263
+ .call<&vcltq_f32>();
1264
+
1265
+ lowestValues.data = float32x4x2_t{
1266
+ vbslq_f32(
1267
+ comparison.val[0],
1268
+ candidateValues.data.val[0],
1269
+ lowestValues.data.val[0]),
1270
+ vbslq_f32(
1271
+ comparison.val[1],
1272
+ candidateValues.data.val[1],
1273
+ lowestValues.data.val[1])};
1274
+ lowestIndices.data = uint32x4x2_t{
1275
+ vbslq_u32(
1276
+ comparison.val[0],
1277
+ candidateIndices.data.val[0],
1278
+ lowestIndices.data.val[0]),
1279
+ vbslq_u32(
1280
+ comparison.val[1],
1281
+ candidateIndices.data.val[1],
1282
+ lowestIndices.data.val[1])};
1283
+ }
1284
+
1285
+ // Vectorized version of the following code:
1286
+ // for (size_t i = 0; i < n; i++) {
1287
+ // bool flag = (candidateValues[i] < currentValues[i]);
1288
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1289
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1290
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1291
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1292
+ // }
1293
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1294
+ // the last equal value is saved instead of the first one), but this behavior
1295
+ // saves instructions.
1296
+ inline void cmplt_min_max_fast(
1297
+ const simd8float32 candidateValues,
1298
+ const simd8uint32 candidateIndices,
1299
+ const simd8float32 currentValues,
1300
+ const simd8uint32 currentIndices,
1301
+ simd8float32& minValues,
1302
+ simd8uint32& minIndices,
1303
+ simd8float32& maxValues,
1304
+ simd8uint32& maxIndices) {
1305
+ const uint32x4x2_t comparison = uint32x4x2_t{
1306
+ vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1307
+ vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1308
+
1309
+ minValues.data = float32x4x2_t{
1310
+ vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1311
+ vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1312
+ minIndices.data = uint32x4x2_t{
1313
+ vbslq_u32(
1314
+ comparison.val[0],
1315
+ candidateIndices.data.val[0],
1316
+ currentIndices.data.val[0]),
1317
+ vbslq_u32(
1318
+ comparison.val[1],
1319
+ candidateIndices.data.val[1],
1320
+ currentIndices.data.val[1])};
1321
+
1322
+ maxValues.data = float32x4x2_t{
1323
+ vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1324
+ vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1325
+ maxIndices.data = uint32x4x2_t{
1326
+ vbslq_u32(
1327
+ comparison.val[0],
1328
+ currentIndices.data.val[0],
1329
+ candidateIndices.data.val[0]),
1330
+ vbslq_u32(
1331
+ comparison.val[1],
1332
+ currentIndices.data.val[1],
1333
+ candidateIndices.data.val[1])};
1334
+ }
1335
+
809
1336
  namespace {
810
1337
 
811
1338
  // get even float32's of a and b, interleaved
812
1339
  simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
813
- return simd8float32{float32x4x2_t{
814
- vuzp1q_f32(a.data.val[0], b.data.val[0]),
815
- vuzp1q_f32(a.data.val[1], b.data.val[1])}};
1340
+ return simd8float32{
1341
+ detail::simdlib::binary_func(a.data, b.data).call<&vuzp1q_f32>()};
816
1342
  }
817
1343
 
818
1344
  // get odd float32's of a and b, interleaved
819
1345
  simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
820
- return simd8float32{float32x4x2_t{
821
- vuzp2q_f32(a.data.val[0], b.data.val[0]),
822
- vuzp2q_f32(a.data.val[1], b.data.val[1])}};
1346
+ return simd8float32{
1347
+ detail::simdlib::binary_func(a.data, b.data).call<&vuzp2q_f32>()};
823
1348
  }
824
1349
 
825
1350
  // 3 cycles