faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,559 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <vector>
11
+ #include <algorithm>
12
+ #include <type_traits>
13
+
14
+ #include <faiss/utils/Heap.h>
15
+ #include <faiss/utils/simdlib.h>
16
+
17
+ #include <faiss/utils/AlignedTable.h>
18
+ #include <faiss/utils/partitioning.h>
19
+ #include <faiss/impl/platform_macros.h>
20
+
21
+ /** This file contains callbacks for kernels that compute distances.
22
+ *
23
+ * The SIMDResultHandler object is intended to be templated and inlined.
24
+ * Methods:
25
+ * - handle(): called when 32 distances are computed and provided in two
26
+ * simd16uint16. (q, b) indicate which entry it is in the block.
27
+ * - set_block_origin(): set the sub-matrix that is being computed
28
+ */
29
+
30
+ namespace faiss {
31
+
32
+ namespace simd_result_handlers {
33
+
34
+
35
+ /** Dummy structure that just computes a checksum on results
36
+ * (to avoid the computation to be optimized away) */
37
+ struct DummyResultHandler {
38
+ size_t cs = 0;
39
+
40
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
41
+ cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
42
+ }
43
+
44
+ void set_block_origin(size_t, size_t) {
45
+ }
46
+ };
47
+
48
+ /** memorize results in a nq-by-nb matrix.
49
+ *
50
+ * j0 is the current upper-left block of the matrix
51
+ */
52
+ struct StoreResultHandler {
53
+ uint16_t *data;
54
+ size_t ld; // total number of columns
55
+ size_t i0 = 0;
56
+ size_t j0 = 0;
57
+
58
+ StoreResultHandler(uint16_t *data, size_t ld):
59
+ data(data), ld(ld) {
60
+ }
61
+
62
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
63
+ size_t ofs = (q + i0) * ld + j0 + b * 32;
64
+ d0.store(data + ofs);
65
+ d1.store(data + ofs + 16);
66
+ }
67
+
68
+ void set_block_origin(size_t i0, size_t j0) {
69
+ this->i0 = i0;
70
+ this->j0 = j0;
71
+ }
72
+ };
73
+
74
+
75
+ /** stores results in fixed-size matrix. */
76
+ template<int NQ, int BB>
77
+ struct FixedStorageHandler {
78
+ simd16uint16 dis[NQ][BB];
79
+ int i0 = 0;
80
+
81
+ void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
82
+ dis[q + i0][2 * b] = d0;
83
+ dis[q + i0][2 * b + 1] = d1;
84
+ }
85
+
86
+ void set_block_origin(size_t i0, size_t j0) {
87
+ this->i0 = i0;
88
+ assert(j0 == 0);
89
+ }
90
+
91
+ template<class OtherResultHandler>
92
+ void to_other_handler(OtherResultHandler & other) const {
93
+ for (int q = 0; q < NQ; q++) {
94
+ for(int b = 0; b < BB; b += 2) {
95
+ other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
96
+ }
97
+ }
98
+ }
99
+
100
+ };
101
+
102
+
103
+ /** Record origin of current block */
104
+ template<class C, bool with_id_map>
105
+ struct SIMDResultHandler {
106
+ using TI = typename C::TI;
107
+
108
+ bool disable = false;
109
+
110
+ int64_t i0 = 0; // query origin
111
+ int64_t j0 = 0; // db origin
112
+ size_t ntotal; // ignore excess elements after ntotal
113
+
114
+ /// these fields are used mainly for the IVF variants (with_id_map=true)
115
+ const TI *id_map; // map offset in invlist to vector id
116
+ const int *q_map; // map q to global query
117
+ const uint16_t *dbias; // table of biases to add to each query
118
+
119
+ explicit SIMDResultHandler(size_t ntotal):
120
+ ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr)
121
+ {}
122
+
123
+ void set_block_origin(size_t i0, size_t j0) {
124
+ this->i0 = i0;
125
+ this->j0 = j0;
126
+ }
127
+
128
+
129
+ // adjust handler data for IVF.
130
+ void adjust_with_origin(size_t & q, simd16uint16 & d0, simd16uint16 & d1)
131
+ {
132
+ q += i0;
133
+
134
+ if (dbias) {
135
+ simd16uint16 dbias16(dbias[q]);
136
+ d0 += dbias16;
137
+ d1 += dbias16;
138
+ }
139
+
140
+ if (with_id_map) { // FIXME test on q_map instead
141
+ q = q_map[q];
142
+ }
143
+ }
144
+
145
+ // compute and adjust idx
146
+ int64_t adjust_id(size_t b, size_t j) {
147
+ int64_t idx = j0 + 32 * b + j;
148
+ if (with_id_map) {
149
+ idx = id_map[idx];
150
+ }
151
+ return idx;
152
+ }
153
+
154
+ /// return binary mask of elements below thr in (d0, d1)
155
+ /// inverse_test returns elements above
156
+ uint32_t get_lt_mask(
157
+ uint16_t thr, size_t b,
158
+ simd16uint16 d0, simd16uint16 d1
159
+ ) {
160
+ simd16uint16 thr16(thr);
161
+ uint32_t lt_mask;
162
+
163
+ constexpr bool keep_min = C::is_max;
164
+ if (keep_min) {
165
+ lt_mask = ~cmp_ge32(d0, d1, thr16);
166
+ } else {
167
+ lt_mask = ~cmp_le32(d0, d1, thr16);
168
+ }
169
+
170
+ if (lt_mask == 0) {
171
+ return 0;
172
+ }
173
+ uint64_t idx = j0 + b * 32;
174
+ if (idx + 32 > ntotal) {
175
+ if (idx >= ntotal) {
176
+ return 0;
177
+ }
178
+ int nbit = (ntotal - idx);
179
+ lt_mask &= (uint32_t(1) << nbit) - 1;
180
+ }
181
+ return lt_mask;
182
+ }
183
+
184
+ virtual void to_flat_arrays(
185
+ float *distances, int64_t *labels,
186
+ const float *normalizers = nullptr
187
+ ) = 0;
188
+
189
+ virtual ~SIMDResultHandler() {}
190
+
191
+ };
192
+
193
+
194
+ /** Special version for k=1 */
195
+ template<class C, bool with_id_map = false>
196
+ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
197
+ using T = typename C::T;
198
+ using TI = typename C::TI;
199
+
200
+ struct Result {
201
+ T val;
202
+ TI id;
203
+ };
204
+ std::vector<Result> results;
205
+
206
+ SingleResultHandler(size_t nq, size_t ntotal):
207
+ SIMDResultHandler<C, with_id_map>(ntotal), results(nq)
208
+ {
209
+ for (int i = 0; i < nq; i++) {
210
+ Result res = {C::neutral(), -1};
211
+ results[i] = res;
212
+ }
213
+ }
214
+
215
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
216
+ if(this->disable) {
217
+ return;
218
+ }
219
+
220
+ this->adjust_with_origin(q, d0, d1);
221
+
222
+ Result & res = results[q];
223
+ uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
224
+ if (!lt_mask) {
225
+ return;
226
+ }
227
+
228
+ ALIGNED(32) uint16_t d32tab[32];
229
+ d0.store(d32tab);
230
+ d1.store(d32tab + 16);
231
+
232
+ while (lt_mask) {
233
+ // find first non-zero
234
+ int j = __builtin_ctz(lt_mask);
235
+ lt_mask -= 1 << j;
236
+ T dis = d32tab[j];
237
+ if (C::cmp(res.val, dis)) {
238
+ res.val = dis;
239
+ res.id = this->adjust_id(b, j);
240
+ }
241
+ }
242
+ }
243
+
244
+ void to_flat_arrays(
245
+ float *distances, int64_t *labels,
246
+ const float *normalizers = nullptr
247
+ ) override {
248
+ for (int q = 0; q < results.size(); q++) {
249
+ if (!normalizers) {
250
+ distances[q] = results[q].val;
251
+ } else {
252
+ float one_a = 1 / normalizers[2 * q];
253
+ float b = normalizers[2 * q + 1];
254
+ distances[q] = b + results[q].val * one_a;
255
+ }
256
+ labels[q] = results[q].id;
257
+ }
258
+ }
259
+
260
+ };
261
+
262
+ /** Structure that collects results in a min- or max-heap */
263
+ template<class C, bool with_id_map = false>
264
+ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
265
+ using T = typename C::T;
266
+ using TI = typename C::TI;
267
+
268
+ int nq;
269
+ T *heap_dis_tab;
270
+ TI *heap_ids_tab;
271
+
272
+ int64_t k; // number of results to keep
273
+
274
+ HeapHandler(
275
+ int nq,
276
+ T * heap_dis_tab, TI * heap_ids_tab,
277
+ size_t k, size_t ntotal
278
+ ):
279
+ SIMDResultHandler<C, with_id_map>(ntotal), nq(nq),
280
+ heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
281
+ {
282
+ for (int q = 0; q < nq; q++) {
283
+ T *heap_dis_in = heap_dis_tab + q * k;
284
+ TI *heap_ids_in = heap_ids_tab + q * k;
285
+ heap_heapify<C> (k, heap_dis_in, heap_ids_in);
286
+ }
287
+ }
288
+
289
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
290
+ if(this->disable) {
291
+ return;
292
+ }
293
+
294
+ this->adjust_with_origin(q, d0, d1);
295
+
296
+ T *heap_dis = heap_dis_tab + q * k;
297
+ TI *heap_ids = heap_ids_tab + q * k;
298
+
299
+ uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) :
300
+ 0xffff;
301
+
302
+ // here we handle the reverse comparison case as well
303
+ uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
304
+
305
+ if (!lt_mask) {
306
+ return;
307
+ }
308
+
309
+ ALIGNED(32) uint16_t d32tab[32] ;
310
+ d0.store(d32tab);
311
+ d1.store(d32tab + 16);
312
+
313
+ while (lt_mask) {
314
+ // find first non-zero
315
+ int j = __builtin_ctz(lt_mask);
316
+ lt_mask -= 1 << j;
317
+ T dis = d32tab[j];
318
+ if (C::cmp(heap_dis[0], dis)) {
319
+ int64_t idx = this->adjust_id(b, j);
320
+ heap_pop<C>(k, heap_dis, heap_ids);
321
+ heap_push<C>(k, heap_dis, heap_ids, dis, idx);
322
+ }
323
+ }
324
+
325
+ }
326
+
327
+ void to_flat_arrays(
328
+ float *distances, int64_t *labels,
329
+ const float *normalizers = nullptr
330
+ ) override {
331
+
332
+ for (int q = 0; q < nq; q++) {
333
+ T *heap_dis_in = heap_dis_tab + q * k;
334
+ TI *heap_ids_in = heap_ids_tab + q * k;
335
+ heap_reorder<C> (k, heap_dis_in, heap_ids_in);
336
+ int64_t *heap_ids = labels + q * k;
337
+ float *heap_dis = distances + q * k;
338
+
339
+ float one_a = 1.0, b = 0.0;
340
+ if (normalizers) {
341
+ one_a = 1 / normalizers[2 * q];
342
+ b = normalizers[2 * q + 1];
343
+ }
344
+ for (int j = 0; j < k; j++) {
345
+ heap_ids[j] = heap_ids_in[j];
346
+ heap_dis[j] = heap_dis_in[j] * one_a + b;
347
+ }
348
+ }
349
+ }
350
+
351
+ };
352
+
353
+
354
+ /** Simple top-N implementation using a reservoir.
355
+ *
356
+ * Results are stored when they are below the threshold until the capacity is
357
+ * reached. Then a partition sort is used to update the threshold. */
358
+
359
+ namespace {
360
+
361
+ uint64_t get_cy () {
362
+ #ifdef MICRO_BENCHMARK
363
+ uint32_t high, low;
364
+ asm volatile("rdtsc \n\t"
365
+ : "=a" (low),
366
+ "=d" (high));
367
+ return ((uint64_t)high << 32) | (low);
368
+ #else
369
+ return 0;
370
+ #endif
371
+ }
372
+
373
+ } // anonymous namespace
374
+
375
+ template<class C>
376
+ struct ReservoirTopN {
377
+ using T = typename C::T;
378
+ using TI = typename C::TI;
379
+
380
+ T *vals;
381
+ TI *ids;
382
+
383
+ size_t i; // number of stored elements
384
+ size_t n; // number of requested elements
385
+ size_t capacity; // size of storage
386
+ size_t cycles = 0;
387
+
388
+ T threshold; // current threshold
389
+
390
+ ReservoirTopN(
391
+ size_t n, size_t capacity,
392
+ T *vals, TI *ids
393
+ ):
394
+ vals(vals), ids(ids),
395
+ i(0), n(n), capacity(capacity) {
396
+ assert(n < capacity);
397
+ threshold = C::neutral();
398
+ }
399
+
400
+ void add(T val, TI id) {
401
+ if (C::cmp(threshold, val)) {
402
+ if (i == capacity) {
403
+ shrink_fuzzy();
404
+ }
405
+ vals[i] = val;
406
+ ids[i] = id;
407
+ i++;
408
+ }
409
+ }
410
+
411
+ /// shrink number of stored elements to n
412
+ void shrink_xx() {
413
+ uint64_t t0 = get_cy();
414
+ qselect (vals, ids, i, n);
415
+ i = n; // forget all elements above i = n
416
+ threshold = C::Crev::neutral();
417
+ for(size_t j = 0; j < n; j++) {
418
+ if(C::cmp(vals[j], threshold)) {
419
+ threshold = vals[j];
420
+ }
421
+ }
422
+ cycles += get_cy() - t0;
423
+ }
424
+
425
+ void shrink() {
426
+ uint64_t t0 = get_cy();
427
+ threshold = partition<C>(vals, ids, i, n);
428
+ i = n;
429
+ cycles += get_cy() - t0;
430
+ }
431
+
432
+ void shrink_fuzzy() {
433
+ uint64_t t0 = get_cy();
434
+ assert(i == capacity);
435
+ threshold = partition_fuzzy<C>(
436
+ vals, ids, capacity, n, (capacity + n) / 2,
437
+ &i);
438
+ cycles += get_cy() - t0;
439
+ }
440
+ };
441
+
442
+
443
+ /** Handler built from several ReservoirTopN (one per query) */
444
+ template<class C, bool with_id_map = false>
445
+ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
446
+ using T = typename C::T;
447
+ using TI = typename C::TI;
448
+
449
+ size_t capacity; // rounded up to multiple of 16
450
+ std::vector<TI> all_ids;
451
+ AlignedTable<T> all_vals;
452
+
453
+ std::vector<ReservoirTopN<C>> reservoirs;
454
+
455
+ uint64_t times[4];
456
+
457
+ ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in):
458
+ SIMDResultHandler<C, with_id_map>(ntotal), capacity((capacity_in + 15) & ~15),
459
+ all_ids(nq * capacity), all_vals(nq * capacity)
460
+ {
461
+ assert(capacity % 16 == 0);
462
+ for (size_t i = 0; i < nq; i++) {
463
+ reservoirs.emplace_back(
464
+ n, capacity,
465
+ all_vals.get() + i * capacity,
466
+ all_ids.data() + i * capacity
467
+ );
468
+ }
469
+ times[0] = times[1] = times[2] = times[3] = 0;
470
+ }
471
+
472
+
473
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
474
+ uint64_t t0 = get_cy();
475
+ if(this->disable) {
476
+ return;
477
+ }
478
+ this->adjust_with_origin(q, d0, d1);
479
+
480
+ ReservoirTopN<C> & res = reservoirs[q];
481
+ uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
482
+ uint64_t t1 = get_cy();
483
+ times[0] += t1 - t0;
484
+
485
+ if (!lt_mask) {
486
+ return;
487
+ }
488
+ ALIGNED(32) uint16_t d32tab[32];
489
+ d0.store(d32tab);
490
+ d1.store(d32tab + 16);
491
+
492
+ while (lt_mask) {
493
+ // find first non-zero
494
+ int j = __builtin_ctz(lt_mask);
495
+ lt_mask -= 1 << j;
496
+ T dis = d32tab[j];
497
+ res.add(dis, this->adjust_id(b, j));
498
+ }
499
+ times[1] += get_cy() - t1;
500
+ }
501
+
502
+
503
+ void to_flat_arrays(
504
+ float *distances, int64_t *labels,
505
+ const float *normalizers = nullptr
506
+ ) override {
507
+ using Cf = typename std::conditional<
508
+ C::is_max,
509
+ CMax<float, int64_t>, CMin<float, int64_t>>::type;
510
+
511
+ uint64_t t0 = get_cy();
512
+ uint64_t t3 = 0;
513
+ std::vector<int> perm(reservoirs[0].n);
514
+ for (int q = 0; q < reservoirs.size(); q++) {
515
+ ReservoirTopN<C> & res = reservoirs[q];
516
+ size_t n = res.n;
517
+
518
+ if (res.i > res.n) {
519
+ res.shrink();
520
+ }
521
+ int64_t *heap_ids = labels + q * n;
522
+ float *heap_dis = distances + q * n;
523
+
524
+ float one_a = 1.0, b = 0.0;
525
+ if (normalizers) {
526
+ one_a = 1 / normalizers[2 * q];
527
+ b = normalizers[2 * q + 1];
528
+ }
529
+ for (int i = 0; i < res.i; i++) {
530
+ perm[i] = i;
531
+ }
532
+ // indirect sort of result arrays
533
+ std::sort(
534
+ perm.begin(), perm.begin() + res.i,
535
+ [&res](int i, int j) {
536
+ return C::cmp(res.vals[j], res.vals[i]);
537
+ }
538
+ );
539
+ for (int i = 0; i < res.i; i++) {
540
+ heap_dis[i] = res.vals[perm[i]] * one_a + b;
541
+ heap_ids[i] = res.ids[perm[i]];
542
+ }
543
+
544
+ // possibly add empty results
545
+ heap_heapify<Cf> (n - res.i, heap_dis + res.i, heap_ids + res.i);
546
+
547
+ t3 += res.cycles;
548
+ }
549
+ times[2] += get_cy() - t0;
550
+ times[3] += t3;
551
+ }
552
+
553
+ };
554
+
555
+
556
+ } // namespace simd_result_handlers
557
+
558
+
559
+ } // namespace faiss