faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -40,19 +40,13 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(
40
40
  size_t d,
41
41
  size_t nlist,
42
42
  size_t M,
43
- size_t nbits_per_idx,
43
+ size_t nbits,
44
44
  MetricType metric,
45
45
  int bbs)
46
- : IndexIVF(quantizer, d, nlist, 0, metric),
47
- pq(d, M, nbits_per_idx),
48
- bbs(bbs) {
49
- FAISS_THROW_IF_NOT(nbits_per_idx == 4);
50
- M2 = roundup(pq.M, 2);
46
+ : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
51
47
  by_residual = false; // set to false by default because it's much faster
52
- is_trained = false;
53
- code_size = pq.code_size;
54
48
 
55
- replace_invlists(new BlockInvertedLists(nlist, bbs, bbs * M2 / 2), true);
49
+ init_fastscan(M, nbits, nlist, metric, bbs);
56
50
  }
57
51
 
58
52
  IndexIVFPQFastScan::IndexIVFPQFastScan() {
@@ -62,26 +56,21 @@ IndexIVFPQFastScan::IndexIVFPQFastScan() {
62
56
  }
63
57
 
64
58
  IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
65
- : IndexIVF(
59
+ : IndexIVFFastScan(
66
60
  orig.quantizer,
67
61
  orig.d,
68
62
  orig.nlist,
69
63
  orig.pq.code_size,
70
64
  orig.metric_type),
71
- pq(orig.pq),
72
- bbs(bbs) {
65
+ pq(orig.pq) {
73
66
  FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
74
67
 
68
+ init_fastscan(orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
69
+
75
70
  by_residual = orig.by_residual;
76
71
  ntotal = orig.ntotal;
77
72
  is_trained = orig.is_trained;
78
73
  nprobe = orig.nprobe;
79
- size_t M = pq.M;
80
-
81
- M2 = roundup(M, 2);
82
-
83
- replace_invlists(
84
- new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2), true);
85
74
 
86
75
  precomputed_table.resize(orig.precomputed_table.size());
87
76
 
@@ -205,150 +194,10 @@ void IndexIVFPQFastScan::encode_vectors(
205
194
  }
206
195
  }
207
196
 
208
- void IndexIVFPQFastScan::add_with_ids(
209
- idx_t n,
210
- const float* x,
211
- const idx_t* xids) {
212
- // copied from IndexIVF::add_with_ids --->
213
-
214
- // do some blocking to avoid excessive allocs
215
- idx_t bs = 65536;
216
- if (n > bs) {
217
- for (idx_t i0 = 0; i0 < n; i0 += bs) {
218
- idx_t i1 = std::min(n, i0 + bs);
219
- if (verbose) {
220
- printf(" IndexIVFPQFastScan::add_with_ids %zd: %zd",
221
- size_t(i0),
222
- size_t(i1));
223
- }
224
- add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
225
- }
226
- return;
227
- }
228
- InterruptCallback::check();
229
-
230
- AlignedTable<uint8_t> codes(n * code_size);
231
-
232
- FAISS_THROW_IF_NOT(is_trained);
233
- direct_map.check_can_add(xids);
234
-
235
- std::unique_ptr<idx_t[]> idx(new idx_t[n]);
236
- quantizer->assign(n, x, idx.get());
237
- size_t nadd = 0, nminus1 = 0;
238
-
239
- for (size_t i = 0; i < n; i++) {
240
- if (idx[i] < 0)
241
- nminus1++;
242
- }
243
-
244
- AlignedTable<uint8_t> flat_codes(n * code_size);
245
- encode_vectors(n, x, idx.get(), flat_codes.get());
246
-
247
- DirectMapAdd dm_adder(direct_map, n, xids);
248
-
249
- // <---
250
-
251
- BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
252
- FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
253
-
254
- // prepare batches
255
- std::vector<idx_t> order(n);
256
- for (idx_t i = 0; i < n; i++) {
257
- order[i] = i;
258
- }
259
-
260
- // TODO should not need stable
261
- std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
262
- return idx[a] < idx[b];
263
- });
264
-
265
- // TODO parallelize
266
- idx_t i0 = 0;
267
- while (i0 < n) {
268
- idx_t list_no = idx[order[i0]];
269
- idx_t i1 = i0 + 1;
270
- while (i1 < n && idx[order[i1]] == list_no) {
271
- i1++;
272
- }
273
-
274
- if (list_no == -1) {
275
- i0 = i1;
276
- continue;
277
- }
278
-
279
- // make linear array
280
- AlignedTable<uint8_t> list_codes((i1 - i0) * code_size);
281
- size_t list_size = bil->list_size(list_no);
282
-
283
- bil->resize(list_no, list_size + i1 - i0);
284
-
285
- for (idx_t i = i0; i < i1; i++) {
286
- size_t ofs = list_size + i - i0;
287
- idx_t id = xids ? xids[order[i]] : ntotal + order[i];
288
- dm_adder.add(order[i], list_no, ofs);
289
- bil->ids[list_no][ofs] = id;
290
- memcpy(list_codes.data() + (i - i0) * code_size,
291
- flat_codes.data() + order[i] * code_size,
292
- code_size);
293
- nadd++;
294
- }
295
- pq4_pack_codes_range(
296
- list_codes.data(),
297
- pq.M,
298
- list_size,
299
- list_size + i1 - i0,
300
- bbs,
301
- M2,
302
- bil->codes[list_no].data());
303
-
304
- i0 = i1;
305
- }
306
-
307
- ntotal += n;
308
- }
309
-
310
197
  /*********************************************************
311
- * search
198
+ * Look-Up Table functions
312
199
  *********************************************************/
313
200
 
314
- namespace {
315
-
316
- // from impl/ProductQuantizer.cpp
317
- template <class C, typename dis_t>
318
- void pq_estimators_from_tables_generic(
319
- const ProductQuantizer& pq,
320
- size_t nbits,
321
- const uint8_t* codes,
322
- size_t ncodes,
323
- const dis_t* dis_table,
324
- const int64_t* ids,
325
- float dis0,
326
- size_t k,
327
- typename C::T* heap_dis,
328
- int64_t* heap_ids) {
329
- using accu_t = typename C::T;
330
- const size_t M = pq.M;
331
- const size_t ksub = pq.ksub;
332
- for (size_t j = 0; j < ncodes; ++j) {
333
- PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
334
- accu_t dis = dis0;
335
- const dis_t* dt = dis_table;
336
- for (size_t m = 0; m < M; m++) {
337
- uint64_t c = decoder.decode();
338
- dis += dt[c];
339
- dt += ksub;
340
- }
341
-
342
- if (C::cmp(heap_dis[0], dis)) {
343
- heap_pop<C>(k, heap_dis, heap_ids);
344
- heap_push<C>(k, heap_dis, heap_ids, dis, ids[j]);
345
- }
346
- }
347
- }
348
-
349
- using idx_t = Index::idx_t;
350
- using namespace quantize_lut;
351
-
352
201
  void fvec_madd_avx(
353
202
  size_t n,
354
203
  const float* a,
@@ -373,11 +222,9 @@ void fvec_madd_avx(
373
222
  }
374
223
  }
375
224
 
376
- } // anonymous namespace
377
-
378
- /*********************************************************
379
- * Look-Up Table functions
380
- *********************************************************/
225
+ bool IndexIVFPQFastScan::lookup_table_is_3d() const {
226
+ return by_residual && metric_type == METRIC_L2;
227
+ }
381
228
 
382
229
  void IndexIVFPQFastScan::compute_LUT(
383
230
  size_t n,
@@ -386,16 +233,14 @@ void IndexIVFPQFastScan::compute_LUT(
386
233
  const float* coarse_dis,
387
234
  AlignedTable<float>& dis_tables,
388
235
  AlignedTable<float>& biases) const {
389
- const IndexIVFPQFastScan& ivfpq = *this;
390
236
  size_t dim12 = pq.ksub * pq.M;
391
237
  size_t d = pq.d;
392
- size_t nprobe = ivfpq.nprobe;
393
238
 
394
- if (ivfpq.by_residual) {
395
- if (ivfpq.metric_type == METRIC_L2) {
239
+ if (by_residual) {
240
+ if (metric_type == METRIC_L2) {
396
241
  dis_tables.resize(n * nprobe * dim12);
397
242
 
398
- if (ivfpq.use_precomputed_table == 1) {
243
+ if (use_precomputed_table == 1) {
399
244
  biases.resize(n * nprobe);
400
245
  memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
401
246
 
@@ -434,7 +279,7 @@ void IndexIVFPQFastScan::compute_LUT(
434
279
  idx_t cij = coarse_ids[ij];
435
280
 
436
281
  if (cij >= 0) {
437
- ivfpq.quantizer->compute_residual(x + i * d, xij, cij);
282
+ quantizer->compute_residual(x + i * d, xij, cij);
438
283
  } else {
439
284
  // will fill with NaNs
440
285
  memset(xij, -1, sizeof(float) * d);
@@ -445,7 +290,7 @@ void IndexIVFPQFastScan::compute_LUT(
445
290
  n * nprobe, xrel.get(), dis_tables.get());
446
291
  }
447
292
 
448
- } else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
293
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
449
294
  dis_tables.resize(n * dim12);
450
295
  pq.compute_inner_prod_tables(n, x, dis_tables.get());
451
296
  // compute_inner_prod_tables(pq, n, x, dis_tables.get());
@@ -453,698 +298,24 @@ void IndexIVFPQFastScan::compute_LUT(
453
298
  biases.resize(n * nprobe);
454
299
  memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
455
300
  } else {
456
- FAISS_THROW_FMT("metric %d not supported", ivfpq.metric_type);
301
+ FAISS_THROW_FMT("metric %d not supported", metric_type);
457
302
  }
458
303
 
459
304
  } else {
460
305
  dis_tables.resize(n * dim12);
461
- if (ivfpq.metric_type == METRIC_L2) {
306
+ if (metric_type == METRIC_L2) {
462
307
  pq.compute_distance_tables(n, x, dis_tables.get());
463
- } else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
308
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
464
309
  pq.compute_inner_prod_tables(n, x, dis_tables.get());
465
310
  } else {
466
- FAISS_THROW_FMT("metric %d not supported", ivfpq.metric_type);
467
- }
468
- }
469
- }
470
-
471
- void IndexIVFPQFastScan::compute_LUT_uint8(
472
- size_t n,
473
- const float* x,
474
- const idx_t* coarse_ids,
475
- const float* coarse_dis,
476
- AlignedTable<uint8_t>& dis_tables,
477
- AlignedTable<uint16_t>& biases,
478
- float* normalizers) const {
479
- const IndexIVFPQFastScan& ivfpq = *this;
480
- AlignedTable<float> dis_tables_float;
481
- AlignedTable<float> biases_float;
482
-
483
- uint64_t t0 = get_cy();
484
- compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
485
- IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
486
-
487
- bool lut_is_3d = ivfpq.by_residual && ivfpq.metric_type == METRIC_L2;
488
- size_t dim123 = pq.ksub * pq.M;
489
- size_t dim123_2 = pq.ksub * M2;
490
- if (lut_is_3d) {
491
- dim123 *= nprobe;
492
- dim123_2 *= nprobe;
493
- }
494
- dis_tables.resize(n * dim123_2);
495
- if (biases_float.get()) {
496
- biases.resize(n * nprobe);
497
- }
498
- uint64_t t1 = get_cy();
499
-
500
- #pragma omp parallel for if (n > 100)
501
- for (int64_t i = 0; i < n; i++) {
502
- const float* t_in = dis_tables_float.get() + i * dim123;
503
- const float* b_in = nullptr;
504
- uint8_t* t_out = dis_tables.get() + i * dim123_2;
505
- uint16_t* b_out = nullptr;
506
- if (biases_float.get()) {
507
- b_in = biases_float.get() + i * nprobe;
508
- b_out = biases.get() + i * nprobe;
509
- }
510
-
511
- quantize_LUT_and_bias(
512
- nprobe,
513
- pq.M,
514
- pq.ksub,
515
- lut_is_3d,
516
- t_in,
517
- b_in,
518
- t_out,
519
- M2,
520
- b_out,
521
- normalizers + 2 * i,
522
- normalizers + 2 * i + 1);
523
- }
524
- IVFFastScan_stats.t_round += get_cy() - t1;
525
- }
526
-
527
- /*********************************************************
528
- * Search functions
529
- *********************************************************/
530
-
531
- template <bool is_max>
532
- void IndexIVFPQFastScan::search_dispatch_implem(
533
- idx_t n,
534
- const float* x,
535
- idx_t k,
536
- float* distances,
537
- idx_t* labels) const {
538
- using Cfloat = typename std::conditional<
539
- is_max,
540
- CMax<float, int64_t>,
541
- CMin<float, int64_t>>::type;
542
-
543
- using C = typename std::conditional<
544
- is_max,
545
- CMax<uint16_t, int64_t>,
546
- CMin<uint16_t, int64_t>>::type;
547
-
548
- if (n == 0) {
549
- return;
550
- }
551
-
552
- // actual implementation used
553
- int impl = implem;
554
-
555
- if (impl == 0) {
556
- if (bbs == 32) {
557
- impl = 12;
558
- } else {
559
- impl = 10;
560
- }
561
- if (k > 20) {
562
- impl++;
563
- }
564
- }
565
-
566
- if (impl == 1) {
567
- search_implem_1<Cfloat>(n, x, k, distances, labels);
568
- } else if (impl == 2) {
569
- search_implem_2<C>(n, x, k, distances, labels);
570
-
571
- } else if (impl >= 10 && impl <= 13) {
572
- size_t ndis = 0, nlist_visited = 0;
573
-
574
- if (n < 2) {
575
- if (impl == 12 || impl == 13) {
576
- search_implem_12<C>(
577
- n,
578
- x,
579
- k,
580
- distances,
581
- labels,
582
- impl,
583
- &ndis,
584
- &nlist_visited);
585
- } else {
586
- search_implem_10<C>(
587
- n,
588
- x,
589
- k,
590
- distances,
591
- labels,
592
- impl,
593
- &ndis,
594
- &nlist_visited);
595
- }
596
- } else {
597
- // explicitly slice over threads
598
- int nslice;
599
- if (n <= omp_get_max_threads()) {
600
- nslice = n;
601
- } else if (by_residual && metric_type == METRIC_L2) {
602
- // make sure we don't make too big LUT tables
603
- size_t lut_size_per_query = pq.M * pq.ksub * nprobe *
604
- (sizeof(float) + sizeof(uint8_t));
605
-
606
- size_t max_lut_size = precomputed_table_max_bytes;
607
- // how many queries we can handle within mem budget
608
- size_t nq_ok =
609
- std::max(max_lut_size / lut_size_per_query, size_t(1));
610
- nslice =
611
- roundup(std::max(size_t(n / nq_ok), size_t(1)),
612
- omp_get_max_threads());
613
- } else {
614
- // LUTs unlikely to be a limiting factor
615
- nslice = omp_get_max_threads();
616
- }
617
-
618
- #pragma omp parallel for reduction(+ : ndis, nlist_visited)
619
- for (int slice = 0; slice < nslice; slice++) {
620
- idx_t i0 = n * slice / nslice;
621
- idx_t i1 = n * (slice + 1) / nslice;
622
- float* dis_i = distances + i0 * k;
623
- idx_t* lab_i = labels + i0 * k;
624
- if (impl == 12 || impl == 13) {
625
- search_implem_12<C>(
626
- i1 - i0,
627
- x + i0 * d,
628
- k,
629
- dis_i,
630
- lab_i,
631
- impl,
632
- &ndis,
633
- &nlist_visited);
634
- } else {
635
- search_implem_10<C>(
636
- i1 - i0,
637
- x + i0 * d,
638
- k,
639
- dis_i,
640
- lab_i,
641
- impl,
642
- &ndis,
643
- &nlist_visited);
644
- }
645
- }
646
- }
647
- indexIVF_stats.nq += n;
648
- indexIVF_stats.ndis += ndis;
649
- indexIVF_stats.nlist += nlist_visited;
650
- } else {
651
- FAISS_THROW_FMT("implem %d does not exist", implem);
652
- }
653
- }
654
-
655
- void IndexIVFPQFastScan::search(
656
- idx_t n,
657
- const float* x,
658
- idx_t k,
659
- float* distances,
660
- idx_t* labels) const {
661
- FAISS_THROW_IF_NOT(k > 0);
662
-
663
- if (metric_type == METRIC_L2) {
664
- search_dispatch_implem<true>(n, x, k, distances, labels);
665
- } else {
666
- search_dispatch_implem<false>(n, x, k, distances, labels);
667
- }
668
- }
669
-
670
- template <class C>
671
- void IndexIVFPQFastScan::search_implem_1(
672
- idx_t n,
673
- const float* x,
674
- idx_t k,
675
- float* distances,
676
- idx_t* labels) const {
677
- FAISS_THROW_IF_NOT(orig_invlists);
678
-
679
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
680
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
681
-
682
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
683
-
684
- size_t dim12 = pq.ksub * pq.M;
685
- AlignedTable<float> dis_tables;
686
- AlignedTable<float> biases;
687
-
688
- compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
689
-
690
- bool single_LUT = !(by_residual && metric_type == METRIC_L2);
691
-
692
- size_t ndis = 0, nlist_visited = 0;
693
-
694
- #pragma omp parallel for reduction(+ : ndis, nlist_visited)
695
- for (idx_t i = 0; i < n; i++) {
696
- int64_t* heap_ids = labels + i * k;
697
- float* heap_dis = distances + i * k;
698
- heap_heapify<C>(k, heap_dis, heap_ids);
699
- float* LUT = nullptr;
700
-
701
- if (single_LUT) {
702
- LUT = dis_tables.get() + i * dim12;
703
- }
704
- for (idx_t j = 0; j < nprobe; j++) {
705
- if (!single_LUT) {
706
- LUT = dis_tables.get() + (i * nprobe + j) * dim12;
707
- }
708
- idx_t list_no = coarse_ids[i * nprobe + j];
709
- if (list_no < 0)
710
- continue;
711
- size_t ls = orig_invlists->list_size(list_no);
712
- if (ls == 0)
713
- continue;
714
- InvertedLists::ScopedCodes codes(orig_invlists, list_no);
715
- InvertedLists::ScopedIds ids(orig_invlists, list_no);
716
-
717
- float bias = biases.get() ? biases[i * nprobe + j] : 0;
718
-
719
- pq_estimators_from_tables_generic<C>(
720
- pq,
721
- pq.nbits,
722
- codes.get(),
723
- ls,
724
- LUT,
725
- ids.get(),
726
- bias,
727
- k,
728
- heap_dis,
729
- heap_ids);
730
- nlist_visited++;
731
- ndis++;
732
- }
733
- heap_reorder<C>(k, heap_dis, heap_ids);
734
- }
735
- indexIVF_stats.nq += n;
736
- indexIVF_stats.ndis += ndis;
737
- indexIVF_stats.nlist += nlist_visited;
738
- }
739
-
740
- template <class C>
741
- void IndexIVFPQFastScan::search_implem_2(
742
- idx_t n,
743
- const float* x,
744
- idx_t k,
745
- float* distances,
746
- idx_t* labels) const {
747
- FAISS_THROW_IF_NOT(orig_invlists);
748
-
749
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
750
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
751
-
752
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
753
-
754
- size_t dim12 = pq.ksub * M2;
755
- AlignedTable<uint8_t> dis_tables;
756
- AlignedTable<uint16_t> biases;
757
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
758
-
759
- compute_LUT_uint8(
760
- n,
761
- x,
762
- coarse_ids.get(),
763
- coarse_dis.get(),
764
- dis_tables,
765
- biases,
766
- normalizers.get());
767
-
768
- bool single_LUT = !(by_residual && metric_type == METRIC_L2);
769
-
770
- size_t ndis = 0, nlist_visited = 0;
771
-
772
- #pragma omp parallel for reduction(+ : ndis, nlist_visited)
773
- for (idx_t i = 0; i < n; i++) {
774
- std::vector<uint16_t> tmp_dis(k);
775
- int64_t* heap_ids = labels + i * k;
776
- uint16_t* heap_dis = tmp_dis.data();
777
- heap_heapify<C>(k, heap_dis, heap_ids);
778
- const uint8_t* LUT = nullptr;
779
-
780
- if (single_LUT) {
781
- LUT = dis_tables.get() + i * dim12;
782
- }
783
- for (idx_t j = 0; j < nprobe; j++) {
784
- if (!single_LUT) {
785
- LUT = dis_tables.get() + (i * nprobe + j) * dim12;
786
- }
787
- idx_t list_no = coarse_ids[i * nprobe + j];
788
- if (list_no < 0)
789
- continue;
790
- size_t ls = orig_invlists->list_size(list_no);
791
- if (ls == 0)
792
- continue;
793
- InvertedLists::ScopedCodes codes(orig_invlists, list_no);
794
- InvertedLists::ScopedIds ids(orig_invlists, list_no);
795
-
796
- uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
797
-
798
- pq_estimators_from_tables_generic<C>(
799
- pq,
800
- pq.nbits,
801
- codes.get(),
802
- ls,
803
- LUT,
804
- ids.get(),
805
- bias,
806
- k,
807
- heap_dis,
808
- heap_ids);
809
-
810
- nlist_visited++;
811
- ndis += ls;
812
- }
813
- heap_reorder<C>(k, heap_dis, heap_ids);
814
- // convert distances to float
815
- {
816
- float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
817
- if (skip & 16) {
818
- one_a = 1;
819
- b = 0;
820
- }
821
- float* heap_dis_float = distances + i * k;
822
- for (int j = 0; j < k; j++) {
823
- heap_dis_float[j] = b + heap_dis[j] * one_a;
824
- }
825
- }
826
- }
827
- indexIVF_stats.nq += n;
828
- indexIVF_stats.ndis += ndis;
829
- indexIVF_stats.nlist += nlist_visited;
830
- }
831
-
832
- template <class C>
833
- void IndexIVFPQFastScan::search_implem_10(
834
- idx_t n,
835
- const float* x,
836
- idx_t k,
837
- float* distances,
838
- idx_t* labels,
839
- int impl,
840
- size_t* ndis_out,
841
- size_t* nlist_out) const {
842
- memset(distances, -1, sizeof(float) * k * n);
843
- memset(labels, -1, sizeof(idx_t) * k * n);
844
-
845
- using HeapHC = HeapHandler<C, true>;
846
- using ReservoirHC = ReservoirHandler<C, true>;
847
- using SingleResultHC = SingleResultHandler<C, true>;
848
-
849
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
850
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
851
-
852
- uint64_t times[10];
853
- memset(times, 0, sizeof(times));
854
- int ti = 0;
855
- #define TIC times[ti++] = get_cy()
856
- TIC;
857
-
858
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
859
-
860
- TIC;
861
-
862
- size_t dim12 = pq.ksub * M2;
863
- AlignedTable<uint8_t> dis_tables;
864
- AlignedTable<uint16_t> biases;
865
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
866
-
867
- compute_LUT_uint8(
868
- n,
869
- x,
870
- coarse_ids.get(),
871
- coarse_dis.get(),
872
- dis_tables,
873
- biases,
874
- normalizers.get());
875
-
876
- TIC;
877
-
878
- bool single_LUT = !(by_residual && metric_type == METRIC_L2);
879
-
880
- TIC;
881
- size_t ndis = 0, nlist_visited = 0;
882
-
883
- {
884
- AlignedTable<uint16_t> tmp_distances(k);
885
- for (idx_t i = 0; i < n; i++) {
886
- const uint8_t* LUT = nullptr;
887
- int qmap1[1] = {0};
888
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
889
-
890
- if (k == 1) {
891
- handler.reset(new SingleResultHC(1, 0));
892
- } else if (impl == 10) {
893
- handler.reset(new HeapHC(
894
- 1, tmp_distances.get(), labels + i * k, k, 0));
895
- } else if (impl == 11) {
896
- handler.reset(new ReservoirHC(1, 0, k, 2 * k));
897
- } else {
898
- FAISS_THROW_MSG("invalid");
899
- }
900
-
901
- handler->q_map = qmap1;
902
-
903
- if (single_LUT) {
904
- LUT = dis_tables.get() + i * dim12;
905
- }
906
- for (idx_t j = 0; j < nprobe; j++) {
907
- size_t ij = i * nprobe + j;
908
- if (!single_LUT) {
909
- LUT = dis_tables.get() + ij * dim12;
910
- }
911
- if (biases.get()) {
912
- handler->dbias = biases.get() + ij;
913
- }
914
-
915
- idx_t list_no = coarse_ids[ij];
916
- if (list_no < 0)
917
- continue;
918
- size_t ls = invlists->list_size(list_no);
919
- if (ls == 0)
920
- continue;
921
-
922
- InvertedLists::ScopedCodes codes(invlists, list_no);
923
- InvertedLists::ScopedIds ids(invlists, list_no);
924
-
925
- handler->ntotal = ls;
926
- handler->id_map = ids.get();
927
-
928
- #define DISPATCH(classHC) \
929
- if (dynamic_cast<classHC*>(handler.get())) { \
930
- auto* res = static_cast<classHC*>(handler.get()); \
931
- pq4_accumulate_loop( \
932
- 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res); \
933
- }
934
- DISPATCH(HeapHC)
935
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
936
- #undef DISPATCH
937
-
938
- nlist_visited++;
939
- ndis++;
940
- }
941
-
942
- handler->to_flat_arrays(
943
- distances + i * k,
944
- labels + i * k,
945
- skip & 16 ? nullptr : normalizers.get() + i * 2);
311
+ FAISS_THROW_FMT("metric %d not supported", metric_type);
946
312
  }
947
313
  }
948
- *ndis_out = ndis;
949
- *nlist_out = nlist;
950
314
  }
951
315
 
952
- template <class C>
953
- void IndexIVFPQFastScan::search_implem_12(
954
- idx_t n,
955
- const float* x,
956
- idx_t k,
957
- float* distances,
958
- idx_t* labels,
959
- int impl,
960
- size_t* ndis_out,
961
- size_t* nlist_out) const {
962
- if (n == 0) { // does not work well with reservoir
963
- return;
964
- }
965
- FAISS_THROW_IF_NOT(bbs == 32);
966
-
967
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
968
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
969
-
970
- uint64_t times[10];
971
- memset(times, 0, sizeof(times));
972
- int ti = 0;
973
- #define TIC times[ti++] = get_cy()
974
- TIC;
975
-
976
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
977
-
978
- TIC;
979
-
980
- size_t dim12 = pq.ksub * M2;
981
- AlignedTable<uint8_t> dis_tables;
982
- AlignedTable<uint16_t> biases;
983
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
984
-
985
- compute_LUT_uint8(
986
- n,
987
- x,
988
- coarse_ids.get(),
989
- coarse_dis.get(),
990
- dis_tables,
991
- biases,
992
- normalizers.get());
993
-
994
- TIC;
995
-
996
- struct QC {
997
- int qno; // sequence number of the query
998
- int list_no; // list to visit
999
- int rank; // this is the rank'th result of the coarse quantizer
1000
- };
1001
- bool single_LUT = !(by_residual && metric_type == METRIC_L2);
1002
-
1003
- std::vector<QC> qcs;
1004
- {
1005
- int ij = 0;
1006
- for (int i = 0; i < n; i++) {
1007
- for (int j = 0; j < nprobe; j++) {
1008
- if (coarse_ids[ij] >= 0) {
1009
- qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
1010
- }
1011
- ij++;
1012
- }
1013
- }
1014
- std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
1015
- return a.list_no < b.list_no;
1016
- });
1017
- }
1018
- TIC;
1019
-
1020
- // prepare the result handlers
1021
-
1022
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
1023
- AlignedTable<uint16_t> tmp_distances;
1024
-
1025
- using HeapHC = HeapHandler<C, true>;
1026
- using ReservoirHC = ReservoirHandler<C, true>;
1027
- using SingleResultHC = SingleResultHandler<C, true>;
1028
-
1029
- if (k == 1) {
1030
- handler.reset(new SingleResultHC(n, 0));
1031
- } else if (impl == 12) {
1032
- tmp_distances.resize(n * k);
1033
- handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
1034
- } else if (impl == 13) {
1035
- handler.reset(new ReservoirHC(n, 0, k, 2 * k));
1036
- }
1037
-
1038
- int qbs2 = this->qbs2 ? this->qbs2 : 11;
1039
-
1040
- std::vector<uint16_t> tmp_bias;
1041
- if (biases.get()) {
1042
- tmp_bias.resize(qbs2);
1043
- handler->dbias = tmp_bias.data();
1044
- }
1045
- TIC;
1046
-
1047
- size_t ndis = 0;
1048
-
1049
- size_t i0 = 0;
1050
- uint64_t t_copy_pack = 0, t_scan = 0;
1051
- while (i0 < qcs.size()) {
1052
- uint64_t tt0 = get_cy();
1053
-
1054
- // find all queries that access this inverted list
1055
- int list_no = qcs[i0].list_no;
1056
- size_t i1 = i0 + 1;
1057
-
1058
- while (i1 < qcs.size() && i1 < i0 + qbs2) {
1059
- if (qcs[i1].list_no != list_no) {
1060
- break;
1061
- }
1062
- i1++;
1063
- }
1064
-
1065
- size_t list_size = invlists->list_size(list_no);
1066
-
1067
- if (list_size == 0) {
1068
- i0 = i1;
1069
- continue;
1070
- }
1071
-
1072
- // re-organize LUTs and biases into the right order
1073
- int nc = i1 - i0;
1074
-
1075
- std::vector<int> q_map(nc), lut_entries(nc);
1076
- AlignedTable<uint8_t> LUT(nc * dim12);
1077
- memset(LUT.get(), -1, nc * dim12);
1078
- int qbs = pq4_preferred_qbs(nc);
1079
-
1080
- for (size_t i = i0; i < i1; i++) {
1081
- const QC& qc = qcs[i];
1082
- q_map[i - i0] = qc.qno;
1083
- int ij = qc.qno * nprobe + qc.rank;
1084
- lut_entries[i - i0] = single_LUT ? qc.qno : ij;
1085
- if (biases.get()) {
1086
- tmp_bias[i - i0] = biases[ij];
1087
- }
1088
- }
1089
- pq4_pack_LUT_qbs_q_map(
1090
- qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
1091
-
1092
- // access the inverted list
1093
-
1094
- ndis += (i1 - i0) * list_size;
1095
-
1096
- InvertedLists::ScopedCodes codes(invlists, list_no);
1097
- InvertedLists::ScopedIds ids(invlists, list_no);
1098
-
1099
- // prepare the handler
1100
-
1101
- handler->ntotal = list_size;
1102
- handler->q_map = q_map.data();
1103
- handler->id_map = ids.get();
1104
- uint64_t tt1 = get_cy();
1105
-
1106
- #define DISPATCH(classHC) \
1107
- if (dynamic_cast<classHC*>(handler.get())) { \
1108
- auto* res = static_cast<classHC*>(handler.get()); \
1109
- pq4_accumulate_loop_qbs( \
1110
- qbs, list_size, M2, codes.get(), LUT.get(), *res); \
1111
- }
1112
- DISPATCH(HeapHC)
1113
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
1114
-
1115
- // prepare for next loop
1116
- i0 = i1;
1117
-
1118
- uint64_t tt2 = get_cy();
1119
- t_copy_pack += tt1 - tt0;
1120
- t_scan += tt2 - tt1;
1121
- }
1122
- TIC;
1123
-
1124
- // labels is in-place for HeapHC
1125
- handler->to_flat_arrays(
1126
- distances, labels, skip & 16 ? nullptr : normalizers.get());
1127
-
1128
- TIC;
1129
-
1130
- // these stats are not thread-safe
1131
-
1132
- for (int i = 1; i < ti; i++) {
1133
- IVFFastScan_stats.times[i] += times[i] - times[i - 1];
1134
- }
1135
- IVFFastScan_stats.t_copy_pack += t_copy_pack;
1136
- IVFFastScan_stats.t_scan += t_scan;
1137
-
1138
- if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
1139
- for (int i = 0; i < 4; i++) {
1140
- IVFFastScan_stats.reservoir_times[i] += rh->times[i];
1141
- }
1142
- }
1143
-
1144
- *ndis_out = ndis;
1145
- *nlist_out = nlist;
316
+ void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
317
+ const {
318
+ pq.decode(bytes, x, n);
1146
319
  }
1147
320
 
1148
- IVFFastScanStats IVFFastScan_stats;
1149
-
1150
321
  } // namespace faiss