faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,832 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ // TODO: Support big endian (currently supporting only little endian)
11
+
12
+ #include <algorithm>
13
+ #include <cstddef>
14
+ #include <cstdint>
15
+ #include <cstring>
16
+ #include <string>
17
+ #include <type_traits>
18
+
19
+ #include <arm_neon.h>
20
+
21
+ namespace faiss {
22
+
23
+ namespace detail {
24
+
25
+ namespace simdlib {
26
+
27
+ static inline uint8x16x2_t reinterpret_u8(const uint8x16x2_t& v) {
28
+ return v;
29
+ }
30
+
31
+ static inline uint8x16x2_t reinterpret_u8(const uint16x8x2_t& v) {
32
+ return {vreinterpretq_u8_u16(v.val[0]), vreinterpretq_u8_u16(v.val[1])};
33
+ }
34
+
35
+ static inline uint8x16x2_t reinterpret_u8(const uint32x4x2_t& v) {
36
+ return {vreinterpretq_u8_u32(v.val[0]), vreinterpretq_u8_u32(v.val[1])};
37
+ }
38
+
39
+ static inline uint8x16x2_t reinterpret_u8(const float32x4x2_t& v) {
40
+ return {vreinterpretq_u8_f32(v.val[0]), vreinterpretq_u8_f32(v.val[1])};
41
+ }
42
+
43
+ static inline uint16x8x2_t reinterpret_u16(const uint8x16x2_t& v) {
44
+ return {vreinterpretq_u16_u8(v.val[0]), vreinterpretq_u16_u8(v.val[1])};
45
+ }
46
+
47
+ static inline uint16x8x2_t reinterpret_u16(const uint16x8x2_t& v) {
48
+ return v;
49
+ }
50
+
51
+ static inline uint16x8x2_t reinterpret_u16(const uint32x4x2_t& v) {
52
+ return {vreinterpretq_u16_u32(v.val[0]), vreinterpretq_u16_u32(v.val[1])};
53
+ }
54
+
55
+ static inline uint16x8x2_t reinterpret_u16(const float32x4x2_t& v) {
56
+ return {vreinterpretq_u16_f32(v.val[0]), vreinterpretq_u16_f32(v.val[1])};
57
+ }
58
+
59
+ static inline uint32x4x2_t reinterpret_u32(const uint8x16x2_t& v) {
60
+ return {vreinterpretq_u32_u8(v.val[0]), vreinterpretq_u32_u8(v.val[1])};
61
+ }
62
+
63
+ static inline uint32x4x2_t reinterpret_u32(const uint16x8x2_t& v) {
64
+ return {vreinterpretq_u32_u16(v.val[0]), vreinterpretq_u32_u16(v.val[1])};
65
+ }
66
+
67
+ static inline uint32x4x2_t reinterpret_u32(const uint32x4x2_t& v) {
68
+ return v;
69
+ }
70
+
71
+ static inline uint32x4x2_t reinterpret_u32(const float32x4x2_t& v) {
72
+ return {vreinterpretq_u32_f32(v.val[0]), vreinterpretq_u32_f32(v.val[1])};
73
+ }
74
+
75
+ static inline float32x4x2_t reinterpret_f32(const uint8x16x2_t& v) {
76
+ return {vreinterpretq_f32_u8(v.val[0]), vreinterpretq_f32_u8(v.val[1])};
77
+ }
78
+
79
+ static inline float32x4x2_t reinterpret_f32(const uint16x8x2_t& v) {
80
+ return {vreinterpretq_f32_u16(v.val[0]), vreinterpretq_f32_u16(v.val[1])};
81
+ }
82
+
83
+ static inline float32x4x2_t reinterpret_f32(const uint32x4x2_t& v) {
84
+ return {vreinterpretq_f32_u32(v.val[0]), vreinterpretq_f32_u32(v.val[1])};
85
+ }
86
+
87
+ static inline float32x4x2_t reinterpret_f32(const float32x4x2_t& v) {
88
+ return v;
89
+ }
90
+
91
+ template <
92
+ typename T,
93
+ typename U = decltype(reinterpret_u8(std::declval<T>().data))>
94
+ struct is_simd256bit : std::is_same<U, uint8x16x2_t> {};
95
+
96
+ static inline void bin(const char (&bytes)[32], char bits[257]) {
97
+ for (int i = 0; i < 256; ++i) {
98
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
99
+ }
100
+ bits[256] = 0;
101
+ }
102
+
103
+ template <typename T, size_t N, typename S>
104
+ static inline void bin(const S& simd, char bits[257]) {
105
+ static_assert(
106
+ std::is_same<void (S::*)(T*) const, decltype(&S::store)>::value,
107
+ "invalid T");
108
+ T ds[N];
109
+ simd.store(ds);
110
+ char bytes[32];
111
+ std::memcpy(bytes, ds, sizeof(char) * 32);
112
+ bin(bytes, bits);
113
+ }
114
+
115
+ template <typename S>
116
+ static inline std::string bin(const S& simd) {
117
+ char bits[257];
118
+ simd.bin(bits);
119
+ return std::string(bits);
120
+ }
121
+
122
+ template <typename D, typename F, typename T>
123
+ static inline void set1(D& d, F&& f, T t) {
124
+ const auto v = f(t);
125
+ d.val[0] = v;
126
+ d.val[1] = v;
127
+ }
128
+
129
+ template <typename T, size_t N, typename S>
130
+ static inline std::string elements_to_string(const char* fmt, const S& simd) {
131
+ static_assert(
132
+ std::is_same<void (S::*)(T*) const, decltype(&S::store)>::value,
133
+ "invalid T");
134
+ T bytes[N];
135
+ simd.store(bytes);
136
+ char res[1000], *ptr = res;
137
+ for (size_t i = 0; i < N; ++i) {
138
+ ptr += sprintf(ptr, fmt, bytes[i]);
139
+ }
140
+ // strip last ,
141
+ ptr[-1] = 0;
142
+ return std::string(res);
143
+ }
144
+
145
+ template <typename T, typename F>
146
+ static inline T unary_func(const T& a, F&& f) {
147
+ T t;
148
+ t.val[0] = f(a.val[0]);
149
+ t.val[1] = f(a.val[1]);
150
+ return t;
151
+ }
152
+
153
+ template <typename T, typename F>
154
+ static inline T binary_func(const T& a, const T& b, F&& f) {
155
+ T t;
156
+ t.val[0] = f(a.val[0], b.val[0]);
157
+ t.val[1] = f(a.val[1], b.val[1]);
158
+ return t;
159
+ }
160
+
161
+ static inline uint16_t vmovmask_u8(const uint8x16_t& v) {
162
+ uint8_t d[16];
163
+ const auto v2 = vreinterpretq_u16_u8(vshrq_n_u8(v, 7));
164
+ const auto v3 = vreinterpretq_u32_u16(vsraq_n_u16(v2, v2, 7));
165
+ const auto v4 = vreinterpretq_u64_u32(vsraq_n_u32(v3, v3, 14));
166
+ vst1q_u8(d, vreinterpretq_u8_u64(vsraq_n_u64(v4, v4, 28)));
167
+ return d[0] | static_cast<uint16_t>(d[8]) << 8u;
168
+ }
169
+
170
+ template <uint16x8_t (*F)(uint16x8_t, uint16x8_t)>
171
+ static inline uint32_t cmp_xe32(
172
+ const uint16x8x2_t& d0,
173
+ const uint16x8x2_t& d1,
174
+ const uint16x8x2_t& thr) {
175
+ const auto d0_thr = detail::simdlib::binary_func(d0, thr, F);
176
+ const auto d1_thr = detail::simdlib::binary_func(d1, thr, F);
177
+ const auto d0_mask = vmovmask_u8(
178
+ vmovn_high_u16(vmovn_u16(d0_thr.val[0]), d0_thr.val[1]));
179
+ const auto d1_mask = vmovmask_u8(
180
+ vmovn_high_u16(vmovn_u16(d1_thr.val[0]), d1_thr.val[1]));
181
+ return d0_mask | static_cast<uint32_t>(d1_mask) << 16;
182
+ }
183
+
184
+ template <std::uint8_t Shift>
185
+ static inline uint16x8_t vshlq(uint16x8_t vec) {
186
+ return vshlq_n_u16(vec, Shift);
187
+ }
188
+
189
+ template <std::uint8_t Shift>
190
+ static inline uint16x8_t vshrq(uint16x8_t vec) {
191
+ return vshrq_n_u16(vec, Shift);
192
+ }
193
+
194
+ } // namespace simdlib
195
+
196
+ } // namespace detail
197
+
198
+ /// vector of 16 elements in uint16
199
+ struct simd16uint16 {
200
+ uint16x8x2_t data;
201
+
202
+ simd16uint16() = default;
203
+
204
+ explicit simd16uint16(int x) : data{vdupq_n_u16(x), vdupq_n_u16(x)} {}
205
+
206
+ explicit simd16uint16(uint16_t x) : data{vdupq_n_u16(x), vdupq_n_u16(x)} {}
207
+
208
+ explicit simd16uint16(const uint16x8x2_t& v) : data{v} {}
209
+
210
+ template <
211
+ typename T,
212
+ typename std::enable_if<
213
+ detail::simdlib::is_simd256bit<T>::value,
214
+ std::nullptr_t>::type = nullptr>
215
+ explicit simd16uint16(const T& x)
216
+ : data{detail::simdlib::reinterpret_u16(x.data)} {}
217
+
218
+ explicit simd16uint16(const uint16_t* x)
219
+ : data{vld1q_u16(x), vld1q_u16(x + 8)} {}
220
+
221
+ void clear() {
222
+ detail::simdlib::set1(data, &vdupq_n_u16, static_cast<uint16_t>(0));
223
+ }
224
+
225
+ void storeu(uint16_t* ptr) const {
226
+ vst1q_u16(ptr, data.val[0]);
227
+ vst1q_u16(ptr + 8, data.val[1]);
228
+ }
229
+
230
+ void loadu(const uint16_t* ptr) {
231
+ data.val[0] = vld1q_u16(ptr);
232
+ data.val[1] = vld1q_u16(ptr + 8);
233
+ }
234
+
235
+ void store(uint16_t* ptr) const {
236
+ storeu(ptr);
237
+ }
238
+
239
+ void bin(char bits[257]) const {
240
+ detail::simdlib::bin<uint16_t, 16u>(*this, bits);
241
+ }
242
+
243
+ std::string bin() const {
244
+ return detail::simdlib::bin(*this);
245
+ }
246
+
247
+ std::string elements_to_string(const char* fmt) const {
248
+ return detail::simdlib::elements_to_string<uint16_t, 16u>(fmt, *this);
249
+ }
250
+
251
+ std::string hex() const {
252
+ return elements_to_string("%02x,");
253
+ }
254
+
255
+ std::string dec() const {
256
+ return elements_to_string("%3d,");
257
+ }
258
+
259
+ void set1(uint16_t x) {
260
+ detail::simdlib::set1(data, &vdupq_n_u16, x);
261
+ }
262
+
263
+ // shift must be known at compile time
264
+ simd16uint16 operator>>(const int shift) const {
265
+ switch (shift) {
266
+ case 0:
267
+ return *this;
268
+ case 1:
269
+ return simd16uint16{detail::simdlib::unary_func(
270
+ data, detail::simdlib::vshrq<1>)};
271
+ case 2:
272
+ return simd16uint16{detail::simdlib::unary_func(
273
+ data, detail::simdlib::vshrq<2>)};
274
+ case 3:
275
+ return simd16uint16{detail::simdlib::unary_func(
276
+ data, detail::simdlib::vshrq<3>)};
277
+ case 4:
278
+ return simd16uint16{detail::simdlib::unary_func(
279
+ data, detail::simdlib::vshrq<4>)};
280
+ case 5:
281
+ return simd16uint16{detail::simdlib::unary_func(
282
+ data, detail::simdlib::vshrq<5>)};
283
+ case 6:
284
+ return simd16uint16{detail::simdlib::unary_func(
285
+ data, detail::simdlib::vshrq<6>)};
286
+ case 7:
287
+ return simd16uint16{detail::simdlib::unary_func(
288
+ data, detail::simdlib::vshrq<7>)};
289
+ case 8:
290
+ return simd16uint16{detail::simdlib::unary_func(
291
+ data, detail::simdlib::vshrq<8>)};
292
+ case 9:
293
+ return simd16uint16{detail::simdlib::unary_func(
294
+ data, detail::simdlib::vshrq<9>)};
295
+ case 10:
296
+ return simd16uint16{detail::simdlib::unary_func(
297
+ data, detail::simdlib::vshrq<10>)};
298
+ case 11:
299
+ return simd16uint16{detail::simdlib::unary_func(
300
+ data, detail::simdlib::vshrq<11>)};
301
+ case 12:
302
+ return simd16uint16{detail::simdlib::unary_func(
303
+ data, detail::simdlib::vshrq<12>)};
304
+ case 13:
305
+ return simd16uint16{detail::simdlib::unary_func(
306
+ data, detail::simdlib::vshrq<13>)};
307
+ case 14:
308
+ return simd16uint16{detail::simdlib::unary_func(
309
+ data, detail::simdlib::vshrq<14>)};
310
+ case 15:
311
+ return simd16uint16{detail::simdlib::unary_func(
312
+ data, detail::simdlib::vshrq<15>)};
313
+ default:
314
+ FAISS_THROW_FMT("Invalid shift %d", shift);
315
+ }
316
+ }
317
+
318
+ // shift must be known at compile time
319
+ simd16uint16 operator<<(const int shift) const {
320
+ switch (shift) {
321
+ case 0:
322
+ return *this;
323
+ case 1:
324
+ return simd16uint16{detail::simdlib::unary_func(
325
+ data, detail::simdlib::vshlq<1>)};
326
+ case 2:
327
+ return simd16uint16{detail::simdlib::unary_func(
328
+ data, detail::simdlib::vshlq<2>)};
329
+ case 3:
330
+ return simd16uint16{detail::simdlib::unary_func(
331
+ data, detail::simdlib::vshlq<3>)};
332
+ case 4:
333
+ return simd16uint16{detail::simdlib::unary_func(
334
+ data, detail::simdlib::vshlq<4>)};
335
+ case 5:
336
+ return simd16uint16{detail::simdlib::unary_func(
337
+ data, detail::simdlib::vshlq<5>)};
338
+ case 6:
339
+ return simd16uint16{detail::simdlib::unary_func(
340
+ data, detail::simdlib::vshlq<6>)};
341
+ case 7:
342
+ return simd16uint16{detail::simdlib::unary_func(
343
+ data, detail::simdlib::vshlq<7>)};
344
+ case 8:
345
+ return simd16uint16{detail::simdlib::unary_func(
346
+ data, detail::simdlib::vshlq<8>)};
347
+ case 9:
348
+ return simd16uint16{detail::simdlib::unary_func(
349
+ data, detail::simdlib::vshlq<9>)};
350
+ case 10:
351
+ return simd16uint16{detail::simdlib::unary_func(
352
+ data, detail::simdlib::vshlq<10>)};
353
+ case 11:
354
+ return simd16uint16{detail::simdlib::unary_func(
355
+ data, detail::simdlib::vshlq<11>)};
356
+ case 12:
357
+ return simd16uint16{detail::simdlib::unary_func(
358
+ data, detail::simdlib::vshlq<12>)};
359
+ case 13:
360
+ return simd16uint16{detail::simdlib::unary_func(
361
+ data, detail::simdlib::vshlq<13>)};
362
+ case 14:
363
+ return simd16uint16{detail::simdlib::unary_func(
364
+ data, detail::simdlib::vshlq<14>)};
365
+ case 15:
366
+ return simd16uint16{detail::simdlib::unary_func(
367
+ data, detail::simdlib::vshlq<15>)};
368
+ default:
369
+ FAISS_THROW_FMT("Invalid shift %d", shift);
370
+ }
371
+ }
372
+
373
+ simd16uint16 operator+=(const simd16uint16& other) {
374
+ *this = *this + other;
375
+ return *this;
376
+ }
377
+
378
+ simd16uint16 operator-=(const simd16uint16& other) {
379
+ *this = *this - other;
380
+ return *this;
381
+ }
382
+
383
+ simd16uint16 operator+(const simd16uint16& other) const {
384
+ return simd16uint16{
385
+ detail::simdlib::binary_func(data, other.data, &vaddq_u16)};
386
+ }
387
+
388
+ simd16uint16 operator-(const simd16uint16& other) const {
389
+ return simd16uint16{
390
+ detail::simdlib::binary_func(data, other.data, &vsubq_u16)};
391
+ }
392
+
393
+ template <
394
+ typename T,
395
+ typename std::enable_if<
396
+ detail::simdlib::is_simd256bit<T>::value,
397
+ std::nullptr_t>::type = nullptr>
398
+ simd16uint16 operator&(const T& other) const {
399
+ return simd16uint16{detail::simdlib::binary_func(
400
+ data,
401
+ detail::simdlib::reinterpret_u16(other.data),
402
+ &vandq_u16)};
403
+ }
404
+
405
+ template <
406
+ typename T,
407
+ typename std::enable_if<
408
+ detail::simdlib::is_simd256bit<T>::value,
409
+ std::nullptr_t>::type = nullptr>
410
+ simd16uint16 operator|(const T& other) const {
411
+ return simd16uint16{detail::simdlib::binary_func(
412
+ data,
413
+ detail::simdlib::reinterpret_u16(other.data),
414
+ &vorrq_u16)};
415
+ }
416
+
417
+ // returns binary masks
418
+ simd16uint16 operator==(const simd16uint16& other) const {
419
+ return simd16uint16{
420
+ detail::simdlib::binary_func(data, other.data, &vceqq_u16)};
421
+ }
422
+
423
+ simd16uint16 operator~() const {
424
+ return simd16uint16{detail::simdlib::unary_func(data, &vmvnq_u16)};
425
+ }
426
+
427
+ // get scalar at index 0
428
+ uint16_t get_scalar_0() const {
429
+ return vgetq_lane_u16(data.val[0], 0);
430
+ }
431
+
432
+ // mask of elements where this >= thresh
433
+ // 2 bit per component: 16 * 2 = 32 bit
434
+ uint32_t ge_mask(const simd16uint16& thresh) const {
435
+ const auto input =
436
+ detail::simdlib::binary_func(data, thresh.data, &vcgeq_u16);
437
+ const auto vmovmask_u16 = [](uint16x8_t v) -> uint16_t {
438
+ uint16_t d[8];
439
+ const auto v2 = vreinterpretq_u32_u16(vshrq_n_u16(v, 14));
440
+ const auto v3 = vreinterpretq_u64_u32(vsraq_n_u32(v2, v2, 14));
441
+ vst1q_u16(d, vreinterpretq_u16_u64(vsraq_n_u64(v3, v3, 28)));
442
+ return d[0] | d[4] << 8u;
443
+ };
444
+ return static_cast<uint32_t>(vmovmask_u16(input.val[1])) << 16u |
445
+ vmovmask_u16(input.val[0]);
446
+ }
447
+
448
+ uint32_t le_mask(const simd16uint16& thresh) const {
449
+ return thresh.ge_mask(*this);
450
+ }
451
+
452
+ uint32_t gt_mask(const simd16uint16& thresh) const {
453
+ return ~le_mask(thresh);
454
+ }
455
+
456
+ bool all_gt(const simd16uint16& thresh) const {
457
+ return le_mask(thresh) == 0;
458
+ }
459
+
460
+ // for debugging only
461
+ uint16_t operator[](int i) const {
462
+ uint16_t tab[8];
463
+ const bool high = i >= 8;
464
+ vst1q_u16(tab, data.val[high]);
465
+ return tab[i - high * 8];
466
+ }
467
+
468
+ void accu_min(const simd16uint16& incoming) {
469
+ data = detail::simdlib::binary_func(incoming.data, data, &vminq_u16);
470
+ }
471
+
472
+ void accu_max(const simd16uint16& incoming) {
473
+ data = detail::simdlib::binary_func(incoming.data, data, &vmaxq_u16);
474
+ }
475
+ };
476
+
477
+ // not really a std::min because it returns an elementwise min
478
+ inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
479
+ return simd16uint16{
480
+ detail::simdlib::binary_func(av.data, bv.data, &vminq_u16)};
481
+ }
482
+
483
+ inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
484
+ return simd16uint16{
485
+ detail::simdlib::binary_func(av.data, bv.data, &vmaxq_u16)};
486
+ }
487
+
488
+ // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
489
+ // return (a0 + a1, b0 + b1)
490
+ // TODO find a better name
491
+ inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
492
+ return simd16uint16{uint16x8x2_t{
493
+ vaddq_u16(a.data.val[0], a.data.val[1]),
494
+ vaddq_u16(b.data.val[0], b.data.val[1])}};
495
+ }
496
+
497
+ // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
498
+ // of d0 and d1 with thr
499
+ inline uint32_t cmp_ge32(
500
+ const simd16uint16& d0,
501
+ const simd16uint16& d1,
502
+ const simd16uint16& thr) {
503
+ return detail::simdlib::cmp_xe32<&vcgeq_u16>(d0.data, d1.data, thr.data);
504
+ }
505
+
506
+ inline uint32_t cmp_le32(
507
+ const simd16uint16& d0,
508
+ const simd16uint16& d1,
509
+ const simd16uint16& thr) {
510
+ return detail::simdlib::cmp_xe32<&vcleq_u16>(d0.data, d1.data, thr.data);
511
+ }
512
+
513
+ // vector of 32 unsigned 8-bit integers
514
+ struct simd32uint8 {
515
+ uint8x16x2_t data;
516
+
517
+ simd32uint8() = default;
518
+
519
+ explicit simd32uint8(int x) : data{vdupq_n_u8(x), vdupq_n_u8(x)} {}
520
+
521
+ explicit simd32uint8(uint8_t x) : data{vdupq_n_u8(x), vdupq_n_u8(x)} {}
522
+
523
+ explicit simd32uint8(const uint8x16x2_t& v) : data{v} {}
524
+
525
+ template <
526
+ typename T,
527
+ typename std::enable_if<
528
+ detail::simdlib::is_simd256bit<T>::value,
529
+ std::nullptr_t>::type = nullptr>
530
+ explicit simd32uint8(const T& x)
531
+ : data{detail::simdlib::reinterpret_u8(x.data)} {}
532
+
533
+ explicit simd32uint8(const uint8_t* x)
534
+ : data{vld1q_u8(x), vld1q_u8(x + 16)} {}
535
+
536
+ void clear() {
537
+ detail::simdlib::set1(data, &vdupq_n_u8, static_cast<uint8_t>(0));
538
+ }
539
+
540
+ void storeu(uint8_t* ptr) const {
541
+ vst1q_u8(ptr, data.val[0]);
542
+ vst1q_u8(ptr + 16, data.val[1]);
543
+ }
544
+
545
+ void loadu(const uint8_t* ptr) {
546
+ data.val[0] = vld1q_u8(ptr);
547
+ data.val[1] = vld1q_u8(ptr + 16);
548
+ }
549
+
550
+ void store(uint8_t* ptr) const {
551
+ storeu(ptr);
552
+ }
553
+
554
+ void bin(char bits[257]) const {
555
+ uint8_t bytes[32];
556
+ store(bytes);
557
+ detail::simdlib::bin(
558
+ const_cast<const char(&)[32]>(
559
+ reinterpret_cast<char(&)[32]>(bytes)),
560
+ bits);
561
+ }
562
+
563
+ std::string bin() const {
564
+ return detail::simdlib::bin(*this);
565
+ }
566
+
567
+ std::string elements_to_string(const char* fmt) const {
568
+ return detail::simdlib::elements_to_string<uint8_t, 32u>(fmt, *this);
569
+ }
570
+
571
+ std::string hex() const {
572
+ return elements_to_string("%02x,");
573
+ }
574
+
575
+ std::string dec() const {
576
+ return elements_to_string("%3d,");
577
+ }
578
+
579
+ void set1(uint8_t x) {
580
+ detail::simdlib::set1(data, &vdupq_n_u8, x);
581
+ }
582
+
583
+ template <
584
+ typename T,
585
+ typename std::enable_if<
586
+ detail::simdlib::is_simd256bit<T>::value,
587
+ std::nullptr_t>::type = nullptr>
588
+ simd32uint8 operator&(const T& other) const {
589
+ return simd32uint8{detail::simdlib::binary_func(
590
+ data, detail::simdlib::reinterpret_u8(other.data), &vandq_u8)};
591
+ }
592
+
593
+ simd32uint8 operator+(const simd32uint8& other) const {
594
+ return simd32uint8{
595
+ detail::simdlib::binary_func(data, other.data, &vaddq_u8)};
596
+ }
597
+
598
+ // The very important operation that everything relies on
599
+ simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
600
+ return simd32uint8{
601
+ detail::simdlib::binary_func(data, idx.data, &vqtbl1q_u8)};
602
+ }
603
+
604
+ simd32uint8 operator+=(const simd32uint8& other) {
605
+ *this = *this + other;
606
+ return *this;
607
+ }
608
+
609
+ // for debugging only
610
+ uint8_t operator[](int i) const {
611
+ uint8_t tab[16];
612
+ const bool high = i >= 16;
613
+ vst1q_u8(tab, data.val[high]);
614
+ return tab[i - high * 16];
615
+ }
616
+ };
617
+
618
+ // convert with saturation
619
+ // careful: this does not cross lanes, so the order is weird
620
+ inline simd32uint8 uint16_to_uint8_saturate(
621
+ const simd16uint16& a,
622
+ const simd16uint16& b) {
623
+ return simd32uint8{uint8x16x2_t{
624
+ vqmovn_high_u16(vqmovn_u16(a.data.val[0]), b.data.val[0]),
625
+ vqmovn_high_u16(vqmovn_u16(a.data.val[1]), b.data.val[1])}};
626
+ }
627
+
628
+ /// get most significant bit of each byte
629
+ inline uint32_t get_MSBs(const simd32uint8& a) {
630
+ using detail::simdlib::vmovmask_u8;
631
+ return vmovmask_u8(a.data.val[0]) |
632
+ static_cast<uint32_t>(vmovmask_u8(a.data.val[1])) << 16u;
633
+ }
634
+
635
+ /// use MSB of each byte of mask to select a byte between a and b
636
+ inline simd32uint8 blendv(
637
+ const simd32uint8& a,
638
+ const simd32uint8& b,
639
+ const simd32uint8& mask) {
640
+ const auto msb = vdupq_n_u8(0x80);
641
+ const uint8x16x2_t msb_mask = {
642
+ vtstq_u8(mask.data.val[0], msb), vtstq_u8(mask.data.val[1], msb)};
643
+ const uint8x16x2_t selected = {
644
+ vbslq_u8(msb_mask.val[0], a.data.val[0], b.data.val[0]),
645
+ vbslq_u8(msb_mask.val[1], a.data.val[1], b.data.val[1])};
646
+ return simd32uint8{selected};
647
+ }
648
+
649
+ /// vector of 8 unsigned 32-bit integers
650
+ struct simd8uint32 {
651
+ uint32x4x2_t data;
652
+
653
+ simd8uint32() = default;
654
+
655
+ explicit simd8uint32(uint32_t x) : data{vdupq_n_u32(x), vdupq_n_u32(x)} {}
656
+
657
+ explicit simd8uint32(const uint32x4x2_t& v) : data{v} {}
658
+
659
+ template <
660
+ typename T,
661
+ typename std::enable_if<
662
+ detail::simdlib::is_simd256bit<T>::value,
663
+ std::nullptr_t>::type = nullptr>
664
+ explicit simd8uint32(const T& x)
665
+ : data{detail::simdlib::reinterpret_u32(x.data)} {}
666
+
667
+ explicit simd8uint32(const uint8_t* x) : simd8uint32(simd32uint8(x)) {}
668
+
669
+ void clear() {
670
+ detail::simdlib::set1(data, &vdupq_n_u32, static_cast<uint32_t>(0));
671
+ }
672
+
673
+ void storeu(uint32_t* ptr) const {
674
+ vst1q_u32(ptr, data.val[0]);
675
+ vst1q_u32(ptr + 4, data.val[1]);
676
+ }
677
+
678
+ void loadu(const uint32_t* ptr) {
679
+ data.val[0] = vld1q_u32(ptr);
680
+ data.val[1] = vld1q_u32(ptr + 4);
681
+ }
682
+
683
+ void store(uint32_t* ptr) const {
684
+ storeu(ptr);
685
+ }
686
+
687
+ void bin(char bits[257]) const {
688
+ detail::simdlib::bin<uint32_t, 8u>(*this, bits);
689
+ }
690
+
691
+ std::string bin() const {
692
+ return detail::simdlib::bin(*this);
693
+ }
694
+
695
+ std::string elements_to_string(const char* fmt) const {
696
+ return detail::simdlib::elements_to_string<uint32_t, 8u>(fmt, *this);
697
+ }
698
+
699
+ std::string hex() const {
700
+ return elements_to_string("%08x,");
701
+ }
702
+
703
+ std::string dec() const {
704
+ return elements_to_string("%10d,");
705
+ }
706
+
707
+ void set1(uint32_t x) {
708
+ detail::simdlib::set1(data, &vdupq_n_u32, x);
709
+ }
710
+ };
711
+
712
+ struct simd8float32 {
713
+ float32x4x2_t data;
714
+
715
+ simd8float32() = default;
716
+
717
+ explicit simd8float32(float x) : data{vdupq_n_f32(x), vdupq_n_f32(x)} {}
718
+
719
+ explicit simd8float32(const float32x4x2_t& v) : data{v} {}
720
+
721
+ template <
722
+ typename T,
723
+ typename std::enable_if<
724
+ detail::simdlib::is_simd256bit<T>::value,
725
+ std::nullptr_t>::type = nullptr>
726
+ explicit simd8float32(const T& x)
727
+ : data{detail::simdlib::reinterpret_f32(x.data)} {}
728
+
729
+ explicit simd8float32(const float* x)
730
+ : data{vld1q_f32(x), vld1q_f32(x + 4)} {}
731
+
732
+ void clear() {
733
+ detail::simdlib::set1(data, &vdupq_n_f32, 0.f);
734
+ }
735
+
736
+ void storeu(float* ptr) const {
737
+ vst1q_f32(ptr, data.val[0]);
738
+ vst1q_f32(ptr + 4, data.val[1]);
739
+ }
740
+
741
+ void loadu(const float* ptr) {
742
+ data.val[0] = vld1q_f32(ptr);
743
+ data.val[1] = vld1q_f32(ptr + 4);
744
+ }
745
+
746
+ void store(float* ptr) const {
747
+ storeu(ptr);
748
+ }
749
+
750
+ void bin(char bits[257]) const {
751
+ detail::simdlib::bin<float, 8u>(*this, bits);
752
+ }
753
+
754
+ std::string bin() const {
755
+ return detail::simdlib::bin(*this);
756
+ }
757
+
758
+ simd8float32 operator*(const simd8float32& other) const {
759
+ return simd8float32{
760
+ detail::simdlib::binary_func(data, other.data, &vmulq_f32)};
761
+ }
762
+
763
+ simd8float32 operator+(const simd8float32& other) const {
764
+ return simd8float32{
765
+ detail::simdlib::binary_func(data, other.data, &vaddq_f32)};
766
+ }
767
+
768
+ simd8float32 operator-(const simd8float32& other) const {
769
+ return simd8float32{
770
+ detail::simdlib::binary_func(data, other.data, &vsubq_f32)};
771
+ }
772
+
773
+ std::string tostring() const {
774
+ return detail::simdlib::elements_to_string<float, 8u>("%g,", *this);
775
+ }
776
+ };
777
+
778
+ // hadd does not cross lanes
779
+ inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
780
+ return simd8float32{
781
+ detail::simdlib::binary_func(a.data, b.data, &vpaddq_f32)};
782
+ }
783
+
784
+ inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
785
+ return simd8float32{
786
+ detail::simdlib::binary_func(a.data, b.data, &vzip1q_f32)};
787
+ }
788
+
789
+ inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
790
+ return simd8float32{
791
+ detail::simdlib::binary_func(a.data, b.data, &vzip2q_f32)};
792
+ }
793
+
794
+ // compute a * b + c
795
+ inline simd8float32 fmadd(
796
+ const simd8float32& a,
797
+ const simd8float32& b,
798
+ const simd8float32& c) {
799
+ return simd8float32{float32x4x2_t{
800
+ vfmaq_f32(c.data.val[0], a.data.val[0], b.data.val[0]),
801
+ vfmaq_f32(c.data.val[1], a.data.val[1], b.data.val[1])}};
802
+ }
803
+
804
+ namespace {
805
+
806
+ // get even float32's of a and b, interleaved
807
+ simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
808
+ return simd8float32{float32x4x2_t{
809
+ vuzp1q_f32(a.data.val[0], b.data.val[0]),
810
+ vuzp1q_f32(a.data.val[1], b.data.val[1])}};
811
+ }
812
+
813
+ // get odd float32's of a and b, interleaved
814
+ simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
815
+ return simd8float32{float32x4x2_t{
816
+ vuzp2q_f32(a.data.val[0], b.data.val[0]),
817
+ vuzp2q_f32(a.data.val[1], b.data.val[1])}};
818
+ }
819
+
820
+ // 3 cycles
821
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
822
+ simd8float32 getlow128(const simd8float32& a, const simd8float32& b) {
823
+ return simd8float32{float32x4x2_t{a.data.val[0], b.data.val[0]}};
824
+ }
825
+
826
+ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
827
+ return simd8float32{float32x4x2_t{a.data.val[1], b.data.val[1]}};
828
+ }
829
+
830
+ } // namespace
831
+
832
+ } // namespace faiss