faiss 0.2.7 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (172) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/lib/faiss.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  12. data/vendor/faiss/faiss/AutoTune.h +0 -1
  13. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  14. data/vendor/faiss/faiss/Clustering.h +31 -21
  15. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  16. data/vendor/faiss/faiss/Index.cpp +1 -1
  17. data/vendor/faiss/faiss/Index.h +20 -5
  18. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  21. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  34. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  38. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  59. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  60. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  61. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  62. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  63. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  64. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  65. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  66. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  67. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  69. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  70. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  71. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  72. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  73. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  74. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  75. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  76. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  77. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  78. data/vendor/faiss/faiss/clone_index.h +3 -0
  79. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  80. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  81. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  82. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  90. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  92. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  93. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  97. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  98. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  99. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  101. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  104. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  105. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  106. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  107. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  108. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  111. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  113. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  119. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  125. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  126. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  127. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  128. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  129. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  133. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  135. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  136. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  137. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  138. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  139. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  140. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  141. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  142. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  143. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  144. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  145. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  146. data/vendor/faiss/faiss/utils/distances.h +81 -4
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  148. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  150. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  152. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  153. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  154. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  155. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  156. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  157. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  158. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  159. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  160. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  161. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  162. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  163. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  164. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  165. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  166. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  167. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  168. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  169. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  170. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  171. data/vendor/faiss/faiss/utils/utils.h +57 -20
  172. metadata +11 -4
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -37,22 +37,22 @@ inline size_t roundup(size_t a, size_t b) {
37
37
 
38
38
  void IndexFastScan::init_fastscan(
39
39
  int d,
40
- size_t M,
41
- size_t nbits,
40
+ size_t M_2,
41
+ size_t nbits_2,
42
42
  MetricType metric,
43
43
  int bbs) {
44
- FAISS_THROW_IF_NOT(nbits == 4);
44
+ FAISS_THROW_IF_NOT(nbits_2 == 4);
45
45
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
46
46
  this->d = d;
47
- this->M = M;
48
- this->nbits = nbits;
47
+ this->M = M_2;
48
+ this->nbits = nbits_2;
49
49
  this->metric_type = metric;
50
50
  this->bbs = bbs;
51
- ksub = (1 << nbits);
51
+ ksub = (1 << nbits_2);
52
52
 
53
- code_size = (M * nbits + 7) / 8;
53
+ code_size = (M_2 * nbits_2 + 7) / 8;
54
54
  ntotal = ntotal2 = 0;
55
- M2 = roundup(M, 2);
55
+ M2 = roundup(M_2, 2);
56
56
  is_trained = false;
57
57
  }
58
58
 
@@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
158
158
 
159
159
  namespace {
160
160
 
161
- template <class C, typename dis_t, class Scaler>
161
+ template <class C, typename dis_t>
162
162
  void estimators_from_tables_generic(
163
163
  const IndexFastScan& index,
164
164
  const uint8_t* codes,
@@ -167,25 +167,28 @@ void estimators_from_tables_generic(
167
167
  size_t k,
168
168
  typename C::T* heap_dis,
169
169
  int64_t* heap_ids,
170
- const Scaler& scaler) {
170
+ const NormTableScaler* scaler) {
171
171
  using accu_t = typename C::T;
172
172
 
173
173
  for (size_t j = 0; j < ncodes; ++j) {
174
174
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
175
175
  accu_t dis = 0;
176
176
  const dis_t* dt = dis_table;
177
- for (size_t m = 0; m < index.M - scaler.nscale; m++) {
177
+ int nscale = scaler ? scaler->nscale : 0;
178
+
179
+ for (size_t m = 0; m < index.M - nscale; m++) {
178
180
  uint64_t c = bsr.read(index.nbits);
179
181
  dis += dt[c];
180
182
  dt += index.ksub;
181
183
  }
182
184
 
183
- for (size_t m = 0; m < scaler.nscale; m++) {
184
- uint64_t c = bsr.read(index.nbits);
185
- dis += scaler.scale_one(dt[c]);
186
- dt += index.ksub;
185
+ if (nscale) {
186
+ for (size_t m = 0; m < nscale; m++) {
187
+ uint64_t c = bsr.read(index.nbits);
188
+ dis += scaler->scale_one(dt[c]);
189
+ dt += index.ksub;
190
+ }
187
191
  }
188
-
189
192
  if (C::cmp(heap_dis[0], dis)) {
190
193
  heap_pop<C>(k, heap_dis, heap_ids);
191
194
  heap_push<C>(k, heap_dis, heap_ids, dis, j);
@@ -193,6 +196,27 @@ void estimators_from_tables_generic(
193
196
  }
194
197
  }
195
198
 
199
+ template <class C>
200
+ ResultHandlerCompare<C, false>* make_knn_handler(
201
+ int impl,
202
+ idx_t n,
203
+ idx_t k,
204
+ size_t ntotal,
205
+ float* distances,
206
+ idx_t* labels) {
207
+ using HeapHC = HeapHandler<C, false>;
208
+ using ReservoirHC = ReservoirHandler<C, false>;
209
+ using SingleResultHC = SingleResultHandler<C, false>;
210
+
211
+ if (k == 1) {
212
+ return new SingleResultHC(n, ntotal, distances, labels);
213
+ } else if (impl % 2 == 0) {
214
+ return new HeapHC(n, ntotal, k, distances, labels);
215
+ } else /* if (impl % 2 == 1) */ {
216
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
217
+ }
218
+ }
219
+
196
220
  } // anonymous namespace
197
221
 
198
222
  using namespace quantize_lut;
@@ -241,22 +265,21 @@ void IndexFastScan::search(
241
265
  !params, "search params not supported for this index");
242
266
  FAISS_THROW_IF_NOT(k > 0);
243
267
 
244
- DummyScaler scaler;
245
268
  if (metric_type == METRIC_L2) {
246
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
269
+ search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
247
270
  } else {
248
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
271
+ search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
249
272
  }
250
273
  }
251
274
 
252
- template <bool is_max, class Scaler>
275
+ template <bool is_max>
253
276
  void IndexFastScan::search_dispatch_implem(
254
277
  idx_t n,
255
278
  const float* x,
256
279
  idx_t k,
257
280
  float* distances,
258
281
  idx_t* labels,
259
- const Scaler& scaler) const {
282
+ const NormTableScaler* scaler) const {
260
283
  using Cfloat = typename std::conditional<
261
284
  is_max,
262
285
  CMax<float, int64_t>,
@@ -319,14 +342,14 @@ void IndexFastScan::search_dispatch_implem(
319
342
  }
320
343
  }
321
344
 
322
- template <class Cfloat, class Scaler>
345
+ template <class Cfloat>
323
346
  void IndexFastScan::search_implem_234(
324
347
  idx_t n,
325
348
  const float* x,
326
349
  idx_t k,
327
350
  float* distances,
328
351
  idx_t* labels,
329
- const Scaler& scaler) const {
352
+ const NormTableScaler* scaler) const {
330
353
  FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
331
354
 
332
355
  const size_t dim12 = ksub * M;
@@ -378,7 +401,7 @@ void IndexFastScan::search_implem_234(
378
401
  }
379
402
  }
380
403
 
381
- template <class C, class Scaler>
404
+ template <class C>
382
405
  void IndexFastScan::search_implem_12(
383
406
  idx_t n,
384
407
  const float* x,
@@ -386,7 +409,8 @@ void IndexFastScan::search_implem_12(
386
409
  float* distances,
387
410
  idx_t* labels,
388
411
  int impl,
389
- const Scaler& scaler) const {
412
+ const NormTableScaler* scaler) const {
413
+ using RH = ResultHandlerCompare<C, false>;
390
414
  FAISS_THROW_IF_NOT(bbs == 32);
391
415
 
392
416
  // handle qbs2 blocking by recursive call
@@ -432,63 +456,31 @@ void IndexFastScan::search_implem_12(
432
456
  pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
433
457
  FAISS_THROW_IF_NOT(LUT_nq == n);
434
458
 
435
- if (k == 1) {
436
- SingleResultHandler<C> handler(n, ntotal);
437
- if (skip & 4) {
438
- // pass
439
- } else {
440
- handler.disable = bool(skip & 2);
441
- pq4_accumulate_loop_qbs(
442
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
443
- }
459
+ std::unique_ptr<RH> handler(
460
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
461
+ handler->disable = bool(skip & 2);
462
+ handler->normalizers = normalizers.get();
444
463
 
445
- handler.to_flat_arrays(distances, labels, normalizers.get());
446
-
447
- } else if (impl == 12) {
448
- std::vector<uint16_t> tmp_dis(n * k);
449
- std::vector<int32_t> tmp_ids(n * k);
450
-
451
- if (skip & 4) {
452
- // skip
453
- } else {
454
- HeapHandler<C> handler(
455
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
456
- handler.disable = bool(skip & 2);
457
-
458
- pq4_accumulate_loop_qbs(
459
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
460
-
461
- if (!(skip & 8)) {
462
- handler.to_flat_arrays(distances, labels, normalizers.get());
463
- }
464
- }
465
-
466
- } else { // impl == 13
467
-
468
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
469
- handler.disable = bool(skip & 2);
470
-
471
- if (skip & 4) {
472
- // skip
473
- } else {
474
- pq4_accumulate_loop_qbs(
475
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
476
- }
477
-
478
- if (!(skip & 8)) {
479
- handler.to_flat_arrays(distances, labels, normalizers.get());
480
- }
481
-
482
- FastScan_stats.t0 += handler.times[0];
483
- FastScan_stats.t1 += handler.times[1];
484
- FastScan_stats.t2 += handler.times[2];
485
- FastScan_stats.t3 += handler.times[3];
464
+ if (skip & 4) {
465
+ // pass
466
+ } else {
467
+ pq4_accumulate_loop_qbs(
468
+ qbs,
469
+ ntotal2,
470
+ M2,
471
+ codes.get(),
472
+ LUT.get(),
473
+ *handler.get(),
474
+ scaler);
475
+ }
476
+ if (!(skip & 8)) {
477
+ handler->end();
486
478
  }
487
479
  }
488
480
 
489
481
  FastScanStats FastScan_stats;
490
482
 
491
- template <class C, class Scaler>
483
+ template <class C>
492
484
  void IndexFastScan::search_implem_14(
493
485
  idx_t n,
494
486
  const float* x,
@@ -496,7 +488,8 @@ void IndexFastScan::search_implem_14(
496
488
  float* distances,
497
489
  idx_t* labels,
498
490
  int impl,
499
- const Scaler& scaler) const {
491
+ const NormTableScaler* scaler) const {
492
+ using RH = ResultHandlerCompare<C, false>;
500
493
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
501
494
 
502
495
  int qbs2 = qbs == 0 ? 4 : qbs;
@@ -531,91 +524,29 @@ void IndexFastScan::search_implem_14(
531
524
  AlignedTable<uint8_t> LUT(n * dim12);
532
525
  pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
533
526
 
534
- if (k == 1) {
535
- SingleResultHandler<C> handler(n, ntotal);
536
- if (skip & 4) {
537
- // pass
538
- } else {
539
- handler.disable = bool(skip & 2);
540
- pq4_accumulate_loop(
541
- n,
542
- ntotal2,
543
- bbs,
544
- M2,
545
- codes.get(),
546
- LUT.get(),
547
- handler,
548
- scaler);
549
- }
550
- handler.to_flat_arrays(distances, labels, normalizers.get());
551
-
552
- } else if (impl == 14) {
553
- std::vector<uint16_t> tmp_dis(n * k);
554
- std::vector<int32_t> tmp_ids(n * k);
555
-
556
- if (skip & 4) {
557
- // skip
558
- } else if (k > 1) {
559
- HeapHandler<C> handler(
560
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
561
- handler.disable = bool(skip & 2);
562
-
563
- pq4_accumulate_loop(
564
- n,
565
- ntotal2,
566
- bbs,
567
- M2,
568
- codes.get(),
569
- LUT.get(),
570
- handler,
571
- scaler);
572
-
573
- if (!(skip & 8)) {
574
- handler.to_flat_arrays(distances, labels, normalizers.get());
575
- }
576
- }
577
-
578
- } else { // impl == 15
579
-
580
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
581
- handler.disable = bool(skip & 2);
582
-
583
- if (skip & 4) {
584
- // skip
585
- } else {
586
- pq4_accumulate_loop(
587
- n,
588
- ntotal2,
589
- bbs,
590
- M2,
591
- codes.get(),
592
- LUT.get(),
593
- handler,
594
- scaler);
595
- }
527
+ std::unique_ptr<RH> handler(
528
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
529
+ handler->disable = bool(skip & 2);
530
+ handler->normalizers = normalizers.get();
596
531
 
597
- if (!(skip & 8)) {
598
- handler.to_flat_arrays(distances, labels, normalizers.get());
599
- }
532
+ if (skip & 4) {
533
+ // pass
534
+ } else {
535
+ pq4_accumulate_loop(
536
+ n,
537
+ ntotal2,
538
+ bbs,
539
+ M2,
540
+ codes.get(),
541
+ LUT.get(),
542
+ *handler.get(),
543
+ scaler);
544
+ }
545
+ if (!(skip & 8)) {
546
+ handler->end();
600
547
  }
601
548
  }
602
549
 
603
- template void IndexFastScan::search_dispatch_implem<true, NormTableScaler>(
604
- idx_t n,
605
- const float* x,
606
- idx_t k,
607
- float* distances,
608
- idx_t* labels,
609
- const NormTableScaler& scaler) const;
610
-
611
- template void IndexFastScan::search_dispatch_implem<false, NormTableScaler>(
612
- idx_t n,
613
- const float* x,
614
- idx_t k,
615
- float* distances,
616
- idx_t* labels,
617
- const NormTableScaler& scaler) const;
618
-
619
550
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
620
551
  std::vector<uint8_t> code(code_size, 0);
621
552
  BitstringWriter bsw(code.data(), code_size);
@@ -13,6 +13,7 @@
13
13
  namespace faiss {
14
14
 
15
15
  struct CodePacker;
16
+ struct NormTableScaler;
16
17
 
17
18
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
18
19
  *
@@ -87,25 +88,25 @@ struct IndexFastScan : Index {
87
88
  uint8_t* lut,
88
89
  float* normalizers) const;
89
90
 
90
- template <bool is_max, class Scaler>
91
+ template <bool is_max>
91
92
  void search_dispatch_implem(
92
93
  idx_t n,
93
94
  const float* x,
94
95
  idx_t k,
95
96
  float* distances,
96
97
  idx_t* labels,
97
- const Scaler& scaler) const;
98
+ const NormTableScaler* scaler) const;
98
99
 
99
- template <class Cfloat, class Scaler>
100
+ template <class Cfloat>
100
101
  void search_implem_234(
101
102
  idx_t n,
102
103
  const float* x,
103
104
  idx_t k,
104
105
  float* distances,
105
106
  idx_t* labels,
106
- const Scaler& scaler) const;
107
+ const NormTableScaler* scaler) const;
107
108
 
108
- template <class C, class Scaler>
109
+ template <class C>
109
110
  void search_implem_12(
110
111
  idx_t n,
111
112
  const float* x,
@@ -113,9 +114,9 @@ struct IndexFastScan : Index {
113
114
  float* distances,
114
115
  idx_t* labels,
115
116
  int impl,
116
- const Scaler& scaler) const;
117
+ const NormTableScaler* scaler) const;
117
118
 
118
- template <class C, class Scaler>
119
+ template <class C>
119
120
  void search_implem_14(
120
121
  idx_t n,
121
122
  const float* x,
@@ -123,7 +124,7 @@ struct IndexFastScan : Index {
123
124
  float* distances,
124
125
  idx_t* labels,
125
126
  int impl,
126
- const Scaler& scaler) const;
127
+ const NormTableScaler* scaler) const;
127
128
 
128
129
  void reconstruct(idx_t key, float* recons) const override;
129
130
  size_t remove_ids(const IDSelector& sel) override;
@@ -14,6 +14,7 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/distances.h>
16
16
  #include <faiss/utils/extra_distances.h>
17
+ #include <faiss/utils/prefetch.h>
17
18
  #include <faiss/utils/sorting.h>
18
19
  #include <faiss/utils/utils.h>
19
20
  #include <cstring>
@@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
122
123
  void set_query(const float* x) override {
123
124
  q = x;
124
125
  }
126
+
127
+ // compute four distances
128
+ void distances_batch_4(
129
+ const idx_t idx0,
130
+ const idx_t idx1,
131
+ const idx_t idx2,
132
+ const idx_t idx3,
133
+ float& dis0,
134
+ float& dis1,
135
+ float& dis2,
136
+ float& dis3) final override {
137
+ ndis += 4;
138
+
139
+ // compute first, assign next
140
+ const float* __restrict y0 =
141
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
142
+ const float* __restrict y1 =
143
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
144
+ const float* __restrict y2 =
145
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
146
+ const float* __restrict y3 =
147
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
148
+
149
+ float dp0 = 0;
150
+ float dp1 = 0;
151
+ float dp2 = 0;
152
+ float dp3 = 0;
153
+ fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
154
+ dis0 = dp0;
155
+ dis1 = dp1;
156
+ dis2 = dp2;
157
+ dis3 = dp3;
158
+ }
125
159
  };
126
160
 
127
161
  struct FlatIPDis : FlatCodesDistanceComputer {
@@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
131
165
  const float* b;
132
166
  size_t ndis;
133
167
 
134
- float symmetric_dis(idx_t i, idx_t j) override {
168
+ float symmetric_dis(idx_t i, idx_t j) final override {
135
169
  return fvec_inner_product(b + j * d, b + i * d, d);
136
170
  }
137
171
 
138
- float distance_to_code(const uint8_t* code) final {
172
+ float distance_to_code(const uint8_t* code) final override {
139
173
  ndis++;
140
- return fvec_inner_product(q, (float*)code, d);
174
+ return fvec_inner_product(q, (const float*)code, d);
141
175
  }
142
176
 
143
177
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
@@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
153
187
  void set_query(const float* x) override {
154
188
  q = x;
155
189
  }
190
+
191
+ // compute four distances
192
+ void distances_batch_4(
193
+ const idx_t idx0,
194
+ const idx_t idx1,
195
+ const idx_t idx2,
196
+ const idx_t idx3,
197
+ float& dis0,
198
+ float& dis1,
199
+ float& dis2,
200
+ float& dis3) final override {
201
+ ndis += 4;
202
+
203
+ // compute first, assign next
204
+ const float* __restrict y0 =
205
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
206
+ const float* __restrict y1 =
207
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
208
+ const float* __restrict y2 =
209
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
210
+ const float* __restrict y3 =
211
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
212
+
213
+ float dp0 = 0;
214
+ float dp1 = 0;
215
+ float dp2 = 0;
216
+ float dp3 = 0;
217
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
218
+ dis0 = dp0;
219
+ dis1 = dp1;
220
+ dis2 = dp2;
221
+ dis3 = dp3;
222
+ }
156
223
  };
157
224
 
158
225
  } // namespace
@@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
184
251
  }
185
252
  }
186
253
 
254
+ /***************************************************
255
+ * IndexFlatL2
256
+ ***************************************************/
257
+
258
+ namespace {
259
+ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
260
+ size_t d;
261
+ idx_t nb;
262
+ const float* q;
263
+ const float* b;
264
+ size_t ndis;
265
+
266
+ const float* l2norms;
267
+ float query_l2norm;
268
+
269
+ float distance_to_code(const uint8_t* code) final override {
270
+ ndis++;
271
+ return fvec_L2sqr(q, (float*)code, d);
272
+ }
273
+
274
+ float operator()(const idx_t i) final override {
275
+ const float* __restrict y =
276
+ reinterpret_cast<const float*>(codes + i * code_size);
277
+
278
+ prefetch_L2(l2norms + i);
279
+ const float dp0 = fvec_inner_product(q, y, d);
280
+ return query_l2norm + l2norms[i] - 2 * dp0;
281
+ }
282
+
283
+ float symmetric_dis(idx_t i, idx_t j) final override {
284
+ const float* __restrict yi =
285
+ reinterpret_cast<const float*>(codes + i * code_size);
286
+ const float* __restrict yj =
287
+ reinterpret_cast<const float*>(codes + j * code_size);
288
+
289
+ prefetch_L2(l2norms + i);
290
+ prefetch_L2(l2norms + j);
291
+ const float dp0 = fvec_inner_product(yi, yj, d);
292
+ return l2norms[i] + l2norms[j] - 2 * dp0;
293
+ }
294
+
295
+ explicit FlatL2WithNormsDis(
296
+ const IndexFlatL2& storage,
297
+ const float* q = nullptr)
298
+ : FlatCodesDistanceComputer(
299
+ storage.codes.data(),
300
+ storage.code_size),
301
+ d(storage.d),
302
+ nb(storage.ntotal),
303
+ q(q),
304
+ b(storage.get_xb()),
305
+ ndis(0),
306
+ l2norms(storage.cached_l2norms.data()),
307
+ query_l2norm(0) {}
308
+
309
+ void set_query(const float* x) override {
310
+ q = x;
311
+ query_l2norm = fvec_norm_L2sqr(q, d);
312
+ }
313
+
314
+ // compute four distances
315
+ void distances_batch_4(
316
+ const idx_t idx0,
317
+ const idx_t idx1,
318
+ const idx_t idx2,
319
+ const idx_t idx3,
320
+ float& dis0,
321
+ float& dis1,
322
+ float& dis2,
323
+ float& dis3) final override {
324
+ ndis += 4;
325
+
326
+ // compute first, assign next
327
+ const float* __restrict y0 =
328
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
329
+ const float* __restrict y1 =
330
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
331
+ const float* __restrict y2 =
332
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
333
+ const float* __restrict y3 =
334
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
335
+
336
+ prefetch_L2(l2norms + idx0);
337
+ prefetch_L2(l2norms + idx1);
338
+ prefetch_L2(l2norms + idx2);
339
+ prefetch_L2(l2norms + idx3);
340
+
341
+ float dp0 = 0;
342
+ float dp1 = 0;
343
+ float dp2 = 0;
344
+ float dp3 = 0;
345
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
346
+ dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
347
+ dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
348
+ dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
349
+ dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
350
+ }
351
+ };
352
+
353
+ } // namespace
354
+
355
+ void IndexFlatL2::sync_l2norms() {
356
+ cached_l2norms.resize(ntotal);
357
+ fvec_norms_L2sqr(
358
+ cached_l2norms.data(),
359
+ reinterpret_cast<const float*>(codes.data()),
360
+ d,
361
+ ntotal);
362
+ }
363
+
364
+ void IndexFlatL2::clear_l2norms() {
365
+ cached_l2norms.clear();
366
+ cached_l2norms.shrink_to_fit();
367
+ }
368
+
369
+ FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
370
+ if (metric_type == METRIC_L2) {
371
+ if (!cached_l2norms.empty()) {
372
+ return new FlatL2WithNormsDis(*this);
373
+ }
374
+ }
375
+
376
+ return IndexFlat::get_FlatCodesDistanceComputer();
377
+ }
378
+
187
379
  /***************************************************
188
380
  * IndexFlat1D
189
381
  ***************************************************/