faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,589 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <string>
11
+ #include <cstdint>
12
+ #include <cstring>
13
+ #include <functional>
14
+ #include <algorithm>
15
+
16
+ namespace faiss {
17
+
18
+
19
+ struct simd256bit {
20
+
21
+ union {
22
+ uint8_t u8[32];
23
+ uint16_t u16[16];
24
+ uint32_t u32[8];
25
+ float f32[8];
26
+ };
27
+
28
+ simd256bit() {}
29
+
30
+ explicit simd256bit(const void *x)
31
+ {
32
+ memcpy(u8, x, 32);
33
+ }
34
+
35
+ void clear() {
36
+ memset(u8, 0, 32);
37
+ }
38
+
39
+ void storeu(void *ptr) const {
40
+ memcpy(ptr, u8, 32);
41
+ }
42
+
43
+ void loadu(const void *ptr) {
44
+ memcpy(u8, ptr, 32);
45
+ }
46
+
47
+ void store(void *ptr) const {
48
+ storeu(ptr);
49
+ }
50
+
51
+ void bin(char bits[257]) const {
52
+ const char *bytes = (char*)this->u8;
53
+ for (int i = 0; i < 256; i++) {
54
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
55
+ }
56
+ bits[256] = 0;
57
+ }
58
+
59
+ std::string bin() const {
60
+ char bits[257];
61
+ bin(bits);
62
+ return std::string(bits);
63
+ }
64
+
65
+ };
66
+
67
+
68
+
69
+
70
+ /// vector of 16 elements in uint16
71
+ struct simd16uint16: simd256bit {
72
+ simd16uint16() {}
73
+
74
+ explicit simd16uint16(int x) {
75
+ set1(x);
76
+ }
77
+
78
+ explicit simd16uint16(uint16_t x) {
79
+ set1(x);
80
+ }
81
+
82
+ explicit simd16uint16(simd256bit x): simd256bit(x) {}
83
+
84
+ explicit simd16uint16(const uint16_t *x): simd256bit((const void*)x) {}
85
+
86
+ std::string elements_to_string(const char * fmt) const {
87
+ char res[1000], *ptr = res;
88
+ for(int i = 0; i < 16; i++) {
89
+ ptr += sprintf(ptr, fmt, u16[i]);
90
+ }
91
+ // strip last ,
92
+ ptr[-1] = 0;
93
+ return std::string(res);
94
+ }
95
+
96
+ std::string hex() const {
97
+ return elements_to_string("%02x,");
98
+ }
99
+
100
+ std::string dec() const {
101
+ return elements_to_string("%3d,");
102
+ }
103
+
104
+ static simd16uint16 unary_func(
105
+ simd16uint16 a, std::function<uint16_t (uint16_t)> f)
106
+ {
107
+ simd16uint16 c;
108
+ for(int j = 0; j < 16; j++) {
109
+ c.u16[j] = f(a.u16[j]);
110
+ }
111
+ return c;
112
+ }
113
+
114
+
115
+ static simd16uint16 binary_func(
116
+ simd16uint16 a, simd16uint16 b,
117
+ std::function<uint16_t (uint16_t, uint16_t)> f)
118
+ {
119
+ simd16uint16 c;
120
+ for(int j = 0; j < 16; j++) {
121
+ c.u16[j] = f(a.u16[j], b.u16[j]);
122
+ }
123
+ return c;
124
+ }
125
+
126
+ void set1(uint16_t x) {
127
+ for(int i = 0; i < 16; i++) {
128
+ u16[i] = x;
129
+ }
130
+ }
131
+
132
+ // shift must be known at compile time
133
+ simd16uint16 operator >> (const int shift) const {
134
+ return unary_func(*this, [shift](uint16_t a) {return a >> shift; });
135
+ }
136
+
137
+
138
+ // shift must be known at compile time
139
+ simd16uint16 operator << (const int shift) const {
140
+ return unary_func(*this, [shift](uint16_t a) {return a << shift; });
141
+ }
142
+
143
+ simd16uint16 operator += (simd16uint16 other) {
144
+ *this = *this + other;
145
+ return *this;
146
+ }
147
+
148
+ simd16uint16 operator -= (simd16uint16 other) {
149
+ *this = *this - other;
150
+ return *this;
151
+ }
152
+
153
+ simd16uint16 operator + (simd16uint16 other) const {
154
+ return binary_func(*this, other,
155
+ [](uint16_t a, uint16_t b) {return a + b; }
156
+ );
157
+ }
158
+
159
+ simd16uint16 operator - (simd16uint16 other) const {
160
+ return binary_func(*this, other,
161
+ [](uint16_t a, uint16_t b) {return a - b; }
162
+ );
163
+ }
164
+
165
+ simd16uint16 operator & (simd256bit other) const {
166
+ return binary_func(*this, simd16uint16(other),
167
+ [](uint16_t a, uint16_t b) {return a & b; }
168
+ );
169
+ }
170
+
171
+ simd16uint16 operator | (simd256bit other) const {
172
+ return binary_func(*this, simd16uint16(other),
173
+ [](uint16_t a, uint16_t b) {return a | b; }
174
+ );
175
+ }
176
+
177
+ // returns binary masks
178
+ simd16uint16 operator == (simd16uint16 other) const {
179
+ return binary_func(*this, other,
180
+ [](uint16_t a, uint16_t b) {return a == b ? 0xffff : 0; }
181
+ );
182
+ }
183
+
184
+ simd16uint16 operator ~() const {
185
+ return unary_func(*this, [](uint16_t a) {return ~a; });
186
+ }
187
+
188
+ // get scalar at index 0
189
+ uint16_t get_scalar_0() const {
190
+ return u16[0];
191
+ }
192
+
193
+ // mask of elements where this >= thresh
194
+ // 2 bit per component: 16 * 2 = 32 bit
195
+ uint32_t ge_mask(simd16uint16 thresh) const {
196
+ uint32_t gem = 0;
197
+ for(int j = 0; j < 16; j++) {
198
+ if (u16[j] >= thresh.u16[j]) {
199
+ gem |= 3 << (j * 2);
200
+ }
201
+ }
202
+ return gem;
203
+ }
204
+
205
+ uint32_t le_mask(simd16uint16 thresh) const {
206
+ return thresh.ge_mask(*this);
207
+ }
208
+
209
+ uint32_t gt_mask(simd16uint16 thresh) const {
210
+ return ~le_mask(thresh);
211
+ }
212
+
213
+ bool all_gt(simd16uint16 thresh) const {
214
+ return le_mask(thresh) == 0;
215
+ }
216
+
217
+ // for debugging only
218
+ uint16_t operator [] (int i) const {
219
+ return u16[i];
220
+ }
221
+
222
+ void accu_min(simd16uint16 incoming) {
223
+ for(int j = 0; j < 16; j++) {
224
+ if (incoming.u16[j] < u16[j]) {
225
+ u16[j] = incoming.u16[j];
226
+ }
227
+ }
228
+ }
229
+
230
+ void accu_max(simd16uint16 incoming) {
231
+ for(int j = 0; j < 16; j++) {
232
+ if (incoming.u16[j] > u16[j]) {
233
+ u16[j] = incoming.u16[j];
234
+ }
235
+ }
236
+ }
237
+
238
+ };
239
+
240
+
241
+ // not really a std::min because it returns an elementwise min
242
+ inline simd16uint16 min(simd16uint16 av, simd16uint16 bv) {
243
+ return simd16uint16::binary_func(av, bv,
244
+ [](uint16_t a, uint16_t b) {return std::min(a, b); }
245
+ );
246
+ }
247
+
248
+ inline simd16uint16 max(simd16uint16 av, simd16uint16 bv) {
249
+ return simd16uint16::binary_func(av, bv,
250
+ [](uint16_t a, uint16_t b) {return std::max(a, b); }
251
+ );
252
+ }
253
+
254
+ // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
255
+ // return (a0 + a1, b0 + b1)
256
+ // TODO find a better name
257
+ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
258
+ simd16uint16 c;
259
+ for(int j = 0; j < 8; j++) {
260
+ c.u16[j] = a.u16[j] + a.u16[j + 8];
261
+ c.u16[j + 8] = b.u16[j] + b.u16[j + 8];
262
+ }
263
+ return c;
264
+ }
265
+
266
+ // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
267
+ // of d0 and d1 with thr
268
+ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
269
+ uint32_t gem = 0;
270
+ for(int j = 0; j < 16; j++) {
271
+ if (d0.u16[j] >= thr.u16[j]) {
272
+ gem |= 1 << j;
273
+ }
274
+ if (d1.u16[j] >= thr.u16[j]) {
275
+ gem |= 1 << (j + 16);
276
+ }
277
+ }
278
+ return gem;
279
+ }
280
+
281
+
282
+ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
283
+ uint32_t gem = 0;
284
+ for(int j = 0; j < 16; j++) {
285
+ if (d0.u16[j] <= thr.u16[j]) {
286
+ gem |= 1 << j;
287
+ }
288
+ if (d1.u16[j] <= thr.u16[j]) {
289
+ gem |= 1 << (j + 16);
290
+ }
291
+ }
292
+ return gem;
293
+ }
294
+
295
+
296
+
297
+ // vector of 32 unsigned 8-bit integers
298
+ struct simd32uint8: simd256bit {
299
+
300
+ simd32uint8() {}
301
+
302
+ explicit simd32uint8(int x) {set1(x); }
303
+
304
+ explicit simd32uint8(uint8_t x) {set1(x); }
305
+
306
+ explicit simd32uint8(simd256bit x): simd256bit(x) {}
307
+
308
+ explicit simd32uint8(const uint8_t *x): simd256bit((const void*)x) {}
309
+
310
+ std::string elements_to_string(const char * fmt) const {
311
+ char res[1000], *ptr = res;
312
+ for(int i = 0; i < 32; i++) {
313
+ ptr += sprintf(ptr, fmt, u8[i]);
314
+ }
315
+ // strip last ,
316
+ ptr[-1] = 0;
317
+ return std::string(res);
318
+ }
319
+
320
+ std::string hex() const {
321
+ return elements_to_string("%02x,");
322
+ }
323
+
324
+ std::string dec() const {
325
+ return elements_to_string("%3d,");
326
+ }
327
+
328
+ void set1(uint8_t x) {
329
+ for(int j = 0; j < 32; j++) {
330
+ u8[j] = x;
331
+ }
332
+ }
333
+
334
+ static simd32uint8 binary_func(
335
+ simd32uint8 a, simd32uint8 b,
336
+ std::function<uint8_t (uint8_t, uint8_t)> f)
337
+ {
338
+ simd32uint8 c;
339
+ for(int j = 0; j < 32; j++) {
340
+ c.u8[j] = f(a.u8[j], b.u8[j]);
341
+ }
342
+ return c;
343
+ }
344
+
345
+
346
+ simd32uint8 operator & (simd256bit other) const {
347
+ return binary_func(*this, simd32uint8(other),
348
+ [](uint8_t a, uint8_t b) {return a & b; }
349
+ );
350
+ }
351
+
352
+ simd32uint8 operator + (simd32uint8 other) const {
353
+ return binary_func(*this, other,
354
+ [](uint8_t a, uint8_t b) {return a + b; }
355
+ );
356
+ }
357
+
358
+ // The very important operation that everything relies on
359
+ simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
360
+ simd32uint8 c;
361
+ for(int j = 0; j < 32; j++) {
362
+ if (idx.u8[j] & 0x80) {
363
+ c.u8[j] = 0;
364
+ } else {
365
+ uint8_t i = idx.u8[j] & 15;
366
+ if (j < 16) {
367
+ c.u8[j] = u8[i];
368
+ } else {
369
+ c.u8[j] = u8[16 + i];
370
+ }
371
+ }
372
+ }
373
+ return c;
374
+ }
375
+
376
+ // extract + 0-extend lane
377
+ // this operation is slow (3 cycles)
378
+
379
+ simd32uint8 operator += (simd32uint8 other) {
380
+ *this = *this + other;
381
+ return *this;
382
+ }
383
+
384
+ // for debugging only
385
+ uint8_t operator [] (int i) const {
386
+ return u8[i];
387
+ }
388
+
389
+ };
390
+
391
+
392
+ // convert with saturation
393
+ // careful: this does not cross lanes, so the order is weird
394
+ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
395
+ simd32uint8 c;
396
+
397
+ auto saturate_16_to_8 = [] (uint16_t x) {
398
+ return x >= 256 ? 0xff : x;
399
+ };
400
+
401
+ for (int i = 0; i < 8; i++) {
402
+ c.u8[ i] = saturate_16_to_8(a.u16[i]);
403
+ c.u8[8 + i] = saturate_16_to_8(b.u16[i]);
404
+ c.u8[16 + i] = saturate_16_to_8(a.u16[8 + i]);
405
+ c.u8[24 + i] = saturate_16_to_8(b.u16[8 + i]);
406
+ }
407
+ return c;
408
+ }
409
+
410
+ /// get most significant bit of each byte
411
+ inline uint32_t get_MSBs(simd32uint8 a) {
412
+ uint32_t res = 0;
413
+ for (int i = 0; i < 32; i++) {
414
+ if (a.u8[i] & 0x80) {
415
+ res |= 1 << i;
416
+ }
417
+ }
418
+ return res;
419
+ }
420
+
421
+ /// use MSB of each byte of mask to select a byte between a and b
422
+ inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
423
+ simd32uint8 c;
424
+ for (int i = 0; i < 32; i++) {
425
+ if (mask.u8[i] & 0x80) {
426
+ c.u8[i] = b.u8[i];
427
+ } else {
428
+ c.u8[i] = a.u8[i];
429
+ }
430
+ }
431
+ return c;
432
+ }
433
+
434
+
435
+
436
+
437
+ /// vector of 8 unsigned 32-bit integers
438
+ struct simd8uint32: simd256bit {
439
+ simd8uint32() {}
440
+
441
+
442
+ explicit simd8uint32(uint32_t x) {set1(x); }
443
+
444
+ explicit simd8uint32(simd256bit x): simd256bit(x) {}
445
+
446
+ explicit simd8uint32(const uint8_t *x): simd256bit((const void*)x) {}
447
+
448
+ std::string elements_to_string(const char * fmt) const {
449
+ char res[1000], *ptr = res;
450
+ for(int i = 0; i < 8; i++) {
451
+ ptr += sprintf(ptr, fmt, u32[i]);
452
+ }
453
+ // strip last ,
454
+ ptr[-1] = 0;
455
+ return std::string(res);
456
+ }
457
+
458
+ std::string hex() const {
459
+ return elements_to_string("%08x,");
460
+ }
461
+
462
+ std::string dec() const {
463
+ return elements_to_string("%10d,");
464
+ }
465
+
466
+ void set1(uint32_t x) {
467
+ for (int i = 0; i < 8; i++) {
468
+ u32[i] = x;
469
+ }
470
+ }
471
+
472
+ };
473
+
474
+ struct simd8float32: simd256bit {
475
+
476
+ simd8float32() {}
477
+
478
+ explicit simd8float32(simd256bit x): simd256bit(x) {}
479
+
480
+ explicit simd8float32(float x) {set1(x); }
481
+
482
+ explicit simd8float32(const float *x) {loadu((void*)x); }
483
+
484
+ void set1(float x) {
485
+ for(int i = 0; i < 8; i++) {
486
+ f32[i] = x;
487
+ }
488
+ }
489
+
490
+ static simd8float32 binary_func(
491
+ simd8float32 a, simd8float32 b,
492
+ std::function<float (float, float)> f)
493
+ {
494
+ simd8float32 c;
495
+ for(int j = 0; j < 8; j++) {
496
+ c.f32[j] = f(a.f32[j], b.f32[j]);
497
+ }
498
+ return c;
499
+ }
500
+
501
+ simd8float32 operator * (simd8float32 other) const {
502
+ return binary_func(*this, other,
503
+ [](float a, float b) {return a * b; }
504
+ );
505
+ }
506
+
507
+ simd8float32 operator + (simd8float32 other) const {
508
+ return binary_func(*this, other,
509
+ [](float a, float b) {return a + b; }
510
+ );
511
+ }
512
+
513
+ simd8float32 operator - (simd8float32 other) const {
514
+ return binary_func(*this, other,
515
+ [](float a, float b) {return a - b; }
516
+ );
517
+ }
518
+
519
+ std::string tostring() const {
520
+ char res[1000], *ptr = res;
521
+ for(int i = 0; i < 8; i++) {
522
+ ptr += sprintf(ptr, "%g,", f32[i]);
523
+ }
524
+ // strip last ,
525
+ ptr[-1] = 0;
526
+ return std::string(res);
527
+ }
528
+
529
+ };
530
+
531
+
532
+ // hadd does not cross lanes
533
+ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
534
+ simd8float32 c;
535
+ c.f32[0] = a.f32[0] + a.f32[1];
536
+ c.f32[1] = a.f32[2] + a.f32[3];
537
+ c.f32[2] = b.f32[0] + b.f32[1];
538
+ c.f32[3] = b.f32[2] + b.f32[3];
539
+
540
+ c.f32[4] = a.f32[4] + a.f32[5];
541
+ c.f32[5] = a.f32[6] + a.f32[7];
542
+ c.f32[6] = b.f32[4] + b.f32[5];
543
+ c.f32[7] = b.f32[6] + b.f32[7];
544
+
545
+ return c;
546
+ }
547
+
548
+ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
549
+ simd8float32 c;
550
+ c.f32[0] = a.f32[0];
551
+ c.f32[1] = b.f32[0];
552
+ c.f32[2] = a.f32[1];
553
+ c.f32[3] = b.f32[1];
554
+
555
+ c.f32[4] = a.f32[4];
556
+ c.f32[5] = b.f32[4];
557
+ c.f32[6] = a.f32[5];
558
+ c.f32[7] = b.f32[5];
559
+
560
+ return c;
561
+ }
562
+
563
+ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
564
+ simd8float32 c;
565
+ c.f32[0] = a.f32[2];
566
+ c.f32[1] = b.f32[2];
567
+ c.f32[2] = a.f32[3];
568
+ c.f32[3] = b.f32[3];
569
+
570
+ c.f32[4] = a.f32[6];
571
+ c.f32[5] = b.f32[6];
572
+ c.f32[6] = a.f32[7];
573
+ c.f32[7] = b.f32[7];
574
+
575
+ return c;
576
+ }
577
+
578
+ // compute a * b + c
579
+ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
580
+ simd8float32 res;
581
+ for(int i = 0; i < 8; i++) {
582
+ res.f32[i] = a.f32[i] * b.f32[i] + c.f32[i];
583
+ }
584
+ return res;
585
+ }
586
+
587
+
588
+
589
+ } // namespace faiss