faiss 0.2.6 → 0.2.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/faiss/extconf.rb +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +2 -2
- data/vendor/faiss/faiss/AutoTune.cpp +15 -4
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +1 -5
- data/vendor/faiss/faiss/Clustering.h +0 -2
- data/vendor/faiss/faiss/IVFlib.h +0 -2
- data/vendor/faiss/faiss/Index.h +1 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
- data/vendor/faiss/faiss/IndexBinary.h +0 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
- data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
- data/vendor/faiss/faiss/IndexFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
- data/vendor/faiss/faiss/IndexHNSW.h +0 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
- data/vendor/faiss/faiss/IndexIDMap.h +0 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
- data/vendor/faiss/faiss/IndexIVF.h +121 -61
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
- data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
- data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
- data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
- data/vendor/faiss/faiss/IndexReplicas.h +0 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
- data/vendor/faiss/faiss/IndexShards.cpp +26 -109
- data/vendor/faiss/faiss/IndexShards.h +2 -3
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
- data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
- data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
- data/vendor/faiss/faiss/MetaIndexes.h +29 -0
- data/vendor/faiss/faiss/MetricType.h +14 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
- data/vendor/faiss/faiss/VectorTransform.h +1 -3
- data/vendor/faiss/faiss/clone_index.cpp +232 -18
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
- data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
- data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
- data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
- data/vendor/faiss/faiss/impl/HNSW.h +6 -9
- data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
- data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
- data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
- data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
- data/vendor/faiss/faiss/impl/NSG.h +4 -7
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +8 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
- data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
- data/vendor/faiss/faiss/utils/Heap.h +35 -1
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
- data/vendor/faiss/faiss/utils/distances.cpp +61 -7
- data/vendor/faiss/faiss/utils/distances.h +11 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
- data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
- data/vendor/faiss/faiss/utils/fp16.h +7 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
- data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
- data/vendor/faiss/faiss/utils/hamming.h +21 -10
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
- data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
- data/vendor/faiss/faiss/utils/sorting.h +71 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
- data/vendor/faiss/faiss/utils/utils.cpp +4 -176
- data/vendor/faiss/faiss/utils/utils.h +2 -9
- metadata +29 -3
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -18,6 +18,8 @@
|
|
18
18
|
|
19
19
|
#include <arm_neon.h>
|
20
20
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
22
|
+
|
21
23
|
namespace faiss {
|
22
24
|
|
23
25
|
namespace detail {
|
@@ -88,6 +90,23 @@ static inline float32x4x2_t reinterpret_f32(const float32x4x2_t& v) {
|
|
88
90
|
return v;
|
89
91
|
}
|
90
92
|
|
93
|
+
// Surprisingly, vdupq_n_u16 has the type of
|
94
|
+
// uint16x8_t (std::uint32_t) , and vdupq_n_u8 also has
|
95
|
+
// uint8x16_t (std::uint32_t) on **some environments**.
|
96
|
+
// We want argument type as same as the type of element
|
97
|
+
// of result vector type (std::uint16_t for uint16x8_t,
|
98
|
+
// and std::uint8_t for uint8x16_t) instead of
|
99
|
+
// std::uint32_t due to using set1 function templates,
|
100
|
+
// so let's fix the argument type here and use these
|
101
|
+
// overload below.
|
102
|
+
static inline ::uint16x8_t vdupq_n_u16(std::uint16_t v) {
|
103
|
+
return ::vdupq_n_u16(v);
|
104
|
+
}
|
105
|
+
|
106
|
+
static inline ::uint8x16_t vdupq_n_u8(std::uint8_t v) {
|
107
|
+
return ::vdupq_n_u8(v);
|
108
|
+
}
|
109
|
+
|
91
110
|
template <
|
92
111
|
typename T,
|
93
112
|
typename U = decltype(reinterpret_u8(std::declval<T>().data))>
|
@@ -119,11 +138,25 @@ static inline std::string bin(const S& simd) {
|
|
119
138
|
return std::string(bits);
|
120
139
|
}
|
121
140
|
|
122
|
-
template <typename
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
141
|
+
template <typename T>
|
142
|
+
using remove_cv_ref_t =
|
143
|
+
typename std::remove_reference<typename std::remove_cv<T>::type>::type;
|
144
|
+
|
145
|
+
template <typename D, typename T>
|
146
|
+
struct set1_impl {
|
147
|
+
D& d;
|
148
|
+
T t;
|
149
|
+
template <remove_cv_ref_t<decltype(std::declval<D>().val[0])> (*F)(T)>
|
150
|
+
inline void call() {
|
151
|
+
const auto v = F(t);
|
152
|
+
d.val[0] = v;
|
153
|
+
d.val[1] = v;
|
154
|
+
}
|
155
|
+
};
|
156
|
+
|
157
|
+
template <typename D, typename T>
|
158
|
+
static inline set1_impl<remove_cv_ref_t<D>, T> set1(D& d, T t) {
|
159
|
+
return {d, t};
|
127
160
|
}
|
128
161
|
|
129
162
|
template <typename T, size_t N, typename S>
|
@@ -142,20 +175,57 @@ static inline std::string elements_to_string(const char* fmt, const S& simd) {
|
|
142
175
|
return std::string(res);
|
143
176
|
}
|
144
177
|
|
145
|
-
template <typename T, typename
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
178
|
+
template <typename T, typename U>
|
179
|
+
struct unary_func_impl {
|
180
|
+
const U& a;
|
181
|
+
using Telem = remove_cv_ref_t<decltype(std::declval<T>().val[0])>;
|
182
|
+
using Uelem = remove_cv_ref_t<decltype(std::declval<U>().val[0])>;
|
183
|
+
template <Telem (*F)(Uelem)>
|
184
|
+
inline T call() {
|
185
|
+
T t;
|
186
|
+
t.val[0] = F(a.val[0]);
|
187
|
+
t.val[1] = F(a.val[1]);
|
188
|
+
return t;
|
189
|
+
}
|
190
|
+
};
|
191
|
+
|
192
|
+
template <typename T>
|
193
|
+
static inline unary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<T>> unary_func(
|
194
|
+
const T& a) {
|
195
|
+
return {a};
|
151
196
|
}
|
152
197
|
|
153
|
-
template <typename T, typename
|
154
|
-
static inline T
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
198
|
+
template <typename T, typename U>
|
199
|
+
static inline unary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<U>> unary_func(
|
200
|
+
const U& a) {
|
201
|
+
return {a};
|
202
|
+
}
|
203
|
+
|
204
|
+
template <typename T, typename U>
|
205
|
+
struct binary_func_impl {
|
206
|
+
const U& a;
|
207
|
+
const U& b;
|
208
|
+
using Telem = remove_cv_ref_t<decltype(std::declval<T>().val[0])>;
|
209
|
+
using Uelem = remove_cv_ref_t<decltype(std::declval<U>().val[0])>;
|
210
|
+
template <Telem (*F)(Uelem, Uelem)>
|
211
|
+
inline T call() {
|
212
|
+
T t;
|
213
|
+
t.val[0] = F(a.val[0], b.val[0]);
|
214
|
+
t.val[1] = F(a.val[1], b.val[1]);
|
215
|
+
return t;
|
216
|
+
}
|
217
|
+
};
|
218
|
+
|
219
|
+
template <typename T>
|
220
|
+
static inline binary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<T>>
|
221
|
+
binary_func(const T& a, const T& b) {
|
222
|
+
return {a, b};
|
223
|
+
}
|
224
|
+
|
225
|
+
template <typename T, typename U>
|
226
|
+
static inline binary_func_impl<remove_cv_ref_t<T>, remove_cv_ref_t<U>>
|
227
|
+
binary_func(const U& a, const U& b) {
|
228
|
+
return {a, b};
|
159
229
|
}
|
160
230
|
|
161
231
|
static inline uint16_t vmovmask_u8(const uint8x16_t& v) {
|
@@ -172,8 +242,8 @@ static inline uint32_t cmp_xe32(
|
|
172
242
|
const uint16x8x2_t& d0,
|
173
243
|
const uint16x8x2_t& d1,
|
174
244
|
const uint16x8x2_t& thr) {
|
175
|
-
const auto d0_thr = detail::simdlib::binary_func(d0, thr
|
176
|
-
const auto d1_thr = detail::simdlib::binary_func(d1, thr
|
245
|
+
const auto d0_thr = detail::simdlib::binary_func(d0, thr).call<F>();
|
246
|
+
const auto d1_thr = detail::simdlib::binary_func(d1, thr).call<F>();
|
177
247
|
const auto d0_mask = vmovmask_u8(
|
178
248
|
vmovn_high_u16(vmovn_u16(d0_thr.val[0]), d0_thr.val[1]));
|
179
249
|
const auto d1_mask = vmovmask_u8(
|
@@ -207,6 +277,44 @@ struct simd16uint16 {
|
|
207
277
|
|
208
278
|
explicit simd16uint16(const uint16x8x2_t& v) : data{v} {}
|
209
279
|
|
280
|
+
explicit simd16uint16(
|
281
|
+
uint16_t u0,
|
282
|
+
uint16_t u1,
|
283
|
+
uint16_t u2,
|
284
|
+
uint16_t u3,
|
285
|
+
uint16_t u4,
|
286
|
+
uint16_t u5,
|
287
|
+
uint16_t u6,
|
288
|
+
uint16_t u7,
|
289
|
+
uint16_t u8,
|
290
|
+
uint16_t u9,
|
291
|
+
uint16_t u10,
|
292
|
+
uint16_t u11,
|
293
|
+
uint16_t u12,
|
294
|
+
uint16_t u13,
|
295
|
+
uint16_t u14,
|
296
|
+
uint16_t u15) {
|
297
|
+
uint16_t temp[16] = {
|
298
|
+
u0,
|
299
|
+
u1,
|
300
|
+
u2,
|
301
|
+
u3,
|
302
|
+
u4,
|
303
|
+
u5,
|
304
|
+
u6,
|
305
|
+
u7,
|
306
|
+
u8,
|
307
|
+
u9,
|
308
|
+
u10,
|
309
|
+
u11,
|
310
|
+
u12,
|
311
|
+
u13,
|
312
|
+
u14,
|
313
|
+
u15};
|
314
|
+
data.val[0] = vld1q_u16(temp);
|
315
|
+
data.val[1] = vld1q_u16(temp + 8);
|
316
|
+
}
|
317
|
+
|
210
318
|
template <
|
211
319
|
typename T,
|
212
320
|
typename std::enable_if<
|
@@ -219,7 +327,8 @@ struct simd16uint16 {
|
|
219
327
|
: data{vld1q_u16(x), vld1q_u16(x + 8)} {}
|
220
328
|
|
221
329
|
void clear() {
|
222
|
-
detail::simdlib::set1(data,
|
330
|
+
detail::simdlib::set1(data, static_cast<uint16_t>(0))
|
331
|
+
.call<&detail::simdlib::vdupq_n_u16>();
|
223
332
|
}
|
224
333
|
|
225
334
|
void storeu(uint16_t* ptr) const {
|
@@ -257,12 +366,12 @@ struct simd16uint16 {
|
|
257
366
|
}
|
258
367
|
|
259
368
|
void set1(uint16_t x) {
|
260
|
-
detail::simdlib::set1(data,
|
369
|
+
detail::simdlib::set1(data, x).call<&detail::simdlib::vdupq_n_u16>();
|
261
370
|
}
|
262
371
|
|
263
372
|
simd16uint16 operator*(const simd16uint16& other) const {
|
264
|
-
return simd16uint16{
|
265
|
-
|
373
|
+
return simd16uint16{detail::simdlib::binary_func(data, other.data)
|
374
|
+
.call<&vmulq_u16>()};
|
266
375
|
}
|
267
376
|
|
268
377
|
// shift must be known at compile time
|
@@ -271,50 +380,56 @@ struct simd16uint16 {
|
|
271
380
|
case 0:
|
272
381
|
return *this;
|
273
382
|
case 1:
|
274
|
-
return simd16uint16{detail::simdlib::unary_func(
|
275
|
-
|
383
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
384
|
+
.call<detail::simdlib::vshrq<1>>()};
|
276
385
|
case 2:
|
277
|
-
return simd16uint16{detail::simdlib::unary_func(
|
278
|
-
|
386
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
387
|
+
.call<detail::simdlib::vshrq<2>>()};
|
279
388
|
case 3:
|
280
|
-
return simd16uint16{detail::simdlib::unary_func(
|
281
|
-
|
389
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
390
|
+
.call<detail::simdlib::vshrq<3>>()};
|
282
391
|
case 4:
|
283
|
-
return simd16uint16{detail::simdlib::unary_func(
|
284
|
-
|
392
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
393
|
+
.call<detail::simdlib::vshrq<4>>()};
|
285
394
|
case 5:
|
286
|
-
return simd16uint16{detail::simdlib::unary_func(
|
287
|
-
|
395
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
396
|
+
.call<detail::simdlib::vshrq<5>>()};
|
288
397
|
case 6:
|
289
|
-
return simd16uint16{detail::simdlib::unary_func(
|
290
|
-
|
398
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
399
|
+
.call<detail::simdlib::vshrq<6>>()};
|
291
400
|
case 7:
|
292
|
-
return simd16uint16{detail::simdlib::unary_func(
|
293
|
-
|
401
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
402
|
+
.call<detail::simdlib::vshrq<7>>()};
|
294
403
|
case 8:
|
295
|
-
return simd16uint16{detail::simdlib::unary_func(
|
296
|
-
|
404
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
405
|
+
.call<detail::simdlib::vshrq<8>>()};
|
297
406
|
case 9:
|
298
|
-
return simd16uint16{detail::simdlib::unary_func(
|
299
|
-
|
407
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
408
|
+
.call<detail::simdlib::vshrq<9>>()};
|
300
409
|
case 10:
|
301
|
-
return simd16uint16{
|
302
|
-
|
410
|
+
return simd16uint16{
|
411
|
+
detail::simdlib::unary_func(data)
|
412
|
+
.call<detail::simdlib::vshrq<10>>()};
|
303
413
|
case 11:
|
304
|
-
return simd16uint16{
|
305
|
-
|
414
|
+
return simd16uint16{
|
415
|
+
detail::simdlib::unary_func(data)
|
416
|
+
.call<detail::simdlib::vshrq<11>>()};
|
306
417
|
case 12:
|
307
|
-
return simd16uint16{
|
308
|
-
|
418
|
+
return simd16uint16{
|
419
|
+
detail::simdlib::unary_func(data)
|
420
|
+
.call<detail::simdlib::vshrq<12>>()};
|
309
421
|
case 13:
|
310
|
-
return simd16uint16{
|
311
|
-
|
422
|
+
return simd16uint16{
|
423
|
+
detail::simdlib::unary_func(data)
|
424
|
+
.call<detail::simdlib::vshrq<13>>()};
|
312
425
|
case 14:
|
313
|
-
return simd16uint16{
|
314
|
-
|
426
|
+
return simd16uint16{
|
427
|
+
detail::simdlib::unary_func(data)
|
428
|
+
.call<detail::simdlib::vshrq<14>>()};
|
315
429
|
case 15:
|
316
|
-
return simd16uint16{
|
317
|
-
|
430
|
+
return simd16uint16{
|
431
|
+
detail::simdlib::unary_func(data)
|
432
|
+
.call<detail::simdlib::vshrq<15>>()};
|
318
433
|
default:
|
319
434
|
FAISS_THROW_FMT("Invalid shift %d", shift);
|
320
435
|
}
|
@@ -326,50 +441,56 @@ struct simd16uint16 {
|
|
326
441
|
case 0:
|
327
442
|
return *this;
|
328
443
|
case 1:
|
329
|
-
return simd16uint16{detail::simdlib::unary_func(
|
330
|
-
|
444
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
445
|
+
.call<detail::simdlib::vshlq<1>>()};
|
331
446
|
case 2:
|
332
|
-
return simd16uint16{detail::simdlib::unary_func(
|
333
|
-
|
447
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
448
|
+
.call<detail::simdlib::vshlq<2>>()};
|
334
449
|
case 3:
|
335
|
-
return simd16uint16{detail::simdlib::unary_func(
|
336
|
-
|
450
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
451
|
+
.call<detail::simdlib::vshlq<3>>()};
|
337
452
|
case 4:
|
338
|
-
return simd16uint16{detail::simdlib::unary_func(
|
339
|
-
|
453
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
454
|
+
.call<detail::simdlib::vshlq<4>>()};
|
340
455
|
case 5:
|
341
|
-
return simd16uint16{detail::simdlib::unary_func(
|
342
|
-
|
456
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
457
|
+
.call<detail::simdlib::vshlq<5>>()};
|
343
458
|
case 6:
|
344
|
-
return simd16uint16{detail::simdlib::unary_func(
|
345
|
-
|
459
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
460
|
+
.call<detail::simdlib::vshlq<6>>()};
|
346
461
|
case 7:
|
347
|
-
return simd16uint16{detail::simdlib::unary_func(
|
348
|
-
|
462
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
463
|
+
.call<detail::simdlib::vshlq<7>>()};
|
349
464
|
case 8:
|
350
|
-
return simd16uint16{detail::simdlib::unary_func(
|
351
|
-
|
465
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
466
|
+
.call<detail::simdlib::vshlq<8>>()};
|
352
467
|
case 9:
|
353
|
-
return simd16uint16{detail::simdlib::unary_func(
|
354
|
-
|
468
|
+
return simd16uint16{detail::simdlib::unary_func(data)
|
469
|
+
.call<detail::simdlib::vshlq<9>>()};
|
355
470
|
case 10:
|
356
|
-
return simd16uint16{
|
357
|
-
|
471
|
+
return simd16uint16{
|
472
|
+
detail::simdlib::unary_func(data)
|
473
|
+
.call<detail::simdlib::vshlq<10>>()};
|
358
474
|
case 11:
|
359
|
-
return simd16uint16{
|
360
|
-
|
475
|
+
return simd16uint16{
|
476
|
+
detail::simdlib::unary_func(data)
|
477
|
+
.call<detail::simdlib::vshlq<11>>()};
|
361
478
|
case 12:
|
362
|
-
return simd16uint16{
|
363
|
-
|
479
|
+
return simd16uint16{
|
480
|
+
detail::simdlib::unary_func(data)
|
481
|
+
.call<detail::simdlib::vshlq<12>>()};
|
364
482
|
case 13:
|
365
|
-
return simd16uint16{
|
366
|
-
|
483
|
+
return simd16uint16{
|
484
|
+
detail::simdlib::unary_func(data)
|
485
|
+
.call<detail::simdlib::vshlq<13>>()};
|
367
486
|
case 14:
|
368
|
-
return simd16uint16{
|
369
|
-
|
487
|
+
return simd16uint16{
|
488
|
+
detail::simdlib::unary_func(data)
|
489
|
+
.call<detail::simdlib::vshlq<14>>()};
|
370
490
|
case 15:
|
371
|
-
return simd16uint16{
|
372
|
-
|
491
|
+
return simd16uint16{
|
492
|
+
detail::simdlib::unary_func(data)
|
493
|
+
.call<detail::simdlib::vshlq<15>>()};
|
373
494
|
default:
|
374
495
|
FAISS_THROW_FMT("Invalid shift %d", shift);
|
375
496
|
}
|
@@ -386,13 +507,13 @@ struct simd16uint16 {
|
|
386
507
|
}
|
387
508
|
|
388
509
|
simd16uint16 operator+(const simd16uint16& other) const {
|
389
|
-
return simd16uint16{
|
390
|
-
|
510
|
+
return simd16uint16{detail::simdlib::binary_func(data, other.data)
|
511
|
+
.call<&vaddq_u16>()};
|
391
512
|
}
|
392
513
|
|
393
514
|
simd16uint16 operator-(const simd16uint16& other) const {
|
394
|
-
return simd16uint16{
|
395
|
-
|
515
|
+
return simd16uint16{detail::simdlib::binary_func(data, other.data)
|
516
|
+
.call<&vsubq_u16>()};
|
396
517
|
}
|
397
518
|
|
398
519
|
template <
|
@@ -401,10 +522,10 @@ struct simd16uint16 {
|
|
401
522
|
detail::simdlib::is_simd256bit<T>::value,
|
402
523
|
std::nullptr_t>::type = nullptr>
|
403
524
|
simd16uint16 operator&(const T& other) const {
|
404
|
-
return simd16uint16{
|
405
|
-
|
406
|
-
|
407
|
-
|
525
|
+
return simd16uint16{
|
526
|
+
detail::simdlib::binary_func(
|
527
|
+
data, detail::simdlib::reinterpret_u16(other.data))
|
528
|
+
.template call<&vandq_u16>()};
|
408
529
|
}
|
409
530
|
|
410
531
|
template <
|
@@ -413,20 +534,45 @@ struct simd16uint16 {
|
|
413
534
|
detail::simdlib::is_simd256bit<T>::value,
|
414
535
|
std::nullptr_t>::type = nullptr>
|
415
536
|
simd16uint16 operator|(const T& other) const {
|
416
|
-
return simd16uint16{
|
417
|
-
|
418
|
-
|
419
|
-
|
537
|
+
return simd16uint16{
|
538
|
+
detail::simdlib::binary_func(
|
539
|
+
data, detail::simdlib::reinterpret_u16(other.data))
|
540
|
+
.template call<&vorrq_u16>()};
|
541
|
+
}
|
542
|
+
|
543
|
+
template <
|
544
|
+
typename T,
|
545
|
+
typename std::enable_if<
|
546
|
+
detail::simdlib::is_simd256bit<T>::value,
|
547
|
+
std::nullptr_t>::type = nullptr>
|
548
|
+
simd16uint16 operator^(const T& other) const {
|
549
|
+
return simd16uint16{
|
550
|
+
detail::simdlib::binary_func(
|
551
|
+
data, detail::simdlib::reinterpret_u16(other.data))
|
552
|
+
.template call<&veorq_u16>()};
|
420
553
|
}
|
421
554
|
|
422
555
|
// returns binary masks
|
423
556
|
simd16uint16 operator==(const simd16uint16& other) const {
|
424
|
-
return simd16uint16{
|
425
|
-
|
557
|
+
return simd16uint16{detail::simdlib::binary_func(data, other.data)
|
558
|
+
.call<&vceqq_u16>()};
|
559
|
+
}
|
560
|
+
|
561
|
+
// Checks whether the other holds exactly the same bytes.
|
562
|
+
bool is_same_as(simd16uint16 other) const {
|
563
|
+
const bool equal0 =
|
564
|
+
(vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
|
565
|
+
0xffff);
|
566
|
+
const bool equal1 =
|
567
|
+
(vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
|
568
|
+
0xffff);
|
569
|
+
|
570
|
+
return equal0 && equal1;
|
426
571
|
}
|
427
572
|
|
428
573
|
simd16uint16 operator~() const {
|
429
|
-
return simd16uint16{
|
574
|
+
return simd16uint16{
|
575
|
+
detail::simdlib::unary_func(data).call<&vmvnq_u16>()};
|
430
576
|
}
|
431
577
|
|
432
578
|
// get scalar at index 0
|
@@ -437,8 +583,8 @@ struct simd16uint16 {
|
|
437
583
|
// mask of elements where this >= thresh
|
438
584
|
// 2 bit per component: 16 * 2 = 32 bit
|
439
585
|
uint32_t ge_mask(const simd16uint16& thresh) const {
|
440
|
-
const auto input =
|
441
|
-
|
586
|
+
const auto input = detail::simdlib::binary_func(data, thresh.data)
|
587
|
+
.call<&vcgeq_u16>();
|
442
588
|
const auto vmovmask_u16 = [](uint16x8_t v) -> uint16_t {
|
443
589
|
uint16_t d[8];
|
444
590
|
const auto v2 = vreinterpretq_u32_u16(vshrq_n_u16(v, 14));
|
@@ -471,23 +617,25 @@ struct simd16uint16 {
|
|
471
617
|
}
|
472
618
|
|
473
619
|
void accu_min(const simd16uint16& incoming) {
|
474
|
-
data = detail::simdlib::binary_func(incoming.data, data
|
620
|
+
data = detail::simdlib::binary_func(incoming.data, data)
|
621
|
+
.call<&vminq_u16>();
|
475
622
|
}
|
476
623
|
|
477
624
|
void accu_max(const simd16uint16& incoming) {
|
478
|
-
data = detail::simdlib::binary_func(incoming.data, data
|
625
|
+
data = detail::simdlib::binary_func(incoming.data, data)
|
626
|
+
.call<&vmaxq_u16>();
|
479
627
|
}
|
480
628
|
};
|
481
629
|
|
482
630
|
// not really a std::min because it returns an elementwise min
|
483
631
|
inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
|
484
632
|
return simd16uint16{
|
485
|
-
detail::simdlib::binary_func(av.data, bv.data
|
633
|
+
detail::simdlib::binary_func(av.data, bv.data).call<&vminq_u16>()};
|
486
634
|
}
|
487
635
|
|
488
636
|
inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
|
489
637
|
return simd16uint16{
|
490
|
-
detail::simdlib::binary_func(av.data, bv.data
|
638
|
+
detail::simdlib::binary_func(av.data, bv.data).call<&vmaxq_u16>()};
|
491
639
|
}
|
492
640
|
|
493
641
|
// decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
|
@@ -515,6 +663,63 @@ inline uint32_t cmp_le32(
|
|
515
663
|
return detail::simdlib::cmp_xe32<&vcleq_u16>(d0.data, d1.data, thr.data);
|
516
664
|
}
|
517
665
|
|
666
|
+
// hadd does not cross lanes
|
667
|
+
inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
|
668
|
+
return simd16uint16{
|
669
|
+
detail::simdlib::binary_func(a.data, b.data).call<&vpaddq_u16>()};
|
670
|
+
}
|
671
|
+
|
672
|
+
// Vectorized version of the following code:
|
673
|
+
// for (size_t i = 0; i < n; i++) {
|
674
|
+
// bool flag = (candidateValues[i] < currentValues[i]);
|
675
|
+
// minValues[i] = flag ? candidateValues[i] : currentValues[i];
|
676
|
+
// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
|
677
|
+
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
|
678
|
+
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
|
679
|
+
// }
|
680
|
+
// Max indices evaluation is inaccurate in case of equal values (the index of
|
681
|
+
// the last equal value is saved instead of the first one), but this behavior
|
682
|
+
// saves instructions.
|
683
|
+
inline void cmplt_min_max_fast(
|
684
|
+
const simd16uint16 candidateValues,
|
685
|
+
const simd16uint16 candidateIndices,
|
686
|
+
const simd16uint16 currentValues,
|
687
|
+
const simd16uint16 currentIndices,
|
688
|
+
simd16uint16& minValues,
|
689
|
+
simd16uint16& minIndices,
|
690
|
+
simd16uint16& maxValues,
|
691
|
+
simd16uint16& maxIndices) {
|
692
|
+
const uint16x8x2_t comparison = uint16x8x2_t{
|
693
|
+
vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
|
694
|
+
vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
|
695
|
+
|
696
|
+
minValues.data = uint16x8x2_t{
|
697
|
+
vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
|
698
|
+
vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
|
699
|
+
minIndices.data = uint16x8x2_t{
|
700
|
+
vbslq_u16(
|
701
|
+
comparison.val[0],
|
702
|
+
candidateIndices.data.val[0],
|
703
|
+
currentIndices.data.val[0]),
|
704
|
+
vbslq_u16(
|
705
|
+
comparison.val[1],
|
706
|
+
candidateIndices.data.val[1],
|
707
|
+
currentIndices.data.val[1])};
|
708
|
+
|
709
|
+
maxValues.data = uint16x8x2_t{
|
710
|
+
vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
|
711
|
+
vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
|
712
|
+
maxIndices.data = uint16x8x2_t{
|
713
|
+
vbslq_u16(
|
714
|
+
comparison.val[0],
|
715
|
+
currentIndices.data.val[0],
|
716
|
+
candidateIndices.data.val[0]),
|
717
|
+
vbslq_u16(
|
718
|
+
comparison.val[1],
|
719
|
+
currentIndices.data.val[1],
|
720
|
+
candidateIndices.data.val[1])};
|
721
|
+
}
|
722
|
+
|
518
723
|
// vector of 32 unsigned 8-bit integers
|
519
724
|
struct simd32uint8 {
|
520
725
|
uint8x16x2_t data;
|
@@ -527,6 +732,47 @@ struct simd32uint8 {
|
|
527
732
|
|
528
733
|
explicit simd32uint8(const uint8x16x2_t& v) : data{v} {}
|
529
734
|
|
735
|
+
template <
|
736
|
+
uint8_t _0,
|
737
|
+
uint8_t _1,
|
738
|
+
uint8_t _2,
|
739
|
+
uint8_t _3,
|
740
|
+
uint8_t _4,
|
741
|
+
uint8_t _5,
|
742
|
+
uint8_t _6,
|
743
|
+
uint8_t _7,
|
744
|
+
uint8_t _8,
|
745
|
+
uint8_t _9,
|
746
|
+
uint8_t _10,
|
747
|
+
uint8_t _11,
|
748
|
+
uint8_t _12,
|
749
|
+
uint8_t _13,
|
750
|
+
uint8_t _14,
|
751
|
+
uint8_t _15,
|
752
|
+
uint8_t _16,
|
753
|
+
uint8_t _17,
|
754
|
+
uint8_t _18,
|
755
|
+
uint8_t _19,
|
756
|
+
uint8_t _20,
|
757
|
+
uint8_t _21,
|
758
|
+
uint8_t _22,
|
759
|
+
uint8_t _23,
|
760
|
+
uint8_t _24,
|
761
|
+
uint8_t _25,
|
762
|
+
uint8_t _26,
|
763
|
+
uint8_t _27,
|
764
|
+
uint8_t _28,
|
765
|
+
uint8_t _29,
|
766
|
+
uint8_t _30,
|
767
|
+
uint8_t _31>
|
768
|
+
static simd32uint8 create() {
|
769
|
+
constexpr uint8_t ds[32] = {_0, _1, _2, _3, _4, _5, _6, _7,
|
770
|
+
_8, _9, _10, _11, _12, _13, _14, _15,
|
771
|
+
_16, _17, _18, _19, _20, _21, _22, _23,
|
772
|
+
_24, _25, _26, _27, _28, _29, _30, _31};
|
773
|
+
return simd32uint8{ds};
|
774
|
+
}
|
775
|
+
|
530
776
|
template <
|
531
777
|
typename T,
|
532
778
|
typename std::enable_if<
|
@@ -539,7 +785,8 @@ struct simd32uint8 {
|
|
539
785
|
: data{vld1q_u8(x), vld1q_u8(x + 16)} {}
|
540
786
|
|
541
787
|
void clear() {
|
542
|
-
detail::simdlib::set1(data,
|
788
|
+
detail::simdlib::set1(data, static_cast<uint8_t>(0))
|
789
|
+
.call<&detail::simdlib::vdupq_n_u8>();
|
543
790
|
}
|
544
791
|
|
545
792
|
void storeu(uint8_t* ptr) const {
|
@@ -582,7 +829,7 @@ struct simd32uint8 {
|
|
582
829
|
}
|
583
830
|
|
584
831
|
void set1(uint8_t x) {
|
585
|
-
detail::simdlib::set1(data,
|
832
|
+
detail::simdlib::set1(data, x).call<&detail::simdlib::vdupq_n_u8>();
|
586
833
|
}
|
587
834
|
|
588
835
|
template <
|
@@ -591,19 +838,21 @@ struct simd32uint8 {
|
|
591
838
|
detail::simdlib::is_simd256bit<T>::value,
|
592
839
|
std::nullptr_t>::type = nullptr>
|
593
840
|
simd32uint8 operator&(const T& other) const {
|
594
|
-
return simd32uint8{
|
595
|
-
|
841
|
+
return simd32uint8{
|
842
|
+
detail::simdlib::binary_func(
|
843
|
+
data, detail::simdlib::reinterpret_u8(other.data))
|
844
|
+
.template call<&vandq_u8>()};
|
596
845
|
}
|
597
846
|
|
598
847
|
simd32uint8 operator+(const simd32uint8& other) const {
|
599
|
-
return simd32uint8{
|
600
|
-
|
848
|
+
return simd32uint8{detail::simdlib::binary_func(data, other.data)
|
849
|
+
.call<&vaddq_u8>()};
|
601
850
|
}
|
602
851
|
|
603
852
|
// The very important operation that everything relies on
|
604
853
|
simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
|
605
|
-
return simd32uint8{
|
606
|
-
|
854
|
+
return simd32uint8{detail::simdlib::binary_func(data, idx.data)
|
855
|
+
.call<&vqtbl1q_u8>()};
|
607
856
|
}
|
608
857
|
|
609
858
|
simd32uint8 operator+=(const simd32uint8& other) {
|
@@ -618,6 +867,16 @@ struct simd32uint8 {
|
|
618
867
|
vst1q_u8(tab, data.val[high]);
|
619
868
|
return tab[i - high * 16];
|
620
869
|
}
|
870
|
+
|
871
|
+
// Checks whether the other holds exactly the same bytes.
|
872
|
+
bool is_same_as(simd32uint8 other) const {
|
873
|
+
const bool equal0 =
|
874
|
+
(vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
|
875
|
+
const bool equal1 =
|
876
|
+
(vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
|
877
|
+
|
878
|
+
return equal0 && equal1;
|
879
|
+
}
|
621
880
|
};
|
622
881
|
|
623
882
|
// convert with saturation
|
@@ -671,8 +930,62 @@ struct simd8uint32 {
|
|
671
930
|
|
672
931
|
explicit simd8uint32(const uint8_t* x) : simd8uint32(simd32uint8(x)) {}
|
673
932
|
|
933
|
+
explicit simd8uint32(
|
934
|
+
uint32_t u0,
|
935
|
+
uint32_t u1,
|
936
|
+
uint32_t u2,
|
937
|
+
uint32_t u3,
|
938
|
+
uint32_t u4,
|
939
|
+
uint32_t u5,
|
940
|
+
uint32_t u6,
|
941
|
+
uint32_t u7) {
|
942
|
+
uint32_t temp[8] = {u0, u1, u2, u3, u4, u5, u6, u7};
|
943
|
+
data.val[0] = vld1q_u32(temp);
|
944
|
+
data.val[1] = vld1q_u32(temp + 4);
|
945
|
+
}
|
946
|
+
|
947
|
+
simd8uint32 operator+(simd8uint32 other) const {
|
948
|
+
return simd8uint32{detail::simdlib::binary_func(data, other.data)
|
949
|
+
.call<&vaddq_u32>()};
|
950
|
+
}
|
951
|
+
|
952
|
+
simd8uint32 operator-(simd8uint32 other) const {
|
953
|
+
return simd8uint32{detail::simdlib::binary_func(data, other.data)
|
954
|
+
.call<&vsubq_u32>()};
|
955
|
+
}
|
956
|
+
|
957
|
+
simd8uint32& operator+=(const simd8uint32& other) {
|
958
|
+
data.val[0] = vaddq_u32(data.val[0], other.data.val[0]);
|
959
|
+
data.val[1] = vaddq_u32(data.val[1], other.data.val[1]);
|
960
|
+
return *this;
|
961
|
+
}
|
962
|
+
|
963
|
+
bool operator==(simd8uint32 other) const {
|
964
|
+
const auto equals = detail::simdlib::binary_func(data, other.data)
|
965
|
+
.call<&vceqq_u32>();
|
966
|
+
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
|
967
|
+
return vminvq_u32(equal) == 0xffffffff;
|
968
|
+
}
|
969
|
+
|
970
|
+
bool operator!=(simd8uint32 other) const {
|
971
|
+
return !(*this == other);
|
972
|
+
}
|
973
|
+
|
974
|
+
// Checks whether the other holds exactly the same bytes.
|
975
|
+
bool is_same_as(simd8uint32 other) const {
|
976
|
+
const bool equal0 =
|
977
|
+
(vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
|
978
|
+
0xffffffff);
|
979
|
+
const bool equal1 =
|
980
|
+
(vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
|
981
|
+
0xffffffff);
|
982
|
+
|
983
|
+
return equal0 && equal1;
|
984
|
+
}
|
985
|
+
|
674
986
|
void clear() {
|
675
|
-
detail::simdlib::set1(data,
|
987
|
+
detail::simdlib::set1(data, static_cast<uint32_t>(0))
|
988
|
+
.call<&vdupq_n_u32>();
|
676
989
|
}
|
677
990
|
|
678
991
|
void storeu(uint32_t* ptr) const {
|
@@ -710,10 +1023,67 @@ struct simd8uint32 {
|
|
710
1023
|
}
|
711
1024
|
|
712
1025
|
void set1(uint32_t x) {
|
713
|
-
detail::simdlib::set1(data,
|
1026
|
+
detail::simdlib::set1(data, x).call<&vdupq_n_u32>();
|
1027
|
+
}
|
1028
|
+
|
1029
|
+
simd8uint32 unzip() const {
|
1030
|
+
return simd8uint32{uint32x4x2_t{
|
1031
|
+
vuzp1q_u32(data.val[0], data.val[1]),
|
1032
|
+
vuzp2q_u32(data.val[0], data.val[1])}};
|
714
1033
|
}
|
715
1034
|
};
|
716
1035
|
|
1036
|
+
// Vectorized version of the following code:
|
1037
|
+
// for (size_t i = 0; i < n; i++) {
|
1038
|
+
// bool flag = (candidateValues[i] < currentValues[i]);
|
1039
|
+
// minValues[i] = flag ? candidateValues[i] : currentValues[i];
|
1040
|
+
// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
|
1041
|
+
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
|
1042
|
+
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
|
1043
|
+
// }
|
1044
|
+
// Max indices evaluation is inaccurate in case of equal values (the index of
|
1045
|
+
// the last equal value is saved instead of the first one), but this behavior
|
1046
|
+
// saves instructions.
|
1047
|
+
inline void cmplt_min_max_fast(
|
1048
|
+
const simd8uint32 candidateValues,
|
1049
|
+
const simd8uint32 candidateIndices,
|
1050
|
+
const simd8uint32 currentValues,
|
1051
|
+
const simd8uint32 currentIndices,
|
1052
|
+
simd8uint32& minValues,
|
1053
|
+
simd8uint32& minIndices,
|
1054
|
+
simd8uint32& maxValues,
|
1055
|
+
simd8uint32& maxIndices) {
|
1056
|
+
const uint32x4x2_t comparison = uint32x4x2_t{
|
1057
|
+
vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
|
1058
|
+
vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
|
1059
|
+
|
1060
|
+
minValues.data = uint32x4x2_t{
|
1061
|
+
vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
|
1062
|
+
vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
|
1063
|
+
minIndices.data = uint32x4x2_t{
|
1064
|
+
vbslq_u32(
|
1065
|
+
comparison.val[0],
|
1066
|
+
candidateIndices.data.val[0],
|
1067
|
+
currentIndices.data.val[0]),
|
1068
|
+
vbslq_u32(
|
1069
|
+
comparison.val[1],
|
1070
|
+
candidateIndices.data.val[1],
|
1071
|
+
currentIndices.data.val[1])};
|
1072
|
+
|
1073
|
+
maxValues.data = uint32x4x2_t{
|
1074
|
+
vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
|
1075
|
+
vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
|
1076
|
+
maxIndices.data = uint32x4x2_t{
|
1077
|
+
vbslq_u32(
|
1078
|
+
comparison.val[0],
|
1079
|
+
currentIndices.data.val[0],
|
1080
|
+
candidateIndices.data.val[0]),
|
1081
|
+
vbslq_u32(
|
1082
|
+
comparison.val[1],
|
1083
|
+
currentIndices.data.val[1],
|
1084
|
+
candidateIndices.data.val[1])};
|
1085
|
+
}
|
1086
|
+
|
717
1087
|
struct simd8float32 {
|
718
1088
|
float32x4x2_t data;
|
719
1089
|
|
@@ -734,8 +1104,22 @@ struct simd8float32 {
|
|
734
1104
|
explicit simd8float32(const float* x)
|
735
1105
|
: data{vld1q_f32(x), vld1q_f32(x + 4)} {}
|
736
1106
|
|
1107
|
+
explicit simd8float32(
|
1108
|
+
float f0,
|
1109
|
+
float f1,
|
1110
|
+
float f2,
|
1111
|
+
float f3,
|
1112
|
+
float f4,
|
1113
|
+
float f5,
|
1114
|
+
float f6,
|
1115
|
+
float f7) {
|
1116
|
+
float temp[8] = {f0, f1, f2, f3, f4, f5, f6, f7};
|
1117
|
+
data.val[0] = vld1q_f32(temp);
|
1118
|
+
data.val[1] = vld1q_f32(temp + 4);
|
1119
|
+
}
|
1120
|
+
|
737
1121
|
void clear() {
|
738
|
-
detail::simdlib::set1(data,
|
1122
|
+
detail::simdlib::set1(data, 0.f).call<&vdupq_n_f32>();
|
739
1123
|
}
|
740
1124
|
|
741
1125
|
void storeu(float* ptr) const {
|
@@ -761,18 +1145,50 @@ struct simd8float32 {
|
|
761
1145
|
}
|
762
1146
|
|
763
1147
|
simd8float32 operator*(const simd8float32& other) const {
|
764
|
-
return simd8float32{
|
765
|
-
|
1148
|
+
return simd8float32{detail::simdlib::binary_func(data, other.data)
|
1149
|
+
.call<&vmulq_f32>()};
|
766
1150
|
}
|
767
1151
|
|
768
1152
|
simd8float32 operator+(const simd8float32& other) const {
|
769
|
-
return simd8float32{
|
770
|
-
|
1153
|
+
return simd8float32{detail::simdlib::binary_func(data, other.data)
|
1154
|
+
.call<&vaddq_f32>()};
|
771
1155
|
}
|
772
1156
|
|
773
1157
|
simd8float32 operator-(const simd8float32& other) const {
|
774
|
-
return simd8float32{
|
775
|
-
|
1158
|
+
return simd8float32{detail::simdlib::binary_func(data, other.data)
|
1159
|
+
.call<&vsubq_f32>()};
|
1160
|
+
}
|
1161
|
+
|
1162
|
+
simd8float32& operator+=(const simd8float32& other) {
|
1163
|
+
// In this context, it is more compiler friendly to write intrinsics
|
1164
|
+
// directly instead of using binary_func
|
1165
|
+
data.val[0] = vaddq_f32(data.val[0], other.data.val[0]);
|
1166
|
+
data.val[1] = vaddq_f32(data.val[1], other.data.val[1]);
|
1167
|
+
return *this;
|
1168
|
+
}
|
1169
|
+
|
1170
|
+
bool operator==(simd8float32 other) const {
|
1171
|
+
const auto equals =
|
1172
|
+
detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
|
1173
|
+
.call<&vceqq_f32>();
|
1174
|
+
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
|
1175
|
+
return vminvq_u32(equal) == 0xffffffff;
|
1176
|
+
}
|
1177
|
+
|
1178
|
+
bool operator!=(simd8float32 other) const {
|
1179
|
+
return !(*this == other);
|
1180
|
+
}
|
1181
|
+
|
1182
|
+
// Checks whether the other holds exactly the same bytes.
|
1183
|
+
bool is_same_as(simd8float32 other) const {
|
1184
|
+
const bool equal0 =
|
1185
|
+
(vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
|
1186
|
+
0xffffffff);
|
1187
|
+
const bool equal1 =
|
1188
|
+
(vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
|
1189
|
+
0xffffffff);
|
1190
|
+
|
1191
|
+
return equal0 && equal1;
|
776
1192
|
}
|
777
1193
|
|
778
1194
|
std::string tostring() const {
|
@@ -783,17 +1199,17 @@ struct simd8float32 {
|
|
783
1199
|
// hadd does not cross lanes
|
784
1200
|
inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
|
785
1201
|
return simd8float32{
|
786
|
-
detail::simdlib::binary_func(a.data, b.data
|
1202
|
+
detail::simdlib::binary_func(a.data, b.data).call<&vpaddq_f32>()};
|
787
1203
|
}
|
788
1204
|
|
789
1205
|
inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
|
790
1206
|
return simd8float32{
|
791
|
-
detail::simdlib::binary_func(a.data, b.data
|
1207
|
+
detail::simdlib::binary_func(a.data, b.data).call<&vzip1q_f32>()};
|
792
1208
|
}
|
793
1209
|
|
794
1210
|
inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
|
795
1211
|
return simd8float32{
|
796
|
-
detail::simdlib::binary_func(a.data, b.data
|
1212
|
+
detail::simdlib::binary_func(a.data, b.data).call<&vzip2q_f32>()};
|
797
1213
|
}
|
798
1214
|
|
799
1215
|
// compute a * b + c
|
@@ -806,20 +1222,129 @@ inline simd8float32 fmadd(
|
|
806
1222
|
vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}};
|
807
1223
|
}
|
808
1224
|
|
1225
|
+
// The following primitive is a vectorized version of the following code
|
1226
|
+
// snippet:
|
1227
|
+
// float lowestValue = HUGE_VAL;
|
1228
|
+
// uint lowestIndex = 0;
|
1229
|
+
// for (size_t i = 0; i < n; i++) {
|
1230
|
+
// if (values[i] < lowestValue) {
|
1231
|
+
// lowestValue = values[i];
|
1232
|
+
// lowestIndex = i;
|
1233
|
+
// }
|
1234
|
+
// }
|
1235
|
+
// Vectorized version can be implemented via two operations: cmp and blend
|
1236
|
+
// with something like this:
|
1237
|
+
// lowestValues = [HUGE_VAL; 8];
|
1238
|
+
// lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
|
1239
|
+
// for (size_t i = 0; i < n; i += 8) {
|
1240
|
+
// auto comparison = cmp(values + i, lowestValues);
|
1241
|
+
// lowestValues = blend(
|
1242
|
+
// comparison,
|
1243
|
+
// values + i,
|
1244
|
+
// lowestValues);
|
1245
|
+
// lowestIndices = blend(
|
1246
|
+
// comparison,
|
1247
|
+
// i + {0, 1, 2, 3, 4, 5, 6, 7},
|
1248
|
+
// lowestIndices);
|
1249
|
+
// lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
|
1250
|
+
// }
|
1251
|
+
// The problem is that blend primitive needs very different instruction
|
1252
|
+
// order for AVX and ARM.
|
1253
|
+
// So, let's introduce a combination of these two in order to avoid
|
1254
|
+
// confusion for ppl who write in low-level SIMD instructions. Additionally,
|
1255
|
+
// these two ops (cmp and blend) are very often used together.
|
1256
|
+
inline void cmplt_and_blend_inplace(
|
1257
|
+
const simd8float32 candidateValues,
|
1258
|
+
const simd8uint32 candidateIndices,
|
1259
|
+
simd8float32& lowestValues,
|
1260
|
+
simd8uint32& lowestIndices) {
|
1261
|
+
const auto comparison = detail::simdlib::binary_func<::uint32x4x2_t>(
|
1262
|
+
candidateValues.data, lowestValues.data)
|
1263
|
+
.call<&vcltq_f32>();
|
1264
|
+
|
1265
|
+
lowestValues.data = float32x4x2_t{
|
1266
|
+
vbslq_f32(
|
1267
|
+
comparison.val[0],
|
1268
|
+
candidateValues.data.val[0],
|
1269
|
+
lowestValues.data.val[0]),
|
1270
|
+
vbslq_f32(
|
1271
|
+
comparison.val[1],
|
1272
|
+
candidateValues.data.val[1],
|
1273
|
+
lowestValues.data.val[1])};
|
1274
|
+
lowestIndices.data = uint32x4x2_t{
|
1275
|
+
vbslq_u32(
|
1276
|
+
comparison.val[0],
|
1277
|
+
candidateIndices.data.val[0],
|
1278
|
+
lowestIndices.data.val[0]),
|
1279
|
+
vbslq_u32(
|
1280
|
+
comparison.val[1],
|
1281
|
+
candidateIndices.data.val[1],
|
1282
|
+
lowestIndices.data.val[1])};
|
1283
|
+
}
|
1284
|
+
|
1285
|
+
// Vectorized version of the following code:
|
1286
|
+
// for (size_t i = 0; i < n; i++) {
|
1287
|
+
// bool flag = (candidateValues[i] < currentValues[i]);
|
1288
|
+
// minValues[i] = flag ? candidateValues[i] : currentValues[i];
|
1289
|
+
// minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
|
1290
|
+
// maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
|
1291
|
+
// maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
|
1292
|
+
// }
|
1293
|
+
// Max indices evaluation is inaccurate in case of equal values (the index of
|
1294
|
+
// the last equal value is saved instead of the first one), but this behavior
|
1295
|
+
// saves instructions.
|
1296
|
+
inline void cmplt_min_max_fast(
|
1297
|
+
const simd8float32 candidateValues,
|
1298
|
+
const simd8uint32 candidateIndices,
|
1299
|
+
const simd8float32 currentValues,
|
1300
|
+
const simd8uint32 currentIndices,
|
1301
|
+
simd8float32& minValues,
|
1302
|
+
simd8uint32& minIndices,
|
1303
|
+
simd8float32& maxValues,
|
1304
|
+
simd8uint32& maxIndices) {
|
1305
|
+
const uint32x4x2_t comparison = uint32x4x2_t{
|
1306
|
+
vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
|
1307
|
+
vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
|
1308
|
+
|
1309
|
+
minValues.data = float32x4x2_t{
|
1310
|
+
vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
|
1311
|
+
vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
|
1312
|
+
minIndices.data = uint32x4x2_t{
|
1313
|
+
vbslq_u32(
|
1314
|
+
comparison.val[0],
|
1315
|
+
candidateIndices.data.val[0],
|
1316
|
+
currentIndices.data.val[0]),
|
1317
|
+
vbslq_u32(
|
1318
|
+
comparison.val[1],
|
1319
|
+
candidateIndices.data.val[1],
|
1320
|
+
currentIndices.data.val[1])};
|
1321
|
+
|
1322
|
+
maxValues.data = float32x4x2_t{
|
1323
|
+
vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
|
1324
|
+
vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
|
1325
|
+
maxIndices.data = uint32x4x2_t{
|
1326
|
+
vbslq_u32(
|
1327
|
+
comparison.val[0],
|
1328
|
+
currentIndices.data.val[0],
|
1329
|
+
candidateIndices.data.val[0]),
|
1330
|
+
vbslq_u32(
|
1331
|
+
comparison.val[1],
|
1332
|
+
currentIndices.data.val[1],
|
1333
|
+
candidateIndices.data.val[1])};
|
1334
|
+
}
|
1335
|
+
|
809
1336
|
namespace {
|
810
1337
|
|
811
1338
|
// get even float32's of a and b, interleaved
|
812
1339
|
simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
|
813
|
-
return simd8float32{
|
814
|
-
|
815
|
-
vuzp1q_f32(a.data.val[1], b.data.val[1])}};
|
1340
|
+
return simd8float32{
|
1341
|
+
detail::simdlib::binary_func(a.data, b.data).call<&vuzp1q_f32>()};
|
816
1342
|
}
|
817
1343
|
|
818
1344
|
// get odd float32's of a and b, interleaved
|
819
1345
|
simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
|
820
|
-
return simd8float32{
|
821
|
-
|
822
|
-
vuzp2q_f32(a.data.val[1], b.data.val[1])}};
|
1346
|
+
return simd8float32{
|
1347
|
+
detail::simdlib::binary_func(a.data, b.data).call<&vuzp2q_f32>()};
|
823
1348
|
}
|
824
1349
|
|
825
1350
|
// 3 cycles
|