faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -21,6 +21,7 @@
21
21
  #include <faiss/impl/AuxIndexStructures.h>
22
22
  #include <faiss/impl/FaissAssert.h>
23
23
  #include <faiss/utils/hamming.h>
24
+ #include <faiss/utils/sorting.h>
24
25
  #include <faiss/utils/utils.h>
25
26
 
26
27
  namespace faiss {
@@ -28,28 +29,14 @@ namespace faiss {
28
29
  IndexBinaryIVF::IndexBinaryIVF(IndexBinary* quantizer, size_t d, size_t nlist)
29
30
  : IndexBinary(d),
30
31
  invlists(new ArrayInvertedLists(nlist, code_size)),
31
- own_invlists(true),
32
- nprobe(1),
33
- max_codes(0),
34
32
  quantizer(quantizer),
35
- nlist(nlist),
36
- own_fields(false),
37
- clustering_index(nullptr) {
33
+ nlist(nlist) {
38
34
  FAISS_THROW_IF_NOT(d == quantizer->d);
39
35
  is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
40
-
41
36
  cp.niter = 10;
42
37
  }
43
38
 
44
- IndexBinaryIVF::IndexBinaryIVF()
45
- : invlists(nullptr),
46
- own_invlists(false),
47
- nprobe(1),
48
- max_codes(0),
49
- quantizer(nullptr),
50
- nlist(0),
51
- own_fields(false),
52
- clustering_index(nullptr) {}
39
+ IndexBinaryIVF::IndexBinaryIVF() {}
53
40
 
54
41
  void IndexBinaryIVF::add(idx_t n, const uint8_t* x) {
55
42
  add_with_ids(n, x, nullptr);
@@ -158,7 +145,7 @@ void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
158
145
 
159
146
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
160
147
  size_t list_size = invlists->list_size(list_no);
161
- const Index::idx_t* idlist = invlists->get_ids(list_no);
148
+ const idx_t* idlist = invlists->get_ids(list_no);
162
149
 
163
150
  for (idx_t offset = 0; offset < list_size; offset++) {
164
151
  idx_t id = idlist[offset];
@@ -174,11 +161,11 @@ void IndexBinaryIVF::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
174
161
 
175
162
  void IndexBinaryIVF::search_and_reconstruct(
176
163
  idx_t n,
177
- const uint8_t* x,
164
+ const uint8_t* __restrict x,
178
165
  idx_t k,
179
- int32_t* distances,
180
- idx_t* labels,
181
- uint8_t* recons,
166
+ int32_t* __restrict distances,
167
+ idx_t* __restrict labels,
168
+ uint8_t* __restrict recons,
182
169
  const SearchParameters* params) const {
183
170
  FAISS_THROW_IF_NOT_MSG(
184
171
  !params, "search params not supported for this index");
@@ -320,8 +307,6 @@ void IndexBinaryIVF::replace_invlists(InvertedLists* il, bool own) {
320
307
 
321
308
  namespace {
322
309
 
323
- using idx_t = Index::idx_t;
324
-
325
310
  template <class HammingComputer>
326
311
  struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
327
312
  HammingComputer hc;
@@ -346,10 +331,10 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
346
331
 
347
332
  size_t scan_codes(
348
333
  size_t n,
349
- const uint8_t* codes,
350
- const idx_t* ids,
351
- int32_t* simi,
352
- idx_t* idxi,
334
+ const uint8_t* __restrict codes,
335
+ const idx_t* __restrict ids,
336
+ int32_t* __restrict simi,
337
+ idx_t* __restrict idxi,
353
338
  size_t k) const override {
354
339
  using C = CMax<int32_t, idx_t>;
355
340
 
@@ -368,8 +353,8 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
368
353
 
369
354
  void scan_codes_range(
370
355
  size_t n,
371
- const uint8_t* codes,
372
- const idx_t* ids,
356
+ const uint8_t* __restrict codes,
357
+ const idx_t* __restrict ids,
373
358
  int radius,
374
359
  RangeQueryResult& result) const override {
375
360
  size_t nup = 0;
@@ -387,12 +372,12 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
387
372
  void search_knn_hamming_heap(
388
373
  const IndexBinaryIVF& ivf,
389
374
  size_t n,
390
- const uint8_t* x,
375
+ const uint8_t* __restrict x,
391
376
  idx_t k,
392
- const idx_t* keys,
393
- const int32_t* coarse_dis,
394
- int32_t* distances,
395
- idx_t* labels,
377
+ const idx_t* __restrict keys,
378
+ const int32_t* __restrict coarse_dis,
379
+ int32_t* __restrict distances,
380
+ idx_t* __restrict labels,
396
381
  bool store_pairs,
397
382
  const IVFSearchParameters* params) {
398
383
  idx_t nprobe = params ? params->nprobe : ivf.nprobe;
@@ -448,7 +433,7 @@ void search_knn_hamming_heap(
448
433
  size_t list_size = ivf.invlists->list_size(key);
449
434
  InvertedLists::ScopedCodes scodes(ivf.invlists, key);
450
435
  std::unique_ptr<InvertedLists::ScopedIds> sids;
451
- const Index::idx_t* ids = nullptr;
436
+ const idx_t* ids = nullptr;
452
437
 
453
438
  if (!store_pairs) {
454
439
  sids.reset(new InvertedLists::ScopedIds(ivf.invlists, key));
@@ -483,11 +468,11 @@ template <class HammingComputer, bool store_pairs>
483
468
  void search_knn_hamming_count(
484
469
  const IndexBinaryIVF& ivf,
485
470
  size_t nx,
486
- const uint8_t* x,
487
- const idx_t* keys,
471
+ const uint8_t* __restrict x,
472
+ const idx_t* __restrict keys,
488
473
  int k,
489
- int32_t* distances,
490
- idx_t* labels,
474
+ int32_t* __restrict distances,
475
+ idx_t* __restrict labels,
491
476
  const IVFSearchParameters* params) {
492
477
  const int nBuckets = ivf.d + 1;
493
478
  std::vector<int> all_counters(nx * nBuckets, 0);
@@ -533,7 +518,7 @@ void search_knn_hamming_count(
533
518
  size_t list_size = ivf.invlists->list_size(key);
534
519
  InvertedLists::ScopedCodes scodes(ivf.invlists, key);
535
520
  const uint8_t* list_vecs = scodes.get();
536
- const Index::idx_t* ids =
521
+ const idx_t* ids =
537
522
  store_pairs ? nullptr : ivf.invlists->get_ids(key);
538
523
 
539
524
  for (size_t j = 0; j < list_size; j++) {
@@ -571,6 +556,185 @@ void search_knn_hamming_count(
571
556
  indexIVF_stats.ndis += ndis;
572
557
  }
573
558
 
559
+ /* Manages NQ queries at a time, stores results */
560
+ template <class HammingComputer, int NQ, int K>
561
+ struct BlockSearch {
562
+ HammingComputer hcs[NQ];
563
+ // heaps to update for each query
564
+ int32_t* distances[NQ];
565
+ idx_t* labels[NQ];
566
+ // curent top of heap
567
+ int32_t heap_tops[NQ];
568
+
569
+ BlockSearch(
570
+ size_t code_size,
571
+ const uint8_t* __restrict x,
572
+ const int32_t* __restrict keys,
573
+ int32_t* __restrict all_distances,
574
+ idx_t* __restrict all_labels) {
575
+ for (idx_t q = 0; q < NQ; q++) {
576
+ idx_t qno = keys[q];
577
+ hcs[q] = HammingComputer(x + qno * code_size, code_size);
578
+ distances[q] = all_distances + qno * K;
579
+ labels[q] = all_labels + qno * K;
580
+ heap_tops[q] = distances[q][0];
581
+ }
582
+ }
583
+
584
+ void add_bcode(const uint8_t* bcode, idx_t id) {
585
+ using C = CMax<int32_t, idx_t>;
586
+ for (int q = 0; q < NQ; q++) {
587
+ int dis = hcs[q].hamming(bcode);
588
+ if (dis < heap_tops[q]) {
589
+ heap_replace_top<C>(K, distances[q], labels[q], dis, id);
590
+ heap_tops[q] = distances[q][0];
591
+ }
592
+ }
593
+ }
594
+ };
595
+
596
+ template <class HammingComputer, int NQ>
597
+ struct BlockSearchVariableK {
598
+ int k;
599
+ HammingComputer hcs[NQ];
600
+ // heaps to update for each query
601
+ int32_t* distances[NQ];
602
+ idx_t* labels[NQ];
603
+ // curent top of heap
604
+ int32_t heap_tops[NQ];
605
+
606
+ BlockSearchVariableK(
607
+ size_t code_size,
608
+ int k,
609
+ const uint8_t* __restrict x,
610
+ const int32_t* __restrict keys,
611
+ int32_t* __restrict all_distances,
612
+ idx_t* __restrict all_labels)
613
+ : k(k) {
614
+ for (idx_t q = 0; q < NQ; q++) {
615
+ idx_t qno = keys[q];
616
+ hcs[q] = HammingComputer(x + qno * code_size, code_size);
617
+ distances[q] = all_distances + qno * k;
618
+ labels[q] = all_labels + qno * k;
619
+ heap_tops[q] = distances[q][0];
620
+ }
621
+ }
622
+
623
+ void add_bcode(const uint8_t* bcode, idx_t id) {
624
+ using C = CMax<int32_t, idx_t>;
625
+ for (int q = 0; q < NQ; q++) {
626
+ int dis = hcs[q].hamming(bcode);
627
+ if (dis < heap_tops[q]) {
628
+ heap_replace_top<C>(k, distances[q], labels[q], dis, id);
629
+ heap_tops[q] = distances[q][0];
630
+ }
631
+ }
632
+ }
633
+ };
634
+
635
+ template <class HammingComputer>
636
+ void search_knn_hamming_per_invlist(
637
+ const IndexBinaryIVF& ivf,
638
+ size_t n,
639
+ const uint8_t* __restrict x,
640
+ idx_t k,
641
+ const idx_t* __restrict keys_in,
642
+ const int32_t* __restrict coarse_dis,
643
+ int32_t* __restrict distances,
644
+ idx_t* __restrict labels,
645
+ bool store_pairs,
646
+ const IVFSearchParameters* params) {
647
+ idx_t nprobe = params ? params->nprobe : ivf.nprobe;
648
+ nprobe = std::min((idx_t)ivf.nlist, nprobe);
649
+ idx_t max_codes = params ? params->max_codes : ivf.max_codes;
650
+ FAISS_THROW_IF_NOT(max_codes == 0);
651
+ FAISS_THROW_IF_NOT(!store_pairs);
652
+ MetricType metric_type = ivf.metric_type;
653
+
654
+ // reorder buckets
655
+ std::vector<int64_t> lims(n + 1);
656
+ int32_t* keys = new int32_t[n * nprobe];
657
+ std::unique_ptr<int32_t[]> delete_keys(keys);
658
+ for (idx_t i = 0; i < n * nprobe; i++) {
659
+ keys[i] = keys_in[i];
660
+ }
661
+ matrix_bucket_sort_inplace(n, nprobe, keys, ivf.nlist, lims.data(), 0);
662
+
663
+ using C = CMax<int32_t, idx_t>;
664
+ heap_heapify<C>(n * k, distances, labels);
665
+ const size_t code_size = ivf.code_size;
666
+
667
+ for (idx_t l = 0; l < ivf.nlist; l++) {
668
+ idx_t l0 = lims[l], nq = lims[l + 1] - l0;
669
+
670
+ InvertedLists::ScopedCodes scodes(ivf.invlists, l);
671
+ InvertedLists::ScopedIds sidx(ivf.invlists, l);
672
+ idx_t nb = ivf.invlists->list_size(l);
673
+ const uint8_t* bcodes = scodes.get();
674
+ const idx_t* ids = sidx.get();
675
+
676
+ idx_t i = 0;
677
+
678
+ // process as much as possible by blocks
679
+ constexpr int BS = 4;
680
+
681
+ if (k == 1) {
682
+ for (; i + BS <= nq; i += BS) {
683
+ BlockSearch<HammingComputer, BS, 1> bc(
684
+ code_size, x, keys + l0 + i, distances, labels);
685
+ for (idx_t j = 0; j < nb; j++) {
686
+ bc.add_bcode(bcodes + j * code_size, ids[j]);
687
+ }
688
+ }
689
+ } else if (k == 2) {
690
+ for (; i + BS <= nq; i += BS) {
691
+ BlockSearch<HammingComputer, BS, 2> bc(
692
+ code_size, x, keys + l0 + i, distances, labels);
693
+ for (idx_t j = 0; j < nb; j++) {
694
+ bc.add_bcode(bcodes + j * code_size, ids[j]);
695
+ }
696
+ }
697
+ } else if (k == 4) {
698
+ for (; i + BS <= nq; i += BS) {
699
+ BlockSearch<HammingComputer, BS, 4> bc(
700
+ code_size, x, keys + l0 + i, distances, labels);
701
+ for (idx_t j = 0; j < nb; j++) {
702
+ bc.add_bcode(bcodes + j * code_size, ids[j]);
703
+ }
704
+ }
705
+ } else {
706
+ for (; i + BS <= nq; i += BS) {
707
+ BlockSearchVariableK<HammingComputer, BS> bc(
708
+ code_size, k, x, keys + l0 + i, distances, labels);
709
+ for (idx_t j = 0; j < nb; j++) {
710
+ bc.add_bcode(bcodes + j * code_size, ids[j]);
711
+ }
712
+ }
713
+ }
714
+
715
+ // leftovers
716
+ for (; i < nq; i++) {
717
+ idx_t qno = keys[l0 + i];
718
+ HammingComputer hc(x + qno * code_size, code_size);
719
+ idx_t* __restrict idxi = labels + qno * k;
720
+ int32_t* __restrict simi = distances + qno * k;
721
+ int32_t simi0 = simi[0];
722
+ for (idx_t j = 0; j < nb; j++) {
723
+ int dis = hc.hamming(bcodes + j * code_size);
724
+
725
+ if (dis < simi0) {
726
+ idx_t id = store_pairs ? lo_build(l, j) : ids[j];
727
+ heap_replace_top<C>(k, simi, idxi, dis, id);
728
+ simi0 = simi[0];
729
+ }
730
+ }
731
+ }
732
+ }
733
+ for (idx_t i = 0; i < n; i++) {
734
+ heap_reorder<C>(k, distances + i * k, labels + i * k);
735
+ }
736
+ }
737
+
574
738
  template <bool store_pairs>
575
739
  void search_knn_hamming_count_1(
576
740
  const IndexBinaryIVF& ivf,
@@ -601,7 +765,56 @@ void search_knn_hamming_count_1(
601
765
  }
602
766
  }
603
767
 
604
- } // namespace
768
+ void search_knn_hamming_per_invlist_1(
769
+ const IndexBinaryIVF& ivf,
770
+ size_t n,
771
+ const uint8_t* x,
772
+ idx_t k,
773
+ const idx_t* keys,
774
+ const int32_t* coarse_dis,
775
+ int32_t* distances,
776
+ idx_t* labels,
777
+ bool store_pairs,
778
+ const IVFSearchParameters* params) {
779
+ switch (ivf.code_size) {
780
+ #define HANDLE_CS(cs) \
781
+ case cs: \
782
+ search_knn_hamming_per_invlist<HammingComputer##cs>( \
783
+ ivf, \
784
+ n, \
785
+ x, \
786
+ k, \
787
+ keys, \
788
+ coarse_dis, \
789
+ distances, \
790
+ labels, \
791
+ store_pairs, \
792
+ params); \
793
+ break;
794
+ HANDLE_CS(4);
795
+ HANDLE_CS(8);
796
+ HANDLE_CS(16);
797
+ HANDLE_CS(20);
798
+ HANDLE_CS(32);
799
+ HANDLE_CS(64);
800
+ #undef HANDLE_CS
801
+ default:
802
+ search_knn_hamming_per_invlist<HammingComputerDefault>(
803
+ ivf,
804
+ n,
805
+ x,
806
+ k,
807
+ keys,
808
+ coarse_dis,
809
+ distances,
810
+ labels,
811
+ store_pairs,
812
+ params);
813
+ break;
814
+ }
815
+ }
816
+
817
+ } // anonymous namespace
605
818
 
606
819
  BinaryInvertedListScanner* IndexBinaryIVF::get_InvertedListScanner(
607
820
  bool store_pairs) const {
@@ -635,7 +848,19 @@ void IndexBinaryIVF::search_preassigned(
635
848
  idx_t* labels,
636
849
  bool store_pairs,
637
850
  const IVFSearchParameters* params) const {
638
- if (use_heap) {
851
+ if (per_invlist_search) {
852
+ search_knn_hamming_per_invlist_1(
853
+ *this,
854
+ n,
855
+ x,
856
+ k,
857
+ idx,
858
+ coarse_dis,
859
+ distances,
860
+ labels,
861
+ store_pairs,
862
+ params);
863
+ } else if (use_heap) {
639
864
  search_knn_hamming_heap(
640
865
  *this,
641
866
  n,
@@ -660,9 +885,9 @@ void IndexBinaryIVF::search_preassigned(
660
885
 
661
886
  void IndexBinaryIVF::range_search(
662
887
  idx_t n,
663
- const uint8_t* x,
888
+ const uint8_t* __restrict x,
664
889
  int radius,
665
- RangeSearchResult* res,
890
+ RangeSearchResult* __restrict res,
666
891
  const SearchParameters* params) const {
667
892
  FAISS_THROW_IF_NOT_MSG(
668
893
  !params, "search params not supported for this index");
@@ -684,11 +909,11 @@ void IndexBinaryIVF::range_search(
684
909
 
685
910
  void IndexBinaryIVF::range_search_preassigned(
686
911
  idx_t n,
687
- const uint8_t* x,
912
+ const uint8_t* __restrict x,
688
913
  int radius,
689
- const idx_t* assign,
690
- const int32_t* centroid_dis,
691
- RangeSearchResult* res) const {
914
+ const idx_t* __restrict assign,
915
+ const int32_t* __restrict centroid_dis,
916
+ RangeSearchResult* __restrict res) const {
692
917
  const size_t nprobe = std::min(nlist, this->nprobe);
693
918
  bool store_pairs = false;
694
919
  size_t nlistv = 0, ndis = 0;
@@ -32,27 +32,36 @@ struct BinaryInvertedListScanner;
32
32
  */
33
33
  struct IndexBinaryIVF : IndexBinary {
34
34
  /// Access to the actual data
35
- InvertedLists* invlists;
36
- bool own_invlists;
35
+ InvertedLists* invlists = nullptr;
36
+ bool own_invlists = true;
37
37
 
38
- size_t nprobe; ///< number of probes at query time
39
- size_t max_codes; ///< max nb of codes to visit to do a query
38
+ size_t nprobe = 1; ///< number of probes at query time
39
+ size_t max_codes = 0; ///< max nb of codes to visit to do a query
40
40
 
41
41
  /** Select between using a heap or counting to select the k smallest values
42
42
  * when scanning inverted lists.
43
43
  */
44
44
  bool use_heap = true;
45
45
 
46
+ /** collect computations per batch */
47
+ bool per_invlist_search = false;
48
+
46
49
  /// map for direct access to the elements. Enables reconstruct().
47
50
  DirectMap direct_map;
48
51
 
49
- IndexBinary* quantizer; ///< quantizer that maps vectors to inverted lists
50
- size_t nlist; ///< number of possible key values
52
+ /// quantizer that maps vectors to inverted lists
53
+ IndexBinary* quantizer = nullptr;
54
+
55
+ /// number of possible key values
56
+ size_t nlist = 0;
51
57
 
52
- bool own_fields; ///< whether object owns the quantizer
58
+ /// whether object owns the quantizer
59
+ bool own_fields = false;
53
60
 
54
61
  ClusteringParameters cp; ///< to override default clustering params
55
- Index* clustering_index; ///< to override index used during clustering
62
+
63
+ /// to override index used during clustering
64
+ Index* clustering_index = nullptr;
56
65
 
57
66
  /** The Inverted file takes a quantizer (an IndexBinary) on input,
58
67
  * which implements the function mapping a vector to a list
@@ -196,7 +205,7 @@ struct IndexBinaryIVF : IndexBinary {
196
205
  return invlists->list_size(list_no);
197
206
  }
198
207
 
199
- /** intialize a direct map
208
+ /** initialize a direct map
200
209
  *
201
210
  * @param new_maintain_direct_map if true, create a direct map,
202
211
  * else clear it
@@ -209,8 +218,6 @@ struct IndexBinaryIVF : IndexBinary {
209
218
  };
210
219
 
211
220
  struct BinaryInvertedListScanner {
212
- using idx_t = Index::idx_t;
213
-
214
221
  /// from now on we handle this query.
215
222
  virtual void set_query(const uint8_t* query_vector) = 0;
216
223
 
@@ -98,18 +98,21 @@ void IndexFastScan::add(idx_t n, const float* x) {
98
98
  ntotal += n;
99
99
  }
100
100
 
101
+ CodePacker* IndexFastScan::get_CodePacker() const {
102
+ return new CodePackerPQ4(M, bbs);
103
+ }
104
+
101
105
  size_t IndexFastScan::remove_ids(const IDSelector& sel) {
102
106
  idx_t j = 0;
107
+ std::vector<uint8_t> buffer(code_size);
108
+ CodePackerPQ4 packer(M, bbs);
103
109
  for (idx_t i = 0; i < ntotal; i++) {
104
110
  if (sel.is_member(i)) {
105
111
  // should be removed
106
112
  } else {
107
113
  if (i > j) {
108
- for (int sq = 0; sq < M; sq++) {
109
- uint8_t code =
110
- pq4_get_packed_element(codes.data(), bbs, M, i, sq);
111
- pq4_set_packed_element(codes.data(), code, bbs, M, j, sq);
112
- }
114
+ packer.unpack_1(codes.data(), i, buffer.data());
115
+ packer.pack_1(buffer.data(), j, codes.data());
113
116
  }
114
117
  j++;
115
118
  }
@@ -142,12 +145,12 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
142
145
  IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex);
143
146
  ntotal2 = roundup(ntotal + other->ntotal, bbs);
144
147
  codes.resize(ntotal2 * M2 / 2);
148
+ std::vector<uint8_t> buffer(code_size);
149
+ CodePackerPQ4 packer(M, bbs);
150
+
145
151
  for (int i = 0; i < other->ntotal; i++) {
146
- for (int sq = 0; sq < M; sq++) {
147
- uint8_t code =
148
- pq4_get_packed_element(other->codes.data(), bbs, M, i, sq);
149
- pq4_set_packed_element(codes.data(), code, bbs, M, ntotal + i, sq);
150
- }
152
+ packer.unpack_1(other->codes.data(), i, buffer.data());
153
+ packer.pack_1(buffer.data(), ntotal + i, codes.data());
151
154
  }
152
155
  ntotal += other->ntotal;
153
156
  other->reset();
@@ -12,6 +12,8 @@
12
12
 
13
13
  namespace faiss {
14
14
 
15
+ struct CodePacker;
16
+
15
17
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
16
18
  *
17
19
  * The codes are not stored sequentially but grouped in blocks of size bbs.
@@ -25,7 +27,6 @@ namespace faiss {
25
27
  * 14: no qbs with heap accumulator
26
28
  * 15: no qbs with reservoir accumulator
27
29
  */
28
-
29
30
  struct IndexFastScan : Index {
30
31
  // implementation to select
31
32
  int implem = 0;
@@ -126,6 +127,9 @@ struct IndexFastScan : Index {
126
127
 
127
128
  void reconstruct(idx_t key, float* recons) const override;
128
129
  size_t remove_ids(const IDSelector& sel) override;
130
+
131
+ CodePacker* get_CodePacker() const;
132
+
129
133
  void merge_from(Index& otherIndex, idx_t add_id = 0) override;
130
134
  void check_compatible_for_merge(const Index& otherIndex) const override;
131
135
  };
@@ -14,6 +14,7 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/distances.h>
16
16
  #include <faiss/utils/extra_distances.h>
17
+ #include <faiss/utils/sorting.h>
17
18
  #include <faiss/utils/utils.h>
18
19
  #include <cstring>
19
20
 
@@ -39,6 +40,10 @@ void IndexFlat::search(
39
40
  } else if (metric_type == METRIC_L2) {
40
41
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
41
42
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
43
+ } else if (is_similarity_metric(metric_type)) {
44
+ float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
+ knn_extra_metrics(
46
+ x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
42
47
  } else {
43
48
  FAISS_THROW_IF_NOT(!sel);
44
49
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
@@ -90,7 +95,7 @@ namespace {
90
95
 
91
96
  struct FlatL2Dis : FlatCodesDistanceComputer {
92
97
  size_t d;
93
- Index::idx_t nb;
98
+ idx_t nb;
94
99
  const float* q;
95
100
  const float* b;
96
101
  size_t ndis;
@@ -121,7 +126,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
121
126
 
122
127
  struct FlatIPDis : FlatCodesDistanceComputer {
123
128
  size_t d;
124
- Index::idx_t nb;
129
+ idx_t nb;
125
130
  const float* q;
126
131
  const float* b;
127
132
  size_t ndis;
@@ -222,7 +227,7 @@ void IndexFlat1D::search(
222
227
  perm.size() == ntotal, "Call update_permutation before search");
223
228
  const float* xb = get_xb();
224
229
 
225
- #pragma omp parallel for
230
+ #pragma omp parallel for if (n > 10000)
226
231
  for (idx_t i = 0; i < n; i++) {
227
232
  float q = x[i]; // query
228
233
  float* D = distances + i * k;
@@ -232,6 +237,14 @@ void IndexFlat1D::search(
232
237
  idx_t i0 = 0, i1 = ntotal;
233
238
  idx_t wp = 0;
234
239
 
240
+ if (ntotal == 0) {
241
+ for (idx_t j = 0; j < k; j++) {
242
+ I[j] = -1;
243
+ D[j] = HUGE_VAL;
244
+ }
245
+ goto done;
246
+ }
247
+
235
248
  if (xb[perm[i0]] > q) {
236
249
  i1 = 0;
237
250
  goto finish_right;
@@ -82,7 +82,7 @@ struct IndexFlatL2 : IndexFlat {
82
82
 
83
83
  /// optimized version for 1D "vectors".
84
84
  struct IndexFlat1D : IndexFlatL2 {
85
- bool continuous_update; ///< is the permutation updated continuously?
85
+ bool continuous_update = true; ///< is the permutation updated continuously?
86
86
 
87
87
  std::vector<idx_t> perm; ///< sorted database indices
88
88
 
@@ -8,6 +8,7 @@
8
8
  #include <faiss/IndexFlatCodes.h>
9
9
 
10
10
  #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/CodePacker.h>
11
12
  #include <faiss/impl/DistanceComputer.h>
12
13
  #include <faiss/impl/FaissAssert.h>
13
14
  #include <faiss/impl/IDSelector.h>
@@ -98,4 +99,8 @@ void IndexFlatCodes::merge_from(Index& otherIndex, idx_t add_id) {
98
99
  other->reset();
99
100
  }
100
101
 
102
+ CodePacker* IndexFlatCodes::get_CodePacker() const {
103
+ return new CodePackerFlat(code_size);
104
+ }
105
+
101
106
  } // namespace faiss
@@ -15,6 +15,8 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
+ struct CodePacker;
19
+
18
20
  /** Index that encodes all vectors as fixed-size codes (size code_size). Storage
19
21
  * is in the codes vector */
20
22
  struct IndexFlatCodes : Index {
@@ -39,8 +41,8 @@ struct IndexFlatCodes : Index {
39
41
 
40
42
  size_t sa_code_size() const override;
41
43
 
42
- /** remove some ids. NB that Because of the structure of the
43
- * indexing structure, the semantics of this operation are
44
+ /** remove some ids. NB that because of the structure of the
45
+ * index, the semantics of this operation are
44
46
  * different from the usual ones: the new ids are shifted */
45
47
  size_t remove_ids(const IDSelector& sel) override;
46
48
 
@@ -51,6 +53,9 @@ struct IndexFlatCodes : Index {
51
53
  return get_FlatCodesDistanceComputer();
52
54
  }
53
55
 
56
+ // returns a new instance of a CodePacker
57
+ CodePacker* get_CodePacker() const;
58
+
54
59
  void check_compatible_for_merge(const Index& otherIndex) const override;
55
60
 
56
61
  virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;