faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -15,275 +15,254 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
-
19
18
  /**************************************************************
20
19
  * HNSW structure implementation
21
20
  **************************************************************/
22
21
 
23
- int HNSW::nb_neighbors(int layer_no) const
24
- {
25
- return cum_nneighbor_per_level[layer_no + 1] -
26
- cum_nneighbor_per_level[layer_no];
22
+ int HNSW::nb_neighbors(int layer_no) const {
23
+ return cum_nneighbor_per_level[layer_no + 1] -
24
+ cum_nneighbor_per_level[layer_no];
27
25
  }
28
26
 
29
- void HNSW::set_nb_neighbors(int level_no, int n)
30
- {
31
- FAISS_THROW_IF_NOT(levels.size() == 0);
32
- int cur_n = nb_neighbors(level_no);
33
- for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
34
- cum_nneighbor_per_level[i] += n - cur_n;
35
- }
27
+ void HNSW::set_nb_neighbors(int level_no, int n) {
28
+ FAISS_THROW_IF_NOT(levels.size() == 0);
29
+ int cur_n = nb_neighbors(level_no);
30
+ for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
31
+ cum_nneighbor_per_level[i] += n - cur_n;
32
+ }
36
33
  }
37
34
 
38
- int HNSW::cum_nb_neighbors(int layer_no) const
39
- {
40
- return cum_nneighbor_per_level[layer_no];
35
+ int HNSW::cum_nb_neighbors(int layer_no) const {
36
+ return cum_nneighbor_per_level[layer_no];
41
37
  }
42
38
 
43
- void HNSW::neighbor_range(idx_t no, int layer_no,
44
- size_t * begin, size_t * end) const
45
- {
46
- size_t o = offsets[no];
47
- *begin = o + cum_nb_neighbors(layer_no);
48
- *end = o + cum_nb_neighbors(layer_no + 1);
39
+ void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end)
40
+ const {
41
+ size_t o = offsets[no];
42
+ *begin = o + cum_nb_neighbors(layer_no);
43
+ *end = o + cum_nb_neighbors(layer_no + 1);
49
44
  }
50
45
 
51
-
52
-
53
46
  HNSW::HNSW(int M) : rng(12345) {
54
- set_default_probas(M, 1.0 / log(M));
55
- max_level = -1;
56
- entry_point = -1;
57
- efSearch = 16;
58
- efConstruction = 40;
59
- upper_beam = 1;
60
- offsets.push_back(0);
47
+ set_default_probas(M, 1.0 / log(M));
48
+ max_level = -1;
49
+ entry_point = -1;
50
+ efSearch = 16;
51
+ efConstruction = 40;
52
+ upper_beam = 1;
53
+ offsets.push_back(0);
61
54
  }
62
55
 
63
-
64
- int HNSW::random_level()
65
- {
66
- double f = rng.rand_float();
67
- // could be a bit faster with bissection
68
- for (int level = 0; level < assign_probas.size(); level++) {
69
- if (f < assign_probas[level]) {
70
- return level;
56
+ int HNSW::random_level() {
57
+ double f = rng.rand_float();
58
+ // could be a bit faster with bissection
59
+ for (int level = 0; level < assign_probas.size(); level++) {
60
+ if (f < assign_probas[level]) {
61
+ return level;
62
+ }
63
+ f -= assign_probas[level];
71
64
  }
72
- f -= assign_probas[level];
73
- }
74
- // happens with exponentially low probability
75
- return assign_probas.size() - 1;
65
+ // happens with exponentially low probability
66
+ return assign_probas.size() - 1;
76
67
  }
77
68
 
78
- void HNSW::set_default_probas(int M, float levelMult)
79
- {
80
- int nn = 0;
81
- cum_nneighbor_per_level.push_back (0);
82
- for (int level = 0; ;level++) {
83
- float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
84
- if (proba < 1e-9) break;
85
- assign_probas.push_back(proba);
86
- nn += level == 0 ? M * 2 : M;
87
- cum_nneighbor_per_level.push_back (nn);
88
- }
69
+ void HNSW::set_default_probas(int M, float levelMult) {
70
+ int nn = 0;
71
+ cum_nneighbor_per_level.push_back(0);
72
+ for (int level = 0;; level++) {
73
+ float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
74
+ if (proba < 1e-9)
75
+ break;
76
+ assign_probas.push_back(proba);
77
+ nn += level == 0 ? M * 2 : M;
78
+ cum_nneighbor_per_level.push_back(nn);
79
+ }
89
80
  }
90
81
 
91
- void HNSW::clear_neighbor_tables(int level)
92
- {
93
- for (int i = 0; i < levels.size(); i++) {
94
- size_t begin, end;
95
- neighbor_range(i, level, &begin, &end);
96
- for (size_t j = begin; j < end; j++) {
97
- neighbors[j] = -1;
82
+ void HNSW::clear_neighbor_tables(int level) {
83
+ for (int i = 0; i < levels.size(); i++) {
84
+ size_t begin, end;
85
+ neighbor_range(i, level, &begin, &end);
86
+ for (size_t j = begin; j < end; j++) {
87
+ neighbors[j] = -1;
88
+ }
98
89
  }
99
- }
100
90
  }
101
91
 
102
-
103
92
  void HNSW::reset() {
104
- max_level = -1;
105
- entry_point = -1;
106
- offsets.clear();
107
- offsets.push_back(0);
108
- levels.clear();
109
- neighbors.clear();
93
+ max_level = -1;
94
+ entry_point = -1;
95
+ offsets.clear();
96
+ offsets.push_back(0);
97
+ levels.clear();
98
+ neighbors.clear();
110
99
  }
111
100
 
112
-
113
-
114
- void HNSW::print_neighbor_stats(int level) const
115
- {
116
- FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size());
117
- printf("stats on level %d, max %d neighbors per vertex:\n",
118
- level, nb_neighbors(level));
119
- size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
101
+ void HNSW::print_neighbor_stats(int level) const {
102
+ FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
103
+ printf("stats on level %d, max %d neighbors per vertex:\n",
104
+ level,
105
+ nb_neighbors(level));
106
+ size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
120
107
  #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
121
108
  reduction(+: tot_reciprocal) reduction(+: n_node)
122
- for (int i = 0; i < levels.size(); i++) {
123
- if (levels[i] > level) {
124
- n_node++;
125
- size_t begin, end;
126
- neighbor_range(i, level, &begin, &end);
127
- std::unordered_set<int> neighset;
128
- for (size_t j = begin; j < end; j++) {
129
- if (neighbors [j] < 0) break;
130
- neighset.insert(neighbors[j]);
131
- }
132
- int n_neigh = neighset.size();
133
- int n_common = 0;
134
- int n_reciprocal = 0;
135
- for (size_t j = begin; j < end; j++) {
136
- storage_idx_t i2 = neighbors[j];
137
- if (i2 < 0) break;
138
- FAISS_ASSERT(i2 != i);
139
- size_t begin2, end2;
140
- neighbor_range(i2, level, &begin2, &end2);
141
- for (size_t j2 = begin2; j2 < end2; j2++) {
142
- storage_idx_t i3 = neighbors[j2];
143
- if (i3 < 0) break;
144
- if (i3 == i) {
145
- n_reciprocal++;
146
- continue;
147
- }
148
- if (neighset.count(i3)) {
149
- neighset.erase(i3);
150
- n_common++;
151
- }
109
+ for (int i = 0; i < levels.size(); i++) {
110
+ if (levels[i] > level) {
111
+ n_node++;
112
+ size_t begin, end;
113
+ neighbor_range(i, level, &begin, &end);
114
+ std::unordered_set<int> neighset;
115
+ for (size_t j = begin; j < end; j++) {
116
+ if (neighbors[j] < 0)
117
+ break;
118
+ neighset.insert(neighbors[j]);
119
+ }
120
+ int n_neigh = neighset.size();
121
+ int n_common = 0;
122
+ int n_reciprocal = 0;
123
+ for (size_t j = begin; j < end; j++) {
124
+ storage_idx_t i2 = neighbors[j];
125
+ if (i2 < 0)
126
+ break;
127
+ FAISS_ASSERT(i2 != i);
128
+ size_t begin2, end2;
129
+ neighbor_range(i2, level, &begin2, &end2);
130
+ for (size_t j2 = begin2; j2 < end2; j2++) {
131
+ storage_idx_t i3 = neighbors[j2];
132
+ if (i3 < 0)
133
+ break;
134
+ if (i3 == i) {
135
+ n_reciprocal++;
136
+ continue;
137
+ }
138
+ if (neighset.count(i3)) {
139
+ neighset.erase(i3);
140
+ n_common++;
141
+ }
142
+ }
143
+ }
144
+ tot_neigh += n_neigh;
145
+ tot_common += n_common;
146
+ tot_reciprocal += n_reciprocal;
152
147
  }
153
- }
154
- tot_neigh += n_neigh;
155
- tot_common += n_common;
156
- tot_reciprocal += n_reciprocal;
157
148
  }
158
- }
159
- float normalizer = n_node;
160
- printf(" nb of nodes at that level %zd\n", n_node);
161
- printf(" neighbors per node: %.2f (%zd)\n",
162
- tot_neigh / normalizer, tot_neigh);
163
- printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer);
164
- printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
165
- tot_common / normalizer, tot_common);
166
-
167
-
168
-
149
+ float normalizer = n_node;
150
+ printf(" nb of nodes at that level %zd\n", n_node);
151
+ printf(" neighbors per node: %.2f (%zd)\n",
152
+ tot_neigh / normalizer,
153
+ tot_neigh);
154
+ printf(" nb of reciprocal neighbors: %.2f\n",
155
+ tot_reciprocal / normalizer);
156
+ printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
157
+ tot_common / normalizer,
158
+ tot_common);
169
159
  }
170
160
 
161
+ void HNSW::fill_with_random_links(size_t n) {
162
+ int max_level = prepare_level_tab(n);
163
+ RandomGenerator rng2(456);
171
164
 
172
- void HNSW::fill_with_random_links(size_t n)
173
- {
174
- int max_level = prepare_level_tab(n);
175
- RandomGenerator rng2(456);
176
-
177
- for (int level = max_level - 1; level >= 0; --level) {
178
- std::vector<int> elts;
179
- for (int i = 0; i < n; i++) {
180
- if (levels[i] > level) {
181
- elts.push_back(i);
182
- }
183
- }
184
- printf ("linking %zd elements in level %d\n",
185
- elts.size(), level);
186
-
187
- if (elts.size() == 1) continue;
165
+ for (int level = max_level - 1; level >= 0; --level) {
166
+ std::vector<int> elts;
167
+ for (int i = 0; i < n; i++) {
168
+ if (levels[i] > level) {
169
+ elts.push_back(i);
170
+ }
171
+ }
172
+ printf("linking %zd elements in level %d\n", elts.size(), level);
188
173
 
189
- for (int ii = 0; ii < elts.size(); ii++) {
190
- int i = elts[ii];
191
- size_t begin, end;
192
- neighbor_range(i, 0, &begin, &end);
193
- for (size_t j = begin; j < end; j++) {
194
- int other = 0;
195
- do {
196
- other = elts[rng2.rand_int(elts.size())];
197
- } while(other == i);
174
+ if (elts.size() == 1)
175
+ continue;
198
176
 
199
- neighbors[j] = other;
200
- }
177
+ for (int ii = 0; ii < elts.size(); ii++) {
178
+ int i = elts[ii];
179
+ size_t begin, end;
180
+ neighbor_range(i, 0, &begin, &end);
181
+ for (size_t j = begin; j < end; j++) {
182
+ int other = 0;
183
+ do {
184
+ other = elts[rng2.rand_int(elts.size())];
185
+ } while (other == i);
186
+
187
+ neighbors[j] = other;
188
+ }
189
+ }
201
190
  }
202
- }
203
191
  }
204
192
 
193
+ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
194
+ size_t n0 = offsets.size() - 1;
205
195
 
206
- int HNSW::prepare_level_tab(size_t n, bool preset_levels)
207
- {
208
- size_t n0 = offsets.size() - 1;
196
+ if (preset_levels) {
197
+ FAISS_ASSERT(n0 + n == levels.size());
198
+ } else {
199
+ FAISS_ASSERT(n0 == levels.size());
200
+ for (int i = 0; i < n; i++) {
201
+ int pt_level = random_level();
202
+ levels.push_back(pt_level + 1);
203
+ }
204
+ }
209
205
 
210
- if (preset_levels) {
211
- FAISS_ASSERT (n0 + n == levels.size());
212
- } else {
213
- FAISS_ASSERT (n0 == levels.size());
206
+ int max_level = 0;
214
207
  for (int i = 0; i < n; i++) {
215
- int pt_level = random_level();
216
- levels.push_back(pt_level + 1);
208
+ int pt_level = levels[i + n0] - 1;
209
+ if (pt_level > max_level)
210
+ max_level = pt_level;
211
+ offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
212
+ neighbors.resize(offsets.back(), -1);
217
213
  }
218
- }
219
214
 
220
- int max_level = 0;
221
- for (int i = 0; i < n; i++) {
222
- int pt_level = levels[i + n0] - 1;
223
- if (pt_level > max_level) max_level = pt_level;
224
- offsets.push_back(offsets.back() +
225
- cum_nb_neighbors(pt_level + 1));
226
- neighbors.resize(offsets.back(), -1);
227
- }
228
-
229
- return max_level;
215
+ return max_level;
230
216
  }
231
217
 
232
-
233
218
  /** Enumerate vertices from farthest to nearest from query, keep a
234
219
  * neighbor only if there is no previous neighbor that is closer to
235
220
  * that vertex than the query.
236
221
  */
237
222
  void HNSW::shrink_neighbor_list(
238
- DistanceComputer& qdis,
239
- std::priority_queue<NodeDistFarther>& input,
240
- std::vector<NodeDistFarther>& output,
241
- int max_size)
242
- {
243
- while (input.size() > 0) {
244
- NodeDistFarther v1 = input.top();
245
- input.pop();
246
- float dist_v1_q = v1.d;
247
-
248
- bool good = true;
249
- for (NodeDistFarther v2 : output) {
250
- float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
251
-
252
- if (dist_v1_v2 < dist_v1_q) {
253
- good = false;
254
- break;
255
- }
256
- }
257
-
258
- if (good) {
259
- output.push_back(v1);
260
- if (output.size() >= max_size) {
261
- return;
262
- }
223
+ DistanceComputer& qdis,
224
+ std::priority_queue<NodeDistFarther>& input,
225
+ std::vector<NodeDistFarther>& output,
226
+ int max_size) {
227
+ while (input.size() > 0) {
228
+ NodeDistFarther v1 = input.top();
229
+ input.pop();
230
+ float dist_v1_q = v1.d;
231
+
232
+ bool good = true;
233
+ for (NodeDistFarther v2 : output) {
234
+ float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
235
+
236
+ if (dist_v1_v2 < dist_v1_q) {
237
+ good = false;
238
+ break;
239
+ }
240
+ }
241
+
242
+ if (good) {
243
+ output.push_back(v1);
244
+ if (output.size() >= max_size) {
245
+ return;
246
+ }
247
+ }
263
248
  }
264
- }
265
249
  }
266
250
 
267
-
268
251
  namespace {
269
252
 
270
-
271
253
  using storage_idx_t = HNSW::storage_idx_t;
272
254
  using NodeDistCloser = HNSW::NodeDistCloser;
273
255
  using NodeDistFarther = HNSW::NodeDistFarther;
274
256
 
275
-
276
257
  /**************************************************************
277
258
  * Addition subroutines
278
259
  **************************************************************/
279
260
 
280
-
281
261
  /// remove neighbors from the list to make it smaller than max_size
282
262
  void shrink_neighbor_list(
283
- DistanceComputer& qdis,
284
- std::priority_queue<NodeDistCloser>& resultSet1,
285
- int max_size)
286
- {
263
+ DistanceComputer& qdis,
264
+ std::priority_queue<NodeDistCloser>& resultSet1,
265
+ int max_size) {
287
266
  if (resultSet1.size() < max_size) {
288
267
  return;
289
268
  }
@@ -300,516 +279,526 @@ void shrink_neighbor_list(
300
279
  for (NodeDistFarther curen2 : returnlist) {
301
280
  resultSet1.emplace(curen2.d, curen2.id);
302
281
  }
303
-
304
282
  }
305
283
 
306
-
307
284
  /// add a link between two elements, possibly shrinking the list
308
285
  /// of links to make room for it.
309
- void add_link(HNSW& hnsw,
310
- DistanceComputer& qdis,
311
- storage_idx_t src, storage_idx_t dest,
312
- int level)
313
- {
314
- size_t begin, end;
315
- hnsw.neighbor_range(src, level, &begin, &end);
316
- if (hnsw.neighbors[end - 1] == -1) {
317
- // there is enough room, find a slot to add it
318
- size_t i = end;
319
- while(i > begin) {
320
- if (hnsw.neighbors[i - 1] != -1) break;
321
- i--;
322
- }
323
- hnsw.neighbors[i] = dest;
324
- return;
325
- }
326
-
327
- // otherwise we let them fight out which to keep
328
-
329
- // copy to resultSet...
330
- std::priority_queue<NodeDistCloser> resultSet;
331
- resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
332
- for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
333
- storage_idx_t neigh = hnsw.neighbors[i];
334
- resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
335
- }
336
-
337
- shrink_neighbor_list(qdis, resultSet, end - begin);
338
-
339
- // ...and back
340
- size_t i = begin;
341
- while (resultSet.size()) {
342
- hnsw.neighbors[i++] = resultSet.top().id;
343
- resultSet.pop();
344
- }
345
- // they may have shrunk more than just by 1 element
346
- while(i < end) {
347
- hnsw.neighbors[i++] = -1;
348
- }
286
+ void add_link(
287
+ HNSW& hnsw,
288
+ DistanceComputer& qdis,
289
+ storage_idx_t src,
290
+ storage_idx_t dest,
291
+ int level) {
292
+ size_t begin, end;
293
+ hnsw.neighbor_range(src, level, &begin, &end);
294
+ if (hnsw.neighbors[end - 1] == -1) {
295
+ // there is enough room, find a slot to add it
296
+ size_t i = end;
297
+ while (i > begin) {
298
+ if (hnsw.neighbors[i - 1] != -1)
299
+ break;
300
+ i--;
301
+ }
302
+ hnsw.neighbors[i] = dest;
303
+ return;
304
+ }
305
+
306
+ // otherwise we let them fight out which to keep
307
+
308
+ // copy to resultSet...
309
+ std::priority_queue<NodeDistCloser> resultSet;
310
+ resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
311
+ for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
312
+ storage_idx_t neigh = hnsw.neighbors[i];
313
+ resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
314
+ }
315
+
316
+ shrink_neighbor_list(qdis, resultSet, end - begin);
317
+
318
+ // ...and back
319
+ size_t i = begin;
320
+ while (resultSet.size()) {
321
+ hnsw.neighbors[i++] = resultSet.top().id;
322
+ resultSet.pop();
323
+ }
324
+ // they may have shrunk more than just by 1 element
325
+ while (i < end) {
326
+ hnsw.neighbors[i++] = -1;
327
+ }
349
328
  }
350
329
 
351
330
  /// search neighbors on a single level, starting from an entry point
352
331
  void search_neighbors_to_add(
353
- HNSW& hnsw,
354
- DistanceComputer& qdis,
355
- std::priority_queue<NodeDistCloser>& results,
356
- int entry_point,
357
- float d_entry_point,
358
- int level,
359
- VisitedTable &vt)
360
- {
361
- // top is nearest candidate
362
- std::priority_queue<NodeDistFarther> candidates;
363
-
364
- NodeDistFarther ev(d_entry_point, entry_point);
365
- candidates.push(ev);
366
- results.emplace(d_entry_point, entry_point);
367
- vt.set(entry_point);
368
-
369
- while (!candidates.empty()) {
370
- // get nearest
371
- const NodeDistFarther &currEv = candidates.top();
372
-
373
- if (currEv.d > results.top().d) {
374
- break;
375
- }
376
- int currNode = currEv.id;
377
- candidates.pop();
378
-
379
- // loop over neighbors
380
- size_t begin, end;
381
- hnsw.neighbor_range(currNode, level, &begin, &end);
382
- for(size_t i = begin; i < end; i++) {
383
- storage_idx_t nodeId = hnsw.neighbors[i];
384
- if (nodeId < 0) break;
385
- if (vt.get(nodeId)) continue;
386
- vt.set(nodeId);
387
-
388
- float dis = qdis(nodeId);
389
- NodeDistFarther evE1(dis, nodeId);
390
-
391
- if (results.size() < hnsw.efConstruction ||
392
- results.top().d > dis) {
393
-
394
- results.emplace(dis, nodeId);
395
- candidates.emplace(dis, nodeId);
396
- if (results.size() > hnsw.efConstruction) {
397
- results.pop();
332
+ HNSW& hnsw,
333
+ DistanceComputer& qdis,
334
+ std::priority_queue<NodeDistCloser>& results,
335
+ int entry_point,
336
+ float d_entry_point,
337
+ int level,
338
+ VisitedTable& vt) {
339
+ // top is nearest candidate
340
+ std::priority_queue<NodeDistFarther> candidates;
341
+
342
+ NodeDistFarther ev(d_entry_point, entry_point);
343
+ candidates.push(ev);
344
+ results.emplace(d_entry_point, entry_point);
345
+ vt.set(entry_point);
346
+
347
+ while (!candidates.empty()) {
348
+ // get nearest
349
+ const NodeDistFarther& currEv = candidates.top();
350
+
351
+ if (currEv.d > results.top().d) {
352
+ break;
353
+ }
354
+ int currNode = currEv.id;
355
+ candidates.pop();
356
+
357
+ // loop over neighbors
358
+ size_t begin, end;
359
+ hnsw.neighbor_range(currNode, level, &begin, &end);
360
+ for (size_t i = begin; i < end; i++) {
361
+ storage_idx_t nodeId = hnsw.neighbors[i];
362
+ if (nodeId < 0)
363
+ break;
364
+ if (vt.get(nodeId))
365
+ continue;
366
+ vt.set(nodeId);
367
+
368
+ float dis = qdis(nodeId);
369
+ NodeDistFarther evE1(dis, nodeId);
370
+
371
+ if (results.size() < hnsw.efConstruction || results.top().d > dis) {
372
+ results.emplace(dis, nodeId);
373
+ candidates.emplace(dis, nodeId);
374
+ if (results.size() > hnsw.efConstruction) {
375
+ results.pop();
376
+ }
377
+ }
398
378
  }
399
- }
400
379
  }
401
- }
402
- vt.advance();
380
+ vt.advance();
403
381
  }
404
382
 
405
-
406
383
  /**************************************************************
407
384
  * Searching subroutines
408
385
  **************************************************************/
409
386
 
410
387
  /// greedily update a nearest vector at a given level
411
- void greedy_update_nearest(const HNSW& hnsw,
412
- DistanceComputer& qdis,
413
- int level,
414
- storage_idx_t& nearest,
415
- float& d_nearest)
416
- {
417
- for(;;) {
418
- storage_idx_t prev_nearest = nearest;
419
-
420
- size_t begin, end;
421
- hnsw.neighbor_range(nearest, level, &begin, &end);
422
- for(size_t i = begin; i < end; i++) {
423
- storage_idx_t v = hnsw.neighbors[i];
424
- if (v < 0) break;
425
- float dis = qdis(v);
426
- if (dis < d_nearest) {
427
- nearest = v;
428
- d_nearest = dis;
429
- }
430
- }
431
- if (nearest == prev_nearest) {
432
- return;
433
- }
434
- }
388
+ void greedy_update_nearest(
389
+ const HNSW& hnsw,
390
+ DistanceComputer& qdis,
391
+ int level,
392
+ storage_idx_t& nearest,
393
+ float& d_nearest) {
394
+ for (;;) {
395
+ storage_idx_t prev_nearest = nearest;
396
+
397
+ size_t begin, end;
398
+ hnsw.neighbor_range(nearest, level, &begin, &end);
399
+ for (size_t i = begin; i < end; i++) {
400
+ storage_idx_t v = hnsw.neighbors[i];
401
+ if (v < 0)
402
+ break;
403
+ float dis = qdis(v);
404
+ if (dis < d_nearest) {
405
+ nearest = v;
406
+ d_nearest = dis;
407
+ }
408
+ }
409
+ if (nearest == prev_nearest) {
410
+ return;
411
+ }
412
+ }
435
413
  }
436
414
 
437
-
438
- } // namespace
439
-
415
+ } // namespace
440
416
 
441
417
  /// Finds neighbors and builds links with them, starting from an entry
442
418
  /// point. The own neighbor list is assumed to be locked.
443
- void HNSW::add_links_starting_from(DistanceComputer& ptdis,
444
- storage_idx_t pt_id,
445
- storage_idx_t nearest,
446
- float d_nearest,
447
- int level,
448
- omp_lock_t *locks,
449
- VisitedTable &vt)
450
- {
451
- std::priority_queue<NodeDistCloser> link_targets;
452
-
453
- search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest,
454
- level, vt);
455
-
456
- // but we can afford only this many neighbors
457
- int M = nb_neighbors(level);
458
-
459
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
460
-
461
- while (!link_targets.empty()) {
462
- int other_id = link_targets.top().id;
463
-
464
- omp_set_lock(&locks[other_id]);
465
- add_link(*this, ptdis, other_id, pt_id, level);
466
- omp_unset_lock(&locks[other_id]);
467
-
468
- add_link(*this, ptdis, pt_id, other_id, level);
419
+ void HNSW::add_links_starting_from(
420
+ DistanceComputer& ptdis,
421
+ storage_idx_t pt_id,
422
+ storage_idx_t nearest,
423
+ float d_nearest,
424
+ int level,
425
+ omp_lock_t* locks,
426
+ VisitedTable& vt) {
427
+ std::priority_queue<NodeDistCloser> link_targets;
428
+
429
+ search_neighbors_to_add(
430
+ *this, ptdis, link_targets, nearest, d_nearest, level, vt);
431
+
432
+ // but we can afford only this many neighbors
433
+ int M = nb_neighbors(level);
434
+
435
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
436
+
437
+ std::vector<storage_idx_t> neighbors;
438
+ neighbors.reserve(link_targets.size());
439
+ while (!link_targets.empty()) {
440
+ storage_idx_t other_id = link_targets.top().id;
441
+ add_link(*this, ptdis, pt_id, other_id, level);
442
+ neighbors.push_back(other_id);
443
+ link_targets.pop();
444
+ }
469
445
 
470
- link_targets.pop();
471
- }
446
+ omp_unset_lock(&locks[pt_id]);
447
+ for (storage_idx_t other_id : neighbors) {
448
+ omp_set_lock(&locks[other_id]);
449
+ add_link(*this, ptdis, other_id, pt_id, level);
450
+ omp_unset_lock(&locks[other_id]);
451
+ }
452
+ omp_set_lock(&locks[pt_id]);
472
453
  }
473
454
 
474
-
475
455
  /**************************************************************
476
456
  * Building, parallel
477
457
  **************************************************************/
478
458
 
479
- void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
480
- std::vector<omp_lock_t>& locks,
481
- VisitedTable& vt)
482
- {
483
- // greedy search on upper levels
459
+ void HNSW::add_with_locks(
460
+ DistanceComputer& ptdis,
461
+ int pt_level,
462
+ int pt_id,
463
+ std::vector<omp_lock_t>& locks,
464
+ VisitedTable& vt) {
465
+ // greedy search on upper levels
484
466
 
485
- storage_idx_t nearest;
467
+ storage_idx_t nearest;
486
468
  #pragma omp critical
487
- {
488
- nearest = entry_point;
469
+ {
470
+ nearest = entry_point;
489
471
 
490
- if (nearest == -1) {
491
- max_level = pt_level;
492
- entry_point = pt_id;
472
+ if (nearest == -1) {
473
+ max_level = pt_level;
474
+ entry_point = pt_id;
475
+ }
493
476
  }
494
- }
495
477
 
496
- if (nearest < 0) {
497
- return;
498
- }
478
+ if (nearest < 0) {
479
+ return;
480
+ }
499
481
 
500
- omp_set_lock(&locks[pt_id]);
482
+ omp_set_lock(&locks[pt_id]);
501
483
 
502
- int level = max_level; // level at which we start adding neighbors
503
- float d_nearest = ptdis(nearest);
484
+ int level = max_level; // level at which we start adding neighbors
485
+ float d_nearest = ptdis(nearest);
504
486
 
505
- for(; level > pt_level; level--) {
506
- greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
507
- }
487
+ for (; level > pt_level; level--) {
488
+ greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
489
+ }
508
490
 
509
- for(; level >= 0; level--) {
510
- add_links_starting_from(ptdis, pt_id, nearest, d_nearest,
511
- level, locks.data(), vt);
512
- }
491
+ for (; level >= 0; level--) {
492
+ add_links_starting_from(
493
+ ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
494
+ }
513
495
 
514
- omp_unset_lock(&locks[pt_id]);
496
+ omp_unset_lock(&locks[pt_id]);
515
497
 
516
- if (pt_level > max_level) {
517
- max_level = pt_level;
518
- entry_point = pt_id;
519
- }
498
+ if (pt_level > max_level) {
499
+ max_level = pt_level;
500
+ entry_point = pt_id;
501
+ }
520
502
  }
521
503
 
522
-
523
504
  /** Do a BFS on the candidates list */
524
505
 
525
506
  int HNSW::search_from_candidates(
526
- DistanceComputer& qdis, int k,
527
- idx_t *I, float *D,
528
- MinimaxHeap& candidates,
529
- VisitedTable& vt,
530
- HNSWStats& stats,
531
- int level, int nres_in) const
532
- {
533
- int nres = nres_in;
534
- int ndis = 0;
535
- for (int i = 0; i < candidates.size(); i++) {
536
- idx_t v1 = candidates.ids[i];
537
- float d = candidates.dis[i];
538
- FAISS_ASSERT(v1 >= 0);
539
- if (nres < k) {
540
- faiss::maxheap_push(++nres, D, I, d, v1);
541
- } else if (d < D[0]) {
542
- faiss::maxheap_replace_top(nres, D, I, d, v1);
543
- }
544
- vt.set(v1);
545
- }
546
-
547
- bool do_dis_check = check_relative_distance;
548
- int nstep = 0;
549
-
550
- while (candidates.size() > 0) {
551
- float d0 = 0;
552
- int v0 = candidates.pop_min(&d0);
553
-
554
- if (do_dis_check) {
555
- // tricky stopping condition: there are more that ef
556
- // distances that are processed already that are smaller
557
- // than d0
558
-
559
- int n_dis_below = candidates.count_below(d0);
560
- if(n_dis_below >= efSearch) {
561
- break;
562
- }
507
+ DistanceComputer& qdis,
508
+ int k,
509
+ idx_t* I,
510
+ float* D,
511
+ MinimaxHeap& candidates,
512
+ VisitedTable& vt,
513
+ HNSWStats& stats,
514
+ int level,
515
+ int nres_in) const {
516
+ int nres = nres_in;
517
+ int ndis = 0;
518
+ for (int i = 0; i < candidates.size(); i++) {
519
+ idx_t v1 = candidates.ids[i];
520
+ float d = candidates.dis[i];
521
+ FAISS_ASSERT(v1 >= 0);
522
+ if (nres < k) {
523
+ faiss::maxheap_push(++nres, D, I, d, v1);
524
+ } else if (d < D[0]) {
525
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
526
+ }
527
+ vt.set(v1);
563
528
  }
564
529
 
565
- size_t begin, end;
566
- neighbor_range(v0, level, &begin, &end);
567
-
568
- for (size_t j = begin; j < end; j++) {
569
- int v1 = neighbors[j];
570
- if (v1 < 0) break;
571
- if (vt.get(v1)) {
572
- continue;
573
- }
574
- vt.set(v1);
575
- ndis++;
576
- float d = qdis(v1);
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
581
- }
582
- candidates.push(v1, d);
583
- }
584
-
585
- nstep++;
586
- if (!do_dis_check && nstep > efSearch) {
587
- break;
588
- }
589
- }
590
-
591
- if (level == 0) {
592
- stats.n1 ++;
593
- if (candidates.size() == 0) {
594
- stats.n2 ++;
595
- }
596
- stats.n3 += ndis;
597
- }
530
+ bool do_dis_check = check_relative_distance;
531
+ int nstep = 0;
598
532
 
599
- return nres;
600
- }
533
+ while (candidates.size() > 0) {
534
+ float d0 = 0;
535
+ int v0 = candidates.pop_min(&d0);
601
536
 
537
+ if (do_dis_check) {
538
+ // tricky stopping condition: there are more that ef
539
+ // distances that are processed already that are smaller
540
+ // than d0
602
541
 
603
- /**************************************************************
604
- * Searching
605
- **************************************************************/
606
-
607
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
608
- const Node& node,
609
- DistanceComputer& qdis,
610
- int ef,
611
- VisitedTable *vt,
612
- HNSWStats& stats) const
613
- {
614
- int ndis = 0;
615
- std::priority_queue<Node> top_candidates;
616
- std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
617
-
618
- top_candidates.push(node);
619
- candidates.push(node);
620
-
621
- vt->set(node.second);
542
+ int n_dis_below = candidates.count_below(d0);
543
+ if (n_dis_below >= efSearch) {
544
+ break;
545
+ }
546
+ }
622
547
 
623
- while (!candidates.empty()) {
624
- float d0;
625
- storage_idx_t v0;
626
- std::tie(d0, v0) = candidates.top();
548
+ size_t begin, end;
549
+ neighbor_range(v0, level, &begin, &end);
550
+
551
+ for (size_t j = begin; j < end; j++) {
552
+ int v1 = neighbors[j];
553
+ if (v1 < 0)
554
+ break;
555
+ if (vt.get(v1)) {
556
+ continue;
557
+ }
558
+ vt.set(v1);
559
+ ndis++;
560
+ float d = qdis(v1);
561
+ if (nres < k) {
562
+ faiss::maxheap_push(++nres, D, I, d, v1);
563
+ } else if (d < D[0]) {
564
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
565
+ }
566
+ candidates.push(v1, d);
567
+ }
627
568
 
628
- if (d0 > top_candidates.top().first) {
629
- break;
569
+ nstep++;
570
+ if (!do_dis_check && nstep > efSearch) {
571
+ break;
572
+ }
630
573
  }
631
574
 
632
- candidates.pop();
633
-
634
- size_t begin, end;
635
- neighbor_range(v0, 0, &begin, &end);
636
-
637
- for (size_t j = begin; j < end; ++j) {
638
- int v1 = neighbors[j];
639
-
640
- if (v1 < 0) {
641
- break;
642
- }
643
- if (vt->get(v1)) {
644
- continue;
645
- }
646
-
647
- vt->set(v1);
648
-
649
- float d1 = qdis(v1);
650
- ++ndis;
651
-
652
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
653
- candidates.emplace(d1, v1);
654
- top_candidates.emplace(d1, v1);
655
-
656
- if (top_candidates.size() > ef) {
657
- top_candidates.pop();
575
+ if (level == 0) {
576
+ stats.n1++;
577
+ if (candidates.size() == 0) {
578
+ stats.n2++;
658
579
  }
659
- }
580
+ stats.n3 += ndis;
660
581
  }
661
- }
662
-
663
- ++stats.n1;
664
- if (candidates.size() == 0) {
665
- ++stats.n2;
666
- }
667
- stats.n3 += ndis;
668
582
 
669
- return top_candidates;
583
+ return nres;
670
584
  }
671
585
 
672
- HNSWStats HNSW::search(DistanceComputer& qdis, int k,
673
- idx_t *I, float *D,
674
- VisitedTable& vt) const
675
- {
676
- HNSWStats stats;
677
-
678
- if (upper_beam == 1) {
679
-
680
- // greedy search on upper levels
681
- storage_idx_t nearest = entry_point;
682
- float d_nearest = qdis(nearest);
586
+ /**************************************************************
587
+ * Searching
588
+ **************************************************************/
683
589
 
684
- for(int level = max_level; level >= 1; level--) {
685
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
686
- }
590
+ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
591
+ const Node& node,
592
+ DistanceComputer& qdis,
593
+ int ef,
594
+ VisitedTable* vt,
595
+ HNSWStats& stats) const {
596
+ int ndis = 0;
597
+ std::priority_queue<Node> top_candidates;
598
+ std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
599
+
600
+ top_candidates.push(node);
601
+ candidates.push(node);
602
+
603
+ vt->set(node.second);
604
+
605
+ while (!candidates.empty()) {
606
+ float d0;
607
+ storage_idx_t v0;
608
+ std::tie(d0, v0) = candidates.top();
609
+
610
+ if (d0 > top_candidates.top().first) {
611
+ break;
612
+ }
687
613
 
688
- int ef = std::max(efSearch, k);
689
- if (search_bounded_queue) {
690
- MinimaxHeap candidates(ef);
614
+ candidates.pop();
691
615
 
692
- candidates.push(nearest, d_nearest);
616
+ size_t begin, end;
617
+ neighbor_range(v0, 0, &begin, &end);
693
618
 
694
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
695
- } else {
696
- std::priority_queue<Node> top_candidates =
697
- search_from_candidate_unbounded(Node(d_nearest, nearest),
698
- qdis, ef, &vt, stats);
619
+ for (size_t j = begin; j < end; ++j) {
620
+ int v1 = neighbors[j];
699
621
 
700
- while (top_candidates.size() > k) {
701
- top_candidates.pop();
702
- }
622
+ if (v1 < 0) {
623
+ break;
624
+ }
625
+ if (vt->get(v1)) {
626
+ continue;
627
+ }
703
628
 
704
- int nres = 0;
705
- while (!top_candidates.empty()) {
706
- float d;
707
- storage_idx_t label;
708
- std::tie(d, label) = top_candidates.top();
709
- faiss::maxheap_push(++nres, D, I, d, label);
710
- top_candidates.pop();
711
- }
712
- }
629
+ vt->set(v1);
713
630
 
714
- vt.advance();
631
+ float d1 = qdis(v1);
632
+ ++ndis;
715
633
 
716
- } else {
717
- int candidates_size = upper_beam;
718
- MinimaxHeap candidates(candidates_size);
634
+ if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
635
+ candidates.emplace(d1, v1);
636
+ top_candidates.emplace(d1, v1);
719
637
 
720
- std::vector<idx_t> I_to_next(candidates_size);
721
- std::vector<float> D_to_next(candidates_size);
638
+ if (top_candidates.size() > ef) {
639
+ top_candidates.pop();
640
+ }
641
+ }
642
+ }
643
+ }
722
644
 
723
- int nres = 1;
724
- I_to_next[0] = entry_point;
725
- D_to_next[0] = qdis(entry_point);
645
+ ++stats.n1;
646
+ if (candidates.size() == 0) {
647
+ ++stats.n2;
648
+ }
649
+ stats.n3 += ndis;
726
650
 
727
- for(int level = max_level; level >= 0; level--) {
651
+ return top_candidates;
652
+ }
728
653
 
729
- // copy I, D -> candidates
654
+ HNSWStats HNSW::search(
655
+ DistanceComputer& qdis,
656
+ int k,
657
+ idx_t* I,
658
+ float* D,
659
+ VisitedTable& vt) const {
660
+ HNSWStats stats;
661
+
662
+ if (upper_beam == 1) {
663
+ // greedy search on upper levels
664
+ storage_idx_t nearest = entry_point;
665
+ float d_nearest = qdis(nearest);
666
+
667
+ for (int level = max_level; level >= 1; level--) {
668
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
669
+ }
730
670
 
731
- candidates.clear();
671
+ int ef = std::max(efSearch, k);
672
+ if (search_bounded_queue) {
673
+ MinimaxHeap candidates(ef);
674
+
675
+ candidates.push(nearest, d_nearest);
676
+
677
+ search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
678
+ } else {
679
+ std::priority_queue<Node> top_candidates =
680
+ search_from_candidate_unbounded(
681
+ Node(d_nearest, nearest), qdis, ef, &vt, stats);
682
+
683
+ while (top_candidates.size() > k) {
684
+ top_candidates.pop();
685
+ }
686
+
687
+ int nres = 0;
688
+ while (!top_candidates.empty()) {
689
+ float d;
690
+ storage_idx_t label;
691
+ std::tie(d, label) = top_candidates.top();
692
+ faiss::maxheap_push(++nres, D, I, d, label);
693
+ top_candidates.pop();
694
+ }
695
+ }
732
696
 
733
- for (int i = 0; i < nres; i++) {
734
- candidates.push(I_to_next[i], D_to_next[i]);
735
- }
697
+ vt.advance();
736
698
 
737
- if (level == 0) {
738
- nres = search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
739
- } else {
740
- nres = search_from_candidates(
741
- qdis, candidates_size,
742
- I_to_next.data(), D_to_next.data(),
743
- candidates, vt, stats, level
744
- );
745
- }
746
- vt.advance();
699
+ } else {
700
+ int candidates_size = upper_beam;
701
+ MinimaxHeap candidates(candidates_size);
702
+
703
+ std::vector<idx_t> I_to_next(candidates_size);
704
+ std::vector<float> D_to_next(candidates_size);
705
+
706
+ int nres = 1;
707
+ I_to_next[0] = entry_point;
708
+ D_to_next[0] = qdis(entry_point);
709
+
710
+ for (int level = max_level; level >= 0; level--) {
711
+ // copy I, D -> candidates
712
+
713
+ candidates.clear();
714
+
715
+ for (int i = 0; i < nres; i++) {
716
+ candidates.push(I_to_next[i], D_to_next[i]);
717
+ }
718
+
719
+ if (level == 0) {
720
+ nres = search_from_candidates(
721
+ qdis, k, I, D, candidates, vt, stats, 0);
722
+ } else {
723
+ nres = search_from_candidates(
724
+ qdis,
725
+ candidates_size,
726
+ I_to_next.data(),
727
+ D_to_next.data(),
728
+ candidates,
729
+ vt,
730
+ stats,
731
+ level);
732
+ }
733
+ vt.advance();
734
+ }
747
735
  }
748
- }
749
736
 
750
- return stats;
737
+ return stats;
751
738
  }
752
739
 
753
-
754
740
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
755
- if (k == n) {
756
- if (v >= dis[0]) return;
757
- faiss::heap_pop<HC> (k--, dis.data(), ids.data());
758
- --nvalid;
759
- }
760
- faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
761
- ++nvalid;
741
+ if (k == n) {
742
+ if (v >= dis[0])
743
+ return;
744
+ faiss::heap_pop<HC>(k--, dis.data(), ids.data());
745
+ --nvalid;
746
+ }
747
+ faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
748
+ ++nvalid;
762
749
  }
763
750
 
764
751
  float HNSW::MinimaxHeap::max() const {
765
- return dis[0];
752
+ return dis[0];
766
753
  }
767
754
 
768
755
  int HNSW::MinimaxHeap::size() const {
769
- return nvalid;
756
+ return nvalid;
770
757
  }
771
758
 
772
759
  void HNSW::MinimaxHeap::clear() {
773
- nvalid = k = 0;
760
+ nvalid = k = 0;
774
761
  }
775
762
 
776
- int HNSW::MinimaxHeap::pop_min(float *vmin_out) {
777
- assert(k > 0);
778
- // returns min. This is an O(n) operation
779
- int i = k - 1;
780
- while (i >= 0) {
781
- if (ids[i] != -1) break;
782
- i--;
783
- }
784
- if (i == -1) return -1;
785
- int imin = i;
786
- float vmin = dis[i];
787
- i--;
788
- while(i >= 0) {
789
- if (ids[i] != -1 && dis[i] < vmin) {
790
- vmin = dis[i];
791
- imin = i;
763
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
764
+ assert(k > 0);
765
+ // returns min. This is an O(n) operation
766
+ int i = k - 1;
767
+ while (i >= 0) {
768
+ if (ids[i] != -1)
769
+ break;
770
+ i--;
792
771
  }
772
+ if (i == -1)
773
+ return -1;
774
+ int imin = i;
775
+ float vmin = dis[i];
793
776
  i--;
794
- }
795
- if (vmin_out) *vmin_out = vmin;
796
- int ret = ids[imin];
797
- ids[imin] = -1;
798
- --nvalid;
777
+ while (i >= 0) {
778
+ if (ids[i] != -1 && dis[i] < vmin) {
779
+ vmin = dis[i];
780
+ imin = i;
781
+ }
782
+ i--;
783
+ }
784
+ if (vmin_out)
785
+ *vmin_out = vmin;
786
+ int ret = ids[imin];
787
+ ids[imin] = -1;
788
+ --nvalid;
799
789
 
800
- return ret;
790
+ return ret;
801
791
  }
802
792
 
803
793
  int HNSW::MinimaxHeap::count_below(float thresh) {
804
- int n_below = 0;
805
- for(int i = 0; i < k; i++) {
806
- if (dis[i] < thresh) {
807
- n_below++;
794
+ int n_below = 0;
795
+ for (int i = 0; i < k; i++) {
796
+ if (dis[i] < thresh) {
797
+ n_below++;
798
+ }
808
799
  }
809
- }
810
800
 
811
- return n_below;
801
+ return n_below;
812
802
  }
813
803
 
814
-
815
- } // namespace faiss
804
+ } // namespace faiss