faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -18,7 +18,9 @@ namespace faiss {
18
18
 
19
19
  /** Index that stores the full vectors and performs exhaustive search */
20
20
  struct IndexFlat : IndexFlatCodes {
21
- explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
21
+ explicit IndexFlat(
22
+ idx_t d, ///< dimensionality of the input vectors
23
+ MetricType metric = METRIC_L2);
22
24
 
23
25
  void search(
24
26
  idx_t n,
@@ -76,8 +78,25 @@ struct IndexFlatIP : IndexFlat {
76
78
  };
77
79
 
78
80
  struct IndexFlatL2 : IndexFlat {
81
+ // Special cache for L2 norms.
82
+ // If this cache is set, then get_distance_computer() returns
83
+ // a special version that computes the distance using dot products
84
+ // and l2 norms.
85
+ std::vector<float> cached_l2norms;
86
+
87
+ /**
88
+ * @param d dimensionality of the input vectors
89
+ */
79
90
  explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
80
91
  IndexFlatL2() {}
92
+
93
+ // override for l2 norms cache.
94
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
95
+
96
+ // compute L2 norms
97
+ void sync_l2norms();
98
+ // clear L2 norms
99
+ void clear_l2norms();
81
100
  };
82
101
 
83
102
  /// optimized version for 1D "vectors".
@@ -103,4 +103,15 @@ CodePacker* IndexFlatCodes::get_CodePacker() const {
103
103
  return new CodePackerFlat(code_size);
104
104
  }
105
105
 
106
+ void IndexFlatCodes::permute_entries(const idx_t* perm) {
107
+ std::vector<uint8_t> new_codes(codes.size());
108
+
109
+ for (idx_t i = 0; i < ntotal; i++) {
110
+ memcpy(new_codes.data() + i * code_size,
111
+ codes.data() + perm[i] * code_size,
112
+ code_size);
113
+ }
114
+ std::swap(codes, new_codes);
115
+ }
116
+
106
117
  } // namespace faiss
@@ -34,7 +34,6 @@ struct IndexFlatCodes : Index {
34
34
 
35
35
  void reset() override;
36
36
 
37
- /// reconstruction using the codec interface
38
37
  void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
39
38
 
40
39
  void reconstruct(idx_t key, float* recons) const override;
@@ -59,6 +58,9 @@ struct IndexFlatCodes : Index {
59
58
  void check_compatible_for_merge(const Index& otherIndex) const override;
60
59
 
61
60
  virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
61
+
62
+ // permute_entries. perm of size ntotal maps new to old positions
63
+ void permute_entries(const idx_t* perm);
62
64
  };
63
65
 
64
66
  } // namespace faiss
@@ -20,16 +20,16 @@
20
20
  #include <queue>
21
21
  #include <unordered_set>
22
22
 
23
- #include <stdint.h>
24
23
  #include <sys/stat.h>
25
24
  #include <sys/types.h>
25
+ #include <cstdint>
26
26
 
27
27
  #include <faiss/Index2Layer.h>
28
28
  #include <faiss/IndexFlat.h>
29
29
  #include <faiss/IndexIVFPQ.h>
30
30
  #include <faiss/impl/AuxIndexStructures.h>
31
31
  #include <faiss/impl/FaissAssert.h>
32
- #include <faiss/utils/Heap.h>
32
+ #include <faiss/impl/ResultHandler.h>
33
33
  #include <faiss/utils/distances.h>
34
34
  #include <faiss/utils/random.h>
35
35
  #include <faiss/utils/sorting.h>
@@ -87,6 +87,23 @@ struct NegativeDistanceComputer : DistanceComputer {
87
87
  return -(*basedis)(i);
88
88
  }
89
89
 
90
+ void distances_batch_4(
91
+ const idx_t idx0,
92
+ const idx_t idx1,
93
+ const idx_t idx2,
94
+ const idx_t idx3,
95
+ float& dis0,
96
+ float& dis1,
97
+ float& dis2,
98
+ float& dis3) override {
99
+ basedis->distances_batch_4(
100
+ idx0, idx1, idx2, idx3, dis0, dis1, dis2, dis3);
101
+ dis0 = -dis0;
102
+ dis1 = -dis1;
103
+ dis2 = -dis2;
104
+ dis3 = -dis3;
105
+ }
106
+
90
107
  /// compute distance between two stored vectors
91
108
  float symmetric_dis(idx_t i, idx_t j) override {
92
109
  return -basedis->symmetric_dis(i, j);
@@ -192,9 +209,8 @@ void hnsw_add_vertices(
192
209
  {
193
210
  VisitedTable vt(ntotal);
194
211
 
195
- DistanceComputer* dis =
196
- storage_distance_computer(index_hnsw.storage);
197
- ScopeDeleter1<DistanceComputer> del(dis);
212
+ std::unique_ptr<DistanceComputer> dis(
213
+ storage_distance_computer(index_hnsw.storage));
198
214
  int prev_display =
199
215
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
200
216
  size_t counter = 0;
@@ -250,18 +266,10 @@ void hnsw_add_vertices(
250
266
  **************************************************************/
251
267
 
252
268
  IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
253
- : Index(d, metric),
254
- hnsw(M),
255
- own_fields(false),
256
- storage(nullptr),
257
- reconstruct_from_neighbors(nullptr) {}
269
+ : Index(d, metric), hnsw(M) {}
258
270
 
259
271
  IndexHNSW::IndexHNSW(Index* storage, int M)
260
- : Index(storage->d, storage->metric_type),
261
- hnsw(M),
262
- own_fields(false),
263
- storage(storage),
264
- reconstruct_from_neighbors(nullptr) {}
272
+ : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {}
265
273
 
266
274
  IndexHNSW::~IndexHNSW() {
267
275
  if (own_fields) {
@@ -278,18 +286,20 @@ void IndexHNSW::train(idx_t n, const float* x) {
278
286
  is_trained = true;
279
287
  }
280
288
 
281
- void IndexHNSW::search(
289
+ namespace {
290
+
291
+ template <class BlockResultHandler>
292
+ void hnsw_search(
293
+ const IndexHNSW* index,
282
294
  idx_t n,
283
295
  const float* x,
284
- idx_t k,
285
- float* distances,
286
- idx_t* labels,
287
- const SearchParameters* params_in) const {
288
- FAISS_THROW_IF_NOT(k > 0);
296
+ BlockResultHandler& bres,
297
+ const SearchParameters* params_in) {
289
298
  FAISS_THROW_IF_NOT_MSG(
290
- storage,
299
+ index->storage,
291
300
  "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
292
301
  const SearchParametersHNSW* params = nullptr;
302
+ const HNSW& hnsw = index->hnsw;
293
303
 
294
304
  int efSearch = hnsw.efSearch;
295
305
  if (params_in) {
@@ -299,61 +309,81 @@ void IndexHNSW::search(
299
309
  }
300
310
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
301
311
 
302
- idx_t check_period =
303
- InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
312
+ idx_t check_period = InterruptCallback::get_period_hint(
313
+ hnsw.max_level * index->d * efSearch);
304
314
 
305
315
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
306
316
  idx_t i1 = std::min(i0 + check_period, n);
307
317
 
308
318
  #pragma omp parallel
309
319
  {
310
- VisitedTable vt(ntotal);
320
+ VisitedTable vt(index->ntotal);
321
+ typename BlockResultHandler::SingleResultHandler res(bres);
311
322
 
312
- DistanceComputer* dis = storage_distance_computer(storage);
313
- ScopeDeleter1<DistanceComputer> del(dis);
323
+ std::unique_ptr<DistanceComputer> dis(
324
+ storage_distance_computer(index->storage));
314
325
 
315
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
326
+ #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided)
316
327
  for (idx_t i = i0; i < i1; i++) {
317
- idx_t* idxi = labels + i * k;
318
- float* simi = distances + i * k;
319
- dis->set_query(x + i * d);
328
+ res.begin(i);
329
+ dis->set_query(x + i * index->d);
320
330
 
321
- maxheap_heapify(k, simi, idxi);
322
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
331
+ HNSWStats stats = hnsw.search(*dis, res, vt, params);
323
332
  n1 += stats.n1;
324
333
  n2 += stats.n2;
325
334
  n3 += stats.n3;
326
335
  ndis += stats.ndis;
327
336
  nreorder += stats.nreorder;
328
- maxheap_reorder(k, simi, idxi);
329
-
330
- if (reconstruct_from_neighbors &&
331
- reconstruct_from_neighbors->k_reorder != 0) {
332
- int k_reorder = reconstruct_from_neighbors->k_reorder;
333
- if (k_reorder == -1 || k_reorder > k)
334
- k_reorder = k;
335
-
336
- nreorder += reconstruct_from_neighbors->compute_distances(
337
- k_reorder, idxi, x + i * d, simi);
338
-
339
- // sort top k_reorder
340
- maxheap_heapify(
341
- k_reorder, simi, idxi, simi, idxi, k_reorder);
342
- maxheap_reorder(k_reorder, simi, idxi);
343
- }
337
+ res.end();
344
338
  }
345
339
  }
346
340
  InterruptCallback::check();
347
341
  }
348
342
 
349
- if (is_similarity_metric(metric_type)) {
343
+ hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
344
+ }
345
+
346
+ } // anonymous namespace
347
+
348
+ void IndexHNSW::search(
349
+ idx_t n,
350
+ const float* x,
351
+ idx_t k,
352
+ float* distances,
353
+ idx_t* labels,
354
+ const SearchParameters* params_in) const {
355
+ FAISS_THROW_IF_NOT(k > 0);
356
+
357
+ using RH = HeapBlockResultHandler<HNSW::C>;
358
+ RH bres(n, distances, labels, k);
359
+
360
+ hnsw_search(this, n, x, bres, params_in);
361
+
362
+ if (is_similarity_metric(this->metric_type)) {
350
363
  // we need to revert the negated distances
351
364
  for (size_t i = 0; i < k * n; i++) {
352
365
  distances[i] = -distances[i];
353
366
  }
354
367
  }
368
+ }
355
369
 
356
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
370
+ void IndexHNSW::range_search(
371
+ idx_t n,
372
+ const float* x,
373
+ float radius,
374
+ RangeSearchResult* result,
375
+ const SearchParameters* params) const {
376
+ using RH = RangeSearchBlockResultHandler<HNSW::C>;
377
+ RH bres(result, radius);
378
+
379
+ hnsw_search(this, n, x, bres, params);
380
+
381
+ if (is_similarity_metric(this->metric_type)) {
382
+ // we need to revert the negated distances
383
+ for (size_t i = 0; i < result->lims[result->nq]; i++) {
384
+ result->distances[i] = -result->distances[i];
385
+ }
386
+ }
357
387
  }
358
388
 
359
389
  void IndexHNSW::add(idx_t n, const float* x) {
@@ -381,8 +411,8 @@ void IndexHNSW::reconstruct(idx_t key, float* recons) const {
381
411
  void IndexHNSW::shrink_level_0_neighbors(int new_size) {
382
412
  #pragma omp parallel
383
413
  {
384
- DistanceComputer* dis = storage_distance_computer(storage);
385
- ScopeDeleter1<DistanceComputer> del(dis);
414
+ std::unique_ptr<DistanceComputer> dis(
415
+ storage_distance_computer(storage));
386
416
 
387
417
  #pragma omp for
388
418
  for (idx_t i = 0; i < ntotal; i++) {
@@ -429,35 +459,33 @@ void IndexHNSW::search_level_0(
429
459
 
430
460
  storage_idx_t ntotal = hnsw.levels.size();
431
461
 
462
+ using RH = HeapBlockResultHandler<HNSW::C>;
463
+ RH bres(n, distances, labels, k);
464
+
432
465
  #pragma omp parallel
433
466
  {
434
467
  std::unique_ptr<DistanceComputer> qdis(
435
468
  storage_distance_computer(storage));
436
469
  HNSWStats search_stats;
437
470
  VisitedTable vt(ntotal);
471
+ RH::SingleResultHandler res(bres);
438
472
 
439
473
  #pragma omp for
440
474
  for (idx_t i = 0; i < n; i++) {
441
- idx_t* idxi = labels + i * k;
442
- float* simi = distances + i * k;
443
-
475
+ res.begin(i);
444
476
  qdis->set_query(x + i * d);
445
- maxheap_heapify(k, simi, idxi);
446
477
 
447
478
  hnsw.search_level_0(
448
479
  *qdis.get(),
449
- k,
450
- idxi,
451
- simi,
480
+ res,
452
481
  nprobe,
453
482
  nearest + i * nprobe,
454
483
  nearest_d + i * nprobe,
455
484
  search_type,
456
485
  search_stats,
457
486
  vt);
458
-
487
+ res.end();
459
488
  vt.advance();
460
- maxheap_reorder(k, simi, idxi);
461
489
  }
462
490
  #pragma omp critical
463
491
  { hnsw_stats.combine(search_stats); }
@@ -515,8 +543,8 @@ void IndexHNSW::init_level_0_from_entry_points(
515
543
  {
516
544
  VisitedTable vt(ntotal);
517
545
 
518
- DistanceComputer* dis = storage_distance_computer(storage);
519
- ScopeDeleter1<DistanceComputer> del(dis);
546
+ std::unique_ptr<DistanceComputer> dis(
547
+ storage_distance_computer(storage));
520
548
  std::vector<float> vec(storage->d);
521
549
 
522
550
  #pragma omp for schedule(dynamic)
@@ -551,8 +579,8 @@ void IndexHNSW::reorder_links() {
551
579
  std::vector<float> distances(M);
552
580
  std::vector<size_t> order(M);
553
581
  std::vector<storage_idx_t> tmp(M);
554
- DistanceComputer* dis = storage_distance_computer(storage);
555
- ScopeDeleter1<DistanceComputer> del(dis);
582
+ std::unique_ptr<DistanceComputer> dis(
583
+ storage_distance_computer(storage));
556
584
 
557
585
  #pragma omp for
558
586
  for (storage_idx_t i = 0; i < ntotal; i++) {
@@ -614,245 +642,12 @@ void IndexHNSW::link_singletons() {
614
642
  }
615
643
  }
616
644
 
617
- /**************************************************************
618
- * ReconstructFromNeighbors implementation
619
- **************************************************************/
620
-
621
- ReconstructFromNeighbors::ReconstructFromNeighbors(
622
- const IndexHNSW& index,
623
- size_t k,
624
- size_t nsq)
625
- : index(index), k(k), nsq(nsq) {
626
- M = index.hnsw.nb_neighbors(0);
627
- FAISS_ASSERT(k <= 256);
628
- code_size = k == 1 ? 0 : nsq;
629
- ntotal = 0;
630
- d = index.d;
631
- FAISS_ASSERT(d % nsq == 0);
632
- dsub = d / nsq;
633
- k_reorder = -1;
634
- }
635
-
636
- void ReconstructFromNeighbors::reconstruct(
637
- storage_idx_t i,
638
- float* x,
639
- float* tmp) const {
640
- const HNSW& hnsw = index.hnsw;
641
- size_t begin, end;
642
- hnsw.neighbor_range(i, 0, &begin, &end);
643
-
644
- if (k == 1 || nsq == 1) {
645
- const float* beta;
646
- if (k == 1) {
647
- beta = codebook.data();
648
- } else {
649
- int idx = codes[i];
650
- beta = codebook.data() + idx * (M + 1);
651
- }
652
-
653
- float w0 = beta[0]; // weight of image itself
654
- index.storage->reconstruct(i, tmp);
655
-
656
- for (int l = 0; l < d; l++)
657
- x[l] = w0 * tmp[l];
658
-
659
- for (size_t j = begin; j < end; j++) {
660
- storage_idx_t ji = hnsw.neighbors[j];
661
- if (ji < 0)
662
- ji = i;
663
- float w = beta[j - begin + 1];
664
- index.storage->reconstruct(ji, tmp);
665
- for (int l = 0; l < d; l++)
666
- x[l] += w * tmp[l];
667
- }
668
- } else if (nsq == 2) {
669
- int idx0 = codes[2 * i];
670
- int idx1 = codes[2 * i + 1];
671
-
672
- const float* beta0 = codebook.data() + idx0 * (M + 1);
673
- const float* beta1 = codebook.data() + (idx1 + k) * (M + 1);
674
-
675
- index.storage->reconstruct(i, tmp);
676
-
677
- float w0;
678
-
679
- w0 = beta0[0];
680
- for (int l = 0; l < dsub; l++)
681
- x[l] = w0 * tmp[l];
682
-
683
- w0 = beta1[0];
684
- for (int l = dsub; l < d; l++)
685
- x[l] = w0 * tmp[l];
686
-
687
- for (size_t j = begin; j < end; j++) {
688
- storage_idx_t ji = hnsw.neighbors[j];
689
- if (ji < 0)
690
- ji = i;
691
- index.storage->reconstruct(ji, tmp);
692
- float w;
693
- w = beta0[j - begin + 1];
694
- for (int l = 0; l < dsub; l++)
695
- x[l] += w * tmp[l];
696
-
697
- w = beta1[j - begin + 1];
698
- for (int l = dsub; l < d; l++)
699
- x[l] += w * tmp[l];
700
- }
701
- } else {
702
- std::vector<const float*> betas(nsq);
703
- {
704
- const float* b = codebook.data();
705
- const uint8_t* c = &codes[i * code_size];
706
- for (int sq = 0; sq < nsq; sq++) {
707
- betas[sq] = b + (*c++) * (M + 1);
708
- b += (M + 1) * k;
709
- }
710
- }
711
-
712
- index.storage->reconstruct(i, tmp);
713
- {
714
- int d0 = 0;
715
- for (int sq = 0; sq < nsq; sq++) {
716
- float w = *(betas[sq]++);
717
- int d1 = d0 + dsub;
718
- for (int l = d0; l < d1; l++) {
719
- x[l] = w * tmp[l];
720
- }
721
- d0 = d1;
722
- }
723
- }
724
-
725
- for (size_t j = begin; j < end; j++) {
726
- storage_idx_t ji = hnsw.neighbors[j];
727
- if (ji < 0)
728
- ji = i;
729
-
730
- index.storage->reconstruct(ji, tmp);
731
- int d0 = 0;
732
- for (int sq = 0; sq < nsq; sq++) {
733
- float w = *(betas[sq]++);
734
- int d1 = d0 + dsub;
735
- for (int l = d0; l < d1; l++) {
736
- x[l] += w * tmp[l];
737
- }
738
- d0 = d1;
739
- }
740
- }
741
- }
742
- }
743
-
744
- void ReconstructFromNeighbors::reconstruct_n(
745
- storage_idx_t n0,
746
- storage_idx_t ni,
747
- float* x) const {
748
- #pragma omp parallel
749
- {
750
- std::vector<float> tmp(index.d);
751
- #pragma omp for
752
- for (storage_idx_t i = 0; i < ni; i++) {
753
- reconstruct(n0 + i, x + i * index.d, tmp.data());
754
- }
755
- }
756
- }
757
-
758
- size_t ReconstructFromNeighbors::compute_distances(
759
- size_t n,
760
- const idx_t* shortlist,
761
- const float* query,
762
- float* distances) const {
763
- std::vector<float> tmp(2 * index.d);
764
- size_t ncomp = 0;
765
- for (int i = 0; i < n; i++) {
766
- if (shortlist[i] < 0)
767
- break;
768
- reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
769
- distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
770
- ncomp++;
771
- }
772
- return ncomp;
773
- }
774
-
775
- void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float* tmp1)
776
- const {
777
- const HNSW& hnsw = index.hnsw;
778
- size_t begin, end;
779
- hnsw.neighbor_range(i, 0, &begin, &end);
780
- size_t d = index.d;
781
-
782
- index.storage->reconstruct(i, tmp1);
783
-
784
- for (size_t j = begin; j < end; j++) {
785
- storage_idx_t ji = hnsw.neighbors[j];
786
- if (ji < 0)
787
- ji = i;
788
- index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
789
- }
790
- }
791
-
792
- /// called by add_codes
793
- void ReconstructFromNeighbors::estimate_code(
794
- const float* x,
795
- storage_idx_t i,
796
- uint8_t* code) const {
797
- // fill in tmp table with the neighbor values
798
- float* tmp1 = new float[d * (M + 1) + (d * k)];
799
- float* tmp2 = tmp1 + d * (M + 1);
800
- ScopeDeleter<float> del(tmp1);
801
-
802
- // collect coordinates of base
803
- get_neighbor_table(i, tmp1);
804
-
805
- for (size_t sq = 0; sq < nsq; sq++) {
806
- int d0 = sq * dsub;
807
-
808
- {
809
- FINTEGER ki = k, di = d, m1 = M + 1;
810
- FINTEGER dsubi = dsub;
811
- float zero = 0, one = 1;
812
-
813
- sgemm_("N",
814
- "N",
815
- &dsubi,
816
- &ki,
817
- &m1,
818
- &one,
819
- tmp1 + d0,
820
- &di,
821
- codebook.data() + sq * (m1 * k),
822
- &m1,
823
- &zero,
824
- tmp2,
825
- &dsubi);
826
- }
827
-
828
- float min = HUGE_VAL;
829
- int argmin = -1;
830
- for (size_t j = 0; j < k; j++) {
831
- float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
832
- if (dis < min) {
833
- min = dis;
834
- argmin = j;
835
- }
836
- }
837
- code[sq] = argmin;
838
- }
839
- }
840
-
841
- void ReconstructFromNeighbors::add_codes(size_t n, const float* x) {
842
- if (k == 1) { // nothing to encode
843
- ntotal += n;
844
- return;
845
- }
846
- codes.resize(codes.size() + code_size * n);
847
- #pragma omp parallel for
848
- for (int i = 0; i < n; i++) {
849
- estimate_code(
850
- x + i * index.d,
851
- ntotal + i,
852
- codes.data() + (ntotal + i) * code_size);
853
- }
854
- ntotal += n;
855
- FAISS_ASSERT(codes.size() == ntotal * code_size);
645
+ void IndexHNSW::permute_entries(const idx_t* perm) {
646
+ auto flat_storage = dynamic_cast<IndexFlatCodes*>(storage);
647
+ FAISS_THROW_IF_NOT_MSG(
648
+ flat_storage, "don't know how to permute this index");
649
+ flat_storage->permute_entries(perm);
650
+ hnsw.permute_entries(perm);
856
651
  }
857
652
 
858
653
  /**************************************************************
@@ -864,7 +659,10 @@ IndexHNSWFlat::IndexHNSWFlat() {
864
659
  }
865
660
 
866
661
  IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
867
- : IndexHNSW(new IndexFlat(d, metric), M) {
662
+ : IndexHNSW(
663
+ (metric == METRIC_L2) ? new IndexFlatL2(d)
664
+ : new IndexFlat(d, metric),
665
+ M) {
868
666
  own_fields = true;
869
667
  is_trained = true;
870
668
  }
@@ -873,10 +671,10 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
873
671
  * IndexHNSWPQ implementation
874
672
  **************************************************************/
875
673
 
876
- IndexHNSWPQ::IndexHNSWPQ() {}
674
+ IndexHNSWPQ::IndexHNSWPQ() = default;
877
675
 
878
- IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M)
879
- : IndexHNSW(new IndexPQ(d, pq_m, 8), M) {
676
+ IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
677
+ : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
880
678
  own_fields = true;
881
679
  is_trained = false;
882
680
  }
@@ -896,11 +694,11 @@ IndexHNSWSQ::IndexHNSWSQ(
896
694
  int M,
897
695
  MetricType metric)
898
696
  : IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) {
899
- is_trained = false;
697
+ is_trained = this->storage->is_trained;
900
698
  own_fields = true;
901
699
  }
902
700
 
903
- IndexHNSWSQ::IndexHNSWSQ() {}
701
+ IndexHNSWSQ::IndexHNSWSQ() = default;
904
702
 
905
703
  /**************************************************************
906
704
  * IndexHNSW2Level implementation
@@ -916,7 +714,7 @@ IndexHNSW2Level::IndexHNSW2Level(
916
714
  is_trained = false;
917
715
  }
918
716
 
919
- IndexHNSW2Level::IndexHNSW2Level() {}
717
+ IndexHNSW2Level::IndexHNSW2Level() = default;
920
718
 
921
719
  namespace {
922
720
 
@@ -935,7 +733,6 @@ int search_from_candidates_2(
935
733
  int level,
936
734
  int nres_in = 0) {
937
735
  int nres = nres_in;
938
- int ndis = 0;
939
736
  for (int i = 0; i < candidates.size(); i++) {
940
737
  idx_t v1 = candidates.ids[i];
941
738
  FAISS_ASSERT(v1 >= 0);
@@ -958,7 +755,6 @@ int search_from_candidates_2(
958
755
  if (vt.visited[v1] == vt.visno + 1) {
959
756
  // nothing to do
960
757
  } else {
961
- ndis++;
962
758
  float d = qdis(v1);
963
759
  candidates.push(v1, d);
964
760
 
@@ -1030,8 +826,8 @@ void IndexHNSW2Level::search(
1030
826
  #pragma omp parallel
1031
827
  {
1032
828
  VisitedTable vt(ntotal);
1033
- DistanceComputer* dis = storage_distance_computer(storage);
1034
- ScopeDeleter1<DistanceComputer> del(dis);
829
+ std::unique_ptr<DistanceComputer> dis(
830
+ storage_distance_computer(storage));
1035
831
 
1036
832
  int candidates_size = hnsw.upper_beam;
1037
833
  MinimaxHeap candidates(candidates_size);