faiss 0.1.3 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,589 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <string>
11
+ #include <cstdint>
12
+ #include <cstring>
13
+ #include <functional>
14
+ #include <algorithm>
15
+
16
+ namespace faiss {
17
+
18
+
19
+ struct simd256bit {
20
+
21
+ union {
22
+ uint8_t u8[32];
23
+ uint16_t u16[16];
24
+ uint32_t u32[8];
25
+ float f32[8];
26
+ };
27
+
28
+ simd256bit() {}
29
+
30
+ explicit simd256bit(const void *x)
31
+ {
32
+ memcpy(u8, x, 32);
33
+ }
34
+
35
+ void clear() {
36
+ memset(u8, 0, 32);
37
+ }
38
+
39
+ void storeu(void *ptr) const {
40
+ memcpy(ptr, u8, 32);
41
+ }
42
+
43
+ void loadu(const void *ptr) {
44
+ memcpy(u8, ptr, 32);
45
+ }
46
+
47
+ void store(void *ptr) const {
48
+ storeu(ptr);
49
+ }
50
+
51
+ void bin(char bits[257]) const {
52
+ const char *bytes = (char*)this->u8;
53
+ for (int i = 0; i < 256; i++) {
54
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
55
+ }
56
+ bits[256] = 0;
57
+ }
58
+
59
+ std::string bin() const {
60
+ char bits[257];
61
+ bin(bits);
62
+ return std::string(bits);
63
+ }
64
+
65
+ };
66
+
67
+
68
+
69
+
70
+ /// vector of 16 elements in uint16
71
+ struct simd16uint16: simd256bit {
72
+ simd16uint16() {}
73
+
74
+ explicit simd16uint16(int x) {
75
+ set1(x);
76
+ }
77
+
78
+ explicit simd16uint16(uint16_t x) {
79
+ set1(x);
80
+ }
81
+
82
+ explicit simd16uint16(simd256bit x): simd256bit(x) {}
83
+
84
+ explicit simd16uint16(const uint16_t *x): simd256bit((const void*)x) {}
85
+
86
+ std::string elements_to_string(const char * fmt) const {
87
+ char res[1000], *ptr = res;
88
+ for(int i = 0; i < 16; i++) {
89
+ ptr += sprintf(ptr, fmt, u16[i]);
90
+ }
91
+ // strip last ,
92
+ ptr[-1] = 0;
93
+ return std::string(res);
94
+ }
95
+
96
+ std::string hex() const {
97
+ return elements_to_string("%02x,");
98
+ }
99
+
100
+ std::string dec() const {
101
+ return elements_to_string("%3d,");
102
+ }
103
+
104
+ static simd16uint16 unary_func(
105
+ simd16uint16 a, std::function<uint16_t (uint16_t)> f)
106
+ {
107
+ simd16uint16 c;
108
+ for(int j = 0; j < 16; j++) {
109
+ c.u16[j] = f(a.u16[j]);
110
+ }
111
+ return c;
112
+ }
113
+
114
+
115
+ static simd16uint16 binary_func(
116
+ simd16uint16 a, simd16uint16 b,
117
+ std::function<uint16_t (uint16_t, uint16_t)> f)
118
+ {
119
+ simd16uint16 c;
120
+ for(int j = 0; j < 16; j++) {
121
+ c.u16[j] = f(a.u16[j], b.u16[j]);
122
+ }
123
+ return c;
124
+ }
125
+
126
+ void set1(uint16_t x) {
127
+ for(int i = 0; i < 16; i++) {
128
+ u16[i] = x;
129
+ }
130
+ }
131
+
132
+ // shift must be known at compile time
133
+ simd16uint16 operator >> (const int shift) const {
134
+ return unary_func(*this, [shift](uint16_t a) {return a >> shift; });
135
+ }
136
+
137
+
138
+ // shift must be known at compile time
139
+ simd16uint16 operator << (const int shift) const {
140
+ return unary_func(*this, [shift](uint16_t a) {return a << shift; });
141
+ }
142
+
143
+ simd16uint16 operator += (simd16uint16 other) {
144
+ *this = *this + other;
145
+ return *this;
146
+ }
147
+
148
+ simd16uint16 operator -= (simd16uint16 other) {
149
+ *this = *this - other;
150
+ return *this;
151
+ }
152
+
153
+ simd16uint16 operator + (simd16uint16 other) const {
154
+ return binary_func(*this, other,
155
+ [](uint16_t a, uint16_t b) {return a + b; }
156
+ );
157
+ }
158
+
159
+ simd16uint16 operator - (simd16uint16 other) const {
160
+ return binary_func(*this, other,
161
+ [](uint16_t a, uint16_t b) {return a - b; }
162
+ );
163
+ }
164
+
165
+ simd16uint16 operator & (simd256bit other) const {
166
+ return binary_func(*this, simd16uint16(other),
167
+ [](uint16_t a, uint16_t b) {return a & b; }
168
+ );
169
+ }
170
+
171
+ simd16uint16 operator | (simd256bit other) const {
172
+ return binary_func(*this, simd16uint16(other),
173
+ [](uint16_t a, uint16_t b) {return a | b; }
174
+ );
175
+ }
176
+
177
+ // returns binary masks
178
+ simd16uint16 operator == (simd16uint16 other) const {
179
+ return binary_func(*this, other,
180
+ [](uint16_t a, uint16_t b) {return a == b ? 0xffff : 0; }
181
+ );
182
+ }
183
+
184
+ simd16uint16 operator ~() const {
185
+ return unary_func(*this, [](uint16_t a) {return ~a; });
186
+ }
187
+
188
+ // get scalar at index 0
189
+ uint16_t get_scalar_0() const {
190
+ return u16[0];
191
+ }
192
+
193
+ // mask of elements where this >= thresh
194
+ // 2 bit per component: 16 * 2 = 32 bit
195
+ uint32_t ge_mask(simd16uint16 thresh) const {
196
+ uint32_t gem = 0;
197
+ for(int j = 0; j < 16; j++) {
198
+ if (u16[j] >= thresh.u16[j]) {
199
+ gem |= 3 << (j * 2);
200
+ }
201
+ }
202
+ return gem;
203
+ }
204
+
205
+ uint32_t le_mask(simd16uint16 thresh) const {
206
+ return thresh.ge_mask(*this);
207
+ }
208
+
209
+ uint32_t gt_mask(simd16uint16 thresh) const {
210
+ return ~le_mask(thresh);
211
+ }
212
+
213
+ bool all_gt(simd16uint16 thresh) const {
214
+ return le_mask(thresh) == 0;
215
+ }
216
+
217
+ // for debugging only
218
+ uint16_t operator [] (int i) const {
219
+ return u16[i];
220
+ }
221
+
222
+ void accu_min(simd16uint16 incoming) {
223
+ for(int j = 0; j < 16; j++) {
224
+ if (incoming.u16[j] < u16[j]) {
225
+ u16[j] = incoming.u16[j];
226
+ }
227
+ }
228
+ }
229
+
230
+ void accu_max(simd16uint16 incoming) {
231
+ for(int j = 0; j < 16; j++) {
232
+ if (incoming.u16[j] > u16[j]) {
233
+ u16[j] = incoming.u16[j];
234
+ }
235
+ }
236
+ }
237
+
238
+ };
239
+
240
+
241
+ // not really a std::min because it returns an elementwise min
242
+ inline simd16uint16 min(simd16uint16 av, simd16uint16 bv) {
243
+ return simd16uint16::binary_func(av, bv,
244
+ [](uint16_t a, uint16_t b) {return std::min(a, b); }
245
+ );
246
+ }
247
+
248
+ inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
249
+ return simd16uint16::binary_func(av, bv,
250
+ [](uint16_t a, uint16_t b) {return std::max(a, b); }
251
+ );
252
+ }
253
+
254
+ // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
255
+ // return (a0 + a1, b0 + b1)
256
+ // TODO find a better name
257
+ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
258
+ simd16uint16 c;
259
+ for(int j = 0; j < 8; j++) {
260
+ c.u16[j] = a.u16[j] + a.u16[j + 8];
261
+ c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
262
+ }
263
+ return c;
264
+ }
265
+
266
+ // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
267
+ // of d0 and d1 with thr
268
+ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
269
+ uint32_t gem = 0;
270
+ for(int j = 0; j < 16; j++) {
271
+ if (d0.u16[j] >= thr.u16[j]) {
272
+ gem |= 1 << j;
273
+ }
274
+ if (d1.u16[j] >= thr.u16[j]) {
275
+ gem |= 1 << (j + 16);
276
+ }
277
+ }
278
+ return gem;
279
+ }
280
+
281
+
282
+ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
283
+ uint32_t gem = 0;
284
+ for(int j = 0; j < 16; j++) {
285
+ if (d0.u16[j] <= thr.u16[j]) {
286
+ gem |= 1 << j;
287
+ }
288
+ if (d1.u16[j] <= thr.u16[j]) {
289
+ gem |= 1 << (j + 16);
290
+ }
291
+ }
292
+ return gem;
293
+ }
294
+
295
+
296
+
297
+ // vector of 32 unsigned 8-bit integers
298
+ struct simd32uint8: simd256bit {
299
+
300
+ simd32uint8() {}
301
+
302
+ explicit simd32uint8(int x) {set1(x); }
303
+
304
+ explicit simd32uint8(uint8_t x) {set1(x); }
305
+
306
+ explicit simd32uint8(simd256bit x): simd256bit(x) {}
307
+
308
+ explicit simd32uint8(const uint8_t *x): simd256bit((const void*)x) {}
309
+
310
+ std::string elements_to_string(const char * fmt) const {
311
+ char res[1000], *ptr = res;
312
+ for(int i = 0; i < 32; i++) {
313
+ ptr += sprintf(ptr, fmt, u8[i]);
314
+ }
315
+ // strip last ,
316
+ ptr[-1] = 0;
317
+ return std::string(res);
318
+ }
319
+
320
+ std::string hex() const {
321
+ return elements_to_string("%02x,");
322
+ }
323
+
324
+ std::string dec() const {
325
+ return elements_to_string("%3d,");
326
+ }
327
+
328
+ void set1(uint8_t x) {
329
+ for(int j = 0; j < 32; j++) {
330
+ u8[j] = x;
331
+ }
332
+ }
333
+
334
+ static simd32uint8 binary_func(
335
+ simd32uint8 a, simd32uint8 b,
336
+ std::function<uint8_t (uint8_t, uint8_t)> f)
337
+ {
338
+ simd32uint8 c;
339
+ for(int j = 0; j < 32; j++) {
340
+ c.u8[j] = f(a.u8[j], b.u8[j]);
341
+ }
342
+ return c;
343
+ }
344
+
345
+
346
+ simd32uint8 operator & (simd256bit other) const {
347
+ return binary_func(*this, simd32uint8(other),
348
+ [](uint8_t a, uint8_t b) {return a & b; }
349
+ );
350
+ }
351
+
352
+ simd32uint8 operator + (simd32uint8 other) const {
353
+ return binary_func(*this, other,
354
+ [](uint8_t a, uint8_t b) {return a + b; }
355
+ );
356
+ }
357
+
358
+ // The very important operation that everything relies on
359
+ simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
360
+ simd32uint8 c;
361
+ for(int j = 0; j < 32; j++) {
362
+ if (idx.u8[j] & 0x80) {
363
+ c.u8[j] = 0;
364
+ } else {
365
+ uint8_t i = idx.u8[j] & 15;
366
+ if (j < 16) {
367
+ c.u8[j] = u8[i];
368
+ } else {
369
+ c.u8[j] = u8[16 + i];
370
+ }
371
+ }
372
+ }
373
+ return c;
374
+ }
375
+
376
+ // extract + 0-extend lane
377
+ // this operation is slow (3 cycles)
378
+
379
+ simd32uint8 operator += (simd32uint8 other) {
380
+ *this = *this + other;
381
+ return *this;
382
+ }
383
+
384
+ // for debugging only
385
+ uint8_t operator [] (int i) const {
386
+ return u8[i];
387
+ }
388
+
389
+ };
390
+
391
+
392
+ // convert with saturation
393
+ // careful: this does not cross lanes, so the order is weird
394
+ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
395
+ simd32uint8 c;
396
+
397
+ auto saturate_16_to_8 = [] (uint16_t x) {
398
+ return x >= 256 ? 0xff : x;
399
+ };
400
+
401
+ for (int i = 0; i < 8; i++) {
402
+ c.u8[ i] = saturate_16_to_8(a.u16[i]);
403
+ c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
404
+ c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
405
+ c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
406
+ }
407
+ return c;
408
+ }
409
+
410
+ /// get most significant bit of each byte
411
+ inline uint32_t get_MSBs(simd32uint8 a) {
412
+ uint32_t res = 0;
413
+ for (int i = 0; i < 32; i++) {
414
+ if (a.u8[i] & 0x80) {
415
+ res |= 1 << i;
416
+ }
417
+ }
418
+ return res;
419
+ }
420
+
421
+ /// use MSB of each byte of mask to select a byte between a and b
422
+ inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
423
+ simd32uint8 c;
424
+ for (int i = 0; i < 32; i++) {
425
+ if (mask.u8[i] & 0x80) {
426
+ c.u8[i] = b.u8[i];
427
+ } else {
428
+ c.u8[i] = a.u8[i];
429
+ }
430
+ }
431
+ return c;
432
+ }
433
+
434
+
435
+
436
+
437
+ /// vector of 8 unsigned 32-bit integers
438
+ struct simd8uint32: simd256bit {
439
+ simd8uint32() {}
440
+
441
+
442
+ explicit simd8uint32(uint32_t x) {set1(x); }
443
+
444
+ explicit simd8uint32(simd256bit x): simd256bit(x) {}
445
+
446
+ explicit simd8uint32(const uint8_t *x): simd256bit((const void*)x) {}
447
+
448
+ std::string elements_to_string(const char * fmt) const {
449
+ char res[1000], *ptr = res;
450
+ for(int i = 0; i < 8; i++) {
451
+ ptr += sprintf(ptr, fmt, u32[i]);
452
+ }
453
+ // strip last ,
454
+ ptr[-1] = 0;
455
+ return std::string(res);
456
+ }
457
+
458
+ std::string hex() const {
459
+ return elements_to_string("%08x,");
460
+ }
461
+
462
+ std::string dec() const {
463
+ return elements_to_string("%10d,");
464
+ }
465
+
466
+ void set1(uint32_t x) {
467
+ for (int i = 0; i < 8; i++) {
468
+ u32[i] = x;
469
+ }
470
+ }
471
+
472
+ };
473
+
474
+ struct simd8float32: simd256bit {
475
+
476
+ simd8float32() {}
477
+
478
+ explicit simd8float32(simd256bit x): simd256bit(x) {}
479
+
480
+ explicit simd8float32(float x) {set1(x); }
481
+
482
+ explicit simd8float32(const float *x) {loadu((void*)x); }
483
+
484
+ void set1(float x) {
485
+ for(int i = 0; i < 8; i++) {
486
+ f32[i] = x;
487
+ }
488
+ }
489
+
490
+ static simd8float32 binary_func(
491
+ simd8float32 a, simd8float32 b,
492
+ std::function<float (float, float)> f)
493
+ {
494
+ simd8float32 c;
495
+ for(int j = 0; j < 8; j++) {
496
+ c.f32[j] = f(a.f32[j], b.f32[j]);
497
+ }
498
+ return c;
499
+ }
500
+
501
+ simd8float32 operator * (simd8float32 other) const {
502
+ return binary_func(*this, other,
503
+ [](float a, float b) {return a * b; }
504
+ );
505
+ }
506
+
507
+ simd8float32 operator + (simd8float32 other) const {
508
+ return binary_func(*this, other,
509
+ [](float a, float b) {return a + b; }
510
+ );
511
+ }
512
+
513
+ simd8float32 operator - (simd8float32 other) const {
514
+ return binary_func(*this, other,
515
+ [](float a, float b) {return a - b; }
516
+ );
517
+ }
518
+
519
+ std::string tostring() const {
520
+ char res[1000], *ptr = res;
521
+ for(int i = 0; i < 8; i++) {
522
+ ptr += sprintf(ptr, "%g,", f32[i]);
523
+ }
524
+ // strip last ,
525
+ ptr[-1] = 0;
526
+ return std::string(res);
527
+ }
528
+
529
+ };
530
+
531
+
532
+ // hadd does not cross lanes
533
+ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
534
+ simd8float32 c;
535
+ c.f32[0] = a.f32[0] + a.f32[1];
536
+ c.f32[1] = a.f32[2] + a.f32[3];
537
+ c.f32[2] = b.f32[0] + b.f32[1];
538
+ c.f32[3] = b.f32[2] + b.f32[3];
539
+
540
+ c.f32[4] = a.f32[4] + a.f32[5];
541
+ c.f32[5] = a.f32[6] + a.f32[7];
542
+ c.f32[6] = b.f32[4] + b.f32[5];
543
+ c.f32[7] = b.f32[6] + b.f32[7];
544
+
545
+ return c;
546
+ }
547
+
548
+ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
549
+ simd8float32 c;
550
+ c.f32[0] = a.f32[0];
551
+ c.f32[1] = b.f32[0];
552
+ c.f32[2] = a.f32[1];
553
+ c.f32[3] = b.f32[1];
554
+
555
+ c.f32[4] = a.f32[4];
556
+ c.f32[5] = b.f32[4];
557
+ c.f32[6] = a.f32[5];
558
+ c.f32[7] = b.f32[5];
559
+
560
+ return c;
561
+ }
562
+
563
+ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
564
+ simd8float32 c;
565
+ c.f32[0] = a.f32[2];
566
+ c.f32[1] = b.f32[2];
567
+ c.f32[2] = a.f32[3];
568
+ c.f32[3] = b.f32[3];
569
+
570
+ c.f32[4] = a.f32[6];
571
+ c.f32[5] = b.f32[6];
572
+ c.f32[6] = a.f32[7];
573
+ c.f32[7] = b.f32[7];
574
+
575
+ return c;
576
+ }
577
+
578
+ // compute a * b + c
579
+ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
580
+ simd8float32 res;
581
+ for(int i = 0; i < 8; i++) {
582
+ res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
583
+ }
584
+ return res;
585
+ }
586
+
587
+
588
+
589
+ } // namespace faiss