faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/impl/HNSW.h>
11
9
 
12
10
  #include <string>
@@ -14,6 +12,17 @@
14
12
  #include <faiss/impl/AuxIndexStructures.h>
15
13
  #include <faiss/impl/DistanceComputer.h>
16
14
  #include <faiss/impl/IDSelector.h>
15
+ #include <faiss/impl/ResultHandler.h>
16
+ #include <faiss/utils/prefetch.h>
17
+
18
+ #include <faiss/impl/platform_macros.h>
19
+
20
+ #ifdef __AVX2__
21
+ #include <immintrin.h>
22
+
23
+ #include <limits>
24
+ #include <type_traits>
25
+ #endif
17
26
 
18
27
  namespace faiss {
19
28
 
@@ -212,7 +221,7 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
212
221
  return max_level;
213
222
  }
214
223
 
215
- /** Enumerate vertices from farthest to nearest from query, keep a
224
+ /** Enumerate vertices from nearest to farthest from query, keep a
216
225
  * neighbor only if there is no previous neighbor that is closer to
217
226
  * that vertex than the query.
218
227
  */
@@ -503,17 +512,15 @@ void HNSW::add_with_locks(
503
512
  **************************************************************/
504
513
 
505
514
  namespace {
506
-
507
515
  using MinimaxHeap = HNSW::MinimaxHeap;
508
516
  using Node = HNSW::Node;
517
+ using C = HNSW::C;
509
518
  /** Do a BFS on the candidates list */
510
519
 
511
520
  int search_from_candidates(
512
521
  const HNSW& hnsw,
513
522
  DistanceComputer& qdis,
514
- int k,
515
- idx_t* I,
516
- float* D,
523
+ ResultHandler<C>& res,
517
524
  MinimaxHeap& candidates,
518
525
  VisitedTable& vt,
519
526
  HNSWStats& stats,
@@ -529,15 +536,16 @@ int search_from_candidates(
529
536
  int efSearch = params ? params->efSearch : hnsw.efSearch;
530
537
  const IDSelector* sel = params ? params->sel : nullptr;
531
538
 
539
+ C::T threshold = res.threshold;
532
540
  for (int i = 0; i < candidates.size(); i++) {
533
541
  idx_t v1 = candidates.ids[i];
534
542
  float d = candidates.dis[i];
535
543
  FAISS_ASSERT(v1 >= 0);
536
544
  if (!sel || sel->is_member(v1)) {
537
- if (nres < k) {
538
- faiss::maxheap_push(++nres, D, I, d, v1);
539
- } else if (d < D[0]) {
540
- faiss::maxheap_replace_top(nres, D, I, d, v1);
545
+ if (d < threshold) {
546
+ if (res.add_result(d, v1)) {
547
+ threshold = res.threshold;
548
+ }
541
549
  }
542
550
  }
543
551
  vt.set(v1);
@@ -563,24 +571,86 @@ int search_from_candidates(
563
571
  size_t begin, end;
564
572
  hnsw.neighbor_range(v0, level, &begin, &end);
565
573
 
574
+ // // baseline version
575
+ // for (size_t j = begin; j < end; j++) {
576
+ // int v1 = hnsw.neighbors[j];
577
+ // if (v1 < 0)
578
+ // break;
579
+ // if (vt.get(v1)) {
580
+ // continue;
581
+ // }
582
+ // vt.set(v1);
583
+ // ndis++;
584
+ // float d = qdis(v1);
585
+ // if (!sel || sel->is_member(v1)) {
586
+ // if (nres < k) {
587
+ // faiss::maxheap_push(++nres, D, I, d, v1);
588
+ // } else if (d < D[0]) {
589
+ // faiss::maxheap_replace_top(nres, D, I, d, v1);
590
+ // }
591
+ // }
592
+ // candidates.push(v1, d);
593
+ // }
594
+
595
+ // the following version processes 4 neighbors at a time
596
+ size_t jmax = begin;
566
597
  for (size_t j = begin; j < end; j++) {
567
598
  int v1 = hnsw.neighbors[j];
568
599
  if (v1 < 0)
569
600
  break;
570
- if (vt.get(v1)) {
571
- continue;
601
+
602
+ prefetch_L2(vt.visited.data() + v1);
603
+ jmax += 1;
604
+ }
605
+
606
+ int counter = 0;
607
+ size_t saved_j[4];
608
+
609
+ ndis += jmax - begin;
610
+ threshold = res.threshold;
611
+
612
+ auto add_to_heap = [&](const size_t idx, const float dis) {
613
+ if (!sel || sel->is_member(idx)) {
614
+ if (dis < threshold) {
615
+ if (res.add_result(dis, idx)) {
616
+ threshold = res.threshold;
617
+ }
618
+ }
572
619
  }
620
+ candidates.push(idx, dis);
621
+ };
622
+
623
+ for (size_t j = begin; j < jmax; j++) {
624
+ int v1 = hnsw.neighbors[j];
625
+
626
+ bool vget = vt.get(v1);
573
627
  vt.set(v1);
574
- ndis++;
575
- float d = qdis(v1);
576
- if (!sel || sel->is_member(v1)) {
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
628
+ saved_j[counter] = v1;
629
+ counter += vget ? 0 : 1;
630
+
631
+ if (counter == 4) {
632
+ float dis[4];
633
+ qdis.distances_batch_4(
634
+ saved_j[0],
635
+ saved_j[1],
636
+ saved_j[2],
637
+ saved_j[3],
638
+ dis[0],
639
+ dis[1],
640
+ dis[2],
641
+ dis[3]);
642
+
643
+ for (size_t id4 = 0; id4 < 4; id4++) {
644
+ add_to_heap(saved_j[id4], dis[id4]);
581
645
  }
646
+
647
+ counter = 0;
582
648
  }
583
- candidates.push(v1, d);
649
+ }
650
+
651
+ for (size_t icnt = 0; icnt < counter; icnt++) {
652
+ float dis = qdis(saved_j[icnt]);
653
+ add_to_heap(saved_j[icnt], dis);
584
654
  }
585
655
 
586
656
  nstep++;
@@ -630,29 +700,92 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
630
700
  size_t begin, end;
631
701
  hnsw.neighbor_range(v0, 0, &begin, &end);
632
702
 
633
- for (size_t j = begin; j < end; ++j) {
703
+ // // baseline version
704
+ // for (size_t j = begin; j < end; ++j) {
705
+ // int v1 = hnsw.neighbors[j];
706
+ //
707
+ // if (v1 < 0) {
708
+ // break;
709
+ // }
710
+ // if (vt->get(v1)) {
711
+ // continue;
712
+ // }
713
+ //
714
+ // vt->set(v1);
715
+ //
716
+ // float d1 = qdis(v1);
717
+ // ++ndis;
718
+ //
719
+ // if (top_candidates.top().first > d1 ||
720
+ // top_candidates.size() < ef) {
721
+ // candidates.emplace(d1, v1);
722
+ // top_candidates.emplace(d1, v1);
723
+ //
724
+ // if (top_candidates.size() > ef) {
725
+ // top_candidates.pop();
726
+ // }
727
+ // }
728
+ // }
729
+
730
+ // the following version processes 4 neighbors at a time
731
+ size_t jmax = begin;
732
+ for (size_t j = begin; j < end; j++) {
634
733
  int v1 = hnsw.neighbors[j];
635
-
636
- if (v1 < 0) {
734
+ if (v1 < 0)
637
735
  break;
638
- }
639
- if (vt->get(v1)) {
640
- continue;
641
- }
642
736
 
643
- vt->set(v1);
737
+ prefetch_L2(vt->visited.data() + v1);
738
+ jmax += 1;
739
+ }
740
+
741
+ int counter = 0;
742
+ size_t saved_j[4];
644
743
 
645
- float d1 = qdis(v1);
646
- ++ndis;
744
+ ndis += jmax - begin;
647
745
 
648
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
649
- candidates.emplace(d1, v1);
650
- top_candidates.emplace(d1, v1);
746
+ auto add_to_heap = [&](const size_t idx, const float dis) {
747
+ if (top_candidates.top().first > dis ||
748
+ top_candidates.size() < ef) {
749
+ candidates.emplace(dis, idx);
750
+ top_candidates.emplace(dis, idx);
651
751
 
652
752
  if (top_candidates.size() > ef) {
653
753
  top_candidates.pop();
654
754
  }
655
755
  }
756
+ };
757
+
758
+ for (size_t j = begin; j < jmax; j++) {
759
+ int v1 = hnsw.neighbors[j];
760
+
761
+ bool vget = vt->get(v1);
762
+ vt->set(v1);
763
+ saved_j[counter] = v1;
764
+ counter += vget ? 0 : 1;
765
+
766
+ if (counter == 4) {
767
+ float dis[4];
768
+ qdis.distances_batch_4(
769
+ saved_j[0],
770
+ saved_j[1],
771
+ saved_j[2],
772
+ saved_j[3],
773
+ dis[0],
774
+ dis[1],
775
+ dis[2],
776
+ dis[3]);
777
+
778
+ for (size_t id4 = 0; id4 < 4; id4++) {
779
+ add_to_heap(saved_j[id4], dis[id4]);
780
+ }
781
+
782
+ counter = 0;
783
+ }
784
+ }
785
+
786
+ for (size_t icnt = 0; icnt < counter; icnt++) {
787
+ float dis = qdis(saved_j[icnt]);
788
+ add_to_heap(saved_j[icnt], dis);
656
789
  }
657
790
  }
658
791
 
@@ -665,19 +798,28 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
665
798
  return top_candidates;
666
799
  }
667
800
 
801
+ // just used as a lower bound for the minmaxheap, but it is set for heap search
802
+ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
803
+ using RH = HeapBlockResultHandler<C>;
804
+ if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
805
+ return hres->k;
806
+ }
807
+ return 1;
808
+ }
809
+
668
810
  } // anonymous namespace
669
811
 
670
812
  HNSWStats HNSW::search(
671
813
  DistanceComputer& qdis,
672
- int k,
673
- idx_t* I,
674
- float* D,
814
+ ResultHandler<C>& res,
675
815
  VisitedTable& vt,
676
816
  const SearchParametersHNSW* params) const {
677
817
  HNSWStats stats;
678
818
  if (entry_point == -1) {
679
819
  return stats;
680
820
  }
821
+ int k = extract_k_from_ResultHandler(res);
822
+
681
823
  if (upper_beam == 1) {
682
824
  // greedy search on upper levels
683
825
  storage_idx_t nearest = entry_point;
@@ -687,14 +829,14 @@ HNSWStats HNSW::search(
687
829
  greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
688
830
  }
689
831
 
690
- int ef = std::max(efSearch, k);
832
+ int ef = std::max(params ? params->efSearch : efSearch, k);
691
833
  if (search_bounded_queue) { // this is the most common branch
692
834
  MinimaxHeap candidates(ef);
693
835
 
694
836
  candidates.push(nearest, d_nearest);
695
837
 
696
838
  search_from_candidates(
697
- *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
839
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
698
840
  } else {
699
841
  std::priority_queue<Node> top_candidates =
700
842
  search_from_candidate_unbounded(
@@ -709,12 +851,11 @@ HNSWStats HNSW::search(
709
851
  top_candidates.pop();
710
852
  }
711
853
 
712
- int nres = 0;
713
854
  while (!top_candidates.empty()) {
714
855
  float d;
715
856
  storage_idx_t label;
716
857
  std::tie(d, label) = top_candidates.top();
717
- faiss::maxheap_push(++nres, D, I, d, label);
858
+ res.add_result(d, label);
718
859
  top_candidates.pop();
719
860
  }
720
861
  }
@@ -728,6 +869,10 @@ HNSWStats HNSW::search(
728
869
  std::vector<idx_t> I_to_next(candidates_size);
729
870
  std::vector<float> D_to_next(candidates_size);
730
871
 
872
+ HeapBlockResultHandler<C> block_resh(
873
+ 1, D_to_next.data(), I_to_next.data(), candidates_size);
874
+ HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
875
+
731
876
  int nres = 1;
732
877
  I_to_next[0] = entry_point;
733
878
  D_to_next[0] = qdis(entry_point);
@@ -743,18 +888,12 @@ HNSWStats HNSW::search(
743
888
 
744
889
  if (level == 0) {
745
890
  nres = search_from_candidates(
746
- *this, qdis, k, I, D, candidates, vt, stats, 0);
891
+ *this, qdis, res, candidates, vt, stats, 0);
747
892
  } else {
893
+ resh.begin(0);
748
894
  nres = search_from_candidates(
749
- *this,
750
- qdis,
751
- candidates_size,
752
- I_to_next.data(),
753
- D_to_next.data(),
754
- candidates,
755
- vt,
756
- stats,
757
- level);
895
+ *this, qdis, resh, candidates, vt, stats, level);
896
+ resh.end();
758
897
  }
759
898
  vt.advance();
760
899
  }
@@ -765,9 +904,7 @@ HNSWStats HNSW::search(
765
904
 
766
905
  void HNSW::search_level_0(
767
906
  DistanceComputer& qdis,
768
- int k,
769
- idx_t* idxi,
770
- float* simi,
907
+ ResultHandler<C>& res,
771
908
  idx_t nprobe,
772
909
  const storage_idx_t* nearest_i,
773
910
  const float* nearest_d,
@@ -775,7 +912,7 @@ void HNSW::search_level_0(
775
912
  HNSWStats& search_stats,
776
913
  VisitedTable& vt) const {
777
914
  const HNSW& hnsw = *this;
778
-
915
+ int k = extract_k_from_ResultHandler(res);
779
916
  if (search_type == 1) {
780
917
  int nres = 0;
781
918
 
@@ -788,22 +925,13 @@ void HNSW::search_level_0(
788
925
  if (vt.get(cj))
789
926
  continue;
790
927
 
791
- int candidates_size = std::max(hnsw.efSearch, int(k));
928
+ int candidates_size = std::max(hnsw.efSearch, k);
792
929
  MinimaxHeap candidates(candidates_size);
793
930
 
794
931
  candidates.push(cj, nearest_d[j]);
795
932
 
796
933
  nres = search_from_candidates(
797
- hnsw,
798
- qdis,
799
- k,
800
- idxi,
801
- simi,
802
- candidates,
803
- vt,
804
- search_stats,
805
- 0,
806
- nres);
934
+ hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
807
935
  }
808
936
  } else if (search_type == 2) {
809
937
  int candidates_size = std::max(hnsw.efSearch, int(k));
@@ -819,10 +947,43 @@ void HNSW::search_level_0(
819
947
  }
820
948
 
821
949
  search_from_candidates(
822
- hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
950
+ hnsw, qdis, res, candidates, vt, search_stats, 0);
823
951
  }
824
952
  }
825
953
 
954
+ void HNSW::permute_entries(const idx_t* map) {
955
+ // remap levels
956
+ storage_idx_t ntotal = levels.size();
957
+ std::vector<storage_idx_t> imap(ntotal); // inverse mapping
958
+ // map: new index -> old index
959
+ // imap: old index -> new index
960
+ for (int i = 0; i < ntotal; i++) {
961
+ assert(map[i] >= 0 && map[i] < ntotal);
962
+ imap[map[i]] = i;
963
+ }
964
+ if (entry_point != -1) {
965
+ entry_point = imap[entry_point];
966
+ }
967
+ std::vector<int> new_levels(ntotal);
968
+ std::vector<size_t> new_offsets(ntotal + 1);
969
+ std::vector<storage_idx_t> new_neighbors(neighbors.size());
970
+ size_t no = 0;
971
+ for (int i = 0; i < ntotal; i++) {
972
+ storage_idx_t o = map[i]; // corresponding "old" index
973
+ new_levels[i] = levels[o];
974
+ for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
975
+ storage_idx_t neigh = neighbors[j];
976
+ new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
977
+ }
978
+ new_offsets[i + 1] = no;
979
+ }
980
+ assert(new_offsets[ntotal] == offsets[ntotal]);
981
+ // swap everyone
982
+ std::swap(levels, new_levels);
983
+ std::swap(offsets, new_offsets);
984
+ std::swap(neighbors, new_neighbors);
985
+ }
986
+
826
987
  /**************************************************************
827
988
  * MinimaxHeap
828
989
  **************************************************************/
@@ -852,17 +1013,105 @@ void HNSW::MinimaxHeap::clear() {
852
1013
  nvalid = k = 0;
853
1014
  }
854
1015
 
1016
+ #ifdef __AVX2__
1017
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1018
+ assert(k > 0);
1019
+ static_assert(
1020
+ std::is_same<storage_idx_t, int32_t>::value,
1021
+ "This code expects storage_idx_t to be int32_t");
1022
+
1023
+ int32_t min_idx = -1;
1024
+ float min_dis = std::numeric_limits<float>::infinity();
1025
+
1026
+ size_t iii = 0;
1027
+
1028
+ __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1029
+ __m256 min_distances =
1030
+ _mm256_set1_ps(std::numeric_limits<float>::infinity());
1031
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1032
+ __m256i offset = _mm256_set1_epi32(8);
1033
+
1034
+ // The baseline version is available in non-AVX2 branch.
1035
+
1036
+ // The following loop tracks the rightmost index with the min distance.
1037
+ // -1 index values are ignored.
1038
+ const int k8 = (k / 8) * 8;
1039
+ for (; iii < k8; iii += 8) {
1040
+ __m256i indices =
1041
+ _mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1042
+ __m256 distances = _mm256_loadu_ps(dis.data() + iii);
1043
+
1044
+ // This mask filters out -1 values among indices.
1045
+ __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1046
+
1047
+ __m256i dmask = _mm256_castps_si256(
1048
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1049
+ __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1050
+
1051
+ const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1052
+ _mm256_castsi256_ps(current_indices),
1053
+ _mm256_castsi256_ps(min_indices),
1054
+ finalmask));
1055
+
1056
+ const __m256 min_distances_new =
1057
+ _mm256_blendv_ps(distances, min_distances, finalmask);
1058
+
1059
+ min_indices = min_indices_new;
1060
+ min_distances = min_distances_new;
1061
+
1062
+ current_indices = _mm256_add_epi32(current_indices, offset);
1063
+ }
1064
+
1065
+ // Vectorizing is doable, but is not practical
1066
+ int32_t vidx8[8];
1067
+ float vdis8[8];
1068
+ _mm256_storeu_ps(vdis8, min_distances);
1069
+ _mm256_storeu_si256((__m256i*)vidx8, min_indices);
1070
+
1071
+ for (size_t j = 0; j < 8; j++) {
1072
+ if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1073
+ min_idx = vidx8[j];
1074
+ min_dis = vdis8[j];
1075
+ }
1076
+ }
1077
+
1078
+ // process last values. Vectorizing is doable, but is not practical
1079
+ for (; iii < k; iii++) {
1080
+ if (ids[iii] != -1 && dis[iii] <= min_dis) {
1081
+ min_dis = dis[iii];
1082
+ min_idx = iii;
1083
+ }
1084
+ }
1085
+
1086
+ if (min_idx == -1) {
1087
+ return -1;
1088
+ }
1089
+
1090
+ if (vmin_out) {
1091
+ *vmin_out = min_dis;
1092
+ }
1093
+ int ret = ids[min_idx];
1094
+ ids[min_idx] = -1;
1095
+ --nvalid;
1096
+ return ret;
1097
+ }
1098
+
1099
+ #else
1100
+
1101
+ // baseline non-vectorized version
855
1102
  int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
856
1103
  assert(k > 0);
857
1104
  // returns min. This is an O(n) operation
858
1105
  int i = k - 1;
859
1106
  while (i >= 0) {
860
- if (ids[i] != -1)
1107
+ if (ids[i] != -1) {
861
1108
  break;
1109
+ }
862
1110
  i--;
863
1111
  }
864
- if (i == -1)
1112
+ if (i == -1) {
865
1113
  return -1;
1114
+ }
866
1115
  int imin = i;
867
1116
  float vmin = dis[i];
868
1117
  i--;
@@ -873,14 +1122,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
873
1122
  }
874
1123
  i--;
875
1124
  }
876
- if (vmin_out)
1125
+ if (vmin_out) {
877
1126
  *vmin_out = vmin;
1127
+ }
878
1128
  int ret = ids[imin];
879
1129
  ids[imin] = -1;
880
1130
  --nvalid;
881
1131
 
882
1132
  return ret;
883
1133
  }
1134
+ #endif
884
1135
 
885
1136
  int HNSW::MinimaxHeap::count_below(float thresh) {
886
1137
  int n_below = 0;
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #pragma once
11
9
 
12
10
  #include <queue>
@@ -42,6 +40,8 @@ namespace faiss {
42
40
  struct VisitedTable;
43
41
  struct DistanceComputer; // from AuxIndexStructures
44
42
  struct HNSWStats;
43
+ template <class C>
44
+ struct ResultHandler;
45
45
 
46
46
  struct SearchParametersHNSW : SearchParameters {
47
47
  int efSearch = 16;
@@ -54,6 +54,9 @@ struct HNSW {
54
54
  /// internal storage of vectors (32 bits: this is expensive)
55
55
  using storage_idx_t = int32_t;
56
56
 
57
+ // for now we do only these distances
58
+ using C = CMax<float, int64_t>;
59
+
57
60
  typedef std::pair<float, storage_idx_t> Node;
58
61
 
59
62
  /** Heap structure that allows fast
@@ -195,18 +198,14 @@ struct HNSW {
195
198
  /// search interface for 1 point, single thread
196
199
  HNSWStats search(
197
200
  DistanceComputer& qdis,
198
- int k,
199
- idx_t* I,
200
- float* D,
201
+ ResultHandler<C>& res,
201
202
  VisitedTable& vt,
202
203
  const SearchParametersHNSW* params = nullptr) const;
203
204
 
204
205
  /// search only in level 0 from a given vertex
205
206
  void search_level_0(
206
207
  DistanceComputer& qdis,
207
- int k,
208
- idx_t* idxi,
209
- float* simi,
208
+ ResultHandler<C>& res,
210
209
  idx_t nprobe,
211
210
  const storage_idx_t* nearest_i,
212
211
  const float* nearest_d,
@@ -226,6 +225,8 @@ struct HNSW {
226
225
  std::priority_queue<NodeDistFarther>& input,
227
226
  std::vector<NodeDistFarther>& output,
228
227
  int max_size);
228
+
229
+ void permute_entries(const idx_t* map);
229
230
  };
230
231
 
231
232
  struct HNSWStats {
@@ -10,7 +10,7 @@
10
10
  #include <unordered_set>
11
11
  #include <vector>
12
12
 
13
- #include <faiss/Index.h>
13
+ #include <faiss/MetricType.h>
14
14
 
15
15
  /** IDSelector is intended to define a subset of vectors to handle (for removal
16
16
  * or as subset to search) */
@@ -140,7 +140,7 @@ struct IDSelectorAnd : IDSelector {
140
140
  : lhs(lhs), rhs(rhs) {}
141
141
  bool is_member(idx_t id) const final {
142
142
  return lhs->is_member(id) && rhs->is_member(id);
143
- };
143
+ }
144
144
  virtual ~IDSelectorAnd() {}
145
145
  };
146
146
 
@@ -153,7 +153,7 @@ struct IDSelectorOr : IDSelector {
153
153
  : lhs(lhs), rhs(rhs) {}
154
154
  bool is_member(idx_t id) const final {
155
155
  return lhs->is_member(id) || rhs->is_member(id);
156
- };
156
+ }
157
157
  virtual ~IDSelectorOr() {}
158
158
  };
159
159
 
@@ -166,7 +166,7 @@ struct IDSelectorXOr : IDSelector {
166
166
  : lhs(lhs), rhs(rhs) {}
167
167
  bool is_member(idx_t id) const final {
168
168
  return lhs->is_member(id) ^ rhs->is_member(id);
169
- };
169
+ }
170
170
  virtual ~IDSelectorXOr() {}
171
171
  };
172
172
 
@@ -628,7 +628,9 @@ void LocalSearchQuantizer::icm_encode_step(
628
628
  {
629
629
  size_t binary_idx = (other_m + 1) * M * K * K +
630
630
  m * K * K + code2 * K + code;
631
- _mm_prefetch(binaries + binary_idx, _MM_HINT_T0);
631
+ _mm_prefetch(
632
+ (const char*)(binaries + binary_idx),
633
+ _MM_HINT_T0);
632
634
  }
633
635
  }
634
636
  #endif