faiss 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,559 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <vector>
11
+ #include <algorithm>
12
+ #include <type_traits>
13
+
14
+ #include <faiss/utils/Heap.h>
15
+ #include <faiss/utils/simdlib.h>
16
+
17
+ #include <faiss/utils/AlignedTable.h>
18
+ #include <faiss/utils/partitioning.h>
19
+ #include <faiss/impl/platform_macros.h>
20
+
21
+ /** This file contains callbacks for kernels that compute distances.
22
+ *
23
+ * The SIMDResultHandler object is intended to be templated and inlined.
24
+ * Methods:
25
+ * - handle(): called when 32 distances are computed and provided in two
26
+ * simd16uint16. (q, b) indicate which entry it is in the block.
27
+ * - set_block_origin(): set the sub-matrix that is being computed
28
+ */
29
+
30
+ namespace faiss {
31
+
32
+ namespace simd_result_handlers {
33
+
34
+
35
+ /** Dummy structure that just computes a checksum on results
36
+ * (to avoid the computation to be optimized away) */
37
+ struct DummyResultHandler {
38
+ size_t cs = 0;
39
+
40
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
41
+ cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
42
+ }
43
+
44
+ void set_block_origin(size_t, size_t) {
45
+ }
46
+ };
47
+
48
+ /** memorize results in a nq-by-nb matrix.
49
+ *
50
+ * j0 is the current upper-left block of the matrix
51
+ */
52
+ struct StoreResultHandler {
53
+ uint16_t *data;
54
+ size_t ld; // total number of columns
55
+ size_t i0 = 0;
56
+ size_t j0 = 0;
57
+
58
+ StoreResultHandler(uint16_t *data, size_t ld):
59
+ data(data), ld(ld) {
60
+ }
61
+
62
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
63
+ size_t ofs = (q + i0) * ld + j0 + b * 32;
64
+ d0.store(data + ofs);
65
+ d1.store(data + ofs + 16);
66
+ }
67
+
68
+ void set_block_origin(size_t i0, size_t j0) {
69
+ this->i0 = i0;
70
+ this->j0 = j0;
71
+ }
72
+ };
73
+
74
+
75
+ /** stores results in fixed-size matrix. */
76
+ template<int NQ, int BB>
77
+ struct FixedStorageHandler {
78
+ simd16uint16 dis[NQ][BB];
79
+ int i0 = 0;
80
+
81
+ void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
82
+ dis[q + i0][2 * b] = d0;
83
+ dis[q + i0][2 * b + 1] = d1;
84
+ }
85
+
86
+ void set_block_origin(size_t i0, size_t j0) {
87
+ this->i0 = i0;
88
+ assert(j0 == 0);
89
+ }
90
+
91
+ template<class OtherResultHandler>
92
+ void to_other_handler(OtherResultHandler & other) const {
93
+ for (int q = 0; q < NQ; q++) {
94
+ for(int b = 0; b < BB; b += 2) {
95
+ other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
96
+ }
97
+ }
98
+ }
99
+
100
+ };
101
+
102
+
103
+ /** Record origin of current block */
104
+ template<class C, bool with_id_map>
105
+ struct SIMDResultHandler {
106
+ using TI = typename C::TI;
107
+
108
+ bool disable = false;
109
+
110
+ int64_t i0 = 0; // query origin
111
+ int64_t j0 = 0; // db origin
112
+ size_t ntotal; // ignore excess elements after ntotal
113
+
114
+ /// these fields are used mainly for the IVF variants (with_id_map=true)
115
+ const TI *id_map; // map offset in invlist to vector id
116
+ const int *q_map; // map q to global query
117
+ const uint16_t *dbias; // table of biases to add to each query
118
+
119
+ explicit SIMDResultHandler(size_t ntotal):
120
+ ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr)
121
+ {}
122
+
123
+ void set_block_origin(size_t i0, size_t j0) {
124
+ this->i0 = i0;
125
+ this->j0 = j0;
126
+ }
127
+
128
+
129
+ // adjust handler data for IVF.
130
+ void adjust_with_origin(size_t & q, simd16uint16 & d0, simd16uint16 & d1)
131
+ {
132
+ q += i0;
133
+
134
+ if (dbias) {
135
+ simd16uint16 dbias16(dbias[q]);
136
+ d0 += dbias16;
137
+ d1 += dbias16;
138
+ }
139
+
140
+ if (with_id_map) { // FIXME test on q_map instead
141
+ q = q_map[q];
142
+ }
143
+ }
144
+
145
+ // compute and adjust idx
146
+ int64_t adjust_id(size_t b, size_t j) {
147
+ int64_t idx = j0 + 32 * b + j;
148
+ if (with_id_map) {
149
+ idx = id_map[idx];
150
+ }
151
+ return idx;
152
+ }
153
+
154
+ /// return binary mask of elements below thr in (d0, d1)
155
+ /// inverse_test returns elements above
156
+ uint32_t get_lt_mask(
157
+ uint16_t thr, size_t b,
158
+ simd16uint16 d0, simd16uint16 d1
159
+ ) {
160
+ simd16uint16 thr16(thr);
161
+ uint32_t lt_mask;
162
+
163
+ constexpr bool keep_min = C::is_max;
164
+ if (keep_min) {
165
+ lt_mask = ~cmp_ge32(d0, d1, thr16);
166
+ } else {
167
+ lt_mask = ~cmp_le32(d0, d1, thr16);
168
+ }
169
+
170
+ if (lt_mask == 0) {
171
+ return 0;
172
+ }
173
+ uint64_t idx = j0 + b * 32;
174
+ if (idx + 32 > ntotal) {
175
+ if (idx >= ntotal) {
176
+ return 0;
177
+ }
178
+ int nbit = (ntotal - idx);
179
+ lt_mask &= (uint32_t(1) << nbit) - 1;
180
+ }
181
+ return lt_mask;
182
+ }
183
+
184
+ virtual void to_flat_arrays(
185
+ float *distances, int64_t *labels,
186
+ const float *normalizers = nullptr
187
+ ) = 0;
188
+
189
+ virtual ~SIMDResultHandler() {}
190
+
191
+ };
192
+
193
+
194
+ /** Special version for k=1 */
195
+ template<class C, bool with_id_map = false>
196
+ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
197
+ using T = typename C::T;
198
+ using TI = typename C::TI;
199
+
200
+ struct Result {
201
+ T val;
202
+ TI id;
203
+ };
204
+ std::vector<Result> results;
205
+
206
+ SingleResultHandler(size_t nq, size_t ntotal):
207
+ SIMDResultHandler<C, with_id_map>(ntotal), results(nq)
208
+ {
209
+ for (int i = 0; i < nq; i++) {
210
+ Result res = {C::neutral(), -1};
211
+ results[i] = res;
212
+ }
213
+ }
214
+
215
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
216
+ if(this->disable) {
217
+ return;
218
+ }
219
+
220
+ this->adjust_with_origin(q, d0, d1);
221
+
222
+ Result & res = results[q];
223
+ uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
224
+ if (!lt_mask) {
225
+ return;
226
+ }
227
+
228
+ ALIGNED(32) uint16_t d32tab[32];
229
+ d0.store(d32tab);
230
+ d1.store(d32tab + 16);
231
+
232
+ while (lt_mask) {
233
+ // find first non-zero
234
+ int j = __builtin_ctz(lt_mask);
235
+ lt_mask -= 1 << j;
236
+ T dis = d32tab[j];
237
+ if (C::cmp(res.val, dis)) {
238
+ res.val = dis;
239
+ res.id = this->adjust_id(b, j);
240
+ }
241
+ }
242
+ }
243
+
244
+ void to_flat_arrays(
245
+ float *distances, int64_t *labels,
246
+ const float *normalizers = nullptr
247
+ ) override {
248
+ for (int q = 0; q < results.size(); q++) {
249
+ if (!normalizers) {
250
+ distances[q] = results[q].val;
251
+ } else {
252
+ float one_a = 1 / normalizers[2 * q];
253
+ float b = normalizers[2 * q + 1];
254
+ distances[q] = b + results[q].val * one_a;
255
+ }
256
+ labels[q] = results[q].id;
257
+ }
258
+ }
259
+
260
+ };
261
+
262
+ /** Structure that collects results in a min- or max-heap */
263
+ template<class C, bool with_id_map = false>
264
+ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
265
+ using T = typename C::T;
266
+ using TI = typename C::TI;
267
+
268
+ int nq;
269
+ T *heap_dis_tab;
270
+ TI *heap_ids_tab;
271
+
272
+ int64_t k; // number of results to keep
273
+
274
+ HeapHandler(
275
+ int nq,
276
+ T * heap_dis_tab, TI * heap_ids_tab,
277
+ size_t k, size_t ntotal
278
+ ):
279
+ SIMDResultHandler<C, with_id_map>(ntotal), nq(nq),
280
+ heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
281
+ {
282
+ for (int q = 0; q < nq; q++) {
283
+ T *heap_dis_in = heap_dis_tab + q * k;
284
+ TI *heap_ids_in = heap_ids_tab + q * k;
285
+ heap_heapify<C> (k, heap_dis_in, heap_ids_in);
286
+ }
287
+ }
288
+
289
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
290
+ if(this->disable) {
291
+ return;
292
+ }
293
+
294
+ this->adjust_with_origin(q, d0, d1);
295
+
296
+ T *heap_dis = heap_dis_tab + q * k;
297
+ TI *heap_ids = heap_ids_tab + q * k;
298
+
299
+ uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) :
300
+ 0xffff;
301
+
302
+ // here we handle the reverse comparison case as well
303
+ uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
304
+
305
+ if (!lt_mask) {
306
+ return;
307
+ }
308
+
309
+ ALIGNED(32) uint16_t d32tab[32] ;
310
+ d0.store(d32tab);
311
+ d1.store(d32tab + 16);
312
+
313
+ while (lt_mask) {
314
+ // find first non-zero
315
+ int j = __builtin_ctz(lt_mask);
316
+ lt_mask -= 1 << j;
317
+ T dis = d32tab[j];
318
+ if (C::cmp(heap_dis[0], dis)) {
319
+ int64_t idx = this->adjust_id(b, j);
320
+ heap_pop<C>(k, heap_dis, heap_ids);
321
+ heap_push<C>(k, heap_dis, heap_ids, dis, idx);
322
+ }
323
+ }
324
+
325
+ }
326
+
327
+ void to_flat_arrays(
328
+ float *distances, int64_t *labels,
329
+ const float *normalizers = nullptr
330
+ ) override {
331
+
332
+ for (int q = 0; q < nq; q++) {
333
+ T *heap_dis_in = heap_dis_tab + q * k;
334
+ TI *heap_ids_in = heap_ids_tab + q * k;
335
+ heap_reorder<C> (k, heap_dis_in, heap_ids_in);
336
+ int64_t *heap_ids = labels + q * k;
337
+ float *heap_dis = distances + q * k;
338
+
339
+ float one_a = 1.0, b = 0.0;
340
+ if (normalizers) {
341
+ one_a = 1 / normalizers[2 * q];
342
+ b = normalizers[2 * q + 1];
343
+ }
344
+ for (int j = 0; j < k; j++) {
345
+ heap_ids[j] = heap_ids_in[j];
346
+ heap_dis[j] = heap_dis_in[j] * one_a + b;
347
+ }
348
+ }
349
+ }
350
+
351
+ };
352
+
353
+
354
+ /** Simple top-N implementation using a reservoir.
355
+ *
356
+ * Results are stored when they are below the threshold until the capacity is
357
+ * reached. Then a partition sort is used to update the threshold. */
358
+
359
+ namespace {
360
+
361
+ uint64_t get_cy () {
362
+ #ifdef MICRO_BENCHMARK
363
+ uint32_t high, low;
364
+ asm volatile("rdtsc \n\t"
365
+ : "=a" (low),
366
+ "=d" (high));
367
+ return ((uint64_t)high << 32) | (low);
368
+ #else
369
+ return 0;
370
+ #endif
371
+ }
372
+
373
+ } // anonymous namespace
374
+
375
+ template<class C>
376
+ struct ReservoirTopN {
377
+ using T = typename C::T;
378
+ using TI = typename C::TI;
379
+
380
+ T *vals;
381
+ TI *ids;
382
+
383
+ size_t i; // number of stored elements
384
+ size_t n; // number of requested elements
385
+ size_t capacity; // size of storage
386
+ size_t cycles = 0;
387
+
388
+ T threshold; // current threshold
389
+
390
+ ReservoirTopN(
391
+ size_t n, size_t capacity,
392
+ T *vals, TI *ids
393
+ ):
394
+ vals(vals), ids(ids),
395
+ i(0), n(n), capacity(capacity) {
396
+ assert(n < capacity);
397
+ threshold = C::neutral();
398
+ }
399
+
400
+ void add(T val, TI id) {
401
+ if (C::cmp(threshold, val)) {
402
+ if (i == capacity) {
403
+ shrink_fuzzy();
404
+ }
405
+ vals[i] = val;
406
+ ids[i] = id;
407
+ i++;
408
+ }
409
+ }
410
+
411
+ /// shrink number of stored elements to n
412
+ void shrink_xx() {
413
+ uint64_t t0 = get_cy();
414
+ qselect (vals, ids, i, n);
415
+ i = n; // forget all elements above i = n
416
+ threshold = C::Crev::neutral();
417
+ for(size_t j = 0; j < n; j++) {
418
+ if(C::cmp(vals[j], threshold)) {
419
+ threshold = vals[j];
420
+ }
421
+ }
422
+ cycles += get_cy() - t0;
423
+ }
424
+
425
+ void shrink() {
426
+ uint64_t t0 = get_cy();
427
+ threshold = partition<C>(vals, ids, i, n);
428
+ i = n;
429
+ cycles += get_cy() - t0;
430
+ }
431
+
432
+ void shrink_fuzzy() {
433
+ uint64_t t0 = get_cy();
434
+ assert(i == capacity);
435
+ threshold = partition_fuzzy<C>(
436
+ vals, ids, capacity, n, (capacity + n) / 2,
437
+ &i);
438
+ cycles += get_cy() - t0;
439
+ }
440
+ };
441
+
442
+
443
+ /** Handler built from several ReservoirTopN (one per query) */
444
+ template<class C, bool with_id_map = false>
445
+ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
446
+ using T = typename C::T;
447
+ using TI = typename C::TI;
448
+
449
+ size_t capacity; // rounded up to multiple of 16
450
+ std::vector<TI> all_ids;
451
+ AlignedTable<T> all_vals;
452
+
453
+ std::vector<ReservoirTopN<C>> reservoirs;
454
+
455
+ uint64_t times[4];
456
+
457
+ ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in):
458
+ SIMDResultHandler<C, with_id_map>(ntotal), capacity((capacity_in + 15) & ~15),
459
+ all_ids(nq * capacity), all_vals(nq * capacity)
460
+ {
461
+ assert(capacity % 16 == 0);
462
+ for (size_t i = 0; i < nq; i++) {
463
+ reservoirs.emplace_back(
464
+ n, capacity,
465
+ all_vals.get() + i * capacity,
466
+ all_ids.data() + i * capacity
467
+ );
468
+ }
469
+ times[0] = times[1] = times[2] = times[3] = 0;
470
+ }
471
+
472
+
473
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
474
+ uint64_t t0 = get_cy();
475
+ if(this->disable) {
476
+ return;
477
+ }
478
+ this->adjust_with_origin(q, d0, d1);
479
+
480
+ ReservoirTopN<C> & res = reservoirs[q];
481
+ uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
482
+ uint64_t t1 = get_cy();
483
+ times[0] += t1 - t0;
484
+
485
+ if (!lt_mask) {
486
+ return;
487
+ }
488
+ ALIGNED(32) uint16_t d32tab[32];
489
+ d0.store(d32tab);
490
+ d1.store(d32tab + 16);
491
+
492
+ while (lt_mask) {
493
+ // find first non-zero
494
+ int j = __builtin_ctz(lt_mask);
495
+ lt_mask -= 1 << j;
496
+ T dis = d32tab[j];
497
+ res.add(dis, this->adjust_id(b, j));
498
+ }
499
+ times[1] += get_cy() - t1;
500
+ }
501
+
502
+
503
+ void to_flat_arrays(
504
+ float *distances, int64_t *labels,
505
+ const float *normalizers = nullptr
506
+ ) override {
507
+ using Cf = typename std::conditional<
508
+ C::is_max,
509
+ CMax<float, int64_t>, CMin<float, int64_t>>::type;
510
+
511
+ uint64_t t0 = get_cy();
512
+ uint64_t t3 = 0;
513
+ std::vector<int> perm(reservoirs[0].n);
514
+ for (int q = 0; q < reservoirs.size(); q++) {
515
+ ReservoirTopN<C> & res = reservoirs[q];
516
+ size_t n = res.n;
517
+
518
+ if (res.i > res.n) {
519
+ res.shrink();
520
+ }
521
+ int64_t *heap_ids = labels + q * n;
522
+ float *heap_dis = distances + q * n;
523
+
524
+ float one_a = 1.0, b = 0.0;
525
+ if (normalizers) {
526
+ one_a = 1 / normalizers[2 * q];
527
+ b = normalizers[2 * q + 1];
528
+ }
529
+ for (int i = 0; i < res.i; i++) {
530
+ perm[i] = i;
531
+ }
532
+ // indirect sort of result arrays
533
+ std::sort(
534
+ perm.begin(), perm.begin() + res.i,
535
+ [&res](int i, int j) {
536
+ return C::cmp(res.vals[j], res.vals[i]);
537
+ }
538
+ );
539
+ for (int i = 0; i < res.i; i++) {
540
+ heap_dis[i] = res.vals[perm[i]] * one_a + b;
541
+ heap_ids[i] = res.ids[perm[i]];
542
+ }
543
+
544
+ // possibly add empty results
545
+ heap_heapify<Cf> (n - res.i, heap_dis + res.i, heap_ids + res.i);
546
+
547
+ t3 += res.cycles;
548
+ }
549
+ times[2] += get_cy() - t0;
550
+ times[3] += t3;
551
+ }
552
+
553
+ };
554
+
555
+
556
+ } // namespace simd_result_handlers
557
+
558
+
559
+ } // namespace faiss