faiss 0.3.0 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,1084 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <algorithm>
11
+ #include <cstdint>
12
+ #include <cstring>
13
+ #include <string>
14
+
15
+ namespace faiss {
16
+
17
+ struct simd256bit {
18
+ union {
19
+ uint8_t u8[32];
20
+ uint16_t u16[16];
21
+ uint32_t u32[8];
22
+ float f32[8];
23
+ };
24
+
25
+ simd256bit() {}
26
+
27
+ explicit simd256bit(const void* x) {
28
+ memcpy(u8, x, 32);
29
+ }
30
+
31
+ void clear() {
32
+ memset(u8, 0, 32);
33
+ }
34
+
35
+ void storeu(void* ptr) const {
36
+ memcpy(ptr, u8, 32);
37
+ }
38
+
39
+ void loadu(const void* ptr) {
40
+ memcpy(u8, ptr, 32);
41
+ }
42
+
43
+ void store(void* ptr) const {
44
+ storeu(ptr);
45
+ }
46
+
47
+ void bin(char bits[257]) const {
48
+ const char* bytes = (char*)this->u8;
49
+ for (int i = 0; i < 256; i++) {
50
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
51
+ }
52
+ bits[256] = 0;
53
+ }
54
+
55
+ std::string bin() const {
56
+ char bits[257];
57
+ bin(bits);
58
+ return std::string(bits);
59
+ }
60
+
61
+ // Checks whether the other holds exactly the same bytes.
62
+ bool is_same_as(simd256bit other) const {
63
+ for (size_t i = 0; i < 8; i++) {
64
+ if (u32[i] != other.u32[i]) {
65
+ return false;
66
+ }
67
+ }
68
+
69
+ return true;
70
+ }
71
+ };
72
+
73
+ /// vector of 16 elements in uint16
74
+ struct simd16uint16 : simd256bit {
75
+ simd16uint16() {}
76
+
77
+ explicit simd16uint16(int x) {
78
+ set1(x);
79
+ }
80
+
81
+ explicit simd16uint16(uint16_t x) {
82
+ set1(x);
83
+ }
84
+
85
+ explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
86
+
87
+ explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
88
+
89
+ explicit simd16uint16(
90
+ uint16_t u0,
91
+ uint16_t u1,
92
+ uint16_t u2,
93
+ uint16_t u3,
94
+ uint16_t u4,
95
+ uint16_t u5,
96
+ uint16_t u6,
97
+ uint16_t u7,
98
+ uint16_t u8,
99
+ uint16_t u9,
100
+ uint16_t u10,
101
+ uint16_t u11,
102
+ uint16_t u12,
103
+ uint16_t u13,
104
+ uint16_t u14,
105
+ uint16_t u15) {
106
+ this->u16[0] = u0;
107
+ this->u16[1] = u1;
108
+ this->u16[2] = u2;
109
+ this->u16[3] = u3;
110
+ this->u16[4] = u4;
111
+ this->u16[5] = u5;
112
+ this->u16[6] = u6;
113
+ this->u16[7] = u7;
114
+ this->u16[8] = u8;
115
+ this->u16[9] = u9;
116
+ this->u16[10] = u10;
117
+ this->u16[11] = u11;
118
+ this->u16[12] = u12;
119
+ this->u16[13] = u13;
120
+ this->u16[14] = u14;
121
+ this->u16[15] = u15;
122
+ }
123
+
124
+ std::string elements_to_string(const char* fmt) const {
125
+ char res[1000], *ptr = res;
126
+ for (int i = 0; i < 16; i++) {
127
+ ptr += sprintf(ptr, fmt, u16[i]);
128
+ }
129
+ // strip last ,
130
+ ptr[-1] = 0;
131
+ return std::string(res);
132
+ }
133
+
134
+ std::string hex() const {
135
+ return elements_to_string("%02x,");
136
+ }
137
+
138
+ std::string dec() const {
139
+ return elements_to_string("%3d,");
140
+ }
141
+
142
+ template <typename F>
143
+ static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
144
+ simd16uint16 c;
145
+ for (int j = 0; j < 16; j++) {
146
+ c.u16[j] = f(a.u16[j]);
147
+ }
148
+ return c;
149
+ }
150
+
151
+ template <typename F>
152
+ static simd16uint16 binary_func(
153
+ const simd16uint16& a,
154
+ const simd16uint16& b,
155
+ F&& f) {
156
+ simd16uint16 c;
157
+ for (int j = 0; j < 16; j++) {
158
+ c.u16[j] = f(a.u16[j], b.u16[j]);
159
+ }
160
+ return c;
161
+ }
162
+
163
+ void set1(uint16_t x) {
164
+ for (int i = 0; i < 16; i++) {
165
+ u16[i] = x;
166
+ }
167
+ }
168
+
169
+ simd16uint16 operator*(const simd16uint16& other) const {
170
+ return binary_func(
171
+ *this, other, [](uint16_t a, uint16_t b) { return a * b; });
172
+ }
173
+
174
+ // shift must be known at compile time
175
+ simd16uint16 operator>>(const int shift) const {
176
+ return unary_func(*this, [shift](uint16_t a) { return a >> shift; });
177
+ }
178
+
179
+ // shift must be known at compile time
180
+ simd16uint16 operator<<(const int shift) const {
181
+ return unary_func(*this, [shift](uint16_t a) { return a << shift; });
182
+ }
183
+
184
+ simd16uint16 operator+=(const simd16uint16& other) {
185
+ *this = *this + other;
186
+ return *this;
187
+ }
188
+
189
+ simd16uint16 operator-=(const simd16uint16& other) {
190
+ *this = *this - other;
191
+ return *this;
192
+ }
193
+
194
+ simd16uint16 operator+(const simd16uint16& other) const {
195
+ return binary_func(
196
+ *this, other, [](uint16_t a, uint16_t b) { return a + b; });
197
+ }
198
+
199
+ simd16uint16 operator-(const simd16uint16& other) const {
200
+ return binary_func(
201
+ *this, other, [](uint16_t a, uint16_t b) { return a - b; });
202
+ }
203
+
204
+ simd16uint16 operator&(const simd256bit& other) const {
205
+ return binary_func(
206
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
207
+ return a & b;
208
+ });
209
+ }
210
+
211
+ simd16uint16 operator|(const simd256bit& other) const {
212
+ return binary_func(
213
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
214
+ return a | b;
215
+ });
216
+ }
217
+
218
+ simd16uint16 operator^(const simd256bit& other) const {
219
+ return binary_func(
220
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
221
+ return a ^ b;
222
+ });
223
+ }
224
+
225
+ // returns binary masks
226
+ simd16uint16 operator==(const simd16uint16& other) const {
227
+ return binary_func(*this, other, [](uint16_t a, uint16_t b) {
228
+ return a == b ? 0xffff : 0;
229
+ });
230
+ }
231
+
232
+ simd16uint16 operator~() const {
233
+ return unary_func(*this, [](uint16_t a) { return ~a; });
234
+ }
235
+
236
+ // get scalar at index 0
237
+ uint16_t get_scalar_0() const {
238
+ return u16[0];
239
+ }
240
+
241
+ // mask of elements where this >= thresh
242
+ // 2 bit per component: 16 * 2 = 32 bit
243
+ uint32_t ge_mask(const simd16uint16& thresh) const {
244
+ uint32_t gem = 0;
245
+ for (int j = 0; j < 16; j++) {
246
+ if (u16[j] >= thresh.u16[j]) {
247
+ gem |= 3 << (j * 2);
248
+ }
249
+ }
250
+ return gem;
251
+ }
252
+
253
+ uint32_t le_mask(const simd16uint16& thresh) const {
254
+ return thresh.ge_mask(*this);
255
+ }
256
+
257
+ uint32_t gt_mask(const simd16uint16& thresh) const {
258
+ return ~le_mask(thresh);
259
+ }
260
+
261
+ bool all_gt(const simd16uint16& thresh) const {
262
+ return le_mask(thresh) == 0;
263
+ }
264
+
265
+ // for debugging only
266
+ uint16_t operator[](int i) const {
267
+ return u16[i];
268
+ }
269
+
270
+ void accu_min(const simd16uint16& incoming) {
271
+ for (int j = 0; j < 16; j++) {
272
+ if (incoming.u16[j] < u16[j]) {
273
+ u16[j] = incoming.u16[j];
274
+ }
275
+ }
276
+ }
277
+
278
+ void accu_max(const simd16uint16& incoming) {
279
+ for (int j = 0; j < 16; j++) {
280
+ if (incoming.u16[j] > u16[j]) {
281
+ u16[j] = incoming.u16[j];
282
+ }
283
+ }
284
+ }
285
+ };
286
+
287
+ // not really a std::min because it returns an elementwise min
288
+ inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
289
+ return simd16uint16::binary_func(
290
+ av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
291
+ }
292
+
293
+ inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
294
+ return simd16uint16::binary_func(
295
+ av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
296
+ }
297
+
298
+ // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
299
+ // return (a0 + a1, b0 + b1)
300
+ // TODO find a better name
301
+ inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
302
+ simd16uint16 c;
303
+ for (int j = 0; j < 8; j++) {
304
+ c.u16[j] = a.u16[j] + a.u16[j + 8];
305
+ c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
306
+ }
307
+ return c;
308
+ }
309
+
310
+ // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
311
+ // of d0 and d1 with thr
312
+ inline uint32_t cmp_ge32(
313
+ const simd16uint16& d0,
314
+ const simd16uint16& d1,
315
+ const simd16uint16& thr) {
316
+ uint32_t gem = 0;
317
+ for (int j = 0; j < 16; j++) {
318
+ if (d0.u16[j] >= thr.u16[j]) {
319
+ gem |= 1 << j;
320
+ }
321
+ if (d1.u16[j] >= thr.u16[j]) {
322
+ gem |= 1 << (j + 16);
323
+ }
324
+ }
325
+ return gem;
326
+ }
327
+
328
+ inline uint32_t cmp_le32(
329
+ const simd16uint16& d0,
330
+ const simd16uint16& d1,
331
+ const simd16uint16& thr) {
332
+ uint32_t gem = 0;
333
+ for (int j = 0; j < 16; j++) {
334
+ if (d0.u16[j] <= thr.u16[j]) {
335
+ gem |= 1 << j;
336
+ }
337
+ if (d1.u16[j] <= thr.u16[j]) {
338
+ gem |= 1 << (j + 16);
339
+ }
340
+ }
341
+ return gem;
342
+ }
343
+
344
+ // hadd does not cross lanes
345
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
346
+ simd16uint16 c;
347
+ c.u16[0] = a.u16[0] + a.u16[1];
348
+ c.u16[1] = a.u16[2] + a.u16[3];
349
+ c.u16[2] = a.u16[4] + a.u16[5];
350
+ c.u16[3] = a.u16[6] + a.u16[7];
351
+ c.u16[4] = b.u16[0] + b.u16[1];
352
+ c.u16[5] = b.u16[2] + b.u16[3];
353
+ c.u16[6] = b.u16[4] + b.u16[5];
354
+ c.u16[7] = b.u16[6] + b.u16[7];
355
+
356
+ c.u16[8] = a.u16[8] + a.u16[9];
357
+ c.u16[9] = a.u16[10] + a.u16[11];
358
+ c.u16[10] = a.u16[12] + a.u16[13];
359
+ c.u16[11] = a.u16[14] + a.u16[15];
360
+ c.u16[12] = b.u16[8] + b.u16[9];
361
+ c.u16[13] = b.u16[10] + b.u16[11];
362
+ c.u16[14] = b.u16[12] + b.u16[13];
363
+ c.u16[15] = b.u16[14] + b.u16[15];
364
+
365
+ return c;
366
+ }
367
+
368
+ // Vectorized version of the following code:
369
+ // for (size_t i = 0; i < n; i++) {
370
+ // bool flag = (candidateValues[i] < currentValues[i]);
371
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
372
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
373
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
374
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
375
+ // }
376
+ // Max indices evaluation is inaccurate in case of equal values (the index of
377
+ // the last equal value is saved instead of the first one), but this behavior
378
+ // saves instructions.
379
+ inline void cmplt_min_max_fast(
380
+ const simd16uint16 candidateValues,
381
+ const simd16uint16 candidateIndices,
382
+ const simd16uint16 currentValues,
383
+ const simd16uint16 currentIndices,
384
+ simd16uint16& minValues,
385
+ simd16uint16& minIndices,
386
+ simd16uint16& maxValues,
387
+ simd16uint16& maxIndices) {
388
+ for (size_t i = 0; i < 16; i++) {
389
+ bool flag = (candidateValues.u16[i] < currentValues.u16[i]);
390
+ minValues.u16[i] = flag ? candidateValues.u16[i] : currentValues.u16[i];
391
+ minIndices.u16[i] =
392
+ flag ? candidateIndices.u16[i] : currentIndices.u16[i];
393
+ maxValues.u16[i] =
394
+ !flag ? candidateValues.u16[i] : currentValues.u16[i];
395
+ maxIndices.u16[i] =
396
+ !flag ? candidateIndices.u16[i] : currentIndices.u16[i];
397
+ }
398
+ }
399
+
400
+ // vector of 32 unsigned 8-bit integers
401
+ struct simd32uint8 : simd256bit {
402
+ simd32uint8() {}
403
+
404
+ explicit simd32uint8(int x) {
405
+ set1(x);
406
+ }
407
+
408
+ explicit simd32uint8(uint8_t x) {
409
+ set1(x);
410
+ }
411
+ template <
412
+ uint8_t _0,
413
+ uint8_t _1,
414
+ uint8_t _2,
415
+ uint8_t _3,
416
+ uint8_t _4,
417
+ uint8_t _5,
418
+ uint8_t _6,
419
+ uint8_t _7,
420
+ uint8_t _8,
421
+ uint8_t _9,
422
+ uint8_t _10,
423
+ uint8_t _11,
424
+ uint8_t _12,
425
+ uint8_t _13,
426
+ uint8_t _14,
427
+ uint8_t _15,
428
+ uint8_t _16,
429
+ uint8_t _17,
430
+ uint8_t _18,
431
+ uint8_t _19,
432
+ uint8_t _20,
433
+ uint8_t _21,
434
+ uint8_t _22,
435
+ uint8_t _23,
436
+ uint8_t _24,
437
+ uint8_t _25,
438
+ uint8_t _26,
439
+ uint8_t _27,
440
+ uint8_t _28,
441
+ uint8_t _29,
442
+ uint8_t _30,
443
+ uint8_t _31>
444
+ static simd32uint8 create() {
445
+ simd32uint8 ret;
446
+ ret.u8[0] = _0;
447
+ ret.u8[1] = _1;
448
+ ret.u8[2] = _2;
449
+ ret.u8[3] = _3;
450
+ ret.u8[4] = _4;
451
+ ret.u8[5] = _5;
452
+ ret.u8[6] = _6;
453
+ ret.u8[7] = _7;
454
+ ret.u8[8] = _8;
455
+ ret.u8[9] = _9;
456
+ ret.u8[10] = _10;
457
+ ret.u8[11] = _11;
458
+ ret.u8[12] = _12;
459
+ ret.u8[13] = _13;
460
+ ret.u8[14] = _14;
461
+ ret.u8[15] = _15;
462
+ ret.u8[16] = _16;
463
+ ret.u8[17] = _17;
464
+ ret.u8[18] = _18;
465
+ ret.u8[19] = _19;
466
+ ret.u8[20] = _20;
467
+ ret.u8[21] = _21;
468
+ ret.u8[22] = _22;
469
+ ret.u8[23] = _23;
470
+ ret.u8[24] = _24;
471
+ ret.u8[25] = _25;
472
+ ret.u8[26] = _26;
473
+ ret.u8[27] = _27;
474
+ ret.u8[28] = _28;
475
+ ret.u8[29] = _29;
476
+ ret.u8[30] = _30;
477
+ ret.u8[31] = _31;
478
+ return ret;
479
+ }
480
+
481
+ explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
482
+
483
+ explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
484
+
485
+ std::string elements_to_string(const char* fmt) const {
486
+ char res[1000], *ptr = res;
487
+ for (int i = 0; i < 32; i++) {
488
+ ptr += sprintf(ptr, fmt, u8[i]);
489
+ }
490
+ // strip last ,
491
+ ptr[-1] = 0;
492
+ return std::string(res);
493
+ }
494
+
495
+ std::string hex() const {
496
+ return elements_to_string("%02x,");
497
+ }
498
+
499
+ std::string dec() const {
500
+ return elements_to_string("%3d,");
501
+ }
502
+
503
+ void set1(uint8_t x) {
504
+ for (int j = 0; j < 32; j++) {
505
+ u8[j] = x;
506
+ }
507
+ }
508
+
509
+ template <typename F>
510
+ static simd32uint8 binary_func(
511
+ const simd32uint8& a,
512
+ const simd32uint8& b,
513
+ F&& f) {
514
+ simd32uint8 c;
515
+ for (int j = 0; j < 32; j++) {
516
+ c.u8[j] = f(a.u8[j], b.u8[j]);
517
+ }
518
+ return c;
519
+ }
520
+
521
+ simd32uint8 operator&(const simd256bit& other) const {
522
+ return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
523
+ return a & b;
524
+ });
525
+ }
526
+
527
+ simd32uint8 operator+(const simd32uint8& other) const {
528
+ return binary_func(
529
+ *this, other, [](uint8_t a, uint8_t b) { return a + b; });
530
+ }
531
+
532
+ // The very important operation that everything relies on
533
+ simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
534
+ simd32uint8 c;
535
+ // The original for loop:
536
+ // for (int j = 0; j < 32; j++) {
537
+ // if (idx.u8[j] & 0x80) {
538
+ // c.u8[j] = 0;
539
+ // } else {
540
+ // uint8_t i = idx.u8[j] & 15;
541
+ // if (j < 16) {
542
+ // c.u8[j] = u8[i];
543
+ // } else {
544
+ // c.u8[j] = u8[16 + i];
545
+ // }
546
+ // }
547
+
548
+ // The following function was re-written for Power 10
549
+ // The loop was unrolled to remove the if (j < 16) statement by doing
550
+ // the j and j + 16 iterations in parallel. The additional unrolling
551
+ // for j + 1 and j + 17, reduces the execution time on Power 10 by
552
+ // about 50% as the instruction scheduling allows on average 2X more
553
+ // instructions to be issued per cycle.
554
+
555
+ for (int j = 0; j < 16; j = j + 2) {
556
+ // j < 16, unrolled to depth of 2
557
+ if (idx.u8[j] & 0x80) {
558
+ c.u8[j] = 0;
559
+ } else {
560
+ uint8_t i = idx.u8[j] & 15;
561
+ c.u8[j] = u8[i];
562
+ }
563
+
564
+ if (idx.u8[j + 1] & 0x80) {
565
+ c.u8[j + 1] = 0;
566
+ } else {
567
+ uint8_t i = idx.u8[j + 1] & 15;
568
+ c.u8[j + 1] = u8[i];
569
+ }
570
+
571
+ // j >= 16, unrolled to depth of 2
572
+ if (idx.u8[j + 16] & 0x80) {
573
+ c.u8[j + 16] = 0;
574
+ } else {
575
+ uint8_t i = idx.u8[j + 16] & 15;
576
+ c.u8[j + 16] = u8[i + 16];
577
+ }
578
+
579
+ if (idx.u8[j + 17] & 0x80) {
580
+ c.u8[j + 17] = 0;
581
+ } else {
582
+ uint8_t i = idx.u8[j + 17] & 15;
583
+ c.u8[j + 17] = u8[i + 16];
584
+ }
585
+ }
586
+ return c;
587
+ }
588
+
589
+ // extract + 0-extend lane
590
+ // this operation is slow (3 cycles)
591
+
592
+ simd32uint8 operator+=(const simd32uint8& other) {
593
+ *this = *this + other;
594
+ return *this;
595
+ }
596
+
597
+ // for debugging only
598
+ uint8_t operator[](int i) const {
599
+ return u8[i];
600
+ }
601
+ };
602
+
603
+ // convert with saturation
604
+ // careful: this does not cross lanes, so the order is weird
605
+ inline simd32uint8 uint16_to_uint8_saturate(
606
+ const simd16uint16& a,
607
+ const simd16uint16& b) {
608
+ simd32uint8 c;
609
+
610
+ auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
611
+
612
+ for (int i = 0; i < 8; i++) {
613
+ c.u8[i] = saturate_16_to_8(a.u16[i]);
614
+ c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
615
+ c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
616
+ c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
617
+ }
618
+ return c;
619
+ }
620
+
621
+ /// get most significant bit of each byte
622
+ inline uint32_t get_MSBs(const simd32uint8& a) {
623
+ uint32_t res = 0;
624
+ for (int i = 0; i < 32; i++) {
625
+ if (a.u8[i] & 0x80) {
626
+ res |= 1 << i;
627
+ }
628
+ }
629
+ return res;
630
+ }
631
+
632
+ /// use MSB of each byte of mask to select a byte between a and b
633
+ inline simd32uint8 blendv(
634
+ const simd32uint8& a,
635
+ const simd32uint8& b,
636
+ const simd32uint8& mask) {
637
+ simd32uint8 c;
638
+ for (int i = 0; i < 32; i++) {
639
+ if (mask.u8[i] & 0x80) {
640
+ c.u8[i] = b.u8[i];
641
+ } else {
642
+ c.u8[i] = a.u8[i];
643
+ }
644
+ }
645
+ return c;
646
+ }
647
+
648
+ /// vector of 8 unsigned 32-bit integers
649
+ struct simd8uint32 : simd256bit {
650
+ simd8uint32() {}
651
+
652
+ explicit simd8uint32(uint32_t x) {
653
+ set1(x);
654
+ }
655
+
656
+ explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
657
+
658
+ explicit simd8uint32(const uint32_t* x) : simd256bit((const void*)x) {}
659
+
660
+ explicit simd8uint32(
661
+ uint32_t u0,
662
+ uint32_t u1,
663
+ uint32_t u2,
664
+ uint32_t u3,
665
+ uint32_t u4,
666
+ uint32_t u5,
667
+ uint32_t u6,
668
+ uint32_t u7) {
669
+ u32[0] = u0;
670
+ u32[1] = u1;
671
+ u32[2] = u2;
672
+ u32[3] = u3;
673
+ u32[4] = u4;
674
+ u32[5] = u5;
675
+ u32[6] = u6;
676
+ u32[7] = u7;
677
+ }
678
+
679
+ simd8uint32 operator+(simd8uint32 other) const {
680
+ simd8uint32 result;
681
+ for (int i = 0; i < 8; i++) {
682
+ result.u32[i] = u32[i] + other.u32[i];
683
+ }
684
+ return result;
685
+ }
686
+
687
+ simd8uint32 operator-(simd8uint32 other) const {
688
+ simd8uint32 result;
689
+ for (int i = 0; i < 8; i++) {
690
+ result.u32[i] = u32[i] - other.u32[i];
691
+ }
692
+ return result;
693
+ }
694
+
695
+ simd8uint32& operator+=(const simd8uint32& other) {
696
+ for (int i = 0; i < 8; i++) {
697
+ u32[i] += other.u32[i];
698
+ }
699
+ return *this;
700
+ }
701
+
702
+ bool operator==(simd8uint32 other) const {
703
+ for (size_t i = 0; i < 8; i++) {
704
+ if (u32[i] != other.u32[i]) {
705
+ return false;
706
+ }
707
+ }
708
+
709
+ return true;
710
+ }
711
+
712
+ bool operator!=(simd8uint32 other) const {
713
+ return !(*this == other);
714
+ }
715
+
716
+ std::string elements_to_string(const char* fmt) const {
717
+ char res[1000], *ptr = res;
718
+ for (int i = 0; i < 8; i++) {
719
+ ptr += sprintf(ptr, fmt, u32[i]);
720
+ }
721
+ // strip last ,
722
+ ptr[-1] = 0;
723
+ return std::string(res);
724
+ }
725
+
726
+ std::string hex() const {
727
+ return elements_to_string("%08x,");
728
+ }
729
+
730
+ std::string dec() const {
731
+ return elements_to_string("%10d,");
732
+ }
733
+
734
+ void set1(uint32_t x) {
735
+ for (int i = 0; i < 8; i++) {
736
+ u32[i] = x;
737
+ }
738
+ }
739
+
740
+ simd8uint32 unzip() const {
741
+ const uint32_t ret[] = {
742
+ u32[0], u32[2], u32[4], u32[6], u32[1], u32[3], u32[5], u32[7]};
743
+ return simd8uint32{ret};
744
+ }
745
+ };
746
+
747
+ // Vectorized version of the following code:
748
+ // for (size_t i = 0; i < n; i++) {
749
+ // bool flag = (candidateValues[i] < currentValues[i]);
750
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
751
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
752
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
753
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
754
+ // }
755
+ // Max indices evaluation is inaccurate in case of equal values (the index of
756
+ // the last equal value is saved instead of the first one), but this behavior
757
+ // saves instructions.
758
+ inline void cmplt_min_max_fast(
759
+ const simd8uint32 candidateValues,
760
+ const simd8uint32 candidateIndices,
761
+ const simd8uint32 currentValues,
762
+ const simd8uint32 currentIndices,
763
+ simd8uint32& minValues,
764
+ simd8uint32& minIndices,
765
+ simd8uint32& maxValues,
766
+ simd8uint32& maxIndices) {
767
+ for (size_t i = 0; i < 8; i++) {
768
+ bool flag = (candidateValues.u32[i] < currentValues.u32[i]);
769
+ minValues.u32[i] = flag ? candidateValues.u32[i] : currentValues.u32[i];
770
+ minIndices.u32[i] =
771
+ flag ? candidateIndices.u32[i] : currentIndices.u32[i];
772
+ maxValues.u32[i] =
773
+ !flag ? candidateValues.u32[i] : currentValues.u32[i];
774
+ maxIndices.u32[i] =
775
+ !flag ? candidateIndices.u32[i] : currentIndices.u32[i];
776
+ }
777
+ }
778
+
779
+ struct simd8float32 : simd256bit {
780
+ simd8float32() {}
781
+
782
+ explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
783
+
784
+ explicit simd8float32(float x) {
785
+ set1(x);
786
+ }
787
+
788
+ explicit simd8float32(const float* x) {
789
+ loadu((void*)x);
790
+ }
791
+
792
+ void set1(float x) {
793
+ for (int i = 0; i < 8; i++) {
794
+ f32[i] = x;
795
+ }
796
+ }
797
+
798
+ explicit simd8float32(
799
+ float f0,
800
+ float f1,
801
+ float f2,
802
+ float f3,
803
+ float f4,
804
+ float f5,
805
+ float f6,
806
+ float f7) {
807
+ f32[0] = f0;
808
+ f32[1] = f1;
809
+ f32[2] = f2;
810
+ f32[3] = f3;
811
+ f32[4] = f4;
812
+ f32[5] = f5;
813
+ f32[6] = f6;
814
+ f32[7] = f7;
815
+ }
816
+
817
+ template <typename F>
818
+ static simd8float32 binary_func(
819
+ const simd8float32& a,
820
+ const simd8float32& b,
821
+ F&& f) {
822
+ simd8float32 c;
823
+ for (int j = 0; j < 8; j++) {
824
+ c.f32[j] = f(a.f32[j], b.f32[j]);
825
+ }
826
+ return c;
827
+ }
828
+
829
+ simd8float32 operator*(const simd8float32& other) const {
830
+ return binary_func(
831
+ *this, other, [](float a, float b) { return a * b; });
832
+ }
833
+
834
+ simd8float32 operator+(const simd8float32& other) const {
835
+ return binary_func(
836
+ *this, other, [](float a, float b) { return a + b; });
837
+ }
838
+
839
+ simd8float32 operator-(const simd8float32& other) const {
840
+ return binary_func(
841
+ *this, other, [](float a, float b) { return a - b; });
842
+ }
843
+
844
+ simd8float32& operator+=(const simd8float32& other) {
845
+ for (size_t i = 0; i < 8; i++) {
846
+ f32[i] += other.f32[i];
847
+ }
848
+
849
+ return *this;
850
+ }
851
+
852
+ bool operator==(simd8float32 other) const {
853
+ for (size_t i = 0; i < 8; i++) {
854
+ if (f32[i] != other.f32[i]) {
855
+ return false;
856
+ }
857
+ }
858
+
859
+ return true;
860
+ }
861
+
862
+ bool operator!=(simd8float32 other) const {
863
+ return !(*this == other);
864
+ }
865
+
866
+ std::string tostring() const {
867
+ char res[1000], *ptr = res;
868
+ for (int i = 0; i < 8; i++) {
869
+ ptr += sprintf(ptr, "%g,", f32[i]);
870
+ }
871
+ // strip last ,
872
+ ptr[-1] = 0;
873
+ return std::string(res);
874
+ }
875
+ };
876
+
877
+ // hadd does not cross lanes
878
+ inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
879
+ simd8float32 c;
880
+ c.f32[0] = a.f32[0] + a.f32[1];
881
+ c.f32[1] = a.f32[2] + a.f32[3];
882
+ c.f32[2] = b.f32[0] + b.f32[1];
883
+ c.f32[3] = b.f32[2] + b.f32[3];
884
+
885
+ c.f32[4] = a.f32[4] + a.f32[5];
886
+ c.f32[5] = a.f32[6] + a.f32[7];
887
+ c.f32[6] = b.f32[4] + b.f32[5];
888
+ c.f32[7] = b.f32[6] + b.f32[7];
889
+
890
+ return c;
891
+ }
892
+
893
+ inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
894
+ simd8float32 c;
895
+ c.f32[0] = a.f32[0];
896
+ c.f32[1] = b.f32[0];
897
+ c.f32[2] = a.f32[1];
898
+ c.f32[3] = b.f32[1];
899
+
900
+ c.f32[4] = a.f32[4];
901
+ c.f32[5] = b.f32[4];
902
+ c.f32[6] = a.f32[5];
903
+ c.f32[7] = b.f32[5];
904
+
905
+ return c;
906
+ }
907
+
908
+ inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
909
+ simd8float32 c;
910
+ c.f32[0] = a.f32[2];
911
+ c.f32[1] = b.f32[2];
912
+ c.f32[2] = a.f32[3];
913
+ c.f32[3] = b.f32[3];
914
+
915
+ c.f32[4] = a.f32[6];
916
+ c.f32[5] = b.f32[6];
917
+ c.f32[6] = a.f32[7];
918
+ c.f32[7] = b.f32[7];
919
+
920
+ return c;
921
+ }
922
+
923
+ // compute a * b + c
924
+ inline simd8float32 fmadd(
925
+ const simd8float32& a,
926
+ const simd8float32& b,
927
+ const simd8float32& c) {
928
+ simd8float32 res;
929
+ for (int i = 0; i < 8; i++) {
930
+ res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
931
+ }
932
+ return res;
933
+ }
934
+
935
+ namespace {
936
+
937
+ // get even float32's of a and b, interleaved
938
+ simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
939
+ simd8float32 c;
940
+
941
+ c.f32[0] = a.f32[0];
942
+ c.f32[1] = a.f32[2];
943
+ c.f32[2] = b.f32[0];
944
+ c.f32[3] = b.f32[2];
945
+
946
+ c.f32[4] = a.f32[4];
947
+ c.f32[5] = a.f32[6];
948
+ c.f32[6] = b.f32[4];
949
+ c.f32[7] = b.f32[6];
950
+
951
+ return c;
952
+ }
953
+
954
+ // get odd float32's of a and b, interleaved
955
+ simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
956
+ simd8float32 c;
957
+
958
+ c.f32[0] = a.f32[1];
959
+ c.f32[1] = a.f32[3];
960
+ c.f32[2] = b.f32[1];
961
+ c.f32[3] = b.f32[3];
962
+
963
+ c.f32[4] = a.f32[5];
964
+ c.f32[5] = a.f32[7];
965
+ c.f32[6] = b.f32[5];
966
+ c.f32[7] = b.f32[7];
967
+
968
+ return c;
969
+ }
970
+
971
+ // 3 cycles
972
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
973
+ simd8float32 getlow128(const simd8float32& a, const simd8float32& b) {
974
+ simd8float32 c;
975
+
976
+ c.f32[0] = a.f32[0];
977
+ c.f32[1] = a.f32[1];
978
+ c.f32[2] = a.f32[2];
979
+ c.f32[3] = a.f32[3];
980
+
981
+ c.f32[4] = b.f32[0];
982
+ c.f32[5] = b.f32[1];
983
+ c.f32[6] = b.f32[2];
984
+ c.f32[7] = b.f32[3];
985
+
986
+ return c;
987
+ }
988
+
989
+ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
990
+ simd8float32 c;
991
+
992
+ c.f32[0] = a.f32[4];
993
+ c.f32[1] = a.f32[5];
994
+ c.f32[2] = a.f32[6];
995
+ c.f32[3] = a.f32[7];
996
+
997
+ c.f32[4] = b.f32[4];
998
+ c.f32[5] = b.f32[5];
999
+ c.f32[6] = b.f32[6];
1000
+ c.f32[7] = b.f32[7];
1001
+
1002
+ return c;
1003
+ }
1004
+
1005
+ // The following primitive is a vectorized version of the following code
1006
+ // snippet:
1007
+ // float lowestValue = HUGE_VAL;
1008
+ // uint lowestIndex = 0;
1009
+ // for (size_t i = 0; i < n; i++) {
1010
+ // if (values[i] < lowestValue) {
1011
+ // lowestValue = values[i];
1012
+ // lowestIndex = i;
1013
+ // }
1014
+ // }
1015
+ // Vectorized version can be implemented via two operations: cmp and blend
1016
+ // with something like this:
1017
+ // lowestValues = [HUGE_VAL; 8];
1018
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
1019
+ // for (size_t i = 0; i < n; i += 8) {
1020
+ // auto comparison = cmp(values + i, lowestValues);
1021
+ // lowestValues = blend(
1022
+ // comparison,
1023
+ // values + i,
1024
+ // lowestValues);
1025
+ // lowestIndices = blend(
1026
+ // comparison,
1027
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
1028
+ // lowestIndices);
1029
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
1030
+ // }
1031
+ // The problem is that blend primitive needs very different instruction
1032
+ // order for AVX and ARM.
1033
+ // So, let's introduce a combination of these two in order to avoid
1034
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
1035
+ // these two ops (cmp and blend) are very often used together.
1036
+ inline void cmplt_and_blend_inplace(
1037
+ const simd8float32 candidateValues,
1038
+ const simd8uint32 candidateIndices,
1039
+ simd8float32& lowestValues,
1040
+ simd8uint32& lowestIndices) {
1041
+ for (size_t j = 0; j < 8; j++) {
1042
+ bool comparison = (candidateValues.f32[j] < lowestValues.f32[j]);
1043
+ if (comparison) {
1044
+ lowestValues.f32[j] = candidateValues.f32[j];
1045
+ lowestIndices.u32[j] = candidateIndices.u32[j];
1046
+ }
1047
+ }
1048
+ }
1049
+
1050
+ // Vectorized version of the following code:
1051
+ // for (size_t i = 0; i < n; i++) {
1052
+ // bool flag = (candidateValues[i] < currentValues[i]);
1053
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1054
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1055
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1056
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1057
+ // }
1058
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1059
+ // the last equal value is saved instead of the first one), but this behavior
1060
+ // saves instructions.
1061
+ inline void cmplt_min_max_fast(
1062
+ const simd8float32 candidateValues,
1063
+ const simd8uint32 candidateIndices,
1064
+ const simd8float32 currentValues,
1065
+ const simd8uint32 currentIndices,
1066
+ simd8float32& minValues,
1067
+ simd8uint32& minIndices,
1068
+ simd8float32& maxValues,
1069
+ simd8uint32& maxIndices) {
1070
+ for (size_t i = 0; i < 8; i++) {
1071
+ bool flag = (candidateValues.f32[i] < currentValues.f32[i]);
1072
+ minValues.f32[i] = flag ? candidateValues.f32[i] : currentValues.f32[i];
1073
+ minIndices.u32[i] =
1074
+ flag ? candidateIndices.u32[i] : currentIndices.u32[i];
1075
+ maxValues.f32[i] =
1076
+ !flag ? candidateValues.f32[i] : currentValues.f32[i];
1077
+ maxIndices.u32[i] =
1078
+ !flag ? candidateIndices.u32[i] : currentIndices.u32[i];
1079
+ }
1080
+ }
1081
+
1082
+ } // namespace
1083
+
1084
+ } // namespace faiss