faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,95 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_ASSERT_INCLUDED
11
+ #define FAISS_ASSERT_INCLUDED
12
+
13
+ #include <faiss/impl/FaissException.h>
14
+ #include <cstdlib>
15
+ #include <cstdio>
16
+ #include <string>
17
+
18
+ ///
19
+ /// Assertions
20
+ ///
21
+
22
+ #define FAISS_ASSERT(X) \
23
+ do { \
24
+ if (! (X)) { \
25
+ fprintf(stderr, "Faiss assertion '%s' failed in %s " \
26
+ "at %s:%d\n", \
27
+ #X, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
28
+ abort(); \
29
+ } \
30
+ } while (false)
31
+
32
+ #define FAISS_ASSERT_MSG(X, MSG) \
33
+ do { \
34
+ if (! (X)) { \
35
+ fprintf(stderr, "Faiss assertion '%s' failed in %s " \
36
+ "at %s:%d; details: " MSG "\n", \
37
+ #X, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
38
+ abort(); \
39
+ } \
40
+ } while (false)
41
+
42
+ #define FAISS_ASSERT_FMT(X, FMT, ...) \
43
+ do { \
44
+ if (! (X)) { \
45
+ fprintf(stderr, "Faiss assertion '%s' failed in %s " \
46
+ "at %s:%d; details: " FMT "\n", \
47
+ #X, __PRETTY_FUNCTION__, __FILE__, __LINE__, __VA_ARGS__); \
48
+ abort(); \
49
+ } \
50
+ } while (false)
51
+
52
+ ///
53
+ /// Exceptions for returning user errors
54
+ ///
55
+
56
+ #define FAISS_THROW_MSG(MSG) \
57
+ do { \
58
+ throw faiss::FaissException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
59
+ } while (false)
60
+
61
+ #define FAISS_THROW_FMT(FMT, ...) \
62
+ do { \
63
+ std::string __s; \
64
+ int __size = snprintf(nullptr, 0, FMT, __VA_ARGS__); \
65
+ __s.resize(__size + 1); \
66
+ snprintf(&__s[0], __s.size(), FMT, __VA_ARGS__); \
67
+ throw faiss::FaissException(__s, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
68
+ } while (false)
69
+
70
+ ///
71
+ /// Exceptions thrown upon a conditional failure
72
+ ///
73
+
74
+ #define FAISS_THROW_IF_NOT(X) \
75
+ do { \
76
+ if (!(X)) { \
77
+ FAISS_THROW_FMT("Error: '%s' failed", #X); \
78
+ } \
79
+ } while (false)
80
+
81
+ #define FAISS_THROW_IF_NOT_MSG(X, MSG) \
82
+ do { \
83
+ if (!(X)) { \
84
+ FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
85
+ } \
86
+ } while (false)
87
+
88
+ #define FAISS_THROW_IF_NOT_FMT(X, FMT, ...) \
89
+ do { \
90
+ if (!(X)) { \
91
+ FAISS_THROW_FMT("Error: '%s' failed: " FMT, #X, __VA_ARGS__); \
92
+ } \
93
+ } while (false)
94
+
95
+ #endif
@@ -0,0 +1,66 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/FaissException.h>
11
+ #include <sstream>
12
+
13
+ namespace faiss {
14
+
15
+ FaissException::FaissException(const std::string& m)
16
+ : msg(m) {
17
+ }
18
+
19
+ FaissException::FaissException(const std::string& m,
20
+ const char* funcName,
21
+ const char* file,
22
+ int line) {
23
+ int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
24
+ funcName, file, line, m.c_str());
25
+ msg.resize(size + 1);
26
+ snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
27
+ funcName, file, line, m.c_str());
28
+ }
29
+
30
+ const char*
31
+ FaissException::what() const noexcept {
32
+ return msg.c_str();
33
+ }
34
+
35
+ void handleExceptions(
36
+ std::vector<std::pair<int, std::exception_ptr>>& exceptions) {
37
+ if (exceptions.size() == 1) {
38
+ // throw the single received exception directly
39
+ std::rethrow_exception(exceptions.front().second);
40
+
41
+ } else if (exceptions.size() > 1) {
42
+ // multiple exceptions; aggregate them and return a single exception
43
+ std::stringstream ss;
44
+
45
+ for (auto& p : exceptions) {
46
+ try {
47
+ std::rethrow_exception(p.second);
48
+ } catch (std::exception& ex) {
49
+ if (ex.what()) {
50
+ // exception message available
51
+ ss << "Exception thrown from index " << p.first << ": "
52
+ << ex.what() << "\n";
53
+ } else {
54
+ // No message available
55
+ ss << "Unknown exception thrown from index " << p.first << "\n";
56
+ }
57
+ } catch (...) {
58
+ ss << "Unknown exception thrown from index " << p.first << "\n";
59
+ }
60
+ }
61
+
62
+ throw FaissException(ss.str());
63
+ }
64
+ }
65
+
66
+ }
@@ -0,0 +1,71 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_EXCEPTION_INCLUDED
11
+ #define FAISS_EXCEPTION_INCLUDED
12
+
13
+ #include <exception>
14
+ #include <string>
15
+ #include <vector>
16
+ #include <utility>
17
+
18
+ namespace faiss {
19
+
20
+ /// Base class for Faiss exceptions
21
+ class FaissException : public std::exception {
22
+ public:
23
+ explicit FaissException(const std::string& msg);
24
+
25
+ FaissException(const std::string& msg,
26
+ const char* funcName,
27
+ const char* file,
28
+ int line);
29
+
30
+ /// from std::exception
31
+ const char* what() const noexcept override;
32
+
33
+ std::string msg;
34
+ };
35
+
36
+ /// Handle multiple exceptions from worker threads, throwing an appropriate
37
+ /// exception that aggregates the information
38
+ /// The pair int is the thread that generated the exception
39
+ void
40
+ handleExceptions(std::vector<std::pair<int, std::exception_ptr>>& exceptions);
41
+
42
+ /** bare-bones unique_ptr
43
+ * this one deletes with delete [] */
44
+ template<class T>
45
+ struct ScopeDeleter {
46
+ const T * ptr;
47
+ explicit ScopeDeleter (const T* ptr = nullptr): ptr (ptr) {}
48
+ void release () {ptr = nullptr; }
49
+ void set (const T * ptr_in) { ptr = ptr_in; }
50
+ void swap (ScopeDeleter<T> &other) {std::swap (ptr, other.ptr); }
51
+ ~ScopeDeleter () {
52
+ delete [] ptr;
53
+ }
54
+ };
55
+
56
+ /** same but deletes with the simple delete (least common case) */
57
+ template<class T>
58
+ struct ScopeDeleter1 {
59
+ const T * ptr;
60
+ explicit ScopeDeleter1 (const T* ptr = nullptr): ptr (ptr) {}
61
+ void release () {ptr = nullptr; }
62
+ void set (const T * ptr_in) { ptr = ptr_in; }
63
+ void swap (ScopeDeleter1<T> &other) {std::swap (ptr, other.ptr); }
64
+ ~ScopeDeleter1 () {
65
+ delete ptr;
66
+ }
67
+ };
68
+
69
+ }
70
+
71
+ #endif
@@ -0,0 +1,818 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/HNSW.h>
11
+
12
+ #include <string>
13
+
14
+ #include <faiss/impl/AuxIndexStructures.h>
15
+
16
+ namespace faiss {
17
+
18
+ using idx_t = Index::idx_t;
19
+
20
+ /**************************************************************
21
+ * HNSW structure implementation
22
+ **************************************************************/
23
+
24
+ int HNSW::nb_neighbors(int layer_no) const
25
+ {
26
+ return cum_nneighbor_per_level[layer_no + 1] -
27
+ cum_nneighbor_per_level[layer_no];
28
+ }
29
+
30
+ void HNSW::set_nb_neighbors(int level_no, int n)
31
+ {
32
+ FAISS_THROW_IF_NOT(levels.size() == 0);
33
+ int cur_n = nb_neighbors(level_no);
34
+ for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
35
+ cum_nneighbor_per_level[i] += n - cur_n;
36
+ }
37
+ }
38
+
39
+ int HNSW::cum_nb_neighbors(int layer_no) const
40
+ {
41
+ return cum_nneighbor_per_level[layer_no];
42
+ }
43
+
44
+ void HNSW::neighbor_range(idx_t no, int layer_no,
45
+ size_t * begin, size_t * end) const
46
+ {
47
+ size_t o = offsets[no];
48
+ *begin = o + cum_nb_neighbors(layer_no);
49
+ *end = o + cum_nb_neighbors(layer_no + 1);
50
+ }
51
+
52
+
53
+
54
+ HNSW::HNSW(int M) : rng(12345) {
55
+ set_default_probas(M, 1.0 / log(M));
56
+ max_level = -1;
57
+ entry_point = -1;
58
+ efSearch = 16;
59
+ efConstruction = 40;
60
+ upper_beam = 1;
61
+ offsets.push_back(0);
62
+ }
63
+
64
+
65
+ int HNSW::random_level()
66
+ {
67
+ double f = rng.rand_float();
68
+ // could be a bit faster with bissection
69
+ for (int level = 0; level < assign_probas.size(); level++) {
70
+ if (f < assign_probas[level]) {
71
+ return level;
72
+ }
73
+ f -= assign_probas[level];
74
+ }
75
+ // happens with exponentially low probability
76
+ return assign_probas.size() - 1;
77
+ }
78
+
79
+ void HNSW::set_default_probas(int M, float levelMult)
80
+ {
81
+ int nn = 0;
82
+ cum_nneighbor_per_level.push_back (0);
83
+ for (int level = 0; ;level++) {
84
+ float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
85
+ if (proba < 1e-9) break;
86
+ assign_probas.push_back(proba);
87
+ nn += level == 0 ? M * 2 : M;
88
+ cum_nneighbor_per_level.push_back (nn);
89
+ }
90
+ }
91
+
92
+ void HNSW::clear_neighbor_tables(int level)
93
+ {
94
+ for (int i = 0; i < levels.size(); i++) {
95
+ size_t begin, end;
96
+ neighbor_range(i, level, &begin, &end);
97
+ for (size_t j = begin; j < end; j++) {
98
+ neighbors[j] = -1;
99
+ }
100
+ }
101
+ }
102
+
103
+
104
+ void HNSW::reset() {
105
+ max_level = -1;
106
+ entry_point = -1;
107
+ offsets.clear();
108
+ offsets.push_back(0);
109
+ levels.clear();
110
+ neighbors.clear();
111
+ }
112
+
113
+
114
+
115
+ void HNSW::print_neighbor_stats(int level) const
116
+ {
117
+ FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size());
118
+ printf("stats on level %d, max %d neighbors per vertex:\n",
119
+ level, nb_neighbors(level));
120
+ size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
121
+ #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
122
+ reduction(+: tot_reciprocal) reduction(+: n_node)
123
+ for (int i = 0; i < levels.size(); i++) {
124
+ if (levels[i] > level) {
125
+ n_node++;
126
+ size_t begin, end;
127
+ neighbor_range(i, level, &begin, &end);
128
+ std::unordered_set<int> neighset;
129
+ for (size_t j = begin; j < end; j++) {
130
+ if (neighbors [j] < 0) break;
131
+ neighset.insert(neighbors[j]);
132
+ }
133
+ int n_neigh = neighset.size();
134
+ int n_common = 0;
135
+ int n_reciprocal = 0;
136
+ for (size_t j = begin; j < end; j++) {
137
+ storage_idx_t i2 = neighbors[j];
138
+ if (i2 < 0) break;
139
+ FAISS_ASSERT(i2 != i);
140
+ size_t begin2, end2;
141
+ neighbor_range(i2, level, &begin2, &end2);
142
+ for (size_t j2 = begin2; j2 < end2; j2++) {
143
+ storage_idx_t i3 = neighbors[j2];
144
+ if (i3 < 0) break;
145
+ if (i3 == i) {
146
+ n_reciprocal++;
147
+ continue;
148
+ }
149
+ if (neighset.count(i3)) {
150
+ neighset.erase(i3);
151
+ n_common++;
152
+ }
153
+ }
154
+ }
155
+ tot_neigh += n_neigh;
156
+ tot_common += n_common;
157
+ tot_reciprocal += n_reciprocal;
158
+ }
159
+ }
160
+ float normalizer = n_node;
161
+ printf(" nb of nodes at that level %ld\n", n_node);
162
+ printf(" neighbors per node: %.2f (%ld)\n",
163
+ tot_neigh / normalizer, tot_neigh);
164
+ printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer);
165
+ printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%ld)\n",
166
+ tot_common / normalizer, tot_common);
167
+
168
+
169
+
170
+ }
171
+
172
+
173
+ void HNSW::fill_with_random_links(size_t n)
174
+ {
175
+ int max_level = prepare_level_tab(n);
176
+ RandomGenerator rng2(456);
177
+
178
+ for (int level = max_level - 1; level >= 0; --level) {
179
+ std::vector<int> elts;
180
+ for (int i = 0; i < n; i++) {
181
+ if (levels[i] > level) {
182
+ elts.push_back(i);
183
+ }
184
+ }
185
+ printf ("linking %ld elements in level %d\n",
186
+ elts.size(), level);
187
+
188
+ if (elts.size() == 1) continue;
189
+
190
+ for (int ii = 0; ii < elts.size(); ii++) {
191
+ int i = elts[ii];
192
+ size_t begin, end;
193
+ neighbor_range(i, 0, &begin, &end);
194
+ for (size_t j = begin; j < end; j++) {
195
+ int other = 0;
196
+ do {
197
+ other = elts[rng2.rand_int(elts.size())];
198
+ } while(other == i);
199
+
200
+ neighbors[j] = other;
201
+ }
202
+ }
203
+ }
204
+ }
205
+
206
+
207
+ int HNSW::prepare_level_tab(size_t n, bool preset_levels)
208
+ {
209
+ size_t n0 = offsets.size() - 1;
210
+
211
+ if (preset_levels) {
212
+ FAISS_ASSERT (n0 + n == levels.size());
213
+ } else {
214
+ FAISS_ASSERT (n0 == levels.size());
215
+ for (int i = 0; i < n; i++) {
216
+ int pt_level = random_level();
217
+ levels.push_back(pt_level + 1);
218
+ }
219
+ }
220
+
221
+ int max_level = 0;
222
+ for (int i = 0; i < n; i++) {
223
+ int pt_level = levels[i + n0] - 1;
224
+ if (pt_level > max_level) max_level = pt_level;
225
+ offsets.push_back(offsets.back() +
226
+ cum_nb_neighbors(pt_level + 1));
227
+ neighbors.resize(offsets.back(), -1);
228
+ }
229
+
230
+ return max_level;
231
+ }
232
+
233
+
234
+ /** Enumerate vertices from farthest to nearest from query, keep a
235
+ * neighbor only if there is no previous neighbor that is closer to
236
+ * that vertex than the query.
237
+ */
238
+ void HNSW::shrink_neighbor_list(
239
+ DistanceComputer& qdis,
240
+ std::priority_queue<NodeDistFarther>& input,
241
+ std::vector<NodeDistFarther>& output,
242
+ int max_size)
243
+ {
244
+ while (input.size() > 0) {
245
+ NodeDistFarther v1 = input.top();
246
+ input.pop();
247
+ float dist_v1_q = v1.d;
248
+
249
+ bool good = true;
250
+ for (NodeDistFarther v2 : output) {
251
+ float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
252
+
253
+ if (dist_v1_v2 < dist_v1_q) {
254
+ good = false;
255
+ break;
256
+ }
257
+ }
258
+
259
+ if (good) {
260
+ output.push_back(v1);
261
+ if (output.size() >= max_size) {
262
+ return;
263
+ }
264
+ }
265
+ }
266
+ }
267
+
268
+
269
+ namespace {
270
+
271
+
272
+ using storage_idx_t = HNSW::storage_idx_t;
273
+ using NodeDistCloser = HNSW::NodeDistCloser;
274
+ using NodeDistFarther = HNSW::NodeDistFarther;
275
+
276
+
277
+ /**************************************************************
278
+ * Addition subroutines
279
+ **************************************************************/
280
+
281
+
282
+ /// remove neighbors from the list to make it smaller than max_size
283
+ void shrink_neighbor_list(
284
+ DistanceComputer& qdis,
285
+ std::priority_queue<NodeDistCloser>& resultSet1,
286
+ int max_size)
287
+ {
288
+ if (resultSet1.size() < max_size) {
289
+ return;
290
+ }
291
+ std::priority_queue<NodeDistFarther> resultSet;
292
+ std::vector<NodeDistFarther> returnlist;
293
+
294
+ while (resultSet1.size() > 0) {
295
+ resultSet.emplace(resultSet1.top().d, resultSet1.top().id);
296
+ resultSet1.pop();
297
+ }
298
+
299
+ HNSW::shrink_neighbor_list(qdis, resultSet, returnlist, max_size);
300
+
301
+ for (NodeDistFarther curen2 : returnlist) {
302
+ resultSet1.emplace(curen2.d, curen2.id);
303
+ }
304
+
305
+ }
306
+
307
+
308
+ /// add a link between two elements, possibly shrinking the list
309
+ /// of links to make room for it.
310
+ void add_link(HNSW& hnsw,
311
+ DistanceComputer& qdis,
312
+ storage_idx_t src, storage_idx_t dest,
313
+ int level)
314
+ {
315
+ size_t begin, end;
316
+ hnsw.neighbor_range(src, level, &begin, &end);
317
+ if (hnsw.neighbors[end - 1] == -1) {
318
+ // there is enough room, find a slot to add it
319
+ size_t i = end;
320
+ while(i > begin) {
321
+ if (hnsw.neighbors[i - 1] != -1) break;
322
+ i--;
323
+ }
324
+ hnsw.neighbors[i] = dest;
325
+ return;
326
+ }
327
+
328
+ // otherwise we let them fight out which to keep
329
+
330
+ // copy to resultSet...
331
+ std::priority_queue<NodeDistCloser> resultSet;
332
+ resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
333
+ for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
334
+ storage_idx_t neigh = hnsw.neighbors[i];
335
+ resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
336
+ }
337
+
338
+ shrink_neighbor_list(qdis, resultSet, end - begin);
339
+
340
+ // ...and back
341
+ size_t i = begin;
342
+ while (resultSet.size()) {
343
+ hnsw.neighbors[i++] = resultSet.top().id;
344
+ resultSet.pop();
345
+ }
346
+ // they may have shrunk more than just by 1 element
347
+ while(i < end) {
348
+ hnsw.neighbors[i++] = -1;
349
+ }
350
+ }
351
+
352
+ /// search neighbors on a single level, starting from an entry point
353
+ void search_neighbors_to_add(
354
+ HNSW& hnsw,
355
+ DistanceComputer& qdis,
356
+ std::priority_queue<NodeDistCloser>& results,
357
+ int entry_point,
358
+ float d_entry_point,
359
+ int level,
360
+ VisitedTable &vt)
361
+ {
362
+ // top is nearest candidate
363
+ std::priority_queue<NodeDistFarther> candidates;
364
+
365
+ NodeDistFarther ev(d_entry_point, entry_point);
366
+ candidates.push(ev);
367
+ results.emplace(d_entry_point, entry_point);
368
+ vt.set(entry_point);
369
+
370
+ while (!candidates.empty()) {
371
+ // get nearest
372
+ const NodeDistFarther &currEv = candidates.top();
373
+
374
+ if (currEv.d > results.top().d) {
375
+ break;
376
+ }
377
+ int currNode = currEv.id;
378
+ candidates.pop();
379
+
380
+ // loop over neighbors
381
+ size_t begin, end;
382
+ hnsw.neighbor_range(currNode, level, &begin, &end);
383
+ for(size_t i = begin; i < end; i++) {
384
+ storage_idx_t nodeId = hnsw.neighbors[i];
385
+ if (nodeId < 0) break;
386
+ if (vt.get(nodeId)) continue;
387
+ vt.set(nodeId);
388
+
389
+ float dis = qdis(nodeId);
390
+ NodeDistFarther evE1(dis, nodeId);
391
+
392
+ if (results.size() < hnsw.efConstruction ||
393
+ results.top().d > dis) {
394
+
395
+ results.emplace(dis, nodeId);
396
+ candidates.emplace(dis, nodeId);
397
+ if (results.size() > hnsw.efConstruction) {
398
+ results.pop();
399
+ }
400
+ }
401
+ }
402
+ }
403
+ vt.advance();
404
+ }
405
+
406
+
407
+ /**************************************************************
408
+ * Searching subroutines
409
+ **************************************************************/
410
+
411
+ /// greedily update a nearest vector at a given level
412
+ void greedy_update_nearest(const HNSW& hnsw,
413
+ DistanceComputer& qdis,
414
+ int level,
415
+ storage_idx_t& nearest,
416
+ float& d_nearest)
417
+ {
418
+ for(;;) {
419
+ storage_idx_t prev_nearest = nearest;
420
+
421
+ size_t begin, end;
422
+ hnsw.neighbor_range(nearest, level, &begin, &end);
423
+ for(size_t i = begin; i < end; i++) {
424
+ storage_idx_t v = hnsw.neighbors[i];
425
+ if (v < 0) break;
426
+ float dis = qdis(v);
427
+ if (dis < d_nearest) {
428
+ nearest = v;
429
+ d_nearest = dis;
430
+ }
431
+ }
432
+ if (nearest == prev_nearest) {
433
+ return;
434
+ }
435
+ }
436
+ }
437
+
438
+
439
+ } // namespace
440
+
441
+
442
+ /// Finds neighbors and builds links with them, starting from an entry
443
+ /// point. The own neighbor list is assumed to be locked.
444
+ void HNSW::add_links_starting_from(DistanceComputer& ptdis,
445
+ storage_idx_t pt_id,
446
+ storage_idx_t nearest,
447
+ float d_nearest,
448
+ int level,
449
+ omp_lock_t *locks,
450
+ VisitedTable &vt)
451
+ {
452
+ std::priority_queue<NodeDistCloser> link_targets;
453
+
454
+ search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest,
455
+ level, vt);
456
+
457
+ // but we can afford only this many neighbors
458
+ int M = nb_neighbors(level);
459
+
460
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
461
+
462
+ while (!link_targets.empty()) {
463
+ int other_id = link_targets.top().id;
464
+
465
+ omp_set_lock(&locks[other_id]);
466
+ add_link(*this, ptdis, other_id, pt_id, level);
467
+ omp_unset_lock(&locks[other_id]);
468
+
469
+ add_link(*this, ptdis, pt_id, other_id, level);
470
+
471
+ link_targets.pop();
472
+ }
473
+ }
474
+
475
+
476
+ /**************************************************************
477
+ * Building, parallel
478
+ **************************************************************/
479
+
480
+ void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
481
+ std::vector<omp_lock_t>& locks,
482
+ VisitedTable& vt)
483
+ {
484
+ // greedy search on upper levels
485
+
486
+ storage_idx_t nearest;
487
+ #pragma omp critical
488
+ {
489
+ nearest = entry_point;
490
+
491
+ if (nearest == -1) {
492
+ max_level = pt_level;
493
+ entry_point = pt_id;
494
+ }
495
+ }
496
+
497
+ if (nearest < 0) {
498
+ return;
499
+ }
500
+
501
+ omp_set_lock(&locks[pt_id]);
502
+
503
+ int level = max_level; // level at which we start adding neighbors
504
+ float d_nearest = ptdis(nearest);
505
+
506
+ for(; level > pt_level; level--) {
507
+ greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
508
+ }
509
+
510
+ for(; level >= 0; level--) {
511
+ add_links_starting_from(ptdis, pt_id, nearest, d_nearest,
512
+ level, locks.data(), vt);
513
+ }
514
+
515
+ omp_unset_lock(&locks[pt_id]);
516
+
517
+ if (pt_level > max_level) {
518
+ max_level = pt_level;
519
+ entry_point = pt_id;
520
+ }
521
+ }
522
+
523
+
524
+ /** Do a BFS on the candidates list */
525
+
526
+ int HNSW::search_from_candidates(
527
+ DistanceComputer& qdis, int k,
528
+ idx_t *I, float *D,
529
+ MinimaxHeap& candidates,
530
+ VisitedTable& vt,
531
+ int level, int nres_in) const
532
+ {
533
+ int nres = nres_in;
534
+ int ndis = 0;
535
+ for (int i = 0; i < candidates.size(); i++) {
536
+ idx_t v1 = candidates.ids[i];
537
+ float d = candidates.dis[i];
538
+ FAISS_ASSERT(v1 >= 0);
539
+ if (nres < k) {
540
+ faiss::maxheap_push(++nres, D, I, d, v1);
541
+ } else if (d < D[0]) {
542
+ faiss::maxheap_pop(nres--, D, I);
543
+ faiss::maxheap_push(++nres, D, I, d, v1);
544
+ }
545
+ vt.set(v1);
546
+ }
547
+
548
+ bool do_dis_check = check_relative_distance;
549
+ int nstep = 0;
550
+
551
+ while (candidates.size() > 0) {
552
+ float d0 = 0;
553
+ int v0 = candidates.pop_min(&d0);
554
+
555
+ if (do_dis_check) {
556
+ // tricky stopping condition: there are more that ef
557
+ // distances that are processed already that are smaller
558
+ // than d0
559
+
560
+ int n_dis_below = candidates.count_below(d0);
561
+ if(n_dis_below >= efSearch) {
562
+ break;
563
+ }
564
+ }
565
+
566
+ size_t begin, end;
567
+ neighbor_range(v0, level, &begin, &end);
568
+
569
+ for (size_t j = begin; j < end; j++) {
570
+ int v1 = neighbors[j];
571
+ if (v1 < 0) break;
572
+ if (vt.get(v1)) {
573
+ continue;
574
+ }
575
+ vt.set(v1);
576
+ ndis++;
577
+ float d = qdis(v1);
578
+ if (nres < k) {
579
+ faiss::maxheap_push(++nres, D, I, d, v1);
580
+ } else if (d < D[0]) {
581
+ faiss::maxheap_pop(nres--, D, I);
582
+ faiss::maxheap_push(++nres, D, I, d, v1);
583
+ }
584
+ candidates.push(v1, d);
585
+ }
586
+
587
+ nstep++;
588
+ if (!do_dis_check && nstep > efSearch) {
589
+ break;
590
+ }
591
+ }
592
+
593
+ if (level == 0) {
594
+ #pragma omp critical
595
+ {
596
+ hnsw_stats.n1 ++;
597
+ if (candidates.size() == 0) {
598
+ hnsw_stats.n2 ++;
599
+ }
600
+ hnsw_stats.n3 += ndis;
601
+ }
602
+ }
603
+
604
+ return nres;
605
+ }
606
+
607
+
608
+ /**************************************************************
609
+ * Searching
610
+ **************************************************************/
611
+
612
+ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
613
+ const Node& node,
614
+ DistanceComputer& qdis,
615
+ int ef,
616
+ VisitedTable *vt) const
617
+ {
618
+ int ndis = 0;
619
+ std::priority_queue<Node> top_candidates;
620
+ std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
621
+
622
+ top_candidates.push(node);
623
+ candidates.push(node);
624
+
625
+ vt->set(node.second);
626
+
627
+ while (!candidates.empty()) {
628
+ float d0;
629
+ storage_idx_t v0;
630
+ std::tie(d0, v0) = candidates.top();
631
+
632
+ if (d0 > top_candidates.top().first) {
633
+ break;
634
+ }
635
+
636
+ candidates.pop();
637
+
638
+ size_t begin, end;
639
+ neighbor_range(v0, 0, &begin, &end);
640
+
641
+ for (size_t j = begin; j < end; ++j) {
642
+ int v1 = neighbors[j];
643
+
644
+ if (v1 < 0) {
645
+ break;
646
+ }
647
+ if (vt->get(v1)) {
648
+ continue;
649
+ }
650
+
651
+ vt->set(v1);
652
+
653
+ float d1 = qdis(v1);
654
+ ++ndis;
655
+
656
+ if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
657
+ candidates.emplace(d1, v1);
658
+ top_candidates.emplace(d1, v1);
659
+
660
+ if (top_candidates.size() > ef) {
661
+ top_candidates.pop();
662
+ }
663
+ }
664
+ }
665
+ }
666
+
667
+ #pragma omp critical
668
+ {
669
+ ++hnsw_stats.n1;
670
+ if (candidates.size() == 0) {
671
+ ++hnsw_stats.n2;
672
+ }
673
+ hnsw_stats.n3 += ndis;
674
+ }
675
+
676
+ return top_candidates;
677
+ }
678
+
679
+ void HNSW::search(DistanceComputer& qdis, int k,
680
+ idx_t *I, float *D,
681
+ VisitedTable& vt) const
682
+ {
683
+ if (upper_beam == 1) {
684
+
685
+ // greedy search on upper levels
686
+ storage_idx_t nearest = entry_point;
687
+ float d_nearest = qdis(nearest);
688
+
689
+ for(int level = max_level; level >= 1; level--) {
690
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
691
+ }
692
+
693
+ int ef = std::max(efSearch, k);
694
+ if (search_bounded_queue) {
695
+ MinimaxHeap candidates(ef);
696
+
697
+ candidates.push(nearest, d_nearest);
698
+
699
+ search_from_candidates(qdis, k, I, D, candidates, vt, 0);
700
+ } else {
701
+ std::priority_queue<Node> top_candidates =
702
+ search_from_candidate_unbounded(Node(d_nearest, nearest),
703
+ qdis, ef, &vt);
704
+
705
+ while (top_candidates.size() > k) {
706
+ top_candidates.pop();
707
+ }
708
+
709
+ int nres = 0;
710
+ while (!top_candidates.empty()) {
711
+ float d;
712
+ storage_idx_t label;
713
+ std::tie(d, label) = top_candidates.top();
714
+ faiss::maxheap_push(++nres, D, I, d, label);
715
+ top_candidates.pop();
716
+ }
717
+ }
718
+
719
+ vt.advance();
720
+
721
+ } else {
722
+ int candidates_size = upper_beam;
723
+ MinimaxHeap candidates(candidates_size);
724
+
725
+ std::vector<idx_t> I_to_next(candidates_size);
726
+ std::vector<float> D_to_next(candidates_size);
727
+
728
+ int nres = 1;
729
+ I_to_next[0] = entry_point;
730
+ D_to_next[0] = qdis(entry_point);
731
+
732
+ for(int level = max_level; level >= 0; level--) {
733
+
734
+ // copy I, D -> candidates
735
+
736
+ candidates.clear();
737
+
738
+ for (int i = 0; i < nres; i++) {
739
+ candidates.push(I_to_next[i], D_to_next[i]);
740
+ }
741
+
742
+ if (level == 0) {
743
+ nres = search_from_candidates(qdis, k, I, D, candidates, vt, 0);
744
+ } else {
745
+ nres = search_from_candidates(
746
+ qdis, candidates_size,
747
+ I_to_next.data(), D_to_next.data(),
748
+ candidates, vt, level
749
+ );
750
+ }
751
+ vt.advance();
752
+ }
753
+ }
754
+ }
755
+
756
+
757
+ void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
758
+ if (k == n) {
759
+ if (v >= dis[0]) return;
760
+ faiss::heap_pop<HC> (k--, dis.data(), ids.data());
761
+ --nvalid;
762
+ }
763
+ faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
764
+ ++nvalid;
765
+ }
766
+
767
+ float HNSW::MinimaxHeap::max() const {
768
+ return dis[0];
769
+ }
770
+
771
+ int HNSW::MinimaxHeap::size() const {
772
+ return nvalid;
773
+ }
774
+
775
+ void HNSW::MinimaxHeap::clear() {
776
+ nvalid = k = 0;
777
+ }
778
+
779
+ int HNSW::MinimaxHeap::pop_min(float *vmin_out) {
780
+ assert(k > 0);
781
+ // returns min. This is an O(n) operation
782
+ int i = k - 1;
783
+ while (i >= 0) {
784
+ if (ids[i] != -1) break;
785
+ i--;
786
+ }
787
+ if (i == -1) return -1;
788
+ int imin = i;
789
+ float vmin = dis[i];
790
+ i--;
791
+ while(i >= 0) {
792
+ if (ids[i] != -1 && dis[i] < vmin) {
793
+ vmin = dis[i];
794
+ imin = i;
795
+ }
796
+ i--;
797
+ }
798
+ if (vmin_out) *vmin_out = vmin;
799
+ int ret = ids[imin];
800
+ ids[imin] = -1;
801
+ --nvalid;
802
+
803
+ return ret;
804
+ }
805
+
806
+ int HNSW::MinimaxHeap::count_below(float thresh) {
807
+ int n_below = 0;
808
+ for(int i = 0; i < k; i++) {
809
+ if (dis[i] < thresh) {
810
+ n_below++;
811
+ }
812
+ }
813
+
814
+ return n_below;
815
+ }
816
+
817
+
818
+ } // namespace faiss