faiss 0.1.5 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,154 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <algorithm>
13
+ #include <mutex>
14
+ #include <queue>
15
+ #include <random>
16
+ #include <unordered_set>
17
+ #include <vector>
18
+
19
+ #include <omp.h>
20
+
21
+ #include <faiss/Index.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/platform_macros.h>
24
+ #include <faiss/utils/Heap.h>
25
+ #include <faiss/utils/random.h>
26
+
27
+ namespace faiss {
28
+
29
+ /** Implementation of NNDescent which is one of the most popular
30
+ * KNN graph building algorithms
31
+ *
32
+ * Efficient K-Nearest Neighbor Graph Construction for Generic
33
+ * Similarity Measures
34
+ *
35
+ * Dong, Wei, Charikar Moses, and Kai Li, WWW 2011
36
+ *
37
+ * This implmentation is heavily influenced by the efanna
38
+ * implementation by Cong Fu and the KGraph library by Wei Dong
39
+ * (https://github.com/ZJULearning/efanna_graph)
40
+ * (https://github.com/aaalgo/kgraph)
41
+ *
42
+ * The NNDescent object stores only the neighbor link structure,
43
+ * see IndexNNDescent.h for the full index object.
44
+ */
45
+
46
+ struct VisitedTable;
47
+ struct DistanceComputer;
48
+
49
+ namespace nndescent {
50
+
51
+ struct Neighbor {
52
+ int id;
53
+ float distance;
54
+ bool flag;
55
+
56
+ Neighbor() = default;
57
+ Neighbor(int id, float distance, bool f)
58
+ : id(id), distance(distance), flag(f) {}
59
+
60
+ inline bool operator<(const Neighbor& other) const {
61
+ return distance < other.distance;
62
+ }
63
+ };
64
+
65
+ struct Nhood {
66
+ std::mutex lock;
67
+ std::vector<Neighbor> pool; // candidate pool (a max heap)
68
+ int M; // number of new neighbors to be operated
69
+
70
+ std::vector<int> nn_old; // old neighbors
71
+ std::vector<int> nn_new; // new neighbors
72
+ std::vector<int> rnn_old; // reverse old neighbors
73
+ std::vector<int> rnn_new; // reverse new neighbors
74
+
75
+ Nhood() = default;
76
+
77
+ Nhood(int l, int s, std::mt19937& rng, int N);
78
+
79
+ Nhood& operator=(const Nhood& other);
80
+
81
+ Nhood(const Nhood& other);
82
+
83
+ void insert(int id, float dist);
84
+
85
+ template <typename C>
86
+ void join(C callback) const;
87
+ };
88
+
89
+ } // namespace nndescent
90
+
91
+ struct NNDescent {
92
+ using storage_idx_t = int;
93
+ using idx_t = Index::idx_t;
94
+
95
+ using KNNGraph = std::vector<nndescent::Nhood>;
96
+
97
+ explicit NNDescent(const int d, const int K);
98
+
99
+ ~NNDescent();
100
+
101
+ void build(DistanceComputer& qdis, const int n, bool verbose);
102
+
103
+ void search(
104
+ DistanceComputer& qdis,
105
+ const int topk,
106
+ idx_t* indices,
107
+ float* dists,
108
+ VisitedTable& vt) const;
109
+
110
+ void reset();
111
+
112
+ /// Initialize the KNN graph randomly
113
+ void init_graph(DistanceComputer& qdis);
114
+
115
+ /// Perform NNDescent algorithm
116
+ void nndescent(DistanceComputer& qdis, bool verbose);
117
+
118
+ /// Perform local join on each node
119
+ void join(DistanceComputer& qdis);
120
+
121
+ /// Sample new neighbors for each node to peform local join later
122
+ void update();
123
+
124
+ /// Sample a small number of points to evaluate the quality of KNNG built
125
+ void generate_eval_set(
126
+ DistanceComputer& qdis,
127
+ std::vector<int>& c,
128
+ std::vector<std::vector<int>>& v,
129
+ int N);
130
+
131
+ /// Evaluate the quality of KNNG built
132
+ float eval_recall(
133
+ std::vector<int>& ctrl_points,
134
+ std::vector<std::vector<int>>& acc_eval_set);
135
+
136
+ bool has_built;
137
+
138
+ int K; // K in KNN graph
139
+ int S; // number of sample neighbors to be updated for each node
140
+ int R; // size of reverse links, 0 means the reverse links will not be used
141
+ int L; // size of the candidate pool in building
142
+ int iter; // number of iterations to iterate over
143
+ int search_L; // size of candidate pool in searching
144
+ int random_seed; // random seed for generators
145
+
146
+ int d; // dimensions
147
+
148
+ int ntotal;
149
+
150
+ KNNGraph graph;
151
+ std::vector<int> final_graph;
152
+ };
153
+
154
+ } // namespace faiss
@@ -0,0 +1,682 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NSG.h>
11
+
12
+ #include <algorithm>
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <stack>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+
19
+ namespace faiss {
20
+
21
+ namespace nsg {
22
+
23
+ namespace {
24
+
25
+ // It needs to be smaller than 0
26
+ constexpr int EMPTY_ID = -1;
27
+
28
+ /* Wrap the distance computer into one that negates the
29
+ distances. This makes supporting INNER_PRODUCE search easier */
30
+
31
+ struct NegativeDistanceComputer : DistanceComputer {
32
+ using idx_t = Index::idx_t;
33
+
34
+ /// owned by this
35
+ DistanceComputer* basedis;
36
+
37
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
38
+ : basedis(basedis) {}
39
+
40
+ void set_query(const float* x) override {
41
+ basedis->set_query(x);
42
+ }
43
+
44
+ /// compute distance of vector i to current query
45
+ float operator()(idx_t i) override {
46
+ return -(*basedis)(i);
47
+ }
48
+
49
+ /// compute distance between two stored vectors
50
+ float symmetric_dis(idx_t i, idx_t j) override {
51
+ return -basedis->symmetric_dis(i, j);
52
+ }
53
+
54
+ ~NegativeDistanceComputer() override {
55
+ delete basedis;
56
+ }
57
+ };
58
+
59
+ } // namespace
60
+
61
+ DistanceComputer* storage_distance_computer(const Index* storage) {
62
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
63
+ return new NegativeDistanceComputer(storage->get_distance_computer());
64
+ } else {
65
+ return storage->get_distance_computer();
66
+ }
67
+ }
68
+
69
+ } // namespace nsg
70
+
71
+ using namespace nsg;
72
+
73
+ using LockGuard = std::lock_guard<std::mutex>;
74
+
75
+ struct Neighbor {
76
+ int id;
77
+ float distance;
78
+ bool flag;
79
+
80
+ Neighbor() = default;
81
+ Neighbor(int id, float distance, bool f)
82
+ : id(id), distance(distance), flag(f) {}
83
+
84
+ inline bool operator<(const Neighbor& other) const {
85
+ return distance < other.distance;
86
+ }
87
+ };
88
+
89
+ struct Node {
90
+ int id;
91
+ float distance;
92
+
93
+ Node() = default;
94
+ Node(int id, float distance) : id(id), distance(distance) {}
95
+
96
+ inline bool operator<(const Node& other) const {
97
+ return distance < other.distance;
98
+ }
99
+ };
100
+
101
+ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
102
+ // find the location to insert
103
+ int left = 0, right = K - 1;
104
+ if (addr[left].distance > nn.distance) {
105
+ memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
106
+ addr[left] = nn;
107
+ return left;
108
+ }
109
+ if (addr[right].distance < nn.distance) {
110
+ addr[K] = nn;
111
+ return K;
112
+ }
113
+ while (left < right - 1) {
114
+ int mid = (left + right) / 2;
115
+ if (addr[mid].distance > nn.distance) {
116
+ right = mid;
117
+ } else {
118
+ left = mid;
119
+ }
120
+ }
121
+ // check equal ID
122
+
123
+ while (left > 0) {
124
+ if (addr[left].distance < nn.distance) {
125
+ break;
126
+ }
127
+ if (addr[left].id == nn.id) {
128
+ return K + 1;
129
+ }
130
+ left--;
131
+ }
132
+ if (addr[left].id == nn.id || addr[right].id == nn.id) {
133
+ return K + 1;
134
+ }
135
+ memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
136
+ addr[right] = nn;
137
+ return right;
138
+ }
139
+
140
+ NSG::NSG(int R) : R(R), rng(0x0903) {
141
+ L = R + 32;
142
+ C = R + 100;
143
+ search_L = 16;
144
+ ntotal = 0;
145
+ is_built = false;
146
+ srand(0x1998);
147
+ }
148
+
149
+ void NSG::search(
150
+ DistanceComputer& dis,
151
+ int k,
152
+ idx_t* I,
153
+ float* D,
154
+ VisitedTable& vt) const {
155
+ FAISS_THROW_IF_NOT(is_built);
156
+ FAISS_THROW_IF_NOT(final_graph);
157
+
158
+ int pool_size = std::max(search_L, k);
159
+ std::vector<Neighbor> retset;
160
+ std::vector<Node> tmp;
161
+ search_on_graph<false>(
162
+ *final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
163
+
164
+ std::partial_sort(
165
+ retset.begin(), retset.begin() + k, retset.begin() + pool_size);
166
+
167
+ for (size_t i = 0; i < k; i++) {
168
+ I[i] = retset[i].id;
169
+ D[i] = retset[i].distance;
170
+ }
171
+ }
172
+
173
+ void NSG::build(
174
+ Index* storage,
175
+ idx_t n,
176
+ const nsg::Graph<idx_t>& knn_graph,
177
+ bool verbose) {
178
+ FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
179
+
180
+ if (verbose) {
181
+ printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
182
+ }
183
+
184
+ ntotal = n;
185
+ init_graph(storage, knn_graph);
186
+
187
+ std::vector<int> degrees(n, 0);
188
+ {
189
+ nsg::Graph<Node> tmp_graph(n, R);
190
+
191
+ link(storage, knn_graph, tmp_graph, verbose);
192
+
193
+ final_graph = std::make_shared<nsg::Graph<int>>(n, R);
194
+ std::fill_n(final_graph->data, n * R, EMPTY_ID);
195
+
196
+ #pragma omp parallel for
197
+ for (int i = 0; i < n; i++) {
198
+ int cnt = 0;
199
+ for (int j = 0; j < R; j++) {
200
+ int id = tmp_graph.at(i, j).id;
201
+ if (id != EMPTY_ID) {
202
+ final_graph->at(i, cnt) = id;
203
+ cnt += 1;
204
+ }
205
+ degrees[i] = cnt;
206
+ }
207
+ }
208
+ }
209
+
210
+ int num_attached = tree_grow(storage, degrees);
211
+ check_graph();
212
+ is_built = true;
213
+
214
+ if (verbose) {
215
+ int max = 0, min = 1e6;
216
+ double avg = 0;
217
+
218
+ for (int i = 0; i < n; i++) {
219
+ int size = 0;
220
+ while (size < R && final_graph->at(i, size) != EMPTY_ID) {
221
+ size += 1;
222
+ }
223
+ max = std::max(size, max);
224
+ min = std::min(size, min);
225
+ avg += size;
226
+ }
227
+
228
+ avg = avg / n;
229
+ printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
230
+ max,
231
+ min,
232
+ avg);
233
+ printf("Attached nodes: %d\n", num_attached);
234
+ }
235
+ }
236
+
237
+ void NSG::reset() {
238
+ final_graph.reset();
239
+ ntotal = 0;
240
+ is_built = false;
241
+ }
242
+
243
+ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
244
+ int d = storage->d;
245
+ int n = storage->ntotal;
246
+
247
+ std::unique_ptr<float[]> center(new float[d]);
248
+ std::unique_ptr<float[]> tmp(new float[d]);
249
+ std::fill_n(center.get(), d, 0.0f);
250
+
251
+ for (int i = 0; i < n; i++) {
252
+ storage->reconstruct(i, tmp.get());
253
+ for (int j = 0; j < d; j++) {
254
+ center[j] += tmp[j];
255
+ }
256
+ }
257
+
258
+ for (int i = 0; i < d; i++) {
259
+ center[i] /= n;
260
+ }
261
+
262
+ std::vector<Neighbor> retset;
263
+ std::vector<Node> tmpset;
264
+
265
+ // random initialize navigating point
266
+ int ep = rng.rand_int(n);
267
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
268
+
269
+ dis->set_query(center.get());
270
+ VisitedTable vt(ntotal);
271
+
272
+ // Do not collect the visited nodes
273
+ search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
274
+
275
+ // set enterpoint
276
+ enterpoint = retset[0].id;
277
+ }
278
+
279
+ template <bool collect_fullset, class index_t>
280
+ void NSG::search_on_graph(
281
+ const nsg::Graph<index_t>& graph,
282
+ DistanceComputer& dis,
283
+ VisitedTable& vt,
284
+ int ep,
285
+ int pool_size,
286
+ std::vector<Neighbor>& retset,
287
+ std::vector<Node>& fullset) const {
288
+ RandomGenerator gen(0x1234);
289
+ retset.resize(pool_size + 1);
290
+ std::vector<int> init_ids(pool_size);
291
+
292
+ int num_ids = 0;
293
+ for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
294
+ int id = (int)graph.at(ep, i);
295
+ if (id < 0 || id >= ntotal) {
296
+ continue;
297
+ }
298
+
299
+ init_ids[i] = id;
300
+ vt.set(id);
301
+ num_ids += 1;
302
+ }
303
+
304
+ while (num_ids < pool_size) {
305
+ int id = gen.rand_int(ntotal);
306
+ if (vt.get(id)) {
307
+ continue;
308
+ }
309
+
310
+ init_ids[num_ids] = id;
311
+ num_ids++;
312
+ vt.set(id);
313
+ }
314
+
315
+ for (int i = 0; i < init_ids.size(); i++) {
316
+ int id = init_ids[i];
317
+
318
+ float dist = dis(id);
319
+ retset[i] = Neighbor(id, dist, true);
320
+
321
+ if (collect_fullset) {
322
+ fullset.emplace_back(retset[i].id, retset[i].distance);
323
+ }
324
+ }
325
+
326
+ std::sort(retset.begin(), retset.begin() + pool_size);
327
+
328
+ int k = 0;
329
+ while (k < pool_size) {
330
+ int updated_pos = pool_size;
331
+
332
+ if (retset[k].flag) {
333
+ retset[k].flag = false;
334
+ int n = retset[k].id;
335
+
336
+ for (int m = 0; m < graph.K; m++) {
337
+ int id = (int)graph.at(n, m);
338
+ if (id < 0 || id > ntotal || vt.get(id)) {
339
+ continue;
340
+ }
341
+ vt.set(id);
342
+
343
+ float dist = dis(id);
344
+ Neighbor nn(id, dist, true);
345
+ if (collect_fullset) {
346
+ fullset.emplace_back(id, dist);
347
+ }
348
+
349
+ if (dist >= retset[pool_size - 1].distance) {
350
+ continue;
351
+ }
352
+
353
+ int r = insert_into_pool(retset.data(), pool_size, nn);
354
+
355
+ updated_pos = std::min(updated_pos, r);
356
+ }
357
+ }
358
+
359
+ k = (updated_pos <= k) ? updated_pos : (k + 1);
360
+ }
361
+ }
362
+
363
+ void NSG::link(
364
+ Index* storage,
365
+ const nsg::Graph<idx_t>& knn_graph,
366
+ nsg::Graph<Node>& graph,
367
+ bool /* verbose */) {
368
+ #pragma omp parallel
369
+ {
370
+ std::unique_ptr<float[]> vec(new float[storage->d]);
371
+
372
+ std::vector<Node> pool;
373
+ std::vector<Neighbor> tmp;
374
+
375
+ VisitedTable vt(ntotal);
376
+ std::unique_ptr<DistanceComputer> dis(
377
+ storage_distance_computer(storage));
378
+
379
+ #pragma omp for schedule(dynamic, 100)
380
+ for (int i = 0; i < ntotal; i++) {
381
+ storage->reconstruct(i, vec.get());
382
+ dis->set_query(vec.get());
383
+
384
+ // Collect the visited nodes into pool
385
+ search_on_graph<true>(
386
+ knn_graph, *dis, vt, enterpoint, L, tmp, pool);
387
+
388
+ sync_prune(i, pool, *dis, vt, knn_graph, graph);
389
+
390
+ pool.clear();
391
+ tmp.clear();
392
+ vt.advance();
393
+ }
394
+ } // omp parallel
395
+
396
+ std::vector<std::mutex> locks(ntotal);
397
+ #pragma omp parallel
398
+ {
399
+ std::unique_ptr<DistanceComputer> dis(
400
+ storage_distance_computer(storage));
401
+
402
+ #pragma omp for schedule(dynamic, 100)
403
+ for (int i = 0; i < ntotal; ++i) {
404
+ add_reverse_links(i, locks, *dis, graph);
405
+ }
406
+ } // omp parallel
407
+ }
408
+
409
+ void NSG::sync_prune(
410
+ int q,
411
+ std::vector<Node>& pool,
412
+ DistanceComputer& dis,
413
+ VisitedTable& vt,
414
+ const nsg::Graph<idx_t>& knn_graph,
415
+ nsg::Graph<Node>& graph) {
416
+ for (int i = 0; i < knn_graph.K; i++) {
417
+ int id = knn_graph.at(q, i);
418
+ if (id < 0 || id >= ntotal || vt.get(id)) {
419
+ continue;
420
+ }
421
+
422
+ float dist = dis.symmetric_dis(q, id);
423
+ pool.emplace_back(id, dist);
424
+ }
425
+
426
+ std::sort(pool.begin(), pool.end());
427
+
428
+ std::vector<Node> result;
429
+
430
+ int start = 0;
431
+ if (pool[start].id == q) {
432
+ start++;
433
+ }
434
+ result.push_back(pool[start]);
435
+
436
+ while (result.size() < R && (++start) < pool.size() && start < C) {
437
+ auto& p = pool[start];
438
+ bool occlude = false;
439
+ for (int t = 0; t < result.size(); t++) {
440
+ if (p.id == result[t].id) {
441
+ occlude = true;
442
+ break;
443
+ }
444
+ float djk = dis.symmetric_dis(result[t].id, p.id);
445
+ if (djk < p.distance /* dik */) {
446
+ occlude = true;
447
+ break;
448
+ }
449
+ }
450
+ if (!occlude) {
451
+ result.push_back(p);
452
+ }
453
+ }
454
+
455
+ for (size_t i = 0; i < R; i++) {
456
+ if (i < result.size()) {
457
+ graph.at(q, i).id = result[i].id;
458
+ graph.at(q, i).distance = result[i].distance;
459
+ } else {
460
+ graph.at(q, i).id = EMPTY_ID;
461
+ }
462
+ }
463
+ }
464
+
465
+ void NSG::add_reverse_links(
466
+ int q,
467
+ std::vector<std::mutex>& locks,
468
+ DistanceComputer& dis,
469
+ nsg::Graph<Node>& graph) {
470
+ for (size_t i = 0; i < R; i++) {
471
+ if (graph.at(q, i).id == EMPTY_ID) {
472
+ break;
473
+ }
474
+
475
+ Node sn(q, graph.at(q, i).distance);
476
+ int des = graph.at(q, i).id;
477
+
478
+ std::vector<Node> tmp_pool;
479
+ int dup = 0;
480
+ {
481
+ LockGuard guard(locks[des]);
482
+ for (int j = 0; j < R; j++) {
483
+ if (graph.at(des, j).id == EMPTY_ID) {
484
+ break;
485
+ }
486
+ if (q == graph.at(des, j).id) {
487
+ dup = 1;
488
+ break;
489
+ }
490
+ tmp_pool.push_back(graph.at(des, j));
491
+ }
492
+ }
493
+
494
+ if (dup) {
495
+ continue;
496
+ }
497
+
498
+ tmp_pool.push_back(sn);
499
+ if (tmp_pool.size() > R) {
500
+ std::vector<Node> result;
501
+ int start = 0;
502
+ std::sort(tmp_pool.begin(), tmp_pool.end());
503
+ result.push_back(tmp_pool[start]);
504
+
505
+ while (result.size() < R && (++start) < tmp_pool.size()) {
506
+ auto& p = tmp_pool[start];
507
+ bool occlude = false;
508
+
509
+ for (int t = 0; t < result.size(); t++) {
510
+ if (p.id == result[t].id) {
511
+ occlude = true;
512
+ break;
513
+ }
514
+ float djk = dis.symmetric_dis(result[t].id, p.id);
515
+ if (djk < p.distance /* dik */) {
516
+ occlude = true;
517
+ break;
518
+ }
519
+ }
520
+
521
+ if (!occlude) {
522
+ result.push_back(p);
523
+ }
524
+ }
525
+
526
+ {
527
+ LockGuard guard(locks[des]);
528
+ for (int t = 0; t < result.size(); t++) {
529
+ graph.at(des, t) = result[t];
530
+ }
531
+ }
532
+
533
+ } else {
534
+ LockGuard guard(locks[des]);
535
+ for (int t = 0; t < R; t++) {
536
+ if (graph.at(des, t).id == EMPTY_ID) {
537
+ graph.at(des, t) = sn;
538
+ break;
539
+ }
540
+ }
541
+ }
542
+ }
543
+ }
544
+
545
+ int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
546
+ int root = enterpoint;
547
+ VisitedTable vt(ntotal);
548
+ VisitedTable vt2(ntotal);
549
+
550
+ int num_attached = 0;
551
+ int cnt = 0;
552
+ while (true) {
553
+ cnt = dfs(vt, root, cnt);
554
+ if (cnt >= ntotal) {
555
+ break;
556
+ }
557
+
558
+ root = attach_unlinked(storage, vt, vt2, degrees);
559
+ vt2.advance();
560
+ num_attached += 1;
561
+ }
562
+
563
+ return num_attached;
564
+ }
565
+
566
+ int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
567
+ int node = root;
568
+ std::stack<int> stack;
569
+ stack.push(root);
570
+
571
+ if (!vt.get(root)) {
572
+ cnt++;
573
+ }
574
+ vt.set(root);
575
+
576
+ while (!stack.empty()) {
577
+ int next = EMPTY_ID;
578
+ for (int i = 0; i < R; i++) {
579
+ int id = final_graph->at(node, i);
580
+ if (id != EMPTY_ID && !vt.get(id)) {
581
+ next = id;
582
+ break;
583
+ }
584
+ }
585
+
586
+ if (next == EMPTY_ID) {
587
+ stack.pop();
588
+ if (stack.empty()) {
589
+ break;
590
+ }
591
+ node = stack.top();
592
+ continue;
593
+ }
594
+ node = next;
595
+ vt.set(node);
596
+ stack.push(node);
597
+ cnt++;
598
+ }
599
+
600
+ return cnt;
601
+ }
602
+
603
+ int NSG::attach_unlinked(
604
+ Index* storage,
605
+ VisitedTable& vt,
606
+ VisitedTable& vt2,
607
+ std::vector<int>& degrees) {
608
+ /* NOTE: This implementation is slightly different from the original paper.
609
+ *
610
+ * Instead of connecting the unlinked node to the nearest point in the
611
+ * spanning tree which will increase the maximum degree of the graph and
612
+ * also make the graph hard to maintain, this implementation links the
613
+ * unlinked node to the nearest node of which the degree is smaller than R.
614
+ * It will keep the degree of all nodes to be no more than `R`.
615
+ */
616
+
617
+ // find one unlinked node
618
+ int id = EMPTY_ID;
619
+ for (int i = 0; i < ntotal; i++) {
620
+ if (!vt.get(i)) {
621
+ id = i;
622
+ break;
623
+ }
624
+ }
625
+
626
+ if (id == EMPTY_ID) {
627
+ return EMPTY_ID; // No Unlinked Node
628
+ }
629
+
630
+ std::vector<Neighbor> tmp;
631
+ std::vector<Node> pool;
632
+
633
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
634
+ std::unique_ptr<float[]> vec(new float[storage->d]);
635
+
636
+ storage->reconstruct(id, vec.get());
637
+ dis->set_query(vec.get());
638
+
639
+ // Collect the visited nodes into pool
640
+ search_on_graph<true>(
641
+ *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
642
+
643
+ std::sort(pool.begin(), pool.end());
644
+
645
+ int node;
646
+ bool found = false;
647
+ for (int i = 0; i < pool.size(); i++) {
648
+ node = pool[i].id;
649
+ if (degrees[node] < R && node != id) {
650
+ found = true;
651
+ break;
652
+ }
653
+ }
654
+
655
+ // randomly choice annother node
656
+ if (!found) {
657
+ do {
658
+ node = rng.rand_int(ntotal);
659
+ if (vt.get(node) && degrees[node] < R && node != id) {
660
+ found = true;
661
+ }
662
+ } while (!found);
663
+ }
664
+
665
+ int pos = degrees[node];
666
+ final_graph->at(node, pos) = id; // replace
667
+ degrees[node] += 1;
668
+
669
+ return node;
670
+ }
671
+
672
+ void NSG::check_graph() const {
673
+ #pragma omp parallel for
674
+ for (int i = 0; i < ntotal; i++) {
675
+ for (int j = 0; j < R; j++) {
676
+ int id = final_graph->at(i, j);
677
+ FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
678
+ }
679
+ }
680
+ }
681
+
682
+ } // namespace faiss