faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,1084 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <algorithm>
11
+ #include <cstdint>
12
+ #include <cstring>
13
+ #include <string>
14
+
15
+ namespace faiss {
16
+
17
+ struct simd256bit {
18
+ union {
19
+ uint8_t u8[32];
20
+ uint16_t u16[16];
21
+ uint32_t u32[8];
22
+ float f32[8];
23
+ };
24
+
25
+ simd256bit() {}
26
+
27
+ explicit simd256bit(const void* x) {
28
+ memcpy(u8, x, 32);
29
+ }
30
+
31
+ void clear() {
32
+ memset(u8, 0, 32);
33
+ }
34
+
35
+ void storeu(void* ptr) const {
36
+ memcpy(ptr, u8, 32);
37
+ }
38
+
39
+ void loadu(const void* ptr) {
40
+ memcpy(u8, ptr, 32);
41
+ }
42
+
43
+ void store(void* ptr) const {
44
+ storeu(ptr);
45
+ }
46
+
47
+ void bin(char bits[257]) const {
48
+ const char* bytes = (char*)this->u8;
49
+ for (int i = 0; i < 256; i++) {
50
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
51
+ }
52
+ bits[256] = 0;
53
+ }
54
+
55
+ std::string bin() const {
56
+ char bits[257];
57
+ bin(bits);
58
+ return std::string(bits);
59
+ }
60
+
61
+ // Checks whether the other holds exactly the same bytes.
62
+ bool is_same_as(simd256bit other) const {
63
+ for (size_t i = 0; i < 8; i++) {
64
+ if (u32[i] != other.u32[i]) {
65
+ return false;
66
+ }
67
+ }
68
+
69
+ return true;
70
+ }
71
+ };
72
+
73
+ /// vector of 16 elements in uint16
74
+ struct simd16uint16 : simd256bit {
75
+ simd16uint16() {}
76
+
77
+ explicit simd16uint16(int x) {
78
+ set1(x);
79
+ }
80
+
81
+ explicit simd16uint16(uint16_t x) {
82
+ set1(x);
83
+ }
84
+
85
+ explicit simd16uint16(const simd256bit& x) : simd256bit(x) {}
86
+
87
+ explicit simd16uint16(const uint16_t* x) : simd256bit((const void*)x) {}
88
+
89
+ explicit simd16uint16(
90
+ uint16_t u0,
91
+ uint16_t u1,
92
+ uint16_t u2,
93
+ uint16_t u3,
94
+ uint16_t u4,
95
+ uint16_t u5,
96
+ uint16_t u6,
97
+ uint16_t u7,
98
+ uint16_t u8,
99
+ uint16_t u9,
100
+ uint16_t u10,
101
+ uint16_t u11,
102
+ uint16_t u12,
103
+ uint16_t u13,
104
+ uint16_t u14,
105
+ uint16_t u15) {
106
+ this->u16[0] = u0;
107
+ this->u16[1] = u1;
108
+ this->u16[2] = u2;
109
+ this->u16[3] = u3;
110
+ this->u16[4] = u4;
111
+ this->u16[5] = u5;
112
+ this->u16[6] = u6;
113
+ this->u16[7] = u7;
114
+ this->u16[8] = u8;
115
+ this->u16[9] = u9;
116
+ this->u16[10] = u10;
117
+ this->u16[11] = u11;
118
+ this->u16[12] = u12;
119
+ this->u16[13] = u13;
120
+ this->u16[14] = u14;
121
+ this->u16[15] = u15;
122
+ }
123
+
124
+ std::string elements_to_string(const char* fmt) const {
125
+ char res[1000], *ptr = res;
126
+ for (int i = 0; i < 16; i++) {
127
+ ptr += sprintf(ptr, fmt, u16[i]);
128
+ }
129
+ // strip last ,
130
+ ptr[-1] = 0;
131
+ return std::string(res);
132
+ }
133
+
134
+ std::string hex() const {
135
+ return elements_to_string("%02x,");
136
+ }
137
+
138
+ std::string dec() const {
139
+ return elements_to_string("%3d,");
140
+ }
141
+
142
+ template <typename F>
143
+ static simd16uint16 unary_func(const simd16uint16& a, F&& f) {
144
+ simd16uint16 c;
145
+ for (int j = 0; j < 16; j++) {
146
+ c.u16[j] = f(a.u16[j]);
147
+ }
148
+ return c;
149
+ }
150
+
151
+ template <typename F>
152
+ static simd16uint16 binary_func(
153
+ const simd16uint16& a,
154
+ const simd16uint16& b,
155
+ F&& f) {
156
+ simd16uint16 c;
157
+ for (int j = 0; j < 16; j++) {
158
+ c.u16[j] = f(a.u16[j], b.u16[j]);
159
+ }
160
+ return c;
161
+ }
162
+
163
+ void set1(uint16_t x) {
164
+ for (int i = 0; i < 16; i++) {
165
+ u16[i] = x;
166
+ }
167
+ }
168
+
169
+ simd16uint16 operator*(const simd16uint16& other) const {
170
+ return binary_func(
171
+ *this, other, [](uint16_t a, uint16_t b) { return a * b; });
172
+ }
173
+
174
+ // shift must be known at compile time
175
+ simd16uint16 operator>>(const int shift) const {
176
+ return unary_func(*this, [shift](uint16_t a) { return a >> shift; });
177
+ }
178
+
179
+ // shift must be known at compile time
180
+ simd16uint16 operator<<(const int shift) const {
181
+ return unary_func(*this, [shift](uint16_t a) { return a << shift; });
182
+ }
183
+
184
+ simd16uint16 operator+=(const simd16uint16& other) {
185
+ *this = *this + other;
186
+ return *this;
187
+ }
188
+
189
+ simd16uint16 operator-=(const simd16uint16& other) {
190
+ *this = *this - other;
191
+ return *this;
192
+ }
193
+
194
+ simd16uint16 operator+(const simd16uint16& other) const {
195
+ return binary_func(
196
+ *this, other, [](uint16_t a, uint16_t b) { return a + b; });
197
+ }
198
+
199
+ simd16uint16 operator-(const simd16uint16& other) const {
200
+ return binary_func(
201
+ *this, other, [](uint16_t a, uint16_t b) { return a - b; });
202
+ }
203
+
204
+ simd16uint16 operator&(const simd256bit& other) const {
205
+ return binary_func(
206
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
207
+ return a & b;
208
+ });
209
+ }
210
+
211
+ simd16uint16 operator|(const simd256bit& other) const {
212
+ return binary_func(
213
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
214
+ return a | b;
215
+ });
216
+ }
217
+
218
+ simd16uint16 operator^(const simd256bit& other) const {
219
+ return binary_func(
220
+ *this, simd16uint16(other), [](uint16_t a, uint16_t b) {
221
+ return a ^ b;
222
+ });
223
+ }
224
+
225
+ // returns binary masks
226
+ simd16uint16 operator==(const simd16uint16& other) const {
227
+ return binary_func(*this, other, [](uint16_t a, uint16_t b) {
228
+ return a == b ? 0xffff : 0;
229
+ });
230
+ }
231
+
232
+ simd16uint16 operator~() const {
233
+ return unary_func(*this, [](uint16_t a) { return ~a; });
234
+ }
235
+
236
+ // get scalar at index 0
237
+ uint16_t get_scalar_0() const {
238
+ return u16[0];
239
+ }
240
+
241
+ // mask of elements where this >= thresh
242
+ // 2 bit per component: 16 * 2 = 32 bit
243
+ uint32_t ge_mask(const simd16uint16& thresh) const {
244
+ uint32_t gem = 0;
245
+ for (int j = 0; j < 16; j++) {
246
+ if (u16[j] >= thresh.u16[j]) {
247
+ gem |= 3 << (j * 2);
248
+ }
249
+ }
250
+ return gem;
251
+ }
252
+
253
+ uint32_t le_mask(const simd16uint16& thresh) const {
254
+ return thresh.ge_mask(*this);
255
+ }
256
+
257
+ uint32_t gt_mask(const simd16uint16& thresh) const {
258
+ return ~le_mask(thresh);
259
+ }
260
+
261
+ bool all_gt(const simd16uint16& thresh) const {
262
+ return le_mask(thresh) == 0;
263
+ }
264
+
265
+ // for debugging only
266
+ uint16_t operator[](int i) const {
267
+ return u16[i];
268
+ }
269
+
270
+ void accu_min(const simd16uint16& incoming) {
271
+ for (int j = 0; j < 16; j++) {
272
+ if (incoming.u16[j] < u16[j]) {
273
+ u16[j] = incoming.u16[j];
274
+ }
275
+ }
276
+ }
277
+
278
+ void accu_max(const simd16uint16& incoming) {
279
+ for (int j = 0; j < 16; j++) {
280
+ if (incoming.u16[j] > u16[j]) {
281
+ u16[j] = incoming.u16[j];
282
+ }
283
+ }
284
+ }
285
+ };
286
+
287
+ // not really a std::min because it returns an elementwise min
288
+ inline simd16uint16 min(const simd16uint16& av, const simd16uint16& bv) {
289
+ return simd16uint16::binary_func(
290
+ av, bv, [](uint16_t a, uint16_t b) { return std::min(a, b); });
291
+ }
292
+
293
+ inline simd16uint16 max(const simd16uint16& av, const simd16uint16& bv) {
294
+ return simd16uint16::binary_func(
295
+ av, bv, [](uint16_t a, uint16_t b) { return std::max(a, b); });
296
+ }
297
+
298
+ // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
299
+ // return (a0 + a1, b0 + b1)
300
+ // TODO find a better name
301
+ inline simd16uint16 combine2x2(const simd16uint16& a, const simd16uint16& b) {
302
+ simd16uint16 c;
303
+ for (int j = 0; j < 8; j++) {
304
+ c.u16[j] = a.u16[j] + a.u16[j + 8];
305
+ c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
306
+ }
307
+ return c;
308
+ }
309
+
310
+ // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
311
+ // of d0 and d1 with thr
312
+ inline uint32_t cmp_ge32(
313
+ const simd16uint16& d0,
314
+ const simd16uint16& d1,
315
+ const simd16uint16& thr) {
316
+ uint32_t gem = 0;
317
+ for (int j = 0; j < 16; j++) {
318
+ if (d0.u16[j] >= thr.u16[j]) {
319
+ gem |= 1 << j;
320
+ }
321
+ if (d1.u16[j] >= thr.u16[j]) {
322
+ gem |= 1 << (j + 16);
323
+ }
324
+ }
325
+ return gem;
326
+ }
327
+
328
+ inline uint32_t cmp_le32(
329
+ const simd16uint16& d0,
330
+ const simd16uint16& d1,
331
+ const simd16uint16& thr) {
332
+ uint32_t gem = 0;
333
+ for (int j = 0; j < 16; j++) {
334
+ if (d0.u16[j] <= thr.u16[j]) {
335
+ gem |= 1 << j;
336
+ }
337
+ if (d1.u16[j] <= thr.u16[j]) {
338
+ gem |= 1 << (j + 16);
339
+ }
340
+ }
341
+ return gem;
342
+ }
343
+
344
+ // hadd does not cross lanes
345
+ inline simd16uint16 hadd(const simd16uint16& a, const simd16uint16& b) {
346
+ simd16uint16 c;
347
+ c.u16[0] = a.u16[0] + a.u16[1];
348
+ c.u16[1] = a.u16[2] + a.u16[3];
349
+ c.u16[2] = a.u16[4] + a.u16[5];
350
+ c.u16[3] = a.u16[6] + a.u16[7];
351
+ c.u16[4] = b.u16[0] + b.u16[1];
352
+ c.u16[5] = b.u16[2] + b.u16[3];
353
+ c.u16[6] = b.u16[4] + b.u16[5];
354
+ c.u16[7] = b.u16[6] + b.u16[7];
355
+
356
+ c.u16[8] = a.u16[8] + a.u16[9];
357
+ c.u16[9] = a.u16[10] + a.u16[11];
358
+ c.u16[10] = a.u16[12] + a.u16[13];
359
+ c.u16[11] = a.u16[14] + a.u16[15];
360
+ c.u16[12] = b.u16[8] + b.u16[9];
361
+ c.u16[13] = b.u16[10] + b.u16[11];
362
+ c.u16[14] = b.u16[12] + b.u16[13];
363
+ c.u16[15] = b.u16[14] + b.u16[15];
364
+
365
+ return c;
366
+ }
367
+
368
+ // Vectorized version of the following code:
369
+ // for (size_t i = 0; i < n; i++) {
370
+ // bool flag = (candidateValues[i] < currentValues[i]);
371
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
372
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
373
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
374
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
375
+ // }
376
+ // Max indices evaluation is inaccurate in case of equal values (the index of
377
+ // the last equal value is saved instead of the first one), but this behavior
378
+ // saves instructions.
379
+ inline void cmplt_min_max_fast(
380
+ const simd16uint16 candidateValues,
381
+ const simd16uint16 candidateIndices,
382
+ const simd16uint16 currentValues,
383
+ const simd16uint16 currentIndices,
384
+ simd16uint16& minValues,
385
+ simd16uint16& minIndices,
386
+ simd16uint16& maxValues,
387
+ simd16uint16& maxIndices) {
388
+ for (size_t i = 0; i < 16; i++) {
389
+ bool flag = (candidateValues.u16[i] < currentValues.u16[i]);
390
+ minValues.u16[i] = flag ? candidateValues.u16[i] : currentValues.u16[i];
391
+ minIndices.u16[i] =
392
+ flag ? candidateIndices.u16[i] : currentIndices.u16[i];
393
+ maxValues.u16[i] =
394
+ !flag ? candidateValues.u16[i] : currentValues.u16[i];
395
+ maxIndices.u16[i] =
396
+ !flag ? candidateIndices.u16[i] : currentIndices.u16[i];
397
+ }
398
+ }
399
+
400
+ // vector of 32 unsigned 8-bit integers
401
+ struct simd32uint8 : simd256bit {
402
+ simd32uint8() {}
403
+
404
+ explicit simd32uint8(int x) {
405
+ set1(x);
406
+ }
407
+
408
+ explicit simd32uint8(uint8_t x) {
409
+ set1(x);
410
+ }
411
+ template <
412
+ uint8_t _0,
413
+ uint8_t _1,
414
+ uint8_t _2,
415
+ uint8_t _3,
416
+ uint8_t _4,
417
+ uint8_t _5,
418
+ uint8_t _6,
419
+ uint8_t _7,
420
+ uint8_t _8,
421
+ uint8_t _9,
422
+ uint8_t _10,
423
+ uint8_t _11,
424
+ uint8_t _12,
425
+ uint8_t _13,
426
+ uint8_t _14,
427
+ uint8_t _15,
428
+ uint8_t _16,
429
+ uint8_t _17,
430
+ uint8_t _18,
431
+ uint8_t _19,
432
+ uint8_t _20,
433
+ uint8_t _21,
434
+ uint8_t _22,
435
+ uint8_t _23,
436
+ uint8_t _24,
437
+ uint8_t _25,
438
+ uint8_t _26,
439
+ uint8_t _27,
440
+ uint8_t _28,
441
+ uint8_t _29,
442
+ uint8_t _30,
443
+ uint8_t _31>
444
+ static simd32uint8 create() {
445
+ simd32uint8 ret;
446
+ ret.u8[0] = _0;
447
+ ret.u8[1] = _1;
448
+ ret.u8[2] = _2;
449
+ ret.u8[3] = _3;
450
+ ret.u8[4] = _4;
451
+ ret.u8[5] = _5;
452
+ ret.u8[6] = _6;
453
+ ret.u8[7] = _7;
454
+ ret.u8[8] = _8;
455
+ ret.u8[9] = _9;
456
+ ret.u8[10] = _10;
457
+ ret.u8[11] = _11;
458
+ ret.u8[12] = _12;
459
+ ret.u8[13] = _13;
460
+ ret.u8[14] = _14;
461
+ ret.u8[15] = _15;
462
+ ret.u8[16] = _16;
463
+ ret.u8[17] = _17;
464
+ ret.u8[18] = _18;
465
+ ret.u8[19] = _19;
466
+ ret.u8[20] = _20;
467
+ ret.u8[21] = _21;
468
+ ret.u8[22] = _22;
469
+ ret.u8[23] = _23;
470
+ ret.u8[24] = _24;
471
+ ret.u8[25] = _25;
472
+ ret.u8[26] = _26;
473
+ ret.u8[27] = _27;
474
+ ret.u8[28] = _28;
475
+ ret.u8[29] = _29;
476
+ ret.u8[30] = _30;
477
+ ret.u8[31] = _31;
478
+ return ret;
479
+ }
480
+
481
+ explicit simd32uint8(const simd256bit& x) : simd256bit(x) {}
482
+
483
+ explicit simd32uint8(const uint8_t* x) : simd256bit((const void*)x) {}
484
+
485
+ std::string elements_to_string(const char* fmt) const {
486
+ char res[1000], *ptr = res;
487
+ for (int i = 0; i < 32; i++) {
488
+ ptr += sprintf(ptr, fmt, u8[i]);
489
+ }
490
+ // strip last ,
491
+ ptr[-1] = 0;
492
+ return std::string(res);
493
+ }
494
+
495
+ std::string hex() const {
496
+ return elements_to_string("%02x,");
497
+ }
498
+
499
+ std::string dec() const {
500
+ return elements_to_string("%3d,");
501
+ }
502
+
503
+ void set1(uint8_t x) {
504
+ for (int j = 0; j < 32; j++) {
505
+ u8[j] = x;
506
+ }
507
+ }
508
+
509
+ template <typename F>
510
+ static simd32uint8 binary_func(
511
+ const simd32uint8& a,
512
+ const simd32uint8& b,
513
+ F&& f) {
514
+ simd32uint8 c;
515
+ for (int j = 0; j < 32; j++) {
516
+ c.u8[j] = f(a.u8[j], b.u8[j]);
517
+ }
518
+ return c;
519
+ }
520
+
521
+ simd32uint8 operator&(const simd256bit& other) const {
522
+ return binary_func(*this, simd32uint8(other), [](uint8_t a, uint8_t b) {
523
+ return a & b;
524
+ });
525
+ }
526
+
527
+ simd32uint8 operator+(const simd32uint8& other) const {
528
+ return binary_func(
529
+ *this, other, [](uint8_t a, uint8_t b) { return a + b; });
530
+ }
531
+
532
+ // The very important operation that everything relies on
533
+ simd32uint8 lookup_2_lanes(const simd32uint8& idx) const {
534
+ simd32uint8 c;
535
+ // The original for loop:
536
+ // for (int j = 0; j < 32; j++) {
537
+ // if (idx.u8[j] & 0x80) {
538
+ // c.u8[j] = 0;
539
+ // } else {
540
+ // uint8_t i = idx.u8[j] & 15;
541
+ // if (j < 16) {
542
+ // c.u8[j] = u8[i];
543
+ // } else {
544
+ // c.u8[j] = u8[16 + i];
545
+ // }
546
+ // }
547
+
548
+ // The following function was re-written for Power 10
549
+ // The loop was unrolled to remove the if (j < 16) statement by doing
550
+ // the j and j + 16 iterations in parallel. The additional unrolling
551
+ // for j + 1 and j + 17, reduces the execution time on Power 10 by
552
+ // about 50% as the instruction scheduling allows on average 2X more
553
+ // instructions to be issued per cycle.
554
+
555
+ for (int j = 0; j < 16; j = j + 2) {
556
+ // j < 16, unrolled to depth of 2
557
+ if (idx.u8[j] & 0x80) {
558
+ c.u8[j] = 0;
559
+ } else {
560
+ uint8_t i = idx.u8[j] & 15;
561
+ c.u8[j] = u8[i];
562
+ }
563
+
564
+ if (idx.u8[j + 1] & 0x80) {
565
+ c.u8[j + 1] = 0;
566
+ } else {
567
+ uint8_t i = idx.u8[j + 1] & 15;
568
+ c.u8[j + 1] = u8[i];
569
+ }
570
+
571
+ // j >= 16, unrolled to depth of 2
572
+ if (idx.u8[j + 16] & 0x80) {
573
+ c.u8[j + 16] = 0;
574
+ } else {
575
+ uint8_t i = idx.u8[j + 16] & 15;
576
+ c.u8[j + 16] = u8[i + 16];
577
+ }
578
+
579
+ if (idx.u8[j + 17] & 0x80) {
580
+ c.u8[j + 17] = 0;
581
+ } else {
582
+ uint8_t i = idx.u8[j + 17] & 15;
583
+ c.u8[j + 17] = u8[i + 16];
584
+ }
585
+ }
586
+ return c;
587
+ }
588
+
589
+ // extract + 0-extend lane
590
+ // this operation is slow (3 cycles)
591
+
592
+ simd32uint8 operator+=(const simd32uint8& other) {
593
+ *this = *this + other;
594
+ return *this;
595
+ }
596
+
597
+ // for debugging only
598
+ uint8_t operator[](int i) const {
599
+ return u8[i];
600
+ }
601
+ };
602
+
603
+ // convert with saturation
604
+ // careful: this does not cross lanes, so the order is weird
605
+ inline simd32uint8 uint16_to_uint8_saturate(
606
+ const simd16uint16& a,
607
+ const simd16uint16& b) {
608
+ simd32uint8 c;
609
+
610
+ auto saturate_16_to_8 = [](uint16_t x) { return x >= 256 ? 0xff : x; };
611
+
612
+ for (int i = 0; i < 8; i++) {
613
+ c.u8[i] = saturate_16_to_8(a.u16[i]);
614
+ c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
615
+ c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
616
+ c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
617
+ }
618
+ return c;
619
+ }
620
+
621
+ /// get most significant bit of each byte
622
+ inline uint32_t get_MSBs(const simd32uint8& a) {
623
+ uint32_t res = 0;
624
+ for (int i = 0; i < 32; i++) {
625
+ if (a.u8[i] & 0x80) {
626
+ res |= 1 << i;
627
+ }
628
+ }
629
+ return res;
630
+ }
631
+
632
+ /// use MSB of each byte of mask to select a byte between a and b
633
+ inline simd32uint8 blendv(
634
+ const simd32uint8& a,
635
+ const simd32uint8& b,
636
+ const simd32uint8& mask) {
637
+ simd32uint8 c;
638
+ for (int i = 0; i < 32; i++) {
639
+ if (mask.u8[i] & 0x80) {
640
+ c.u8[i] = b.u8[i];
641
+ } else {
642
+ c.u8[i] = a.u8[i];
643
+ }
644
+ }
645
+ return c;
646
+ }
647
+
648
+ /// vector of 8 unsigned 32-bit integers
649
+ struct simd8uint32 : simd256bit {
650
+ simd8uint32() {}
651
+
652
+ explicit simd8uint32(uint32_t x) {
653
+ set1(x);
654
+ }
655
+
656
+ explicit simd8uint32(const simd256bit& x) : simd256bit(x) {}
657
+
658
+ explicit simd8uint32(const uint32_t* x) : simd256bit((const void*)x) {}
659
+
660
+ explicit simd8uint32(
661
+ uint32_t u0,
662
+ uint32_t u1,
663
+ uint32_t u2,
664
+ uint32_t u3,
665
+ uint32_t u4,
666
+ uint32_t u5,
667
+ uint32_t u6,
668
+ uint32_t u7) {
669
+ u32[0] = u0;
670
+ u32[1] = u1;
671
+ u32[2] = u2;
672
+ u32[3] = u3;
673
+ u32[4] = u4;
674
+ u32[5] = u5;
675
+ u32[6] = u6;
676
+ u32[7] = u7;
677
+ }
678
+
679
+ simd8uint32 operator+(simd8uint32 other) const {
680
+ simd8uint32 result;
681
+ for (int i = 0; i < 8; i++) {
682
+ result.u32[i] = u32[i] + other.u32[i];
683
+ }
684
+ return result;
685
+ }
686
+
687
+ simd8uint32 operator-(simd8uint32 other) const {
688
+ simd8uint32 result;
689
+ for (int i = 0; i < 8; i++) {
690
+ result.u32[i] = u32[i] - other.u32[i];
691
+ }
692
+ return result;
693
+ }
694
+
695
+ simd8uint32& operator+=(const simd8uint32& other) {
696
+ for (int i = 0; i < 8; i++) {
697
+ u32[i] += other.u32[i];
698
+ }
699
+ return *this;
700
+ }
701
+
702
+ bool operator==(simd8uint32 other) const {
703
+ for (size_t i = 0; i < 8; i++) {
704
+ if (u32[i] != other.u32[i]) {
705
+ return false;
706
+ }
707
+ }
708
+
709
+ return true;
710
+ }
711
+
712
+ bool operator!=(simd8uint32 other) const {
713
+ return !(*this == other);
714
+ }
715
+
716
+ std::string elements_to_string(const char* fmt) const {
717
+ char res[1000], *ptr = res;
718
+ for (int i = 0; i < 8; i++) {
719
+ ptr += sprintf(ptr, fmt, u32[i]);
720
+ }
721
+ // strip last ,
722
+ ptr[-1] = 0;
723
+ return std::string(res);
724
+ }
725
+
726
+ std::string hex() const {
727
+ return elements_to_string("%08x,");
728
+ }
729
+
730
+ std::string dec() const {
731
+ return elements_to_string("%10d,");
732
+ }
733
+
734
+ void set1(uint32_t x) {
735
+ for (int i = 0; i < 8; i++) {
736
+ u32[i] = x;
737
+ }
738
+ }
739
+
740
+ simd8uint32 unzip() const {
741
+ const uint32_t ret[] = {
742
+ u32[0], u32[2], u32[4], u32[6], u32[1], u32[3], u32[5], u32[7]};
743
+ return simd8uint32{ret};
744
+ }
745
+ };
746
+
747
+ // Vectorized version of the following code:
748
+ // for (size_t i = 0; i < n; i++) {
749
+ // bool flag = (candidateValues[i] < currentValues[i]);
750
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
751
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
752
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
753
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
754
+ // }
755
+ // Max indices evaluation is inaccurate in case of equal values (the index of
756
+ // the last equal value is saved instead of the first one), but this behavior
757
+ // saves instructions.
758
+ inline void cmplt_min_max_fast(
759
+ const simd8uint32 candidateValues,
760
+ const simd8uint32 candidateIndices,
761
+ const simd8uint32 currentValues,
762
+ const simd8uint32 currentIndices,
763
+ simd8uint32& minValues,
764
+ simd8uint32& minIndices,
765
+ simd8uint32& maxValues,
766
+ simd8uint32& maxIndices) {
767
+ for (size_t i = 0; i < 8; i++) {
768
+ bool flag = (candidateValues.u32[i] < currentValues.u32[i]);
769
+ minValues.u32[i] = flag ? candidateValues.u32[i] : currentValues.u32[i];
770
+ minIndices.u32[i] =
771
+ flag ? candidateIndices.u32[i] : currentIndices.u32[i];
772
+ maxValues.u32[i] =
773
+ !flag ? candidateValues.u32[i] : currentValues.u32[i];
774
+ maxIndices.u32[i] =
775
+ !flag ? candidateIndices.u32[i] : currentIndices.u32[i];
776
+ }
777
+ }
778
+
779
+ struct simd8float32 : simd256bit {
780
+ simd8float32() {}
781
+
782
+ explicit simd8float32(const simd256bit& x) : simd256bit(x) {}
783
+
784
+ explicit simd8float32(float x) {
785
+ set1(x);
786
+ }
787
+
788
+ explicit simd8float32(const float* x) {
789
+ loadu((void*)x);
790
+ }
791
+
792
+ void set1(float x) {
793
+ for (int i = 0; i < 8; i++) {
794
+ f32[i] = x;
795
+ }
796
+ }
797
+
798
+ explicit simd8float32(
799
+ float f0,
800
+ float f1,
801
+ float f2,
802
+ float f3,
803
+ float f4,
804
+ float f5,
805
+ float f6,
806
+ float f7) {
807
+ f32[0] = f0;
808
+ f32[1] = f1;
809
+ f32[2] = f2;
810
+ f32[3] = f3;
811
+ f32[4] = f4;
812
+ f32[5] = f5;
813
+ f32[6] = f6;
814
+ f32[7] = f7;
815
+ }
816
+
817
+ template <typename F>
818
+ static simd8float32 binary_func(
819
+ const simd8float32& a,
820
+ const simd8float32& b,
821
+ F&& f) {
822
+ simd8float32 c;
823
+ for (int j = 0; j < 8; j++) {
824
+ c.f32[j] = f(a.f32[j], b.f32[j]);
825
+ }
826
+ return c;
827
+ }
828
+
829
+ simd8float32 operator*(const simd8float32& other) const {
830
+ return binary_func(
831
+ *this, other, [](float a, float b) { return a * b; });
832
+ }
833
+
834
+ simd8float32 operator+(const simd8float32& other) const {
835
+ return binary_func(
836
+ *this, other, [](float a, float b) { return a + b; });
837
+ }
838
+
839
+ simd8float32 operator-(const simd8float32& other) const {
840
+ return binary_func(
841
+ *this, other, [](float a, float b) { return a - b; });
842
+ }
843
+
844
+ simd8float32& operator+=(const simd8float32& other) {
845
+ for (size_t i = 0; i < 8; i++) {
846
+ f32[i] += other.f32[i];
847
+ }
848
+
849
+ return *this;
850
+ }
851
+
852
+ bool operator==(simd8float32 other) const {
853
+ for (size_t i = 0; i < 8; i++) {
854
+ if (f32[i] != other.f32[i]) {
855
+ return false;
856
+ }
857
+ }
858
+
859
+ return true;
860
+ }
861
+
862
+ bool operator!=(simd8float32 other) const {
863
+ return !(*this == other);
864
+ }
865
+
866
+ std::string tostring() const {
867
+ char res[1000], *ptr = res;
868
+ for (int i = 0; i < 8; i++) {
869
+ ptr += sprintf(ptr, "%g,", f32[i]);
870
+ }
871
+ // strip last ,
872
+ ptr[-1] = 0;
873
+ return std::string(res);
874
+ }
875
+ };
876
+
877
+ // hadd does not cross lanes
878
+ inline simd8float32 hadd(const simd8float32& a, const simd8float32& b) {
879
+ simd8float32 c;
880
+ c.f32[0] = a.f32[0] + a.f32[1];
881
+ c.f32[1] = a.f32[2] + a.f32[3];
882
+ c.f32[2] = b.f32[0] + b.f32[1];
883
+ c.f32[3] = b.f32[2] + b.f32[3];
884
+
885
+ c.f32[4] = a.f32[4] + a.f32[5];
886
+ c.f32[5] = a.f32[6] + a.f32[7];
887
+ c.f32[6] = b.f32[4] + b.f32[5];
888
+ c.f32[7] = b.f32[6] + b.f32[7];
889
+
890
+ return c;
891
+ }
892
+
893
+ inline simd8float32 unpacklo(const simd8float32& a, const simd8float32& b) {
894
+ simd8float32 c;
895
+ c.f32[0] = a.f32[0];
896
+ c.f32[1] = b.f32[0];
897
+ c.f32[2] = a.f32[1];
898
+ c.f32[3] = b.f32[1];
899
+
900
+ c.f32[4] = a.f32[4];
901
+ c.f32[5] = b.f32[4];
902
+ c.f32[6] = a.f32[5];
903
+ c.f32[7] = b.f32[5];
904
+
905
+ return c;
906
+ }
907
+
908
+ inline simd8float32 unpackhi(const simd8float32& a, const simd8float32& b) {
909
+ simd8float32 c;
910
+ c.f32[0] = a.f32[2];
911
+ c.f32[1] = b.f32[2];
912
+ c.f32[2] = a.f32[3];
913
+ c.f32[3] = b.f32[3];
914
+
915
+ c.f32[4] = a.f32[6];
916
+ c.f32[5] = b.f32[6];
917
+ c.f32[6] = a.f32[7];
918
+ c.f32[7] = b.f32[7];
919
+
920
+ return c;
921
+ }
922
+
923
+ // compute a * b + c
924
+ inline simd8float32 fmadd(
925
+ const simd8float32& a,
926
+ const simd8float32& b,
927
+ const simd8float32& c) {
928
+ simd8float32 res;
929
+ for (int i = 0; i < 8; i++) {
930
+ res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
931
+ }
932
+ return res;
933
+ }
934
+
935
+ namespace {
936
+
937
+ // get even float32's of a and b, interleaved
938
+ simd8float32 geteven(const simd8float32& a, const simd8float32& b) {
939
+ simd8float32 c;
940
+
941
+ c.f32[0] = a.f32[0];
942
+ c.f32[1] = a.f32[2];
943
+ c.f32[2] = b.f32[0];
944
+ c.f32[3] = b.f32[2];
945
+
946
+ c.f32[4] = a.f32[4];
947
+ c.f32[5] = a.f32[6];
948
+ c.f32[6] = b.f32[4];
949
+ c.f32[7] = b.f32[6];
950
+
951
+ return c;
952
+ }
953
+
954
+ // get odd float32's of a and b, interleaved
955
+ simd8float32 getodd(const simd8float32& a, const simd8float32& b) {
956
+ simd8float32 c;
957
+
958
+ c.f32[0] = a.f32[1];
959
+ c.f32[1] = a.f32[3];
960
+ c.f32[2] = b.f32[1];
961
+ c.f32[3] = b.f32[3];
962
+
963
+ c.f32[4] = a.f32[5];
964
+ c.f32[5] = a.f32[7];
965
+ c.f32[6] = b.f32[5];
966
+ c.f32[7] = b.f32[7];
967
+
968
+ return c;
969
+ }
970
+
971
+ // 3 cycles
972
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
973
+ simd8float32 getlow128(const simd8float32& a, const simd8float32& b) {
974
+ simd8float32 c;
975
+
976
+ c.f32[0] = a.f32[0];
977
+ c.f32[1] = a.f32[1];
978
+ c.f32[2] = a.f32[2];
979
+ c.f32[3] = a.f32[3];
980
+
981
+ c.f32[4] = b.f32[0];
982
+ c.f32[5] = b.f32[1];
983
+ c.f32[6] = b.f32[2];
984
+ c.f32[7] = b.f32[3];
985
+
986
+ return c;
987
+ }
988
+
989
+ simd8float32 gethigh128(const simd8float32& a, const simd8float32& b) {
990
+ simd8float32 c;
991
+
992
+ c.f32[0] = a.f32[4];
993
+ c.f32[1] = a.f32[5];
994
+ c.f32[2] = a.f32[6];
995
+ c.f32[3] = a.f32[7];
996
+
997
+ c.f32[4] = b.f32[4];
998
+ c.f32[5] = b.f32[5];
999
+ c.f32[6] = b.f32[6];
1000
+ c.f32[7] = b.f32[7];
1001
+
1002
+ return c;
1003
+ }
1004
+
1005
+ // The following primitive is a vectorized version of the following code
1006
+ // snippet:
1007
+ // float lowestValue = HUGE_VAL;
1008
+ // uint lowestIndex = 0;
1009
+ // for (size_t i = 0; i < n; i++) {
1010
+ // if (values[i] < lowestValue) {
1011
+ // lowestValue = values[i];
1012
+ // lowestIndex = i;
1013
+ // }
1014
+ // }
1015
+ // Vectorized version can be implemented via two operations: cmp and blend
1016
+ // with something like this:
1017
+ // lowestValues = [HUGE_VAL; 8];
1018
+ // lowestIndices = {0, 1, 2, 3, 4, 5, 6, 7};
1019
+ // for (size_t i = 0; i < n; i += 8) {
1020
+ // auto comparison = cmp(values + i, lowestValues);
1021
+ // lowestValues = blend(
1022
+ // comparison,
1023
+ // values + i,
1024
+ // lowestValues);
1025
+ // lowestIndices = blend(
1026
+ // comparison,
1027
+ // i + {0, 1, 2, 3, 4, 5, 6, 7},
1028
+ // lowestIndices);
1029
+ // lowestIndices += {8, 8, 8, 8, 8, 8, 8, 8};
1030
+ // }
1031
+ // The problem is that blend primitive needs very different instruction
1032
+ // order for AVX and ARM.
1033
+ // So, let's introduce a combination of these two in order to avoid
1034
+ // confusion for ppl who write in low-level SIMD instructions. Additionally,
1035
+ // these two ops (cmp and blend) are very often used together.
1036
+ inline void cmplt_and_blend_inplace(
1037
+ const simd8float32 candidateValues,
1038
+ const simd8uint32 candidateIndices,
1039
+ simd8float32& lowestValues,
1040
+ simd8uint32& lowestIndices) {
1041
+ for (size_t j = 0; j < 8; j++) {
1042
+ bool comparison = (candidateValues.f32[j] < lowestValues.f32[j]);
1043
+ if (comparison) {
1044
+ lowestValues.f32[j] = candidateValues.f32[j];
1045
+ lowestIndices.u32[j] = candidateIndices.u32[j];
1046
+ }
1047
+ }
1048
+ }
1049
+
1050
+ // Vectorized version of the following code:
1051
+ // for (size_t i = 0; i < n; i++) {
1052
+ // bool flag = (candidateValues[i] < currentValues[i]);
1053
+ // minValues[i] = flag ? candidateValues[i] : currentValues[i];
1054
+ // minIndices[i] = flag ? candidateIndices[i] : currentIndices[i];
1055
+ // maxValues[i] = !flag ? candidateValues[i] : currentValues[i];
1056
+ // maxIndices[i] = !flag ? candidateIndices[i] : currentIndices[i];
1057
+ // }
1058
+ // Max indices evaluation is inaccurate in case of equal values (the index of
1059
+ // the last equal value is saved instead of the first one), but this behavior
1060
+ // saves instructions.
1061
+ inline void cmplt_min_max_fast(
1062
+ const simd8float32 candidateValues,
1063
+ const simd8uint32 candidateIndices,
1064
+ const simd8float32 currentValues,
1065
+ const simd8uint32 currentIndices,
1066
+ simd8float32& minValues,
1067
+ simd8uint32& minIndices,
1068
+ simd8float32& maxValues,
1069
+ simd8uint32& maxIndices) {
1070
+ for (size_t i = 0; i < 8; i++) {
1071
+ bool flag = (candidateValues.f32[i] < currentValues.f32[i]);
1072
+ minValues.f32[i] = flag ? candidateValues.f32[i] : currentValues.f32[i];
1073
+ minIndices.u32[i] =
1074
+ flag ? candidateIndices.u32[i] : currentIndices.u32[i];
1075
+ maxValues.f32[i] =
1076
+ !flag ? candidateValues.f32[i] : currentValues.f32[i];
1077
+ maxIndices.u32[i] =
1078
+ !flag ? candidateIndices.u32[i] : currentIndices.u32[i];
1079
+ }
1080
+ }
1081
+
1082
+ } // namespace
1083
+
1084
+ } // namespace faiss