faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -37,22 +37,22 @@ inline size_t roundup(size_t a, size_t b) {
37
37
 
38
38
  void IndexFastScan::init_fastscan(
39
39
  int d,
40
- size_t M,
41
- size_t nbits,
40
+ size_t M_2,
41
+ size_t nbits_2,
42
42
  MetricType metric,
43
43
  int bbs) {
44
- FAISS_THROW_IF_NOT(nbits == 4);
44
+ FAISS_THROW_IF_NOT(nbits_2 == 4);
45
45
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
46
46
  this->d = d;
47
- this->M = M;
48
- this->nbits = nbits;
47
+ this->M = M_2;
48
+ this->nbits = nbits_2;
49
49
  this->metric_type = metric;
50
50
  this->bbs = bbs;
51
- ksub = (1 << nbits);
51
+ ksub = (1 << nbits_2);
52
52
 
53
- code_size = (M * nbits + 7) / 8;
53
+ code_size = (M_2 * nbits_2 + 7) / 8;
54
54
  ntotal = ntotal2 = 0;
55
- M2 = roundup(M, 2);
55
+ M2 = roundup(M_2, 2);
56
56
  is_trained = false;
57
57
  }
58
58
 
@@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
158
158
 
159
159
  namespace {
160
160
 
161
- template <class C, typename dis_t, class Scaler>
161
+ template <class C, typename dis_t>
162
162
  void estimators_from_tables_generic(
163
163
  const IndexFastScan& index,
164
164
  const uint8_t* codes,
@@ -167,25 +167,28 @@ void estimators_from_tables_generic(
167
167
  size_t k,
168
168
  typename C::T* heap_dis,
169
169
  int64_t* heap_ids,
170
- const Scaler& scaler) {
170
+ const NormTableScaler* scaler) {
171
171
  using accu_t = typename C::T;
172
172
 
173
173
  for (size_t j = 0; j < ncodes; ++j) {
174
174
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
175
175
  accu_t dis = 0;
176
176
  const dis_t* dt = dis_table;
177
- for (size_t m = 0; m < index.M - scaler.nscale; m++) {
177
+ int nscale = scaler ? scaler->nscale : 0;
178
+
179
+ for (size_t m = 0; m < index.M - nscale; m++) {
178
180
  uint64_t c = bsr.read(index.nbits);
179
181
  dis += dt[c];
180
182
  dt += index.ksub;
181
183
  }
182
184
 
183
- for (size_t m = 0; m < scaler.nscale; m++) {
184
- uint64_t c = bsr.read(index.nbits);
185
- dis += scaler.scale_one(dt[c]);
186
- dt += index.ksub;
185
+ if (nscale) {
186
+ for (size_t m = 0; m < nscale; m++) {
187
+ uint64_t c = bsr.read(index.nbits);
188
+ dis += scaler->scale_one(dt[c]);
189
+ dt += index.ksub;
190
+ }
187
191
  }
188
-
189
192
  if (C::cmp(heap_dis[0], dis)) {
190
193
  heap_pop<C>(k, heap_dis, heap_ids);
191
194
  heap_push<C>(k, heap_dis, heap_ids, dis, j);
@@ -193,6 +196,27 @@ void estimators_from_tables_generic(
193
196
  }
194
197
  }
195
198
 
199
+ template <class C>
200
+ ResultHandlerCompare<C, false>* make_knn_handler(
201
+ int impl,
202
+ idx_t n,
203
+ idx_t k,
204
+ size_t ntotal,
205
+ float* distances,
206
+ idx_t* labels) {
207
+ using HeapHC = HeapHandler<C, false>;
208
+ using ReservoirHC = ReservoirHandler<C, false>;
209
+ using SingleResultHC = SingleResultHandler<C, false>;
210
+
211
+ if (k == 1) {
212
+ return new SingleResultHC(n, ntotal, distances, labels);
213
+ } else if (impl % 2 == 0) {
214
+ return new HeapHC(n, ntotal, k, distances, labels);
215
+ } else /* if (impl % 2 == 1) */ {
216
+ return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels);
217
+ }
218
+ }
219
+
196
220
  } // anonymous namespace
197
221
 
198
222
  using namespace quantize_lut;
@@ -241,22 +265,21 @@ void IndexFastScan::search(
241
265
  !params, "search params not supported for this index");
242
266
  FAISS_THROW_IF_NOT(k > 0);
243
267
 
244
- DummyScaler scaler;
245
268
  if (metric_type == METRIC_L2) {
246
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
269
+ search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
247
270
  } else {
248
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
271
+ search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
249
272
  }
250
273
  }
251
274
 
252
- template <bool is_max, class Scaler>
275
+ template <bool is_max>
253
276
  void IndexFastScan::search_dispatch_implem(
254
277
  idx_t n,
255
278
  const float* x,
256
279
  idx_t k,
257
280
  float* distances,
258
281
  idx_t* labels,
259
- const Scaler& scaler) const {
282
+ const NormTableScaler* scaler) const {
260
283
  using Cfloat = typename std::conditional<
261
284
  is_max,
262
285
  CMax<float, int64_t>,
@@ -319,14 +342,14 @@ void IndexFastScan::search_dispatch_implem(
319
342
  }
320
343
  }
321
344
 
322
- template <class Cfloat, class Scaler>
345
+ template <class Cfloat>
323
346
  void IndexFastScan::search_implem_234(
324
347
  idx_t n,
325
348
  const float* x,
326
349
  idx_t k,
327
350
  float* distances,
328
351
  idx_t* labels,
329
- const Scaler& scaler) const {
352
+ const NormTableScaler* scaler) const {
330
353
  FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
331
354
 
332
355
  const size_t dim12 = ksub * M;
@@ -378,7 +401,7 @@ void IndexFastScan::search_implem_234(
378
401
  }
379
402
  }
380
403
 
381
- template <class C, class Scaler>
404
+ template <class C>
382
405
  void IndexFastScan::search_implem_12(
383
406
  idx_t n,
384
407
  const float* x,
@@ -386,7 +409,8 @@ void IndexFastScan::search_implem_12(
386
409
  float* distances,
387
410
  idx_t* labels,
388
411
  int impl,
389
- const Scaler& scaler) const {
412
+ const NormTableScaler* scaler) const {
413
+ using RH = ResultHandlerCompare<C, false>;
390
414
  FAISS_THROW_IF_NOT(bbs == 32);
391
415
 
392
416
  // handle qbs2 blocking by recursive call
@@ -432,63 +456,31 @@ void IndexFastScan::search_implem_12(
432
456
  pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
433
457
  FAISS_THROW_IF_NOT(LUT_nq == n);
434
458
 
435
- if (k == 1) {
436
- SingleResultHandler<C> handler(n, ntotal);
437
- if (skip & 4) {
438
- // pass
439
- } else {
440
- handler.disable = bool(skip & 2);
441
- pq4_accumulate_loop_qbs(
442
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
443
- }
459
+ std::unique_ptr<RH> handler(
460
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
461
+ handler->disable = bool(skip & 2);
462
+ handler->normalizers = normalizers.get();
444
463
 
445
- handler.to_flat_arrays(distances, labels, normalizers.get());
446
-
447
- } else if (impl == 12) {
448
- std::vector<uint16_t> tmp_dis(n * k);
449
- std::vector<int32_t> tmp_ids(n * k);
450
-
451
- if (skip & 4) {
452
- // skip
453
- } else {
454
- HeapHandler<C> handler(
455
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
456
- handler.disable = bool(skip & 2);
457
-
458
- pq4_accumulate_loop_qbs(
459
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
460
-
461
- if (!(skip & 8)) {
462
- handler.to_flat_arrays(distances, labels, normalizers.get());
463
- }
464
- }
465
-
466
- } else { // impl == 13
467
-
468
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
469
- handler.disable = bool(skip & 2);
470
-
471
- if (skip & 4) {
472
- // skip
473
- } else {
474
- pq4_accumulate_loop_qbs(
475
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
476
- }
477
-
478
- if (!(skip & 8)) {
479
- handler.to_flat_arrays(distances, labels, normalizers.get());
480
- }
481
-
482
- FastScan_stats.t0 += handler.times[0];
483
- FastScan_stats.t1 += handler.times[1];
484
- FastScan_stats.t2 += handler.times[2];
485
- FastScan_stats.t3 += handler.times[3];
464
+ if (skip & 4) {
465
+ // pass
466
+ } else {
467
+ pq4_accumulate_loop_qbs(
468
+ qbs,
469
+ ntotal2,
470
+ M2,
471
+ codes.get(),
472
+ LUT.get(),
473
+ *handler.get(),
474
+ scaler);
475
+ }
476
+ if (!(skip & 8)) {
477
+ handler->end();
486
478
  }
487
479
  }
488
480
 
489
481
  FastScanStats FastScan_stats;
490
482
 
491
- template <class C, class Scaler>
483
+ template <class C>
492
484
  void IndexFastScan::search_implem_14(
493
485
  idx_t n,
494
486
  const float* x,
@@ -496,7 +488,8 @@ void IndexFastScan::search_implem_14(
496
488
  float* distances,
497
489
  idx_t* labels,
498
490
  int impl,
499
- const Scaler& scaler) const {
491
+ const NormTableScaler* scaler) const {
492
+ using RH = ResultHandlerCompare<C, false>;
500
493
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
501
494
 
502
495
  int qbs2 = qbs == 0 ? 4 : qbs;
@@ -531,91 +524,29 @@ void IndexFastScan::search_implem_14(
531
524
  AlignedTable<uint8_t> LUT(n * dim12);
532
525
  pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
533
526
 
534
- if (k == 1) {
535
- SingleResultHandler<C> handler(n, ntotal);
536
- if (skip & 4) {
537
- // pass
538
- } else {
539
- handler.disable = bool(skip & 2);
540
- pq4_accumulate_loop(
541
- n,
542
- ntotal2,
543
- bbs,
544
- M2,
545
- codes.get(),
546
- LUT.get(),
547
- handler,
548
- scaler);
549
- }
550
- handler.to_flat_arrays(distances, labels, normalizers.get());
551
-
552
- } else if (impl == 14) {
553
- std::vector<uint16_t> tmp_dis(n * k);
554
- std::vector<int32_t> tmp_ids(n * k);
555
-
556
- if (skip & 4) {
557
- // skip
558
- } else if (k > 1) {
559
- HeapHandler<C> handler(
560
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
561
- handler.disable = bool(skip & 2);
562
-
563
- pq4_accumulate_loop(
564
- n,
565
- ntotal2,
566
- bbs,
567
- M2,
568
- codes.get(),
569
- LUT.get(),
570
- handler,
571
- scaler);
572
-
573
- if (!(skip & 8)) {
574
- handler.to_flat_arrays(distances, labels, normalizers.get());
575
- }
576
- }
577
-
578
- } else { // impl == 15
579
-
580
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
581
- handler.disable = bool(skip & 2);
582
-
583
- if (skip & 4) {
584
- // skip
585
- } else {
586
- pq4_accumulate_loop(
587
- n,
588
- ntotal2,
589
- bbs,
590
- M2,
591
- codes.get(),
592
- LUT.get(),
593
- handler,
594
- scaler);
595
- }
527
+ std::unique_ptr<RH> handler(
528
+ make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
529
+ handler->disable = bool(skip & 2);
530
+ handler->normalizers = normalizers.get();
596
531
 
597
- if (!(skip & 8)) {
598
- handler.to_flat_arrays(distances, labels, normalizers.get());
599
- }
532
+ if (skip & 4) {
533
+ // pass
534
+ } else {
535
+ pq4_accumulate_loop(
536
+ n,
537
+ ntotal2,
538
+ bbs,
539
+ M2,
540
+ codes.get(),
541
+ LUT.get(),
542
+ *handler.get(),
543
+ scaler);
544
+ }
545
+ if (!(skip & 8)) {
546
+ handler->end();
600
547
  }
601
548
  }
602
549
 
603
- template void IndexFastScan::search_dispatch_implem<true, NormTableScaler>(
604
- idx_t n,
605
- const float* x,
606
- idx_t k,
607
- float* distances,
608
- idx_t* labels,
609
- const NormTableScaler& scaler) const;
610
-
611
- template void IndexFastScan::search_dispatch_implem<false, NormTableScaler>(
612
- idx_t n,
613
- const float* x,
614
- idx_t k,
615
- float* distances,
616
- idx_t* labels,
617
- const NormTableScaler& scaler) const;
618
-
619
550
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
620
551
  std::vector<uint8_t> code(code_size, 0);
621
552
  BitstringWriter bsw(code.data(), code_size);
@@ -13,6 +13,7 @@
13
13
  namespace faiss {
14
14
 
15
15
  struct CodePacker;
16
+ struct NormTableScaler;
16
17
 
17
18
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
18
19
  *
@@ -87,25 +88,25 @@ struct IndexFastScan : Index {
87
88
  uint8_t* lut,
88
89
  float* normalizers) const;
89
90
 
90
- template <bool is_max, class Scaler>
91
+ template <bool is_max>
91
92
  void search_dispatch_implem(
92
93
  idx_t n,
93
94
  const float* x,
94
95
  idx_t k,
95
96
  float* distances,
96
97
  idx_t* labels,
97
- const Scaler& scaler) const;
98
+ const NormTableScaler* scaler) const;
98
99
 
99
- template <class Cfloat, class Scaler>
100
+ template <class Cfloat>
100
101
  void search_implem_234(
101
102
  idx_t n,
102
103
  const float* x,
103
104
  idx_t k,
104
105
  float* distances,
105
106
  idx_t* labels,
106
- const Scaler& scaler) const;
107
+ const NormTableScaler* scaler) const;
107
108
 
108
- template <class C, class Scaler>
109
+ template <class C>
109
110
  void search_implem_12(
110
111
  idx_t n,
111
112
  const float* x,
@@ -113,9 +114,9 @@ struct IndexFastScan : Index {
113
114
  float* distances,
114
115
  idx_t* labels,
115
116
  int impl,
116
- const Scaler& scaler) const;
117
+ const NormTableScaler* scaler) const;
117
118
 
118
- template <class C, class Scaler>
119
+ template <class C>
119
120
  void search_implem_14(
120
121
  idx_t n,
121
122
  const float* x,
@@ -123,7 +124,7 @@ struct IndexFastScan : Index {
123
124
  float* distances,
124
125
  idx_t* labels,
125
126
  int impl,
126
- const Scaler& scaler) const;
127
+ const NormTableScaler* scaler) const;
127
128
 
128
129
  void reconstruct(idx_t key, float* recons) const override;
129
130
  size_t remove_ids(const IDSelector& sel) override;
@@ -14,6 +14,7 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/distances.h>
16
16
  #include <faiss/utils/extra_distances.h>
17
+ #include <faiss/utils/prefetch.h>
17
18
  #include <faiss/utils/sorting.h>
18
19
  #include <faiss/utils/utils.h>
19
20
  #include <cstring>
@@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
122
123
  void set_query(const float* x) override {
123
124
  q = x;
124
125
  }
126
+
127
+ // compute four distances
128
+ void distances_batch_4(
129
+ const idx_t idx0,
130
+ const idx_t idx1,
131
+ const idx_t idx2,
132
+ const idx_t idx3,
133
+ float& dis0,
134
+ float& dis1,
135
+ float& dis2,
136
+ float& dis3) final override {
137
+ ndis += 4;
138
+
139
+ // compute first, assign next
140
+ const float* __restrict y0 =
141
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
142
+ const float* __restrict y1 =
143
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
144
+ const float* __restrict y2 =
145
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
146
+ const float* __restrict y3 =
147
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
148
+
149
+ float dp0 = 0;
150
+ float dp1 = 0;
151
+ float dp2 = 0;
152
+ float dp3 = 0;
153
+ fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
154
+ dis0 = dp0;
155
+ dis1 = dp1;
156
+ dis2 = dp2;
157
+ dis3 = dp3;
158
+ }
125
159
  };
126
160
 
127
161
  struct FlatIPDis : FlatCodesDistanceComputer {
@@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
131
165
  const float* b;
132
166
  size_t ndis;
133
167
 
134
- float symmetric_dis(idx_t i, idx_t j) override {
168
+ float symmetric_dis(idx_t i, idx_t j) final override {
135
169
  return fvec_inner_product(b + j * d, b + i * d, d);
136
170
  }
137
171
 
138
- float distance_to_code(const uint8_t* code) final {
172
+ float distance_to_code(const uint8_t* code) final override {
139
173
  ndis++;
140
- return fvec_inner_product(q, (float*)code, d);
174
+ return fvec_inner_product(q, (const float*)code, d);
141
175
  }
142
176
 
143
177
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
@@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
153
187
  void set_query(const float* x) override {
154
188
  q = x;
155
189
  }
190
+
191
+ // compute four distances
192
+ void distances_batch_4(
193
+ const idx_t idx0,
194
+ const idx_t idx1,
195
+ const idx_t idx2,
196
+ const idx_t idx3,
197
+ float& dis0,
198
+ float& dis1,
199
+ float& dis2,
200
+ float& dis3) final override {
201
+ ndis += 4;
202
+
203
+ // compute first, assign next
204
+ const float* __restrict y0 =
205
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
206
+ const float* __restrict y1 =
207
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
208
+ const float* __restrict y2 =
209
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
210
+ const float* __restrict y3 =
211
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
212
+
213
+ float dp0 = 0;
214
+ float dp1 = 0;
215
+ float dp2 = 0;
216
+ float dp3 = 0;
217
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
218
+ dis0 = dp0;
219
+ dis1 = dp1;
220
+ dis2 = dp2;
221
+ dis3 = dp3;
222
+ }
156
223
  };
157
224
 
158
225
  } // namespace
@@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
184
251
  }
185
252
  }
186
253
 
254
+ /***************************************************
255
+ * IndexFlatL2
256
+ ***************************************************/
257
+
258
+ namespace {
259
+ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
260
+ size_t d;
261
+ idx_t nb;
262
+ const float* q;
263
+ const float* b;
264
+ size_t ndis;
265
+
266
+ const float* l2norms;
267
+ float query_l2norm;
268
+
269
+ float distance_to_code(const uint8_t* code) final override {
270
+ ndis++;
271
+ return fvec_L2sqr(q, (float*)code, d);
272
+ }
273
+
274
+ float operator()(const idx_t i) final override {
275
+ const float* __restrict y =
276
+ reinterpret_cast<const float*>(codes + i * code_size);
277
+
278
+ prefetch_L2(l2norms + i);
279
+ const float dp0 = fvec_inner_product(q, y, d);
280
+ return query_l2norm + l2norms[i] - 2 * dp0;
281
+ }
282
+
283
+ float symmetric_dis(idx_t i, idx_t j) final override {
284
+ const float* __restrict yi =
285
+ reinterpret_cast<const float*>(codes + i * code_size);
286
+ const float* __restrict yj =
287
+ reinterpret_cast<const float*>(codes + j * code_size);
288
+
289
+ prefetch_L2(l2norms + i);
290
+ prefetch_L2(l2norms + j);
291
+ const float dp0 = fvec_inner_product(yi, yj, d);
292
+ return l2norms[i] + l2norms[j] - 2 * dp0;
293
+ }
294
+
295
+ explicit FlatL2WithNormsDis(
296
+ const IndexFlatL2& storage,
297
+ const float* q = nullptr)
298
+ : FlatCodesDistanceComputer(
299
+ storage.codes.data(),
300
+ storage.code_size),
301
+ d(storage.d),
302
+ nb(storage.ntotal),
303
+ q(q),
304
+ b(storage.get_xb()),
305
+ ndis(0),
306
+ l2norms(storage.cached_l2norms.data()),
307
+ query_l2norm(0) {}
308
+
309
+ void set_query(const float* x) override {
310
+ q = x;
311
+ query_l2norm = fvec_norm_L2sqr(q, d);
312
+ }
313
+
314
+ // compute four distances
315
+ void distances_batch_4(
316
+ const idx_t idx0,
317
+ const idx_t idx1,
318
+ const idx_t idx2,
319
+ const idx_t idx3,
320
+ float& dis0,
321
+ float& dis1,
322
+ float& dis2,
323
+ float& dis3) final override {
324
+ ndis += 4;
325
+
326
+ // compute first, assign next
327
+ const float* __restrict y0 =
328
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
329
+ const float* __restrict y1 =
330
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
331
+ const float* __restrict y2 =
332
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
333
+ const float* __restrict y3 =
334
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
335
+
336
+ prefetch_L2(l2norms + idx0);
337
+ prefetch_L2(l2norms + idx1);
338
+ prefetch_L2(l2norms + idx2);
339
+ prefetch_L2(l2norms + idx3);
340
+
341
+ float dp0 = 0;
342
+ float dp1 = 0;
343
+ float dp2 = 0;
344
+ float dp3 = 0;
345
+ fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
346
+ dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
347
+ dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
348
+ dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
349
+ dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
350
+ }
351
+ };
352
+
353
+ } // namespace
354
+
355
+ void IndexFlatL2::sync_l2norms() {
356
+ cached_l2norms.resize(ntotal);
357
+ fvec_norms_L2sqr(
358
+ cached_l2norms.data(),
359
+ reinterpret_cast<const float*>(codes.data()),
360
+ d,
361
+ ntotal);
362
+ }
363
+
364
+ void IndexFlatL2::clear_l2norms() {
365
+ cached_l2norms.clear();
366
+ cached_l2norms.shrink_to_fit();
367
+ }
368
+
369
+ FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
370
+ if (metric_type == METRIC_L2) {
371
+ if (!cached_l2norms.empty()) {
372
+ return new FlatL2WithNormsDis(*this);
373
+ }
374
+ }
375
+
376
+ return IndexFlat::get_FlatCodesDistanceComputer();
377
+ }
378
+
187
379
  /***************************************************
188
380
  * IndexFlat1D
189
381
  ***************************************************/