faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -12,6 +12,8 @@
12
12
  #include <string>
13
13
 
14
14
  #include <faiss/impl/AuxIndexStructures.h>
15
+ #include <faiss/impl/DistanceComputer.h>
16
+ #include <faiss/impl/IDSelector.h>
15
17
 
16
18
  namespace faiss {
17
19
 
@@ -501,9 +503,19 @@ void HNSW::add_with_locks(
501
503
  }
502
504
  }
503
505
 
506
+ /**************************************************************
507
+ * Searching
508
+ **************************************************************/
509
+
510
+ namespace {
511
+
512
+ using idx_t = HNSW::idx_t;
513
+ using MinimaxHeap = HNSW::MinimaxHeap;
514
+ using Node = HNSW::Node;
504
515
  /** Do a BFS on the candidates list */
505
516
 
506
- int HNSW::search_from_candidates(
517
+ int search_from_candidates(
518
+ const HNSW& hnsw,
507
519
  DistanceComputer& qdis,
508
520
  int k,
509
521
  idx_t* I,
@@ -512,22 +524,31 @@ int HNSW::search_from_candidates(
512
524
  VisitedTable& vt,
513
525
  HNSWStats& stats,
514
526
  int level,
515
- int nres_in) const {
527
+ int nres_in = 0,
528
+ const SearchParametersHNSW* params = nullptr) {
516
529
  int nres = nres_in;
517
530
  int ndis = 0;
531
+
532
+ // can be overridden by search params
533
+ bool do_dis_check = params ? params->check_relative_distance
534
+ : hnsw.check_relative_distance;
535
+ int efSearch = params ? params->efSearch : hnsw.efSearch;
536
+ const IDSelector* sel = params ? params->sel : nullptr;
537
+
518
538
  for (int i = 0; i < candidates.size(); i++) {
519
539
  idx_t v1 = candidates.ids[i];
520
540
  float d = candidates.dis[i];
521
541
  FAISS_ASSERT(v1 >= 0);
522
- if (nres < k) {
523
- faiss::maxheap_push(++nres, D, I, d, v1);
524
- } else if (d < D[0]) {
525
- faiss::maxheap_replace_top(nres, D, I, d, v1);
542
+ if (!sel || sel->is_member(v1)) {
543
+ if (nres < k) {
544
+ faiss::maxheap_push(++nres, D, I, d, v1);
545
+ } else if (d < D[0]) {
546
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
547
+ }
526
548
  }
527
549
  vt.set(v1);
528
550
  }
529
551
 
530
- bool do_dis_check = check_relative_distance;
531
552
  int nstep = 0;
532
553
 
533
554
  while (candidates.size() > 0) {
@@ -546,10 +567,10 @@ int HNSW::search_from_candidates(
546
567
  }
547
568
 
548
569
  size_t begin, end;
549
- neighbor_range(v0, level, &begin, &end);
570
+ hnsw.neighbor_range(v0, level, &begin, &end);
550
571
 
551
572
  for (size_t j = begin; j < end; j++) {
552
- int v1 = neighbors[j];
573
+ int v1 = hnsw.neighbors[j];
553
574
  if (v1 < 0)
554
575
  break;
555
576
  if (vt.get(v1)) {
@@ -558,10 +579,12 @@ int HNSW::search_from_candidates(
558
579
  vt.set(v1);
559
580
  ndis++;
560
581
  float d = qdis(v1);
561
- if (nres < k) {
562
- faiss::maxheap_push(++nres, D, I, d, v1);
563
- } else if (d < D[0]) {
564
- faiss::maxheap_replace_top(nres, D, I, d, v1);
582
+ if (!sel || sel->is_member(v1)) {
583
+ if (nres < k) {
584
+ faiss::maxheap_push(++nres, D, I, d, v1);
585
+ } else if (d < D[0]) {
586
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
587
+ }
565
588
  }
566
589
  candidates.push(v1, d);
567
590
  }
@@ -583,16 +606,13 @@ int HNSW::search_from_candidates(
583
606
  return nres;
584
607
  }
585
608
 
586
- /**************************************************************
587
- * Searching
588
- **************************************************************/
589
-
590
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
609
+ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
610
+ const HNSW& hnsw,
591
611
  const Node& node,
592
612
  DistanceComputer& qdis,
593
613
  int ef,
594
614
  VisitedTable* vt,
595
- HNSWStats& stats) const {
615
+ HNSWStats& stats) {
596
616
  int ndis = 0;
597
617
  std::priority_queue<Node> top_candidates;
598
618
  std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
@@ -614,10 +634,10 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
614
634
  candidates.pop();
615
635
 
616
636
  size_t begin, end;
617
- neighbor_range(v0, 0, &begin, &end);
637
+ hnsw.neighbor_range(v0, 0, &begin, &end);
618
638
 
619
639
  for (size_t j = begin; j < end; ++j) {
620
- int v1 = neighbors[j];
640
+ int v1 = hnsw.neighbors[j];
621
641
 
622
642
  if (v1 < 0) {
623
643
  break;
@@ -651,14 +671,19 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
651
671
  return top_candidates;
652
672
  }
653
673
 
674
+ } // anonymous namespace
675
+
654
676
  HNSWStats HNSW::search(
655
677
  DistanceComputer& qdis,
656
678
  int k,
657
679
  idx_t* I,
658
680
  float* D,
659
- VisitedTable& vt) const {
681
+ VisitedTable& vt,
682
+ const SearchParametersHNSW* params) const {
660
683
  HNSWStats stats;
661
-
684
+ if (entry_point == -1) {
685
+ return stats;
686
+ }
662
687
  if (upper_beam == 1) {
663
688
  // greedy search on upper levels
664
689
  storage_idx_t nearest = entry_point;
@@ -669,16 +694,22 @@ HNSWStats HNSW::search(
669
694
  }
670
695
 
671
696
  int ef = std::max(efSearch, k);
672
- if (search_bounded_queue) {
697
+ if (search_bounded_queue) { // this is the most common branch
673
698
  MinimaxHeap candidates(ef);
674
699
 
675
700
  candidates.push(nearest, d_nearest);
676
701
 
677
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
702
+ search_from_candidates(
703
+ *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
678
704
  } else {
679
705
  std::priority_queue<Node> top_candidates =
680
706
  search_from_candidate_unbounded(
681
- Node(d_nearest, nearest), qdis, ef, &vt, stats);
707
+ *this,
708
+ Node(d_nearest, nearest),
709
+ qdis,
710
+ ef,
711
+ &vt,
712
+ stats);
682
713
 
683
714
  while (top_candidates.size() > k) {
684
715
  top_candidates.pop();
@@ -718,9 +749,10 @@ HNSWStats HNSW::search(
718
749
 
719
750
  if (level == 0) {
720
751
  nres = search_from_candidates(
721
- qdis, k, I, D, candidates, vt, stats, 0);
752
+ *this, qdis, k, I, D, candidates, vt, stats, 0);
722
753
  } else {
723
754
  nres = search_from_candidates(
755
+ *this,
724
756
  qdis,
725
757
  candidates_size,
726
758
  I_to_next.data(),
@@ -737,6 +769,70 @@ HNSWStats HNSW::search(
737
769
  return stats;
738
770
  }
739
771
 
772
+ void HNSW::search_level_0(
773
+ DistanceComputer& qdis,
774
+ int k,
775
+ idx_t* idxi,
776
+ float* simi,
777
+ idx_t nprobe,
778
+ const storage_idx_t* nearest_i,
779
+ const float* nearest_d,
780
+ int search_type,
781
+ HNSWStats& search_stats,
782
+ VisitedTable& vt) const {
783
+ const HNSW& hnsw = *this;
784
+
785
+ if (search_type == 1) {
786
+ int nres = 0;
787
+
788
+ for (int j = 0; j < nprobe; j++) {
789
+ storage_idx_t cj = nearest_i[j];
790
+
791
+ if (cj < 0)
792
+ break;
793
+
794
+ if (vt.get(cj))
795
+ continue;
796
+
797
+ int candidates_size = std::max(hnsw.efSearch, int(k));
798
+ MinimaxHeap candidates(candidates_size);
799
+
800
+ candidates.push(cj, nearest_d[j]);
801
+
802
+ nres = search_from_candidates(
803
+ hnsw,
804
+ qdis,
805
+ k,
806
+ idxi,
807
+ simi,
808
+ candidates,
809
+ vt,
810
+ search_stats,
811
+ 0,
812
+ nres);
813
+ }
814
+ } else if (search_type == 2) {
815
+ int candidates_size = std::max(hnsw.efSearch, int(k));
816
+ candidates_size = std::max(candidates_size, int(nprobe));
817
+
818
+ MinimaxHeap candidates(candidates_size);
819
+ for (int j = 0; j < nprobe; j++) {
820
+ storage_idx_t cj = nearest_i[j];
821
+
822
+ if (cj < 0)
823
+ break;
824
+ candidates.push(cj, nearest_d[j]);
825
+ }
826
+
827
+ search_from_candidates(
828
+ hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
829
+ }
830
+ }
831
+
832
+ /**************************************************************
833
+ * MinimaxHeap
834
+ **************************************************************/
835
+
740
836
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
741
837
  if (k == n) {
742
838
  if (v >= dis[0])
@@ -43,6 +43,13 @@ struct VisitedTable;
43
43
  struct DistanceComputer; // from AuxIndexStructures
44
44
  struct HNSWStats;
45
45
 
46
+ struct SearchParametersHNSW : SearchParameters {
47
+ int efSearch = 16;
48
+ bool check_relative_distance = true;
49
+
50
+ ~SearchParametersHNSW() {}
51
+ };
52
+
46
53
  struct HNSW {
47
54
  /// internal storage of vectors (32 bits: this is expensive)
48
55
  typedef int storage_idx_t;
@@ -188,30 +195,26 @@ struct HNSW {
188
195
  std::vector<omp_lock_t>& locks,
189
196
  VisitedTable& vt);
190
197
 
191
- int search_from_candidates(
198
+ /// search interface for 1 point, single thread
199
+ HNSWStats search(
192
200
  DistanceComputer& qdis,
193
201
  int k,
194
202
  idx_t* I,
195
203
  float* D,
196
- MinimaxHeap& candidates,
197
204
  VisitedTable& vt,
198
- HNSWStats& stats,
199
- int level,
200
- int nres_in = 0) const;
205
+ const SearchParametersHNSW* params = nullptr) const;
201
206
 
202
- std::priority_queue<Node> search_from_candidate_unbounded(
203
- const Node& node,
204
- DistanceComputer& qdis,
205
- int ef,
206
- VisitedTable* vt,
207
- HNSWStats& stats) const;
208
-
209
- /// search interface
210
- HNSWStats search(
207
+ /// search only in level 0 from a given vertex
208
+ void search_level_0(
211
209
  DistanceComputer& qdis,
212
210
  int k,
213
- idx_t* I,
214
- float* D,
211
+ idx_t* idxi,
212
+ float* simi,
213
+ idx_t nprobe,
214
+ const storage_idx_t* nearest_i,
215
+ const float* nearest_d,
216
+ int search_type,
217
+ HNSWStats& search_stats,
215
218
  VisitedTable& vt) const;
216
219
 
217
220
  void reset();
@@ -0,0 +1,125 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/impl/IDSelector.h>
10
+
11
+ namespace faiss {
12
+
13
+ /***********************************************************************
14
+ * IDSelectorRange
15
+ ***********************************************************************/
16
+
17
+ IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted)
18
+ : imin(imin), imax(imax), assume_sorted(assume_sorted) {}
19
+
20
+ bool IDSelectorRange::is_member(idx_t id) const {
21
+ return id >= imin && id < imax;
22
+ }
23
+
24
+ void IDSelectorRange::find_sorted_ids_bounds(
25
+ size_t list_size,
26
+ const idx_t* ids,
27
+ size_t* jmin_out,
28
+ size_t* jmax_out) const {
29
+ FAISS_ASSERT(assume_sorted);
30
+ if (list_size == 0 || imax <= ids[0] || imin > ids[list_size - 1]) {
31
+ *jmin_out = *jmax_out = 0;
32
+ return;
33
+ }
34
+ // bissection to find imin
35
+ if (ids[0] >= imin) {
36
+ *jmin_out = 0;
37
+ } else {
38
+ size_t j0 = 0, j1 = list_size;
39
+ while (j1 > j0 + 1) {
40
+ size_t jmed = (j0 + j1) / 2;
41
+ if (ids[jmed] >= imin) {
42
+ j1 = jmed;
43
+ } else {
44
+ j0 = jmed;
45
+ }
46
+ }
47
+ *jmin_out = j1;
48
+ }
49
+ // bissection to find imax
50
+ if (*jmin_out == list_size || ids[*jmin_out] >= imax) {
51
+ *jmax_out = *jmin_out;
52
+ } else {
53
+ size_t j0 = *jmin_out, j1 = list_size;
54
+ while (j1 > j0 + 1) {
55
+ size_t jmed = (j0 + j1) / 2;
56
+ if (ids[jmed] >= imax) {
57
+ j1 = jmed;
58
+ } else {
59
+ j0 = jmed;
60
+ }
61
+ }
62
+ *jmax_out = j1;
63
+ }
64
+ }
65
+
66
+ /***********************************************************************
67
+ * IDSelectorArray
68
+ ***********************************************************************/
69
+
70
+ IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
71
+
72
+ bool IDSelectorArray::is_member(idx_t id) const {
73
+ for (idx_t i = 0; i < n; i++) {
74
+ if (ids[i] == id)
75
+ return true;
76
+ }
77
+ return false;
78
+ }
79
+
80
+ /***********************************************************************
81
+ * IDSelectorBatch
82
+ ***********************************************************************/
83
+
84
+ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
85
+ nbits = 0;
86
+ while (n > ((idx_t)1 << nbits)) {
87
+ nbits++;
88
+ }
89
+ nbits += 5;
90
+ // for n = 1M, nbits = 25 is optimal, see P56659518
91
+
92
+ mask = ((idx_t)1 << nbits) - 1;
93
+ bloom.resize((idx_t)1 << (nbits - 3), 0);
94
+ for (idx_t i = 0; i < n; i++) {
95
+ Index::idx_t id = indices[i];
96
+ set.insert(id);
97
+ id &= mask;
98
+ bloom[id >> 3] |= 1 << (id & 7);
99
+ }
100
+ }
101
+
102
+ bool IDSelectorBatch::is_member(idx_t i) const {
103
+ long im = i & mask;
104
+ if (!(bloom[im >> 3] & (1 << (im & 7)))) {
105
+ return 0;
106
+ }
107
+ return set.count(i);
108
+ }
109
+
110
+ /***********************************************************************
111
+ * IDSelectorBitmap
112
+ ***********************************************************************/
113
+
114
+ IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap)
115
+ : n(n), bitmap(bitmap) {}
116
+
117
+ bool IDSelectorBitmap::is_member(idx_t ii) const {
118
+ uint64_t i = ii;
119
+ if ((i >> 3) >= n) {
120
+ return false;
121
+ }
122
+ return (bitmap[i >> 3] >> (i & 7)) & 1;
123
+ }
124
+
125
+ } // namespace faiss
@@ -0,0 +1,135 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <unordered_set>
11
+ #include <vector>
12
+
13
+ #include <faiss/Index.h>
14
+
15
+ /** IDSelector is intended to define a subset of vectors to handle (for removal
16
+ * or as subset to search) */
17
+
18
+ namespace faiss {
19
+
20
+ /** Encapsulates a set of ids to handle. */
21
+ struct IDSelector {
22
+ using idx_t = Index::idx_t;
23
+ virtual bool is_member(idx_t id) const = 0;
24
+ virtual ~IDSelector() {}
25
+ };
26
+
27
+ /** ids between [imin, imax) */
28
+ struct IDSelectorRange : IDSelector {
29
+ idx_t imin, imax;
30
+
31
+ /// Assume that the ids to handle are sorted. In some cases this can speed
32
+ /// up processing
33
+ bool assume_sorted;
34
+
35
+ IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted = false);
36
+
37
+ bool is_member(idx_t id) const final;
38
+
39
+ /// for sorted ids, find the range of list indices where the valid ids are
40
+ /// stored
41
+ void find_sorted_ids_bounds(
42
+ size_t list_size,
43
+ const idx_t* ids,
44
+ size_t* jmin,
45
+ size_t* jmax) const;
46
+
47
+ ~IDSelectorRange() override {}
48
+ };
49
+
50
+ /** Simple array of elements
51
+ *
52
+ * is_member calls are very inefficient, but some operations can use the ids
53
+ * directly.
54
+ */
55
+ struct IDSelectorArray : IDSelector {
56
+ size_t n;
57
+ const idx_t* ids;
58
+
59
+ /** Construct with an array of ids to process
60
+ *
61
+ * @param n number of ids to store
62
+ * @param ids elements to store. The pointer should remain valid during
63
+ * IDSelectorArray's lifetime
64
+ */
65
+ IDSelectorArray(size_t n, const idx_t* ids);
66
+ bool is_member(idx_t id) const final;
67
+ ~IDSelectorArray() override {}
68
+ };
69
+
70
+ /** Ids from a set.
71
+ *
72
+ * Repetitions of ids in the indices set passed to the constructor does not hurt
73
+ * performance.
74
+ *
75
+ * The hash function used for the bloom filter and GCC's implementation of
76
+ * unordered_set are just the least significant bits of the id. This works fine
77
+ * for random ids or ids in sequences but will produce many hash collisions if
78
+ * lsb's are always the same
79
+ */
80
+ struct IDSelectorBatch : IDSelector {
81
+ std::unordered_set<idx_t> set;
82
+
83
+ // Bloom filter to avoid accessing the unordered set if it is unlikely
84
+ // to be true
85
+ std::vector<uint8_t> bloom;
86
+ int nbits;
87
+ idx_t mask;
88
+
89
+ /** Construct with an array of ids to process
90
+ *
91
+ * @param n number of ids to store
92
+ * @param ids elements to store. The pointer can be released after
93
+ * construction
94
+ */
95
+ IDSelectorBatch(size_t n, const idx_t* indices);
96
+ bool is_member(idx_t id) const final;
97
+ ~IDSelectorBatch() override {}
98
+ };
99
+
100
+ /** One bit per element. Constructed with a bitmap, size ceil(n / 8).
101
+ */
102
+ struct IDSelectorBitmap : IDSelector {
103
+ size_t n;
104
+ const uint8_t* bitmap;
105
+
106
+ /** Construct with a binary mask
107
+ *
108
+ * @param n size of the bitmap array
109
+ * @param bitmap id will be selected iff id / 8 < n and bit number
110
+ * (i%8) of bitmap[floor(i / 8)] is 1.
111
+ */
112
+ IDSelectorBitmap(size_t n, const uint8_t* bitmap);
113
+ bool is_member(idx_t id) const final;
114
+ ~IDSelectorBitmap() override {}
115
+ };
116
+
117
+ /** reverts the membership test of another selector */
118
+ struct IDSelectorNot : IDSelector {
119
+ const IDSelector* sel;
120
+ IDSelectorNot(const IDSelector* sel) : sel(sel) {}
121
+ bool is_member(idx_t id) const final {
122
+ return !sel->is_member(id);
123
+ }
124
+ virtual ~IDSelectorNot() {}
125
+ };
126
+
127
+ /// selects all entries (useful for benchmarking)
128
+ struct IDSelectorAll : IDSelector {
129
+ bool is_member(idx_t id) const final {
130
+ return true;
131
+ }
132
+ virtual ~IDSelectorAll() {}
133
+ };
134
+
135
+ } // namespace faiss
@@ -15,7 +15,6 @@
15
15
 
16
16
  #include <algorithm>
17
17
 
18
- #include <faiss/Clustering.h>
19
18
  #include <faiss/impl/AuxIndexStructures.h>
20
19
  #include <faiss/impl/FaissAssert.h>
21
20
  #include <faiss/utils/distances.h>
@@ -151,9 +150,6 @@ LocalSearchQuantizer::LocalSearchQuantizer(
151
150
  size_t nbits,
152
151
  Search_type_t search_type)
153
152
  : AdditiveQuantizer(d, std::vector<size_t>(M, nbits), search_type) {
154
- is_trained = false;
155
- verbose = false;
156
-
157
153
  K = (1 << nbits);
158
154
 
159
155
  train_iters = 25;
@@ -182,7 +178,7 @@ LocalSearchQuantizer::LocalSearchQuantizer() : LocalSearchQuantizer(0, 0, 0) {}
182
178
 
183
179
  void LocalSearchQuantizer::train(size_t n, const float* x) {
184
180
  FAISS_THROW_IF_NOT(K == (1 << nbits[0]));
185
- FAISS_THROW_IF_NOT(nperts <= M);
181
+ nperts = std::min(nperts, M);
186
182
 
187
183
  lsq_timer.reset();
188
184
  LSQTimerScope scope(&lsq_timer, "train");
@@ -264,26 +260,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
264
260
  decode_unpacked(codes.data(), x_recons.data(), n);
265
261
  fvec_norms_L2sqr(norms.data(), x_recons.data(), d, n);
266
262
 
267
- norm_min = HUGE_VALF;
268
- norm_max = -HUGE_VALF;
269
- for (idx_t i = 0; i < n; i++) {
270
- if (norms[i] < norm_min) {
271
- norm_min = norms[i];
272
- }
273
- if (norms[i] > norm_max) {
274
- norm_max = norms[i];
275
- }
276
- }
277
-
278
- if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
279
- size_t k = (1 << 8);
280
- if (search_type == ST_norm_cqint4) {
281
- k = (1 << 4);
282
- }
283
- Clustering1D clus(k);
284
- clus.train_exact(n, norms.data());
285
- qnorm.add(clus.k, clus.centroids.data());
286
- }
263
+ train_norm(n, norms.data());
287
264
  }
288
265
 
289
266
  if (verbose) {
@@ -318,10 +295,11 @@ void LocalSearchQuantizer::perturb_codebooks(
318
295
  }
319
296
  }
320
297
 
321
- void LocalSearchQuantizer::compute_codes(
298
+ void LocalSearchQuantizer::compute_codes_add_centroids(
322
299
  const float* x,
323
300
  uint8_t* codes_out,
324
- size_t n) const {
301
+ size_t n,
302
+ const float* centroids) const {
325
303
  FAISS_THROW_IF_NOT_MSG(is_trained, "LSQ is not trained yet.");
326
304
 
327
305
  lsq_timer.reset();
@@ -335,7 +313,7 @@ void LocalSearchQuantizer::compute_codes(
335
313
  random_int32(codes, 0, K - 1, gen);
336
314
 
337
315
  icm_encode(codes.data(), x, n, encode_ils_iters, gen);
338
- pack_codes(n, codes.data(), codes_out);
316
+ pack_codes(n, codes.data(), codes_out, -1, nullptr, centroids);
339
317
 
340
318
  if (verbose) {
341
319
  scope.finish();
@@ -83,8 +83,13 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
83
83
  * @param x vectors to encode, size n * d
84
84
  * @param codes output codes, size n * code_size
85
85
  * @param n number of vectors
86
+ * @param centroids centroids to be added to x, size n * d
86
87
  */
87
- void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
88
+ void compute_codes_add_centroids(
89
+ const float* x,
90
+ uint8_t* codes,
91
+ size_t n,
92
+ const float* centroids = nullptr) const override;
88
93
 
89
94
  /** Update codebooks given encodings
90
95
  *