faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/impl/HNSW.h>
11
9
 
12
10
  #include <string>
@@ -14,6 +12,17 @@
14
12
  #include <faiss/impl/AuxIndexStructures.h>
15
13
  #include <faiss/impl/DistanceComputer.h>
16
14
  #include <faiss/impl/IDSelector.h>
15
+ #include <faiss/impl/ResultHandler.h>
16
+ #include <faiss/utils/prefetch.h>
17
+
18
+ #include <faiss/impl/platform_macros.h>
19
+
20
+ #ifdef __AVX2__
21
+ #include <immintrin.h>
22
+
23
+ #include <limits>
24
+ #include <type_traits>
25
+ #endif
17
26
 
18
27
  namespace faiss {
19
28
 
@@ -212,7 +221,7 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
212
221
  return max_level;
213
222
  }
214
223
 
215
- /** Enumerate vertices from farthest to nearest from query, keep a
224
+ /** Enumerate vertices from nearest to farthest from query, keep a
216
225
  * neighbor only if there is no previous neighbor that is closer to
217
226
  * that vertex than the query.
218
227
  */
@@ -503,17 +512,15 @@ void HNSW::add_with_locks(
503
512
  **************************************************************/
504
513
 
505
514
  namespace {
506
-
507
515
  using MinimaxHeap = HNSW::MinimaxHeap;
508
516
  using Node = HNSW::Node;
517
+ using C = HNSW::C;
509
518
  /** Do a BFS on the candidates list */
510
519
 
511
520
  int search_from_candidates(
512
521
  const HNSW& hnsw,
513
522
  DistanceComputer& qdis,
514
- int k,
515
- idx_t* I,
516
- float* D,
523
+ ResultHandler<C>& res,
517
524
  MinimaxHeap& candidates,
518
525
  VisitedTable& vt,
519
526
  HNSWStats& stats,
@@ -529,15 +536,16 @@ int search_from_candidates(
529
536
  int efSearch = params ? params->efSearch : hnsw.efSearch;
530
537
  const IDSelector* sel = params ? params->sel : nullptr;
531
538
 
539
+ C::T threshold = res.threshold;
532
540
  for (int i = 0; i < candidates.size(); i++) {
533
541
  idx_t v1 = candidates.ids[i];
534
542
  float d = candidates.dis[i];
535
543
  FAISS_ASSERT(v1 >= 0);
536
544
  if (!sel || sel->is_member(v1)) {
537
- if (nres < k) {
538
- faiss::maxheap_push(++nres, D, I, d, v1);
539
- } else if (d < D[0]) {
540
- faiss::maxheap_replace_top(nres, D, I, d, v1);
545
+ if (d < threshold) {
546
+ if (res.add_result(d, v1)) {
547
+ threshold = res.threshold;
548
+ }
541
549
  }
542
550
  }
543
551
  vt.set(v1);
@@ -563,24 +571,86 @@ int search_from_candidates(
563
571
  size_t begin, end;
564
572
  hnsw.neighbor_range(v0, level, &begin, &end);
565
573
 
574
+ // // baseline version
575
+ // for (size_t j = begin; j < end; j++) {
576
+ // int v1 = hnsw.neighbors[j];
577
+ // if (v1 < 0)
578
+ // break;
579
+ // if (vt.get(v1)) {
580
+ // continue;
581
+ // }
582
+ // vt.set(v1);
583
+ // ndis++;
584
+ // float d = qdis(v1);
585
+ // if (!sel || sel->is_member(v1)) {
586
+ // if (nres < k) {
587
+ // faiss::maxheap_push(++nres, D, I, d, v1);
588
+ // } else if (d < D[0]) {
589
+ // faiss::maxheap_replace_top(nres, D, I, d, v1);
590
+ // }
591
+ // }
592
+ // candidates.push(v1, d);
593
+ // }
594
+
595
+ // the following version processes 4 neighbors at a time
596
+ size_t jmax = begin;
566
597
  for (size_t j = begin; j < end; j++) {
567
598
  int v1 = hnsw.neighbors[j];
568
599
  if (v1 < 0)
569
600
  break;
570
- if (vt.get(v1)) {
571
- continue;
601
+
602
+ prefetch_L2(vt.visited.data() + v1);
603
+ jmax += 1;
604
+ }
605
+
606
+ int counter = 0;
607
+ size_t saved_j[4];
608
+
609
+ ndis += jmax - begin;
610
+ threshold = res.threshold;
611
+
612
+ auto add_to_heap = [&](const size_t idx, const float dis) {
613
+ if (!sel || sel->is_member(idx)) {
614
+ if (dis < threshold) {
615
+ if (res.add_result(dis, idx)) {
616
+ threshold = res.threshold;
617
+ }
618
+ }
572
619
  }
620
+ candidates.push(idx, dis);
621
+ };
622
+
623
+ for (size_t j = begin; j < jmax; j++) {
624
+ int v1 = hnsw.neighbors[j];
625
+
626
+ bool vget = vt.get(v1);
573
627
  vt.set(v1);
574
- ndis++;
575
- float d = qdis(v1);
576
- if (!sel || sel->is_member(v1)) {
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
628
+ saved_j[counter] = v1;
629
+ counter += vget ? 0 : 1;
630
+
631
+ if (counter == 4) {
632
+ float dis[4];
633
+ qdis.distances_batch_4(
634
+ saved_j[0],
635
+ saved_j[1],
636
+ saved_j[2],
637
+ saved_j[3],
638
+ dis[0],
639
+ dis[1],
640
+ dis[2],
641
+ dis[3]);
642
+
643
+ for (size_t id4 = 0; id4 < 4; id4++) {
644
+ add_to_heap(saved_j[id4], dis[id4]);
581
645
  }
646
+
647
+ counter = 0;
582
648
  }
583
- candidates.push(v1, d);
649
+ }
650
+
651
+ for (size_t icnt = 0; icnt < counter; icnt++) {
652
+ float dis = qdis(saved_j[icnt]);
653
+ add_to_heap(saved_j[icnt], dis);
584
654
  }
585
655
 
586
656
  nstep++;
@@ -630,29 +700,92 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
630
700
  size_t begin, end;
631
701
  hnsw.neighbor_range(v0, 0, &begin, &end);
632
702
 
633
- for (size_t j = begin; j < end; ++j) {
703
+ // // baseline version
704
+ // for (size_t j = begin; j < end; ++j) {
705
+ // int v1 = hnsw.neighbors[j];
706
+ //
707
+ // if (v1 < 0) {
708
+ // break;
709
+ // }
710
+ // if (vt->get(v1)) {
711
+ // continue;
712
+ // }
713
+ //
714
+ // vt->set(v1);
715
+ //
716
+ // float d1 = qdis(v1);
717
+ // ++ndis;
718
+ //
719
+ // if (top_candidates.top().first > d1 ||
720
+ // top_candidates.size() < ef) {
721
+ // candidates.emplace(d1, v1);
722
+ // top_candidates.emplace(d1, v1);
723
+ //
724
+ // if (top_candidates.size() > ef) {
725
+ // top_candidates.pop();
726
+ // }
727
+ // }
728
+ // }
729
+
730
+ // the following version processes 4 neighbors at a time
731
+ size_t jmax = begin;
732
+ for (size_t j = begin; j < end; j++) {
634
733
  int v1 = hnsw.neighbors[j];
635
-
636
- if (v1 < 0) {
734
+ if (v1 < 0)
637
735
  break;
638
- }
639
- if (vt->get(v1)) {
640
- continue;
641
- }
642
736
 
643
- vt->set(v1);
737
+ prefetch_L2(vt->visited.data() + v1);
738
+ jmax += 1;
739
+ }
740
+
741
+ int counter = 0;
742
+ size_t saved_j[4];
644
743
 
645
- float d1 = qdis(v1);
646
- ++ndis;
744
+ ndis += jmax - begin;
647
745
 
648
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
649
- candidates.emplace(d1, v1);
650
- top_candidates.emplace(d1, v1);
746
+ auto add_to_heap = [&](const size_t idx, const float dis) {
747
+ if (top_candidates.top().first > dis ||
748
+ top_candidates.size() < ef) {
749
+ candidates.emplace(dis, idx);
750
+ top_candidates.emplace(dis, idx);
651
751
 
652
752
  if (top_candidates.size() > ef) {
653
753
  top_candidates.pop();
654
754
  }
655
755
  }
756
+ };
757
+
758
+ for (size_t j = begin; j < jmax; j++) {
759
+ int v1 = hnsw.neighbors[j];
760
+
761
+ bool vget = vt->get(v1);
762
+ vt->set(v1);
763
+ saved_j[counter] = v1;
764
+ counter += vget ? 0 : 1;
765
+
766
+ if (counter == 4) {
767
+ float dis[4];
768
+ qdis.distances_batch_4(
769
+ saved_j[0],
770
+ saved_j[1],
771
+ saved_j[2],
772
+ saved_j[3],
773
+ dis[0],
774
+ dis[1],
775
+ dis[2],
776
+ dis[3]);
777
+
778
+ for (size_t id4 = 0; id4 < 4; id4++) {
779
+ add_to_heap(saved_j[id4], dis[id4]);
780
+ }
781
+
782
+ counter = 0;
783
+ }
784
+ }
785
+
786
+ for (size_t icnt = 0; icnt < counter; icnt++) {
787
+ float dis = qdis(saved_j[icnt]);
788
+ add_to_heap(saved_j[icnt], dis);
656
789
  }
657
790
  }
658
791
 
@@ -665,19 +798,28 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
665
798
  return top_candidates;
666
799
  }
667
800
 
801
+ // just used as a lower bound for the minmaxheap, but it is set for heap search
802
+ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
803
+ using RH = HeapBlockResultHandler<C>;
804
+ if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
805
+ return hres->k;
806
+ }
807
+ return 1;
808
+ }
809
+
668
810
  } // anonymous namespace
669
811
 
670
812
  HNSWStats HNSW::search(
671
813
  DistanceComputer& qdis,
672
- int k,
673
- idx_t* I,
674
- float* D,
814
+ ResultHandler<C>& res,
675
815
  VisitedTable& vt,
676
816
  const SearchParametersHNSW* params) const {
677
817
  HNSWStats stats;
678
818
  if (entry_point == -1) {
679
819
  return stats;
680
820
  }
821
+ int k = extract_k_from_ResultHandler(res);
822
+
681
823
  if (upper_beam == 1) {
682
824
  // greedy search on upper levels
683
825
  storage_idx_t nearest = entry_point;
@@ -687,14 +829,14 @@ HNSWStats HNSW::search(
687
829
  greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
688
830
  }
689
831
 
690
- int ef = std::max(efSearch, k);
832
+ int ef = std::max(params ? params->efSearch : efSearch, k);
691
833
  if (search_bounded_queue) { // this is the most common branch
692
834
  MinimaxHeap candidates(ef);
693
835
 
694
836
  candidates.push(nearest, d_nearest);
695
837
 
696
838
  search_from_candidates(
697
- *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
839
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
698
840
  } else {
699
841
  std::priority_queue<Node> top_candidates =
700
842
  search_from_candidate_unbounded(
@@ -709,12 +851,11 @@ HNSWStats HNSW::search(
709
851
  top_candidates.pop();
710
852
  }
711
853
 
712
- int nres = 0;
713
854
  while (!top_candidates.empty()) {
714
855
  float d;
715
856
  storage_idx_t label;
716
857
  std::tie(d, label) = top_candidates.top();
717
- faiss::maxheap_push(++nres, D, I, d, label);
858
+ res.add_result(d, label);
718
859
  top_candidates.pop();
719
860
  }
720
861
  }
@@ -728,6 +869,10 @@ HNSWStats HNSW::search(
728
869
  std::vector<idx_t> I_to_next(candidates_size);
729
870
  std::vector<float> D_to_next(candidates_size);
730
871
 
872
+ HeapBlockResultHandler<C> block_resh(
873
+ 1, D_to_next.data(), I_to_next.data(), candidates_size);
874
+ HeapBlockResultHandler<C>::SingleResultHandler resh(block_resh);
875
+
731
876
  int nres = 1;
732
877
  I_to_next[0] = entry_point;
733
878
  D_to_next[0] = qdis(entry_point);
@@ -743,18 +888,12 @@ HNSWStats HNSW::search(
743
888
 
744
889
  if (level == 0) {
745
890
  nres = search_from_candidates(
746
- *this, qdis, k, I, D, candidates, vt, stats, 0);
891
+ *this, qdis, res, candidates, vt, stats, 0);
747
892
  } else {
893
+ resh.begin(0);
748
894
  nres = search_from_candidates(
749
- *this,
750
- qdis,
751
- candidates_size,
752
- I_to_next.data(),
753
- D_to_next.data(),
754
- candidates,
755
- vt,
756
- stats,
757
- level);
895
+ *this, qdis, resh, candidates, vt, stats, level);
896
+ resh.end();
758
897
  }
759
898
  vt.advance();
760
899
  }
@@ -765,9 +904,7 @@ HNSWStats HNSW::search(
765
904
 
766
905
  void HNSW::search_level_0(
767
906
  DistanceComputer& qdis,
768
- int k,
769
- idx_t* idxi,
770
- float* simi,
907
+ ResultHandler<C>& res,
771
908
  idx_t nprobe,
772
909
  const storage_idx_t* nearest_i,
773
910
  const float* nearest_d,
@@ -775,7 +912,7 @@ void HNSW::search_level_0(
775
912
  HNSWStats& search_stats,
776
913
  VisitedTable& vt) const {
777
914
  const HNSW& hnsw = *this;
778
-
915
+ int k = extract_k_from_ResultHandler(res);
779
916
  if (search_type == 1) {
780
917
  int nres = 0;
781
918
 
@@ -788,22 +925,13 @@ void HNSW::search_level_0(
788
925
  if (vt.get(cj))
789
926
  continue;
790
927
 
791
- int candidates_size = std::max(hnsw.efSearch, int(k));
928
+ int candidates_size = std::max(hnsw.efSearch, k);
792
929
  MinimaxHeap candidates(candidates_size);
793
930
 
794
931
  candidates.push(cj, nearest_d[j]);
795
932
 
796
933
  nres = search_from_candidates(
797
- hnsw,
798
- qdis,
799
- k,
800
- idxi,
801
- simi,
802
- candidates,
803
- vt,
804
- search_stats,
805
- 0,
806
- nres);
934
+ hnsw, qdis, res, candidates, vt, search_stats, 0, nres);
807
935
  }
808
936
  } else if (search_type == 2) {
809
937
  int candidates_size = std::max(hnsw.efSearch, int(k));
@@ -819,10 +947,43 @@ void HNSW::search_level_0(
819
947
  }
820
948
 
821
949
  search_from_candidates(
822
- hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
950
+ hnsw, qdis, res, candidates, vt, search_stats, 0);
823
951
  }
824
952
  }
825
953
 
954
+ void HNSW::permute_entries(const idx_t* map) {
955
+ // remap levels
956
+ storage_idx_t ntotal = levels.size();
957
+ std::vector<storage_idx_t> imap(ntotal); // inverse mapping
958
+ // map: new index -> old index
959
+ // imap: old index -> new index
960
+ for (int i = 0; i < ntotal; i++) {
961
+ assert(map[i] >= 0 && map[i] < ntotal);
962
+ imap[map[i]] = i;
963
+ }
964
+ if (entry_point != -1) {
965
+ entry_point = imap[entry_point];
966
+ }
967
+ std::vector<int> new_levels(ntotal);
968
+ std::vector<size_t> new_offsets(ntotal + 1);
969
+ std::vector<storage_idx_t> new_neighbors(neighbors.size());
970
+ size_t no = 0;
971
+ for (int i = 0; i < ntotal; i++) {
972
+ storage_idx_t o = map[i]; // corresponding "old" index
973
+ new_levels[i] = levels[o];
974
+ for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
975
+ storage_idx_t neigh = neighbors[j];
976
+ new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
977
+ }
978
+ new_offsets[i + 1] = no;
979
+ }
980
+ assert(new_offsets[ntotal] == offsets[ntotal]);
981
+ // swap everyone
982
+ std::swap(levels, new_levels);
983
+ std::swap(offsets, new_offsets);
984
+ std::swap(neighbors, new_neighbors);
985
+ }
986
+
826
987
  /**************************************************************
827
988
  * MinimaxHeap
828
989
  **************************************************************/
@@ -852,17 +1013,105 @@ void HNSW::MinimaxHeap::clear() {
852
1013
  nvalid = k = 0;
853
1014
  }
854
1015
 
1016
+ #ifdef __AVX2__
1017
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1018
+ assert(k > 0);
1019
+ static_assert(
1020
+ std::is_same<storage_idx_t, int32_t>::value,
1021
+ "This code expects storage_idx_t to be int32_t");
1022
+
1023
+ int32_t min_idx = -1;
1024
+ float min_dis = std::numeric_limits<float>::infinity();
1025
+
1026
+ size_t iii = 0;
1027
+
1028
+ __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1029
+ __m256 min_distances =
1030
+ _mm256_set1_ps(std::numeric_limits<float>::infinity());
1031
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1032
+ __m256i offset = _mm256_set1_epi32(8);
1033
+
1034
+ // The baseline version is available in non-AVX2 branch.
1035
+
1036
+ // The following loop tracks the rightmost index with the min distance.
1037
+ // -1 index values are ignored.
1038
+ const int k8 = (k / 8) * 8;
1039
+ for (; iii < k8; iii += 8) {
1040
+ __m256i indices =
1041
+ _mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1042
+ __m256 distances = _mm256_loadu_ps(dis.data() + iii);
1043
+
1044
+ // This mask filters out -1 values among indices.
1045
+ __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1046
+
1047
+ __m256i dmask = _mm256_castps_si256(
1048
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1049
+ __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1050
+
1051
+ const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1052
+ _mm256_castsi256_ps(current_indices),
1053
+ _mm256_castsi256_ps(min_indices),
1054
+ finalmask));
1055
+
1056
+ const __m256 min_distances_new =
1057
+ _mm256_blendv_ps(distances, min_distances, finalmask);
1058
+
1059
+ min_indices = min_indices_new;
1060
+ min_distances = min_distances_new;
1061
+
1062
+ current_indices = _mm256_add_epi32(current_indices, offset);
1063
+ }
1064
+
1065
+ // Vectorizing is doable, but is not practical
1066
+ int32_t vidx8[8];
1067
+ float vdis8[8];
1068
+ _mm256_storeu_ps(vdis8, min_distances);
1069
+ _mm256_storeu_si256((__m256i*)vidx8, min_indices);
1070
+
1071
+ for (size_t j = 0; j < 8; j++) {
1072
+ if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1073
+ min_idx = vidx8[j];
1074
+ min_dis = vdis8[j];
1075
+ }
1076
+ }
1077
+
1078
+ // process last values. Vectorizing is doable, but is not practical
1079
+ for (; iii < k; iii++) {
1080
+ if (ids[iii] != -1 && dis[iii] <= min_dis) {
1081
+ min_dis = dis[iii];
1082
+ min_idx = iii;
1083
+ }
1084
+ }
1085
+
1086
+ if (min_idx == -1) {
1087
+ return -1;
1088
+ }
1089
+
1090
+ if (vmin_out) {
1091
+ *vmin_out = min_dis;
1092
+ }
1093
+ int ret = ids[min_idx];
1094
+ ids[min_idx] = -1;
1095
+ --nvalid;
1096
+ return ret;
1097
+ }
1098
+
1099
+ #else
1100
+
1101
+ // baseline non-vectorized version
855
1102
  int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
856
1103
  assert(k > 0);
857
1104
  // returns min. This is an O(n) operation
858
1105
  int i = k - 1;
859
1106
  while (i >= 0) {
860
- if (ids[i] != -1)
1107
+ if (ids[i] != -1) {
861
1108
  break;
1109
+ }
862
1110
  i--;
863
1111
  }
864
- if (i == -1)
1112
+ if (i == -1) {
865
1113
  return -1;
1114
+ }
866
1115
  int imin = i;
867
1116
  float vmin = dis[i];
868
1117
  i--;
@@ -873,14 +1122,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
873
1122
  }
874
1123
  i--;
875
1124
  }
876
- if (vmin_out)
1125
+ if (vmin_out) {
877
1126
  *vmin_out = vmin;
1127
+ }
878
1128
  int ret = ids[imin];
879
1129
  ids[imin] = -1;
880
1130
  --nvalid;
881
1131
 
882
1132
  return ret;
883
1133
  }
1134
+ #endif
884
1135
 
885
1136
  int HNSW::MinimaxHeap::count_below(float thresh) {
886
1137
  int n_below = 0;
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #pragma once
11
9
 
12
10
  #include <queue>
@@ -42,6 +40,8 @@ namespace faiss {
42
40
  struct VisitedTable;
43
41
  struct DistanceComputer; // from AuxIndexStructures
44
42
  struct HNSWStats;
43
+ template <class C>
44
+ struct ResultHandler;
45
45
 
46
46
  struct SearchParametersHNSW : SearchParameters {
47
47
  int efSearch = 16;
@@ -54,6 +54,9 @@ struct HNSW {
54
54
  /// internal storage of vectors (32 bits: this is expensive)
55
55
  using storage_idx_t = int32_t;
56
56
 
57
+ // for now we do only these distances
58
+ using C = CMax<float, int64_t>;
59
+
57
60
  typedef std::pair<float, storage_idx_t> Node;
58
61
 
59
62
  /** Heap structure that allows fast
@@ -195,18 +198,14 @@ struct HNSW {
195
198
  /// search interface for 1 point, single thread
196
199
  HNSWStats search(
197
200
  DistanceComputer& qdis,
198
- int k,
199
- idx_t* I,
200
- float* D,
201
+ ResultHandler<C>& res,
201
202
  VisitedTable& vt,
202
203
  const SearchParametersHNSW* params = nullptr) const;
203
204
 
204
205
  /// search only in level 0 from a given vertex
205
206
  void search_level_0(
206
207
  DistanceComputer& qdis,
207
- int k,
208
- idx_t* idxi,
209
- float* simi,
208
+ ResultHandler<C>& res,
210
209
  idx_t nprobe,
211
210
  const storage_idx_t* nearest_i,
212
211
  const float* nearest_d,
@@ -226,6 +225,8 @@ struct HNSW {
226
225
  std::priority_queue<NodeDistFarther>& input,
227
226
  std::vector<NodeDistFarther>& output,
228
227
  int max_size);
228
+
229
+ void permute_entries(const idx_t* map);
229
230
  };
230
231
 
231
232
  struct HNSWStats {
@@ -10,7 +10,7 @@
10
10
  #include <unordered_set>
11
11
  #include <vector>
12
12
 
13
- #include <faiss/Index.h>
13
+ #include <faiss/MetricType.h>
14
14
 
15
15
  /** IDSelector is intended to define a subset of vectors to handle (for removal
16
16
  * or as subset to search) */
@@ -140,7 +140,7 @@ struct IDSelectorAnd : IDSelector {
140
140
  : lhs(lhs), rhs(rhs) {}
141
141
  bool is_member(idx_t id) const final {
142
142
  return lhs->is_member(id) && rhs->is_member(id);
143
- };
143
+ }
144
144
  virtual ~IDSelectorAnd() {}
145
145
  };
146
146
 
@@ -153,7 +153,7 @@ struct IDSelectorOr : IDSelector {
153
153
  : lhs(lhs), rhs(rhs) {}
154
154
  bool is_member(idx_t id) const final {
155
155
  return lhs->is_member(id) || rhs->is_member(id);
156
- };
156
+ }
157
157
  virtual ~IDSelectorOr() {}
158
158
  };
159
159
 
@@ -166,7 +166,7 @@ struct IDSelectorXOr : IDSelector {
166
166
  : lhs(lhs), rhs(rhs) {}
167
167
  bool is_member(idx_t id) const final {
168
168
  return lhs->is_member(id) ^ rhs->is_member(id);
169
- };
169
+ }
170
170
  virtual ~IDSelectorXOr() {}
171
171
  };
172
172
 
@@ -628,7 +628,9 @@ void LocalSearchQuantizer::icm_encode_step(
628
628
  {
629
629
  size_t binary_idx = (other_m + 1) * M * K * K +
630
630
  m * K * K + code2 * K + code;
631
- _mm_prefetch(binaries + binary_idx, _MM_HINT_T0);
631
+ _mm_prefetch(
632
+ (const char*)(binaries + binary_idx),
633
+ _MM_HINT_T0);
632
634
  }
633
635
  }
634
636
  #endif