faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,461 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <string>
11
+ #include <cstdint>
12
+
13
+ #include <immintrin.h>
14
+
15
+ #include <faiss/impl/platform_macros.h>
16
+
17
+ namespace faiss {
18
+
19
+
20
+ /** Simple wrapper around the AVX 256-bit registers
21
+ *
22
+ * The objective is to separate the different interpretations of the same
23
+ * registers (as a vector of uint8, uint16 or uint32), to provide printing
24
+ * functions, and to give more readable names to the AVX intrinsics. It does not
25
+ * pretend to be exhausitve, functions are added as needed.
26
+ */
27
+
28
+ /// 256-bit representation without interpretation as a vector
29
+ struct simd256bit {
30
+
31
+ union {
32
+ __m256i i;
33
+ __m256 f;
34
+ };
35
+
36
+ simd256bit() {}
37
+
38
+ explicit simd256bit(__m256i i): i(i) {}
39
+
40
+ explicit simd256bit(__m256 f): f(f) {}
41
+
42
+ explicit simd256bit(const void *x):
43
+ i(_mm256_load_si256((__m256i const *)x))
44
+ {}
45
+
46
+ void clear() {
47
+ i = _mm256_setzero_si256();
48
+ }
49
+
50
+ void storeu(void *ptr) const {
51
+ _mm256_storeu_si256((__m256i *)ptr, i);
52
+ }
53
+
54
+ void loadu(const void *ptr) {
55
+ i = _mm256_loadu_si256((__m256i*)ptr);
56
+ }
57
+
58
+ void store(void *ptr) const {
59
+ _mm256_store_si256((__m256i *)ptr, i);
60
+ }
61
+
62
+ void bin(char bits[257]) const {
63
+ char bytes[32];
64
+ storeu((void*)bytes);
65
+ for (int i = 0; i < 256; i++) {
66
+ bits[i] = '0' + ((bytes[i / 8] >> (i % 8)) & 1);
67
+ }
68
+ bits[256] = 0;
69
+ }
70
+
71
+ std::string bin() const {
72
+ char bits[257];
73
+ bin(bits);
74
+ return std::string(bits);
75
+ }
76
+
77
+ };
78
+
79
+
80
+ /// vector of 16 elements in uint16
81
+ struct simd16uint16: simd256bit {
82
+ simd16uint16() {}
83
+
84
+ explicit simd16uint16(__m256i i): simd256bit(i) {}
85
+
86
+ explicit simd16uint16(int x): simd256bit(_mm256_set1_epi16(x)) {}
87
+
88
+ explicit simd16uint16(uint16_t x): simd256bit(_mm256_set1_epi16(x)) {}
89
+
90
+ explicit simd16uint16(simd256bit x): simd256bit(x) {}
91
+
92
+ explicit simd16uint16(const uint16_t *x): simd256bit((const void*)x) {}
93
+
94
+ std::string elements_to_string(const char * fmt) const {
95
+ uint16_t bytes[16];
96
+ storeu((void*)bytes);
97
+ char res[1000];
98
+ char *ptr = res;
99
+ for(int i = 0; i < 16; i++) {
100
+ ptr += sprintf(ptr, fmt, bytes[i]);
101
+ }
102
+ // strip last ,
103
+ ptr[-1] = 0;
104
+ return std::string(res);
105
+ }
106
+
107
+ std::string hex() const {
108
+ return elements_to_string("%02x,");
109
+ }
110
+
111
+ std::string dec() const {
112
+ return elements_to_string("%3d,");
113
+ }
114
+
115
+ void set1(uint16_t x) {
116
+ i = _mm256_set1_epi16((short)x);
117
+ }
118
+
119
+ // shift must be known at compile time
120
+ simd16uint16 operator >> (const int shift) const {
121
+ return simd16uint16(_mm256_srli_epi16(i, shift));
122
+ }
123
+
124
+ // shift must be known at compile time
125
+ simd16uint16 operator << (const int shift) const {
126
+ return simd16uint16(_mm256_slli_epi16(i, shift));
127
+ }
128
+
129
+ simd16uint16 operator += (simd16uint16 other) {
130
+ i = _mm256_add_epi16(i, other.i);
131
+ return *this;
132
+ }
133
+
134
+ simd16uint16 operator -= (simd16uint16 other) {
135
+ i = _mm256_sub_epi16(i, other.i);
136
+ return *this;
137
+ }
138
+
139
+ simd16uint16 operator + (simd16uint16 other) const {
140
+ return simd16uint16(_mm256_add_epi16(i, other.i));
141
+ }
142
+
143
+ simd16uint16 operator - (simd16uint16 other) const {
144
+ return simd16uint16(_mm256_sub_epi16(i, other.i));
145
+ }
146
+
147
+ simd16uint16 operator & (simd256bit other) const {
148
+ return simd16uint16(_mm256_and_si256(i, other.i));
149
+ }
150
+
151
+ simd16uint16 operator | (simd256bit other) const {
152
+ return simd16uint16(_mm256_or_si256(i, other.i));
153
+ }
154
+
155
+ // returns binary masks
156
+ simd16uint16 operator == (simd256bit other) const {
157
+ return simd16uint16(_mm256_cmpeq_epi16(i, other.i));
158
+ }
159
+
160
+ simd16uint16 operator ~() const {
161
+ return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1)));
162
+ }
163
+
164
+ // get scalar at index 0
165
+ uint16_t get_scalar_0() const {
166
+ return _mm256_extract_epi16(i, 0);
167
+ }
168
+
169
+ // mask of elements where this >= thresh
170
+ // 2 bit per component: 16 * 2 = 32 bit
171
+ uint32_t ge_mask(simd16uint16 thresh) const {
172
+ __m256i j = thresh.i;
173
+ __m256i max = _mm256_max_epu16(i, j);
174
+ __m256i ge = _mm256_cmpeq_epi16(i, max);
175
+ return _mm256_movemask_epi8(ge);
176
+ }
177
+
178
+ uint32_t le_mask(simd16uint16 thresh) const {
179
+ return thresh.ge_mask(*this);
180
+ }
181
+
182
+ uint32_t gt_mask(simd16uint16 thresh) const {
183
+ return ~le_mask(thresh);
184
+ }
185
+
186
+ bool all_gt(simd16uint16 thresh) const {
187
+ return le_mask(thresh) == 0;
188
+ }
189
+
190
+ // for debugging only
191
+ uint16_t operator [] (int i) const {
192
+ ALIGNED(32) uint16_t tab[16];
193
+ store(tab);
194
+ return tab[i];
195
+ }
196
+
197
+ void accu_min(simd16uint16 incoming) {
198
+ i = _mm256_min_epu16(i, incoming.i);
199
+ }
200
+
201
+ void accu_max(simd16uint16 incoming) {
202
+ i = _mm256_max_epu16(i, incoming.i);
203
+ }
204
+
205
+ };
206
+
207
+ // not really a std::min because it returns an elementwise min
208
+ inline simd16uint16 min(simd16uint16 a, simd16uint16 b) {
209
+ return simd16uint16(_mm256_min_epu16(a.i, b.i));
210
+ }
211
+
212
+ inline simd16uint16 max(simd16uint16 a, simd16uint16 b) {
213
+ return simd16uint16(_mm256_max_epu16(a.i, b.i));
214
+ }
215
+
216
+
217
+
218
+ // decompose in 128-lanes: a = (a0, a1), b = (b0, b1)
219
+ // return (a0 + a1, b0 + b1)
220
+ // TODO find a better name
221
+ inline simd16uint16 combine2x2(simd16uint16 a, simd16uint16 b) {
222
+
223
+ __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
224
+ __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
225
+
226
+ return simd16uint16(a1b0) + simd16uint16(a0b1);
227
+ }
228
+
229
+ // compare d0 and d1 to thr, return 32 bits corresponding to the concatenation
230
+ // of d0 and d1 with thr
231
+ inline uint32_t cmp_ge32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
232
+
233
+ __m256i max0 = _mm256_max_epu16(d0.i, thr.i);
234
+ __m256i ge0 = _mm256_cmpeq_epi16(d0.i, max0);
235
+
236
+ __m256i max1 = _mm256_max_epu16(d1.i, thr.i);
237
+ __m256i ge1 = _mm256_cmpeq_epi16(d1.i, max1);
238
+
239
+ __m256i ge01 = _mm256_packs_epi16(ge0, ge1);
240
+
241
+ // easier than manipulating bit fields afterwards
242
+ ge01 = _mm256_permute4x64_epi64(ge01, 0 | (2 << 2) | (1 << 4) | (3 << 6));
243
+ uint32_t ge = _mm256_movemask_epi8(ge01);
244
+
245
+ return ge;
246
+ }
247
+
248
+
249
+ inline uint32_t cmp_le32(simd16uint16 d0, simd16uint16 d1, simd16uint16 thr) {
250
+
251
+ __m256i max0 = _mm256_min_epu16(d0.i, thr.i);
252
+ __m256i ge0 = _mm256_cmpeq_epi16(d0.i, max0);
253
+
254
+ __m256i max1 = _mm256_min_epu16(d1.i, thr.i);
255
+ __m256i ge1 = _mm256_cmpeq_epi16(d1.i, max1);
256
+
257
+ __m256i ge01 = _mm256_packs_epi16(ge0, ge1);
258
+
259
+ // easier than manipulating bit fields afterwards
260
+ ge01 = _mm256_permute4x64_epi64(ge01, 0 | (2 << 2) | (1 << 4) | (3 << 6));
261
+ uint32_t ge = _mm256_movemask_epi8(ge01);
262
+
263
+ return ge;
264
+ }
265
+
266
+
267
+ // vector of 32 unsigned 8-bit integers
268
+ struct simd32uint8: simd256bit {
269
+
270
+
271
+ simd32uint8() {}
272
+
273
+ explicit simd32uint8(__m256i i): simd256bit(i) {}
274
+
275
+ explicit simd32uint8(int x): simd256bit(_mm256_set1_epi8(x)) {}
276
+
277
+ explicit simd32uint8(uint8_t x): simd256bit(_mm256_set1_epi8(x)) {}
278
+
279
+ explicit simd32uint8(simd256bit x): simd256bit(x) {}
280
+
281
+ explicit simd32uint8(const uint8_t *x): simd256bit((const void*)x) {}
282
+
283
+ std::string elements_to_string(const char * fmt) const {
284
+ uint8_t bytes[32];
285
+ storeu((void*)bytes);
286
+ char res[1000];
287
+ char *ptr = res;
288
+ for(int i = 0; i < 32; i++) {
289
+ ptr += sprintf(ptr, fmt, bytes[i]);
290
+ }
291
+ // strip last ,
292
+ ptr[-1] = 0;
293
+ return std::string(res);
294
+ }
295
+
296
+ std::string hex() const {
297
+ return elements_to_string("%02x,");
298
+ }
299
+
300
+ std::string dec() const {
301
+ return elements_to_string("%3d,");
302
+ }
303
+
304
+ void set1(uint8_t x) {
305
+ i = _mm256_set1_epi8((char)x);
306
+ }
307
+
308
+ simd32uint8 operator & (simd256bit other) const {
309
+ return simd32uint8(_mm256_and_si256(i, other.i));
310
+ }
311
+
312
+ simd32uint8 operator + (simd32uint8 other) const {
313
+ return simd32uint8(_mm256_add_epi8(i, other.i));
314
+ }
315
+
316
+ simd32uint8 lookup_2_lanes(simd32uint8 idx) const {
317
+ return simd32uint8(_mm256_shuffle_epi8(i, idx.i));
318
+ }
319
+
320
+ // extract + 0-extend lane
321
+ // this operation is slow (3 cycles)
322
+ simd16uint16 lane0_as_uint16() const {
323
+ __m128i x = _mm256_extracti128_si256(i, 0);
324
+ return simd16uint16(_mm256_cvtepu8_epi16(x));
325
+ }
326
+
327
+ simd16uint16 lane1_as_uint16() const {
328
+ __m128i x = _mm256_extracti128_si256(i, 1);
329
+ return simd16uint16(_mm256_cvtepu8_epi16(x));
330
+ }
331
+
332
+ simd32uint8 operator += (simd32uint8 other) {
333
+ i = _mm256_add_epi8(i, other.i);
334
+ return *this;
335
+ }
336
+
337
+ // for debugging only
338
+ uint8_t operator [] (int i) const {
339
+ ALIGNED(32) uint8_t tab[32];
340
+ store(tab);
341
+ return tab[i];
342
+ }
343
+
344
+ };
345
+
346
+ // convert with saturation
347
+ // careful: this does not cross lanes, so the order is weird
348
+ inline simd32uint8 uint16_to_uint8_saturate(simd16uint16 a, simd16uint16 b) {
349
+ return simd32uint8(_mm256_packs_epi16(a.i, b.i));
350
+ }
351
+
352
+ /// get most significant bit of each byte
353
+ inline uint32_t get_MSBs(simd32uint8 a) {
354
+ return _mm256_movemask_epi8(a.i);
355
+ }
356
+
357
+ /// use MSB of each byte of mask to select a byte between a and b
358
+ inline simd32uint8 blendv(simd32uint8 a, simd32uint8 b, simd32uint8 mask) {
359
+ return simd32uint8(_mm256_blendv_epi8(a.i, b.i, mask.i));
360
+ }
361
+
362
+
363
+
364
+ /// vector of 8 unsigned 32-bit integers
365
+ struct simd8uint32: simd256bit {
366
+ simd8uint32() {}
367
+
368
+ explicit simd8uint32(__m256i i): simd256bit(i) {}
369
+
370
+ explicit simd8uint32(uint32_t x): simd256bit(_mm256_set1_epi32(x)) {}
371
+
372
+ explicit simd8uint32(simd256bit x): simd256bit(x) {}
373
+
374
+ explicit simd8uint32(const uint8_t *x): simd256bit((const void*)x) {}
375
+
376
+ std::string elements_to_string(const char * fmt) const {
377
+ uint32_t bytes[8];
378
+ storeu((void*)bytes);
379
+ char res[1000];
380
+ char *ptr = res;
381
+ for(int i = 0; i < 8; i++) {
382
+ ptr += sprintf(ptr, fmt, bytes[i]);
383
+ }
384
+ // strip last ,
385
+ ptr[-1] = 0;
386
+ return std::string(res);
387
+ }
388
+
389
+ std::string hex() const {
390
+ return elements_to_string("%08x,");
391
+ }
392
+
393
+ std::string dec() const {
394
+ return elements_to_string("%10d,");
395
+ }
396
+
397
+ void set1(uint32_t x) {
398
+ i = _mm256_set1_epi32((int)x);
399
+ }
400
+
401
+ };
402
+
403
+ struct simd8float32: simd256bit {
404
+
405
+ simd8float32() {}
406
+
407
+
408
+ explicit simd8float32(simd256bit x): simd256bit(x) {}
409
+
410
+ explicit simd8float32(__m256 x): simd256bit(x) {}
411
+
412
+ explicit simd8float32(float x): simd256bit(_mm256_set1_ps(x)) {}
413
+
414
+ explicit simd8float32(const float *x): simd256bit(_mm256_load_ps(x)) {}
415
+
416
+ simd8float32 operator * (simd8float32 other) const {
417
+ return simd8float32(_mm256_mul_ps(f, other.f));
418
+ }
419
+
420
+ simd8float32 operator + (simd8float32 other) const {
421
+ return simd8float32(_mm256_add_ps(f, other.f));
422
+ }
423
+
424
+ simd8float32 operator - (simd8float32 other) const {
425
+ return simd8float32(_mm256_sub_ps(f, other.f));
426
+ }
427
+
428
+ std::string tostring() const {
429
+ float tab[8];
430
+ storeu((void*)tab);
431
+ char res[1000];
432
+ char *ptr = res;
433
+ for(int i = 0; i < 8; i++) {
434
+ ptr += sprintf(ptr, "%g,", tab[i]);
435
+ }
436
+ // strip last ,
437
+ ptr[-1] = 0;
438
+ return std::string(res);
439
+ }
440
+
441
+ };
442
+
443
+ inline simd8float32 hadd(simd8float32 a, simd8float32 b) {
444
+ return simd8float32(_mm256_hadd_ps(a.f, b.f));
445
+ }
446
+
447
+ inline simd8float32 unpacklo(simd8float32 a, simd8float32 b) {
448
+ return simd8float32(_mm256_unpacklo_ps(a.f, b.f));
449
+ }
450
+
451
+ inline simd8float32 unpackhi(simd8float32 a, simd8float32 b) {
452
+ return simd8float32(_mm256_unpackhi_ps(a.f, b.f));
453
+ }
454
+
455
+ // compute a * b + c
456
+ inline simd8float32 fmadd(simd8float32 a, simd8float32 b, simd8float32 c) {
457
+ return simd8float32(_mm256_fmadd_ps(a.f, b.f, c.f));
458
+ }
459
+
460
+
461
+ } // namespace faiss