faiss 0.3.1 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexHNSW.h>
11
9
 
12
10
  #include <omp.h>
@@ -17,7 +15,10 @@
17
15
  #include <cstdlib>
18
16
  #include <cstring>
19
17
 
18
+ #include <limits>
19
+ #include <memory>
20
20
  #include <queue>
21
+ #include <random>
21
22
  #include <unordered_set>
22
23
 
23
24
  #include <sys/stat.h>
@@ -34,26 +35,6 @@
34
35
  #include <faiss/utils/random.h>
35
36
  #include <faiss/utils/sorting.h>
36
37
 
37
- extern "C" {
38
-
39
- /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
40
-
41
- int sgemm_(
42
- const char* transa,
43
- const char* transb,
44
- FINTEGER* m,
45
- FINTEGER* n,
46
- FINTEGER* k,
47
- const float* alpha,
48
- const float* a,
49
- FINTEGER* lda,
50
- const float* b,
51
- FINTEGER* ldb,
52
- float* beta,
53
- float* c,
54
- FINTEGER* ldc);
55
- }
56
-
57
38
  namespace faiss {
58
39
 
59
40
  using MinimaxHeap = HNSW::MinimaxHeap;
@@ -68,52 +49,6 @@ HNSWStats hnsw_stats;
68
49
 
69
50
  namespace {
70
51
 
71
- /* Wrap the distance computer into one that negates the
72
- distances. This makes supporting INNER_PRODUCE search easier */
73
-
74
- struct NegativeDistanceComputer : DistanceComputer {
75
- /// owned by this
76
- DistanceComputer* basedis;
77
-
78
- explicit NegativeDistanceComputer(DistanceComputer* basedis)
79
- : basedis(basedis) {}
80
-
81
- void set_query(const float* x) override {
82
- basedis->set_query(x);
83
- }
84
-
85
- /// compute distance of vector i to current query
86
- float operator()(idx_t i) override {
87
- return -(*basedis)(i);
88
- }
89
-
90
- void distances_batch_4(
91
- const idx_t idx0,
92
- const idx_t idx1,
93
- const idx_t idx2,
94
- const idx_t idx3,
95
- float& dis0,
96
- float& dis1,
97
- float& dis2,
98
- float& dis3) override {
99
- basedis->distances_batch_4(
100
- idx0, idx1, idx2, idx3, dis0, dis1, dis2, dis3);
101
- dis0 = -dis0;
102
- dis1 = -dis1;
103
- dis2 = -dis2;
104
- dis3 = -dis3;
105
- }
106
-
107
- /// compute distance between two stored vectors
108
- float symmetric_dis(idx_t i, idx_t j) override {
109
- return -basedis->symmetric_dis(i, j);
110
- }
111
-
112
- virtual ~NegativeDistanceComputer() {
113
- delete basedis;
114
- }
115
- };
116
-
117
52
  DistanceComputer* storage_distance_computer(const Index* storage) {
118
53
  if (is_similarity_metric(storage->metric_type)) {
119
54
  return new NegativeDistanceComputer(storage->get_distance_computer());
@@ -192,7 +127,9 @@ void hnsw_add_vertices(
192
127
 
193
128
  int i1 = n;
194
129
 
195
- for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
130
+ for (int pt_level = hist.size() - 1;
131
+ pt_level >= !index_hnsw.init_level0;
132
+ pt_level--) {
196
133
  int i0 = i1 - hist[pt_level];
197
134
 
198
135
  if (verbose) {
@@ -228,7 +165,13 @@ void hnsw_add_vertices(
228
165
  continue;
229
166
  }
230
167
 
231
- hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
168
+ hnsw.add_with_locks(
169
+ *dis,
170
+ pt_level,
171
+ pt_id,
172
+ locks,
173
+ vt,
174
+ index_hnsw.keep_max_size_level0 && (pt_level == 0));
232
175
 
233
176
  if (prev_display >= 0 && i - i0 > prev_display + 10000) {
234
177
  prev_display = i - i0;
@@ -248,7 +191,11 @@ void hnsw_add_vertices(
248
191
  }
249
192
  i1 = i0;
250
193
  }
251
- FAISS_ASSERT(i1 == 0);
194
+ if (index_hnsw.init_level0) {
195
+ FAISS_ASSERT(i1 == 0);
196
+ } else {
197
+ FAISS_ASSERT((i1 - hist[0]) == 0);
198
+ }
252
199
  }
253
200
  if (verbose) {
254
201
  printf("Done in %.3f ms\n", getmillisecs() - t0);
@@ -297,7 +244,8 @@ void hnsw_search(
297
244
  const SearchParameters* params_in) {
298
245
  FAISS_THROW_IF_NOT_MSG(
299
246
  index->storage,
300
- "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
247
+ "No storage index, please use IndexHNSWFlat (or variants) "
248
+ "instead of IndexHNSW directly");
301
249
  const SearchParametersHNSW* params = nullptr;
302
250
  const HNSW& hnsw = index->hnsw;
303
251
 
@@ -307,7 +255,7 @@ void hnsw_search(
307
255
  FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
308
256
  efSearch = params->efSearch;
309
257
  }
310
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
258
+ size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
311
259
 
312
260
  idx_t check_period = InterruptCallback::get_period_hint(
313
261
  hnsw.max_level * index->d * efSearch);
@@ -315,7 +263,7 @@ void hnsw_search(
315
263
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
316
264
  idx_t i1 = std::min(i0 + check_period, n);
317
265
 
318
- #pragma omp parallel
266
+ #pragma omp parallel if (i1 - i0 > 1)
319
267
  {
320
268
  VisitedTable vt(index->ntotal);
321
269
  typename BlockResultHandler::SingleResultHandler res(bres);
@@ -323,7 +271,7 @@ void hnsw_search(
323
271
  std::unique_ptr<DistanceComputer> dis(
324
272
  storage_distance_computer(index->storage));
325
273
 
326
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided)
274
+ #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
327
275
  for (idx_t i = i0; i < i1; i++) {
328
276
  res.begin(i);
329
277
  dis->set_query(x + i * index->d);
@@ -331,16 +279,15 @@ void hnsw_search(
331
279
  HNSWStats stats = hnsw.search(*dis, res, vt, params);
332
280
  n1 += stats.n1;
333
281
  n2 += stats.n2;
334
- n3 += stats.n3;
335
282
  ndis += stats.ndis;
336
- nreorder += stats.nreorder;
283
+ nhops += stats.nhops;
337
284
  res.end();
338
285
  }
339
286
  }
340
287
  InterruptCallback::check();
341
288
  }
342
289
 
343
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
290
+ hnsw_stats.combine({n1, n2, ndis, nhops});
344
291
  }
345
292
 
346
293
  } // anonymous namespace
@@ -374,7 +321,7 @@ void IndexHNSW::range_search(
374
321
  RangeSearchResult* result,
375
322
  const SearchParameters* params) const {
376
323
  using RH = RangeSearchBlockResultHandler<HNSW::C>;
377
- RH bres(result, radius);
324
+ RH bres(result, is_similarity_metric(metric_type) ? -radius : radius);
378
325
 
379
326
  hnsw_search(this, n, x, bres, params);
380
327
 
@@ -453,10 +400,18 @@ void IndexHNSW::search_level_0(
453
400
  float* distances,
454
401
  idx_t* labels,
455
402
  int nprobe,
456
- int search_type) const {
403
+ int search_type,
404
+ const SearchParameters* params_in) const {
457
405
  FAISS_THROW_IF_NOT(k > 0);
458
406
  FAISS_THROW_IF_NOT(nprobe > 0);
459
407
 
408
+ const SearchParametersHNSW* params = nullptr;
409
+
410
+ if (params_in) {
411
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
412
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
413
+ }
414
+
460
415
  storage_idx_t ntotal = hnsw.levels.size();
461
416
 
462
417
  using RH = HeapBlockResultHandler<HNSW::C>;
@@ -483,13 +438,21 @@ void IndexHNSW::search_level_0(
483
438
  nearest_d + i * nprobe,
484
439
  search_type,
485
440
  search_stats,
486
- vt);
441
+ vt,
442
+ params);
487
443
  res.end();
488
444
  vt.advance();
489
445
  }
490
446
  #pragma omp critical
491
447
  { hnsw_stats.combine(search_stats); }
492
448
  }
449
+ if (is_similarity_metric(this->metric_type)) {
450
+ // we need to revert the negated distances
451
+ #pragma omp parallel for
452
+ for (int64_t i = 0; i < k * n; i++) {
453
+ distances[i] = -distances[i];
454
+ }
455
+ }
493
456
  }
494
457
 
495
458
  void IndexHNSW::init_level_0_from_knngraph(
@@ -650,6 +613,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) {
650
613
  hnsw.permute_entries(perm);
651
614
  }
652
615
 
616
+ DistanceComputer* IndexHNSW::get_distance_computer() const {
617
+ return storage->get_distance_computer();
618
+ }
619
+
653
620
  /**************************************************************
654
621
  * IndexHNSWFlat implementation
655
622
  **************************************************************/
@@ -673,8 +640,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
673
640
 
674
641
  IndexHNSWPQ::IndexHNSWPQ() = default;
675
642
 
676
- IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
677
- : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
643
+ IndexHNSWPQ::IndexHNSWPQ(
644
+ int d,
645
+ int pq_m,
646
+ int M,
647
+ int pq_nbits,
648
+ MetricType metric)
649
+ : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
678
650
  own_fields = true;
679
651
  is_trained = false;
680
652
  }
@@ -800,7 +772,7 @@ void IndexHNSW2Level::search(
800
772
  IndexHNSW::search(n, x, k, distances, labels);
801
773
 
802
774
  } else { // "mixed" search
803
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
775
+ size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
804
776
 
805
777
  const IndexIVFPQ* index_ivfpq =
806
778
  dynamic_cast<const IndexIVFPQ*>(storage);
@@ -829,10 +801,10 @@ void IndexHNSW2Level::search(
829
801
  std::unique_ptr<DistanceComputer> dis(
830
802
  storage_distance_computer(storage));
831
803
 
832
- int candidates_size = hnsw.upper_beam;
804
+ constexpr int candidates_size = 1;
833
805
  MinimaxHeap candidates(candidates_size);
834
806
 
835
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
807
+ #pragma omp for reduction(+ : n1, n2, ndis, nhops)
836
808
  for (idx_t i = 0; i < n; i++) {
837
809
  idx_t* idxi = labels + i * k;
838
810
  float* simi = distances + i * k;
@@ -854,7 +826,7 @@ void IndexHNSW2Level::search(
854
826
 
855
827
  candidates.clear();
856
828
 
857
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
829
+ for (int j = 0; j < k; j++) {
858
830
  if (idxi[j] < 0)
859
831
  break;
860
832
  candidates.push(idxi[j], simi[j]);
@@ -877,9 +849,8 @@ void IndexHNSW2Level::search(
877
849
  k);
878
850
  n1 += search_stats.n1;
879
851
  n2 += search_stats.n2;
880
- n3 += search_stats.n3;
881
852
  ndis += search_stats.ndis;
882
- nreorder += search_stats.nreorder;
853
+ nhops += search_stats.nhops;
883
854
 
884
855
  vt.advance();
885
856
  vt.advance();
@@ -888,7 +859,7 @@ void IndexHNSW2Level::search(
888
859
  }
889
860
  }
890
861
 
891
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
862
+ hnsw_stats.combine({n1, n2, ndis, nhops});
892
863
  }
893
864
  }
894
865
 
@@ -914,4 +885,86 @@ void IndexHNSW2Level::flip_to_ivf() {
914
885
  delete storage2l;
915
886
  }
916
887
 
888
+ /**************************************************************
889
+ * IndexHNSWCagra implementation
890
+ **************************************************************/
891
+
892
+ IndexHNSWCagra::IndexHNSWCagra() {
893
+ is_trained = true;
894
+ }
895
+
896
+ IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
897
+ : IndexHNSW(
898
+ (metric == METRIC_L2)
899
+ ? static_cast<IndexFlat*>(new IndexFlatL2(d))
900
+ : static_cast<IndexFlat*>(new IndexFlatIP(d)),
901
+ M) {
902
+ FAISS_THROW_IF_NOT_MSG(
903
+ ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
904
+ "unsupported metric type for IndexHNSWCagra");
905
+ own_fields = true;
906
+ is_trained = true;
907
+ init_level0 = true;
908
+ keep_max_size_level0 = true;
909
+ }
910
+
911
+ void IndexHNSWCagra::add(idx_t n, const float* x) {
912
+ FAISS_THROW_IF_NOT_MSG(
913
+ !base_level_only,
914
+ "Cannot add vectors when base_level_only is set to True");
915
+
916
+ IndexHNSW::add(n, x);
917
+ }
918
+
919
+ void IndexHNSWCagra::search(
920
+ idx_t n,
921
+ const float* x,
922
+ idx_t k,
923
+ float* distances,
924
+ idx_t* labels,
925
+ const SearchParameters* params) const {
926
+ if (!base_level_only) {
927
+ IndexHNSW::search(n, x, k, distances, labels, params);
928
+ } else {
929
+ std::vector<storage_idx_t> nearest(n);
930
+ std::vector<float> nearest_d(n);
931
+
932
+ #pragma omp for
933
+ for (idx_t i = 0; i < n; i++) {
934
+ std::unique_ptr<DistanceComputer> dis(
935
+ storage_distance_computer(this->storage));
936
+ dis->set_query(x + i * d);
937
+ nearest[i] = -1;
938
+ nearest_d[i] = std::numeric_limits<float>::max();
939
+
940
+ std::random_device rd;
941
+ std::mt19937 gen(rd());
942
+ std::uniform_int_distribution<idx_t> distrib(0, this->ntotal - 1);
943
+
944
+ for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) {
945
+ auto idx = distrib(gen);
946
+ auto distance = (*dis)(idx);
947
+ if (distance < nearest_d[i]) {
948
+ nearest[i] = idx;
949
+ nearest_d[i] = distance;
950
+ }
951
+ }
952
+ FAISS_THROW_IF_NOT_MSG(
953
+ nearest[i] >= 0, "Could not find a valid entrypoint.");
954
+ }
955
+
956
+ search_level_0(
957
+ n,
958
+ x,
959
+ k,
960
+ nearest.data(),
961
+ nearest_d.data(),
962
+ distances,
963
+ labels,
964
+ 1, // n_probes
965
+ 1, // search_type
966
+ params);
967
+ }
968
+ }
969
+
917
970
  } // namespace faiss
@@ -27,13 +27,25 @@ struct IndexHNSW;
27
27
  struct IndexHNSW : Index {
28
28
  typedef HNSW::storage_idx_t storage_idx_t;
29
29
 
30
- // the link strcuture
30
+ // the link structure
31
31
  HNSW hnsw;
32
32
 
33
33
  // the sequential storage
34
34
  bool own_fields = false;
35
35
  Index* storage = nullptr;
36
36
 
37
+ // When set to false, level 0 in the knn graph is not initialized.
38
+ // This option is used by GpuIndexCagra::copyTo(IndexHNSWCagra*)
39
+ // as level 0 knn graph is copied over from the index built by
40
+ // GpuIndexCagra.
41
+ bool init_level0 = true;
42
+
43
+ // When set to true, all neighbors in level 0 are filled up
44
+ // to the maximum size allowed (2 * M). This option is used by
45
+ // IndexHHNSWCagra to create a full base layer graph that is
46
+ // used when GpuIndexCagra::copyFrom(IndexHNSWCagra*) is invoked.
47
+ bool keep_max_size_level0 = false;
48
+
37
49
  explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
38
50
  explicit IndexHNSW(Index* storage, int M = 32);
39
51
 
@@ -81,7 +93,8 @@ struct IndexHNSW : Index {
81
93
  float* distances,
82
94
  idx_t* labels,
83
95
  int nprobe = 1,
84
- int search_type = 1) const;
96
+ int search_type = 1,
97
+ const SearchParameters* params = nullptr) const;
85
98
 
86
99
  /// alternative graph building
87
100
  void init_level_0_from_knngraph(int k, const float* D, const idx_t* I);
@@ -98,6 +111,8 @@ struct IndexHNSW : Index {
98
111
  void link_singletons();
99
112
 
100
113
  void permute_entries(const idx_t* perm);
114
+
115
+ DistanceComputer* get_distance_computer() const override;
101
116
  };
102
117
 
103
118
  /** Flat index topped with with a HNSW structure to access elements
@@ -114,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW {
114
129
  */
115
130
  struct IndexHNSWPQ : IndexHNSW {
116
131
  IndexHNSWPQ();
117
- IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
132
+ IndexHNSWPQ(
133
+ int d,
134
+ int pq_m,
135
+ int M,
136
+ int pq_nbits = 8,
137
+ MetricType metric = METRIC_L2);
118
138
  void train(idx_t n, const float* x) override;
119
139
  };
120
140
 
@@ -148,4 +168,33 @@ struct IndexHNSW2Level : IndexHNSW {
148
168
  const SearchParameters* params = nullptr) const override;
149
169
  };
150
170
 
171
+ struct IndexHNSWCagra : IndexHNSW {
172
+ IndexHNSWCagra();
173
+ IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2);
174
+
175
+ /// When set to true, the index is immutable.
176
+ /// This option is used to copy the knn graph from GpuIndexCagra
177
+ /// to the base level of IndexHNSWCagra without adding upper levels.
178
+ /// Doing so enables to search the HNSW index, but removes the
179
+ /// ability to add vectors.
180
+ bool base_level_only = false;
181
+
182
+ /// When `base_level_only` is set to `True`, the search function
183
+ /// searches only the base level knn graph of the HNSW index.
184
+ /// This parameter selects the entry point by randomly selecting
185
+ /// some points and using the best one.
186
+ int num_base_level_search_entrypoints = 32;
187
+
188
+ void add(idx_t n, const float* x) override;
189
+
190
+ /// entry point for search
191
+ void search(
192
+ idx_t n,
193
+ const float* x,
194
+ idx_t k,
195
+ float* distances,
196
+ idx_t* labels,
197
+ const SearchParameters* params = nullptr) const override;
198
+ };
199
+
151
200
  } // namespace faiss
@@ -66,8 +66,8 @@ void Level1Quantizer::train_q1(
66
66
  } else if (quantizer_trains_alone == 1) {
67
67
  if (verbose)
68
68
  printf("IVF quantizer trains alone...\n");
69
- quantizer->train(n, x);
70
69
  quantizer->verbose = verbose;
70
+ quantizer->train(n, x);
71
71
  FAISS_THROW_IF_NOT_MSG(
72
72
  quantizer->ntotal == nlist,
73
73
  "nlist not consistent with quantizer size");
@@ -444,7 +444,7 @@ void IndexIVF::search_preassigned(
444
444
  max_codes = unlimited_list_size;
445
445
  }
446
446
 
447
- bool do_parallel = omp_get_max_threads() >= 2 &&
447
+ [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
448
448
  (pmode == 0 ? false
449
449
  : pmode == 3 ? n > 1
450
450
  : pmode == 1 ? nprobe > 1
@@ -784,7 +784,7 @@ void IndexIVF::range_search_preassigned(
784
784
 
785
785
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
786
786
  // don't start parallel section if single query
787
- bool do_parallel = omp_get_max_threads() >= 2 &&
787
+ [[maybe_unused]] bool do_parallel = omp_get_max_threads() >= 2 &&
788
788
  (pmode == 3 ? false
789
789
  : pmode == 0 ? nx > 1
790
790
  : pmode == 1 ? nprobe > 1
@@ -433,6 +433,14 @@ struct IndexIVF : Index, IndexIVFInterface {
433
433
 
434
434
  /* The standalone codec interface (except sa_decode that is specific) */
435
435
  size_t sa_code_size() const override;
436
+
437
+ /** encode a set of vectors
438
+ * sa_encode will call encode_vector with include_listno=true
439
+ * @param n nb of vectors to encode
440
+ * @param x the vectors to encode
441
+ * @param bytes output array for the codes
442
+ * @return nb of bytes written to codes
443
+ */
436
444
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
437
445
 
438
446
  IndexIVF();
@@ -471,7 +479,7 @@ struct InvertedListScanner {
471
479
  virtual float distance_to_code(const uint8_t* code) const = 0;
472
480
 
473
481
  /** scan a set of codes, compute distances to current query and
474
- * update heap of results if necessary. Default implemetation
482
+ * update heap of results if necessary. Default implementation
475
483
  * calls distance_to_code.
476
484
  *
477
485
  * @param n number of codes to scan
@@ -116,6 +116,21 @@ void IndexIVFAdditiveQuantizer::sa_decode(
116
116
  }
117
117
  }
118
118
 
119
+ void IndexIVFAdditiveQuantizer::reconstruct_from_offset(
120
+ int64_t list_no,
121
+ int64_t offset,
122
+ float* recons) const {
123
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
124
+ aq->decode(code, recons, 1);
125
+ if (by_residual) {
126
+ std::vector<float> centroid(d);
127
+ quantizer->reconstruct(list_no, centroid.data());
128
+ for (int i = 0; i < d; ++i) {
129
+ recons[i] += centroid[i];
130
+ }
131
+ }
132
+ }
133
+
119
134
  IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() = default;
120
135
 
121
136
  /*********************************************
@@ -56,6 +56,9 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
56
56
 
57
57
  void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
58
58
 
59
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
60
+ const override;
61
+
59
62
  ~IndexIVFAdditiveQuantizer() override;
60
63
  };
61
64