faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -12,6 +12,8 @@
12
12
  #include <string>
13
13
 
14
14
  #include <faiss/impl/AuxIndexStructures.h>
15
+ #include <faiss/impl/DistanceComputer.h>
16
+ #include <faiss/impl/IDSelector.h>
15
17
 
16
18
  namespace faiss {
17
19
 
@@ -501,9 +503,19 @@ void HNSW::add_with_locks(
501
503
  }
502
504
  }
503
505
 
506
+ /**************************************************************
507
+ * Searching
508
+ **************************************************************/
509
+
510
+ namespace {
511
+
512
+ using idx_t = HNSW::idx_t;
513
+ using MinimaxHeap = HNSW::MinimaxHeap;
514
+ using Node = HNSW::Node;
504
515
  /** Do a BFS on the candidates list */
505
516
 
506
- int HNSW::search_from_candidates(
517
+ int search_from_candidates(
518
+ const HNSW& hnsw,
507
519
  DistanceComputer& qdis,
508
520
  int k,
509
521
  idx_t* I,
@@ -512,22 +524,31 @@ int HNSW::search_from_candidates(
512
524
  VisitedTable& vt,
513
525
  HNSWStats& stats,
514
526
  int level,
515
- int nres_in) const {
527
+ int nres_in = 0,
528
+ const SearchParametersHNSW* params = nullptr) {
516
529
  int nres = nres_in;
517
530
  int ndis = 0;
531
+
532
+ // can be overridden by search params
533
+ bool do_dis_check = params ? params->check_relative_distance
534
+ : hnsw.check_relative_distance;
535
+ int efSearch = params ? params->efSearch : hnsw.efSearch;
536
+ const IDSelector* sel = params ? params->sel : nullptr;
537
+
518
538
  for (int i = 0; i < candidates.size(); i++) {
519
539
  idx_t v1 = candidates.ids[i];
520
540
  float d = candidates.dis[i];
521
541
  FAISS_ASSERT(v1 >= 0);
522
- if (nres < k) {
523
- faiss::maxheap_push(++nres, D, I, d, v1);
524
- } else if (d < D[0]) {
525
- faiss::maxheap_replace_top(nres, D, I, d, v1);
542
+ if (!sel || sel->is_member(v1)) {
543
+ if (nres < k) {
544
+ faiss::maxheap_push(++nres, D, I, d, v1);
545
+ } else if (d < D[0]) {
546
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
547
+ }
526
548
  }
527
549
  vt.set(v1);
528
550
  }
529
551
 
530
- bool do_dis_check = check_relative_distance;
531
552
  int nstep = 0;
532
553
 
533
554
  while (candidates.size() > 0) {
@@ -546,10 +567,10 @@ int HNSW::search_from_candidates(
546
567
  }
547
568
 
548
569
  size_t begin, end;
549
- neighbor_range(v0, level, &begin, &end);
570
+ hnsw.neighbor_range(v0, level, &begin, &end);
550
571
 
551
572
  for (size_t j = begin; j < end; j++) {
552
- int v1 = neighbors[j];
573
+ int v1 = hnsw.neighbors[j];
553
574
  if (v1 < 0)
554
575
  break;
555
576
  if (vt.get(v1)) {
@@ -558,10 +579,12 @@ int HNSW::search_from_candidates(
558
579
  vt.set(v1);
559
580
  ndis++;
560
581
  float d = qdis(v1);
561
- if (nres < k) {
562
- faiss::maxheap_push(++nres, D, I, d, v1);
563
- } else if (d < D[0]) {
564
- faiss::maxheap_replace_top(nres, D, I, d, v1);
582
+ if (!sel || sel->is_member(v1)) {
583
+ if (nres < k) {
584
+ faiss::maxheap_push(++nres, D, I, d, v1);
585
+ } else if (d < D[0]) {
586
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
587
+ }
565
588
  }
566
589
  candidates.push(v1, d);
567
590
  }
@@ -583,16 +606,13 @@ int HNSW::search_from_candidates(
583
606
  return nres;
584
607
  }
585
608
 
586
- /**************************************************************
587
- * Searching
588
- **************************************************************/
589
-
590
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
609
+ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
610
+ const HNSW& hnsw,
591
611
  const Node& node,
592
612
  DistanceComputer& qdis,
593
613
  int ef,
594
614
  VisitedTable* vt,
595
- HNSWStats& stats) const {
615
+ HNSWStats& stats) {
596
616
  int ndis = 0;
597
617
  std::priority_queue<Node> top_candidates;
598
618
  std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
@@ -614,10 +634,10 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
614
634
  candidates.pop();
615
635
 
616
636
  size_t begin, end;
617
- neighbor_range(v0, 0, &begin, &end);
637
+ hnsw.neighbor_range(v0, 0, &begin, &end);
618
638
 
619
639
  for (size_t j = begin; j < end; ++j) {
620
- int v1 = neighbors[j];
640
+ int v1 = hnsw.neighbors[j];
621
641
 
622
642
  if (v1 < 0) {
623
643
  break;
@@ -651,14 +671,19 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
651
671
  return top_candidates;
652
672
  }
653
673
 
674
+ } // anonymous namespace
675
+
654
676
  HNSWStats HNSW::search(
655
677
  DistanceComputer& qdis,
656
678
  int k,
657
679
  idx_t* I,
658
680
  float* D,
659
- VisitedTable& vt) const {
681
+ VisitedTable& vt,
682
+ const SearchParametersHNSW* params) const {
660
683
  HNSWStats stats;
661
-
684
+ if (entry_point == -1) {
685
+ return stats;
686
+ }
662
687
  if (upper_beam == 1) {
663
688
  // greedy search on upper levels
664
689
  storage_idx_t nearest = entry_point;
@@ -669,16 +694,22 @@ HNSWStats HNSW::search(
669
694
  }
670
695
 
671
696
  int ef = std::max(efSearch, k);
672
- if (search_bounded_queue) {
697
+ if (search_bounded_queue) { // this is the most common branch
673
698
  MinimaxHeap candidates(ef);
674
699
 
675
700
  candidates.push(nearest, d_nearest);
676
701
 
677
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
702
+ search_from_candidates(
703
+ *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
678
704
  } else {
679
705
  std::priority_queue<Node> top_candidates =
680
706
  search_from_candidate_unbounded(
681
- Node(d_nearest, nearest), qdis, ef, &vt, stats);
707
+ *this,
708
+ Node(d_nearest, nearest),
709
+ qdis,
710
+ ef,
711
+ &vt,
712
+ stats);
682
713
 
683
714
  while (top_candidates.size() > k) {
684
715
  top_candidates.pop();
@@ -718,9 +749,10 @@ HNSWStats HNSW::search(
718
749
 
719
750
  if (level == 0) {
720
751
  nres = search_from_candidates(
721
- qdis, k, I, D, candidates, vt, stats, 0);
752
+ *this, qdis, k, I, D, candidates, vt, stats, 0);
722
753
  } else {
723
754
  nres = search_from_candidates(
755
+ *this,
724
756
  qdis,
725
757
  candidates_size,
726
758
  I_to_next.data(),
@@ -737,6 +769,70 @@ HNSWStats HNSW::search(
737
769
  return stats;
738
770
  }
739
771
 
772
+ void HNSW::search_level_0(
773
+ DistanceComputer& qdis,
774
+ int k,
775
+ idx_t* idxi,
776
+ float* simi,
777
+ idx_t nprobe,
778
+ const storage_idx_t* nearest_i,
779
+ const float* nearest_d,
780
+ int search_type,
781
+ HNSWStats& search_stats,
782
+ VisitedTable& vt) const {
783
+ const HNSW& hnsw = *this;
784
+
785
+ if (search_type == 1) {
786
+ int nres = 0;
787
+
788
+ for (int j = 0; j < nprobe; j++) {
789
+ storage_idx_t cj = nearest_i[j];
790
+
791
+ if (cj < 0)
792
+ break;
793
+
794
+ if (vt.get(cj))
795
+ continue;
796
+
797
+ int candidates_size = std::max(hnsw.efSearch, int(k));
798
+ MinimaxHeap candidates(candidates_size);
799
+
800
+ candidates.push(cj, nearest_d[j]);
801
+
802
+ nres = search_from_candidates(
803
+ hnsw,
804
+ qdis,
805
+ k,
806
+ idxi,
807
+ simi,
808
+ candidates,
809
+ vt,
810
+ search_stats,
811
+ 0,
812
+ nres);
813
+ }
814
+ } else if (search_type == 2) {
815
+ int candidates_size = std::max(hnsw.efSearch, int(k));
816
+ candidates_size = std::max(candidates_size, int(nprobe));
817
+
818
+ MinimaxHeap candidates(candidates_size);
819
+ for (int j = 0; j < nprobe; j++) {
820
+ storage_idx_t cj = nearest_i[j];
821
+
822
+ if (cj < 0)
823
+ break;
824
+ candidates.push(cj, nearest_d[j]);
825
+ }
826
+
827
+ search_from_candidates(
828
+ hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
829
+ }
830
+ }
831
+
832
+ /**************************************************************
833
+ * MinimaxHeap
834
+ **************************************************************/
835
+
740
836
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
741
837
  if (k == n) {
742
838
  if (v >= dis[0])
@@ -43,6 +43,13 @@ struct VisitedTable;
43
43
  struct DistanceComputer; // from AuxIndexStructures
44
44
  struct HNSWStats;
45
45
 
46
+ struct SearchParametersHNSW : SearchParameters {
47
+ int efSearch = 16;
48
+ bool check_relative_distance = true;
49
+
50
+ ~SearchParametersHNSW() {}
51
+ };
52
+
46
53
  struct HNSW {
47
54
  /// internal storage of vectors (32 bits: this is expensive)
48
55
  typedef int storage_idx_t;
@@ -188,30 +195,26 @@ struct HNSW {
188
195
  std::vector<omp_lock_t>& locks,
189
196
  VisitedTable& vt);
190
197
 
191
- int search_from_candidates(
198
+ /// search interface for 1 point, single thread
199
+ HNSWStats search(
192
200
  DistanceComputer& qdis,
193
201
  int k,
194
202
  idx_t* I,
195
203
  float* D,
196
- MinimaxHeap& candidates,
197
204
  VisitedTable& vt,
198
- HNSWStats& stats,
199
- int level,
200
- int nres_in = 0) const;
205
+ const SearchParametersHNSW* params = nullptr) const;
201
206
 
202
- std::priority_queue<Node> search_from_candidate_unbounded(
203
- const Node& node,
204
- DistanceComputer& qdis,
205
- int ef,
206
- VisitedTable* vt,
207
- HNSWStats& stats) const;
208
-
209
- /// search interface
210
- HNSWStats search(
207
+ /// search only in level 0 from a given vertex
208
+ void search_level_0(
211
209
  DistanceComputer& qdis,
212
210
  int k,
213
- idx_t* I,
214
- float* D,
211
+ idx_t* idxi,
212
+ float* simi,
213
+ idx_t nprobe,
214
+ const storage_idx_t* nearest_i,
215
+ const float* nearest_d,
216
+ int search_type,
217
+ HNSWStats& search_stats,
215
218
  VisitedTable& vt) const;
216
219
 
217
220
  void reset();
@@ -0,0 +1,125 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/impl/IDSelector.h>
10
+
11
+ namespace faiss {
12
+
13
+ /***********************************************************************
14
+ * IDSelectorRange
15
+ ***********************************************************************/
16
+
17
+ IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted)
18
+ : imin(imin), imax(imax), assume_sorted(assume_sorted) {}
19
+
20
+ bool IDSelectorRange::is_member(idx_t id) const {
21
+ return id >= imin && id < imax;
22
+ }
23
+
24
+ void IDSelectorRange::find_sorted_ids_bounds(
25
+ size_t list_size,
26
+ const idx_t* ids,
27
+ size_t* jmin_out,
28
+ size_t* jmax_out) const {
29
+ FAISS_ASSERT(assume_sorted);
30
+ if (list_size == 0 || imax <= ids[0] || imin > ids[list_size - 1]) {
31
+ *jmin_out = *jmax_out = 0;
32
+ return;
33
+ }
34
+ // bissection to find imin
35
+ if (ids[0] >= imin) {
36
+ *jmin_out = 0;
37
+ } else {
38
+ size_t j0 = 0, j1 = list_size;
39
+ while (j1 > j0 + 1) {
40
+ size_t jmed = (j0 + j1) / 2;
41
+ if (ids[jmed] >= imin) {
42
+ j1 = jmed;
43
+ } else {
44
+ j0 = jmed;
45
+ }
46
+ }
47
+ *jmin_out = j1;
48
+ }
49
+ // bissection to find imax
50
+ if (*jmin_out == list_size || ids[*jmin_out] >= imax) {
51
+ *jmax_out = *jmin_out;
52
+ } else {
53
+ size_t j0 = *jmin_out, j1 = list_size;
54
+ while (j1 > j0 + 1) {
55
+ size_t jmed = (j0 + j1) / 2;
56
+ if (ids[jmed] >= imax) {
57
+ j1 = jmed;
58
+ } else {
59
+ j0 = jmed;
60
+ }
61
+ }
62
+ *jmax_out = j1;
63
+ }
64
+ }
65
+
66
+ /***********************************************************************
67
+ * IDSelectorArray
68
+ ***********************************************************************/
69
+
70
+ IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
71
+
72
+ bool IDSelectorArray::is_member(idx_t id) const {
73
+ for (idx_t i = 0; i < n; i++) {
74
+ if (ids[i] == id)
75
+ return true;
76
+ }
77
+ return false;
78
+ }
79
+
80
+ /***********************************************************************
81
+ * IDSelectorBatch
82
+ ***********************************************************************/
83
+
84
+ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
85
+ nbits = 0;
86
+ while (n > ((idx_t)1 << nbits)) {
87
+ nbits++;
88
+ }
89
+ nbits += 5;
90
+ // for n = 1M, nbits = 25 is optimal, see P56659518
91
+
92
+ mask = ((idx_t)1 << nbits) - 1;
93
+ bloom.resize((idx_t)1 << (nbits - 3), 0);
94
+ for (idx_t i = 0; i < n; i++) {
95
+ Index::idx_t id = indices[i];
96
+ set.insert(id);
97
+ id &= mask;
98
+ bloom[id >> 3] |= 1 << (id & 7);
99
+ }
100
+ }
101
+
102
+ bool IDSelectorBatch::is_member(idx_t i) const {
103
+ long im = i & mask;
104
+ if (!(bloom[im >> 3] & (1 << (im & 7)))) {
105
+ return 0;
106
+ }
107
+ return set.count(i);
108
+ }
109
+
110
+ /***********************************************************************
111
+ * IDSelectorBitmap
112
+ ***********************************************************************/
113
+
114
+ IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap)
115
+ : n(n), bitmap(bitmap) {}
116
+
117
+ bool IDSelectorBitmap::is_member(idx_t ii) const {
118
+ uint64_t i = ii;
119
+ if ((i >> 3) >= n) {
120
+ return false;
121
+ }
122
+ return (bitmap[i >> 3] >> (i & 7)) & 1;
123
+ }
124
+
125
+ } // namespace faiss
@@ -0,0 +1,135 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <unordered_set>
11
+ #include <vector>
12
+
13
+ #include <faiss/Index.h>
14
+
15
+ /** IDSelector is intended to define a subset of vectors to handle (for removal
16
+ * or as subset to search) */
17
+
18
+ namespace faiss {
19
+
20
+ /** Encapsulates a set of ids to handle. */
21
+ struct IDSelector {
22
+ using idx_t = Index::idx_t;
23
+ virtual bool is_member(idx_t id) const = 0;
24
+ virtual ~IDSelector() {}
25
+ };
26
+
27
+ /** ids between [imin, imax) */
28
+ struct IDSelectorRange : IDSelector {
29
+ idx_t imin, imax;
30
+
31
+ /// Assume that the ids to handle are sorted. In some cases this can speed
32
+ /// up processing
33
+ bool assume_sorted;
34
+
35
+ IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted = false);
36
+
37
+ bool is_member(idx_t id) const final;
38
+
39
+ /// for sorted ids, find the range of list indices where the valid ids are
40
+ /// stored
41
+ void find_sorted_ids_bounds(
42
+ size_t list_size,
43
+ const idx_t* ids,
44
+ size_t* jmin,
45
+ size_t* jmax) const;
46
+
47
+ ~IDSelectorRange() override {}
48
+ };
49
+
50
+ /** Simple array of elements
51
+ *
52
+ * is_member calls are very inefficient, but some operations can use the ids
53
+ * directly.
54
+ */
55
+ struct IDSelectorArray : IDSelector {
56
+ size_t n;
57
+ const idx_t* ids;
58
+
59
+ /** Construct with an array of ids to process
60
+ *
61
+ * @param n number of ids to store
62
+ * @param ids elements to store. The pointer should remain valid during
63
+ * IDSelectorArray's lifetime
64
+ */
65
+ IDSelectorArray(size_t n, const idx_t* ids);
66
+ bool is_member(idx_t id) const final;
67
+ ~IDSelectorArray() override {}
68
+ };
69
+
70
+ /** Ids from a set.
71
+ *
72
+ * Repetitions of ids in the indices set passed to the constructor does not hurt
73
+ * performance.
74
+ *
75
+ * The hash function used for the bloom filter and GCC's implementation of
76
+ * unordered_set are just the least significant bits of the id. This works fine
77
+ * for random ids or ids in sequences but will produce many hash collisions if
78
+ * lsb's are always the same
79
+ */
80
+ struct IDSelectorBatch : IDSelector {
81
+ std::unordered_set<idx_t> set;
82
+
83
+ // Bloom filter to avoid accessing the unordered set if it is unlikely
84
+ // to be true
85
+ std::vector<uint8_t> bloom;
86
+ int nbits;
87
+ idx_t mask;
88
+
89
+ /** Construct with an array of ids to process
90
+ *
91
+ * @param n number of ids to store
92
+ * @param ids elements to store. The pointer can be released after
93
+ * construction
94
+ */
95
+ IDSelectorBatch(size_t n, const idx_t* indices);
96
+ bool is_member(idx_t id) const final;
97
+ ~IDSelectorBatch() override {}
98
+ };
99
+
100
+ /** One bit per element. Constructed with a bitmap, size ceil(n / 8).
101
+ */
102
+ struct IDSelectorBitmap : IDSelector {
103
+ size_t n;
104
+ const uint8_t* bitmap;
105
+
106
+ /** Construct with a binary mask
107
+ *
108
+ * @param n size of the bitmap array
109
+ * @param bitmap id will be selected iff id / 8 < n and bit number
110
+ * (i%8) of bitmap[floor(i / 8)] is 1.
111
+ */
112
+ IDSelectorBitmap(size_t n, const uint8_t* bitmap);
113
+ bool is_member(idx_t id) const final;
114
+ ~IDSelectorBitmap() override {}
115
+ };
116
+
117
+ /** reverts the membership test of another selector */
118
+ struct IDSelectorNot : IDSelector {
119
+ const IDSelector* sel;
120
+ IDSelectorNot(const IDSelector* sel) : sel(sel) {}
121
+ bool is_member(idx_t id) const final {
122
+ return !sel->is_member(id);
123
+ }
124
+ virtual ~IDSelectorNot() {}
125
+ };
126
+
127
+ /// selects all entries (useful for benchmarking)
128
+ struct IDSelectorAll : IDSelector {
129
+ bool is_member(idx_t id) const final {
130
+ return true;
131
+ }
132
+ virtual ~IDSelectorAll() {}
133
+ };
134
+
135
+ } // namespace faiss
@@ -15,7 +15,6 @@
15
15
 
16
16
  #include <algorithm>
17
17
 
18
- #include <faiss/Clustering.h>
19
18
  #include <faiss/impl/AuxIndexStructures.h>
20
19
  #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/utils/distances.h>
@@ -151,9 +150,6 @@ LocalSearchQuantizer::LocalSearchQuantizer(
151
150
  size_t nbits,
152
151
  Search_type_t search_type)
153
152
  : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
154
- is_trained = false;
155
- verbose = false;
156
-
157
153
  K = (1 << nbits);
158
154
 
159
155
  train_iters = 25;
@@ -182,7 +178,7 @@ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
182
178
 
183
179
  void LocalSearchQuantizer::train(size_t n, const float* x) {
184
180
  FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
185
- FAISS_THROW_IF_NOT(nperts <= M);
181
+ nperts = std::min(nperts, M);
186
182
 
187
183
  lsq_timer.reset();
188
184
  LSQTimerScope scope(&lsq_timer, "train");
@@ -264,26 +260,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
264
260
  decode_unpacked(codes.data(), x_recons.data(), n);
265
261
  fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
266
262
 
267
- norm_min = HUGE_VALF;
268
- norm_max = -HUGE_VALF;
269
- for (idx_t i = 0; i < n; i++) {
270
- if (norms[i] < norm_min) {
271
- norm_min = norms[i];
272
- }
273
- if (norms[i] > norm_max) {
274
- norm_max = norms[i];
275
- }
276
- }
277
-
278
- if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
279
- size_t k = (1 << 8);
280
- if (search_type == ST_norm_cqint4) {
281
- k = (1 << 4);
282
- }
283
- Clustering1D clus(k);
284
- clus.train_exact(n, norms.data());
285
- qnorm.add(clus.k, clus.centroids.data());
286
- }
263
+ train_norm(n, norms.data());
287
264
  }
288
265
 
289
266
  if (verbose) {
@@ -318,10 +295,11 @@ void LocalSearchQuantizer::perturb_codebooks(
318
295
  }
319
296
  }
320
297
 
321
- void LocalSearchQuantizer::compute_codes(
298
+ void LocalSearchQuantizer::compute_codes_add_centroids(
322
299
  const float* x,
323
300
  uint8_t* codes_out,
324
- size_t n) const {
301
+ size_t n,
302
+ const float* centroids) const {
325
303
  FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
326
304
 
327
305
  lsq_timer.reset();
@@ -335,7 +313,7 @@ void LocalSearchQuantizer::compute_codes(
335
313
  random_int32(codes, 0, K - 1, gen);
336
314
 
337
315
  icm_encode(codes.data(), x, n, encode_ils_iters, gen);
338
- pack_codes(n, codes.data(), codes_out);
316
+ pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids);
339
317
 
340
318
  if (verbose) {
341
319
  scope.finish();
@@ -83,8 +83,13 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
83
83
  * @param x vectors to encode, size n * d
84
84
  * @param codes output codes, size n * code_size
85
85
  * @param n number of vectors
86
+ * @param centroids centroids to be added to x, size n * d
86
87
  */
87
- void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
88
+ void compute_codes_add_centroids(
89
+ const float* x,
90
+ uint8_t* codes,
91
+ size_t n,
92
+ const float* centroids = nullptr) const override;
88
93
 
89
94
  /** Update codebooks given encodings
90
95
  *