faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -202,7 +202,10 @@ void hnsw_add_vertices(
202
202
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
203
203
  size_t counter = 0;
204
204
 
205
- #pragma omp for schedule(dynamic)
205
+ // here we should do schedule(dynamic) but this segfaults for
206
+ // some versions of LLVM. The performance impact should not be
207
+ // too large when (i1 - i0) / num_threads >> 1
208
+ #pragma omp for schedule(static)
206
209
  for (int i = i0; i < i1; i++) {
207
210
  storage_idx_t pt_id = order[i];
208
211
  dis->set_query(x + (pt_id - n0) * d);
@@ -219,7 +222,6 @@ void hnsw_add_vertices(
219
222
  printf(" %d / %d\r", i - i0, i1 - i0);
220
223
  fflush(stdout);
221
224
  }
222
-
223
225
  if (counter % check_period == 0) {
224
226
  if (InterruptCallback::is_interrupted()) {
225
227
  interrupt = true;
@@ -284,18 +286,24 @@ void IndexHNSW::search(
284
286
  const float* x,
285
287
  idx_t k,
286
288
  float* distances,
287
- idx_t* labels) const
288
-
289
- {
289
+ idx_t* labels,
290
+ const SearchParameters* params_in) const {
290
291
  FAISS_THROW_IF_NOT(k > 0);
291
-
292
292
  FAISS_THROW_IF_NOT_MSG(
293
293
  storage,
294
294
  "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
295
+ const SearchParametersHNSW* params = nullptr;
296
+
297
+ int efSearch = hnsw.efSearch;
298
+ if (params_in) {
299
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
300
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
301
+ efSearch = params->efSearch;
302
+ }
295
303
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
296
304
 
297
- idx_t check_period = InterruptCallback::get_period_hint(
298
- hnsw.max_level * d * hnsw.efSearch);
305
+ idx_t check_period =
306
+ InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
299
307
 
300
308
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
301
309
  idx_t i1 = std::min(i0 + check_period, n);
@@ -314,7 +322,7 @@ void IndexHNSW::search(
314
322
  dis->set_query(x + i * d);
315
323
 
316
324
  maxheap_heapify(k, simi, idxi);
317
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt);
325
+ HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
318
326
  n1 += stats.n1;
319
327
  n2 += stats.n2;
320
328
  n3 += stats.n3;
@@ -423,16 +431,15 @@ void IndexHNSW::search_level_0(
423
431
  FAISS_THROW_IF_NOT(nprobe > 0);
424
432
 
425
433
  storage_idx_t ntotal = hnsw.levels.size();
426
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
427
434
 
428
435
  #pragma omp parallel
429
436
  {
430
- DistanceComputer* qdis = storage_distance_computer(storage);
431
- ScopeDeleter1<DistanceComputer> del(qdis);
432
-
437
+ std::unique_ptr<DistanceComputer> qdis(
438
+ storage_distance_computer(storage));
439
+ HNSWStats search_stats;
433
440
  VisitedTable vt(ntotal);
434
441
 
435
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
442
+ #pragma omp for
436
443
  for (idx_t i = 0; i < n; i++) {
437
444
  idx_t* idxi = labels + i * k;
438
445
  float* simi = distances + i * k;
@@ -440,69 +447,24 @@ void IndexHNSW::search_level_0(
440
447
  qdis->set_query(x + i * d);
441
448
  maxheap_heapify(k, simi, idxi);
442
449
 
443
- if (search_type == 1) {
444
- int nres = 0;
445
-
446
- for (int j = 0; j < nprobe; j++) {
447
- storage_idx_t cj = nearest[i * nprobe + j];
448
-
449
- if (cj < 0)
450
- break;
451
-
452
- if (vt.get(cj))
453
- continue;
454
-
455
- int candidates_size = std::max(hnsw.efSearch, int(k));
456
- MinimaxHeap candidates(candidates_size);
457
-
458
- candidates.push(cj, nearest_d[i * nprobe + j]);
459
-
460
- HNSWStats search_stats;
461
- nres = hnsw.search_from_candidates(
462
- *qdis,
463
- k,
464
- idxi,
465
- simi,
466
- candidates,
467
- vt,
468
- search_stats,
469
- 0,
470
- nres);
471
- n1 += search_stats.n1;
472
- n2 += search_stats.n2;
473
- n3 += search_stats.n3;
474
- ndis += search_stats.ndis;
475
- nreorder += search_stats.nreorder;
476
- }
477
- } else if (search_type == 2) {
478
- int candidates_size = std::max(hnsw.efSearch, int(k));
479
- candidates_size = std::max(candidates_size, nprobe);
480
-
481
- MinimaxHeap candidates(candidates_size);
482
- for (int j = 0; j < nprobe; j++) {
483
- storage_idx_t cj = nearest[i * nprobe + j];
484
-
485
- if (cj < 0)
486
- break;
487
- candidates.push(cj, nearest_d[i * nprobe + j]);
488
- }
450
+ hnsw.search_level_0(
451
+ *qdis.get(),
452
+ k,
453
+ idxi,
454
+ simi,
455
+ nprobe,
456
+ nearest + i * nprobe,
457
+ nearest_d + i * nprobe,
458
+ search_type,
459
+ search_stats,
460
+ vt);
489
461
 
490
- HNSWStats search_stats;
491
- hnsw.search_from_candidates(
492
- *qdis, k, idxi, simi, candidates, vt, search_stats, 0);
493
- n1 += search_stats.n1;
494
- n2 += search_stats.n2;
495
- n3 += search_stats.n3;
496
- ndis += search_stats.ndis;
497
- nreorder += search_stats.nreorder;
498
- }
499
462
  vt.advance();
500
-
501
463
  maxheap_reorder(k, simi, idxi);
502
464
  }
465
+ #pragma omp critical
466
+ { hnsw_stats.combine(search_stats); }
503
467
  }
504
-
505
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
506
468
  }
507
469
 
508
470
  void IndexHNSW::init_level_0_from_knngraph(
@@ -1035,8 +997,11 @@ void IndexHNSW2Level::search(
1035
997
  const float* x,
1036
998
  idx_t k,
1037
999
  float* distances,
1038
- idx_t* labels) const {
1000
+ idx_t* labels,
1001
+ const SearchParameters* params) const {
1039
1002
  FAISS_THROW_IF_NOT(k > 0);
1003
+ FAISS_THROW_IF_NOT_MSG(
1004
+ !params, "search params not supported for this index");
1040
1005
 
1041
1006
  if (dynamic_cast<const Index2Layer*>(storage)) {
1042
1007
  IndexHNSW::search(n, x, k, distances, labels);
@@ -1095,74 +1060,37 @@ void IndexHNSW2Level::search(
1095
1060
  }
1096
1061
 
1097
1062
  candidates.clear();
1098
- // copy the upper_beam elements to candidates list
1099
-
1100
- int search_policy = 2;
1101
-
1102
- if (search_policy == 1) {
1103
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1104
- if (idxi[j] < 0)
1105
- break;
1106
- candidates.push(idxi[j], simi[j]);
1107
- // search_from_candidates adds them back
1108
- idxi[j] = -1;
1109
- simi[j] = HUGE_VAL;
1110
- }
1111
1063
 
1112
- // reorder from sorted to heap
1113
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1114
-
1115
- HNSWStats search_stats;
1116
- hnsw.search_from_candidates(
1117
- *dis,
1118
- k,
1119
- idxi,
1120
- simi,
1121
- candidates,
1122
- vt,
1123
- search_stats,
1124
- 0,
1125
- k);
1126
- n1 += search_stats.n1;
1127
- n2 += search_stats.n2;
1128
- n3 += search_stats.n3;
1129
- ndis += search_stats.ndis;
1130
- nreorder += search_stats.nreorder;
1131
-
1132
- vt.advance();
1133
-
1134
- } else if (search_policy == 2) {
1135
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1136
- if (idxi[j] < 0)
1137
- break;
1138
- candidates.push(idxi[j], simi[j]);
1139
- }
1140
-
1141
- // reorder from sorted to heap
1142
- maxheap_heapify(k, simi, idxi, simi, idxi, k);
1143
-
1144
- HNSWStats search_stats;
1145
- search_from_candidates_2(
1146
- hnsw,
1147
- *dis,
1148
- k,
1149
- idxi,
1150
- simi,
1151
- candidates,
1152
- vt,
1153
- search_stats,
1154
- 0,
1155
- k);
1156
- n1 += search_stats.n1;
1157
- n2 += search_stats.n2;
1158
- n3 += search_stats.n3;
1159
- ndis += search_stats.ndis;
1160
- nreorder += search_stats.nreorder;
1161
-
1162
- vt.advance();
1163
- vt.advance();
1064
+ for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1065
+ if (idxi[j] < 0)
1066
+ break;
1067
+ candidates.push(idxi[j], simi[j]);
1164
1068
  }
1165
1069
 
1070
+ // reorder from sorted to heap
1071
+ maxheap_heapify(k, simi, idxi, simi, idxi, k);
1072
+
1073
+ HNSWStats search_stats;
1074
+ search_from_candidates_2(
1075
+ hnsw,
1076
+ *dis,
1077
+ k,
1078
+ idxi,
1079
+ simi,
1080
+ candidates,
1081
+ vt,
1082
+ search_stats,
1083
+ 0,
1084
+ k);
1085
+ n1 += search_stats.n1;
1086
+ n2 += search_stats.n2;
1087
+ n3 += search_stats.n3;
1088
+ ndis += search_stats.ndis;
1089
+ nreorder += search_stats.nreorder;
1090
+
1091
+ vt.advance();
1092
+ vt.advance();
1093
+
1166
1094
  maxheap_reorder(k, simi, idxi);
1167
1095
  }
1168
1096
  }
@@ -96,7 +96,8 @@ struct IndexHNSW : Index {
96
96
  const float* x,
97
97
  idx_t k,
98
98
  float* distances,
99
- idx_t* labels) const override;
99
+ idx_t* labels,
100
+ const SearchParameters* params = nullptr) const override;
100
101
 
101
102
  void reconstruct(idx_t key, float* recons) const override;
102
103
 
@@ -180,7 +181,8 @@ struct IndexHNSW2Level : IndexHNSW {
180
181
  const float* x,
181
182
  idx_t k,
182
183
  float* distances,
183
- idx_t* labels) const override;
184
+ idx_t* labels,
185
+ const SearchParameters* params = nullptr) const override;
184
186
  };
185
187
 
186
188
  } // namespace faiss
@@ -0,0 +1,247 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexIDMap.h>
11
+
12
+ #include <stdint.h>
13
+ #include <cinttypes>
14
+ #include <cstdio>
15
+ #include <limits>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/IDSelector.h>
20
+ #include <faiss/utils/Heap.h>
21
+ #include <faiss/utils/WorkerThread.h>
22
+
23
+ namespace faiss {
24
+
25
+ /*****************************************************
26
+ * IndexIDMap implementation
27
+ *******************************************************/
28
+
29
+ template <typename IndexT>
30
+ IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
31
+ : index(index), own_fields(false) {
32
+ FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
33
+ this->is_trained = index->is_trained;
34
+ this->metric_type = index->metric_type;
35
+ this->verbose = index->verbose;
36
+ this->d = index->d;
37
+ }
38
+
39
+ template <typename IndexT>
40
+ void IndexIDMapTemplate<IndexT>::add(
41
+ idx_t,
42
+ const typename IndexT::component_t*) {
43
+ FAISS_THROW_MSG(
44
+ "add does not make sense with IndexIDMap, "
45
+ "use add_with_ids");
46
+ }
47
+
48
+ template <typename IndexT>
49
+ void IndexIDMapTemplate<IndexT>::train(
50
+ idx_t n,
51
+ const typename IndexT::component_t* x) {
52
+ index->train(n, x);
53
+ this->is_trained = index->is_trained;
54
+ }
55
+
56
+ template <typename IndexT>
57
+ void IndexIDMapTemplate<IndexT>::reset() {
58
+ index->reset();
59
+ id_map.clear();
60
+ this->ntotal = 0;
61
+ }
62
+
63
+ template <typename IndexT>
64
+ void IndexIDMapTemplate<IndexT>::add_with_ids(
65
+ idx_t n,
66
+ const typename IndexT::component_t* x,
67
+ const typename IndexT::idx_t* xids) {
68
+ index->add(n, x);
69
+ for (idx_t i = 0; i < n; i++)
70
+ id_map.push_back(xids[i]);
71
+ this->ntotal = index->ntotal;
72
+ }
73
+
74
+ template <typename IndexT>
75
+ void IndexIDMapTemplate<IndexT>::search(
76
+ idx_t n,
77
+ const typename IndexT::component_t* x,
78
+ idx_t k,
79
+ typename IndexT::distance_t* distances,
80
+ typename IndexT::idx_t* labels,
81
+ const SearchParameters* params) const {
82
+ FAISS_THROW_IF_NOT_MSG(
83
+ !params, "search params not supported for this index");
84
+ index->search(n, x, k, distances, labels);
85
+ idx_t* li = labels;
86
+ #pragma omp parallel for
87
+ for (idx_t i = 0; i < n * k; i++) {
88
+ li[i] = li[i] < 0 ? li[i] : id_map[li[i]];
89
+ }
90
+ }
91
+
92
+ template <typename IndexT>
93
+ void IndexIDMapTemplate<IndexT>::range_search(
94
+ typename IndexT::idx_t n,
95
+ const typename IndexT::component_t* x,
96
+ typename IndexT::distance_t radius,
97
+ RangeSearchResult* result,
98
+ const SearchParameters* params) const {
99
+ FAISS_THROW_IF_NOT_MSG(
100
+ !params, "search params not supported for this index");
101
+ index->range_search(n, x, radius, result);
102
+ #pragma omp parallel for
103
+ for (idx_t i = 0; i < result->lims[result->nq]; i++) {
104
+ result->labels[i] = result->labels[i] < 0 ? result->labels[i]
105
+ : id_map[result->labels[i]];
106
+ }
107
+ }
108
+
109
+ namespace {
110
+
111
+ struct IDTranslatedSelector : IDSelector {
112
+ const std::vector<int64_t>& id_map;
113
+ const IDSelector& sel;
114
+ IDTranslatedSelector(
115
+ const std::vector<int64_t>& id_map,
116
+ const IDSelector& sel)
117
+ : id_map(id_map), sel(sel) {}
118
+ bool is_member(idx_t id) const override {
119
+ return sel.is_member(id_map[id]);
120
+ }
121
+ };
122
+
123
+ } // namespace
124
+
125
+ template <typename IndexT>
126
+ size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127
+ // remove in sub-index first
128
+ IDTranslatedSelector sel2(id_map, sel);
129
+ size_t nremove = index->remove_ids(sel2);
130
+
131
+ int64_t j = 0;
132
+ for (idx_t i = 0; i < this->ntotal; i++) {
133
+ if (sel.is_member(id_map[i])) {
134
+ // remove
135
+ } else {
136
+ id_map[j] = id_map[i];
137
+ j++;
138
+ }
139
+ }
140
+ FAISS_ASSERT(j == index->ntotal);
141
+ this->ntotal = j;
142
+ id_map.resize(this->ntotal);
143
+ return nremove;
144
+ }
145
+
146
+ template <typename IndexT>
147
+ void IndexIDMapTemplate<IndexT>::check_compatible_for_merge(
148
+ const IndexT& otherIndex) const {
149
+ auto other = dynamic_cast<const IndexIDMapTemplate<IndexT>*>(&otherIndex);
150
+ FAISS_THROW_IF_NOT(other);
151
+ index->check_compatible_for_merge(*other->index);
152
+ }
153
+
154
+ template <typename IndexT>
155
+ void IndexIDMapTemplate<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
156
+ check_compatible_for_merge(otherIndex);
157
+ auto other = static_cast<IndexIDMapTemplate<IndexT>*>(&otherIndex);
158
+ index->merge_from(*other->index);
159
+ for (size_t i = 0; i < other->id_map.size(); i++) {
160
+ id_map.push_back(other->id_map[i] + add_id);
161
+ }
162
+ other->id_map.resize(0);
163
+ this->ntotal = index->ntotal;
164
+ other->ntotal = 0;
165
+ }
166
+
167
+ template <typename IndexT>
168
+ IndexIDMapTemplate<IndexT>::~IndexIDMapTemplate() {
169
+ if (own_fields)
170
+ delete index;
171
+ }
172
+
173
+ /*****************************************************
174
+ * IndexIDMap2 implementation
175
+ *******************************************************/
176
+
177
+ template <typename IndexT>
178
+ IndexIDMap2Template<IndexT>::IndexIDMap2Template(IndexT* index)
179
+ : IndexIDMapTemplate<IndexT>(index) {}
180
+
181
+ template <typename IndexT>
182
+ void IndexIDMap2Template<IndexT>::add_with_ids(
183
+ idx_t n,
184
+ const typename IndexT::component_t* x,
185
+ const typename IndexT::idx_t* xids) {
186
+ size_t prev_ntotal = this->ntotal;
187
+ IndexIDMapTemplate<IndexT>::add_with_ids(n, x, xids);
188
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
189
+ rev_map[this->id_map[i]] = i;
190
+ }
191
+ }
192
+
193
+ template <typename IndexT>
194
+ void IndexIDMap2Template<IndexT>::check_consistency() const {
195
+ FAISS_THROW_IF_NOT(rev_map.size() == this->id_map.size());
196
+ FAISS_THROW_IF_NOT(this->id_map.size() == this->ntotal);
197
+ for (size_t i = 0; i < this->ntotal; i++) {
198
+ idx_t ii = rev_map.at(this->id_map[i]);
199
+ FAISS_THROW_IF_NOT(ii == i);
200
+ }
201
+ }
202
+
203
+ template <typename IndexT>
204
+ void IndexIDMap2Template<IndexT>::merge_from(IndexT& otherIndex, idx_t add_id) {
205
+ size_t prev_ntotal = this->ntotal;
206
+ IndexIDMapTemplate<IndexT>::merge_from(otherIndex, add_id);
207
+ for (size_t i = prev_ntotal; i < this->ntotal; i++) {
208
+ rev_map[this->id_map[i]] = i;
209
+ }
210
+ static_cast<IndexIDMap2Template<IndexT>&>(otherIndex).rev_map.clear();
211
+ }
212
+
213
+ template <typename IndexT>
214
+ void IndexIDMap2Template<IndexT>::construct_rev_map() {
215
+ rev_map.clear();
216
+ for (size_t i = 0; i < this->ntotal; i++) {
217
+ rev_map[this->id_map[i]] = i;
218
+ }
219
+ }
220
+
221
+ template <typename IndexT>
222
+ size_t IndexIDMap2Template<IndexT>::remove_ids(const IDSelector& sel) {
223
+ // This is quite inefficient
224
+ size_t nremove = IndexIDMapTemplate<IndexT>::remove_ids(sel);
225
+ construct_rev_map();
226
+ return nremove;
227
+ }
228
+
229
+ template <typename IndexT>
230
+ void IndexIDMap2Template<IndexT>::reconstruct(
231
+ idx_t key,
232
+ typename IndexT::component_t* recons) const {
233
+ try {
234
+ this->index->reconstruct(rev_map.at(key), recons);
235
+ } catch (const std::out_of_range& e) {
236
+ FAISS_THROW_FMT("key %" PRId64 " not found", key);
237
+ }
238
+ }
239
+
240
+ // explicit template instantiations
241
+
242
+ template struct IndexIDMapTemplate<Index>;
243
+ template struct IndexIDMapTemplate<IndexBinary>;
244
+ template struct IndexIDMap2Template<Index>;
245
+ template struct IndexIDMap2Template<IndexBinary>;
246
+
247
+ } // namespace faiss
@@ -0,0 +1,107 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+ #include <faiss/IndexBinary.h>
12
+
13
+ #include <unordered_map>
14
+ #include <vector>
15
+
16
+ namespace faiss {
17
+
18
+ /** Index that translates search results to ids */
19
+ template <typename IndexT>
20
+ struct IndexIDMapTemplate : IndexT {
21
+ using idx_t = typename IndexT::idx_t;
22
+ using component_t = typename IndexT::component_t;
23
+ using distance_t = typename IndexT::distance_t;
24
+
25
+ IndexT* index; ///! the sub-index
26
+ bool own_fields; ///! whether pointers are deleted in destructo
27
+ std::vector<idx_t> id_map;
28
+
29
+ explicit IndexIDMapTemplate(IndexT* index);
30
+
31
+ /// @param xids if non-null, ids to store for the vectors (size n)
32
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
33
+ override;
34
+
35
+ /// this will fail. Use add_with_ids
36
+ void add(idx_t n, const component_t* x) override;
37
+
38
+ void search(
39
+ idx_t n,
40
+ const component_t* x,
41
+ idx_t k,
42
+ distance_t* distances,
43
+ idx_t* labels,
44
+ const SearchParameters* params = nullptr) const override;
45
+
46
+ void train(idx_t n, const component_t* x) override;
47
+
48
+ void reset() override;
49
+
50
+ /// remove ids adapted to IndexFlat
51
+ size_t remove_ids(const IDSelector& sel) override;
52
+
53
+ void range_search(
54
+ idx_t n,
55
+ const component_t* x,
56
+ distance_t radius,
57
+ RangeSearchResult* result,
58
+ const SearchParameters* params = nullptr) const override;
59
+
60
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
61
+ void check_compatible_for_merge(const IndexT& otherIndex) const override;
62
+
63
+ ~IndexIDMapTemplate() override;
64
+ IndexIDMapTemplate() {
65
+ own_fields = false;
66
+ index = nullptr;
67
+ }
68
+ };
69
+
70
+ using IndexIDMap = IndexIDMapTemplate<Index>;
71
+ using IndexBinaryIDMap = IndexIDMapTemplate<IndexBinary>;
72
+
73
+ /** same as IndexIDMap but also provides an efficient reconstruction
74
+ * implementation via a 2-way index */
75
+ template <typename IndexT>
76
+ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
77
+ using idx_t = typename IndexT::idx_t;
78
+ using component_t = typename IndexT::component_t;
79
+ using distance_t = typename IndexT::distance_t;
80
+
81
+ std::unordered_map<idx_t, idx_t> rev_map;
82
+
83
+ explicit IndexIDMap2Template(IndexT* index);
84
+
85
+ /// make the rev_map from scratch
86
+ void construct_rev_map();
87
+
88
+ void add_with_ids(idx_t n, const component_t* x, const idx_t* xids)
89
+ override;
90
+
91
+ size_t remove_ids(const IDSelector& sel) override;
92
+
93
+ void reconstruct(idx_t key, component_t* recons) const override;
94
+
95
+ /// check that the rev_map and the id_map are in sync
96
+ void check_consistency() const;
97
+
98
+ void merge_from(IndexT& otherIndex, idx_t add_id = 0) override;
99
+
100
+ ~IndexIDMap2Template() override {}
101
+ IndexIDMap2Template() {}
102
+ };
103
+
104
+ using IndexIDMap2 = IndexIDMap2Template<Index>;
105
+ using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
106
+
107
+ } // namespace faiss