faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -10,71 +10,81 @@
10
10
  #include <faiss/impl/FaissException.h>
11
11
  #include <sstream>
12
12
 
13
- #ifdef __GNUG__
13
+ #ifdef __GNUG__
14
14
  #include <cxxabi.h>
15
15
  #endif
16
16
 
17
17
  namespace faiss {
18
18
 
19
- FaissException::FaissException(const std::string& m)
20
- : msg(m) {
21
- }
22
-
23
- FaissException::FaissException(const std::string& m,
24
- const char* funcName,
25
- const char* file,
26
- int line) {
27
- int size = snprintf(nullptr, 0, "Error in %s at %s:%d: %s",
28
- funcName, file, line, m.c_str());
29
- msg.resize(size + 1);
30
- snprintf(&msg[0], msg.size(), "Error in %s at %s:%d: %s",
31
- funcName, file, line, m.c_str());
19
+ FaissException::FaissException(const std::string& m) : msg(m) {}
20
+
21
+ FaissException::FaissException(
22
+ const std::string& m,
23
+ const char* funcName,
24
+ const char* file,
25
+ int line) {
26
+ int size = snprintf(
27
+ nullptr,
28
+ 0,
29
+ "Error in %s at %s:%d: %s",
30
+ funcName,
31
+ file,
32
+ line,
33
+ m.c_str());
34
+ msg.resize(size + 1);
35
+ snprintf(
36
+ &msg[0],
37
+ msg.size(),
38
+ "Error in %s at %s:%d: %s",
39
+ funcName,
40
+ file,
41
+ line,
42
+ m.c_str());
32
43
  }
33
44
 
34
- const char*
35
- FaissException::what() const noexcept {
36
- return msg.c_str();
45
+ const char* FaissException::what() const noexcept {
46
+ return msg.c_str();
37
47
  }
38
48
 
39
49
  void handleExceptions(
40
- std::vector<std::pair<int, std::exception_ptr>>& exceptions) {
41
- if (exceptions.size() == 1) {
42
- // throw the single received exception directly
43
- std::rethrow_exception(exceptions.front().second);
44
-
45
- } else if (exceptions.size() > 1) {
46
- // multiple exceptions; aggregate them and return a single exception
47
- std::stringstream ss;
48
-
49
- for (auto& p : exceptions) {
50
- try {
51
- std::rethrow_exception(p.second);
52
- } catch (std::exception& ex) {
53
- if (ex.what()) {
54
- // exception message available
55
- ss << "Exception thrown from index " << p.first << ": "
56
- << ex.what() << "\n";
57
- } else {
58
- // No message available
59
- ss << "Unknown exception thrown from index " << p.first << "\n";
50
+ std::vector<std::pair<int, std::exception_ptr>>& exceptions) {
51
+ if (exceptions.size() == 1) {
52
+ // throw the single received exception directly
53
+ std::rethrow_exception(exceptions.front().second);
54
+
55
+ } else if (exceptions.size() > 1) {
56
+ // multiple exceptions; aggregate them and return a single exception
57
+ std::stringstream ss;
58
+
59
+ for (auto& p : exceptions) {
60
+ try {
61
+ std::rethrow_exception(p.second);
62
+ } catch (std::exception& ex) {
63
+ if (ex.what()) {
64
+ // exception message available
65
+ ss << "Exception thrown from index " << p.first << ": "
66
+ << ex.what() << "\n";
67
+ } else {
68
+ // No message available
69
+ ss << "Unknown exception thrown from index " << p.first
70
+ << "\n";
71
+ }
72
+ } catch (...) {
73
+ ss << "Unknown exception thrown from index " << p.first << "\n";
74
+ }
60
75
  }
61
- } catch (...) {
62
- ss << "Unknown exception thrown from index " << p.first << "\n";
63
- }
64
- }
65
76
 
66
- throw FaissException(ss.str());
67
- }
77
+ throw FaissException(ss.str());
78
+ }
68
79
  }
69
80
 
70
-
71
81
  // From
72
82
  // https://stackoverflow.com/questions/281818/unmangling-the-result-of-stdtype-infoname
73
83
 
74
84
  std::string demangle_cpp_symbol(const char* name) {
75
85
  #ifdef __GNUG__
76
86
  int status = -1;
77
- const char * res = abi::__cxa_demangle(name, nullptr, nullptr, &status);
87
+ const char* res = abi::__cxa_demangle(name, nullptr, nullptr, &status);
78
88
  std::string sres;
79
89
  if (status == 0) {
80
90
  sres = res;
@@ -87,6 +97,4 @@ std::string demangle_cpp_symbol(const char* name) {
87
97
  #endif
88
98
  }
89
99
 
90
-
91
-
92
- } // namespace
100
+ } // namespace faiss
@@ -12,56 +12,69 @@
12
12
 
13
13
  #include <exception>
14
14
  #include <string>
15
- #include <vector>
16
15
  #include <utility>
16
+ #include <vector>
17
17
 
18
18
  namespace faiss {
19
19
 
20
20
  /// Base class for Faiss exceptions
21
21
  class FaissException : public std::exception {
22
- public:
23
- explicit FaissException(const std::string& msg);
22
+ public:
23
+ explicit FaissException(const std::string& msg);
24
24
 
25
- FaissException(const std::string& msg,
26
- const char* funcName,
27
- const char* file,
28
- int line);
25
+ FaissException(
26
+ const std::string& msg,
27
+ const char* funcName,
28
+ const char* file,
29
+ int line);
29
30
 
30
- /// from std::exception
31
- const char* what() const noexcept override;
31
+ /// from std::exception
32
+ const char* what() const noexcept override;
32
33
 
33
- std::string msg;
34
+ std::string msg;
34
35
  };
35
36
 
36
37
  /// Handle multiple exceptions from worker threads, throwing an appropriate
37
38
  /// exception that aggregates the information
38
39
  /// The pair int is the thread that generated the exception
39
- void
40
- handleExceptions(std::vector<std::pair<int, std::exception_ptr>>& exceptions);
40
+ void handleExceptions(
41
+ std::vector<std::pair<int, std::exception_ptr>>& exceptions);
41
42
 
42
43
  /** bare-bones unique_ptr
43
44
  * this one deletes with delete [] */
44
- template<class T>
45
+ template <class T>
45
46
  struct ScopeDeleter {
46
- const T * ptr;
47
- explicit ScopeDeleter (const T* ptr = nullptr): ptr (ptr) {}
48
- void release () {ptr = nullptr; }
49
- void set (const T * ptr_in) { ptr = ptr_in; }
50
- void swap (ScopeDeleter<T> &other) {std::swap (ptr, other.ptr); }
51
- ~ScopeDeleter () {
52
- delete [] ptr;
47
+ const T* ptr;
48
+ explicit ScopeDeleter(const T* ptr = nullptr) : ptr(ptr) {}
49
+ void release() {
50
+ ptr = nullptr;
51
+ }
52
+ void set(const T* ptr_in) {
53
+ ptr = ptr_in;
54
+ }
55
+ void swap(ScopeDeleter<T>& other) {
56
+ std::swap(ptr, other.ptr);
57
+ }
58
+ ~ScopeDeleter() {
59
+ delete[] ptr;
53
60
  }
54
61
  };
55
62
 
56
63
  /** same but deletes with the simple delete (least common case) */
57
- template<class T>
64
+ template <class T>
58
65
  struct ScopeDeleter1 {
59
- const T * ptr;
60
- explicit ScopeDeleter1 (const T* ptr = nullptr): ptr (ptr) {}
61
- void release () {ptr = nullptr; }
62
- void set (const T * ptr_in) { ptr = ptr_in; }
63
- void swap (ScopeDeleter1<T> &other) {std::swap (ptr, other.ptr); }
64
- ~ScopeDeleter1 () {
66
+ const T* ptr;
67
+ explicit ScopeDeleter1(const T* ptr = nullptr) : ptr(ptr) {}
68
+ void release() {
69
+ ptr = nullptr;
70
+ }
71
+ void set(const T* ptr_in) {
72
+ ptr = ptr_in;
73
+ }
74
+ void swap(ScopeDeleter1<T>& other) {
75
+ std::swap(ptr, other.ptr);
76
+ }
77
+ ~ScopeDeleter1() {
65
78
  delete ptr;
66
79
  }
67
80
  };
@@ -69,7 +82,6 @@ struct ScopeDeleter1 {
69
82
  /// make typeids more readable
70
83
  std::string demangle_cpp_symbol(const char* name);
71
84
 
72
-
73
- }
85
+ } // namespace faiss
74
86
 
75
87
  #endif
@@ -15,275 +15,254 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
-
19
18
  /**************************************************************
20
19
  * HNSW structure implementation
21
20
  **************************************************************/
22
21
 
23
- int HNSW::nb_neighbors(int layer_no) const
24
- {
25
- return cum_nneighbor_per_level[layer_no + 1] -
26
- cum_nneighbor_per_level[layer_no];
22
+ int HNSW::nb_neighbors(int layer_no) const {
23
+ return cum_nneighbor_per_level[layer_no + 1] -
24
+ cum_nneighbor_per_level[layer_no];
27
25
  }
28
26
 
29
- void HNSW::set_nb_neighbors(int level_no, int n)
30
- {
31
- FAISS_THROW_IF_NOT(levels.size() == 0);
32
- int cur_n = nb_neighbors(level_no);
33
- for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
34
- cum_nneighbor_per_level[i] += n - cur_n;
35
- }
27
+ void HNSW::set_nb_neighbors(int level_no, int n) {
28
+ FAISS_THROW_IF_NOT(levels.size() == 0);
29
+ int cur_n = nb_neighbors(level_no);
30
+ for (int i = level_no + 1; i < cum_nneighbor_per_level.size(); i++) {
31
+ cum_nneighbor_per_level[i] += n - cur_n;
32
+ }
36
33
  }
37
34
 
38
- int HNSW::cum_nb_neighbors(int layer_no) const
39
- {
40
- return cum_nneighbor_per_level[layer_no];
35
+ int HNSW::cum_nb_neighbors(int layer_no) const {
36
+ return cum_nneighbor_per_level[layer_no];
41
37
  }
42
38
 
43
- void HNSW::neighbor_range(idx_t no, int layer_no,
44
- size_t * begin, size_t * end) const
45
- {
46
- size_t o = offsets[no];
47
- *begin = o + cum_nb_neighbors(layer_no);
48
- *end = o + cum_nb_neighbors(layer_no + 1);
39
+ void HNSW::neighbor_range(idx_t no, int layer_no, size_t* begin, size_t* end)
40
+ const {
41
+ size_t o = offsets[no];
42
+ *begin = o + cum_nb_neighbors(layer_no);
43
+ *end = o + cum_nb_neighbors(layer_no + 1);
49
44
  }
50
45
 
51
-
52
-
53
46
  HNSW::HNSW(int M) : rng(12345) {
54
- set_default_probas(M, 1.0 / log(M));
55
- max_level = -1;
56
- entry_point = -1;
57
- efSearch = 16;
58
- efConstruction = 40;
59
- upper_beam = 1;
60
- offsets.push_back(0);
47
+ set_default_probas(M, 1.0 / log(M));
48
+ max_level = -1;
49
+ entry_point = -1;
50
+ efSearch = 16;
51
+ efConstruction = 40;
52
+ upper_beam = 1;
53
+ offsets.push_back(0);
61
54
  }
62
55
 
63
-
64
- int HNSW::random_level()
65
- {
66
- double f = rng.rand_float();
67
- // could be a bit faster with bissection
68
- for (int level = 0; level < assign_probas.size(); level++) {
69
- if (f < assign_probas[level]) {
70
- return level;
56
+ int HNSW::random_level() {
57
+ double f = rng.rand_float();
58
+ // could be a bit faster with bissection
59
+ for (int level = 0; level < assign_probas.size(); level++) {
60
+ if (f < assign_probas[level]) {
61
+ return level;
62
+ }
63
+ f -= assign_probas[level];
71
64
  }
72
- f -= assign_probas[level];
73
- }
74
- // happens with exponentially low probability
75
- return assign_probas.size() - 1;
65
+ // happens with exponentially low probability
66
+ return assign_probas.size() - 1;
76
67
  }
77
68
 
78
- void HNSW::set_default_probas(int M, float levelMult)
79
- {
80
- int nn = 0;
81
- cum_nneighbor_per_level.push_back (0);
82
- for (int level = 0; ;level++) {
83
- float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
84
- if (proba < 1e-9) break;
85
- assign_probas.push_back(proba);
86
- nn += level == 0 ? M * 2 : M;
87
- cum_nneighbor_per_level.push_back (nn);
88
- }
69
+ void HNSW::set_default_probas(int M, float levelMult) {
70
+ int nn = 0;
71
+ cum_nneighbor_per_level.push_back(0);
72
+ for (int level = 0;; level++) {
73
+ float proba = exp(-level / levelMult) * (1 - exp(-1 / levelMult));
74
+ if (proba < 1e-9)
75
+ break;
76
+ assign_probas.push_back(proba);
77
+ nn += level == 0 ? M * 2 : M;
78
+ cum_nneighbor_per_level.push_back(nn);
79
+ }
89
80
  }
90
81
 
91
- void HNSW::clear_neighbor_tables(int level)
92
- {
93
- for (int i = 0; i < levels.size(); i++) {
94
- size_t begin, end;
95
- neighbor_range(i, level, &begin, &end);
96
- for (size_t j = begin; j < end; j++) {
97
- neighbors[j] = -1;
82
+ void HNSW::clear_neighbor_tables(int level) {
83
+ for (int i = 0; i < levels.size(); i++) {
84
+ size_t begin, end;
85
+ neighbor_range(i, level, &begin, &end);
86
+ for (size_t j = begin; j < end; j++) {
87
+ neighbors[j] = -1;
88
+ }
98
89
  }
99
- }
100
90
  }
101
91
 
102
-
103
92
  void HNSW::reset() {
104
- max_level = -1;
105
- entry_point = -1;
106
- offsets.clear();
107
- offsets.push_back(0);
108
- levels.clear();
109
- neighbors.clear();
93
+ max_level = -1;
94
+ entry_point = -1;
95
+ offsets.clear();
96
+ offsets.push_back(0);
97
+ levels.clear();
98
+ neighbors.clear();
110
99
  }
111
100
 
112
-
113
-
114
- void HNSW::print_neighbor_stats(int level) const
115
- {
116
- FAISS_THROW_IF_NOT (level < cum_nneighbor_per_level.size());
117
- printf("stats on level %d, max %d neighbors per vertex:\n",
118
- level, nb_neighbors(level));
119
- size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
101
+ void HNSW::print_neighbor_stats(int level) const {
102
+ FAISS_THROW_IF_NOT(level < cum_nneighbor_per_level.size());
103
+ printf("stats on level %d, max %d neighbors per vertex:\n",
104
+ level,
105
+ nb_neighbors(level));
106
+ size_t tot_neigh = 0, tot_common = 0, tot_reciprocal = 0, n_node = 0;
120
107
  #pragma omp parallel for reduction(+: tot_neigh) reduction(+: tot_common) \
121
108
  reduction(+: tot_reciprocal) reduction(+: n_node)
122
- for (int i = 0; i < levels.size(); i++) {
123
- if (levels[i] > level) {
124
- n_node++;
125
- size_t begin, end;
126
- neighbor_range(i, level, &begin, &end);
127
- std::unordered_set<int> neighset;
128
- for (size_t j = begin; j < end; j++) {
129
- if (neighbors [j] < 0) break;
130
- neighset.insert(neighbors[j]);
131
- }
132
- int n_neigh = neighset.size();
133
- int n_common = 0;
134
- int n_reciprocal = 0;
135
- for (size_t j = begin; j < end; j++) {
136
- storage_idx_t i2 = neighbors[j];
137
- if (i2 < 0) break;
138
- FAISS_ASSERT(i2 != i);
139
- size_t begin2, end2;
140
- neighbor_range(i2, level, &begin2, &end2);
141
- for (size_t j2 = begin2; j2 < end2; j2++) {
142
- storage_idx_t i3 = neighbors[j2];
143
- if (i3 < 0) break;
144
- if (i3 == i) {
145
- n_reciprocal++;
146
- continue;
147
- }
148
- if (neighset.count(i3)) {
149
- neighset.erase(i3);
150
- n_common++;
151
- }
109
+ for (int i = 0; i < levels.size(); i++) {
110
+ if (levels[i] > level) {
111
+ n_node++;
112
+ size_t begin, end;
113
+ neighbor_range(i, level, &begin, &end);
114
+ std::unordered_set<int> neighset;
115
+ for (size_t j = begin; j < end; j++) {
116
+ if (neighbors[j] < 0)
117
+ break;
118
+ neighset.insert(neighbors[j]);
119
+ }
120
+ int n_neigh = neighset.size();
121
+ int n_common = 0;
122
+ int n_reciprocal = 0;
123
+ for (size_t j = begin; j < end; j++) {
124
+ storage_idx_t i2 = neighbors[j];
125
+ if (i2 < 0)
126
+ break;
127
+ FAISS_ASSERT(i2 != i);
128
+ size_t begin2, end2;
129
+ neighbor_range(i2, level, &begin2, &end2);
130
+ for (size_t j2 = begin2; j2 < end2; j2++) {
131
+ storage_idx_t i3 = neighbors[j2];
132
+ if (i3 < 0)
133
+ break;
134
+ if (i3 == i) {
135
+ n_reciprocal++;
136
+ continue;
137
+ }
138
+ if (neighset.count(i3)) {
139
+ neighset.erase(i3);
140
+ n_common++;
141
+ }
142
+ }
143
+ }
144
+ tot_neigh += n_neigh;
145
+ tot_common += n_common;
146
+ tot_reciprocal += n_reciprocal;
152
147
  }
153
- }
154
- tot_neigh += n_neigh;
155
- tot_common += n_common;
156
- tot_reciprocal += n_reciprocal;
157
148
  }
158
- }
159
- float normalizer = n_node;
160
- printf(" nb of nodes at that level %zd\n", n_node);
161
- printf(" neighbors per node: %.2f (%zd)\n",
162
- tot_neigh / normalizer, tot_neigh);
163
- printf(" nb of reciprocal neighbors: %.2f\n", tot_reciprocal / normalizer);
164
- printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
165
- tot_common / normalizer, tot_common);
166
-
167
-
168
-
149
+ float normalizer = n_node;
150
+ printf(" nb of nodes at that level %zd\n", n_node);
151
+ printf(" neighbors per node: %.2f (%zd)\n",
152
+ tot_neigh / normalizer,
153
+ tot_neigh);
154
+ printf(" nb of reciprocal neighbors: %.2f\n",
155
+ tot_reciprocal / normalizer);
156
+ printf(" nb of neighbors that are also neighbor-of-neighbors: %.2f (%zd)\n",
157
+ tot_common / normalizer,
158
+ tot_common);
169
159
  }
170
160
 
161
+ void HNSW::fill_with_random_links(size_t n) {
162
+ int max_level = prepare_level_tab(n);
163
+ RandomGenerator rng2(456);
171
164
 
172
- void HNSW::fill_with_random_links(size_t n)
173
- {
174
- int max_level = prepare_level_tab(n);
175
- RandomGenerator rng2(456);
176
-
177
- for (int level = max_level - 1; level >= 0; --level) {
178
- std::vector<int> elts;
179
- for (int i = 0; i < n; i++) {
180
- if (levels[i] > level) {
181
- elts.push_back(i);
182
- }
183
- }
184
- printf ("linking %zd elements in level %d\n",
185
- elts.size(), level);
186
-
187
- if (elts.size() == 1) continue;
165
+ for (int level = max_level - 1; level >= 0; --level) {
166
+ std::vector<int> elts;
167
+ for (int i = 0; i < n; i++) {
168
+ if (levels[i] > level) {
169
+ elts.push_back(i);
170
+ }
171
+ }
172
+ printf("linking %zd elements in level %d\n", elts.size(), level);
188
173
 
189
- for (int ii = 0; ii < elts.size(); ii++) {
190
- int i = elts[ii];
191
- size_t begin, end;
192
- neighbor_range(i, 0, &begin, &end);
193
- for (size_t j = begin; j < end; j++) {
194
- int other = 0;
195
- do {
196
- other = elts[rng2.rand_int(elts.size())];
197
- } while(other == i);
174
+ if (elts.size() == 1)
175
+ continue;
198
176
 
199
- neighbors[j] = other;
200
- }
177
+ for (int ii = 0; ii < elts.size(); ii++) {
178
+ int i = elts[ii];
179
+ size_t begin, end;
180
+ neighbor_range(i, 0, &begin, &end);
181
+ for (size_t j = begin; j < end; j++) {
182
+ int other = 0;
183
+ do {
184
+ other = elts[rng2.rand_int(elts.size())];
185
+ } while (other == i);
186
+
187
+ neighbors[j] = other;
188
+ }
189
+ }
201
190
  }
202
- }
203
191
  }
204
192
 
193
+ int HNSW::prepare_level_tab(size_t n, bool preset_levels) {
194
+ size_t n0 = offsets.size() - 1;
205
195
 
206
- int HNSW::prepare_level_tab(size_t n, bool preset_levels)
207
- {
208
- size_t n0 = offsets.size() - 1;
196
+ if (preset_levels) {
197
+ FAISS_ASSERT(n0 + n == levels.size());
198
+ } else {
199
+ FAISS_ASSERT(n0 == levels.size());
200
+ for (int i = 0; i < n; i++) {
201
+ int pt_level = random_level();
202
+ levels.push_back(pt_level + 1);
203
+ }
204
+ }
209
205
 
210
- if (preset_levels) {
211
- FAISS_ASSERT (n0 + n == levels.size());
212
- } else {
213
- FAISS_ASSERT (n0 == levels.size());
206
+ int max_level = 0;
214
207
  for (int i = 0; i < n; i++) {
215
- int pt_level = random_level();
216
- levels.push_back(pt_level + 1);
208
+ int pt_level = levels[i + n0] - 1;
209
+ if (pt_level > max_level)
210
+ max_level = pt_level;
211
+ offsets.push_back(offsets.back() + cum_nb_neighbors(pt_level + 1));
212
+ neighbors.resize(offsets.back(), -1);
217
213
  }
218
- }
219
214
 
220
- int max_level = 0;
221
- for (int i = 0; i < n; i++) {
222
- int pt_level = levels[i + n0] - 1;
223
- if (pt_level > max_level) max_level = pt_level;
224
- offsets.push_back(offsets.back() +
225
- cum_nb_neighbors(pt_level + 1));
226
- neighbors.resize(offsets.back(), -1);
227
- }
228
-
229
- return max_level;
215
+ return max_level;
230
216
  }
231
217
 
232
-
233
218
  /** Enumerate vertices from farthest to nearest from query, keep a
234
219
  * neighbor only if there is no previous neighbor that is closer to
235
220
  * that vertex than the query.
236
221
  */
237
222
  void HNSW::shrink_neighbor_list(
238
- DistanceComputer& qdis,
239
- std::priority_queue<NodeDistFarther>& input,
240
- std::vector<NodeDistFarther>& output,
241
- int max_size)
242
- {
243
- while (input.size() > 0) {
244
- NodeDistFarther v1 = input.top();
245
- input.pop();
246
- float dist_v1_q = v1.d;
247
-
248
- bool good = true;
249
- for (NodeDistFarther v2 : output) {
250
- float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
251
-
252
- if (dist_v1_v2 < dist_v1_q) {
253
- good = false;
254
- break;
255
- }
256
- }
257
-
258
- if (good) {
259
- output.push_back(v1);
260
- if (output.size() >= max_size) {
261
- return;
262
- }
223
+ DistanceComputer& qdis,
224
+ std::priority_queue<NodeDistFarther>& input,
225
+ std::vector<NodeDistFarther>& output,
226
+ int max_size) {
227
+ while (input.size() > 0) {
228
+ NodeDistFarther v1 = input.top();
229
+ input.pop();
230
+ float dist_v1_q = v1.d;
231
+
232
+ bool good = true;
233
+ for (NodeDistFarther v2 : output) {
234
+ float dist_v1_v2 = qdis.symmetric_dis(v2.id, v1.id);
235
+
236
+ if (dist_v1_v2 < dist_v1_q) {
237
+ good = false;
238
+ break;
239
+ }
240
+ }
241
+
242
+ if (good) {
243
+ output.push_back(v1);
244
+ if (output.size() >= max_size) {
245
+ return;
246
+ }
247
+ }
263
248
  }
264
- }
265
249
  }
266
250
 
267
-
268
251
  namespace {
269
252
 
270
-
271
253
  using storage_idx_t = HNSW::storage_idx_t;
272
254
  using NodeDistCloser = HNSW::NodeDistCloser;
273
255
  using NodeDistFarther = HNSW::NodeDistFarther;
274
256
 
275
-
276
257
  /**************************************************************
277
258
  * Addition subroutines
278
259
  **************************************************************/
279
260
 
280
-
281
261
  /// remove neighbors from the list to make it smaller than max_size
282
262
  void shrink_neighbor_list(
283
- DistanceComputer& qdis,
284
- std::priority_queue<NodeDistCloser>& resultSet1,
285
- int max_size)
286
- {
263
+ DistanceComputer& qdis,
264
+ std::priority_queue<NodeDistCloser>& resultSet1,
265
+ int max_size) {
287
266
  if (resultSet1.size() < max_size) {
288
267
  return;
289
268
  }
@@ -300,516 +279,521 @@ void shrink_neighbor_list(
300
279
  for (NodeDistFarther curen2 : returnlist) {
301
280
  resultSet1.emplace(curen2.d, curen2.id);
302
281
  }
303
-
304
282
  }
305
283
 
306
-
307
284
  /// add a link between two elements, possibly shrinking the list
308
285
  /// of links to make room for it.
309
- void add_link(HNSW& hnsw,
310
- DistanceComputer& qdis,
311
- storage_idx_t src, storage_idx_t dest,
312
- int level)
313
- {
314
- size_t begin, end;
315
- hnsw.neighbor_range(src, level, &begin, &end);
316
- if (hnsw.neighbors[end - 1] == -1) {
317
- // there is enough room, find a slot to add it
318
- size_t i = end;
319
- while(i > begin) {
320
- if (hnsw.neighbors[i - 1] != -1) break;
321
- i--;
322
- }
323
- hnsw.neighbors[i] = dest;
324
- return;
325
- }
326
-
327
- // otherwise we let them fight out which to keep
328
-
329
- // copy to resultSet...
330
- std::priority_queue<NodeDistCloser> resultSet;
331
- resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
332
- for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
333
- storage_idx_t neigh = hnsw.neighbors[i];
334
- resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
335
- }
336
-
337
- shrink_neighbor_list(qdis, resultSet, end - begin);
338
-
339
- // ...and back
340
- size_t i = begin;
341
- while (resultSet.size()) {
342
- hnsw.neighbors[i++] = resultSet.top().id;
343
- resultSet.pop();
344
- }
345
- // they may have shrunk more than just by 1 element
346
- while(i < end) {
347
- hnsw.neighbors[i++] = -1;
348
- }
286
+ void add_link(
287
+ HNSW& hnsw,
288
+ DistanceComputer& qdis,
289
+ storage_idx_t src,
290
+ storage_idx_t dest,
291
+ int level) {
292
+ size_t begin, end;
293
+ hnsw.neighbor_range(src, level, &begin, &end);
294
+ if (hnsw.neighbors[end - 1] == -1) {
295
+ // there is enough room, find a slot to add it
296
+ size_t i = end;
297
+ while (i > begin) {
298
+ if (hnsw.neighbors[i - 1] != -1)
299
+ break;
300
+ i--;
301
+ }
302
+ hnsw.neighbors[i] = dest;
303
+ return;
304
+ }
305
+
306
+ // otherwise we let them fight out which to keep
307
+
308
+ // copy to resultSet...
309
+ std::priority_queue<NodeDistCloser> resultSet;
310
+ resultSet.emplace(qdis.symmetric_dis(src, dest), dest);
311
+ for (size_t i = begin; i < end; i++) { // HERE WAS THE BUG
312
+ storage_idx_t neigh = hnsw.neighbors[i];
313
+ resultSet.emplace(qdis.symmetric_dis(src, neigh), neigh);
314
+ }
315
+
316
+ shrink_neighbor_list(qdis, resultSet, end - begin);
317
+
318
+ // ...and back
319
+ size_t i = begin;
320
+ while (resultSet.size()) {
321
+ hnsw.neighbors[i++] = resultSet.top().id;
322
+ resultSet.pop();
323
+ }
324
+ // they may have shrunk more than just by 1 element
325
+ while (i < end) {
326
+ hnsw.neighbors[i++] = -1;
327
+ }
349
328
  }
350
329
 
351
330
  /// search neighbors on a single level, starting from an entry point
352
331
  void search_neighbors_to_add(
353
- HNSW& hnsw,
354
- DistanceComputer& qdis,
355
- std::priority_queue<NodeDistCloser>& results,
356
- int entry_point,
357
- float d_entry_point,
358
- int level,
359
- VisitedTable &vt)
360
- {
361
- // top is nearest candidate
362
- std::priority_queue<NodeDistFarther> candidates;
363
-
364
- NodeDistFarther ev(d_entry_point, entry_point);
365
- candidates.push(ev);
366
- results.emplace(d_entry_point, entry_point);
367
- vt.set(entry_point);
368
-
369
- while (!candidates.empty()) {
370
- // get nearest
371
- const NodeDistFarther &currEv = candidates.top();
372
-
373
- if (currEv.d > results.top().d) {
374
- break;
375
- }
376
- int currNode = currEv.id;
377
- candidates.pop();
378
-
379
- // loop over neighbors
380
- size_t begin, end;
381
- hnsw.neighbor_range(currNode, level, &begin, &end);
382
- for(size_t i = begin; i < end; i++) {
383
- storage_idx_t nodeId = hnsw.neighbors[i];
384
- if (nodeId < 0) break;
385
- if (vt.get(nodeId)) continue;
386
- vt.set(nodeId);
387
-
388
- float dis = qdis(nodeId);
389
- NodeDistFarther evE1(dis, nodeId);
390
-
391
- if (results.size() < hnsw.efConstruction ||
392
- results.top().d > dis) {
393
-
394
- results.emplace(dis, nodeId);
395
- candidates.emplace(dis, nodeId);
396
- if (results.size() > hnsw.efConstruction) {
397
- results.pop();
332
+ HNSW& hnsw,
333
+ DistanceComputer& qdis,
334
+ std::priority_queue<NodeDistCloser>& results,
335
+ int entry_point,
336
+ float d_entry_point,
337
+ int level,
338
+ VisitedTable& vt) {
339
+ // top is nearest candidate
340
+ std::priority_queue<NodeDistFarther> candidates;
341
+
342
+ NodeDistFarther ev(d_entry_point, entry_point);
343
+ candidates.push(ev);
344
+ results.emplace(d_entry_point, entry_point);
345
+ vt.set(entry_point);
346
+
347
+ while (!candidates.empty()) {
348
+ // get nearest
349
+ const NodeDistFarther& currEv = candidates.top();
350
+
351
+ if (currEv.d > results.top().d) {
352
+ break;
353
+ }
354
+ int currNode = currEv.id;
355
+ candidates.pop();
356
+
357
+ // loop over neighbors
358
+ size_t begin, end;
359
+ hnsw.neighbor_range(currNode, level, &begin, &end);
360
+ for (size_t i = begin; i < end; i++) {
361
+ storage_idx_t nodeId = hnsw.neighbors[i];
362
+ if (nodeId < 0)
363
+ break;
364
+ if (vt.get(nodeId))
365
+ continue;
366
+ vt.set(nodeId);
367
+
368
+ float dis = qdis(nodeId);
369
+ NodeDistFarther evE1(dis, nodeId);
370
+
371
+ if (results.size() < hnsw.efConstruction || results.top().d > dis) {
372
+ results.emplace(dis, nodeId);
373
+ candidates.emplace(dis, nodeId);
374
+ if (results.size() > hnsw.efConstruction) {
375
+ results.pop();
376
+ }
377
+ }
398
378
  }
399
- }
400
379
  }
401
- }
402
- vt.advance();
380
+ vt.advance();
403
381
  }
404
382
 
405
-
406
383
  /**************************************************************
407
384
  * Searching subroutines
408
385
  **************************************************************/
409
386
 
410
387
  /// greedily update a nearest vector at a given level
411
- void greedy_update_nearest(const HNSW& hnsw,
412
- DistanceComputer& qdis,
413
- int level,
414
- storage_idx_t& nearest,
415
- float& d_nearest)
416
- {
417
- for(;;) {
418
- storage_idx_t prev_nearest = nearest;
419
-
420
- size_t begin, end;
421
- hnsw.neighbor_range(nearest, level, &begin, &end);
422
- for(size_t i = begin; i < end; i++) {
423
- storage_idx_t v = hnsw.neighbors[i];
424
- if (v < 0) break;
425
- float dis = qdis(v);
426
- if (dis < d_nearest) {
427
- nearest = v;
428
- d_nearest = dis;
429
- }
430
- }
431
- if (nearest == prev_nearest) {
432
- return;
433
- }
434
- }
388
+ void greedy_update_nearest(
389
+ const HNSW& hnsw,
390
+ DistanceComputer& qdis,
391
+ int level,
392
+ storage_idx_t& nearest,
393
+ float& d_nearest) {
394
+ for (;;) {
395
+ storage_idx_t prev_nearest = nearest;
396
+
397
+ size_t begin, end;
398
+ hnsw.neighbor_range(nearest, level, &begin, &end);
399
+ for (size_t i = begin; i < end; i++) {
400
+ storage_idx_t v = hnsw.neighbors[i];
401
+ if (v < 0)
402
+ break;
403
+ float dis = qdis(v);
404
+ if (dis < d_nearest) {
405
+ nearest = v;
406
+ d_nearest = dis;
407
+ }
408
+ }
409
+ if (nearest == prev_nearest) {
410
+ return;
411
+ }
412
+ }
435
413
  }
436
414
 
437
-
438
- } // namespace
439
-
415
+ } // namespace
440
416
 
441
417
  /// Finds neighbors and builds links with them, starting from an entry
442
418
  /// point. The own neighbor list is assumed to be locked.
443
- void HNSW::add_links_starting_from(DistanceComputer& ptdis,
444
- storage_idx_t pt_id,
445
- storage_idx_t nearest,
446
- float d_nearest,
447
- int level,
448
- omp_lock_t *locks,
449
- VisitedTable &vt)
450
- {
451
- std::priority_queue<NodeDistCloser> link_targets;
419
+ void HNSW::add_links_starting_from(
420
+ DistanceComputer& ptdis,
421
+ storage_idx_t pt_id,
422
+ storage_idx_t nearest,
423
+ float d_nearest,
424
+ int level,
425
+ omp_lock_t* locks,
426
+ VisitedTable& vt) {
427
+ std::priority_queue<NodeDistCloser> link_targets;
452
428
 
453
- search_neighbors_to_add(*this, ptdis, link_targets, nearest, d_nearest,
454
- level, vt);
429
+ search_neighbors_to_add(
430
+ *this, ptdis, link_targets, nearest, d_nearest, level, vt);
455
431
 
456
- // but we can afford only this many neighbors
457
- int M = nb_neighbors(level);
432
+ // but we can afford only this many neighbors
433
+ int M = nb_neighbors(level);
458
434
 
459
- ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
435
+ ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
460
436
 
461
- while (!link_targets.empty()) {
462
- int other_id = link_targets.top().id;
437
+ while (!link_targets.empty()) {
438
+ int other_id = link_targets.top().id;
463
439
 
464
- omp_set_lock(&locks[other_id]);
465
- add_link(*this, ptdis, other_id, pt_id, level);
466
- omp_unset_lock(&locks[other_id]);
440
+ omp_set_lock(&locks[other_id]);
441
+ add_link(*this, ptdis, other_id, pt_id, level);
442
+ omp_unset_lock(&locks[other_id]);
467
443
 
468
- add_link(*this, ptdis, pt_id, other_id, level);
444
+ add_link(*this, ptdis, pt_id, other_id, level);
469
445
 
470
- link_targets.pop();
471
- }
446
+ link_targets.pop();
447
+ }
472
448
  }
473
449
 
474
-
475
450
  /**************************************************************
476
451
  * Building, parallel
477
452
  **************************************************************/
478
453
 
479
- void HNSW::add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
480
- std::vector<omp_lock_t>& locks,
481
- VisitedTable& vt)
482
- {
483
- // greedy search on upper levels
454
+ void HNSW::add_with_locks(
455
+ DistanceComputer& ptdis,
456
+ int pt_level,
457
+ int pt_id,
458
+ std::vector<omp_lock_t>& locks,
459
+ VisitedTable& vt) {
460
+ // greedy search on upper levels
484
461
 
485
- storage_idx_t nearest;
462
+ storage_idx_t nearest;
486
463
  #pragma omp critical
487
- {
488
- nearest = entry_point;
464
+ {
465
+ nearest = entry_point;
489
466
 
490
- if (nearest == -1) {
491
- max_level = pt_level;
492
- entry_point = pt_id;
467
+ if (nearest == -1) {
468
+ max_level = pt_level;
469
+ entry_point = pt_id;
470
+ }
493
471
  }
494
- }
495
472
 
496
- if (nearest < 0) {
497
- return;
498
- }
473
+ if (nearest < 0) {
474
+ return;
475
+ }
499
476
 
500
- omp_set_lock(&locks[pt_id]);
477
+ omp_set_lock(&locks[pt_id]);
501
478
 
502
- int level = max_level; // level at which we start adding neighbors
503
- float d_nearest = ptdis(nearest);
479
+ int level = max_level; // level at which we start adding neighbors
480
+ float d_nearest = ptdis(nearest);
504
481
 
505
- for(; level > pt_level; level--) {
506
- greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
507
- }
482
+ for (; level > pt_level; level--) {
483
+ greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
484
+ }
508
485
 
509
- for(; level >= 0; level--) {
510
- add_links_starting_from(ptdis, pt_id, nearest, d_nearest,
511
- level, locks.data(), vt);
512
- }
486
+ for (; level >= 0; level--) {
487
+ add_links_starting_from(
488
+ ptdis, pt_id, nearest, d_nearest, level, locks.data(), vt);
489
+ }
513
490
 
514
- omp_unset_lock(&locks[pt_id]);
491
+ omp_unset_lock(&locks[pt_id]);
515
492
 
516
- if (pt_level > max_level) {
517
- max_level = pt_level;
518
- entry_point = pt_id;
519
- }
493
+ if (pt_level > max_level) {
494
+ max_level = pt_level;
495
+ entry_point = pt_id;
496
+ }
520
497
  }
521
498
 
522
-
523
499
  /** Do a BFS on the candidates list */
524
500
 
525
501
  int HNSW::search_from_candidates(
526
- DistanceComputer& qdis, int k,
527
- idx_t *I, float *D,
528
- MinimaxHeap& candidates,
529
- VisitedTable& vt,
530
- HNSWStats& stats,
531
- int level, int nres_in) const
532
- {
533
- int nres = nres_in;
534
- int ndis = 0;
535
- for (int i = 0; i < candidates.size(); i++) {
536
- idx_t v1 = candidates.ids[i];
537
- float d = candidates.dis[i];
538
- FAISS_ASSERT(v1 >= 0);
539
- if (nres < k) {
540
- faiss::maxheap_push(++nres, D, I, d, v1);
541
- } else if (d < D[0]) {
542
- faiss::maxheap_replace_top(nres, D, I, d, v1);
543
- }
544
- vt.set(v1);
545
- }
546
-
547
- bool do_dis_check = check_relative_distance;
548
- int nstep = 0;
549
-
550
- while (candidates.size() > 0) {
551
- float d0 = 0;
552
- int v0 = candidates.pop_min(&d0);
553
-
554
- if (do_dis_check) {
555
- // tricky stopping condition: there are more that ef
556
- // distances that are processed already that are smaller
557
- // than d0
558
-
559
- int n_dis_below = candidates.count_below(d0);
560
- if(n_dis_below >= efSearch) {
561
- break;
562
- }
563
- }
564
-
565
- size_t begin, end;
566
- neighbor_range(v0, level, &begin, &end);
567
-
568
- for (size_t j = begin; j < end; j++) {
569
- int v1 = neighbors[j];
570
- if (v1 < 0) break;
571
- if (vt.get(v1)) {
572
- continue;
573
- }
574
- vt.set(v1);
575
- ndis++;
576
- float d = qdis(v1);
577
- if (nres < k) {
578
- faiss::maxheap_push(++nres, D, I, d, v1);
579
- } else if (d < D[0]) {
580
- faiss::maxheap_replace_top(nres, D, I, d, v1);
581
- }
582
- candidates.push(v1, d);
583
- }
584
-
585
- nstep++;
586
- if (!do_dis_check && nstep > efSearch) {
587
- break;
588
- }
589
- }
590
-
591
- if (level == 0) {
592
- stats.n1 ++;
593
- if (candidates.size() == 0) {
594
- stats.n2 ++;
502
+ DistanceComputer& qdis,
503
+ int k,
504
+ idx_t* I,
505
+ float* D,
506
+ MinimaxHeap& candidates,
507
+ VisitedTable& vt,
508
+ HNSWStats& stats,
509
+ int level,
510
+ int nres_in) const {
511
+ int nres = nres_in;
512
+ int ndis = 0;
513
+ for (int i = 0; i < candidates.size(); i++) {
514
+ idx_t v1 = candidates.ids[i];
515
+ float d = candidates.dis[i];
516
+ FAISS_ASSERT(v1 >= 0);
517
+ if (nres < k) {
518
+ faiss::maxheap_push(++nres, D, I, d, v1);
519
+ } else if (d < D[0]) {
520
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
521
+ }
522
+ vt.set(v1);
595
523
  }
596
- stats.n3 += ndis;
597
- }
598
-
599
- return nres;
600
- }
601
-
602
524
 
603
- /**************************************************************
604
- * Searching
605
- **************************************************************/
525
+ bool do_dis_check = check_relative_distance;
526
+ int nstep = 0;
606
527
 
607
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
608
- const Node& node,
609
- DistanceComputer& qdis,
610
- int ef,
611
- VisitedTable *vt,
612
- HNSWStats& stats) const
613
- {
614
- int ndis = 0;
615
- std::priority_queue<Node> top_candidates;
616
- std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
528
+ while (candidates.size() > 0) {
529
+ float d0 = 0;
530
+ int v0 = candidates.pop_min(&d0);
617
531
 
618
- top_candidates.push(node);
619
- candidates.push(node);
532
+ if (do_dis_check) {
533
+ // tricky stopping condition: there are more that ef
534
+ // distances that are processed already that are smaller
535
+ // than d0
620
536
 
621
- vt->set(node.second);
537
+ int n_dis_below = candidates.count_below(d0);
538
+ if (n_dis_below >= efSearch) {
539
+ break;
540
+ }
541
+ }
622
542
 
623
- while (!candidates.empty()) {
624
- float d0;
625
- storage_idx_t v0;
626
- std::tie(d0, v0) = candidates.top();
543
+ size_t begin, end;
544
+ neighbor_range(v0, level, &begin, &end);
545
+
546
+ for (size_t j = begin; j < end; j++) {
547
+ int v1 = neighbors[j];
548
+ if (v1 < 0)
549
+ break;
550
+ if (vt.get(v1)) {
551
+ continue;
552
+ }
553
+ vt.set(v1);
554
+ ndis++;
555
+ float d = qdis(v1);
556
+ if (nres < k) {
557
+ faiss::maxheap_push(++nres, D, I, d, v1);
558
+ } else if (d < D[0]) {
559
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
560
+ }
561
+ candidates.push(v1, d);
562
+ }
627
563
 
628
- if (d0 > top_candidates.top().first) {
629
- break;
564
+ nstep++;
565
+ if (!do_dis_check && nstep > efSearch) {
566
+ break;
567
+ }
630
568
  }
631
569
 
632
- candidates.pop();
633
-
634
- size_t begin, end;
635
- neighbor_range(v0, 0, &begin, &end);
636
-
637
- for (size_t j = begin; j < end; ++j) {
638
- int v1 = neighbors[j];
639
-
640
- if (v1 < 0) {
641
- break;
642
- }
643
- if (vt->get(v1)) {
644
- continue;
645
- }
646
-
647
- vt->set(v1);
648
-
649
- float d1 = qdis(v1);
650
- ++ndis;
651
-
652
- if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
653
- candidates.emplace(d1, v1);
654
- top_candidates.emplace(d1, v1);
655
-
656
- if (top_candidates.size() > ef) {
657
- top_candidates.pop();
570
+ if (level == 0) {
571
+ stats.n1++;
572
+ if (candidates.size() == 0) {
573
+ stats.n2++;
658
574
  }
659
- }
575
+ stats.n3 += ndis;
660
576
  }
661
- }
662
-
663
- ++stats.n1;
664
- if (candidates.size() == 0) {
665
- ++stats.n2;
666
- }
667
- stats.n3 += ndis;
668
577
 
669
- return top_candidates;
578
+ return nres;
670
579
  }
671
580
 
672
- HNSWStats HNSW::search(DistanceComputer& qdis, int k,
673
- idx_t *I, float *D,
674
- VisitedTable& vt) const
675
- {
676
- HNSWStats stats;
677
-
678
- if (upper_beam == 1) {
679
-
680
- // greedy search on upper levels
681
- storage_idx_t nearest = entry_point;
682
- float d_nearest = qdis(nearest);
581
+ /**************************************************************
582
+ * Searching
583
+ **************************************************************/
683
584
 
684
- for(int level = max_level; level >= 1; level--) {
685
- greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
686
- }
585
+ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
586
+ const Node& node,
587
+ DistanceComputer& qdis,
588
+ int ef,
589
+ VisitedTable* vt,
590
+ HNSWStats& stats) const {
591
+ int ndis = 0;
592
+ std::priority_queue<Node> top_candidates;
593
+ std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
594
+
595
+ top_candidates.push(node);
596
+ candidates.push(node);
597
+
598
+ vt->set(node.second);
599
+
600
+ while (!candidates.empty()) {
601
+ float d0;
602
+ storage_idx_t v0;
603
+ std::tie(d0, v0) = candidates.top();
604
+
605
+ if (d0 > top_candidates.top().first) {
606
+ break;
607
+ }
687
608
 
688
- int ef = std::max(efSearch, k);
689
- if (search_bounded_queue) {
690
- MinimaxHeap candidates(ef);
609
+ candidates.pop();
691
610
 
692
- candidates.push(nearest, d_nearest);
611
+ size_t begin, end;
612
+ neighbor_range(v0, 0, &begin, &end);
693
613
 
694
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
695
- } else {
696
- std::priority_queue<Node> top_candidates =
697
- search_from_candidate_unbounded(Node(d_nearest, nearest),
698
- qdis, ef, &vt, stats);
614
+ for (size_t j = begin; j < end; ++j) {
615
+ int v1 = neighbors[j];
699
616
 
700
- while (top_candidates.size() > k) {
701
- top_candidates.pop();
702
- }
617
+ if (v1 < 0) {
618
+ break;
619
+ }
620
+ if (vt->get(v1)) {
621
+ continue;
622
+ }
703
623
 
704
- int nres = 0;
705
- while (!top_candidates.empty()) {
706
- float d;
707
- storage_idx_t label;
708
- std::tie(d, label) = top_candidates.top();
709
- faiss::maxheap_push(++nres, D, I, d, label);
710
- top_candidates.pop();
711
- }
712
- }
624
+ vt->set(v1);
713
625
 
714
- vt.advance();
626
+ float d1 = qdis(v1);
627
+ ++ndis;
715
628
 
716
- } else {
717
- int candidates_size = upper_beam;
718
- MinimaxHeap candidates(candidates_size);
629
+ if (top_candidates.top().first > d1 || top_candidates.size() < ef) {
630
+ candidates.emplace(d1, v1);
631
+ top_candidates.emplace(d1, v1);
719
632
 
720
- std::vector<idx_t> I_to_next(candidates_size);
721
- std::vector<float> D_to_next(candidates_size);
633
+ if (top_candidates.size() > ef) {
634
+ top_candidates.pop();
635
+ }
636
+ }
637
+ }
638
+ }
722
639
 
723
- int nres = 1;
724
- I_to_next[0] = entry_point;
725
- D_to_next[0] = qdis(entry_point);
640
+ ++stats.n1;
641
+ if (candidates.size() == 0) {
642
+ ++stats.n2;
643
+ }
644
+ stats.n3 += ndis;
726
645
 
727
- for(int level = max_level; level >= 0; level--) {
646
+ return top_candidates;
647
+ }
728
648
 
729
- // copy I, D -> candidates
649
+ HNSWStats HNSW::search(
650
+ DistanceComputer& qdis,
651
+ int k,
652
+ idx_t* I,
653
+ float* D,
654
+ VisitedTable& vt) const {
655
+ HNSWStats stats;
656
+
657
+ if (upper_beam == 1) {
658
+ // greedy search on upper levels
659
+ storage_idx_t nearest = entry_point;
660
+ float d_nearest = qdis(nearest);
661
+
662
+ for (int level = max_level; level >= 1; level--) {
663
+ greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
664
+ }
730
665
 
731
- candidates.clear();
666
+ int ef = std::max(efSearch, k);
667
+ if (search_bounded_queue) {
668
+ MinimaxHeap candidates(ef);
669
+
670
+ candidates.push(nearest, d_nearest);
671
+
672
+ search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
673
+ } else {
674
+ std::priority_queue<Node> top_candidates =
675
+ search_from_candidate_unbounded(
676
+ Node(d_nearest, nearest), qdis, ef, &vt, stats);
677
+
678
+ while (top_candidates.size() > k) {
679
+ top_candidates.pop();
680
+ }
681
+
682
+ int nres = 0;
683
+ while (!top_candidates.empty()) {
684
+ float d;
685
+ storage_idx_t label;
686
+ std::tie(d, label) = top_candidates.top();
687
+ faiss::maxheap_push(++nres, D, I, d, label);
688
+ top_candidates.pop();
689
+ }
690
+ }
732
691
 
733
- for (int i = 0; i < nres; i++) {
734
- candidates.push(I_to_next[i], D_to_next[i]);
735
- }
692
+ vt.advance();
736
693
 
737
- if (level == 0) {
738
- nres = search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
739
- } else {
740
- nres = search_from_candidates(
741
- qdis, candidates_size,
742
- I_to_next.data(), D_to_next.data(),
743
- candidates, vt, stats, level
744
- );
745
- }
746
- vt.advance();
694
+ } else {
695
+ int candidates_size = upper_beam;
696
+ MinimaxHeap candidates(candidates_size);
697
+
698
+ std::vector<idx_t> I_to_next(candidates_size);
699
+ std::vector<float> D_to_next(candidates_size);
700
+
701
+ int nres = 1;
702
+ I_to_next[0] = entry_point;
703
+ D_to_next[0] = qdis(entry_point);
704
+
705
+ for (int level = max_level; level >= 0; level--) {
706
+ // copy I, D -> candidates
707
+
708
+ candidates.clear();
709
+
710
+ for (int i = 0; i < nres; i++) {
711
+ candidates.push(I_to_next[i], D_to_next[i]);
712
+ }
713
+
714
+ if (level == 0) {
715
+ nres = search_from_candidates(
716
+ qdis, k, I, D, candidates, vt, stats, 0);
717
+ } else {
718
+ nres = search_from_candidates(
719
+ qdis,
720
+ candidates_size,
721
+ I_to_next.data(),
722
+ D_to_next.data(),
723
+ candidates,
724
+ vt,
725
+ stats,
726
+ level);
727
+ }
728
+ vt.advance();
729
+ }
747
730
  }
748
- }
749
731
 
750
- return stats;
732
+ return stats;
751
733
  }
752
734
 
753
-
754
735
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
755
- if (k == n) {
756
- if (v >= dis[0]) return;
757
- faiss::heap_pop<HC> (k--, dis.data(), ids.data());
758
- --nvalid;
759
- }
760
- faiss::heap_push<HC> (++k, dis.data(), ids.data(), v, i);
761
- ++nvalid;
736
+ if (k == n) {
737
+ if (v >= dis[0])
738
+ return;
739
+ faiss::heap_pop<HC>(k--, dis.data(), ids.data());
740
+ --nvalid;
741
+ }
742
+ faiss::heap_push<HC>(++k, dis.data(), ids.data(), v, i);
743
+ ++nvalid;
762
744
  }
763
745
 
764
746
  float HNSW::MinimaxHeap::max() const {
765
- return dis[0];
747
+ return dis[0];
766
748
  }
767
749
 
768
750
  int HNSW::MinimaxHeap::size() const {
769
- return nvalid;
751
+ return nvalid;
770
752
  }
771
753
 
772
754
  void HNSW::MinimaxHeap::clear() {
773
- nvalid = k = 0;
755
+ nvalid = k = 0;
774
756
  }
775
757
 
776
- int HNSW::MinimaxHeap::pop_min(float *vmin_out) {
777
- assert(k > 0);
778
- // returns min. This is an O(n) operation
779
- int i = k - 1;
780
- while (i >= 0) {
781
- if (ids[i] != -1) break;
782
- i--;
783
- }
784
- if (i == -1) return -1;
785
- int imin = i;
786
- float vmin = dis[i];
787
- i--;
788
- while(i >= 0) {
789
- if (ids[i] != -1 && dis[i] < vmin) {
790
- vmin = dis[i];
791
- imin = i;
758
+ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
759
+ assert(k > 0);
760
+ // returns min. This is an O(n) operation
761
+ int i = k - 1;
762
+ while (i >= 0) {
763
+ if (ids[i] != -1)
764
+ break;
765
+ i--;
792
766
  }
767
+ if (i == -1)
768
+ return -1;
769
+ int imin = i;
770
+ float vmin = dis[i];
793
771
  i--;
794
- }
795
- if (vmin_out) *vmin_out = vmin;
796
- int ret = ids[imin];
797
- ids[imin] = -1;
798
- --nvalid;
772
+ while (i >= 0) {
773
+ if (ids[i] != -1 && dis[i] < vmin) {
774
+ vmin = dis[i];
775
+ imin = i;
776
+ }
777
+ i--;
778
+ }
779
+ if (vmin_out)
780
+ *vmin_out = vmin;
781
+ int ret = ids[imin];
782
+ ids[imin] = -1;
783
+ --nvalid;
799
784
 
800
- return ret;
785
+ return ret;
801
786
  }
802
787
 
803
788
  int HNSW::MinimaxHeap::count_below(float thresh) {
804
- int n_below = 0;
805
- for(int i = 0; i < k; i++) {
806
- if (dis[i] < thresh) {
807
- n_below++;
789
+ int n_below = 0;
790
+ for (int i = 0; i < k; i++) {
791
+ if (dis[i] < thresh) {
792
+ n_below++;
793
+ }
808
794
  }
809
- }
810
795
 
811
- return n_below;
796
+ return n_below;
812
797
  }
813
798
 
814
-
815
- } // namespace faiss
799
+ } // namespace faiss