faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -15,275 +15,254 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
-
19
18
  /**************************************************************
20
19
  * HNSW structure implementation
21
20
  **************************************************************/
22
21
 
23
- int HNSW::nb_neighbors(int layer_no) const
24
- {
25
- return cum_nneighbor_per_level[layer_no + 1] -
26
- cum_nneighbor_per_level[layer_no];
22
+ int HNSW::nb_neighbors(int layer_no) const {
23
+ return cum_nneighbor_per_level[layer_no + 1] -
24
+ cum_nneighbor_per_level[layer_no];
27
25
  }
28
26
 
29
- void HNSW::set_nb_neighbors(int level_no, int n)
30
- {
31
- FAISS_THROW_IF_NOT(levels.size() == 0);
32
- int cur_n = nb_neighbors(level_no);
33
- for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
34
- cum_nneighbor_per_level[i] += n - cur_n;
35
- }
27
+ void HNSW::set_nb_neighbors(int level_no, int n) {
28
+ FAISS_THROW_IF_NOT(levels.size() == 0);
29
+ int cur_n = nb_neighbors(level_no);
30
+ for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
31
+ cum_nneighbor_per_level[i] += n - cur_n;
32
+ }
36
33
  }
37
34
 
38
- int HNSW::cum_nb_neighbors(int layer_no) const
39
- {
40
- return cum_nneighbor_per_level[layer_no];
35
+ int HNSW::cum_nb_neighbors(int layer_no) const {
36
+ return cum_nneighbor_per_level[layer_no];
41
37
  }
42
38
 
43
- void HNSW::neighbor_range(idx_t no, int layer_no,
44
- size_t * begin, size_t * end) const
45
- {
46
- size_t o = offsets[no];
47
- *begin = o + cum_nb_neighbors(layer_no);
48
- *end = o + cum_nb_neighbors(layer_no + 1);
39
+ void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end)
40
+ const {
41
+ size_t o = offsets[no];
42
+ *begin = o + cum_nb_neighbors(layer_no);
43
+ *end = o + cum_nb_neighbors(layer_no + 1);
49
44
  }
50
45
 
51
-
52
-
53
46
  HNSW::HNSW(int M) : rng(12345) {
54
- set_default_probas(M, 1.0 / log(M));
55
- max_level = -1;
56
- entry_point = -1;
57
- efSearch = 16;
58
- efConstruction = 40;
59
- upper_beam = 1;
60
- offsets.push_back(0);
47
+ set_default_probas(M, 1.0 / log(M));
48
+ max_level = -1;
49
+ entry_point = -1;
50
+ efSearch = 16;
51
+ efConstruction = 40;
52
+ upper_beam = 1;
53
+ offsets.push_back(0);
61
54
  }
62
55
 
63
-
64
- int HNSW::random_level()
65
- {
66
- double f = rng.rand_float();
67
- // could be a bit faster with bissection
68
- for (int level = 0; level < assign_probas.size(); level++) {
69
- if (f < assign_probas[level]) {
70
- return level;
56
+ int HNSW::random_level() {
57
+ double f = rng.rand_float();
58
+ // could be a bit faster with bissection
59
+ for (int level = 0; level < assign_probas.size(); level++) {
60
+ if (f < assign_probas[level]) {
61
+ return level;
62
+ }
63
+ f -= assign_probas[level];
71
64
  }
72
- f -= assign_probas[level];
73
- }
74
- // happens with exponentially low probability
75
- return assign_probas.size() - 1;
65
+ // happens with exponentially low probability
66
+ return assign_probas.size() - 1;
76
67
  }
77
68
 
78
- void HNSW::set_default_probas(int M, float levelMult)
79
- {
80
- int nn = 0;
81
- cum_nneighbor_per_level.push_back (0);
82
- for (int level = 0; ;level++) {
83
- float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
84
- if (proba < 1e-9) break;
85
- assign_probas.push_back(proba);
86
- nn += level == 0 ? M * 2 : M;
87
- cum_nneighbor_per_level.push_back (nn);
88
- }
69
+ void HNSW::set_default_probas(int M, float levelMult) {
70
+ int nn = 0;
71
+ cum_nneighbor_per_level.push_back(0);
72
+ for (int level = 0;; level++) {
73
+ float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
74
+ if (proba < 1e-9)
75
+ break;
76
+ assign_probas.push_back(proba);
77
+ nn += level == 0 ? M * 2 : M;
78
+ cum_nneighbor_per_level.push_back(nn);
79
+ }
89
80
  }
90
81
 
91
- void HNSW::clear_neighbor_tables(int level)
92
- {
93
- for (int i = 0; i < levels.size(); i++) {
94
- size_t begin, end;
95
- neighbor_range(i, level, &begin, &end);
96
- for (size_t j = begin; j < end; j++) {
97
- neighbors[j] = -1;
82
+ void HNSW::clear_neighbor_tables(int level) {
83
+ for (int i = 0; i < levels.size(); i++) {
84
+ size_t begin, end;
85
+ neighbor_range(i, level, &begin, &end);
86
+ for (size_t j = begin; j < end; j++) {
87
+ neighbors[j] = -1;
88
+ }
98
89
  }
99
- }
100
90
  }
101
91
 
102
-
103
92
  void HNSW::reset() {
104
- max_level = -1;
105
- entry_point = -1;
106
- offsets.clear();
107
- offsets.push_back(0);
108
- levels.clear();
109
- neighbors.clear();
93
+ max_level = -1;
94
+ entry_point = -1;
95
+ offsets.clear();
96
+ offsets.push_back(0);
97
+ levels.clear();
98
+ neighbors.clear();
110
99
  }
111
100
 
112
-
113
-
114
- void HNSW::print_neighbor_stats(int level) const
115
- {
116
- FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size());
117
- printf("stats on level %d, max %d neighbors per vertex:\n",
118
- level, nb_neighbors(level));
119
- size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
101
+ void HNSW::print_neighbor_stats(int level) const {
102
+ FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
103
+ printf("stats on level %d, max %d neighbors per vertex:\n",
104
+ level,
105
+ nb_neighbors(level));
106
+ size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
120
107
  #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
121
108
  reduction(+: tot_reciprocal) reduction(+: n_node)
122
- for (int i = 0; i < levels.size(); i++) {
123
- if (levels[i] > level) {
124
- n_node++;
125
- size_t begin, end;
126
- neighbor_range(i, level, &begin, &end);
127
- std::unordered_set<int> neighset;
128
- for (size_t j = begin; j < end; j++) {
129
- if (neighbors [j] < 0) break;
130
- neighset.insert(neighbors[j]);
131
- }
132
- int n_neigh = neighset.size();
133
- int n_common = 0;
134
- int n_reciprocal = 0;
135
- for (size_t j = begin; j < end; j++) {
136
- storage_idx_t i2 = neighbors[j];
137
- if (i2 < 0) break;
138
- FAISS_ASSERT(i2 != i);
139
- size_t begin2, end2;
140
- neighbor_range(i2, level, &begin2, &end2);
141
- for (size_t j2 = begin2; j2 < end2; j2++) {
142
- storage_idx_t i3 = neighbors[j2];
143
- if (i3 < 0) break;
144
- if (i3 == i) {
145
- n_reciprocal++;
146
- continue;
147
- }
148
- if (neighset.count(i3)) {
149
- neighset.erase(i3);
150
- n_common++;
151
- }
109
+ for (int i = 0; i < levels.size(); i++) {
110
+ if (levels[i] > level) {
111
+ n_node++;
112
+ size_t begin, end;
113
+ neighbor_range(i, level, &begin, &end);
114
+ std::unordered_set<int> neighset;
115
+ for (size_t j = begin; j < end; j++) {
116
+ if (neighbors[j] < 0)
117
+ break;
118
+ neighset.insert(neighbors[j]);
119
+ }
120
+ int n_neigh = neighset.size();
121
+ int n_common = 0;
122
+ int n_reciprocal = 0;
123
+ for (size_t j = begin; j < end; j++) {
124
+ storage_idx_t i2 = neighbors[j];
125
+ if (i2 < 0)
126
+ break;
127
+ FAISS_ASSERT(i2 != i);
128
+ size_t begin2, end2;
129
+ neighbor_range(i2, level, &begin2, &end2);
130
+ for (size_t j2 = begin2; j2 < end2; j2++) {
131
+ storage_idx_t i3 = neighbors[j2];
132
+ if (i3 < 0)
133
+ break;
134
+ if (i3 == i) {
135
+ n_reciprocal++;
136
+ continue;
137
+ }
138
+ if (neighset.count(i3)) {
139
+ neighset.erase(i3);
140
+ n_common++;
141
+ }
142
+ }
143
+ }
144
+ tot_neigh += n_neigh;
145
+ tot_common += n_common;
146
+ tot_reciprocal += n_reciprocal;
152
147
  }
153
- }
154
- tot_neigh += n_neigh;
155
- tot_common += n_common;
156
- tot_reciprocal += n_reciprocal;
157
148
  }
158
- }
159
- float normalizer = n_node;
160
- printf(" nb of nodes at that level %zd\n", n_node);
161
- printf(" neighbors per node: %.2f (%zd)\n",
162
- tot_neigh / normalizer, tot_neigh);
163
- printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer);
164
- printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
165
- tot_common / normalizer, tot_common);
166
-
167
-
168
-
149
+ float normalizer = n_node;
150
+ printf(" nb of nodes at that level %zd\n", n_node);
151
+ printf(" neighbors per node: %.2f (%zd)\n",
152
+ tot_neigh / normalizer,
153
+ tot_neigh);
154
+ printf(" nb of reciprocal neighbors: %.2f\n",
155
+ tot_reciprocal / normalizer);
156
+ printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
157
+ tot_common / normalizer,
158
+ tot_common);
169
159
  }
170
160
 
161
+ void HNSW::fill_with_random_links(size_t n) {
162
+ int max_level = prepare_level_tab(n);
163
+ RandomGenerator rng2(456);
171
164
 
172
- void HNSW::fill_with_random_links(size_t n)
173
- {
174
- int max_level = prepare_level_tab(n);
175
- RandomGenerator rng2(456);
176
-
177
- for (int level = max_level - 1; level >= 0; --level) {
178
- std::vector<int> elts;
179
- for (int i = 0; i < n; i++) {
180
- if (levels[i] > level) {
181
- elts.push_back(i);
182
- }
183
- }
184
- printf ("linking %zd elements in level %d\n",
185
- elts.size(), level);
186
-
187
- if (elts.size() == 1) continue;
165
+ for (int level = max_level - 1; level >= 0; --level) {
166
+ std::vector<int> elts;
167
+ for (int i = 0; i < n; i++) {
168
+ if (levels[i] > level) {
169
+ elts.push_back(i);
170
+ }
171
+ }
172
+ printf("linking %zd elements in level %d\n", elts.size(), level);
188
173
 
189
- for (int ii = 0; ii < elts.size(); ii++) {
190
- int i = elts[ii];
191
- size_t begin, end;
192
- neighbor_range(i, 0, &begin, &end);
193
- for (size_t j = begin; j < end; j++) {
194
- int other = 0;
195
- do {
196
- other = elts[rng2.rand_int(elts.size())];
197
- } while(other == i);
174
+ if (elts.size() == 1)
175
+ continue;
198
176
 
199
- neighbors[j] = other;
200
- }
177
+ for (int ii = 0; ii < elts.size(); ii++) {
178
+ int i = elts[ii];
179
+ size_t begin, end;
180
+ neighbor_range(i, 0, &begin, &end);
181
+ for (size_t j = begin; j < end; j++) {
182
+ int other = 0;
183
+ do {
184
+ other = elts[rng2.rand_int(elts.size())];
185
+ } while (other == i);
186
+
187
+ neighbors[j] = other;
188
+ }
189
+ }
201
190
  }
202
- }
203
191
  }
204
192
 
193
+ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
194
+ size_t n0 = offsets.size() - 1;
205
195
 
206
- int HNSW::prepare_level_tab(size_t n, bool preset_levels)
207
- {
208
- size_t n0 = offsets.size() - 1;
196
+ if (preset_levels) {
197
+ FAISS_ASSERT(n0 + n == levels.size());
198
+ } else {
199
+ FAISS_ASSERT(n0 == levels.size());
200
+ for (int i = 0; i < n; i++) {
201
+ int pt_level = random_level();
202
+ levels.push_back(pt_level + 1);
203
+ }
204
+ }
209
205
 
210
- if (preset_levels) {
211
- FAISS_ASSERT (n0 + n == levels.size());
212
- } else {
213
- FAISS_ASSERT (n0 == levels.size());
206
+ int max_level = 0;
214
207
  for (int i = 0; i < n; i++) {
215
- int pt_level = random_level();
216
- levels.push_back(pt_level + 1);
208
+ int pt_level = levels[i + n0] - 1;
209
+ if (pt_level > max_level)
210
+ max_level = pt_level;
211
+ offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
212
+ neighbors.resize(offsets.back(), -1);
217
213
  }
218
- }
219
214
 
220
- int max_level = 0;
221
- for (int i = 0; i < n; i++) {
222
- int pt_level = levels[i + n0] - 1;
223
- if (pt_level > max_level) max_level = pt_level;
224
- offsets.push_back(offsets.back() +
225
- cum_nb_neighbors(pt_level + 1));
226
- neighbors.resize(offsets.back(), -1);
227
- }
228
-
229
- return max_level;
215
+ return max_level;
230
216
  }
231
217
 
232
-
233
218
  /** Enumerate vertices from farthest to nearest from query, keep a
234
219
  * neighbor only if there is no previous neighbor that is closer to
235
220
  * that vertex than the query.
236
221
  */
237
222
  void HNSW::shrink_neighbor_list(
238
- DistanceComputer& qdis,
239
- std::priority_queue<NodeDistFarther>& input,
240
- std::vector<NodeDistFarther>& output,
241
- int max_size)
242
- {
243
- while (input.size() > 0) {
244
- NodeDistFarther v1 = input.top();
245
- input.pop();
246
- float dist_v1_q = v1.d;
247
-
248
- bool good = true;
249
- for (NodeDistFarther v2 : output) {
250
- float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
251
-
252
- if (dist_v1_v2 < dist_v1_q) {
253
- good = false;
254
- break;
255
- }
256
- }
257
-
258
- if (good) {
259
- output.push_back(v1);
260
- if (output.size() >= max_size) {
261
- return;
262
- }
223
+ DistanceComputer& qdis,
224
+ std::priority_queue<NodeDistFarther>& input,
225
+ std::vector<NodeDistFarther>& output,
226
+ int max_size) {
227
+ while (input.size() > 0) {
228
+ NodeDistFarther v1 = input.top();
229
+ input.pop();
230
+ float dist_v1_q = v1.d;
231
+
232
+ bool good = true;
233
+ for (NodeDistFarther v2 : output) {
234
+ float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
235
+
236
+ if (dist_v1_v2 < dist_v1_q) {
237
+ good = false;
238
+ break;
239
+ }
240
+ }
241
+
242
+ if (good) {
243
+ output.push_back(v1);
244
+ if (output.size() >= max_size) {
245
+ return;
246
+ }
247
+ }
263
248
  }
264
- }
265
249
  }
266
250
 
267
-
268
251
  namespace {
269
252
 
270
-
271
253
  using storage_idx_t = HNSW::storage_idx_t;
272
254
  using NodeDistCloser = HNSW::NodeDistCloser;
273
255
  using NodeDistFarther = HNSW::NodeDistFarther;
274
256
 
275
-
276
257
  /**************************************************************
277
258
  * Addition subroutines
278
259
  **************************************************************/
279
260
 
280
-
281
261
  /// remove neighbors from the list to make it smaller than max_size
282
262
  void shrink_neighbor_list(
283
- DistanceComputer& qdis,
284
- std::priority_queue<NodeDistCloser>& resultSet1,
285
- int max_size)
286
- {
263
+ DistanceComputer& qdis,
264
+ std::priority_queue<NodeDistCloser>& resultSet1,
265
+ int max_size) {
287
266
  if (resultSet1.size() < max_size) {
288
267
  return;
289
268
  }
@@ -300,516 +279,526 @@ void shrink_neighbor_list(
300
279
  for (NodeDistFarther curen2 : returnlist) {
301
280
  resultSet1.emplace(curen2.d, curen2.id);
302
281
  }
303
-
304
282
  }
305
283
 
306
-
307
284
  /// add a link between two elements, possibly shrinking the list
308
285
  /// of links to make room for it.
309
- void add_link(HNSW& hnsw,
310
- DistanceComputer& qdis,
311
- storage_idx_t src, storage_idx_t dest,
312
- int level)
313
- {
314
- size_t begin, end;
315
- hnsw.neighbor_range(src, level, &begin, &end);
316
- if (hnsw.neighbors[end - 1] == -1) {
317
- // there is enough room, find a slot to add it
318
- size_t i = end;
319
- while(i > begin) {
320
- if (hnsw.neighbors[i - 1] != -1) break;
321
- i--;
322
- }
323
- hnsw.neighbors[i] = dest;
324
- return;
325
- }
326
-
327
- // otherwise we let them fight out which to keep
328
-
329
- // copy to resultSet...
330
- std::priority_queue<NodeDistCloser> resultSet;
331
- resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
332
- for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
333
- storage_idx_t neigh = hnsw.neighbors[i];
334
- resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
335
- }
336
-
337
- shrink_neighbor_list(qdis, resultSet, end - begin);
338
-
339
- // ...and back
340
- size_t i = begin;
341
- while (resultSet.size()) {
342
- hnsw.neighbors[i++] = resultSet.top().id;
343
- resultSet.pop();
344
- }
345
- // they may have shrunk more than just by 1 element
346
- while(i < end) {
347
- hnsw.neighbors[i++] = -1;
348
- }
286
+ void add_link(
287
+ HNSW& hnsw,
288
+ DistanceComputer& qdis,
289
+ storage_idx_t src,
290
+ storage_idx_t dest,
291
+ int level) {
292
+ size_t begin, end;
293
+ hnsw.neighbor_range(src, level, &begin, &end);
294
+ if (hnsw.neighbors[end - 1] == -1) {
295
+ // there is enough room, find a slot to add it
296
+ size_t i = end;
297
+ while (i > begin) {
298
+ if (hnsw.neighbors[i - 1] != -1)
299
+ break;
300
+ i--;
301
+ }
302
+ hnsw.neighbors[i] = dest;
303
+ return;
304
+ }
305
+
306
+ // otherwise we let them fight out which to keep
307
+
308
+ // copy to resultSet...
309
+ std::priority_queue<NodeDistCloser> resultSet;
310
+ resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
311
+ for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
312
+ storage_idx_t neigh = hnsw.neighbors[i];
313
+ resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
314
+ }
315
+
316
+ shrink_neighbor_list(qdis, resultSet, end - begin);
317
+
318
+ // ...and back
319
+ size_t i = begin;
320
+ while (resultSet.size()) {
321
+ hnsw.neighbors[i++] = resultSet.top().id;
322
+ resultSet.pop();
323
+ }
324
+ // they may have shrunk more than just by 1 element
325
+ while (i < end) {
326
+ hnsw.neighbors[i++] = -1;
327
+ }
349
328
  }
350
329
 
351
330
  /// search neighbors on a single level, starting from an entry point
352
331
  void search_neighbors_to_add(
353
- HNSW& hnsw,
354
- DistanceComputer& qdis,
355
- std::priority_queue<NodeDistCloser>& results,
356
- int entry_point,
357
- float d_entry_point,
358
- int level,
359
- VisitedTable &vt)
360
- {
361
- // top is nearest candidate
362
- std::priority_queue<NodeDistFarther> candidates;
363
-
364
- NodeDistFarther ev(d_entry_point, entry_point);
365
- candidates.push(ev);
366
- results.emplace(d_entry_point, entry_point);
367
- vt.set(entry_point);
368
-
369
- while (!candidates.empty()) {
370
- // get nearest
371
- const NodeDistFarther &currEv = candidates.top();
372
-
373
- if (currEv.d > results.top().d) {
374
- break;
375
- }
376
- int currNode = currEv.id;
377
- candidates.pop();
378
-
379
- // loop over neighbors
380
- size_t begin, end;
381
- hnsw.neighbor_range(currNode, level, &begin, &end);
382
- for(size_t i = begin; i < end; i++) {
383
- storage_idx_t nodeId = hnsw.neighbors[i];
384
- if (nodeId < 0) break;
385
- if (vt.get(nodeId)) continue;
386
- vt.set(nodeId);
387
-
388
- float dis = qdis(nodeId);
389
- NodeDistFarther evE1(dis, nodeId);
390
-
391
- if (results.size() < hnsw.efConstruction ||
392
- results.top().d > dis) {
393
-
394
- results.emplace(dis, nodeId);
395
- candidates.emplace(dis, nodeId);
396
- if (results.size() > hnsw.efConstruction) {
397
- results.pop();
332
+ HNSW& hnsw,
333
+ DistanceComputer& qdis,
334
+ std::priority_queue<NodeDistCloser>& results,
335
+ int entry_point,
336
+ float d_entry_point,
337
+ int level,
338
+ VisitedTable& vt) {
339
+ // top is nearest candidate
340
+ std::priority_queue<NodeDistFarther> candidates;
341
+
342
+ NodeDistFarther ev(d_entry_point, entry_point);
343
+ candidates.push(ev);
344
+ results.emplace(d_entry_point, entry_point);
345
+ vt.set(entry_point);
346
+
347
+ while (!candidates.empty()) {
348
+ // get nearest
349
+ const NodeDistFarther& currEv = candidates.top();
350
+
351
+ if (currEv.d > results.top().d) {
352
+ break;
353
+ }
354
+ int currNode = currEv.id;
355
+ candidates.pop();
356
+
357
+ // loop over neighbors
358
+ size_t begin, end;
359
+ hnsw.neighbor_range(currNode, level, &begin, &end);
360
+ for (size_t i = begin; i < end; i++) {
361
+ storage_idx_t nodeId = hnsw.neighbors[i];
362
+ if (nodeId < 0)
363
+ break;
364
+ if (vt.get(nodeId))
365
+ continue;
366
+ vt.set(nodeId);
367
+
368
+ float dis = qdis(nodeId);
369
+ NodeDistFarther evE1(dis, nodeId);
370
+
371
+ if (results.size() < hnsw.efConstruction || results.top().d > dis) {
372
+ results.emplace(dis, nodeId);
373
+ candidates.emplace(dis, nodeId);
374
+ if (results.size() > hnsw.efConstruction) {
375
+ results.pop();
376
+ }
377
+ }
398
378
  }
399
- }
400
379
  }
401
- }
402
- vt.advance();
380
+ vt.advance();
403
381
  }
404
382
 
405
-
406
383
  /**************************************************************
407
384
  * Searching subroutines
408
385
  **************************************************************/
409
386
 
410
387
  /// greedily update a nearest vector at a given level
411
- void greedy_update_nearest(const HNSW& hnsw,
412
- DistanceComputer& qdis,
413
- int level,
414
- storage_idx_t& nearest,
415
- float& d_nearest)
416
- {
417
- for(;;) {
418
- storage_idx_t prev_nearest = nearest;
419
-
420
- size_t begin, end;
421
- hnsw.neighbor_range(nearest, level, &begin, &end);
422
- for(size_t i = begin; i < end; i++) {
423
- storage_idx_t v = hnsw.neighbors[i];
424
- if (v < 0) break;
425
- float dis = qdis(v);
426
- if (dis < d_nearest) {
427
- nearest = v;
428
- d_nearest = dis;
429
- }
430
- }
431
- if (nearest == prev_nearest) {
432
- return;
433
- }
434
- }
388
+ void greedy_update_nearest(
389
+ const HNSW& hnsw,
390
+ DistanceComputer& qdis,
391
+ int level,
392
+ storage_idx_t& nearest,
393
+ float& d_nearest) {
394
+ for (;;) {
395
+ storage_idx_t prev_nearest = nearest;
396
+
397
+ size_t begin, end;
398
+ hnsw.neighbor_range(nearest, level, &begin, &end);
399
+ for (size_t i = begin; i < end; i++) {
400
+ storage_idx_t v = hnsw.neighbors[i];
401
+ if (v < 0)
402
+ break;
403
+ float dis = qdis(v);
404
+ if (dis < d_nearest) {
405
+ nearest = v;
406
+ d_nearest = dis;
407
+ }
408
+ }
409
+ if (nearest == prev_nearest) {
410
+ return;
411
+ }
412
+ }
435
413
  }
436
414
 
437
-
438
- } // namespace
439
-
415
+ } // namespace
440
416
 
441
417
  /// Finds neighbors and builds links with them, starting from an entry
442
418
  /// point. The own neighbor list is assumed to be locked.
443
- void HNSW::add_links_starting_from(DistanceComputer& ptdis,
444
- storage_idx_t pt_id,
445
- storage_idx_t nearest,
446
- float d_nearest,
447
- int level,
448
- omp_lock_t *locks,
449
- VisitedTable &vt)
450
- {
451
- std::priority_queue<NodeDistCloser> link_targets;
452
-
453
- search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest,
454
- level, vt);
455
-
456
- // but we can afford only this many neighbors
457
- int M = nb_neighbors(level);
458
-
459
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
460
-
461
- while (!link_targets.empty()) {
462
- int other_id = link_targets.top().id;
463
-
464
- omp_set_lock(&locks[other_id]);
465
- add_link(*this, ptdis, other_id, pt_id, level);
466
- omp_unset_lock(&locks[other_id]);
467
-
468
- add_link(*this, ptdis, pt_id, other_id, level);
419
+ void HNSW::add_links_starting_from(
420
+ DistanceComputer& ptdis,
421
+ storage_idx_t pt_id,
422
+ storage_idx_t nearest,
423
+ float d_nearest,
424
+ int level,
425
+ omp_lock_t* locks,
426
+ VisitedTable& vt) {
427
+ std::priority_queue<NodeDistCloser> link_targets;
428
+
429
+ search_neighbors_to_add(
430
+ *this, ptdis, link_targets, nearest, d_nearest, level, vt);
431
+
432
+ // but we can afford only this many neighbors
433
+ int M = nb_neighbors(level);
434
+
435
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
436
+
437
+ std::vector<storage_idx_t> neighbors;
438
+ neighbors.reserve(link_targets.size());
439
+ while (!link_targets.empty()) {
440
+ storage_idx_t other_id = link_targets.top().id;
441
+ add_link(*this, ptdis, pt_id, other_id, level);
442
+ neighbors.push_back(other_id);
443
+ link_targets.pop();
444
+ }
469
445
 
470
- link_targets.pop();
471
- }
446
+ omp_unset_lock(&locks[pt_id]);
447
+ for (storage_idx_t other_id : neighbors) {
448
+ omp_set_lock(&locks[other_id]);
449
+ add_link(*this, ptdis, other_id, pt_id, level);
450
+ omp_unset_lock(&locks[other_id]);
451
+ }
452
+ omp_set_lock(&locks[pt_id]);
472
453
  }
473
454
 
474
-
475
455
  /**************************************************************
476
456
  * Building, parallel
477
457
  **************************************************************/
478
458
 
479
- void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
480
- std::vector<omp_lock_t>& locks,
481
- VisitedTable& vt)
482
- {
483
- // greedy search on upper levels
459
+ void HNSW::add_with_locks(
460
+ DistanceComputer& ptdis,
461
+ int pt_level,
462
+ int pt_id,
463
+ std::vector<omp_lock_t>& locks,
464
+ VisitedTable& vt) {
465
+ // greedy search on upper levels
484
466
 
485
- storage_idx_t nearest;
467
+ storage_idx_t nearest;
486
468
  #pragma omp critical
487
- {
488
- nearest = entry_point;
469
+ {
470
+ nearest = entry_point;
489
471
 
490
- if (nearest == -1) {
491
- max_level = pt_level;
492
- entry_point = pt_id;
472
+ if (nearest == -1) {
473
+ max_level = pt_level;
474
+ entry_point = pt_id;
475
+ }
493
476
  }
494
- }
495
477
 
496
- if (nearest < 0) {
497
- return;
498
- }
478
+ if (nearest < 0) {
479
+ return;
480
+ }
499
481
 
500
- omp_set_lock(&locks[pt_id]);
482
+ omp_set_lock(&locks[pt_id]);
501
483
 
502
- int level = max_level; // level at which we start adding neighbors
503
- float d_nearest = ptdis(nearest);
484
+ int level = max_level; // level at which we start adding neighbors
485
+ float d_nearest = ptdis(nearest);
504
486
 
505
- for(; level > pt_level; level--) {
506
- greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
507
- }
487
+ for (; level > pt_level; level--) {
488
+ greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
489
+ }
508
490
 
509
- for(; level >= 0; level--) {
510
- add_links_starting_from(ptdis, pt_id, nearest, d_nearest,
511
- level, locks.data(), vt);
512
- }
491
+ for (; level >= 0; level--) {
492
+ add_links_starting_from(
493
+ ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
494
+ }
513
495
 
514
- omp_unset_lock(&locks[pt_id]);
496
+ omp_unset_lock(&locks[pt_id]);
515
497
 
516
- if (pt_level > max_level) {
517
- max_level = pt_level;
518
- entry_point = pt_id;
519
- }
498
+ if (pt_level > max_level) {
499
+ max_level = pt_level;
500
+ entry_point = pt_id;
501
+ }
520
502
  }
521
503
 
522
-
523
504
  /** Do a BFS on the candidates list */
524
505
 
525
506
  int HNSW::search_from_candidates(
526
- DistanceComputer& qdis, int k,
527
- idx_t *I, float *D,
528
- MinimaxHeap& candidates,
529
- VisitedTable& vt,
530
- HNSWStats& stats,
531
- int level, int nres_in) const
532
- {
533
- int nres = nres_in;
534
- int ndis = 0;
535
- for (int i = 0; i < candidates.size(); i++) {
536
- idx_t v1 = candidates.ids[i];
537
- float d = candidates.dis[i];
538
- FAISS_ASSERT(v1 >= 0);
539
- if (nres < k) {
540
- faiss::maxheap_push(++nres, D, I, d, v1);
541
- } else if (d < D[0]) {
542
- faiss::maxheap_replace_top(nres, D, I, d, v1);
543
- }
544
- vt.set(v1);
545
- }
546
-
547
- bool do_dis_check = check_relative_distance;
548
- int nstep = 0;
549
-
550
- while (candidates.size() > 0) {
551
- float d0 = 0;
552
- int v0 = candidates.pop_min(&d0);
553
-
554
- if (do_dis_check) {
555
- // tricky stopping condition: there are more that ef
556
- // distances that are processed already that are smaller
557
- // than d0
558
-
559
- int n_dis_below = candidates.count_below(d0);
560
- if(n_dis_below >= efSearch) {
561
- break;
562
- }
507
+ DistanceComputer& qdis,
508
+ int k,
509
+ idx_t* I,
510
+ float* D,
511
+ MinimaxHeap& candidates,
512
+ VisitedTable& vt,
513
+ HNSWStats& stats,
514
+ int level,
515
+ int nres_in) const {
516
+ int nres = nres_in;
517
+ int ndis = 0;
518
+ for (int i = 0; i < candidates.size(); i++) {
519
+ idx_t v1 = candidates.ids[i];
520
+ float d = candidates.dis[i];
521
+ FAISS_ASSERT(v1 >= 0);
522
+ if (nres < k) {
523
+ faiss::maxheap_push(++nres, D, I, d, v1);
524
+ } else if (d < D[0]) {
525
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
526
+ }
527
+ vt.set(v1);
563
528
  }
564
529
 
565
- size_t begin, end;
566
- neighbor_range(v0, level, &begin, &end);
567
-
568
- for (size_t j = begin; j < end; j++) {
569
- int v1 = neighbors[j];
570
- if (v1 < 0) break;
571
- if (vt.get(v1)) {
572
- continue;
573
- }
574
- vt.set(v1);
575
- ndis++;
576
- float d = qdis(v1);
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
581
- }
582
- candidates.push(v1, d);
583
- }
584
-
585
- nstep++;
586
- if (!do_dis_check && nstep > efSearch) {
587
- break;
588
- }
589
- }
590
-
591
- if (level == 0) {
592
- stats.n1 ++;
593
- if (candidates.size() == 0) {
594
- stats.n2 ++;
595
- }
596
- stats.n3 += ndis;
597
- }
530
+ bool do_dis_check = check_relative_distance;
531
+ int nstep = 0;
598
532
 
599
- return nres;
600
- }
533
+ while (candidates.size() > 0) {
534
+ float d0 = 0;
535
+ int v0 = candidates.pop_min(&d0);
601
536
 
537
+ if (do_dis_check) {
538
+ // tricky stopping condition: there are more that ef
539
+ // distances that are processed already that are smaller
540
+ // than d0
602
541
 
603
- /**************************************************************
604
- * Searching
605
- **************************************************************/
606
-
607
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
608
- const Node& node,
609
- DistanceComputer& qdis,
610
- int ef,
611
- VisitedTable *vt,
612
- HNSWStats& stats) const
613
- {
614
- int ndis = 0;
615
- std::priority_queue<Node> top_candidates;
616
- std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
617
-
618
- top_candidates.push(node);
619
- candidates.push(node);
620
-
621
- vt->set(node.second);
542
+ int n_dis_below = candidates.count_below(d0);
543
+ if (n_dis_below >= efSearch) {
544
+ break;
545
+ }
546
+ }
622
547
 
623
- while (!candidates.empty()) {
624
- float d0;
625
- storage_idx_t v0;
626
- std::tie(d0, v0) = candidates.top();
548
+ size_t begin, end;
549
+ neighbor_range(v0, level, &begin, &end);
550
+
551
+ for (size_t j = begin; j < end; j++) {
552
+ int v1 = neighbors[j];
553
+ if (v1 < 0)
554
+ break;
555
+ if (vt.get(v1)) {
556
+ continue;
557
+ }
558
+ vt.set(v1);
559
+ ndis++;
560
+ float d = qdis(v1);
561
+ if (nres < k) {
562
+ faiss::maxheap_push(++nres, D, I, d, v1);
563
+ } else if (d < D[0]) {
564
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
565
+ }
566
+ candidates.push(v1, d);
567
+ }
627
568
 
628
- if (d0 > top_candidates.top().first) {
629
- break;
569
+ nstep++;
570
+ if (!do_dis_check && nstep > efSearch) {
571
+ break;
572
+ }
630
573
  }
631
574
 
632
- candidates.pop();
633
-
634
- size_t begin, end;
635
- neighbor_range(v0, 0, &begin, &end);
636
-
637
- for (size_t j = begin; j < end; ++j) {
638
- int v1 = neighbors[j];
639
-
640
- if (v1 < 0) {
641
- break;
642
- }
643
- if (vt->get(v1)) {
644
- continue;
645
- }
646
-
647
- vt->set(v1);
648
-
649
- float d1 = qdis(v1);
650
- ++ndis;
651
-
652
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
653
- candidates.emplace(d1, v1);
654
- top_candidates.emplace(d1, v1);
655
-
656
- if (top_candidates.size() > ef) {
657
- top_candidates.pop();
575
+ if (level == 0) {
576
+ stats.n1++;
577
+ if (candidates.size() == 0) {
578
+ stats.n2++;
658
579
  }
659
- }
580
+ stats.n3 += ndis;
660
581
  }
661
- }
662
-
663
- ++stats.n1;
664
- if (candidates.size() == 0) {
665
- ++stats.n2;
666
- }
667
- stats.n3 += ndis;
668
582
 
669
- return top_candidates;
583
+ return nres;
670
584
  }
671
585
 
672
- HNSWStats HNSW::search(DistanceComputer& qdis, int k,
673
- idx_t *I, float *D,
674
- VisitedTable& vt) const
675
- {
676
- HNSWStats stats;
677
-
678
- if (upper_beam == 1) {
679
-
680
- // greedy search on upper levels
681
- storage_idx_t nearest = entry_point;
682
- float d_nearest = qdis(nearest);
586
+ /**************************************************************
587
+ * Searching
588
+ **************************************************************/
683
589
 
684
- for(int level = max_level; level >= 1; level--) {
685
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
686
- }
590
+ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
591
+ const Node& node,
592
+ DistanceComputer& qdis,
593
+ int ef,
594
+ VisitedTable* vt,
595
+ HNSWStats& stats) const {
596
+ int ndis = 0;
597
+ std::priority_queue<Node> top_candidates;
598
+ std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
599
+
600
+ top_candidates.push(node);
601
+ candidates.push(node);
602
+
603
+ vt->set(node.second);
604
+
605
+ while (!candidates.empty()) {
606
+ float d0;
607
+ storage_idx_t v0;
608
+ std::tie(d0, v0) = candidates.top();
609
+
610
+ if (d0 > top_candidates.top().first) {
611
+ break;
612
+ }
687
613
 
688
- int ef = std::max(efSearch, k);
689
- if (search_bounded_queue) {
690
- MinimaxHeap candidates(ef);
614
+ candidates.pop();
691
615
 
692
- candidates.push(nearest, d_nearest);
616
+ size_t begin, end;
617
+ neighbor_range(v0, 0, &begin, &end);
693
618
 
694
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
695
- } else {
696
- std::priority_queue<Node> top_candidates =
697
- search_from_candidate_unbounded(Node(d_nearest, nearest),
698
- qdis, ef, &vt, stats);
619
+ for (size_t j = begin; j < end; ++j) {
620
+ int v1 = neighbors[j];
699
621
 
700
- while (top_candidates.size() > k) {
701
- top_candidates.pop();
702
- }
622
+ if (v1 < 0) {
623
+ break;
624
+ }
625
+ if (vt->get(v1)) {
626
+ continue;
627
+ }
703
628
 
704
- int nres = 0;
705
- while (!top_candidates.empty()) {
706
- float d;
707
- storage_idx_t label;
708
- std::tie(d, label) = top_candidates.top();
709
- faiss::maxheap_push(++nres, D, I, d, label);
710
- top_candidates.pop();
711
- }
712
- }
629
+ vt->set(v1);
713
630
 
714
- vt.advance();
631
+ float d1 = qdis(v1);
632
+ ++ndis;
715
633
 
716
- } else {
717
- int candidates_size = upper_beam;
718
- MinimaxHeap candidates(candidates_size);
634
+ if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
635
+ candidates.emplace(d1, v1);
636
+ top_candidates.emplace(d1, v1);
719
637
 
720
- std::vector<idx_t> I_to_next(candidates_size);
721
- std::vector<float> D_to_next(candidates_size);
638
+ if (top_candidates.size() > ef) {
639
+ top_candidates.pop();
640
+ }
641
+ }
642
+ }
643
+ }
722
644
 
723
- int nres = 1;
724
- I_to_next[0] = entry_point;
725
- D_to_next[0] = qdis(entry_point);
645
+ ++stats.n1;
646
+ if (candidates.size() == 0) {
647
+ ++stats.n2;
648
+ }
649
+ stats.n3 += ndis;
726
650
 
727
- for(int level = max_level; level >= 0; level--) {
651
+ return top_candidates;
652
+ }
728
653
 
729
- // copy I, D -> candidates
654
+ HNSWStats HNSW::search(
655
+ DistanceComputer& qdis,
656
+ int k,
657
+ idx_t* I,
658
+ float* D,
659
+ VisitedTable& vt) const {
660
+ HNSWStats stats;
661
+
662
+ if (upper_beam == 1) {
663
+ // greedy search on upper levels
664
+ storage_idx_t nearest = entry_point;
665
+ float d_nearest = qdis(nearest);
666
+
667
+ for (int level = max_level; level >= 1; level--) {
668
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
669
+ }
730
670
 
731
- candidates.clear();
671
+ int ef = std::max(efSearch, k);
672
+ if (search_bounded_queue) {
673
+ MinimaxHeap candidates(ef);
674
+
675
+ candidates.push(nearest, d_nearest);
676
+
677
+ search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
678
+ } else {
679
+ std::priority_queue<Node> top_candidates =
680
+ search_from_candidate_unbounded(
681
+ Node(d_nearest, nearest), qdis, ef, &vt, stats);
682
+
683
+ while (top_candidates.size() > k) {
684
+ top_candidates.pop();
685
+ }
686
+
687
+ int nres = 0;
688
+ while (!top_candidates.empty()) {
689
+ float d;
690
+ storage_idx_t label;
691
+ std::tie(d, label) = top_candidates.top();
692
+ faiss::maxheap_push(++nres, D, I, d, label);
693
+ top_candidates.pop();
694
+ }
695
+ }
732
696
 
733
- for (int i = 0; i < nres; i++) {
734
- candidates.push(I_to_next[i], D_to_next[i]);
735
- }
697
+ vt.advance();
736
698
 
737
- if (level == 0) {
738
- nres = search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
739
- } else {
740
- nres = search_from_candidates(
741
- qdis, candidates_size,
742
- I_to_next.data(), D_to_next.data(),
743
- candidates, vt, stats, level
744
- );
745
- }
746
- vt.advance();
699
+ } else {
700
+ int candidates_size = upper_beam;
701
+ MinimaxHeap candidates(candidates_size);
702
+
703
+ std::vector<idx_t> I_to_next(candidates_size);
704
+ std::vector<float> D_to_next(candidates_size);
705
+
706
+ int nres = 1;
707
+ I_to_next[0] = entry_point;
708
+ D_to_next[0] = qdis(entry_point);
709
+
710
+ for (int level = max_level; level >= 0; level--) {
711
+ // copy I, D -> candidates
712
+
713
+ candidates.clear();
714
+
715
+ for (int i = 0; i < nres; i++) {
716
+ candidates.push(I_to_next[i], D_to_next[i]);
717
+ }
718
+
719
+ if (level == 0) {
720
+ nres = search_from_candidates(
721
+ qdis, k, I, D, candidates, vt, stats, 0);
722
+ } else {
723
+ nres = search_from_candidates(
724
+ qdis,
725
+ candidates_size,
726
+ I_to_next.data(),
727
+ D_to_next.data(),
728
+ candidates,
729
+ vt,
730
+ stats,
731
+ level);
732
+ }
733
+ vt.advance();
734
+ }
747
735
  }
748
- }
749
736
 
750
- return stats;
737
+ return stats;
751
738
  }
752
739
 
753
-
754
740
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
755
- if (k == n) {
756
- if (v >= dis[0]) return;
757
- faiss::heap_pop<HC> (k--, dis.data(), ids.data());
758
- --nvalid;
759
- }
760
- faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
761
- ++nvalid;
741
+ if (k == n) {
742
+ if (v >= dis[0])
743
+ return;
744
+ faiss::heap_pop<HC>(k--, dis.data(), ids.data());
745
+ --nvalid;
746
+ }
747
+ faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
748
+ ++nvalid;
762
749
  }
763
750
 
764
751
  float HNSW::MinimaxHeap::max() const {
765
- return dis[0];
752
+ return dis[0];
766
753
  }
767
754
 
768
755
  int HNSW::MinimaxHeap::size() const {
769
- return nvalid;
756
+ return nvalid;
770
757
  }
771
758
 
772
759
  void HNSW::MinimaxHeap::clear() {
773
- nvalid = k = 0;
760
+ nvalid = k = 0;
774
761
  }
775
762
 
776
- int HNSW::MinimaxHeap::pop_min(float *vmin_out) {
777
- assert(k > 0);
778
- // returns min. This is an O(n) operation
779
- int i = k - 1;
780
- while (i >= 0) {
781
- if (ids[i] != -1) break;
782
- i--;
783
- }
784
- if (i == -1) return -1;
785
- int imin = i;
786
- float vmin = dis[i];
787
- i--;
788
- while(i >= 0) {
789
- if (ids[i] != -1 && dis[i] < vmin) {
790
- vmin = dis[i];
791
- imin = i;
763
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
764
+ assert(k > 0);
765
+ // returns min. This is an O(n) operation
766
+ int i = k - 1;
767
+ while (i >= 0) {
768
+ if (ids[i] != -1)
769
+ break;
770
+ i--;
792
771
  }
772
+ if (i == -1)
773
+ return -1;
774
+ int imin = i;
775
+ float vmin = dis[i];
793
776
  i--;
794
- }
795
- if (vmin_out) *vmin_out = vmin;
796
- int ret = ids[imin];
797
- ids[imin] = -1;
798
- --nvalid;
777
+ while (i >= 0) {
778
+ if (ids[i] != -1 && dis[i] < vmin) {
779
+ vmin = dis[i];
780
+ imin = i;
781
+ }
782
+ i--;
783
+ }
784
+ if (vmin_out)
785
+ *vmin_out = vmin;
786
+ int ret = ids[imin];
787
+ ids[imin] = -1;
788
+ --nvalid;
799
789
 
800
- return ret;
790
+ return ret;
801
791
  }
802
792
 
803
793
  int HNSW::MinimaxHeap::count_below(float thresh) {
804
- int n_below = 0;
805
- for(int i = 0; i < k; i++) {
806
- if (dis[i] < thresh) {
807
- n_below++;
794
+ int n_below = 0;
795
+ for (int i = 0; i < k; i++) {
796
+ if (dis[i] < thresh) {
797
+ n_below++;
798
+ }
808
799
  }
809
- }
810
800
 
811
- return n_below;
801
+ return n_below;
812
802
  }
813
803
 
814
-
815
- } // namespace faiss
804
+ } // namespace faiss