faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -5,15 +5,25 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/impl/HNSW.h>
11
9
 
10
+ #include <cstddef>
12
11
  #include <string>
13
12
 
14
13
  #include <faiss/impl/AuxIndexStructures.h>
15
14
  #include <faiss/impl/DistanceComputer.h>
16
15
  #include <faiss/impl/IDSelector.h>
16
+ #include <faiss/impl/ResultHandler.h>
17
+ #include <faiss/utils/prefetch.h>
18
+
19
+ #include <faiss/impl/platform_macros.h>
20
+
21
+ #ifdef __AVX2__
22
+ #include <immintrin.h>
23
+
24
+ #include <limits>
25
+ #include <type_traits>
26
+ #endif
17
27
 
18
28
  namespace faiss {
19
29
 
@@ -101,8 +111,8 @@ void HNSW::print_neighbor_stats(int level) const {
101
111
  level,
102
112
  nb_neighbors(level));
103
113
  size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
104
- #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
105
- reduction(+: tot_reciprocal) reduction(+: n_node)
114
+ #pragma omp parallel for reduction(+ : tot_neigh) reduction(+ : tot_common) \
115
+ reduction(+ : tot_reciprocal) reduction(+ : n_node)
106
116
  for (int i = 0; i < levels.size(); i++) {
107
117
  if (levels[i] > level) {
108
118
  n_node++;
@@ -206,13 +216,13 @@ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
206
216
  if (pt_level > max_level)
207
217
  max_level = pt_level;
208
218
  offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
209
- neighbors.resize(offsets.back(), -1);
210
219
  }
220
+ neighbors.resize(offsets.back(), -1);
211
221
 
212
222
  return max_level;
213
223
  }
214
224
 
215
- /** Enumerate vertices from farthest to nearest from query, keep a
225
+ /** Enumerate vertices from nearest to farthest from query, keep a
216
226
  * neighbor only if there is no previous neighbor that is closer to
217
227
  * that vertex than the query.
218
228
  */
@@ -220,7 +230,14 @@ void HNSW::shrink_neighbor_list(
220
230
  DistanceComputer& qdis,
221
231
  std::priority_queue<NodeDistFarther>& input,
222
232
  std::vector<NodeDistFarther>& output,
223
- int max_size) {
233
+ int max_size,
234
+ bool keep_max_size_level0) {
235
+ // This prevents number of neighbors at
236
+ // level 0 from being shrunk to less than 2 * M.
237
+ // This is essential in making sure
238
+ // `faiss::gpu::GpuIndexCagra::copyFrom(IndexHNSWCagra*)` is functional
239
+ std::vector<NodeDistFarther> outsiders;
240
+
224
241
  while (input.size() > 0) {
225
242
  NodeDistFarther v1 = input.top();
226
243
  input.pop();
@@ -241,8 +258,15 @@ void HNSW::shrink_neighbor_list(
241
258
  if (output.size() >= max_size) {
242
259
  return;
243
260
  }
261
+ } else if (keep_max_size_level0) {
262
+ outsiders.push_back(v1);
244
263
  }
245
264
  }
265
+ size_t idx = 0;
266
+ while (keep_max_size_level0 && (output.size() < max_size) &&
267
+ (idx < outsiders.size())) {
268
+ output.push_back(outsiders[idx++]);
269
+ }
246
270
  }
247
271
 
248
272
  namespace {
@@ -259,7 +283,8 @@ using NodeDistFarther = HNSW::NodeDistFarther;
259
283
  void shrink_neighbor_list(
260
284
  DistanceComputer& qdis,
261
285
  std::priority_queue<NodeDistCloser>& resultSet1,
262
- int max_size) {
286
+ int max_size,
287
+ bool keep_max_size_level0 = false) {
263
288
  if (resultSet1.size() < max_size) {
264
289
  return;
265
290
  }
@@ -271,7 +296,8 @@ void shrink_neighbor_list(
271
296
  resultSet1.pop();
272
297
  }
273
298
 
274
- HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
299
+ HNSW::shrink_neighbor_list(
300
+ qdis, resultSet, returnlist, max_size, keep_max_size_level0);
275
301
 
276
302
  for (NodeDistFarther curen2 : returnlist) {
277
303
  resultSet1.emplace(curen2.d, curen2.id);
@@ -285,7 +311,8 @@ void add_link(
285
311
  DistanceComputer& qdis,
286
312
  storage_idx_t src,
287
313
  storage_idx_t dest,
288
- int level) {
314
+ int level,
315
+ bool keep_max_size_level0 = false) {
289
316
  size_t begin, end;
290
317
  hnsw.neighbor_range(src, level, &begin, &end);
291
318
  if (hnsw.neighbors[end - 1] == -1) {
@@ -310,7 +337,7 @@ void add_link(
310
337
  resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
311
338
  }
312
339
 
313
- shrink_neighbor_list(qdis, resultSet, end - begin);
340
+ shrink_neighbor_list(qdis, resultSet, end - begin, keep_max_size_level0);
314
341
 
315
342
  // ...and back
316
343
  size_t i = begin;
@@ -333,6 +360,9 @@ void search_neighbors_to_add(
333
360
  float d_entry_point,
334
361
  int level,
335
362
  VisitedTable& vt) {
363
+ // selects a version
364
+ const bool reference_version = false;
365
+
336
366
  // top is nearest candidate
337
367
  std::priority_queue<NodeDistFarther> candidates;
338
368
 
@@ -354,59 +384,90 @@ void search_neighbors_to_add(
354
384
  // loop over neighbors
355
385
  size_t begin, end;
356
386
  hnsw.neighbor_range(currNode, level, &begin, &end);
357
- for (size_t i = begin; i < end; i++) {
358
- storage_idx_t nodeId = hnsw.neighbors[i];
359
- if (nodeId < 0)
360
- break;
361
- if (vt.get(nodeId))
362
- continue;
363
- vt.set(nodeId);
364
-
365
- float dis = qdis(nodeId);
366
- NodeDistFarther evE1(dis, nodeId);
367
387
 
368
- if (results.size() < hnsw.efConstruction || results.top().d > dis) {
369
- results.emplace(dis, nodeId);
370
- candidates.emplace(dis, nodeId);
371
- if (results.size() > hnsw.efConstruction) {
372
- results.pop();
388
+ // select a version, based on a flag
389
+ if (reference_version) {
390
+ // a reference version
391
+ for (size_t i = begin; i < end; i++) {
392
+ storage_idx_t nodeId = hnsw.neighbors[i];
393
+ if (nodeId < 0)
394
+ break;
395
+ if (vt.get(nodeId))
396
+ continue;
397
+ vt.set(nodeId);
398
+
399
+ float dis = qdis(nodeId);
400
+ NodeDistFarther evE1(dis, nodeId);
401
+
402
+ if (results.size() < hnsw.efConstruction ||
403
+ results.top().d > dis) {
404
+ results.emplace(dis, nodeId);
405
+ candidates.emplace(dis, nodeId);
406
+ if (results.size() > hnsw.efConstruction) {
407
+ results.pop();
408
+ }
373
409
  }
374
410
  }
375
- }
376
- }
377
- vt.advance();
378
- }
411
+ } else {
412
+ // a faster version
413
+
414
+ // the following version processes 4 neighbors at a time
415
+ auto update_with_candidate = [&](const storage_idx_t idx,
416
+ const float dis) {
417
+ if (results.size() < hnsw.efConstruction ||
418
+ results.top().d > dis) {
419
+ results.emplace(dis, idx);
420
+ candidates.emplace(dis, idx);
421
+ if (results.size() > hnsw.efConstruction) {
422
+ results.pop();
423
+ }
424
+ }
425
+ };
379
426
 
380
- /**************************************************************
381
- * Searching subroutines
382
- **************************************************************/
427
+ int n_buffered = 0;
428
+ storage_idx_t buffered_ids[4];
383
429
 
384
- /// greedily update a nearest vector at a given level
385
- void greedy_update_nearest(
386
- const HNSW& hnsw,
387
- DistanceComputer& qdis,
388
- int level,
389
- storage_idx_t& nearest,
390
- float& d_nearest) {
391
- for (;;) {
392
- storage_idx_t prev_nearest = nearest;
430
+ for (size_t j = begin; j < end; j++) {
431
+ storage_idx_t nodeId = hnsw.neighbors[j];
432
+ if (nodeId < 0)
433
+ break;
434
+ if (vt.get(nodeId)) {
435
+ continue;
436
+ }
437
+ vt.set(nodeId);
438
+
439
+ buffered_ids[n_buffered] = nodeId;
440
+ n_buffered += 1;
441
+
442
+ if (n_buffered == 4) {
443
+ float dis[4];
444
+ qdis.distances_batch_4(
445
+ buffered_ids[0],
446
+ buffered_ids[1],
447
+ buffered_ids[2],
448
+ buffered_ids[3],
449
+ dis[0],
450
+ dis[1],
451
+ dis[2],
452
+ dis[3]);
453
+
454
+ for (size_t id4 = 0; id4 < 4; id4++) {
455
+ update_with_candidate(buffered_ids[id4], dis[id4]);
456
+ }
393
457
 
394
- size_t begin, end;
395
- hnsw.neighbor_range(nearest, level, &begin, &end);
396
- for (size_t i = begin; i < end; i++) {
397
- storage_idx_t v = hnsw.neighbors[i];
398
- if (v < 0)
399
- break;
400
- float dis = qdis(v);
401
- if (dis < d_nearest) {
402
- nearest = v;
403
- d_nearest = dis;
458
+ n_buffered = 0;
459
+ }
460
+ }
461
+
462
+ // process leftovers
463
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
464
+ float dis = qdis(buffered_ids[icnt]);
465
+ update_with_candidate(buffered_ids[icnt], dis);
404
466
  }
405
- }
406
- if (nearest == prev_nearest) {
407
- return;
408
467
  }
409
468
  }
469
+
470
+ vt.advance();
410
471
  }
411
472
 
412
473
  } // namespace
@@ -420,7 +481,8 @@ void HNSW::add_links_starting_from(
420
481
  float d_nearest,
421
482
  int level,
422
483
  omp_lock_t* locks,
423
- VisitedTable& vt) {
484
+ VisitedTable& vt,
485
+ bool keep_max_size_level0) {
424
486
  std::priority_queue<NodeDistCloser> link_targets;
425
487
 
426
488
  search_neighbors_to_add(
@@ -429,13 +491,13 @@ void HNSW::add_links_starting_from(
429
491
  // but we can afford only this many neighbors
430
492
  int M = nb_neighbors(level);
431
493
 
432
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
494
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M, keep_max_size_level0);
433
495
 
434
496
  std::vector<storage_idx_t> neighbors;
435
497
  neighbors.reserve(link_targets.size());
436
498
  while (!link_targets.empty()) {
437
499
  storage_idx_t other_id = link_targets.top().id;
438
- add_link(*this, ptdis, pt_id, other_id, level);
500
+ add_link(*this, ptdis, pt_id, other_id, level, keep_max_size_level0);
439
501
  neighbors.push_back(other_id);
440
502
  link_targets.pop();
441
503
  }
@@ -443,7 +505,7 @@ void HNSW::add_links_starting_from(
443
505
  omp_unset_lock(&locks[pt_id]);
444
506
  for (storage_idx_t other_id : neighbors) {
445
507
  omp_set_lock(&locks[other_id]);
446
- add_link(*this, ptdis, other_id, pt_id, level);
508
+ add_link(*this, ptdis, other_id, pt_id, level, keep_max_size_level0);
447
509
  omp_unset_lock(&locks[other_id]);
448
510
  }
449
511
  omp_set_lock(&locks[pt_id]);
@@ -458,7 +520,8 @@ void HNSW::add_with_locks(
458
520
  int pt_level,
459
521
  int pt_id,
460
522
  std::vector<omp_lock_t>& locks,
461
- VisitedTable& vt) {
523
+ VisitedTable& vt,
524
+ bool keep_max_size_level0) {
462
525
  // greedy search on upper levels
463
526
 
464
527
  storage_idx_t nearest;
@@ -487,7 +550,14 @@ void HNSW::add_with_locks(
487
550
 
488
551
  for (; level >= 0; level--) {
489
552
  add_links_starting_from(
490
- ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
553
+ ptdis,
554
+ pt_id,
555
+ nearest,
556
+ d_nearest,
557
+ level,
558
+ locks.data(),
559
+ vt,
560
+ keep_max_size_level0);
491
561
  }
492
562
 
493
563
  omp_unset_lock(&locks[pt_id]);
@@ -502,24 +572,20 @@ void HNSW::add_with_locks(
502
572
  * Searching
503
573
  **************************************************************/
504
574
 
505
- namespace {
506
-
507
575
  using MinimaxHeap = HNSW::MinimaxHeap;
508
576
  using Node = HNSW::Node;
577
+ using C = HNSW::C;
509
578
  /** Do a BFS on the candidates list */
510
-
511
579
  int search_from_candidates(
512
580
  const HNSW& hnsw,
513
581
  DistanceComputer& qdis,
514
- int k,
515
- idx_t* I,
516
- float* D,
582
+ ResultHandler<C>& res,
517
583
  MinimaxHeap& candidates,
518
584
  VisitedTable& vt,
519
585
  HNSWStats& stats,
520
586
  int level,
521
- int nres_in = 0,
522
- const SearchParametersHNSW* params = nullptr) {
587
+ int nres_in,
588
+ const SearchParametersHNSW* params) {
523
589
  int nres = nres_in;
524
590
  int ndis = 0;
525
591
 
@@ -529,15 +595,16 @@ int search_from_candidates(
529
595
  int efSearch = params ? params->efSearch : hnsw.efSearch;
530
596
  const IDSelector* sel = params ? params->sel : nullptr;
531
597
 
598
+ C::T threshold = res.threshold;
532
599
  for (int i = 0; i < candidates.size(); i++) {
533
600
  idx_t v1 = candidates.ids[i];
534
601
  float d = candidates.dis[i];
535
602
  FAISS_ASSERT(v1 >= 0);
536
603
  if (!sel || sel->is_member(v1)) {
537
- if (nres < k) {
538
- faiss::maxheap_push(++nres, D, I, d, v1);
539
- } else if (d < D[0]) {
540
- faiss::maxheap_replace_top(nres, D, I, d, v1);
604
+ if (d < threshold) {
605
+ if (res.add_result(d, v1)) {
606
+ threshold = res.threshold;
607
+ }
541
608
  }
542
609
  }
543
610
  vt.set(v1);
@@ -563,24 +630,70 @@ int search_from_candidates(
563
630
  size_t begin, end;
564
631
  hnsw.neighbor_range(v0, level, &begin, &end);
565
632
 
633
+ // a faster version: reference version in unit test test_hnsw.cpp
634
+ // the following version processes 4 neighbors at a time
635
+ size_t jmax = begin;
566
636
  for (size_t j = begin; j < end; j++) {
567
637
  int v1 = hnsw.neighbors[j];
568
638
  if (v1 < 0)
569
639
  break;
570
- if (vt.get(v1)) {
571
- continue;
640
+
641
+ prefetch_L2(vt.visited.data() + v1);
642
+ jmax += 1;
643
+ }
644
+
645
+ int counter = 0;
646
+ size_t saved_j[4];
647
+
648
+ threshold = res.threshold;
649
+
650
+ auto add_to_heap = [&](const size_t idx, const float dis) {
651
+ if (!sel || sel->is_member(idx)) {
652
+ if (dis < threshold) {
653
+ if (res.add_result(dis, idx)) {
654
+ threshold = res.threshold;
655
+ nres += 1;
656
+ }
657
+ }
572
658
  }
659
+ candidates.push(idx, dis);
660
+ };
661
+
662
+ for (size_t j = begin; j < jmax; j++) {
663
+ int v1 = hnsw.neighbors[j];
664
+
665
+ bool vget = vt.get(v1);
573
666
  vt.set(v1);
574
- ndis++;
575
- float d = qdis(v1);
576
- if (!sel || sel->is_member(v1)) {
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
667
+ saved_j[counter] = v1;
668
+ counter += vget ? 0 : 1;
669
+
670
+ if (counter == 4) {
671
+ float dis[4];
672
+ qdis.distances_batch_4(
673
+ saved_j[0],
674
+ saved_j[1],
675
+ saved_j[2],
676
+ saved_j[3],
677
+ dis[0],
678
+ dis[1],
679
+ dis[2],
680
+ dis[3]);
681
+
682
+ for (size_t id4 = 0; id4 < 4; id4++) {
683
+ add_to_heap(saved_j[id4], dis[id4]);
581
684
  }
685
+
686
+ ndis += 4;
687
+
688
+ counter = 0;
582
689
  }
583
- candidates.push(v1, d);
690
+ }
691
+
692
+ for (size_t icnt = 0; icnt < counter; icnt++) {
693
+ float dis = qdis(saved_j[icnt]);
694
+ add_to_heap(saved_j[icnt], dis);
695
+
696
+ ndis += 1;
584
697
  }
585
698
 
586
699
  nstep++;
@@ -594,7 +707,8 @@ int search_from_candidates(
594
707
  if (candidates.size() == 0) {
595
708
  stats.n2++;
596
709
  }
597
- stats.n3 += ndis;
710
+ stats.ndis += ndis;
711
+ stats.nhops += nstep;
598
712
  }
599
713
 
600
714
  return nres;
@@ -630,151 +744,241 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
630
744
  size_t begin, end;
631
745
  hnsw.neighbor_range(v0, 0, &begin, &end);
632
746
 
633
- for (size_t j = begin; j < end; ++j) {
747
+ // a faster version: reference version in unit test test_hnsw.cpp
748
+ // the following version processes 4 neighbors at a time
749
+ size_t jmax = begin;
750
+ for (size_t j = begin; j < end; j++) {
634
751
  int v1 = hnsw.neighbors[j];
635
-
636
- if (v1 < 0) {
752
+ if (v1 < 0)
637
753
  break;
638
- }
639
- if (vt->get(v1)) {
640
- continue;
641
- }
642
754
 
643
- vt->set(v1);
755
+ prefetch_L2(vt->visited.data() + v1);
756
+ jmax += 1;
757
+ }
644
758
 
645
- float d1 = qdis(v1);
646
- ++ndis;
759
+ int counter = 0;
760
+ size_t saved_j[4];
647
761
 
648
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
649
- candidates.emplace(d1, v1);
650
- top_candidates.emplace(d1, v1);
762
+ auto add_to_heap = [&](const size_t idx, const float dis) {
763
+ if (top_candidates.top().first > dis ||
764
+ top_candidates.size() < ef) {
765
+ candidates.emplace(dis, idx);
766
+ top_candidates.emplace(dis, idx);
651
767
 
652
768
  if (top_candidates.size() > ef) {
653
769
  top_candidates.pop();
654
770
  }
655
771
  }
772
+ };
773
+
774
+ for (size_t j = begin; j < jmax; j++) {
775
+ int v1 = hnsw.neighbors[j];
776
+
777
+ bool vget = vt->get(v1);
778
+ vt->set(v1);
779
+ saved_j[counter] = v1;
780
+ counter += vget ? 0 : 1;
781
+
782
+ if (counter == 4) {
783
+ float dis[4];
784
+ qdis.distances_batch_4(
785
+ saved_j[0],
786
+ saved_j[1],
787
+ saved_j[2],
788
+ saved_j[3],
789
+ dis[0],
790
+ dis[1],
791
+ dis[2],
792
+ dis[3]);
793
+
794
+ for (size_t id4 = 0; id4 < 4; id4++) {
795
+ add_to_heap(saved_j[id4], dis[id4]);
796
+ }
797
+
798
+ ndis += 4;
799
+
800
+ counter = 0;
801
+ }
802
+ }
803
+
804
+ for (size_t icnt = 0; icnt < counter; icnt++) {
805
+ float dis = qdis(saved_j[icnt]);
806
+ add_to_heap(saved_j[icnt], dis);
807
+
808
+ ndis += 1;
656
809
  }
810
+
811
+ stats.nhops += 1;
657
812
  }
658
813
 
659
814
  ++stats.n1;
660
815
  if (candidates.size() == 0) {
661
816
  ++stats.n2;
662
817
  }
663
- stats.n3 += ndis;
818
+ stats.ndis += ndis;
664
819
 
665
820
  return top_candidates;
666
821
  }
667
822
 
668
- } // anonymous namespace
669
-
670
- HNSWStats HNSW::search(
823
+ /// greedily update a nearest vector at a given level
824
+ HNSWStats greedy_update_nearest(
825
+ const HNSW& hnsw,
671
826
  DistanceComputer& qdis,
672
- int k,
673
- idx_t* I,
674
- float* D,
675
- VisitedTable& vt,
676
- const SearchParametersHNSW* params) const {
827
+ int level,
828
+ storage_idx_t& nearest,
829
+ float& d_nearest) {
677
830
  HNSWStats stats;
678
- if (entry_point == -1) {
679
- return stats;
680
- }
681
- if (upper_beam == 1) {
682
- // greedy search on upper levels
683
- storage_idx_t nearest = entry_point;
684
- float d_nearest = qdis(nearest);
685
831
 
686
- for (int level = max_level; level >= 1; level--) {
687
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
688
- }
832
+ for (;;) {
833
+ storage_idx_t prev_nearest = nearest;
689
834
 
690
- int ef = std::max(efSearch, k);
691
- if (search_bounded_queue) { // this is the most common branch
692
- MinimaxHeap candidates(ef);
835
+ size_t begin, end;
836
+ hnsw.neighbor_range(nearest, level, &begin, &end);
693
837
 
694
- candidates.push(nearest, d_nearest);
838
+ size_t ndis = 0;
695
839
 
696
- search_from_candidates(
697
- *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
698
- } else {
699
- std::priority_queue<Node> top_candidates =
700
- search_from_candidate_unbounded(
701
- *this,
702
- Node(d_nearest, nearest),
703
- qdis,
704
- ef,
705
- &vt,
706
- stats);
707
-
708
- while (top_candidates.size() > k) {
709
- top_candidates.pop();
840
+ // a faster version: reference version in unit test test_hnsw.cpp
841
+ // the following version processes 4 neighbors at a time
842
+ auto update_with_candidate = [&](const storage_idx_t idx,
843
+ const float dis) {
844
+ if (dis < d_nearest) {
845
+ nearest = idx;
846
+ d_nearest = dis;
710
847
  }
848
+ };
849
+
850
+ int n_buffered = 0;
851
+ storage_idx_t buffered_ids[4];
852
+
853
+ for (size_t j = begin; j < end; j++) {
854
+ storage_idx_t v = hnsw.neighbors[j];
855
+ if (v < 0)
856
+ break;
857
+ ndis += 1;
858
+
859
+ buffered_ids[n_buffered] = v;
860
+ n_buffered += 1;
861
+
862
+ if (n_buffered == 4) {
863
+ float dis[4];
864
+ qdis.distances_batch_4(
865
+ buffered_ids[0],
866
+ buffered_ids[1],
867
+ buffered_ids[2],
868
+ buffered_ids[3],
869
+ dis[0],
870
+ dis[1],
871
+ dis[2],
872
+ dis[3]);
873
+
874
+ for (size_t id4 = 0; id4 < 4; id4++) {
875
+ update_with_candidate(buffered_ids[id4], dis[id4]);
876
+ }
711
877
 
712
- int nres = 0;
713
- while (!top_candidates.empty()) {
714
- float d;
715
- storage_idx_t label;
716
- std::tie(d, label) = top_candidates.top();
717
- faiss::maxheap_push(++nres, D, I, d, label);
718
- top_candidates.pop();
878
+ n_buffered = 0;
719
879
  }
720
880
  }
721
881
 
722
- vt.advance();
882
+ // process leftovers
883
+ for (size_t icnt = 0; icnt < n_buffered; icnt++) {
884
+ float dis = qdis(buffered_ids[icnt]);
885
+ update_with_candidate(buffered_ids[icnt], dis);
886
+ }
723
887
 
724
- } else {
725
- int candidates_size = upper_beam;
726
- MinimaxHeap candidates(candidates_size);
888
+ // update stats
889
+ stats.ndis += ndis;
890
+ stats.nhops += 1;
727
891
 
728
- std::vector<idx_t> I_to_next(candidates_size);
729
- std::vector<float> D_to_next(candidates_size);
892
+ if (nearest == prev_nearest) {
893
+ return stats;
894
+ }
895
+ }
896
+ }
730
897
 
731
- int nres = 1;
732
- I_to_next[0] = entry_point;
733
- D_to_next[0] = qdis(entry_point);
898
+ namespace {
899
+ using MinimaxHeap = HNSW::MinimaxHeap;
900
+ using Node = HNSW::Node;
901
+ using C = HNSW::C;
734
902
 
735
- for (int level = max_level; level >= 0; level--) {
736
- // copy I, D -> candidates
903
+ // just used as a lower bound for the minmaxheap, but it is set for heap search
904
+ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
905
+ using RH = HeapBlockResultHandler<C>;
906
+ if (auto hres = dynamic_cast<RH::SingleResultHandler*>(&res)) {
907
+ return hres->k;
908
+ }
909
+ return 1;
910
+ }
737
911
 
738
- candidates.clear();
912
+ } // namespace
739
913
 
740
- for (int i = 0; i < nres; i++) {
741
- candidates.push(I_to_next[i], D_to_next[i]);
742
- }
914
+ HNSWStats HNSW::search(
915
+ DistanceComputer& qdis,
916
+ ResultHandler<C>& res,
917
+ VisitedTable& vt,
918
+ const SearchParametersHNSW* params) const {
919
+ HNSWStats stats;
920
+ if (entry_point == -1) {
921
+ return stats;
922
+ }
923
+ int k = extract_k_from_ResultHandler(res);
743
924
 
744
- if (level == 0) {
745
- nres = search_from_candidates(
746
- *this, qdis, k, I, D, candidates, vt, stats, 0);
747
- } else {
748
- nres = search_from_candidates(
749
- *this,
750
- qdis,
751
- candidates_size,
752
- I_to_next.data(),
753
- D_to_next.data(),
754
- candidates,
755
- vt,
756
- stats,
757
- level);
758
- }
759
- vt.advance();
925
+ bool bounded_queue =
926
+ params ? params->bounded_queue : this->search_bounded_queue;
927
+
928
+ // greedy search on upper levels
929
+ storage_idx_t nearest = entry_point;
930
+ float d_nearest = qdis(nearest);
931
+
932
+ for (int level = max_level; level >= 1; level--) {
933
+ HNSWStats local_stats =
934
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
935
+ stats.combine(local_stats);
936
+ }
937
+
938
+ int ef = std::max(params ? params->efSearch : efSearch, k);
939
+ if (bounded_queue) { // this is the most common branch
940
+ MinimaxHeap candidates(ef);
941
+
942
+ candidates.push(nearest, d_nearest);
943
+
944
+ search_from_candidates(
945
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
946
+ } else {
947
+ std::priority_queue<Node> top_candidates =
948
+ search_from_candidate_unbounded(
949
+ *this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
950
+
951
+ while (top_candidates.size() > k) {
952
+ top_candidates.pop();
953
+ }
954
+
955
+ while (!top_candidates.empty()) {
956
+ float d;
957
+ storage_idx_t label;
958
+ std::tie(d, label) = top_candidates.top();
959
+ res.add_result(d, label);
960
+ top_candidates.pop();
760
961
  }
761
962
  }
762
963
 
964
+ vt.advance();
965
+
763
966
  return stats;
764
967
  }
765
968
 
766
969
  void HNSW::search_level_0(
767
970
  DistanceComputer& qdis,
768
- int k,
769
- idx_t* idxi,
770
- float* simi,
971
+ ResultHandler<C>& res,
771
972
  idx_t nprobe,
772
973
  const storage_idx_t* nearest_i,
773
974
  const float* nearest_d,
774
975
  int search_type,
775
976
  HNSWStats& search_stats,
776
- VisitedTable& vt) const {
977
+ VisitedTable& vt,
978
+ const SearchParametersHNSW* params) const {
777
979
  const HNSW& hnsw = *this;
980
+ auto efSearch = params ? params->efSearch : hnsw.efSearch;
981
+ int k = extract_k_from_ResultHandler(res);
778
982
 
779
983
  if (search_type == 1) {
780
984
  int nres = 0;
@@ -788,7 +992,7 @@ void HNSW::search_level_0(
788
992
  if (vt.get(cj))
789
993
  continue;
790
994
 
791
- int candidates_size = std::max(hnsw.efSearch, int(k));
995
+ int candidates_size = std::max(efSearch, k);
792
996
  MinimaxHeap candidates(candidates_size);
793
997
 
794
998
  candidates.push(cj, nearest_d[j]);
@@ -796,17 +1000,17 @@ void HNSW::search_level_0(
796
1000
  nres = search_from_candidates(
797
1001
  hnsw,
798
1002
  qdis,
799
- k,
800
- idxi,
801
- simi,
1003
+ res,
802
1004
  candidates,
803
1005
  vt,
804
1006
  search_stats,
805
1007
  0,
806
- nres);
1008
+ nres,
1009
+ params);
1010
+ nres = std::min(nres, candidates_size);
807
1011
  }
808
1012
  } else if (search_type == 2) {
809
- int candidates_size = std::max(hnsw.efSearch, int(k));
1013
+ int candidates_size = std::max(efSearch, int(k));
810
1014
  candidates_size = std::max(candidates_size, int(nprobe));
811
1015
 
812
1016
  MinimaxHeap candidates(candidates_size);
@@ -819,10 +1023,43 @@ void HNSW::search_level_0(
819
1023
  }
820
1024
 
821
1025
  search_from_candidates(
822
- hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
1026
+ hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
823
1027
  }
824
1028
  }
825
1029
 
1030
+ void HNSW::permute_entries(const idx_t* map) {
1031
+ // remap levels
1032
+ storage_idx_t ntotal = levels.size();
1033
+ std::vector<storage_idx_t> imap(ntotal); // inverse mapping
1034
+ // map: new index -> old index
1035
+ // imap: old index -> new index
1036
+ for (int i = 0; i < ntotal; i++) {
1037
+ assert(map[i] >= 0 && map[i] < ntotal);
1038
+ imap[map[i]] = i;
1039
+ }
1040
+ if (entry_point != -1) {
1041
+ entry_point = imap[entry_point];
1042
+ }
1043
+ std::vector<int> new_levels(ntotal);
1044
+ std::vector<size_t> new_offsets(ntotal + 1);
1045
+ std::vector<storage_idx_t> new_neighbors(neighbors.size());
1046
+ size_t no = 0;
1047
+ for (int i = 0; i < ntotal; i++) {
1048
+ storage_idx_t o = map[i]; // corresponding "old" index
1049
+ new_levels[i] = levels[o];
1050
+ for (size_t j = offsets[o]; j < offsets[o + 1]; j++) {
1051
+ storage_idx_t neigh = neighbors[j];
1052
+ new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh;
1053
+ }
1054
+ new_offsets[i + 1] = no;
1055
+ }
1056
+ assert(new_offsets[ntotal] == offsets[ntotal]);
1057
+ // swap everyone
1058
+ std::swap(levels, new_levels);
1059
+ std::swap(offsets, new_offsets);
1060
+ std::swap(neighbors, new_neighbors);
1061
+ }
1062
+
826
1063
  /**************************************************************
827
1064
  * MinimaxHeap
828
1065
  **************************************************************/
@@ -852,17 +1089,197 @@ void HNSW::MinimaxHeap::clear() {
852
1089
  nvalid = k = 0;
853
1090
  }
854
1091
 
1092
+ #ifdef __AVX512F__
1093
+
1094
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1095
+ assert(k > 0);
1096
+ static_assert(
1097
+ std::is_same<storage_idx_t, int32_t>::value,
1098
+ "This code expects storage_idx_t to be int32_t");
1099
+
1100
+ int32_t min_idx = -1;
1101
+ float min_dis = std::numeric_limits<float>::infinity();
1102
+
1103
+ __m512i min_indices = _mm512_set1_epi32(-1);
1104
+ __m512 min_distances =
1105
+ _mm512_set1_ps(std::numeric_limits<float>::infinity());
1106
+ __m512i current_indices = _mm512_setr_epi32(
1107
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1108
+ __m512i offset = _mm512_set1_epi32(16);
1109
+
1110
+ // The following loop tracks the rightmost index with the min distance.
1111
+ // -1 index values are ignored.
1112
+ const int k16 = (k / 16) * 16;
1113
+ for (size_t iii = 0; iii < k16; iii += 16) {
1114
+ __m512i indices =
1115
+ _mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1116
+ __m512 distances = _mm512_loadu_ps(dis.data() + iii);
1117
+
1118
+ // This mask filters out -1 values among indices.
1119
+ __mmask16 m1mask =
1120
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1121
+
1122
+ __mmask16 dmask =
1123
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1124
+ __mmask16 finalmask = m1mask | dmask;
1125
+
1126
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1127
+ finalmask, current_indices, min_indices);
1128
+ const __m512 min_distances_new =
1129
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1130
+
1131
+ min_indices = min_indices_new;
1132
+ min_distances = min_distances_new;
1133
+
1134
+ current_indices = _mm512_add_epi32(current_indices, offset);
1135
+ }
1136
+
1137
+ // leftovers
1138
+ if (k16 != k) {
1139
+ const __mmask16 kmask = (1 << (k - k16)) - 1;
1140
+
1141
+ __m512i indices = _mm512_mask_loadu_epi32(
1142
+ _mm512_set1_epi32(-1), kmask, ids.data() + k16);
1143
+ __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1144
+
1145
+ // This mask filters out -1 values among indices.
1146
+ __mmask16 m1mask =
1147
+ _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1148
+
1149
+ __mmask16 dmask =
1150
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1151
+ __mmask16 finalmask = m1mask | dmask;
1152
+
1153
+ const __m512i min_indices_new = _mm512_mask_blend_epi32(
1154
+ finalmask, current_indices, min_indices);
1155
+ const __m512 min_distances_new =
1156
+ _mm512_mask_blend_ps(finalmask, distances, min_distances);
1157
+
1158
+ min_indices = min_indices_new;
1159
+ min_distances = min_distances_new;
1160
+ }
1161
+
1162
+ // grab min distance
1163
+ min_dis = _mm512_reduce_min_ps(min_distances);
1164
+ // blend
1165
+ __mmask16 mindmask =
1166
+ _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1167
+ // pick the max one
1168
+ min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1169
+
1170
+ if (min_idx == -1) {
1171
+ return -1;
1172
+ }
1173
+
1174
+ if (vmin_out) {
1175
+ *vmin_out = min_dis;
1176
+ }
1177
+ int ret = ids[min_idx];
1178
+ ids[min_idx] = -1;
1179
+ --nvalid;
1180
+ return ret;
1181
+ }
1182
+
1183
+ #elif __AVX2__
1184
+
1185
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1186
+ assert(k > 0);
1187
+ static_assert(
1188
+ std::is_same<storage_idx_t, int32_t>::value,
1189
+ "This code expects storage_idx_t to be int32_t");
1190
+
1191
+ int32_t min_idx = -1;
1192
+ float min_dis = std::numeric_limits<float>::infinity();
1193
+
1194
+ size_t iii = 0;
1195
+
1196
+ __m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1197
+ __m256 min_distances =
1198
+ _mm256_set1_ps(std::numeric_limits<float>::infinity());
1199
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1200
+ __m256i offset = _mm256_set1_epi32(8);
1201
+
1202
+ // The baseline version is available in non-AVX2 branch.
1203
+
1204
+ // The following loop tracks the rightmost index with the min distance.
1205
+ // -1 index values are ignored.
1206
+ const int k8 = (k / 8) * 8;
1207
+ for (; iii < k8; iii += 8) {
1208
+ __m256i indices =
1209
+ _mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1210
+ __m256 distances = _mm256_loadu_ps(dis.data() + iii);
1211
+
1212
+ // This mask filters out -1 values among indices.
1213
+ __m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1214
+
1215
+ __m256i dmask = _mm256_castps_si256(
1216
+ _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1217
+ __m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1218
+
1219
+ const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1220
+ _mm256_castsi256_ps(current_indices),
1221
+ _mm256_castsi256_ps(min_indices),
1222
+ finalmask));
1223
+
1224
+ const __m256 min_distances_new =
1225
+ _mm256_blendv_ps(distances, min_distances, finalmask);
1226
+
1227
+ min_indices = min_indices_new;
1228
+ min_distances = min_distances_new;
1229
+
1230
+ current_indices = _mm256_add_epi32(current_indices, offset);
1231
+ }
1232
+
1233
+ // Vectorizing is doable, but is not practical
1234
+ int32_t vidx8[8];
1235
+ float vdis8[8];
1236
+ _mm256_storeu_ps(vdis8, min_distances);
1237
+ _mm256_storeu_si256((__m256i*)vidx8, min_indices);
1238
+
1239
+ for (size_t j = 0; j < 8; j++) {
1240
+ if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1241
+ min_idx = vidx8[j];
1242
+ min_dis = vdis8[j];
1243
+ }
1244
+ }
1245
+
1246
+ // process last values. Vectorizing is doable, but is not practical
1247
+ for (; iii < k; iii++) {
1248
+ if (ids[iii] != -1 && dis[iii] <= min_dis) {
1249
+ min_dis = dis[iii];
1250
+ min_idx = iii;
1251
+ }
1252
+ }
1253
+
1254
+ if (min_idx == -1) {
1255
+ return -1;
1256
+ }
1257
+
1258
+ if (vmin_out) {
1259
+ *vmin_out = min_dis;
1260
+ }
1261
+ int ret = ids[min_idx];
1262
+ ids[min_idx] = -1;
1263
+ --nvalid;
1264
+ return ret;
1265
+ }
1266
+
1267
+ #else
1268
+
1269
+ // baseline non-vectorized version
855
1270
  int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
856
1271
  assert(k > 0);
857
1272
  // returns min. This is an O(n) operation
858
1273
  int i = k - 1;
859
1274
  while (i >= 0) {
860
- if (ids[i] != -1)
1275
+ if (ids[i] != -1) {
861
1276
  break;
1277
+ }
862
1278
  i--;
863
1279
  }
864
- if (i == -1)
1280
+ if (i == -1) {
865
1281
  return -1;
1282
+ }
866
1283
  int imin = i;
867
1284
  float vmin = dis[i];
868
1285
  i--;
@@ -873,14 +1290,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
873
1290
  }
874
1291
  i--;
875
1292
  }
876
- if (vmin_out)
1293
+ if (vmin_out) {
877
1294
  *vmin_out = vmin;
1295
+ }
878
1296
  int ret = ids[imin];
879
1297
  ids[imin] = -1;
880
1298
  --nvalid;
881
1299
 
882
1300
  return ret;
883
1301
  }
1302
+ #endif
884
1303
 
885
1304
  int HNSW::MinimaxHeap::count_below(float thresh) {
886
1305
  int n_below = 0;