faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -15,275 +15,254 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
-
19
18
  /**************************************************************
20
19
  * HNSW structure implementation
21
20
  **************************************************************/
22
21
 
23
- int HNSW::nb_neighbors(int layer_no) const
24
- {
25
- return cum_nneighbor_per_level[layer_no + 1] -
26
- cum_nneighbor_per_level[layer_no];
22
+ int HNSW::nb_neighbors(int layer_no) const {
23
+ return cum_nneighbor_per_level[layer_no + 1] -
24
+ cum_nneighbor_per_level[layer_no];
27
25
  }
28
26
 
29
- void HNSW::set_nb_neighbors(int level_no, int n)
30
- {
31
- FAISS_THROW_IF_NOT(levels.size() == 0);
32
- int cur_n = nb_neighbors(level_no);
33
- for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
34
- cum_nneighbor_per_level[i] += n - cur_n;
35
- }
27
+ void HNSW::set_nb_neighbors(int level_no, int n) {
28
+ FAISS_THROW_IF_NOT(levels.size() == 0);
29
+ int cur_n = nb_neighbors(level_no);
30
+ for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
31
+ cum_nneighbor_per_level[i] += n - cur_n;
32
+ }
36
33
  }
37
34
 
38
- int HNSW::cum_nb_neighbors(int layer_no) const
39
- {
40
- return cum_nneighbor_per_level[layer_no];
35
+ int HNSW::cum_nb_neighbors(int layer_no) const {
36
+ return cum_nneighbor_per_level[layer_no];
41
37
  }
42
38
 
43
- void HNSW::neighbor_range(idx_t no, int layer_no,
44
- size_t * begin, size_t * end) const
45
- {
46
- size_t o = offsets[no];
47
- *begin = o + cum_nb_neighbors(layer_no);
48
- *end = o + cum_nb_neighbors(layer_no + 1);
39
+ void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end)
40
+ const {
41
+ size_t o = offsets[no];
42
+ *begin = o + cum_nb_neighbors(layer_no);
43
+ *end = o + cum_nb_neighbors(layer_no + 1);
49
44
  }
50
45
 
51
-
52
-
53
46
  HNSW::HNSW(int M) : rng(12345) {
54
- set_default_probas(M, 1.0 / log(M));
55
- max_level = -1;
56
- entry_point = -1;
57
- efSearch = 16;
58
- efConstruction = 40;
59
- upper_beam = 1;
60
- offsets.push_back(0);
47
+ set_default_probas(M, 1.0 / log(M));
48
+ max_level = -1;
49
+ entry_point = -1;
50
+ efSearch = 16;
51
+ efConstruction = 40;
52
+ upper_beam = 1;
53
+ offsets.push_back(0);
61
54
  }
62
55
 
63
-
64
- int HNSW::random_level()
65
- {
66
- double f = rng.rand_float();
67
- // could be a bit faster with bissection
68
- for (int level = 0; level < assign_probas.size(); level++) {
69
- if (f < assign_probas[level]) {
70
- return level;
56
+ int HNSW::random_level() {
57
+ double f = rng.rand_float();
58
+ // could be a bit faster with bissection
59
+ for (int level = 0; level < assign_probas.size(); level++) {
60
+ if (f < assign_probas[level]) {
61
+ return level;
62
+ }
63
+ f -= assign_probas[level];
71
64
  }
72
- f -= assign_probas[level];
73
- }
74
- // happens with exponentially low probability
75
- return assign_probas.size() - 1;
65
+ // happens with exponentially low probability
66
+ return assign_probas.size() - 1;
76
67
  }
77
68
 
78
- void HNSW::set_default_probas(int M, float levelMult)
79
- {
80
- int nn = 0;
81
- cum_nneighbor_per_level.push_back (0);
82
- for (int level = 0; ;level++) {
83
- float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
84
- if (proba < 1e-9) break;
85
- assign_probas.push_back(proba);
86
- nn += level == 0 ? M * 2 : M;
87
- cum_nneighbor_per_level.push_back (nn);
88
- }
69
+ void HNSW::set_default_probas(int M, float levelMult) {
70
+ int nn = 0;
71
+ cum_nneighbor_per_level.push_back(0);
72
+ for (int level = 0;; level++) {
73
+ float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
74
+ if (proba < 1e-9)
75
+ break;
76
+ assign_probas.push_back(proba);
77
+ nn += level == 0 ? M * 2 : M;
78
+ cum_nneighbor_per_level.push_back(nn);
79
+ }
89
80
  }
90
81
 
91
- void HNSW::clear_neighbor_tables(int level)
92
- {
93
- for (int i = 0; i < levels.size(); i++) {
94
- size_t begin, end;
95
- neighbor_range(i, level, &begin, &end);
96
- for (size_t j = begin; j < end; j++) {
97
- neighbors[j] = -1;
82
+ void HNSW::clear_neighbor_tables(int level) {
83
+ for (int i = 0; i < levels.size(); i++) {
84
+ size_t begin, end;
85
+ neighbor_range(i, level, &begin, &end);
86
+ for (size_t j = begin; j < end; j++) {
87
+ neighbors[j] = -1;
88
+ }
98
89
  }
99
- }
100
90
  }
101
91
 
102
-
103
92
  void HNSW::reset() {
104
- max_level = -1;
105
- entry_point = -1;
106
- offsets.clear();
107
- offsets.push_back(0);
108
- levels.clear();
109
- neighbors.clear();
93
+ max_level = -1;
94
+ entry_point = -1;
95
+ offsets.clear();
96
+ offsets.push_back(0);
97
+ levels.clear();
98
+ neighbors.clear();
110
99
  }
111
100
 
112
-
113
-
114
- void HNSW::print_neighbor_stats(int level) const
115
- {
116
- FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size());
117
- printf("stats on level %d, max %d neighbors per vertex:\n",
118
- level, nb_neighbors(level));
119
- size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
101
+ void HNSW::print_neighbor_stats(int level) const {
102
+ FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
103
+ printf("stats on level %d, max %d neighbors per vertex:\n",
104
+ level,
105
+ nb_neighbors(level));
106
+ size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
120
107
  #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
121
108
  reduction(+: tot_reciprocal) reduction(+: n_node)
122
- for (int i = 0; i < levels.size(); i++) {
123
- if (levels[i] > level) {
124
- n_node++;
125
- size_t begin, end;
126
- neighbor_range(i, level, &begin, &end);
127
- std::unordered_set<int> neighset;
128
- for (size_t j = begin; j < end; j++) {
129
- if (neighbors [j] < 0) break;
130
- neighset.insert(neighbors[j]);
131
- }
132
- int n_neigh = neighset.size();
133
- int n_common = 0;
134
- int n_reciprocal = 0;
135
- for (size_t j = begin; j < end; j++) {
136
- storage_idx_t i2 = neighbors[j];
137
- if (i2 < 0) break;
138
- FAISS_ASSERT(i2 != i);
139
- size_t begin2, end2;
140
- neighbor_range(i2, level, &begin2, &end2);
141
- for (size_t j2 = begin2; j2 < end2; j2++) {
142
- storage_idx_t i3 = neighbors[j2];
143
- if (i3 < 0) break;
144
- if (i3 == i) {
145
- n_reciprocal++;
146
- continue;
147
- }
148
- if (neighset.count(i3)) {
149
- neighset.erase(i3);
150
- n_common++;
151
- }
109
+ for (int i = 0; i < levels.size(); i++) {
110
+ if (levels[i] > level) {
111
+ n_node++;
112
+ size_t begin, end;
113
+ neighbor_range(i, level, &begin, &end);
114
+ std::unordered_set<int> neighset;
115
+ for (size_t j = begin; j < end; j++) {
116
+ if (neighbors[j] < 0)
117
+ break;
118
+ neighset.insert(neighbors[j]);
119
+ }
120
+ int n_neigh = neighset.size();
121
+ int n_common = 0;
122
+ int n_reciprocal = 0;
123
+ for (size_t j = begin; j < end; j++) {
124
+ storage_idx_t i2 = neighbors[j];
125
+ if (i2 < 0)
126
+ break;
127
+ FAISS_ASSERT(i2 != i);
128
+ size_t begin2, end2;
129
+ neighbor_range(i2, level, &begin2, &end2);
130
+ for (size_t j2 = begin2; j2 < end2; j2++) {
131
+ storage_idx_t i3 = neighbors[j2];
132
+ if (i3 < 0)
133
+ break;
134
+ if (i3 == i) {
135
+ n_reciprocal++;
136
+ continue;
137
+ }
138
+ if (neighset.count(i3)) {
139
+ neighset.erase(i3);
140
+ n_common++;
141
+ }
142
+ }
143
+ }
144
+ tot_neigh += n_neigh;
145
+ tot_common += n_common;
146
+ tot_reciprocal += n_reciprocal;
152
147
  }
153
- }
154
- tot_neigh += n_neigh;
155
- tot_common += n_common;
156
- tot_reciprocal += n_reciprocal;
157
148
  }
158
- }
159
- float normalizer = n_node;
160
- printf(" nb of nodes at that level %zd\n", n_node);
161
- printf(" neighbors per node: %.2f (%zd)\n",
162
- tot_neigh / normalizer, tot_neigh);
163
- printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer);
164
- printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
165
- tot_common / normalizer, tot_common);
166
-
167
-
168
-
149
+ float normalizer = n_node;
150
+ printf(" nb of nodes at that level %zd\n", n_node);
151
+ printf(" neighbors per node: %.2f (%zd)\n",
152
+ tot_neigh / normalizer,
153
+ tot_neigh);
154
+ printf(" nb of reciprocal neighbors: %.2f\n",
155
+ tot_reciprocal / normalizer);
156
+ printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
157
+ tot_common / normalizer,
158
+ tot_common);
169
159
  }
170
160
 
161
+ void HNSW::fill_with_random_links(size_t n) {
162
+ int max_level = prepare_level_tab(n);
163
+ RandomGenerator rng2(456);
171
164
 
172
- void HNSW::fill_with_random_links(size_t n)
173
- {
174
- int max_level = prepare_level_tab(n);
175
- RandomGenerator rng2(456);
176
-
177
- for (int level = max_level - 1; level >= 0; --level) {
178
- std::vector<int> elts;
179
- for (int i = 0; i < n; i++) {
180
- if (levels[i] > level) {
181
- elts.push_back(i);
182
- }
183
- }
184
- printf ("linking %zd elements in level %d\n",
185
- elts.size(), level);
186
-
187
- if (elts.size() == 1) continue;
165
+ for (int level = max_level - 1; level >= 0; --level) {
166
+ std::vector<int> elts;
167
+ for (int i = 0; i < n; i++) {
168
+ if (levels[i] > level) {
169
+ elts.push_back(i);
170
+ }
171
+ }
172
+ printf("linking %zd elements in level %d\n", elts.size(), level);
188
173
 
189
- for (int ii = 0; ii < elts.size(); ii++) {
190
- int i = elts[ii];
191
- size_t begin, end;
192
- neighbor_range(i, 0, &begin, &end);
193
- for (size_t j = begin; j < end; j++) {
194
- int other = 0;
195
- do {
196
- other = elts[rng2.rand_int(elts.size())];
197
- } while(other == i);
174
+ if (elts.size() == 1)
175
+ continue;
198
176
 
199
- neighbors[j] = other;
200
- }
177
+ for (int ii = 0; ii < elts.size(); ii++) {
178
+ int i = elts[ii];
179
+ size_t begin, end;
180
+ neighbor_range(i, 0, &begin, &end);
181
+ for (size_t j = begin; j < end; j++) {
182
+ int other = 0;
183
+ do {
184
+ other = elts[rng2.rand_int(elts.size())];
185
+ } while (other == i);
186
+
187
+ neighbors[j] = other;
188
+ }
189
+ }
201
190
  }
202
- }
203
191
  }
204
192
 
193
+ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
194
+ size_t n0 = offsets.size() - 1;
205
195
 
206
- int HNSW::prepare_level_tab(size_t n, bool preset_levels)
207
- {
208
- size_t n0 = offsets.size() - 1;
196
+ if (preset_levels) {
197
+ FAISS_ASSERT(n0 + n == levels.size());
198
+ } else {
199
+ FAISS_ASSERT(n0 == levels.size());
200
+ for (int i = 0; i < n; i++) {
201
+ int pt_level = random_level();
202
+ levels.push_back(pt_level + 1);
203
+ }
204
+ }
209
205
 
210
- if (preset_levels) {
211
- FAISS_ASSERT (n0 + n == levels.size());
212
- } else {
213
- FAISS_ASSERT (n0 == levels.size());
206
+ int max_level = 0;
214
207
  for (int i = 0; i < n; i++) {
215
- int pt_level = random_level();
216
- levels.push_back(pt_level + 1);
208
+ int pt_level = levels[i + n0] - 1;
209
+ if (pt_level > max_level)
210
+ max_level = pt_level;
211
+ offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
212
+ neighbors.resize(offsets.back(), -1);
217
213
  }
218
- }
219
214
 
220
- int max_level = 0;
221
- for (int i = 0; i < n; i++) {
222
- int pt_level = levels[i + n0] - 1;
223
- if (pt_level > max_level) max_level = pt_level;
224
- offsets.push_back(offsets.back() +
225
- cum_nb_neighbors(pt_level + 1));
226
- neighbors.resize(offsets.back(), -1);
227
- }
228
-
229
- return max_level;
215
+ return max_level;
230
216
  }
231
217
 
232
-
233
218
  /** Enumerate vertices from farthest to nearest from query, keep a
234
219
  * neighbor only if there is no previous neighbor that is closer to
235
220
  * that vertex than the query.
236
221
  */
237
222
  void HNSW::shrink_neighbor_list(
238
- DistanceComputer& qdis,
239
- std::priority_queue<NodeDistFarther>& input,
240
- std::vector<NodeDistFarther>& output,
241
- int max_size)
242
- {
243
- while (input.size() > 0) {
244
- NodeDistFarther v1 = input.top();
245
- input.pop();
246
- float dist_v1_q = v1.d;
247
-
248
- bool good = true;
249
- for (NodeDistFarther v2 : output) {
250
- float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
251
-
252
- if (dist_v1_v2 < dist_v1_q) {
253
- good = false;
254
- break;
255
- }
256
- }
257
-
258
- if (good) {
259
- output.push_back(v1);
260
- if (output.size() >= max_size) {
261
- return;
262
- }
223
+ DistanceComputer& qdis,
224
+ std::priority_queue<NodeDistFarther>& input,
225
+ std::vector<NodeDistFarther>& output,
226
+ int max_size) {
227
+ while (input.size() > 0) {
228
+ NodeDistFarther v1 = input.top();
229
+ input.pop();
230
+ float dist_v1_q = v1.d;
231
+
232
+ bool good = true;
233
+ for (NodeDistFarther v2 : output) {
234
+ float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
235
+
236
+ if (dist_v1_v2 < dist_v1_q) {
237
+ good = false;
238
+ break;
239
+ }
240
+ }
241
+
242
+ if (good) {
243
+ output.push_back(v1);
244
+ if (output.size() >= max_size) {
245
+ return;
246
+ }
247
+ }
263
248
  }
264
- }
265
249
  }
266
250
 
267
-
268
251
  namespace {
269
252
 
270
-
271
253
  using storage_idx_t = HNSW::storage_idx_t;
272
254
  using NodeDistCloser = HNSW::NodeDistCloser;
273
255
  using NodeDistFarther = HNSW::NodeDistFarther;
274
256
 
275
-
276
257
  /**************************************************************
277
258
  * Addition subroutines
278
259
  **************************************************************/
279
260
 
280
-
281
261
  /// remove neighbors from the list to make it smaller than max_size
282
262
  void shrink_neighbor_list(
283
- DistanceComputer& qdis,
284
- std::priority_queue<NodeDistCloser>& resultSet1,
285
- int max_size)
286
- {
263
+ DistanceComputer& qdis,
264
+ std::priority_queue<NodeDistCloser>& resultSet1,
265
+ int max_size) {
287
266
  if (resultSet1.size() < max_size) {
288
267
  return;
289
268
  }
@@ -300,516 +279,521 @@ void shrink_neighbor_list(
300
279
  for (NodeDistFarther curen2 : returnlist) {
301
280
  resultSet1.emplace(curen2.d, curen2.id);
302
281
  }
303
-
304
282
  }
305
283
 
306
-
307
284
  /// add a link between two elements, possibly shrinking the list
308
285
  /// of links to make room for it.
309
- void add_link(HNSW& hnsw,
310
- DistanceComputer& qdis,
311
- storage_idx_t src, storage_idx_t dest,
312
- int level)
313
- {
314
- size_t begin, end;
315
- hnsw.neighbor_range(src, level, &begin, &end);
316
- if (hnsw.neighbors[end - 1] == -1) {
317
- // there is enough room, find a slot to add it
318
- size_t i = end;
319
- while(i > begin) {
320
- if (hnsw.neighbors[i - 1] != -1) break;
321
- i--;
322
- }
323
- hnsw.neighbors[i] = dest;
324
- return;
325
- }
326
-
327
- // otherwise we let them fight out which to keep
328
-
329
- // copy to resultSet...
330
- std::priority_queue<NodeDistCloser> resultSet;
331
- resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
332
- for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
333
- storage_idx_t neigh = hnsw.neighbors[i];
334
- resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
335
- }
336
-
337
- shrink_neighbor_list(qdis, resultSet, end - begin);
338
-
339
- // ...and back
340
- size_t i = begin;
341
- while (resultSet.size()) {
342
- hnsw.neighbors[i++] = resultSet.top().id;
343
- resultSet.pop();
344
- }
345
- // they may have shrunk more than just by 1 element
346
- while(i < end) {
347
- hnsw.neighbors[i++] = -1;
348
- }
286
+ void add_link(
287
+ HNSW& hnsw,
288
+ DistanceComputer& qdis,
289
+ storage_idx_t src,
290
+ storage_idx_t dest,
291
+ int level) {
292
+ size_t begin, end;
293
+ hnsw.neighbor_range(src, level, &begin, &end);
294
+ if (hnsw.neighbors[end - 1] == -1) {
295
+ // there is enough room, find a slot to add it
296
+ size_t i = end;
297
+ while (i > begin) {
298
+ if (hnsw.neighbors[i - 1] != -1)
299
+ break;
300
+ i--;
301
+ }
302
+ hnsw.neighbors[i] = dest;
303
+ return;
304
+ }
305
+
306
+ // otherwise we let them fight out which to keep
307
+
308
+ // copy to resultSet...
309
+ std::priority_queue<NodeDistCloser> resultSet;
310
+ resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
311
+ for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
312
+ storage_idx_t neigh = hnsw.neighbors[i];
313
+ resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
314
+ }
315
+
316
+ shrink_neighbor_list(qdis, resultSet, end - begin);
317
+
318
+ // ...and back
319
+ size_t i = begin;
320
+ while (resultSet.size()) {
321
+ hnsw.neighbors[i++] = resultSet.top().id;
322
+ resultSet.pop();
323
+ }
324
+ // they may have shrunk more than just by 1 element
325
+ while (i < end) {
326
+ hnsw.neighbors[i++] = -1;
327
+ }
349
328
  }
350
329
 
351
330
  /// search neighbors on a single level, starting from an entry point
352
331
  void search_neighbors_to_add(
353
- HNSW& hnsw,
354
- DistanceComputer& qdis,
355
- std::priority_queue<NodeDistCloser>& results,
356
- int entry_point,
357
- float d_entry_point,
358
- int level,
359
- VisitedTable &vt)
360
- {
361
- // top is nearest candidate
362
- std::priority_queue<NodeDistFarther> candidates;
363
-
364
- NodeDistFarther ev(d_entry_point, entry_point);
365
- candidates.push(ev);
366
- results.emplace(d_entry_point, entry_point);
367
- vt.set(entry_point);
368
-
369
- while (!candidates.empty()) {
370
- // get nearest
371
- const NodeDistFarther &currEv = candidates.top();
372
-
373
- if (currEv.d > results.top().d) {
374
- break;
375
- }
376
- int currNode = currEv.id;
377
- candidates.pop();
378
-
379
- // loop over neighbors
380
- size_t begin, end;
381
- hnsw.neighbor_range(currNode, level, &begin, &end);
382
- for(size_t i = begin; i < end; i++) {
383
- storage_idx_t nodeId = hnsw.neighbors[i];
384
- if (nodeId < 0) break;
385
- if (vt.get(nodeId)) continue;
386
- vt.set(nodeId);
387
-
388
- float dis = qdis(nodeId);
389
- NodeDistFarther evE1(dis, nodeId);
390
-
391
- if (results.size() < hnsw.efConstruction ||
392
- results.top().d > dis) {
393
-
394
- results.emplace(dis, nodeId);
395
- candidates.emplace(dis, nodeId);
396
- if (results.size() > hnsw.efConstruction) {
397
- results.pop();
332
+ HNSW& hnsw,
333
+ DistanceComputer& qdis,
334
+ std::priority_queue<NodeDistCloser>& results,
335
+ int entry_point,
336
+ float d_entry_point,
337
+ int level,
338
+ VisitedTable& vt) {
339
+ // top is nearest candidate
340
+ std::priority_queue<NodeDistFarther> candidates;
341
+
342
+ NodeDistFarther ev(d_entry_point, entry_point);
343
+ candidates.push(ev);
344
+ results.emplace(d_entry_point, entry_point);
345
+ vt.set(entry_point);
346
+
347
+ while (!candidates.empty()) {
348
+ // get nearest
349
+ const NodeDistFarther& currEv = candidates.top();
350
+
351
+ if (currEv.d > results.top().d) {
352
+ break;
353
+ }
354
+ int currNode = currEv.id;
355
+ candidates.pop();
356
+
357
+ // loop over neighbors
358
+ size_t begin, end;
359
+ hnsw.neighbor_range(currNode, level, &begin, &end);
360
+ for (size_t i = begin; i < end; i++) {
361
+ storage_idx_t nodeId = hnsw.neighbors[i];
362
+ if (nodeId < 0)
363
+ break;
364
+ if (vt.get(nodeId))
365
+ continue;
366
+ vt.set(nodeId);
367
+
368
+ float dis = qdis(nodeId);
369
+ NodeDistFarther evE1(dis, nodeId);
370
+
371
+ if (results.size() < hnsw.efConstruction || results.top().d > dis) {
372
+ results.emplace(dis, nodeId);
373
+ candidates.emplace(dis, nodeId);
374
+ if (results.size() > hnsw.efConstruction) {
375
+ results.pop();
376
+ }
377
+ }
398
378
  }
399
- }
400
379
  }
401
- }
402
- vt.advance();
380
+ vt.advance();
403
381
  }
404
382
 
405
-
406
383
  /**************************************************************
407
384
  * Searching subroutines
408
385
  **************************************************************/
409
386
 
410
387
  /// greedily update a nearest vector at a given level
411
- void greedy_update_nearest(const HNSW& hnsw,
412
- DistanceComputer& qdis,
413
- int level,
414
- storage_idx_t& nearest,
415
- float& d_nearest)
416
- {
417
- for(;;) {
418
- storage_idx_t prev_nearest = nearest;
419
-
420
- size_t begin, end;
421
- hnsw.neighbor_range(nearest, level, &begin, &end);
422
- for(size_t i = begin; i < end; i++) {
423
- storage_idx_t v = hnsw.neighbors[i];
424
- if (v < 0) break;
425
- float dis = qdis(v);
426
- if (dis < d_nearest) {
427
- nearest = v;
428
- d_nearest = dis;
429
- }
430
- }
431
- if (nearest == prev_nearest) {
432
- return;
433
- }
434
- }
388
+ void greedy_update_nearest(
389
+ const HNSW& hnsw,
390
+ DistanceComputer& qdis,
391
+ int level,
392
+ storage_idx_t& nearest,
393
+ float& d_nearest) {
394
+ for (;;) {
395
+ storage_idx_t prev_nearest = nearest;
396
+
397
+ size_t begin, end;
398
+ hnsw.neighbor_range(nearest, level, &begin, &end);
399
+ for (size_t i = begin; i < end; i++) {
400
+ storage_idx_t v = hnsw.neighbors[i];
401
+ if (v < 0)
402
+ break;
403
+ float dis = qdis(v);
404
+ if (dis < d_nearest) {
405
+ nearest = v;
406
+ d_nearest = dis;
407
+ }
408
+ }
409
+ if (nearest == prev_nearest) {
410
+ return;
411
+ }
412
+ }
435
413
  }
436
414
 
437
-
438
- } // namespace
439
-
415
+ } // namespace
440
416
 
441
417
  /// Finds neighbors and builds links with them, starting from an entry
442
418
  /// point. The own neighbor list is assumed to be locked.
443
- void HNSW::add_links_starting_from(DistanceComputer& ptdis,
444
- storage_idx_t pt_id,
445
- storage_idx_t nearest,
446
- float d_nearest,
447
- int level,
448
- omp_lock_t *locks,
449
- VisitedTable &vt)
450
- {
451
- std::priority_queue<NodeDistCloser> link_targets;
419
+ void HNSW::add_links_starting_from(
420
+ DistanceComputer& ptdis,
421
+ storage_idx_t pt_id,
422
+ storage_idx_t nearest,
423
+ float d_nearest,
424
+ int level,
425
+ omp_lock_t* locks,
426
+ VisitedTable& vt) {
427
+ std::priority_queue<NodeDistCloser> link_targets;
452
428
 
453
- search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest,
454
- level, vt);
429
+ search_neighbors_to_add(
430
+ *this, ptdis, link_targets, nearest, d_nearest, level, vt);
455
431
 
456
- // but we can afford only this many neighbors
457
- int M = nb_neighbors(level);
432
+ // but we can afford only this many neighbors
433
+ int M = nb_neighbors(level);
458
434
 
459
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
435
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
460
436
 
461
- while (!link_targets.empty()) {
462
- int other_id = link_targets.top().id;
437
+ while (!link_targets.empty()) {
438
+ int other_id = link_targets.top().id;
463
439
 
464
- omp_set_lock(&locks[other_id]);
465
- add_link(*this, ptdis, other_id, pt_id, level);
466
- omp_unset_lock(&locks[other_id]);
440
+ omp_set_lock(&locks[other_id]);
441
+ add_link(*this, ptdis, other_id, pt_id, level);
442
+ omp_unset_lock(&locks[other_id]);
467
443
 
468
- add_link(*this, ptdis, pt_id, other_id, level);
444
+ add_link(*this, ptdis, pt_id, other_id, level);
469
445
 
470
- link_targets.pop();
471
- }
446
+ link_targets.pop();
447
+ }
472
448
  }
473
449
 
474
-
475
450
  /**************************************************************
476
451
  * Building, parallel
477
452
  **************************************************************/
478
453
 
479
- void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
480
- std::vector<omp_lock_t>& locks,
481
- VisitedTable& vt)
482
- {
483
- // greedy search on upper levels
454
+ void HNSW::add_with_locks(
455
+ DistanceComputer& ptdis,
456
+ int pt_level,
457
+ int pt_id,
458
+ std::vector<omp_lock_t>& locks,
459
+ VisitedTable& vt) {
460
+ // greedy search on upper levels
484
461
 
485
- storage_idx_t nearest;
462
+ storage_idx_t nearest;
486
463
  #pragma omp critical
487
- {
488
- nearest = entry_point;
464
+ {
465
+ nearest = entry_point;
489
466
 
490
- if (nearest == -1) {
491
- max_level = pt_level;
492
- entry_point = pt_id;
467
+ if (nearest == -1) {
468
+ max_level = pt_level;
469
+ entry_point = pt_id;
470
+ }
493
471
  }
494
- }
495
472
 
496
- if (nearest < 0) {
497
- return;
498
- }
473
+ if (nearest < 0) {
474
+ return;
475
+ }
499
476
 
500
- omp_set_lock(&locks[pt_id]);
477
+ omp_set_lock(&locks[pt_id]);
501
478
 
502
- int level = max_level; // level at which we start adding neighbors
503
- float d_nearest = ptdis(nearest);
479
+ int level = max_level; // level at which we start adding neighbors
480
+ float d_nearest = ptdis(nearest);
504
481
 
505
- for(; level > pt_level; level--) {
506
- greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
507
- }
482
+ for (; level > pt_level; level--) {
483
+ greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
484
+ }
508
485
 
509
- for(; level >= 0; level--) {
510
- add_links_starting_from(ptdis, pt_id, nearest, d_nearest,
511
- level, locks.data(), vt);
512
- }
486
+ for (; level >= 0; level--) {
487
+ add_links_starting_from(
488
+ ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
489
+ }
513
490
 
514
- omp_unset_lock(&locks[pt_id]);
491
+ omp_unset_lock(&locks[pt_id]);
515
492
 
516
- if (pt_level > max_level) {
517
- max_level = pt_level;
518
- entry_point = pt_id;
519
- }
493
+ if (pt_level > max_level) {
494
+ max_level = pt_level;
495
+ entry_point = pt_id;
496
+ }
520
497
  }
521
498
 
522
-
523
499
  /** Do a BFS on the candidates list */
524
500
 
525
501
  int HNSW::search_from_candidates(
526
- DistanceComputer& qdis, int k,
527
- idx_t *I, float *D,
528
- MinimaxHeap& candidates,
529
- VisitedTable& vt,
530
- HNSWStats& stats,
531
- int level, int nres_in) const
532
- {
533
- int nres = nres_in;
534
- int ndis = 0;
535
- for (int i = 0; i < candidates.size(); i++) {
536
- idx_t v1 = candidates.ids[i];
537
- float d = candidates.dis[i];
538
- FAISS_ASSERT(v1 >= 0);
539
- if (nres < k) {
540
- faiss::maxheap_push(++nres, D, I, d, v1);
541
- } else if (d < D[0]) {
542
- faiss::maxheap_replace_top(nres, D, I, d, v1);
543
- }
544
- vt.set(v1);
545
- }
546
-
547
- bool do_dis_check = check_relative_distance;
548
- int nstep = 0;
549
-
550
- while (candidates.size() > 0) {
551
- float d0 = 0;
552
- int v0 = candidates.pop_min(&d0);
553
-
554
- if (do_dis_check) {
555
- // tricky stopping condition: there are more that ef
556
- // distances that are processed already that are smaller
557
- // than d0
558
-
559
- int n_dis_below = candidates.count_below(d0);
560
- if(n_dis_below >= efSearch) {
561
- break;
562
- }
563
- }
564
-
565
- size_t begin, end;
566
- neighbor_range(v0, level, &begin, &end);
567
-
568
- for (size_t j = begin; j < end; j++) {
569
- int v1 = neighbors[j];
570
- if (v1 < 0) break;
571
- if (vt.get(v1)) {
572
- continue;
573
- }
574
- vt.set(v1);
575
- ndis++;
576
- float d = qdis(v1);
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
581
- }
582
- candidates.push(v1, d);
583
- }
584
-
585
- nstep++;
586
- if (!do_dis_check && nstep > efSearch) {
587
- break;
588
- }
589
- }
590
-
591
- if (level == 0) {
592
- stats.n1 ++;
593
- if (candidates.size() == 0) {
594
- stats.n2 ++;
502
+ DistanceComputer& qdis,
503
+ int k,
504
+ idx_t* I,
505
+ float* D,
506
+ MinimaxHeap& candidates,
507
+ VisitedTable& vt,
508
+ HNSWStats& stats,
509
+ int level,
510
+ int nres_in) const {
511
+ int nres = nres_in;
512
+ int ndis = 0;
513
+ for (int i = 0; i < candidates.size(); i++) {
514
+ idx_t v1 = candidates.ids[i];
515
+ float d = candidates.dis[i];
516
+ FAISS_ASSERT(v1 >= 0);
517
+ if (nres < k) {
518
+ faiss::maxheap_push(++nres, D, I, d, v1);
519
+ } else if (d < D[0]) {
520
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
521
+ }
522
+ vt.set(v1);
595
523
  }
596
- stats.n3 += ndis;
597
- }
598
-
599
- return nres;
600
- }
601
-
602
524
 
603
- /**************************************************************
604
- * Searching
605
- **************************************************************/
525
+ bool do_dis_check = check_relative_distance;
526
+ int nstep = 0;
606
527
 
607
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
608
- const Node& node,
609
- DistanceComputer& qdis,
610
- int ef,
611
- VisitedTable *vt,
612
- HNSWStats& stats) const
613
- {
614
- int ndis = 0;
615
- std::priority_queue<Node> top_candidates;
616
- std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
528
+ while (candidates.size() > 0) {
529
+ float d0 = 0;
530
+ int v0 = candidates.pop_min(&d0);
617
531
 
618
- top_candidates.push(node);
619
- candidates.push(node);
532
+ if (do_dis_check) {
533
+ // tricky stopping condition: there are more that ef
534
+ // distances that are processed already that are smaller
535
+ // than d0
620
536
 
621
- vt->set(node.second);
537
+ int n_dis_below = candidates.count_below(d0);
538
+ if (n_dis_below >= efSearch) {
539
+ break;
540
+ }
541
+ }
622
542
 
623
- while (!candidates.empty()) {
624
- float d0;
625
- storage_idx_t v0;
626
- std::tie(d0, v0) = candidates.top();
543
+ size_t begin, end;
544
+ neighbor_range(v0, level, &begin, &end);
545
+
546
+ for (size_t j = begin; j < end; j++) {
547
+ int v1 = neighbors[j];
548
+ if (v1 < 0)
549
+ break;
550
+ if (vt.get(v1)) {
551
+ continue;
552
+ }
553
+ vt.set(v1);
554
+ ndis++;
555
+ float d = qdis(v1);
556
+ if (nres < k) {
557
+ faiss::maxheap_push(++nres, D, I, d, v1);
558
+ } else if (d < D[0]) {
559
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
560
+ }
561
+ candidates.push(v1, d);
562
+ }
627
563
 
628
- if (d0 > top_candidates.top().first) {
629
- break;
564
+ nstep++;
565
+ if (!do_dis_check && nstep > efSearch) {
566
+ break;
567
+ }
630
568
  }
631
569
 
632
- candidates.pop();
633
-
634
- size_t begin, end;
635
- neighbor_range(v0, 0, &begin, &end);
636
-
637
- for (size_t j = begin; j < end; ++j) {
638
- int v1 = neighbors[j];
639
-
640
- if (v1 < 0) {
641
- break;
642
- }
643
- if (vt->get(v1)) {
644
- continue;
645
- }
646
-
647
- vt->set(v1);
648
-
649
- float d1 = qdis(v1);
650
- ++ndis;
651
-
652
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
653
- candidates.emplace(d1, v1);
654
- top_candidates.emplace(d1, v1);
655
-
656
- if (top_candidates.size() > ef) {
657
- top_candidates.pop();
570
+ if (level == 0) {
571
+ stats.n1++;
572
+ if (candidates.size() == 0) {
573
+ stats.n2++;
658
574
  }
659
- }
575
+ stats.n3 += ndis;
660
576
  }
661
- }
662
-
663
- ++stats.n1;
664
- if (candidates.size() == 0) {
665
- ++stats.n2;
666
- }
667
- stats.n3 += ndis;
668
577
 
669
- return top_candidates;
578
+ return nres;
670
579
  }
671
580
 
672
- HNSWStats HNSW::search(DistanceComputer& qdis, int k,
673
- idx_t *I, float *D,
674
- VisitedTable& vt) const
675
- {
676
- HNSWStats stats;
677
-
678
- if (upper_beam == 1) {
679
-
680
- // greedy search on upper levels
681
- storage_idx_t nearest = entry_point;
682
- float d_nearest = qdis(nearest);
581
+ /**************************************************************
582
+ * Searching
583
+ **************************************************************/
683
584
 
684
- for(int level = max_level; level >= 1; level--) {
685
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
686
- }
585
+ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
586
+ const Node& node,
587
+ DistanceComputer& qdis,
588
+ int ef,
589
+ VisitedTable* vt,
590
+ HNSWStats& stats) const {
591
+ int ndis = 0;
592
+ std::priority_queue<Node> top_candidates;
593
+ std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
594
+
595
+ top_candidates.push(node);
596
+ candidates.push(node);
597
+
598
+ vt->set(node.second);
599
+
600
+ while (!candidates.empty()) {
601
+ float d0;
602
+ storage_idx_t v0;
603
+ std::tie(d0, v0) = candidates.top();
604
+
605
+ if (d0 > top_candidates.top().first) {
606
+ break;
607
+ }
687
608
 
688
- int ef = std::max(efSearch, k);
689
- if (search_bounded_queue) {
690
- MinimaxHeap candidates(ef);
609
+ candidates.pop();
691
610
 
692
- candidates.push(nearest, d_nearest);
611
+ size_t begin, end;
612
+ neighbor_range(v0, 0, &begin, &end);
693
613
 
694
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
695
- } else {
696
- std::priority_queue<Node> top_candidates =
697
- search_from_candidate_unbounded(Node(d_nearest, nearest),
698
- qdis, ef, &vt, stats);
614
+ for (size_t j = begin; j < end; ++j) {
615
+ int v1 = neighbors[j];
699
616
 
700
- while (top_candidates.size() > k) {
701
- top_candidates.pop();
702
- }
617
+ if (v1 < 0) {
618
+ break;
619
+ }
620
+ if (vt->get(v1)) {
621
+ continue;
622
+ }
703
623
 
704
- int nres = 0;
705
- while (!top_candidates.empty()) {
706
- float d;
707
- storage_idx_t label;
708
- std::tie(d, label) = top_candidates.top();
709
- faiss::maxheap_push(++nres, D, I, d, label);
710
- top_candidates.pop();
711
- }
712
- }
624
+ vt->set(v1);
713
625
 
714
- vt.advance();
626
+ float d1 = qdis(v1);
627
+ ++ndis;
715
628
 
716
- } else {
717
- int candidates_size = upper_beam;
718
- MinimaxHeap candidates(candidates_size);
629
+ if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
630
+ candidates.emplace(d1, v1);
631
+ top_candidates.emplace(d1, v1);
719
632
 
720
- std::vector<idx_t> I_to_next(candidates_size);
721
- std::vector<float> D_to_next(candidates_size);
633
+ if (top_candidates.size() > ef) {
634
+ top_candidates.pop();
635
+ }
636
+ }
637
+ }
638
+ }
722
639
 
723
- int nres = 1;
724
- I_to_next[0] = entry_point;
725
- D_to_next[0] = qdis(entry_point);
640
+ ++stats.n1;
641
+ if (candidates.size() == 0) {
642
+ ++stats.n2;
643
+ }
644
+ stats.n3 += ndis;
726
645
 
727
- for(int level = max_level; level >= 0; level--) {
646
+ return top_candidates;
647
+ }
728
648
 
729
- // copy I, D -> candidates
649
+ HNSWStats HNSW::search(
650
+ DistanceComputer& qdis,
651
+ int k,
652
+ idx_t* I,
653
+ float* D,
654
+ VisitedTable& vt) const {
655
+ HNSWStats stats;
656
+
657
+ if (upper_beam == 1) {
658
+ // greedy search on upper levels
659
+ storage_idx_t nearest = entry_point;
660
+ float d_nearest = qdis(nearest);
661
+
662
+ for (int level = max_level; level >= 1; level--) {
663
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
664
+ }
730
665
 
731
- candidates.clear();
666
+ int ef = std::max(efSearch, k);
667
+ if (search_bounded_queue) {
668
+ MinimaxHeap candidates(ef);
669
+
670
+ candidates.push(nearest, d_nearest);
671
+
672
+ search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
673
+ } else {
674
+ std::priority_queue<Node> top_candidates =
675
+ search_from_candidate_unbounded(
676
+ Node(d_nearest, nearest), qdis, ef, &vt, stats);
677
+
678
+ while (top_candidates.size() > k) {
679
+ top_candidates.pop();
680
+ }
681
+
682
+ int nres = 0;
683
+ while (!top_candidates.empty()) {
684
+ float d;
685
+ storage_idx_t label;
686
+ std::tie(d, label) = top_candidates.top();
687
+ faiss::maxheap_push(++nres, D, I, d, label);
688
+ top_candidates.pop();
689
+ }
690
+ }
732
691
 
733
- for (int i = 0; i < nres; i++) {
734
- candidates.push(I_to_next[i], D_to_next[i]);
735
- }
692
+ vt.advance();
736
693
 
737
- if (level == 0) {
738
- nres = search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
739
- } else {
740
- nres = search_from_candidates(
741
- qdis, candidates_size,
742
- I_to_next.data(), D_to_next.data(),
743
- candidates, vt, stats, level
744
- );
745
- }
746
- vt.advance();
694
+ } else {
695
+ int candidates_size = upper_beam;
696
+ MinimaxHeap candidates(candidates_size);
697
+
698
+ std::vector<idx_t> I_to_next(candidates_size);
699
+ std::vector<float> D_to_next(candidates_size);
700
+
701
+ int nres = 1;
702
+ I_to_next[0] = entry_point;
703
+ D_to_next[0] = qdis(entry_point);
704
+
705
+ for (int level = max_level; level >= 0; level--) {
706
+ // copy I, D -> candidates
707
+
708
+ candidates.clear();
709
+
710
+ for (int i = 0; i < nres; i++) {
711
+ candidates.push(I_to_next[i], D_to_next[i]);
712
+ }
713
+
714
+ if (level == 0) {
715
+ nres = search_from_candidates(
716
+ qdis, k, I, D, candidates, vt, stats, 0);
717
+ } else {
718
+ nres = search_from_candidates(
719
+ qdis,
720
+ candidates_size,
721
+ I_to_next.data(),
722
+ D_to_next.data(),
723
+ candidates,
724
+ vt,
725
+ stats,
726
+ level);
727
+ }
728
+ vt.advance();
729
+ }
747
730
  }
748
- }
749
731
 
750
- return stats;
732
+ return stats;
751
733
  }
752
734
 
753
-
754
735
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
755
- if (k == n) {
756
- if (v >= dis[0]) return;
757
- faiss::heap_pop<HC> (k--, dis.data(), ids.data());
758
- --nvalid;
759
- }
760
- faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
761
- ++nvalid;
736
+ if (k == n) {
737
+ if (v >= dis[0])
738
+ return;
739
+ faiss::heap_pop<HC>(k--, dis.data(), ids.data());
740
+ --nvalid;
741
+ }
742
+ faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
743
+ ++nvalid;
762
744
  }
763
745
 
764
746
  float HNSW::MinimaxHeap::max() const {
765
- return dis[0];
747
+ return dis[0];
766
748
  }
767
749
 
768
750
  int HNSW::MinimaxHeap::size() const {
769
- return nvalid;
751
+ return nvalid;
770
752
  }
771
753
 
772
754
  void HNSW::MinimaxHeap::clear() {
773
- nvalid = k = 0;
755
+ nvalid = k = 0;
774
756
  }
775
757
 
776
- int HNSW::MinimaxHeap::pop_min(float *vmin_out) {
777
- assert(k > 0);
778
- // returns min. This is an O(n) operation
779
- int i = k - 1;
780
- while (i >= 0) {
781
- if (ids[i] != -1) break;
782
- i--;
783
- }
784
- if (i == -1) return -1;
785
- int imin = i;
786
- float vmin = dis[i];
787
- i--;
788
- while(i >= 0) {
789
- if (ids[i] != -1 && dis[i] < vmin) {
790
- vmin = dis[i];
791
- imin = i;
758
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
759
+ assert(k > 0);
760
+ // returns min. This is an O(n) operation
761
+ int i = k - 1;
762
+ while (i >= 0) {
763
+ if (ids[i] != -1)
764
+ break;
765
+ i--;
792
766
  }
767
+ if (i == -1)
768
+ return -1;
769
+ int imin = i;
770
+ float vmin = dis[i];
793
771
  i--;
794
- }
795
- if (vmin_out) *vmin_out = vmin;
796
- int ret = ids[imin];
797
- ids[imin] = -1;
798
- --nvalid;
772
+ while (i >= 0) {
773
+ if (ids[i] != -1 && dis[i] < vmin) {
774
+ vmin = dis[i];
775
+ imin = i;
776
+ }
777
+ i--;
778
+ }
779
+ if (vmin_out)
780
+ *vmin_out = vmin;
781
+ int ret = ids[imin];
782
+ ids[imin] = -1;
783
+ --nvalid;
799
784
 
800
- return ret;
785
+ return ret;
801
786
  }
802
787
 
803
788
  int HNSW::MinimaxHeap::count_below(float thresh) {
804
- int n_below = 0;
805
- for(int i = 0; i < k; i++) {
806
- if (dis[i] < thresh) {
807
- n_below++;
789
+ int n_below = 0;
790
+ for (int i = 0; i < k; i++) {
791
+ if (dis[i] < thresh) {
792
+ n_below++;
793
+ }
808
794
  }
809
- }
810
795
 
811
- return n_below;
796
+ return n_below;
812
797
  }
813
798
 
814
-
815
- } // namespace faiss
799
+ } // namespace faiss