faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,154 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <algorithm>
13
+ #include <mutex>
14
+ #include <queue>
15
+ #include <random>
16
+ #include <unordered_set>
17
+ #include <vector>
18
+
19
+ #include <omp.h>
20
+
21
+ #include <faiss/Index.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/platform_macros.h>
24
+ #include <faiss/utils/Heap.h>
25
+ #include <faiss/utils/random.h>
26
+
27
+ namespace faiss {
28
+
29
+ /** Implementation of NNDescent which is one of the most popular
30
+ * KNN graph building algorithms
31
+ *
32
+ * Efficient K-Nearest Neighbor Graph Construction for Generic
33
+ * Similarity Measures
34
+ *
35
+ * Dong, Wei, Charikar Moses, and Kai Li, WWW 2011
36
+ *
37
+ * This implmentation is heavily influenced by the efanna
38
+ * implementation by Cong Fu and the KGraph library by Wei Dong
39
+ * (https://github.com/ZJULearning/efanna_graph)
40
+ * (https://github.com/aaalgo/kgraph)
41
+ *
42
+ * The NNDescent object stores only the neighbor link structure,
43
+ * see IndexNNDescent.h for the full index object.
44
+ */
45
+
46
+ struct VisitedTable;
47
+ struct DistanceComputer;
48
+
49
+ namespace nndescent {
50
+
51
+ struct Neighbor {
52
+ int id;
53
+ float distance;
54
+ bool flag;
55
+
56
+ Neighbor() = default;
57
+ Neighbor(int id, float distance, bool f)
58
+ : id(id), distance(distance), flag(f) {}
59
+
60
+ inline bool operator<(const Neighbor& other) const {
61
+ return distance < other.distance;
62
+ }
63
+ };
64
+
65
+ struct Nhood {
66
+ std::mutex lock;
67
+ std::vector<Neighbor> pool; // candidate pool (a max heap)
68
+ int M; // number of new neighbors to be operated
69
+
70
+ std::vector<int> nn_old; // old neighbors
71
+ std::vector<int> nn_new; // new neighbors
72
+ std::vector<int> rnn_old; // reverse old neighbors
73
+ std::vector<int> rnn_new; // reverse new neighbors
74
+
75
+ Nhood() = default;
76
+
77
+ Nhood(int l, int s, std::mt19937& rng, int N);
78
+
79
+ Nhood& operator=(const Nhood& other);
80
+
81
+ Nhood(const Nhood& other);
82
+
83
+ void insert(int id, float dist);
84
+
85
+ template <typename C>
86
+ void join(C callback) const;
87
+ };
88
+
89
+ } // namespace nndescent
90
+
91
+ struct NNDescent {
92
+ using storage_idx_t = int;
93
+ using idx_t = Index::idx_t;
94
+
95
+ using KNNGraph = std::vector<nndescent::Nhood>;
96
+
97
+ explicit NNDescent(const int d, const int K);
98
+
99
+ ~NNDescent();
100
+
101
+ void build(DistanceComputer& qdis, const int n, bool verbose);
102
+
103
+ void search(
104
+ DistanceComputer& qdis,
105
+ const int topk,
106
+ idx_t* indices,
107
+ float* dists,
108
+ VisitedTable& vt) const;
109
+
110
+ void reset();
111
+
112
+ /// Initialize the KNN graph randomly
113
+ void init_graph(DistanceComputer& qdis);
114
+
115
+ /// Perform NNDescent algorithm
116
+ void nndescent(DistanceComputer& qdis, bool verbose);
117
+
118
+ /// Perform local join on each node
119
+ void join(DistanceComputer& qdis);
120
+
121
+ /// Sample new neighbors for each node to peform local join later
122
+ void update();
123
+
124
+ /// Sample a small number of points to evaluate the quality of KNNG built
125
+ void generate_eval_set(
126
+ DistanceComputer& qdis,
127
+ std::vector<int>& c,
128
+ std::vector<std::vector<int>>& v,
129
+ int N);
130
+
131
+ /// Evaluate the quality of KNNG built
132
+ float eval_recall(
133
+ std::vector<int>& ctrl_points,
134
+ std::vector<std::vector<int>>& acc_eval_set);
135
+
136
+ bool has_built;
137
+
138
+ int K; // K in KNN graph
139
+ int S; // number of sample neighbors to be updated for each node
140
+ int R; // size of reverse links, 0 means the reverse links will not be used
141
+ int L; // size of the candidate pool in building
142
+ int iter; // number of iterations to iterate over
143
+ int search_L; // size of candidate pool in searching
144
+ int random_seed; // random seed for generators
145
+
146
+ int d; // dimensions
147
+
148
+ int ntotal;
149
+
150
+ KNNGraph graph;
151
+ std::vector<int> final_graph;
152
+ };
153
+
154
+ } // namespace faiss
@@ -0,0 +1,682 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NSG.h>
11
+
12
+ #include <algorithm>
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <stack>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+
19
+ namespace faiss {
20
+
21
+ namespace nsg {
22
+
23
+ namespace {
24
+
25
+ // It needs to be smaller than 0
26
+ constexpr int EMPTY_ID = -1;
27
+
28
+ /* Wrap the distance computer into one that negates the
29
+ distances. This makes supporting INNER_PRODUCE search easier */
30
+
31
+ struct NegativeDistanceComputer : DistanceComputer {
32
+ using idx_t = Index::idx_t;
33
+
34
+ /// owned by this
35
+ DistanceComputer* basedis;
36
+
37
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
38
+ : basedis(basedis) {}
39
+
40
+ void set_query(const float* x) override {
41
+ basedis->set_query(x);
42
+ }
43
+
44
+ /// compute distance of vector i to current query
45
+ float operator()(idx_t i) override {
46
+ return -(*basedis)(i);
47
+ }
48
+
49
+ /// compute distance between two stored vectors
50
+ float symmetric_dis(idx_t i, idx_t j) override {
51
+ return -basedis->symmetric_dis(i, j);
52
+ }
53
+
54
+ ~NegativeDistanceComputer() override {
55
+ delete basedis;
56
+ }
57
+ };
58
+
59
+ } // namespace
60
+
61
+ DistanceComputer* storage_distance_computer(const Index* storage) {
62
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
63
+ return new NegativeDistanceComputer(storage->get_distance_computer());
64
+ } else {
65
+ return storage->get_distance_computer();
66
+ }
67
+ }
68
+
69
+ } // namespace nsg
70
+
71
+ using namespace nsg;
72
+
73
+ using LockGuard = std::lock_guard<std::mutex>;
74
+
75
+ struct Neighbor {
76
+ int id;
77
+ float distance;
78
+ bool flag;
79
+
80
+ Neighbor() = default;
81
+ Neighbor(int id, float distance, bool f)
82
+ : id(id), distance(distance), flag(f) {}
83
+
84
+ inline bool operator<(const Neighbor& other) const {
85
+ return distance < other.distance;
86
+ }
87
+ };
88
+
89
+ struct Node {
90
+ int id;
91
+ float distance;
92
+
93
+ Node() = default;
94
+ Node(int id, float distance) : id(id), distance(distance) {}
95
+
96
+ inline bool operator<(const Node& other) const {
97
+ return distance < other.distance;
98
+ }
99
+ };
100
+
101
+ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
102
+ // find the location to insert
103
+ int left = 0, right = K - 1;
104
+ if (addr[left].distance > nn.distance) {
105
+ memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
106
+ addr[left] = nn;
107
+ return left;
108
+ }
109
+ if (addr[right].distance < nn.distance) {
110
+ addr[K] = nn;
111
+ return K;
112
+ }
113
+ while (left < right - 1) {
114
+ int mid = (left + right) / 2;
115
+ if (addr[mid].distance > nn.distance) {
116
+ right = mid;
117
+ } else {
118
+ left = mid;
119
+ }
120
+ }
121
+ // check equal ID
122
+
123
+ while (left > 0) {
124
+ if (addr[left].distance < nn.distance) {
125
+ break;
126
+ }
127
+ if (addr[left].id == nn.id) {
128
+ return K + 1;
129
+ }
130
+ left--;
131
+ }
132
+ if (addr[left].id == nn.id || addr[right].id == nn.id) {
133
+ return K + 1;
134
+ }
135
+ memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
136
+ addr[right] = nn;
137
+ return right;
138
+ }
139
+
140
+ NSG::NSG(int R) : R(R), rng(0x0903) {
141
+ L = R + 32;
142
+ C = R + 100;
143
+ search_L = 16;
144
+ ntotal = 0;
145
+ is_built = false;
146
+ srand(0x1998);
147
+ }
148
+
149
+ void NSG::search(
150
+ DistanceComputer& dis,
151
+ int k,
152
+ idx_t* I,
153
+ float* D,
154
+ VisitedTable& vt) const {
155
+ FAISS_THROW_IF_NOT(is_built);
156
+ FAISS_THROW_IF_NOT(final_graph);
157
+
158
+ int pool_size = std::max(search_L, k);
159
+ std::vector<Neighbor> retset;
160
+ std::vector<Node> tmp;
161
+ search_on_graph<false>(
162
+ *final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
163
+
164
+ std::partial_sort(
165
+ retset.begin(), retset.begin() + k, retset.begin() + pool_size);
166
+
167
+ for (size_t i = 0; i < k; i++) {
168
+ I[i] = retset[i].id;
169
+ D[i] = retset[i].distance;
170
+ }
171
+ }
172
+
173
+ void NSG::build(
174
+ Index* storage,
175
+ idx_t n,
176
+ const nsg::Graph<idx_t>& knn_graph,
177
+ bool verbose) {
178
+ FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
179
+
180
+ if (verbose) {
181
+ printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
182
+ }
183
+
184
+ ntotal = n;
185
+ init_graph(storage, knn_graph);
186
+
187
+ std::vector<int> degrees(n, 0);
188
+ {
189
+ nsg::Graph<Node> tmp_graph(n, R);
190
+
191
+ link(storage, knn_graph, tmp_graph, verbose);
192
+
193
+ final_graph = std::make_shared<nsg::Graph<int>>(n, R);
194
+ std::fill_n(final_graph->data, n * R, EMPTY_ID);
195
+
196
+ #pragma omp parallel for
197
+ for (int i = 0; i < n; i++) {
198
+ int cnt = 0;
199
+ for (int j = 0; j < R; j++) {
200
+ int id = tmp_graph.at(i, j).id;
201
+ if (id != EMPTY_ID) {
202
+ final_graph->at(i, cnt) = id;
203
+ cnt += 1;
204
+ }
205
+ degrees[i] = cnt;
206
+ }
207
+ }
208
+ }
209
+
210
+ int num_attached = tree_grow(storage, degrees);
211
+ check_graph();
212
+ is_built = true;
213
+
214
+ if (verbose) {
215
+ int max = 0, min = 1e6;
216
+ double avg = 0;
217
+
218
+ for (int i = 0; i < n; i++) {
219
+ int size = 0;
220
+ while (size < R && final_graph->at(i, size) != EMPTY_ID) {
221
+ size += 1;
222
+ }
223
+ max = std::max(size, max);
224
+ min = std::min(size, min);
225
+ avg += size;
226
+ }
227
+
228
+ avg = avg / n;
229
+ printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
230
+ max,
231
+ min,
232
+ avg);
233
+ printf("Attached nodes: %d\n", num_attached);
234
+ }
235
+ }
236
+
237
+ void NSG::reset() {
238
+ final_graph.reset();
239
+ ntotal = 0;
240
+ is_built = false;
241
+ }
242
+
243
+ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
244
+ int d = storage->d;
245
+ int n = storage->ntotal;
246
+
247
+ std::unique_ptr<float[]> center(new float[d]);
248
+ std::unique_ptr<float[]> tmp(new float[d]);
249
+ std::fill_n(center.get(), d, 0.0f);
250
+
251
+ for (int i = 0; i < n; i++) {
252
+ storage->reconstruct(i, tmp.get());
253
+ for (int j = 0; j < d; j++) {
254
+ center[j] += tmp[j];
255
+ }
256
+ }
257
+
258
+ for (int i = 0; i < d; i++) {
259
+ center[i] /= n;
260
+ }
261
+
262
+ std::vector<Neighbor> retset;
263
+ std::vector<Node> tmpset;
264
+
265
+ // random initialize navigating point
266
+ int ep = rng.rand_int(n);
267
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
268
+
269
+ dis->set_query(center.get());
270
+ VisitedTable vt(ntotal);
271
+
272
+ // Do not collect the visited nodes
273
+ search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
274
+
275
+ // set enterpoint
276
+ enterpoint = retset[0].id;
277
+ }
278
+
279
+ template <bool collect_fullset, class index_t>
280
+ void NSG::search_on_graph(
281
+ const nsg::Graph<index_t>& graph,
282
+ DistanceComputer& dis,
283
+ VisitedTable& vt,
284
+ int ep,
285
+ int pool_size,
286
+ std::vector<Neighbor>& retset,
287
+ std::vector<Node>& fullset) const {
288
+ RandomGenerator gen(0x1234);
289
+ retset.resize(pool_size + 1);
290
+ std::vector<int> init_ids(pool_size);
291
+
292
+ int num_ids = 0;
293
+ for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
294
+ int id = (int)graph.at(ep, i);
295
+ if (id < 0 || id >= ntotal) {
296
+ continue;
297
+ }
298
+
299
+ init_ids[i] = id;
300
+ vt.set(id);
301
+ num_ids += 1;
302
+ }
303
+
304
+ while (num_ids < pool_size) {
305
+ int id = gen.rand_int(ntotal);
306
+ if (vt.get(id)) {
307
+ continue;
308
+ }
309
+
310
+ init_ids[num_ids] = id;
311
+ num_ids++;
312
+ vt.set(id);
313
+ }
314
+
315
+ for (int i = 0; i < init_ids.size(); i++) {
316
+ int id = init_ids[i];
317
+
318
+ float dist = dis(id);
319
+ retset[i] = Neighbor(id, dist, true);
320
+
321
+ if (collect_fullset) {
322
+ fullset.emplace_back(retset[i].id, retset[i].distance);
323
+ }
324
+ }
325
+
326
+ std::sort(retset.begin(), retset.begin() + pool_size);
327
+
328
+ int k = 0;
329
+ while (k < pool_size) {
330
+ int updated_pos = pool_size;
331
+
332
+ if (retset[k].flag) {
333
+ retset[k].flag = false;
334
+ int n = retset[k].id;
335
+
336
+ for (int m = 0; m < graph.K; m++) {
337
+ int id = (int)graph.at(n, m);
338
+ if (id < 0 || id > ntotal || vt.get(id)) {
339
+ continue;
340
+ }
341
+ vt.set(id);
342
+
343
+ float dist = dis(id);
344
+ Neighbor nn(id, dist, true);
345
+ if (collect_fullset) {
346
+ fullset.emplace_back(id, dist);
347
+ }
348
+
349
+ if (dist >= retset[pool_size - 1].distance) {
350
+ continue;
351
+ }
352
+
353
+ int r = insert_into_pool(retset.data(), pool_size, nn);
354
+
355
+ updated_pos = std::min(updated_pos, r);
356
+ }
357
+ }
358
+
359
+ k = (updated_pos <= k) ? updated_pos : (k + 1);
360
+ }
361
+ }
362
+
363
+ void NSG::link(
364
+ Index* storage,
365
+ const nsg::Graph<idx_t>& knn_graph,
366
+ nsg::Graph<Node>& graph,
367
+ bool /* verbose */) {
368
+ #pragma omp parallel
369
+ {
370
+ std::unique_ptr<float[]> vec(new float[storage->d]);
371
+
372
+ std::vector<Node> pool;
373
+ std::vector<Neighbor> tmp;
374
+
375
+ VisitedTable vt(ntotal);
376
+ std::unique_ptr<DistanceComputer> dis(
377
+ storage_distance_computer(storage));
378
+
379
+ #pragma omp for schedule(dynamic, 100)
380
+ for (int i = 0; i < ntotal; i++) {
381
+ storage->reconstruct(i, vec.get());
382
+ dis->set_query(vec.get());
383
+
384
+ // Collect the visited nodes into pool
385
+ search_on_graph<true>(
386
+ knn_graph, *dis, vt, enterpoint, L, tmp, pool);
387
+
388
+ sync_prune(i, pool, *dis, vt, knn_graph, graph);
389
+
390
+ pool.clear();
391
+ tmp.clear();
392
+ vt.advance();
393
+ }
394
+ } // omp parallel
395
+
396
+ std::vector<std::mutex> locks(ntotal);
397
+ #pragma omp parallel
398
+ {
399
+ std::unique_ptr<DistanceComputer> dis(
400
+ storage_distance_computer(storage));
401
+
402
+ #pragma omp for schedule(dynamic, 100)
403
+ for (int i = 0; i < ntotal; ++i) {
404
+ add_reverse_links(i, locks, *dis, graph);
405
+ }
406
+ } // omp parallel
407
+ }
408
+
409
+ void NSG::sync_prune(
410
+ int q,
411
+ std::vector<Node>& pool,
412
+ DistanceComputer& dis,
413
+ VisitedTable& vt,
414
+ const nsg::Graph<idx_t>& knn_graph,
415
+ nsg::Graph<Node>& graph) {
416
+ for (int i = 0; i < knn_graph.K; i++) {
417
+ int id = knn_graph.at(q, i);
418
+ if (id < 0 || id >= ntotal || vt.get(id)) {
419
+ continue;
420
+ }
421
+
422
+ float dist = dis.symmetric_dis(q, id);
423
+ pool.emplace_back(id, dist);
424
+ }
425
+
426
+ std::sort(pool.begin(), pool.end());
427
+
428
+ std::vector<Node> result;
429
+
430
+ int start = 0;
431
+ if (pool[start].id == q) {
432
+ start++;
433
+ }
434
+ result.push_back(pool[start]);
435
+
436
+ while (result.size() < R && (++start) < pool.size() && start < C) {
437
+ auto& p = pool[start];
438
+ bool occlude = false;
439
+ for (int t = 0; t < result.size(); t++) {
440
+ if (p.id == result[t].id) {
441
+ occlude = true;
442
+ break;
443
+ }
444
+ float djk = dis.symmetric_dis(result[t].id, p.id);
445
+ if (djk < p.distance /* dik */) {
446
+ occlude = true;
447
+ break;
448
+ }
449
+ }
450
+ if (!occlude) {
451
+ result.push_back(p);
452
+ }
453
+ }
454
+
455
+ for (size_t i = 0; i < R; i++) {
456
+ if (i < result.size()) {
457
+ graph.at(q, i).id = result[i].id;
458
+ graph.at(q, i).distance = result[i].distance;
459
+ } else {
460
+ graph.at(q, i).id = EMPTY_ID;
461
+ }
462
+ }
463
+ }
464
+
465
+ void NSG::add_reverse_links(
466
+ int q,
467
+ std::vector<std::mutex>& locks,
468
+ DistanceComputer& dis,
469
+ nsg::Graph<Node>& graph) {
470
+ for (size_t i = 0; i < R; i++) {
471
+ if (graph.at(q, i).id == EMPTY_ID) {
472
+ break;
473
+ }
474
+
475
+ Node sn(q, graph.at(q, i).distance);
476
+ int des = graph.at(q, i).id;
477
+
478
+ std::vector<Node> tmp_pool;
479
+ int dup = 0;
480
+ {
481
+ LockGuard guard(locks[des]);
482
+ for (int j = 0; j < R; j++) {
483
+ if (graph.at(des, j).id == EMPTY_ID) {
484
+ break;
485
+ }
486
+ if (q == graph.at(des, j).id) {
487
+ dup = 1;
488
+ break;
489
+ }
490
+ tmp_pool.push_back(graph.at(des, j));
491
+ }
492
+ }
493
+
494
+ if (dup) {
495
+ continue;
496
+ }
497
+
498
+ tmp_pool.push_back(sn);
499
+ if (tmp_pool.size() > R) {
500
+ std::vector<Node> result;
501
+ int start = 0;
502
+ std::sort(tmp_pool.begin(), tmp_pool.end());
503
+ result.push_back(tmp_pool[start]);
504
+
505
+ while (result.size() < R && (++start) < tmp_pool.size()) {
506
+ auto& p = tmp_pool[start];
507
+ bool occlude = false;
508
+
509
+ for (int t = 0; t < result.size(); t++) {
510
+ if (p.id == result[t].id) {
511
+ occlude = true;
512
+ break;
513
+ }
514
+ float djk = dis.symmetric_dis(result[t].id, p.id);
515
+ if (djk < p.distance /* dik */) {
516
+ occlude = true;
517
+ break;
518
+ }
519
+ }
520
+
521
+ if (!occlude) {
522
+ result.push_back(p);
523
+ }
524
+ }
525
+
526
+ {
527
+ LockGuard guard(locks[des]);
528
+ for (int t = 0; t < result.size(); t++) {
529
+ graph.at(des, t) = result[t];
530
+ }
531
+ }
532
+
533
+ } else {
534
+ LockGuard guard(locks[des]);
535
+ for (int t = 0; t < R; t++) {
536
+ if (graph.at(des, t).id == EMPTY_ID) {
537
+ graph.at(des, t) = sn;
538
+ break;
539
+ }
540
+ }
541
+ }
542
+ }
543
+ }
544
+
545
+ int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
546
+ int root = enterpoint;
547
+ VisitedTable vt(ntotal);
548
+ VisitedTable vt2(ntotal);
549
+
550
+ int num_attached = 0;
551
+ int cnt = 0;
552
+ while (true) {
553
+ cnt = dfs(vt, root, cnt);
554
+ if (cnt >= ntotal) {
555
+ break;
556
+ }
557
+
558
+ root = attach_unlinked(storage, vt, vt2, degrees);
559
+ vt2.advance();
560
+ num_attached += 1;
561
+ }
562
+
563
+ return num_attached;
564
+ }
565
+
566
+ int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
567
+ int node = root;
568
+ std::stack<int> stack;
569
+ stack.push(root);
570
+
571
+ if (!vt.get(root)) {
572
+ cnt++;
573
+ }
574
+ vt.set(root);
575
+
576
+ while (!stack.empty()) {
577
+ int next = EMPTY_ID;
578
+ for (int i = 0; i < R; i++) {
579
+ int id = final_graph->at(node, i);
580
+ if (id != EMPTY_ID && !vt.get(id)) {
581
+ next = id;
582
+ break;
583
+ }
584
+ }
585
+
586
+ if (next == EMPTY_ID) {
587
+ stack.pop();
588
+ if (stack.empty()) {
589
+ break;
590
+ }
591
+ node = stack.top();
592
+ continue;
593
+ }
594
+ node = next;
595
+ vt.set(node);
596
+ stack.push(node);
597
+ cnt++;
598
+ }
599
+
600
+ return cnt;
601
+ }
602
+
603
+ int NSG::attach_unlinked(
604
+ Index* storage,
605
+ VisitedTable& vt,
606
+ VisitedTable& vt2,
607
+ std::vector<int>& degrees) {
608
+ /* NOTE: This implementation is slightly different from the original paper.
609
+ *
610
+ * Instead of connecting the unlinked node to the nearest point in the
611
+ * spanning tree which will increase the maximum degree of the graph and
612
+ * also make the graph hard to maintain, this implementation links the
613
+ * unlinked node to the nearest node of which the degree is smaller than R.
614
+ * It will keep the degree of all nodes to be no more than `R`.
615
+ */
616
+
617
+ // find one unlinked node
618
+ int id = EMPTY_ID;
619
+ for (int i = 0; i < ntotal; i++) {
620
+ if (!vt.get(i)) {
621
+ id = i;
622
+ break;
623
+ }
624
+ }
625
+
626
+ if (id == EMPTY_ID) {
627
+ return EMPTY_ID; // No Unlinked Node
628
+ }
629
+
630
+ std::vector<Neighbor> tmp;
631
+ std::vector<Node> pool;
632
+
633
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
634
+ std::unique_ptr<float[]> vec(new float[storage->d]);
635
+
636
+ storage->reconstruct(id, vec.get());
637
+ dis->set_query(vec.get());
638
+
639
+ // Collect the visited nodes into pool
640
+ search_on_graph<true>(
641
+ *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
642
+
643
+ std::sort(pool.begin(), pool.end());
644
+
645
+ int node;
646
+ bool found = false;
647
+ for (int i = 0; i < pool.size(); i++) {
648
+ node = pool[i].id;
649
+ if (degrees[node] < R && node != id) {
650
+ found = true;
651
+ break;
652
+ }
653
+ }
654
+
655
+ // randomly choice annother node
656
+ if (!found) {
657
+ do {
658
+ node = rng.rand_int(ntotal);
659
+ if (vt.get(node) && degrees[node] < R && node != id) {
660
+ found = true;
661
+ }
662
+ } while (!found);
663
+ }
664
+
665
+ int pos = degrees[node];
666
+ final_graph->at(node, pos) = id; // replace
667
+ degrees[node] += 1;
668
+
669
+ return node;
670
+ }
671
+
672
+ void NSG::check_graph() const {
673
+ #pragma omp parallel for
674
+ for (int i = 0; i < ntotal; i++) {
675
+ for (int j = 0; j < R; j++) {
676
+ int id = final_graph->at(i, j);
677
+ FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
678
+ }
679
+ }
680
+ }
681
+
682
+ } // namespace faiss