faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -202,7 +202,10 @@ void hnsw_add_vertices(
202
202
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
203
203
  size_t counter = 0;
204
204
 
205
- #pragma omp for schedule(dynamic)
205
+ // here we should do schedule(dynamic) but this segfaults for
206
+ // some versions of LLVM. The performance impact should not be
207
+ // too large when (i1 - i0) / num_threads >> 1
208
+ #pragma omp for schedule(static)
206
209
  for (int i = i0; i < i1; i++) {
207
210
  storage_idx_t pt_id = order[i];
208
211
  dis->set_query(x + (pt_id - n0) * d);
@@ -219,7 +222,6 @@ void hnsw_add_vertices(
219
222
  printf(" %d / %d\r", i - i0, i1 - i0);
220
223
  fflush(stdout);
221
224
  }
222
-
223
225
  if (counter % check_period == 0) {
224
226
  if (InterruptCallback::is_interrupted()) {
225
227
  interrupt = true;
@@ -284,18 +286,24 @@ void IndexHNSW::search(
284
286
  const float* x,
285
287
  idx_t k,
286
288
  float* distances,
287
- idx_t* labels) const
288
-
289
- {
289
+ idx_t* labels,
290
+ const SearchParameters* params_in) const {
290
291
  FAISS_THROW_IF_NOT(k > 0);
291
-
292
292
  FAISS_THROW_IF_NOT_MSG(
293
293
  storage,
294
294
  "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
295
+ const SearchParametersHNSW* params = nullptr;
296
+
297
+ int efSearch = hnsw.efSearch;
298
+ if (params_in) {
299
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
300
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
301
+ efSearch = params->efSearch;
302
+ }
295
303
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
296
304
 
297
- idx_t check_period = InterruptCallback::get_period_hint(
298
- hnsw.max_level * d * hnsw.efSearch);
305
+ idx_t check_period =
306
+ InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
299
307
 
300
308
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
301
309
  idx_t i1 = std::min(i0 + check_period, n);
@@ -314,7 +322,7 @@ void IndexHNSW::search(
314
322
  dis->set_query(x + i * d);
315
323
 
316
324
  maxheap_heapify(k, simi, idxi);
317
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt);
325
+ HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
318
326
  n1 += stats.n1;
319
327
  n2 += stats.n2;
320
328
  n3 += stats.n3;
@@ -423,16 +431,15 @@ void IndexHNSW::search_level_0(
423
431
  FAISS_THROW_IF_NOT(nprobe > 0);
424
432
 
425
433
  storage_idx_t ntotal = hnsw.levels.size();
426
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
427
434
 
428
435
  #pragma omp parallel
429
436
  {
430
- DistanceComputer* qdis = storage_distance_computer(storage);
431
- ScopeDeleter1<DistanceComputer> del(qdis);
432
-
437
+ std::unique_ptr<DistanceComputer> qdis(
438
+ storage_distance_computer(storage));
439
+ HNSWStats search_stats;
433
440
  VisitedTable vt(ntotal);
434
441
 
435
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
442
+ #pragma omp for
436
443
  for (idx_t i = 0; i < n; i++) {
437
444
  idx_t* idxi = labels + i * k;
438
445
  float* simi = distances + i * k;
@@ -440,69 +447,24 @@ void IndexHNSW::search_level_0(
440
447
  qdis->set_query(x + i * d);
441
448
  maxheap_heapify(k, simi, idxi);
442
449
 
443
- if (search_type == 1) {
444
- int nres = 0;
445
-
446
- for (int j = 0; j < nprobe; j++) {
447
- storage_idx_t cj = nearest[i * nprobe + j];
448
-
449
- if (cj < 0)
450
- break;
451
-
452
- if (vt.get(cj))
453
- continue;
454
-
455
- int candidates_size = std::max(hnsw.efSearch, int(k));
456
- MinimaxHeap candidates(candidates_size);
457
-
458
- candidates.push(cj, nearest_d[i * nprobe + j]);
459
-
460
- HNSWStats search_stats;
461
- nres = hnsw.search_from_candidates(
462
- *qdis,
463
- k,
464
- idxi,
465
- simi,
466
- candidates,
467
- vt,
468
- search_stats,
469
- 0,
470
- nres);
471
- n1 += search_stats.n1;
472
- n2 += search_stats.n2;
473
- n3 += search_stats.n3;
474
- ndis += search_stats.ndis;
475
- nreorder += search_stats.nreorder;
476
- }
477
- } else if (search_type == 2) {
478
- int candidates_size = std::max(hnsw.efSearch, int(k));
479
- candidates_size = std::max(candidates_size, nprobe);
480
-
481
- MinimaxHeap candidates(candidates_size);
482
- for (int j = 0; j < nprobe; j++) {
483
- storage_idx_t cj = nearest[i * nprobe + j];
484
-
485
- if (cj < 0)
486
- break;
487
- candidates.push(cj, nearest_d[i * nprobe + j]);
488
- }
450
+ hnsw.search_level_0(
451
+ *qdis.get(),
452
+ k,
453
+ idxi,
454
+ simi,
455
+ nprobe,
456
+ nearest + i * nprobe,
457
+ nearest_d + i * nprobe,
458
+ search_type,
459
+ search_stats,
460
+ vt);
489
461
 
490
- HNSWStats search_stats;
491
- hnsw.search_from_candidates(
492
- *qdis, k, idxi, simi, candidates, vt, search_stats, 0);
493
- n1 += search_stats.n1;
494
- n2 += search_stats.n2;
495
- n3 += search_stats.n3;
496
- ndis += search_stats.ndis;
497
- nreorder += search_stats.nreorder;
498
- }
499
462
  vt.advance();
500
-
501
463
  maxheap_reorder(k, simi, idxi);
502
464
  }
465
+ #pragma omp critical
466
+ { hnsw_stats.combine(search_stats); }
503
467
  }
504
-
505
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
506
468
  }
507
469
 
508
470
  void IndexHNSW::init_level_0_from_knngraph(
@@ -1035,8 +997,11 @@ void IndexHNSW2Level::search(
1035
997
  const float* x,
1036
998
  idx_t k,
1037
999
  float* distances,
1038
- idx_t* labels) const {
1000
+ idx_t* labels,
1001
+ const SearchParameters* params) const {
1039
1002
  FAISS_THROW_IF_NOT(k > 0);
1003
+ FAISS_THROW_IF_NOT_MSG(
1004
+ !params, "search params not supported for this index");
1040
1005
 
1041
1006
  if (dynamic_cast<const Index2Layer*>(storage)) {
1042
1007
  IndexHNSW::search(n, x, k, distances, labels);
@@ -1095,74 +1060,37 @@ void IndexHNSW2Level::search(
1095
1060
  }
1096
1061
 
1097
1062
  candidates.clear();
1098
- // copy the upper_beam elements to candidates list
1099
-
1100
- int search_policy = 2;
1101
-
1102
- if (search_policy == 1) {
1103
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1104
- if (idxi[j] < 0)
1105
- break;
1106
- candidates.push(idxi[j], simi[j]);
1107
- // search_from_candidates adds them back
1108
- idxi[j] = -1;
1109
- simi[j] = HUGE_VAL;
1110
- }
1111
1063
 
1112
- // reorder from sorted to heap
1113
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1114
-
1115
- HNSWStats search_stats;
1116
- hnsw.search_from_candidates(
1117
- *dis,
1118
- k,
1119
- idxi,
1120
- simi,
1121
- candidates,
1122
- vt,
1123
- search_stats,
1124
- 0,
1125
- k);
1126
- n1 += search_stats.n1;
1127
- n2 += search_stats.n2;
1128
- n3 += search_stats.n3;
1129
- ndis += search_stats.ndis;
1130
- nreorder += search_stats.nreorder;
1131
-
1132
- vt.advance();
1133
-
1134
- } else if (search_policy == 2) {
1135
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1136
- if (idxi[j] < 0)
1137
- break;
1138
- candidates.push(idxi[j], simi[j]);
1139
- }
1140
-
1141
- // reorder from sorted to heap
1142
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1143
-
1144
- HNSWStats search_stats;
1145
- search_from_candidates_2(
1146
- hnsw,
1147
- *dis,
1148
- k,
1149
- idxi,
1150
- simi,
1151
- candidates,
1152
- vt,
1153
- search_stats,
1154
- 0,
1155
- k);
1156
- n1 += search_stats.n1;
1157
- n2 += search_stats.n2;
1158
- n3 += search_stats.n3;
1159
- ndis += search_stats.ndis;
1160
- nreorder += search_stats.nreorder;
1161
-
1162
- vt.advance();
1163
- vt.advance();
1064
+ for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1065
+ if (idxi[j] < 0)
1066
+ break;
1067
+ candidates.push(idxi[j], simi[j]);
1164
1068
  }
1165
1069
 
1070
+ // reorder from sorted to heap
1071
+ maxheap_heapify(k, simi, idxi, simi, idxi, k);
1072
+
1073
+ HNSWStats search_stats;
1074
+ search_from_candidates_2(
1075
+ hnsw,
1076
+ *dis,
1077
+ k,
1078
+ idxi,
1079
+ simi,
1080
+ candidates,
1081
+ vt,
1082
+ search_stats,
1083
+ 0,
1084
+ k);
1085
+ n1 += search_stats.n1;
1086
+ n2 += search_stats.n2;
1087
+ n3 += search_stats.n3;
1088
+ ndis += search_stats.ndis;
1089
+ nreorder += search_stats.nreorder;
1090
+
1091
+ vt.advance();
1092
+ vt.advance();
1093
+
1166
1094
  maxheap_reorder(k, simi, idxi);
1167
1095
  }
1168
1096
  }
@@ -96,7 +96,8 @@ struct IndexHNSW : Index {
96
96
  const float* x,
97
97
  idx_t k,
98
98
  float* distances,
99
- idx_t* labels) const override;
99
+ idx_t* labels,
100
+ const SearchParameters* params = nullptr) const override;
100
101
 
101
102
  void reconstruct(idx_t key, float* recons) const override;
102
103
 
@@ -180,7 +181,8 @@ struct IndexHNSW2Level : IndexHNSW {
180
181
  const float* x,
181
182
  idx_t k,
182
183
  float* distances,
183
- idx_t* labels) const override;
184
+ idx_t* labels,
185
+ const SearchParameters* params = nullptr) const override;
184
186
  };
185
187
 
186
188
  } // namespace faiss
@@ -0,0 +1,247 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexIDMap.h>
11
+
12
+ #include <stdint.h>
13
+ #include <cinttypes>
14
+ #include <cstdio>
15
+ #include <limits>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/IDSelector.h>
20
+ #include <faiss/utils/Heap.h>
21
+ #include <faiss/utils/WorkerThread.h>
22
+
23
+ namespace faiss {
24
+
25
+ /*****************************************************
26
+ * IndexIDMap implementation
27
+ *******************************************************/
28
+
29
+ template <typename IndexT>
30
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
+ : index(index), own_fields(false) {
32
+ FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
+ this->is_trained = index->is_trained;
34
+ this->metric_type = index->metric_type;
35
+ this->verbose = index->verbose;
36
+ this->d = index->d;
37
+ }
38
+
39
+ template <typename IndexT>
40
+ void IndexIDMapTemplate<IndexT>::add(
41
+ idx_t,
42
+ const typename IndexT::component_t*) {
43
+ FAISS_THROW_MSG(
44
+ "add does not make sense with IndexIDMap, "
45
+ "use add_with_ids");
46
+ }
47
+
48
+ template <typename IndexT>
49
+ void IndexIDMapTemplate<IndexT>::train(
50
+ idx_t n,
51
+ const typename IndexT::component_t* x) {
52
+ index->train(n, x);
53
+ this->is_trained = index->is_trained;
54
+ }
55
+
56
+ template <typename IndexT>
57
+ void IndexIDMapTemplate<IndexT>::reset() {
58
+ index->reset();
59
+ id_map.clear();
60
+ this->ntotal = 0;
61
+ }
62
+
63
+ template <typename IndexT>
64
+ void IndexIDMapTemplate<IndexT>::add_with_ids(
65
+ idx_t n,
66
+ const typename IndexT::component_t* x,
67
+ const typename IndexT::idx_t* xids) {
68
+ index->add(n, x);
69
+ for (idx_t i = 0; i < n; i++)
70
+ id_map.push_back(xids[i]);
71
+ this->ntotal = index->ntotal;
72
+ }
73
+
74
+ template <typename IndexT>
75
+ void IndexIDMapTemplate<IndexT>::search(
76
+ idx_t n,
77
+ const typename IndexT::component_t* x,
78
+ idx_t k,
79
+ typename IndexT::distance_t* distances,
80
+ typename IndexT::idx_t* labels,
81
+ const SearchParameters* params) const {
82
+ FAISS_THROW_IF_NOT_MSG(
83
+ !params, "search params not supported for this index");
84
+ index->search(n, x, k, distances, labels);
85
+ idx_t* li = labels;
86
+ #pragma omp parallel for
87
+ for (idx_t i = 0; i < n * k; i++) {
88
+ li[i] = li[i] < 0 ? li[i] : id_map[li[i]];
89
+ }
90
+ }
91
+
92
+ template <typename IndexT>
93
+ void IndexIDMapTemplate<IndexT>::range_search(
94
+ typename IndexT::idx_t n,
95
+ const typename IndexT::component_t* x,
96
+ typename IndexT::distance_t radius,
97
+ RangeSearchResult* result,
98
+ const SearchParameters* params) const {
99
+ FAISS_THROW_IF_NOT_MSG(
100
+ !params, "search params not supported for this index");
101
+ index->range_search(n, x, radius, result);
102
+ #pragma omp parallel for
103
+ for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
+ result->labels[i] = result->labels[i] < 0 ? result->labels[i]
105
+ : id_map[result->labels[i]];
106
+ }
107
+ }
108
+
109
+ namespace {
110
+
111
+ struct IDTranslatedSelector : IDSelector {
112
+ const std::vector<int64_t>& id_map;
113
+ const IDSelector& sel;
114
+ IDTranslatedSelector(
115
+ const std::vector<int64_t>& id_map,
116
+ const IDSelector& sel)
117
+ : id_map(id_map), sel(sel) {}
118
+ bool is_member(idx_t id) const override {
119
+ return sel.is_member(id_map[id]);
120
+ }
121
+ };
122
+
123
+ } // namespace
124
+
125
+ template <typename IndexT>
126
+ size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
+ // remove in sub-index first
128
+ IDTranslatedSelector sel2(id_map, sel);
129
+ size_t nremove = index->remove_ids(sel2);
130
+
131
+ int64_t j = 0;
132
+ for (idx_t i = 0; i < this->ntotal; i++) {
133
+ if (sel.is_member(id_map[i])) {
134
+ // remove
135
+ } else {
136
+ id_map[j] = id_map[i];
137
+ j++;
138
+ }
139
+ }
140
+ FAISS_ASSERT(j == index->ntotal);
141
+ this->ntotal = j;
142
+ id_map.resize(this->ntotal);
143
+ return nremove;
144
+ }
145
+
146
+ template <typename IndexT>
147
+ void IndexIDMapTemplate<IndexT>::check_compatible_for_merge(
148
+ const IndexT& otherIndex) const {
149
+ auto other = dynamic_cast<const IndexIDMapTemplate<IndexT>*>(&otherIndex);
150
+ FAISS_THROW_IF_NOT(other);
151
+ index->check_compatible_for_merge(*other->index);
152
+ }
153
+
154
+ template <typename IndexT>
155
+ void IndexIDMapTemplate<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
156
+ check_compatible_for_merge(otherIndex);
157
+ auto other = static_cast<IndexIDMapTemplate<IndexT>*>(&otherIndex);
158
+ index->merge_from(*other->index);
159
+ for (size_t i = 0; i < other->id_map.size(); i++) {
160
+ id_map.push_back(other->id_map[i] + add_id);
161
+ }
162
+ other->id_map.resize(0);
163
+ this->ntotal = index->ntotal;
164
+ other->ntotal = 0;
165
+ }
166
+
167
+ template <typename IndexT>
168
+ IndexIDMapTemplate<IndexT>::~IndexIDMapTemplate() {
169
+ if (own_fields)
170
+ delete index;
171
+ }
172
+
173
+ /*****************************************************
174
+ * IndexIDMap2 implementation
175
+ *******************************************************/
176
+
177
+ template <typename IndexT>
178
+ IndexIDMap2Template<IndexT>::IndexIDMap2Template(IndexT* index)
179
+ : IndexIDMapTemplate<IndexT>(index) {}
180
+
181
+ template <typename IndexT>
182
+ void IndexIDMap2Template<IndexT>::add_with_ids(
183
+ idx_t n,
184
+ const typename IndexT::component_t* x,
185
+ const typename IndexT::idx_t* xids) {
186
+ size_t prev_ntotal = this->ntotal;
187
+ IndexIDMapTemplate<IndexT>::add_with_ids(n, x, xids);
188
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
189
+ rev_map[this->id_map[i]] = i;
190
+ }
191
+ }
192
+
193
+ template <typename IndexT>
194
+ void IndexIDMap2Template<IndexT>::check_consistency() const {
195
+ FAISS_THROW_IF_NOT(rev_map.size() == this->id_map.size());
196
+ FAISS_THROW_IF_NOT(this->id_map.size() == this->ntotal);
197
+ for (size_t i = 0; i < this->ntotal; i++) {
198
+ idx_t ii = rev_map.at(this->id_map[i]);
199
+ FAISS_THROW_IF_NOT(ii == i);
200
+ }
201
+ }
202
+
203
+ template <typename IndexT>
204
+ void IndexIDMap2Template<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
205
+ size_t prev_ntotal = this->ntotal;
206
+ IndexIDMapTemplate<IndexT>::merge_from(otherIndex, add_id);
207
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
208
+ rev_map[this->id_map[i]] = i;
209
+ }
210
+ static_cast<IndexIDMap2Template<IndexT>&>(otherIndex).rev_map.clear();
211
+ }
212
+
213
+ template <typename IndexT>
214
+ void IndexIDMap2Template<IndexT>::construct_rev_map() {
215
+ rev_map.clear();
216
+ for (size_t i = 0; i < this->ntotal; i++) {
217
+ rev_map[this->id_map[i]] = i;
218
+ }
219
+ }
220
+
221
+ template <typename IndexT>
222
+ size_t IndexIDMap2Template<IndexT>::remove_ids(const IDSelector& sel) {
223
+ // This is quite inefficient
224
+ size_t nremove = IndexIDMapTemplate<IndexT>::remove_ids(sel);
225
+ construct_rev_map();
226
+ return nremove;
227
+ }
228
+
229
+ template <typename IndexT>
230
+ void IndexIDMap2Template<IndexT>::reconstruct(
231
+ idx_t key,
232
+ typename IndexT::component_t* recons) const {
233
+ try {
234
+ this->index->reconstruct(rev_map.at(key), recons);
235
+ } catch (const std::out_of_range& e) {
236
+ FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
+ }
238
+ }
239
+
240
+ // explicit template instantiations
241
+
242
+ template struct IndexIDMapTemplate<Index>;
243
+ template struct IndexIDMapTemplate<IndexBinary>;
244
+ template struct IndexIDMap2Template<Index>;
245
+ template struct IndexIDMap2Template<IndexBinary>;
246
+
247
+ } // namespace faiss
@@ -0,0 +1,107 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/IndexBinary.h>
12
+
13
+ #include <unordered_map>
14
+ #include <vector>
15
+
16
+ namespace faiss {
17
+
18
+ /** Index that translates search results to ids */
19
+ template <typename IndexT>
20
+ struct IndexIDMapTemplate : IndexT {
21
+ using idx_t = typename IndexT::idx_t;
22
+ using component_t = typename IndexT::component_t;
23
+ using distance_t = typename IndexT::distance_t;
24
+
25
+ IndexT* index; ///! the sub-index
26
+ bool own_fields; ///! whether pointers are deleted in destructo
27
+ std::vector<idx_t> id_map;
28
+
29
+ explicit IndexIDMapTemplate(IndexT* index);
30
+
31
+ /// @param xids if non-null, ids to store for the vectors (size n)
32
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
33
+ override;
34
+
35
+ /// this will fail. Use add_with_ids
36
+ void add(idx_t n, const component_t* x) override;
37
+
38
+ void search(
39
+ idx_t n,
40
+ const component_t* x,
41
+ idx_t k,
42
+ distance_t* distances,
43
+ idx_t* labels,
44
+ const SearchParameters* params = nullptr) const override;
45
+
46
+ void train(idx_t n, const component_t* x) override;
47
+
48
+ void reset() override;
49
+
50
+ /// remove ids adapted to IndexFlat
51
+ size_t remove_ids(const IDSelector& sel) override;
52
+
53
+ void range_search(
54
+ idx_t n,
55
+ const component_t* x,
56
+ distance_t radius,
57
+ RangeSearchResult* result,
58
+ const SearchParameters* params = nullptr) const override;
59
+
60
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
61
+ void check_compatible_for_merge(const IndexT& otherIndex) const override;
62
+
63
+ ~IndexIDMapTemplate() override;
64
+ IndexIDMapTemplate() {
65
+ own_fields = false;
66
+ index = nullptr;
67
+ }
68
+ };
69
+
70
+ using IndexIDMap = IndexIDMapTemplate<Index>;
71
+ using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
72
+
73
+ /** same as IndexIDMap but also provides an efficient reconstruction
74
+ * implementation via a 2-way index */
75
+ template <typename IndexT>
76
+ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
77
+ using idx_t = typename IndexT::idx_t;
78
+ using component_t = typename IndexT::component_t;
79
+ using distance_t = typename IndexT::distance_t;
80
+
81
+ std::unordered_map<idx_t, idx_t> rev_map;
82
+
83
+ explicit IndexIDMap2Template(IndexT* index);
84
+
85
+ /// make the rev_map from scratch
86
+ void construct_rev_map();
87
+
88
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
89
+ override;
90
+
91
+ size_t remove_ids(const IDSelector& sel) override;
92
+
93
+ void reconstruct(idx_t key, component_t* recons) const override;
94
+
95
+ /// check that the rev_map and the id_map are in sync
96
+ void check_consistency() const;
97
+
98
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
99
+
100
+ ~IndexIDMap2Template() override {}
101
+ IndexIDMap2Template() {}
102
+ };
103
+
104
+ using IndexIDMap2 = IndexIDMap2Template<Index>;
105
+ using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
106
+
107
+ } // namespace faiss