faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,1290 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFFastScan.h>
9
+
10
+ #include <cassert>
11
+ #include <cinttypes>
12
+ #include <cstdio>
13
+ #include <set>
14
+
15
+ #include <omp.h>
16
+
17
+ #include <memory>
18
+
19
+ #include <faiss/IndexIVFPQ.h>
20
+ #include <faiss/impl/AuxIndexStructures.h>
21
+ #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/LookupTableScaler.h>
23
+ #include <faiss/impl/pq4_fast_scan.h>
24
+ #include <faiss/impl/simd_result_handlers.h>
25
+ #include <faiss/invlists/BlockInvertedLists.h>
26
+ #include <faiss/utils/distances.h>
27
+ #include <faiss/utils/hamming.h>
28
+ #include <faiss/utils/quantize_lut.h>
29
+ #include <faiss/utils/utils.h>
30
+
31
+ namespace faiss {
32
+
33
+ using namespace simd_result_handlers;
34
+
35
+ inline size_t roundup(size_t a, size_t b) {
36
+ return (a + b - 1) / b * b;
37
+ }
38
+
39
+ IndexIVFFastScan::IndexIVFFastScan(
40
+ Index* quantizer,
41
+ size_t d,
42
+ size_t nlist,
43
+ size_t code_size,
44
+ MetricType metric)
45
+ : IndexIVF(quantizer, d, nlist, code_size, metric) {
46
+ FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
47
+ }
48
+
49
+ IndexIVFFastScan::IndexIVFFastScan() {
50
+ bbs = 0;
51
+ M2 = 0;
52
+ is_trained = false;
53
+ }
54
+
55
+ void IndexIVFFastScan::init_fastscan(
56
+ size_t M,
57
+ size_t nbits,
58
+ size_t nlist,
59
+ MetricType /* metric */,
60
+ int bbs) {
61
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
62
+ FAISS_THROW_IF_NOT(nbits == 4);
63
+
64
+ this->M = M;
65
+ this->nbits = nbits;
66
+ this->bbs = bbs;
67
+ ksub = (1 << nbits);
68
+ M2 = roundup(M, 2);
69
+ code_size = M2 / 2;
70
+
71
+ is_trained = false;
72
+ replace_invlists(new BlockInvertedLists(nlist, bbs, bbs * M2 / 2), true);
73
+ }
74
+
75
+ IndexIVFFastScan::~IndexIVFFastScan() {}
76
+
77
+ /*********************************************************
78
+ * Code management functions
79
+ *********************************************************/
80
+
81
+ void IndexIVFFastScan::add_with_ids(
82
+ idx_t n,
83
+ const float* x,
84
+ const idx_t* xids) {
85
+ FAISS_THROW_IF_NOT(is_trained);
86
+
87
+ // do some blocking to avoid excessive allocs
88
+ constexpr idx_t bs = 65536;
89
+ if (n > bs) {
90
+ double t0 = getmillisecs();
91
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
92
+ idx_t i1 = std::min(n, i0 + bs);
93
+ if (verbose) {
94
+ double t1 = getmillisecs();
95
+ double elapsed_time = (t1 - t0) / 1000;
96
+ double total_time = 0;
97
+ if (i0 != 0) {
98
+ total_time = elapsed_time / i0 * n;
99
+ }
100
+ size_t mem = get_mem_usage_kb() / (1 << 10);
101
+
102
+ printf("IndexIVFFastScan::add_with_ids %zd/%zd, time %.2f/%.2f, RSS %zdMB\n",
103
+ size_t(i1),
104
+ size_t(n),
105
+ elapsed_time,
106
+ total_time,
107
+ mem);
108
+ }
109
+ add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
110
+ }
111
+ return;
112
+ }
113
+ InterruptCallback::check();
114
+
115
+ AlignedTable<uint8_t> codes(n * code_size);
116
+ direct_map.check_can_add(xids);
117
+ std::unique_ptr<idx_t[]> idx(new idx_t[n]);
118
+ quantizer->assign(n, x, idx.get());
119
+ size_t nadd = 0, nminus1 = 0;
120
+
121
+ for (size_t i = 0; i < n; i++) {
122
+ if (idx[i] < 0) {
123
+ nminus1++;
124
+ }
125
+ }
126
+
127
+ AlignedTable<uint8_t> flat_codes(n * code_size);
128
+ encode_vectors(n, x, idx.get(), flat_codes.get());
129
+
130
+ DirectMapAdd dm_adder(direct_map, n, xids);
131
+ BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
132
+ FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
133
+
134
+ // prepare batches
135
+ std::vector<idx_t> order(n);
136
+ for (idx_t i = 0; i < n; i++) {
137
+ order[i] = i;
138
+ }
139
+
140
+ // TODO should not need stable
141
+ std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
142
+ return idx[a] < idx[b];
143
+ });
144
+
145
+ // TODO parallelize
146
+ idx_t i0 = 0;
147
+ while (i0 < n) {
148
+ idx_t list_no = idx[order[i0]];
149
+ idx_t i1 = i0 + 1;
150
+ while (i1 < n && idx[order[i1]] == list_no) {
151
+ i1++;
152
+ }
153
+
154
+ if (list_no == -1) {
155
+ i0 = i1;
156
+ continue;
157
+ }
158
+
159
+ // make linear array
160
+ AlignedTable<uint8_t> list_codes((i1 - i0) * code_size);
161
+ size_t list_size = bil->list_size(list_no);
162
+
163
+ bil->resize(list_no, list_size + i1 - i0);
164
+
165
+ for (idx_t i = i0; i < i1; i++) {
166
+ size_t ofs = list_size + i - i0;
167
+ idx_t id = xids ? xids[order[i]] : ntotal + order[i];
168
+ dm_adder.add(order[i], list_no, ofs);
169
+ bil->ids[list_no][ofs] = id;
170
+ memcpy(list_codes.data() + (i - i0) * code_size,
171
+ flat_codes.data() + order[i] * code_size,
172
+ code_size);
173
+ nadd++;
174
+ }
175
+ pq4_pack_codes_range(
176
+ list_codes.data(),
177
+ M,
178
+ list_size,
179
+ list_size + i1 - i0,
180
+ bbs,
181
+ M2,
182
+ bil->codes[list_no].data());
183
+
184
+ i0 = i1;
185
+ }
186
+
187
+ ntotal += n;
188
+ }
189
+
190
+ /*********************************************************
191
+ * search
192
+ *********************************************************/
193
+
194
+ namespace {
195
+
196
+ template <class C, typename dis_t, class Scaler>
197
+ void estimators_from_tables_generic(
198
+ const IndexIVFFastScan& index,
199
+ const uint8_t* codes,
200
+ size_t ncodes,
201
+ const dis_t* dis_table,
202
+ const int64_t* ids,
203
+ float bias,
204
+ size_t k,
205
+ typename C::T* heap_dis,
206
+ int64_t* heap_ids,
207
+ const Scaler& scaler) {
208
+ using accu_t = typename C::T;
209
+ for (size_t j = 0; j < ncodes; ++j) {
210
+ BitstringReader bsr(codes + j * index.code_size, index.code_size);
211
+ accu_t dis = bias;
212
+ const dis_t* __restrict dt = dis_table;
213
+ for (size_t m = 0; m < index.M - scaler.nscale; m++) {
214
+ uint64_t c = bsr.read(index.nbits);
215
+ dis += dt[c];
216
+ dt += index.ksub;
217
+ }
218
+
219
+ for (size_t m = 0; m < scaler.nscale; m++) {
220
+ uint64_t c = bsr.read(index.nbits);
221
+ dis += scaler.scale_one(dt[c]);
222
+ dt += index.ksub;
223
+ }
224
+
225
+ if (C::cmp(heap_dis[0], dis)) {
226
+ heap_pop<C>(k, heap_dis, heap_ids);
227
+ heap_push<C>(k, heap_dis, heap_ids, dis, ids[j]);
228
+ }
229
+ }
230
+ }
231
+
232
+ using idx_t = Index::idx_t;
233
+ using namespace quantize_lut;
234
+
235
+ } // anonymous namespace
236
+
237
+ /*********************************************************
238
+ * Look-Up Table functions
239
+ *********************************************************/
240
+
241
+ void IndexIVFFastScan::compute_LUT_uint8(
242
+ size_t n,
243
+ const float* x,
244
+ const idx_t* coarse_ids,
245
+ const float* coarse_dis,
246
+ AlignedTable<uint8_t>& dis_tables,
247
+ AlignedTable<uint16_t>& biases,
248
+ float* normalizers) const {
249
+ AlignedTable<float> dis_tables_float;
250
+ AlignedTable<float> biases_float;
251
+
252
+ uint64_t t0 = get_cy();
253
+ compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
254
+ IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
255
+
256
+ bool lut_is_3d = lookup_table_is_3d();
257
+ size_t dim123 = ksub * M;
258
+ size_t dim123_2 = ksub * M2;
259
+ if (lut_is_3d) {
260
+ dim123 *= nprobe;
261
+ dim123_2 *= nprobe;
262
+ }
263
+ dis_tables.resize(n * dim123_2);
264
+ if (biases_float.get()) {
265
+ biases.resize(n * nprobe);
266
+ }
267
+ uint64_t t1 = get_cy();
268
+
269
+ #pragma omp parallel for if (n > 100)
270
+ for (int64_t i = 0; i < n; i++) {
271
+ const float* t_in = dis_tables_float.get() + i * dim123;
272
+ const float* b_in = nullptr;
273
+ uint8_t* t_out = dis_tables.get() + i * dim123_2;
274
+ uint16_t* b_out = nullptr;
275
+ if (biases_float.get()) {
276
+ b_in = biases_float.get() + i * nprobe;
277
+ b_out = biases.get() + i * nprobe;
278
+ }
279
+
280
+ quantize_LUT_and_bias(
281
+ nprobe,
282
+ M,
283
+ ksub,
284
+ lut_is_3d,
285
+ t_in,
286
+ b_in,
287
+ t_out,
288
+ M2,
289
+ b_out,
290
+ normalizers + 2 * i,
291
+ normalizers + 2 * i + 1);
292
+ }
293
+ IVFFastScan_stats.t_round += get_cy() - t1;
294
+ }
295
+
296
+ /*********************************************************
297
+ * Search functions
298
+ *********************************************************/
299
+
300
+ void IndexIVFFastScan::search(
301
+ idx_t n,
302
+ const float* x,
303
+ idx_t k,
304
+ float* distances,
305
+ idx_t* labels,
306
+ const SearchParameters* params) const {
307
+ FAISS_THROW_IF_NOT_MSG(
308
+ !params, "search params not supported for this index");
309
+ FAISS_THROW_IF_NOT(k > 0);
310
+
311
+ DummyScaler scaler;
312
+ if (metric_type == METRIC_L2) {
313
+ search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
314
+ } else {
315
+ search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
316
+ }
317
+ }
318
+
319
+ void IndexIVFFastScan::range_search(
320
+ idx_t,
321
+ const float*,
322
+ float,
323
+ RangeSearchResult*,
324
+ const SearchParameters*) const {
325
+ FAISS_THROW_MSG("not implemented");
326
+ }
327
+
328
+ template <bool is_max, class Scaler>
329
+ void IndexIVFFastScan::search_dispatch_implem(
330
+ idx_t n,
331
+ const float* x,
332
+ idx_t k,
333
+ float* distances,
334
+ idx_t* labels,
335
+ const Scaler& scaler) const {
336
+ using Cfloat = typename std::conditional<
337
+ is_max,
338
+ CMax<float, int64_t>,
339
+ CMin<float, int64_t>>::type;
340
+
341
+ using C = typename std::conditional<
342
+ is_max,
343
+ CMax<uint16_t, int64_t>,
344
+ CMin<uint16_t, int64_t>>::type;
345
+
346
+ if (n == 0) {
347
+ return;
348
+ }
349
+
350
+ // actual implementation used
351
+ int impl = implem;
352
+
353
+ if (impl == 0) {
354
+ if (bbs == 32) {
355
+ impl = 12;
356
+ } else {
357
+ impl = 10;
358
+ }
359
+ if (k > 20) {
360
+ impl++;
361
+ }
362
+ }
363
+
364
+ if (impl == 1) {
365
+ search_implem_1<Cfloat>(n, x, k, distances, labels, scaler);
366
+ } else if (impl == 2) {
367
+ search_implem_2<C>(n, x, k, distances, labels, scaler);
368
+
369
+ } else if (impl >= 10 && impl <= 15) {
370
+ size_t ndis = 0, nlist_visited = 0;
371
+
372
+ if (n < 2) {
373
+ if (impl == 12 || impl == 13) {
374
+ search_implem_12<C>(
375
+ n,
376
+ x,
377
+ k,
378
+ distances,
379
+ labels,
380
+ impl,
381
+ &ndis,
382
+ &nlist_visited,
383
+ scaler);
384
+ } else if (impl == 14 || impl == 15) {
385
+ search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
386
+ } else {
387
+ search_implem_10<C>(
388
+ n,
389
+ x,
390
+ k,
391
+ distances,
392
+ labels,
393
+ impl,
394
+ &ndis,
395
+ &nlist_visited,
396
+ scaler);
397
+ }
398
+ } else {
399
+ // explicitly slice over threads
400
+ int nslice;
401
+ if (n <= omp_get_max_threads()) {
402
+ nslice = n;
403
+ } else if (lookup_table_is_3d()) {
404
+ // make sure we don't make too big LUT tables
405
+ size_t lut_size_per_query =
406
+ M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
407
+
408
+ size_t max_lut_size = precomputed_table_max_bytes;
409
+ // how many queries we can handle within mem budget
410
+ size_t nq_ok =
411
+ std::max(max_lut_size / lut_size_per_query, size_t(1));
412
+ nslice =
413
+ roundup(std::max(size_t(n / nq_ok), size_t(1)),
414
+ omp_get_max_threads());
415
+ } else {
416
+ // LUTs unlikely to be a limiting factor
417
+ nslice = omp_get_max_threads();
418
+ }
419
+ if (impl == 14 ||
420
+ impl == 15) { // this might require slicing if there are too
421
+ // many queries (for now we keep this simple)
422
+ search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
423
+ } else {
424
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
425
+ for (int slice = 0; slice < nslice; slice++) {
426
+ idx_t i0 = n * slice / nslice;
427
+ idx_t i1 = n * (slice + 1) / nslice;
428
+ float* dis_i = distances + i0 * k;
429
+ idx_t* lab_i = labels + i0 * k;
430
+ if (impl == 12 || impl == 13) {
431
+ search_implem_12<C>(
432
+ i1 - i0,
433
+ x + i0 * d,
434
+ k,
435
+ dis_i,
436
+ lab_i,
437
+ impl,
438
+ &ndis,
439
+ &nlist_visited,
440
+ scaler);
441
+ } else {
442
+ search_implem_10<C>(
443
+ i1 - i0,
444
+ x + i0 * d,
445
+ k,
446
+ dis_i,
447
+ lab_i,
448
+ impl,
449
+ &ndis,
450
+ &nlist_visited,
451
+ scaler);
452
+ }
453
+ }
454
+ }
455
+ }
456
+ indexIVF_stats.nq += n;
457
+ indexIVF_stats.ndis += ndis;
458
+ indexIVF_stats.nlist += nlist_visited;
459
+ } else {
460
+ FAISS_THROW_FMT("implem %d does not exist", implem);
461
+ }
462
+ }
463
+
464
+ template <class C, class Scaler>
465
+ void IndexIVFFastScan::search_implem_1(
466
+ idx_t n,
467
+ const float* x,
468
+ idx_t k,
469
+ float* distances,
470
+ idx_t* labels,
471
+ const Scaler& scaler) const {
472
+ FAISS_THROW_IF_NOT(orig_invlists);
473
+
474
+ std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
475
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
476
+
477
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
478
+
479
+ size_t dim12 = ksub * M;
480
+ AlignedTable<float> dis_tables;
481
+ AlignedTable<float> biases;
482
+
483
+ compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
484
+
485
+ bool single_LUT = !lookup_table_is_3d();
486
+
487
+ size_t ndis = 0, nlist_visited = 0;
488
+
489
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
490
+ for (idx_t i = 0; i < n; i++) {
491
+ int64_t* heap_ids = labels + i * k;
492
+ float* heap_dis = distances + i * k;
493
+ heap_heapify<C>(k, heap_dis, heap_ids);
494
+ float* LUT = nullptr;
495
+
496
+ if (single_LUT) {
497
+ LUT = dis_tables.get() + i * dim12;
498
+ }
499
+ for (idx_t j = 0; j < nprobe; j++) {
500
+ if (!single_LUT) {
501
+ LUT = dis_tables.get() + (i * nprobe + j) * dim12;
502
+ }
503
+ idx_t list_no = coarse_ids[i * nprobe + j];
504
+ if (list_no < 0)
505
+ continue;
506
+ size_t ls = orig_invlists->list_size(list_no);
507
+ if (ls == 0)
508
+ continue;
509
+ InvertedLists::ScopedCodes codes(orig_invlists, list_no);
510
+ InvertedLists::ScopedIds ids(orig_invlists, list_no);
511
+
512
+ float bias = biases.get() ? biases[i * nprobe + j] : 0;
513
+
514
+ estimators_from_tables_generic<C>(
515
+ *this,
516
+ codes.get(),
517
+ ls,
518
+ LUT,
519
+ ids.get(),
520
+ bias,
521
+ k,
522
+ heap_dis,
523
+ heap_ids,
524
+ scaler);
525
+ nlist_visited++;
526
+ ndis++;
527
+ }
528
+ heap_reorder<C>(k, heap_dis, heap_ids);
529
+ }
530
+ indexIVF_stats.nq += n;
531
+ indexIVF_stats.ndis += ndis;
532
+ indexIVF_stats.nlist += nlist_visited;
533
+ }
534
+
535
+ template <class C, class Scaler>
536
+ void IndexIVFFastScan::search_implem_2(
537
+ idx_t n,
538
+ const float* x,
539
+ idx_t k,
540
+ float* distances,
541
+ idx_t* labels,
542
+ const Scaler& scaler) const {
543
+ FAISS_THROW_IF_NOT(orig_invlists);
544
+
545
+ std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
546
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
547
+
548
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
549
+
550
+ size_t dim12 = ksub * M2;
551
+ AlignedTable<uint8_t> dis_tables;
552
+ AlignedTable<uint16_t> biases;
553
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
554
+
555
+ compute_LUT_uint8(
556
+ n,
557
+ x,
558
+ coarse_ids.get(),
559
+ coarse_dis.get(),
560
+ dis_tables,
561
+ biases,
562
+ normalizers.get());
563
+
564
+ bool single_LUT = !lookup_table_is_3d();
565
+
566
+ size_t ndis = 0, nlist_visited = 0;
567
+
568
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
569
+ for (idx_t i = 0; i < n; i++) {
570
+ std::vector<uint16_t> tmp_dis(k);
571
+ int64_t* heap_ids = labels + i * k;
572
+ uint16_t* heap_dis = tmp_dis.data();
573
+ heap_heapify<C>(k, heap_dis, heap_ids);
574
+ const uint8_t* LUT = nullptr;
575
+
576
+ if (single_LUT) {
577
+ LUT = dis_tables.get() + i * dim12;
578
+ }
579
+ for (idx_t j = 0; j < nprobe; j++) {
580
+ if (!single_LUT) {
581
+ LUT = dis_tables.get() + (i * nprobe + j) * dim12;
582
+ }
583
+ idx_t list_no = coarse_ids[i * nprobe + j];
584
+ if (list_no < 0)
585
+ continue;
586
+ size_t ls = orig_invlists->list_size(list_no);
587
+ if (ls == 0)
588
+ continue;
589
+ InvertedLists::ScopedCodes codes(orig_invlists, list_no);
590
+ InvertedLists::ScopedIds ids(orig_invlists, list_no);
591
+
592
+ uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
593
+
594
+ estimators_from_tables_generic<C>(
595
+ *this,
596
+ codes.get(),
597
+ ls,
598
+ LUT,
599
+ ids.get(),
600
+ bias,
601
+ k,
602
+ heap_dis,
603
+ heap_ids,
604
+ scaler);
605
+
606
+ nlist_visited++;
607
+ ndis += ls;
608
+ }
609
+ heap_reorder<C>(k, heap_dis, heap_ids);
610
+ // convert distances to float
611
+ {
612
+ float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
613
+ if (skip & 16) {
614
+ one_a = 1;
615
+ b = 0;
616
+ }
617
+ float* heap_dis_float = distances + i * k;
618
+ for (int j = 0; j < k; j++) {
619
+ heap_dis_float[j] = b + heap_dis[j] * one_a;
620
+ }
621
+ }
622
+ }
623
+ indexIVF_stats.nq += n;
624
+ indexIVF_stats.ndis += ndis;
625
+ indexIVF_stats.nlist += nlist_visited;
626
+ }
627
+
628
+ template <class C, class Scaler>
629
+ void IndexIVFFastScan::search_implem_10(
630
+ idx_t n,
631
+ const float* x,
632
+ idx_t k,
633
+ float* distances,
634
+ idx_t* labels,
635
+ int impl,
636
+ size_t* ndis_out,
637
+ size_t* nlist_out,
638
+ const Scaler& scaler) const {
639
+ memset(distances, -1, sizeof(float) * k * n);
640
+ memset(labels, -1, sizeof(idx_t) * k * n);
641
+
642
+ using HeapHC = HeapHandler<C, true>;
643
+ using ReservoirHC = ReservoirHandler<C, true>;
644
+ using SingleResultHC = SingleResultHandler<C, true>;
645
+
646
+ std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
647
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
648
+
649
+ uint64_t times[10];
650
+ memset(times, 0, sizeof(times));
651
+ int ti = 0;
652
+ #define TIC times[ti++] = get_cy()
653
+ TIC;
654
+
655
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
656
+
657
+ TIC;
658
+
659
+ size_t dim12 = ksub * M2;
660
+ AlignedTable<uint8_t> dis_tables;
661
+ AlignedTable<uint16_t> biases;
662
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
663
+
664
+ compute_LUT_uint8(
665
+ n,
666
+ x,
667
+ coarse_ids.get(),
668
+ coarse_dis.get(),
669
+ dis_tables,
670
+ biases,
671
+ normalizers.get());
672
+
673
+ TIC;
674
+
675
+ bool single_LUT = !lookup_table_is_3d();
676
+
677
+ TIC;
678
+ size_t ndis = 0, nlist_visited = 0;
679
+
680
+ {
681
+ AlignedTable<uint16_t> tmp_distances(k);
682
+ for (idx_t i = 0; i < n; i++) {
683
+ const uint8_t* LUT = nullptr;
684
+ int qmap1[1] = {0};
685
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
686
+
687
+ if (k == 1) {
688
+ handler.reset(new SingleResultHC(1, 0));
689
+ } else if (impl == 10) {
690
+ handler.reset(new HeapHC(
691
+ 1, tmp_distances.get(), labels + i * k, k, 0));
692
+ } else if (impl == 11) {
693
+ handler.reset(new ReservoirHC(1, 0, k, 2 * k));
694
+ } else {
695
+ FAISS_THROW_MSG("invalid");
696
+ }
697
+
698
+ handler->q_map = qmap1;
699
+
700
+ if (single_LUT) {
701
+ LUT = dis_tables.get() + i * dim12;
702
+ }
703
+ for (idx_t j = 0; j < nprobe; j++) {
704
+ size_t ij = i * nprobe + j;
705
+ if (!single_LUT) {
706
+ LUT = dis_tables.get() + ij * dim12;
707
+ }
708
+ if (biases.get()) {
709
+ handler->dbias = biases.get() + ij;
710
+ }
711
+
712
+ idx_t list_no = coarse_ids[ij];
713
+ if (list_no < 0)
714
+ continue;
715
+ size_t ls = invlists->list_size(list_no);
716
+ if (ls == 0)
717
+ continue;
718
+
719
+ InvertedLists::ScopedCodes codes(invlists, list_no);
720
+ InvertedLists::ScopedIds ids(invlists, list_no);
721
+
722
+ handler->ntotal = ls;
723
+ handler->id_map = ids.get();
724
+
725
+ #define DISPATCH(classHC) \
726
+ if (dynamic_cast<classHC*>(handler.get())) { \
727
+ auto* res = static_cast<classHC*>(handler.get()); \
728
+ pq4_accumulate_loop( \
729
+ 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \
730
+ }
731
+ DISPATCH(HeapHC)
732
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
733
+ #undef DISPATCH
734
+
735
+ nlist_visited++;
736
+ ndis++;
737
+ }
738
+
739
+ handler->to_flat_arrays(
740
+ distances + i * k,
741
+ labels + i * k,
742
+ skip & 16 ? nullptr : normalizers.get() + i * 2);
743
+ }
744
+ }
745
+ *ndis_out = ndis;
746
+ *nlist_out = nlist;
747
+ }
748
+
749
+ template <class C, class Scaler>
750
+ void IndexIVFFastScan::search_implem_12(
751
+ idx_t n,
752
+ const float* x,
753
+ idx_t k,
754
+ float* distances,
755
+ idx_t* labels,
756
+ int impl,
757
+ size_t* ndis_out,
758
+ size_t* nlist_out,
759
+ const Scaler& scaler) const {
760
+ if (n == 0) { // does not work well with reservoir
761
+ return;
762
+ }
763
+ FAISS_THROW_IF_NOT(bbs == 32);
764
+
765
+ std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
766
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
767
+
768
+ uint64_t times[10];
769
+ memset(times, 0, sizeof(times));
770
+ int ti = 0;
771
+ #define TIC times[ti++] = get_cy()
772
+ TIC;
773
+
774
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
775
+
776
+ TIC;
777
+
778
+ size_t dim12 = ksub * M2;
779
+ AlignedTable<uint8_t> dis_tables;
780
+ AlignedTable<uint16_t> biases;
781
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
782
+
783
+ compute_LUT_uint8(
784
+ n,
785
+ x,
786
+ coarse_ids.get(),
787
+ coarse_dis.get(),
788
+ dis_tables,
789
+ biases,
790
+ normalizers.get());
791
+
792
+ TIC;
793
+
794
+ struct QC {
795
+ int qno; // sequence number of the query
796
+ int list_no; // list to visit
797
+ int rank; // this is the rank'th result of the coarse quantizer
798
+ };
799
+ bool single_LUT = !lookup_table_is_3d();
800
+
801
+ std::vector<QC> qcs;
802
+ {
803
+ int ij = 0;
804
+ for (int i = 0; i < n; i++) {
805
+ for (int j = 0; j < nprobe; j++) {
806
+ if (coarse_ids[ij] >= 0) {
807
+ qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
808
+ }
809
+ ij++;
810
+ }
811
+ }
812
+ std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
813
+ return a.list_no < b.list_no;
814
+ });
815
+ }
816
+ TIC;
817
+
818
+ // prepare the result handlers
819
+
820
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
821
+ AlignedTable<uint16_t> tmp_distances;
822
+
823
+ using HeapHC = HeapHandler<C, true>;
824
+ using ReservoirHC = ReservoirHandler<C, true>;
825
+ using SingleResultHC = SingleResultHandler<C, true>;
826
+
827
+ if (k == 1) {
828
+ handler.reset(new SingleResultHC(n, 0));
829
+ } else if (impl == 12) {
830
+ tmp_distances.resize(n * k);
831
+ handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
832
+ } else if (impl == 13) {
833
+ handler.reset(new ReservoirHC(n, 0, k, 2 * k));
834
+ }
835
+
836
+ int qbs2 = this->qbs2 ? this->qbs2 : 11;
837
+
838
+ std::vector<uint16_t> tmp_bias;
839
+ if (biases.get()) {
840
+ tmp_bias.resize(qbs2);
841
+ handler->dbias = tmp_bias.data();
842
+ }
843
+ TIC;
844
+
845
+ size_t ndis = 0;
846
+
847
+ size_t i0 = 0;
848
+ uint64_t t_copy_pack = 0, t_scan = 0;
849
+ while (i0 < qcs.size()) {
850
+ uint64_t tt0 = get_cy();
851
+
852
+ // find all queries that access this inverted list
853
+ int list_no = qcs[i0].list_no;
854
+ size_t i1 = i0 + 1;
855
+
856
+ while (i1 < qcs.size() && i1 < i0 + qbs2) {
857
+ if (qcs[i1].list_no != list_no) {
858
+ break;
859
+ }
860
+ i1++;
861
+ }
862
+
863
+ size_t list_size = invlists->list_size(list_no);
864
+
865
+ if (list_size == 0) {
866
+ i0 = i1;
867
+ continue;
868
+ }
869
+
870
+ // re-organize LUTs and biases into the right order
871
+ int nc = i1 - i0;
872
+
873
+ std::vector<int> q_map(nc), lut_entries(nc);
874
+ AlignedTable<uint8_t> LUT(nc * dim12);
875
+ memset(LUT.get(), -1, nc * dim12);
876
+ int qbs = pq4_preferred_qbs(nc);
877
+
878
+ for (size_t i = i0; i < i1; i++) {
879
+ const QC& qc = qcs[i];
880
+ q_map[i - i0] = qc.qno;
881
+ int ij = qc.qno * nprobe + qc.rank;
882
+ lut_entries[i - i0] = single_LUT ? qc.qno : ij;
883
+ if (biases.get()) {
884
+ tmp_bias[i - i0] = biases[ij];
885
+ }
886
+ }
887
+ pq4_pack_LUT_qbs_q_map(
888
+ qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
889
+
890
+ // access the inverted list
891
+
892
+ ndis += (i1 - i0) * list_size;
893
+
894
+ InvertedLists::ScopedCodes codes(invlists, list_no);
895
+ InvertedLists::ScopedIds ids(invlists, list_no);
896
+
897
+ // prepare the handler
898
+
899
+ handler->ntotal = list_size;
900
+ handler->q_map = q_map.data();
901
+ handler->id_map = ids.get();
902
+ uint64_t tt1 = get_cy();
903
+
904
+ #define DISPATCH(classHC) \
905
+ if (dynamic_cast<classHC*>(handler.get())) { \
906
+ auto* res = static_cast<classHC*>(handler.get()); \
907
+ pq4_accumulate_loop_qbs( \
908
+ qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
909
+ }
910
+ DISPATCH(HeapHC)
911
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
912
+
913
+ // prepare for next loop
914
+ i0 = i1;
915
+
916
+ uint64_t tt2 = get_cy();
917
+ t_copy_pack += tt1 - tt0;
918
+ t_scan += tt2 - tt1;
919
+ }
920
+ TIC;
921
+
922
+ // labels is in-place for HeapHC
923
+ handler->to_flat_arrays(
924
+ distances, labels, skip & 16 ? nullptr : normalizers.get());
925
+
926
+ TIC;
927
+
928
+ // these stats are not thread-safe
929
+
930
+ for (int i = 1; i < ti; i++) {
931
+ IVFFastScan_stats.times[i] += times[i] - times[i - 1];
932
+ }
933
+ IVFFastScan_stats.t_copy_pack += t_copy_pack;
934
+ IVFFastScan_stats.t_scan += t_scan;
935
+
936
+ if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
937
+ for (int i = 0; i < 4; i++) {
938
+ IVFFastScan_stats.reservoir_times[i] += rh->times[i];
939
+ }
940
+ }
941
+
942
+ *ndis_out = ndis;
943
+ *nlist_out = nlist;
944
+ }
945
+
946
+ template <class C, class Scaler>
947
+ void IndexIVFFastScan::search_implem_14(
948
+ idx_t n,
949
+ const float* x,
950
+ idx_t k,
951
+ float* distances,
952
+ idx_t* labels,
953
+ int impl,
954
+ const Scaler& scaler) const {
955
+ if (n == 0) { // does not work well with reservoir
956
+ return;
957
+ }
958
+ FAISS_THROW_IF_NOT(bbs == 32);
959
+
960
+ std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
961
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
962
+
963
+ uint64_t ttg0 = get_cy();
964
+
965
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
966
+
967
+ uint64_t ttg1 = get_cy();
968
+ uint64_t coarse_search_tt = ttg1 - ttg0;
969
+
970
+ size_t dim12 = ksub * M2;
971
+ AlignedTable<uint8_t> dis_tables;
972
+ AlignedTable<uint16_t> biases;
973
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
974
+
975
+ compute_LUT_uint8(
976
+ n,
977
+ x,
978
+ coarse_ids.get(),
979
+ coarse_dis.get(),
980
+ dis_tables,
981
+ biases,
982
+ normalizers.get());
983
+
984
+ uint64_t ttg2 = get_cy();
985
+ uint64_t lut_compute_tt = ttg2 - ttg1;
986
+
987
+ struct QC {
988
+ int qno; // sequence number of the query
989
+ int list_no; // list to visit
990
+ int rank; // this is the rank'th result of the coarse quantizer
991
+ };
992
+ bool single_LUT = !lookup_table_is_3d();
993
+
994
+ std::vector<QC> qcs;
995
+ {
996
+ int ij = 0;
997
+ for (int i = 0; i < n; i++) {
998
+ for (int j = 0; j < nprobe; j++) {
999
+ if (coarse_ids[ij] >= 0) {
1000
+ qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
1001
+ }
1002
+ ij++;
1003
+ }
1004
+ }
1005
+ std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
1006
+ return a.list_no < b.list_no;
1007
+ });
1008
+ }
1009
+
1010
+ struct SE {
1011
+ size_t start; // start in the QC vector
1012
+ size_t end; // end in the QC vector
1013
+ size_t list_size;
1014
+ };
1015
+ std::vector<SE> ses;
1016
+ size_t i0_l = 0;
1017
+ while (i0_l < qcs.size()) {
1018
+ // find all queries that access this inverted list
1019
+ int list_no = qcs[i0_l].list_no;
1020
+ size_t i1 = i0_l + 1;
1021
+
1022
+ while (i1 < qcs.size() && i1 < i0_l + qbs2) {
1023
+ if (qcs[i1].list_no != list_no) {
1024
+ break;
1025
+ }
1026
+ i1++;
1027
+ }
1028
+
1029
+ size_t list_size = invlists->list_size(list_no);
1030
+
1031
+ if (list_size == 0) {
1032
+ i0_l = i1;
1033
+ continue;
1034
+ }
1035
+ ses.push_back(SE{i0_l, i1, list_size});
1036
+ i0_l = i1;
1037
+ }
1038
+ uint64_t ttg3 = get_cy();
1039
+ uint64_t compute_clusters_tt = ttg3 - ttg2;
1040
+
1041
+ // function to handle the global heap
1042
+ using HeapForIP = CMin<float, idx_t>;
1043
+ using HeapForL2 = CMax<float, idx_t>;
1044
+ auto init_result = [&](float* simi, idx_t* idxi) {
1045
+ if (metric_type == METRIC_INNER_PRODUCT) {
1046
+ heap_heapify<HeapForIP>(k, simi, idxi);
1047
+ } else {
1048
+ heap_heapify<HeapForL2>(k, simi, idxi);
1049
+ }
1050
+ };
1051
+
1052
+ auto add_local_results = [&](const float* local_dis,
1053
+ const idx_t* local_idx,
1054
+ float* simi,
1055
+ idx_t* idxi) {
1056
+ if (metric_type == METRIC_INNER_PRODUCT) {
1057
+ heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
1058
+ } else {
1059
+ heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
1060
+ }
1061
+ };
1062
+
1063
+ auto reorder_result = [&](float* simi, idx_t* idxi) {
1064
+ if (metric_type == METRIC_INNER_PRODUCT) {
1065
+ heap_reorder<HeapForIP>(k, simi, idxi);
1066
+ } else {
1067
+ heap_reorder<HeapForL2>(k, simi, idxi);
1068
+ }
1069
+ };
1070
+ uint64_t ttg4 = get_cy();
1071
+ uint64_t fn_tt = ttg4 - ttg3;
1072
+
1073
+ size_t ndis = 0;
1074
+ size_t nlist_visited = 0;
1075
+
1076
+ #pragma omp parallel reduction(+ : ndis, nlist_visited)
1077
+ {
1078
+ // storage for each thread
1079
+ std::vector<idx_t> local_idx(k * n);
1080
+ std::vector<float> local_dis(k * n);
1081
+
1082
+ // prepare the result handlers
1083
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
1084
+ AlignedTable<uint16_t> tmp_distances;
1085
+
1086
+ using HeapHC = HeapHandler<C, true>;
1087
+ using ReservoirHC = ReservoirHandler<C, true>;
1088
+ using SingleResultHC = SingleResultHandler<C, true>;
1089
+
1090
+ if (k == 1) {
1091
+ handler.reset(new SingleResultHC(n, 0));
1092
+ } else if (impl == 14) {
1093
+ tmp_distances.resize(n * k);
1094
+ handler.reset(
1095
+ new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
1096
+ } else if (impl == 15) {
1097
+ handler.reset(new ReservoirHC(n, 0, k, 2 * k));
1098
+ }
1099
+
1100
+ int qbs2 = this->qbs2 ? this->qbs2 : 11;
1101
+
1102
+ std::vector<uint16_t> tmp_bias;
1103
+ if (biases.get()) {
1104
+ tmp_bias.resize(qbs2);
1105
+ handler->dbias = tmp_bias.data();
1106
+ }
1107
+
1108
+ uint64_t ttg5 = get_cy();
1109
+ uint64_t handler_tt = ttg5 - ttg4;
1110
+
1111
+ std::set<int> q_set;
1112
+ uint64_t t_copy_pack = 0, t_scan = 0;
1113
+ #pragma omp for schedule(dynamic)
1114
+ for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
1115
+ uint64_t tt0 = get_cy();
1116
+ size_t i0 = ses[cluster].start;
1117
+ size_t i1 = ses[cluster].end;
1118
+ size_t list_size = ses[cluster].list_size;
1119
+ nlist_visited++;
1120
+ int list_no = qcs[i0].list_no;
1121
+
1122
+ // re-organize LUTs and biases into the right order
1123
+ int nc = i1 - i0;
1124
+
1125
+ std::vector<int> q_map(nc), lut_entries(nc);
1126
+ AlignedTable<uint8_t> LUT(nc * dim12);
1127
+ memset(LUT.get(), -1, nc * dim12);
1128
+ int qbs = pq4_preferred_qbs(nc);
1129
+
1130
+ for (size_t i = i0; i < i1; i++) {
1131
+ const QC& qc = qcs[i];
1132
+ q_map[i - i0] = qc.qno;
1133
+ q_set.insert(qc.qno);
1134
+ int ij = qc.qno * nprobe + qc.rank;
1135
+ lut_entries[i - i0] = single_LUT ? qc.qno : ij;
1136
+ if (biases.get()) {
1137
+ tmp_bias[i - i0] = biases[ij];
1138
+ }
1139
+ }
1140
+ pq4_pack_LUT_qbs_q_map(
1141
+ qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
1142
+
1143
+ // access the inverted list
1144
+
1145
+ ndis += (i1 - i0) * list_size;
1146
+
1147
+ InvertedLists::ScopedCodes codes(invlists, list_no);
1148
+ InvertedLists::ScopedIds ids(invlists, list_no);
1149
+
1150
+ // prepare the handler
1151
+
1152
+ handler->ntotal = list_size;
1153
+ handler->q_map = q_map.data();
1154
+ handler->id_map = ids.get();
1155
+ uint64_t tt1 = get_cy();
1156
+
1157
+ #define DISPATCH(classHC) \
1158
+ if (dynamic_cast<classHC*>(handler.get())) { \
1159
+ auto* res = static_cast<classHC*>(handler.get()); \
1160
+ pq4_accumulate_loop_qbs( \
1161
+ qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
1162
+ }
1163
+ DISPATCH(HeapHC)
1164
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
1165
+
1166
+ uint64_t tt2 = get_cy();
1167
+ t_copy_pack += tt1 - tt0;
1168
+ t_scan += tt2 - tt1;
1169
+ }
1170
+
1171
+ // labels is in-place for HeapHC
1172
+ handler->to_flat_arrays(
1173
+ local_dis.data(),
1174
+ local_idx.data(),
1175
+ skip & 16 ? nullptr : normalizers.get());
1176
+
1177
+ #pragma omp single
1178
+ {
1179
+ // we init the results as a heap
1180
+ for (idx_t i = 0; i < n; i++) {
1181
+ init_result(distances + i * k, labels + i * k);
1182
+ }
1183
+ }
1184
+ #pragma omp barrier
1185
+ #pragma omp critical
1186
+ {
1187
+ // write to global heap #go over only the queries
1188
+ for (std::set<int>::iterator it = q_set.begin(); it != q_set.end();
1189
+ ++it) {
1190
+ add_local_results(
1191
+ local_dis.data() + *it * k,
1192
+ local_idx.data() + *it * k,
1193
+ distances + *it * k,
1194
+ labels + *it * k);
1195
+ }
1196
+
1197
+ IVFFastScan_stats.t_copy_pack += t_copy_pack;
1198
+ IVFFastScan_stats.t_scan += t_scan;
1199
+
1200
+ if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
1201
+ for (int i = 0; i < 4; i++) {
1202
+ IVFFastScan_stats.reservoir_times[i] += rh->times[i];
1203
+ }
1204
+ }
1205
+ }
1206
+ #pragma omp barrier
1207
+ #pragma omp single
1208
+ {
1209
+ for (idx_t i = 0; i < n; i++) {
1210
+ reorder_result(distances + i * k, labels + i * k);
1211
+ }
1212
+ }
1213
+ }
1214
+
1215
+ indexIVF_stats.nq += n;
1216
+ indexIVF_stats.ndis += ndis;
1217
+ indexIVF_stats.nlist += nlist_visited;
1218
+ }
1219
+
1220
+ void IndexIVFFastScan::reconstruct_from_offset(
1221
+ int64_t list_no,
1222
+ int64_t offset,
1223
+ float* recons) const {
1224
+ // unpack codes
1225
+ InvertedLists::ScopedCodes list_codes(invlists, list_no);
1226
+ std::vector<uint8_t> code(code_size, 0);
1227
+ BitstringWriter bsw(code.data(), code_size);
1228
+ for (size_t m = 0; m < M; m++) {
1229
+ uint8_t c =
1230
+ pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
1231
+ bsw.write(c, nbits);
1232
+ }
1233
+ sa_decode(1, code.data(), recons);
1234
+
1235
+ // add centroid to it
1236
+ if (by_residual) {
1237
+ std::vector<float> centroid(d);
1238
+ quantizer->reconstruct(list_no, centroid.data());
1239
+ for (int i = 0; i < d; ++i) {
1240
+ recons[i] += centroid[i];
1241
+ }
1242
+ }
1243
+ }
1244
+
1245
+ void IndexIVFFastScan::reconstruct_orig_invlists() {
1246
+ FAISS_THROW_IF_NOT(orig_invlists != nullptr);
1247
+ FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);
1248
+
1249
+ for (size_t list_no = 0; list_no < nlist; list_no++) {
1250
+ InvertedLists::ScopedCodes codes(invlists, list_no);
1251
+ InvertedLists::ScopedIds ids(invlists, list_no);
1252
+ size_t list_size = orig_invlists->list_size(list_no);
1253
+ std::vector<uint8_t> code(code_size, 0);
1254
+
1255
+ for (size_t offset = 0; offset < list_size; offset++) {
1256
+ // unpack codes
1257
+ BitstringWriter bsw(code.data(), code_size);
1258
+ for (size_t m = 0; m < M; m++) {
1259
+ uint8_t c =
1260
+ pq4_get_packed_element(codes.get(), bbs, M2, offset, m);
1261
+ bsw.write(c, nbits);
1262
+ }
1263
+
1264
+ // get id
1265
+ idx_t id = ids.get()[offset];
1266
+
1267
+ orig_invlists->add_entry(list_no, id, code.data());
1268
+ }
1269
+ }
1270
+ }
1271
+
1272
+ IVFFastScanStats IVFFastScan_stats;
1273
+
1274
+ template void IndexIVFFastScan::search_dispatch_implem<true, NormTableScaler>(
1275
+ idx_t n,
1276
+ const float* x,
1277
+ idx_t k,
1278
+ float* distances,
1279
+ idx_t* labels,
1280
+ const NormTableScaler& scaler) const;
1281
+
1282
+ template void IndexIVFFastScan::search_dispatch_implem<false, NormTableScaler>(
1283
+ idx_t n,
1284
+ const float* x,
1285
+ idx_t k,
1286
+ float* distances,
1287
+ idx_t* labels,
1288
+ const NormTableScaler& scaler) const;
1289
+
1290
+ } // namespace faiss