faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -18,7 +18,9 @@ namespace faiss {
18
18
 
19
19
  /** Index that stores the full vectors and performs exhaustive search */
20
20
  struct IndexFlat : IndexFlatCodes {
21
- explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
21
+ explicit IndexFlat(
22
+ idx_t d, ///< dimensionality of the input vectors
23
+ MetricType metric = METRIC_L2);
22
24
 
23
25
  void search(
24
26
  idx_t n,
@@ -76,8 +78,25 @@ struct IndexFlatIP : IndexFlat {
76
78
  };
77
79
 
78
80
  struct IndexFlatL2 : IndexFlat {
81
+ // Special cache for L2 norms.
82
+ // If this cache is set, then get_distance_computer() returns
83
+ // a special version that computes the distance using dot products
84
+ // and l2 norms.
85
+ std::vector<float> cached_l2norms;
86
+
87
+ /**
88
+ * @param d dimensionality of the input vectors
89
+ */
79
90
  explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
80
91
  IndexFlatL2() {}
92
+
93
+ // override for l2 norms cache.
94
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
95
+
96
+ // compute L2 norms
97
+ void sync_l2norms();
98
+ // clear L2 norms
99
+ void clear_l2norms();
81
100
  };
82
101
 
83
102
  /// optimized version for 1D "vectors".
@@ -103,4 +103,15 @@ CodePacker* IndexFlatCodes::get_CodePacker() const {
103
103
  return new CodePackerFlat(code_size);
104
104
  }
105
105
 
106
+ void IndexFlatCodes::permute_entries(const idx_t* perm) {
107
+ std::vector<uint8_t> new_codes(codes.size());
108
+
109
+ for (idx_t i = 0; i < ntotal; i++) {
110
+ memcpy(new_codes.data() + i * code_size,
111
+ codes.data() + perm[i] * code_size,
112
+ code_size);
113
+ }
114
+ std::swap(codes, new_codes);
115
+ }
116
+
106
117
  } // namespace faiss
@@ -34,7 +34,6 @@ struct IndexFlatCodes : Index {
34
34
 
35
35
  void reset() override;
36
36
 
37
- /// reconstruction using the codec interface
38
37
  void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
39
38
 
40
39
  void reconstruct(idx_t key, float* recons) const override;
@@ -59,6 +58,9 @@ struct IndexFlatCodes : Index {
59
58
  void check_compatible_for_merge(const Index& otherIndex) const override;
60
59
 
61
60
  virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override;
61
+
62
+ // permute_entries. perm of size ntotal maps new to old positions
63
+ void permute_entries(const idx_t* perm);
62
64
  };
63
65
 
64
66
  } // namespace faiss
@@ -20,16 +20,16 @@
20
20
  #include <queue>
21
21
  #include <unordered_set>
22
22
 
23
- #include <stdint.h>
24
23
  #include <sys/stat.h>
25
24
  #include <sys/types.h>
25
+ #include <cstdint>
26
26
 
27
27
  #include <faiss/Index2Layer.h>
28
28
  #include <faiss/IndexFlat.h>
29
29
  #include <faiss/IndexIVFPQ.h>
30
30
  #include <faiss/impl/AuxIndexStructures.h>
31
31
  #include <faiss/impl/FaissAssert.h>
32
- #include <faiss/utils/Heap.h>
32
+ #include <faiss/impl/ResultHandler.h>
33
33
  #include <faiss/utils/distances.h>
34
34
  #include <faiss/utils/random.h>
35
35
  #include <faiss/utils/sorting.h>
@@ -87,6 +87,23 @@ struct NegativeDistanceComputer : DistanceComputer {
87
87
  return -(*basedis)(i);
88
88
  }
89
89
 
90
+ void distances_batch_4(
91
+ const idx_t idx0,
92
+ const idx_t idx1,
93
+ const idx_t idx2,
94
+ const idx_t idx3,
95
+ float& dis0,
96
+ float& dis1,
97
+ float& dis2,
98
+ float& dis3) override {
99
+ basedis->distances_batch_4(
100
+ idx0, idx1, idx2, idx3, dis0, dis1, dis2, dis3);
101
+ dis0 = -dis0;
102
+ dis1 = -dis1;
103
+ dis2 = -dis2;
104
+ dis3 = -dis3;
105
+ }
106
+
90
107
  /// compute distance between two stored vectors
91
108
  float symmetric_dis(idx_t i, idx_t j) override {
92
109
  return -basedis->symmetric_dis(i, j);
@@ -192,9 +209,8 @@ void hnsw_add_vertices(
192
209
  {
193
210
  VisitedTable vt(ntotal);
194
211
 
195
- DistanceComputer* dis =
196
- storage_distance_computer(index_hnsw.storage);
197
- ScopeDeleter1<DistanceComputer> del(dis);
212
+ std::unique_ptr<DistanceComputer> dis(
213
+ storage_distance_computer(index_hnsw.storage));
198
214
  int prev_display =
199
215
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
200
216
  size_t counter = 0;
@@ -250,18 +266,10 @@ void hnsw_add_vertices(
250
266
  **************************************************************/
251
267
 
252
268
  IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
253
- : Index(d, metric),
254
- hnsw(M),
255
- own_fields(false),
256
- storage(nullptr),
257
- reconstruct_from_neighbors(nullptr) {}
269
+ : Index(d, metric), hnsw(M) {}
258
270
 
259
271
  IndexHNSW::IndexHNSW(Index* storage, int M)
260
- : Index(storage->d, storage->metric_type),
261
- hnsw(M),
262
- own_fields(false),
263
- storage(storage),
264
- reconstruct_from_neighbors(nullptr) {}
272
+ : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {}
265
273
 
266
274
  IndexHNSW::~IndexHNSW() {
267
275
  if (own_fields) {
@@ -278,18 +286,20 @@ void IndexHNSW::train(idx_t n, const float* x) {
278
286
  is_trained = true;
279
287
  }
280
288
 
281
- void IndexHNSW::search(
289
+ namespace {
290
+
291
+ template <class BlockResultHandler>
292
+ void hnsw_search(
293
+ const IndexHNSW* index,
282
294
  idx_t n,
283
295
  const float* x,
284
- idx_t k,
285
- float* distances,
286
- idx_t* labels,
287
- const SearchParameters* params_in) const {
288
- FAISS_THROW_IF_NOT(k > 0);
296
+ BlockResultHandler& bres,
297
+ const SearchParameters* params_in) {
289
298
  FAISS_THROW_IF_NOT_MSG(
290
- storage,
299
+ index->storage,
291
300
  "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
292
301
  const SearchParametersHNSW* params = nullptr;
302
+ const HNSW& hnsw = index->hnsw;
293
303
 
294
304
  int efSearch = hnsw.efSearch;
295
305
  if (params_in) {
@@ -299,61 +309,81 @@ void IndexHNSW::search(
299
309
  }
300
310
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
301
311
 
302
- idx_t check_period =
303
- InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
312
+ idx_t check_period = InterruptCallback::get_period_hint(
313
+ hnsw.max_level * index->d * efSearch);
304
314
 
305
315
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
306
316
  idx_t i1 = std::min(i0 + check_period, n);
307
317
 
308
318
  #pragma omp parallel
309
319
  {
310
- VisitedTable vt(ntotal);
320
+ VisitedTable vt(index->ntotal);
321
+ typename BlockResultHandler::SingleResultHandler res(bres);
311
322
 
312
- DistanceComputer* dis = storage_distance_computer(storage);
313
- ScopeDeleter1<DistanceComputer> del(dis);
323
+ std::unique_ptr<DistanceComputer> dis(
324
+ storage_distance_computer(index->storage));
314
325
 
315
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
326
+ #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder) schedule(guided)
316
327
  for (idx_t i = i0; i < i1; i++) {
317
- idx_t* idxi = labels + i * k;
318
- float* simi = distances + i * k;
319
- dis->set_query(x + i * d);
328
+ res.begin(i);
329
+ dis->set_query(x + i * index->d);
320
330
 
321
- maxheap_heapify(k, simi, idxi);
322
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
331
+ HNSWStats stats = hnsw.search(*dis, res, vt, params);
323
332
  n1 += stats.n1;
324
333
  n2 += stats.n2;
325
334
  n3 += stats.n3;
326
335
  ndis += stats.ndis;
327
336
  nreorder += stats.nreorder;
328
- maxheap_reorder(k, simi, idxi);
329
-
330
- if (reconstruct_from_neighbors &&
331
- reconstruct_from_neighbors->k_reorder != 0) {
332
- int k_reorder = reconstruct_from_neighbors->k_reorder;
333
- if (k_reorder == -1 || k_reorder > k)
334
- k_reorder = k;
335
-
336
- nreorder += reconstruct_from_neighbors->compute_distances(
337
- k_reorder, idxi, x + i * d, simi);
338
-
339
- // sort top k_reorder
340
- maxheap_heapify(
341
- k_reorder, simi, idxi, simi, idxi, k_reorder);
342
- maxheap_reorder(k_reorder, simi, idxi);
343
- }
337
+ res.end();
344
338
  }
345
339
  }
346
340
  InterruptCallback::check();
347
341
  }
348
342
 
349
- if (is_similarity_metric(metric_type)) {
343
+ hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
344
+ }
345
+
346
+ } // anonymous namespace
347
+
348
+ void IndexHNSW::search(
349
+ idx_t n,
350
+ const float* x,
351
+ idx_t k,
352
+ float* distances,
353
+ idx_t* labels,
354
+ const SearchParameters* params_in) const {
355
+ FAISS_THROW_IF_NOT(k > 0);
356
+
357
+ using RH = HeapBlockResultHandler<HNSW::C>;
358
+ RH bres(n, distances, labels, k);
359
+
360
+ hnsw_search(this, n, x, bres, params_in);
361
+
362
+ if (is_similarity_metric(this->metric_type)) {
350
363
  // we need to revert the negated distances
351
364
  for (size_t i = 0; i < k * n; i++) {
352
365
  distances[i] = -distances[i];
353
366
  }
354
367
  }
368
+ }
355
369
 
356
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
370
+ void IndexHNSW::range_search(
371
+ idx_t n,
372
+ const float* x,
373
+ float radius,
374
+ RangeSearchResult* result,
375
+ const SearchParameters* params) const {
376
+ using RH = RangeSearchBlockResultHandler<HNSW::C>;
377
+ RH bres(result, radius);
378
+
379
+ hnsw_search(this, n, x, bres, params);
380
+
381
+ if (is_similarity_metric(this->metric_type)) {
382
+ // we need to revert the negated distances
383
+ for (size_t i = 0; i < result->lims[result->nq]; i++) {
384
+ result->distances[i] = -result->distances[i];
385
+ }
386
+ }
357
387
  }
358
388
 
359
389
  void IndexHNSW::add(idx_t n, const float* x) {
@@ -381,8 +411,8 @@ void IndexHNSW::reconstruct(idx_t key, float* recons) const {
381
411
  void IndexHNSW::shrink_level_0_neighbors(int new_size) {
382
412
  #pragma omp parallel
383
413
  {
384
- DistanceComputer* dis = storage_distance_computer(storage);
385
- ScopeDeleter1<DistanceComputer> del(dis);
414
+ std::unique_ptr<DistanceComputer> dis(
415
+ storage_distance_computer(storage));
386
416
 
387
417
  #pragma omp for
388
418
  for (idx_t i = 0; i < ntotal; i++) {
@@ -429,35 +459,33 @@ void IndexHNSW::search_level_0(
429
459
 
430
460
  storage_idx_t ntotal = hnsw.levels.size();
431
461
 
462
+ using RH = HeapBlockResultHandler<HNSW::C>;
463
+ RH bres(n, distances, labels, k);
464
+
432
465
  #pragma omp parallel
433
466
  {
434
467
  std::unique_ptr<DistanceComputer> qdis(
435
468
  storage_distance_computer(storage));
436
469
  HNSWStats search_stats;
437
470
  VisitedTable vt(ntotal);
471
+ RH::SingleResultHandler res(bres);
438
472
 
439
473
  #pragma omp for
440
474
  for (idx_t i = 0; i < n; i++) {
441
- idx_t* idxi = labels + i * k;
442
- float* simi = distances + i * k;
443
-
475
+ res.begin(i);
444
476
  qdis->set_query(x + i * d);
445
- maxheap_heapify(k, simi, idxi);
446
477
 
447
478
  hnsw.search_level_0(
448
479
  *qdis.get(),
449
- k,
450
- idxi,
451
- simi,
480
+ res,
452
481
  nprobe,
453
482
  nearest + i * nprobe,
454
483
  nearest_d + i * nprobe,
455
484
  search_type,
456
485
  search_stats,
457
486
  vt);
458
-
487
+ res.end();
459
488
  vt.advance();
460
- maxheap_reorder(k, simi, idxi);
461
489
  }
462
490
  #pragma omp critical
463
491
  { hnsw_stats.combine(search_stats); }
@@ -515,8 +543,8 @@ void IndexHNSW::init_level_0_from_entry_points(
515
543
  {
516
544
  VisitedTable vt(ntotal);
517
545
 
518
- DistanceComputer* dis = storage_distance_computer(storage);
519
- ScopeDeleter1<DistanceComputer> del(dis);
546
+ std::unique_ptr<DistanceComputer> dis(
547
+ storage_distance_computer(storage));
520
548
  std::vector<float> vec(storage->d);
521
549
 
522
550
  #pragma omp for schedule(dynamic)
@@ -551,8 +579,8 @@ void IndexHNSW::reorder_links() {
551
579
  std::vector<float> distances(M);
552
580
  std::vector<size_t> order(M);
553
581
  std::vector<storage_idx_t> tmp(M);
554
- DistanceComputer* dis = storage_distance_computer(storage);
555
- ScopeDeleter1<DistanceComputer> del(dis);
582
+ std::unique_ptr<DistanceComputer> dis(
583
+ storage_distance_computer(storage));
556
584
 
557
585
  #pragma omp for
558
586
  for (storage_idx_t i = 0; i < ntotal; i++) {
@@ -614,245 +642,12 @@ void IndexHNSW::link_singletons() {
614
642
  }
615
643
  }
616
644
 
617
- /**************************************************************
618
- * ReconstructFromNeighbors implementation
619
- **************************************************************/
620
-
621
- ReconstructFromNeighbors::ReconstructFromNeighbors(
622
- const IndexHNSW& index,
623
- size_t k,
624
- size_t nsq)
625
- : index(index), k(k), nsq(nsq) {
626
- M = index.hnsw.nb_neighbors(0);
627
- FAISS_ASSERT(k <= 256);
628
- code_size = k == 1 ? 0 : nsq;
629
- ntotal = 0;
630
- d = index.d;
631
- FAISS_ASSERT(d % nsq == 0);
632
- dsub = d / nsq;
633
- k_reorder = -1;
634
- }
635
-
636
- void ReconstructFromNeighbors::reconstruct(
637
- storage_idx_t i,
638
- float* x,
639
- float* tmp) const {
640
- const HNSW& hnsw = index.hnsw;
641
- size_t begin, end;
642
- hnsw.neighbor_range(i, 0, &begin, &end);
643
-
644
- if (k == 1 || nsq == 1) {
645
- const float* beta;
646
- if (k == 1) {
647
- beta = codebook.data();
648
- } else {
649
- int idx = codes[i];
650
- beta = codebook.data() + idx * (M + 1);
651
- }
652
-
653
- float w0 = beta[0]; // weight of image itself
654
- index.storage->reconstruct(i, tmp);
655
-
656
- for (int l = 0; l < d; l++)
657
- x[l] = w0 * tmp[l];
658
-
659
- for (size_t j = begin; j < end; j++) {
660
- storage_idx_t ji = hnsw.neighbors[j];
661
- if (ji < 0)
662
- ji = i;
663
- float w = beta[j - begin + 1];
664
- index.storage->reconstruct(ji, tmp);
665
- for (int l = 0; l < d; l++)
666
- x[l] += w * tmp[l];
667
- }
668
- } else if (nsq == 2) {
669
- int idx0 = codes[2 * i];
670
- int idx1 = codes[2 * i + 1];
671
-
672
- const float* beta0 = codebook.data() + idx0 * (M + 1);
673
- const float* beta1 = codebook.data() + (idx1 + k) * (M + 1);
674
-
675
- index.storage->reconstruct(i, tmp);
676
-
677
- float w0;
678
-
679
- w0 = beta0[0];
680
- for (int l = 0; l < dsub; l++)
681
- x[l] = w0 * tmp[l];
682
-
683
- w0 = beta1[0];
684
- for (int l = dsub; l < d; l++)
685
- x[l] = w0 * tmp[l];
686
-
687
- for (size_t j = begin; j < end; j++) {
688
- storage_idx_t ji = hnsw.neighbors[j];
689
- if (ji < 0)
690
- ji = i;
691
- index.storage->reconstruct(ji, tmp);
692
- float w;
693
- w = beta0[j - begin + 1];
694
- for (int l = 0; l < dsub; l++)
695
- x[l] += w * tmp[l];
696
-
697
- w = beta1[j - begin + 1];
698
- for (int l = dsub; l < d; l++)
699
- x[l] += w * tmp[l];
700
- }
701
- } else {
702
- std::vector<const float*> betas(nsq);
703
- {
704
- const float* b = codebook.data();
705
- const uint8_t* c = &codes[i * code_size];
706
- for (int sq = 0; sq < nsq; sq++) {
707
- betas[sq] = b + (*c++) * (M + 1);
708
- b += (M + 1) * k;
709
- }
710
- }
711
-
712
- index.storage->reconstruct(i, tmp);
713
- {
714
- int d0 = 0;
715
- for (int sq = 0; sq < nsq; sq++) {
716
- float w = *(betas[sq]++);
717
- int d1 = d0 + dsub;
718
- for (int l = d0; l < d1; l++) {
719
- x[l] = w * tmp[l];
720
- }
721
- d0 = d1;
722
- }
723
- }
724
-
725
- for (size_t j = begin; j < end; j++) {
726
- storage_idx_t ji = hnsw.neighbors[j];
727
- if (ji < 0)
728
- ji = i;
729
-
730
- index.storage->reconstruct(ji, tmp);
731
- int d0 = 0;
732
- for (int sq = 0; sq < nsq; sq++) {
733
- float w = *(betas[sq]++);
734
- int d1 = d0 + dsub;
735
- for (int l = d0; l < d1; l++) {
736
- x[l] += w * tmp[l];
737
- }
738
- d0 = d1;
739
- }
740
- }
741
- }
742
- }
743
-
744
- void ReconstructFromNeighbors::reconstruct_n(
745
- storage_idx_t n0,
746
- storage_idx_t ni,
747
- float* x) const {
748
- #pragma omp parallel
749
- {
750
- std::vector<float> tmp(index.d);
751
- #pragma omp for
752
- for (storage_idx_t i = 0; i < ni; i++) {
753
- reconstruct(n0 + i, x + i * index.d, tmp.data());
754
- }
755
- }
756
- }
757
-
758
- size_t ReconstructFromNeighbors::compute_distances(
759
- size_t n,
760
- const idx_t* shortlist,
761
- const float* query,
762
- float* distances) const {
763
- std::vector<float> tmp(2 * index.d);
764
- size_t ncomp = 0;
765
- for (int i = 0; i < n; i++) {
766
- if (shortlist[i] < 0)
767
- break;
768
- reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
769
- distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
770
- ncomp++;
771
- }
772
- return ncomp;
773
- }
774
-
775
- void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float* tmp1)
776
- const {
777
- const HNSW& hnsw = index.hnsw;
778
- size_t begin, end;
779
- hnsw.neighbor_range(i, 0, &begin, &end);
780
- size_t d = index.d;
781
-
782
- index.storage->reconstruct(i, tmp1);
783
-
784
- for (size_t j = begin; j < end; j++) {
785
- storage_idx_t ji = hnsw.neighbors[j];
786
- if (ji < 0)
787
- ji = i;
788
- index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
789
- }
790
- }
791
-
792
- /// called by add_codes
793
- void ReconstructFromNeighbors::estimate_code(
794
- const float* x,
795
- storage_idx_t i,
796
- uint8_t* code) const {
797
- // fill in tmp table with the neighbor values
798
- float* tmp1 = new float[d * (M + 1) + (d * k)];
799
- float* tmp2 = tmp1 + d * (M + 1);
800
- ScopeDeleter<float> del(tmp1);
801
-
802
- // collect coordinates of base
803
- get_neighbor_table(i, tmp1);
804
-
805
- for (size_t sq = 0; sq < nsq; sq++) {
806
- int d0 = sq * dsub;
807
-
808
- {
809
- FINTEGER ki = k, di = d, m1 = M + 1;
810
- FINTEGER dsubi = dsub;
811
- float zero = 0, one = 1;
812
-
813
- sgemm_("N",
814
- "N",
815
- &dsubi,
816
- &ki,
817
- &m1,
818
- &one,
819
- tmp1 + d0,
820
- &di,
821
- codebook.data() + sq * (m1 * k),
822
- &m1,
823
- &zero,
824
- tmp2,
825
- &dsubi);
826
- }
827
-
828
- float min = HUGE_VAL;
829
- int argmin = -1;
830
- for (size_t j = 0; j < k; j++) {
831
- float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
832
- if (dis < min) {
833
- min = dis;
834
- argmin = j;
835
- }
836
- }
837
- code[sq] = argmin;
838
- }
839
- }
840
-
841
- void ReconstructFromNeighbors::add_codes(size_t n, const float* x) {
842
- if (k == 1) { // nothing to encode
843
- ntotal += n;
844
- return;
845
- }
846
- codes.resize(codes.size() + code_size * n);
847
- #pragma omp parallel for
848
- for (int i = 0; i < n; i++) {
849
- estimate_code(
850
- x + i * index.d,
851
- ntotal + i,
852
- codes.data() + (ntotal + i) * code_size);
853
- }
854
- ntotal += n;
855
- FAISS_ASSERT(codes.size() == ntotal * code_size);
645
+ void IndexHNSW::permute_entries(const idx_t* perm) {
646
+ auto flat_storage = dynamic_cast<IndexFlatCodes*>(storage);
647
+ FAISS_THROW_IF_NOT_MSG(
648
+ flat_storage, "don't know how to permute this index");
649
+ flat_storage->permute_entries(perm);
650
+ hnsw.permute_entries(perm);
856
651
  }
857
652
 
858
653
  /**************************************************************
@@ -864,7 +659,10 @@ IndexHNSWFlat::IndexHNSWFlat() {
864
659
  }
865
660
 
866
661
  IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
867
- : IndexHNSW(new IndexFlat(d, metric), M) {
662
+ : IndexHNSW(
663
+ (metric == METRIC_L2) ? new IndexFlatL2(d)
664
+ : new IndexFlat(d, metric),
665
+ M) {
868
666
  own_fields = true;
869
667
  is_trained = true;
870
668
  }
@@ -873,10 +671,10 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
873
671
  * IndexHNSWPQ implementation
874
672
  **************************************************************/
875
673
 
876
- IndexHNSWPQ::IndexHNSWPQ() {}
674
+ IndexHNSWPQ::IndexHNSWPQ() = default;
877
675
 
878
- IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M)
879
- : IndexHNSW(new IndexPQ(d, pq_m, 8), M) {
676
+ IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
677
+ : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
880
678
  own_fields = true;
881
679
  is_trained = false;
882
680
  }
@@ -896,11 +694,11 @@ IndexHNSWSQ::IndexHNSWSQ(
896
694
  int M,
897
695
  MetricType metric)
898
696
  : IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) {
899
- is_trained = false;
697
+ is_trained = this->storage->is_trained;
900
698
  own_fields = true;
901
699
  }
902
700
 
903
- IndexHNSWSQ::IndexHNSWSQ() {}
701
+ IndexHNSWSQ::IndexHNSWSQ() = default;
904
702
 
905
703
  /**************************************************************
906
704
  * IndexHNSW2Level implementation
@@ -916,7 +714,7 @@ IndexHNSW2Level::IndexHNSW2Level(
916
714
  is_trained = false;
917
715
  }
918
716
 
919
- IndexHNSW2Level::IndexHNSW2Level() {}
717
+ IndexHNSW2Level::IndexHNSW2Level() = default;
920
718
 
921
719
  namespace {
922
720
 
@@ -935,7 +733,6 @@ int search_from_candidates_2(
935
733
  int level,
936
734
  int nres_in = 0) {
937
735
  int nres = nres_in;
938
- int ndis = 0;
939
736
  for (int i = 0; i < candidates.size(); i++) {
940
737
  idx_t v1 = candidates.ids[i];
941
738
  FAISS_ASSERT(v1 >= 0);
@@ -958,7 +755,6 @@ int search_from_candidates_2(
958
755
  if (vt.visited[v1] == vt.visno + 1) {
959
756
  // nothing to do
960
757
  } else {
961
- ndis++;
962
758
  float d = qdis(v1);
963
759
  candidates.push(v1, d);
964
760
 
@@ -1030,8 +826,8 @@ void IndexHNSW2Level::search(
1030
826
  #pragma omp parallel
1031
827
  {
1032
828
  VisitedTable vt(ntotal);
1033
- DistanceComputer* dis = storage_distance_computer(storage);
1034
- ScopeDeleter1<DistanceComputer> del(dis);
829
+ std::unique_ptr<DistanceComputer> dis(
830
+ storage_distance_computer(storage));
1035
831
 
1036
832
  int candidates_size = hnsw.upper_beam;
1037
833
  MinimaxHeap candidates(candidates_size);