faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,679 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NSG.h>
11
+
12
+ #include <algorithm>
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <stack>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+
19
+ namespace faiss {
20
+
21
+ namespace nsg {
22
+
23
+ namespace {
24
+
25
+ // It needs to be smaller than 0
26
+ constexpr int EMPTY_ID = -1;
27
+
28
+ /* Wrap the distance computer into one that negates the
29
+ distances. This makes supporting INNER_PRODUCE search easier */
30
+
31
+ struct NegativeDistanceComputer : DistanceComputer {
32
+ using idx_t = Index::idx_t;
33
+
34
+ /// owned by this
35
+ DistanceComputer* basedis;
36
+
37
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
38
+ : basedis(basedis) {}
39
+
40
+ void set_query(const float* x) override {
41
+ basedis->set_query(x);
42
+ }
43
+
44
+ /// compute distance of vector i to current query
45
+ float operator()(idx_t i) override {
46
+ return -(*basedis)(i);
47
+ }
48
+
49
+ /// compute distance between two stored vectors
50
+ float symmetric_dis(idx_t i, idx_t j) override {
51
+ return -basedis->symmetric_dis(i, j);
52
+ }
53
+
54
+ ~NegativeDistanceComputer() override {
55
+ delete basedis;
56
+ }
57
+ };
58
+
59
+ } // namespace
60
+
61
+ DistanceComputer* storage_distance_computer(const Index* storage) {
62
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
63
+ return new NegativeDistanceComputer(storage->get_distance_computer());
64
+ } else {
65
+ return storage->get_distance_computer();
66
+ }
67
+ }
68
+
69
+ } // namespace nsg
70
+
71
+ using namespace nsg;
72
+
73
+ using LockGuard = std::lock_guard<std::mutex>;
74
+
75
+ struct Neighbor {
76
+ int id;
77
+ float distance;
78
+ bool flag;
79
+
80
+ Neighbor() = default;
81
+ Neighbor(int id, float distance, bool f)
82
+ : id(id), distance(distance), flag(f) {}
83
+
84
+ inline bool operator<(const Neighbor& other) const {
85
+ return distance < other.distance;
86
+ }
87
+ };
88
+
89
+ struct Node {
90
+ int id;
91
+ float distance;
92
+
93
+ Node() = default;
94
+ Node(int id, float distance) : id(id), distance(distance) {}
95
+
96
+ inline bool operator<(const Node& other) const {
97
+ return distance < other.distance;
98
+ }
99
+ };
100
+
101
+ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
102
+ // find the location to insert
103
+ int left = 0, right = K - 1;
104
+ if (addr[left].distance > nn.distance) {
105
+ memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
106
+ addr[left] = nn;
107
+ return left;
108
+ }
109
+ if (addr[right].distance < nn.distance) {
110
+ addr[K] = nn;
111
+ return K;
112
+ }
113
+ while (left < right - 1) {
114
+ int mid = (left + right) / 2;
115
+ if (addr[mid].distance > nn.distance) {
116
+ right = mid;
117
+ } else {
118
+ left = mid;
119
+ }
120
+ }
121
+ // check equal ID
122
+
123
+ while (left > 0) {
124
+ if (addr[left].distance < nn.distance) {
125
+ break;
126
+ }
127
+ if (addr[left].id == nn.id) {
128
+ return K + 1;
129
+ }
130
+ left--;
131
+ }
132
+ if (addr[left].id == nn.id || addr[right].id == nn.id) {
133
+ return K + 1;
134
+ }
135
+ memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
136
+ addr[right] = nn;
137
+ return right;
138
+ }
139
+
140
+ NSG::NSG(int R) : R(R), rng(0x0903) {
141
+ L = R + 32;
142
+ C = R + 100;
143
+ search_L = 16;
144
+ ntotal = 0;
145
+ is_built = false;
146
+ srand(0x1998);
147
+ }
148
+
149
+ void NSG::search(
150
+ DistanceComputer& dis,
151
+ int k,
152
+ idx_t* I,
153
+ float* D,
154
+ VisitedTable& vt) const {
155
+ FAISS_THROW_IF_NOT(is_built);
156
+ FAISS_THROW_IF_NOT(final_graph);
157
+
158
+ int pool_size = std::max(search_L, k);
159
+ std::vector<Neighbor> retset;
160
+ std::vector<Node> tmp;
161
+ search_on_graph<false>(
162
+ *final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
163
+
164
+ for (size_t i = 0; i < k; i++) {
165
+ I[i] = retset[i].id;
166
+ D[i] = retset[i].distance;
167
+ }
168
+ }
169
+
170
+ void NSG::build(
171
+ Index* storage,
172
+ idx_t n,
173
+ const nsg::Graph<idx_t>& knn_graph,
174
+ bool verbose) {
175
+ FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
176
+
177
+ if (verbose) {
178
+ printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
179
+ }
180
+
181
+ ntotal = n;
182
+ init_graph(storage, knn_graph);
183
+
184
+ std::vector<int> degrees(n, 0);
185
+ {
186
+ nsg::Graph<Node> tmp_graph(n, R);
187
+
188
+ link(storage, knn_graph, tmp_graph, verbose);
189
+
190
+ final_graph = std::make_shared<nsg::Graph<int>>(n, R);
191
+ std::fill_n(final_graph->data, n * R, EMPTY_ID);
192
+
193
+ #pragma omp parallel for
194
+ for (int i = 0; i < n; i++) {
195
+ int cnt = 0;
196
+ for (int j = 0; j < R; j++) {
197
+ int id = tmp_graph.at(i, j).id;
198
+ if (id != EMPTY_ID) {
199
+ final_graph->at(i, cnt) = id;
200
+ cnt += 1;
201
+ }
202
+ degrees[i] = cnt;
203
+ }
204
+ }
205
+ }
206
+
207
+ int num_attached = tree_grow(storage, degrees);
208
+ check_graph();
209
+ is_built = true;
210
+
211
+ if (verbose) {
212
+ int max = 0, min = 1e6;
213
+ double avg = 0;
214
+
215
+ for (int i = 0; i < n; i++) {
216
+ int size = 0;
217
+ while (size < R && final_graph->at(i, size) != EMPTY_ID) {
218
+ size += 1;
219
+ }
220
+ max = std::max(size, max);
221
+ min = std::min(size, min);
222
+ avg += size;
223
+ }
224
+
225
+ avg = avg / n;
226
+ printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
227
+ max,
228
+ min,
229
+ avg);
230
+ printf("Attached nodes: %d\n", num_attached);
231
+ }
232
+ }
233
+
234
+ void NSG::reset() {
235
+ final_graph.reset();
236
+ ntotal = 0;
237
+ is_built = false;
238
+ }
239
+
240
+ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
241
+ int d = storage->d;
242
+ int n = storage->ntotal;
243
+
244
+ std::unique_ptr<float[]> center(new float[d]);
245
+ std::unique_ptr<float[]> tmp(new float[d]);
246
+ std::fill_n(center.get(), d, 0.0f);
247
+
248
+ for (int i = 0; i < n; i++) {
249
+ storage->reconstruct(i, tmp.get());
250
+ for (int j = 0; j < d; j++) {
251
+ center[j] += tmp[j];
252
+ }
253
+ }
254
+
255
+ for (int i = 0; i < d; i++) {
256
+ center[i] /= n;
257
+ }
258
+
259
+ std::vector<Neighbor> retset;
260
+ std::vector<Node> tmpset;
261
+
262
+ // random initialize navigating point
263
+ int ep = rng.rand_int(n);
264
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
265
+
266
+ dis->set_query(center.get());
267
+ VisitedTable vt(ntotal);
268
+
269
+ // Do not collect the visited nodes
270
+ search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
271
+
272
+ // set enterpoint
273
+ enterpoint = retset[0].id;
274
+ }
275
+
276
+ template <bool collect_fullset, class index_t>
277
+ void NSG::search_on_graph(
278
+ const nsg::Graph<index_t>& graph,
279
+ DistanceComputer& dis,
280
+ VisitedTable& vt,
281
+ int ep,
282
+ int pool_size,
283
+ std::vector<Neighbor>& retset,
284
+ std::vector<Node>& fullset) const {
285
+ RandomGenerator gen(0x1234);
286
+ retset.resize(pool_size + 1);
287
+ std::vector<int> init_ids(pool_size);
288
+
289
+ int num_ids = 0;
290
+ for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
291
+ int id = (int)graph.at(ep, i);
292
+ if (id < 0 || id >= ntotal) {
293
+ continue;
294
+ }
295
+
296
+ init_ids[i] = id;
297
+ vt.set(id);
298
+ num_ids += 1;
299
+ }
300
+
301
+ while (num_ids < pool_size) {
302
+ int id = gen.rand_int(ntotal);
303
+ if (vt.get(id)) {
304
+ continue;
305
+ }
306
+
307
+ init_ids[num_ids] = id;
308
+ num_ids++;
309
+ vt.set(id);
310
+ }
311
+
312
+ for (int i = 0; i < init_ids.size(); i++) {
313
+ int id = init_ids[i];
314
+
315
+ float dist = dis(id);
316
+ retset[i] = Neighbor(id, dist, true);
317
+
318
+ if (collect_fullset) {
319
+ fullset.emplace_back(retset[i].id, retset[i].distance);
320
+ }
321
+ }
322
+
323
+ std::sort(retset.begin(), retset.begin() + pool_size);
324
+
325
+ int k = 0;
326
+ while (k < pool_size) {
327
+ int updated_pos = pool_size;
328
+
329
+ if (retset[k].flag) {
330
+ retset[k].flag = false;
331
+ int n = retset[k].id;
332
+
333
+ for (int m = 0; m < graph.K; m++) {
334
+ int id = (int)graph.at(n, m);
335
+ if (id < 0 || id > ntotal || vt.get(id)) {
336
+ continue;
337
+ }
338
+ vt.set(id);
339
+
340
+ float dist = dis(id);
341
+ Neighbor nn(id, dist, true);
342
+ if (collect_fullset) {
343
+ fullset.emplace_back(id, dist);
344
+ }
345
+
346
+ if (dist >= retset[pool_size - 1].distance) {
347
+ continue;
348
+ }
349
+
350
+ int r = insert_into_pool(retset.data(), pool_size, nn);
351
+
352
+ updated_pos = std::min(updated_pos, r);
353
+ }
354
+ }
355
+
356
+ k = (updated_pos <= k) ? updated_pos : (k + 1);
357
+ }
358
+ }
359
+
360
+ void NSG::link(
361
+ Index* storage,
362
+ const nsg::Graph<idx_t>& knn_graph,
363
+ nsg::Graph<Node>& graph,
364
+ bool /* verbose */) {
365
+ #pragma omp parallel
366
+ {
367
+ std::unique_ptr<float[]> vec(new float[storage->d]);
368
+
369
+ std::vector<Node> pool;
370
+ std::vector<Neighbor> tmp;
371
+
372
+ VisitedTable vt(ntotal);
373
+ std::unique_ptr<DistanceComputer> dis(
374
+ storage_distance_computer(storage));
375
+
376
+ #pragma omp for schedule(dynamic, 100)
377
+ for (int i = 0; i < ntotal; i++) {
378
+ storage->reconstruct(i, vec.get());
379
+ dis->set_query(vec.get());
380
+
381
+ // Collect the visited nodes into pool
382
+ search_on_graph<true>(
383
+ knn_graph, *dis, vt, enterpoint, L, tmp, pool);
384
+
385
+ sync_prune(i, pool, *dis, vt, knn_graph, graph);
386
+
387
+ pool.clear();
388
+ tmp.clear();
389
+ vt.advance();
390
+ }
391
+ } // omp parallel
392
+
393
+ std::vector<std::mutex> locks(ntotal);
394
+ #pragma omp parallel
395
+ {
396
+ std::unique_ptr<DistanceComputer> dis(
397
+ storage_distance_computer(storage));
398
+
399
+ #pragma omp for schedule(dynamic, 100)
400
+ for (int i = 0; i < ntotal; ++i) {
401
+ add_reverse_links(i, locks, *dis, graph);
402
+ }
403
+ } // omp parallel
404
+ }
405
+
406
+ void NSG::sync_prune(
407
+ int q,
408
+ std::vector<Node>& pool,
409
+ DistanceComputer& dis,
410
+ VisitedTable& vt,
411
+ const nsg::Graph<idx_t>& knn_graph,
412
+ nsg::Graph<Node>& graph) {
413
+ for (int i = 0; i < knn_graph.K; i++) {
414
+ int id = knn_graph.at(q, i);
415
+ if (id < 0 || id >= ntotal || vt.get(id)) {
416
+ continue;
417
+ }
418
+
419
+ float dist = dis.symmetric_dis(q, id);
420
+ pool.emplace_back(id, dist);
421
+ }
422
+
423
+ std::sort(pool.begin(), pool.end());
424
+
425
+ std::vector<Node> result;
426
+
427
+ int start = 0;
428
+ if (pool[start].id == q) {
429
+ start++;
430
+ }
431
+ result.push_back(pool[start]);
432
+
433
+ while (result.size() < R && (++start) < pool.size() && start < C) {
434
+ auto& p = pool[start];
435
+ bool occlude = false;
436
+ for (int t = 0; t < result.size(); t++) {
437
+ if (p.id == result[t].id) {
438
+ occlude = true;
439
+ break;
440
+ }
441
+ float djk = dis.symmetric_dis(result[t].id, p.id);
442
+ if (djk < p.distance /* dik */) {
443
+ occlude = true;
444
+ break;
445
+ }
446
+ }
447
+ if (!occlude) {
448
+ result.push_back(p);
449
+ }
450
+ }
451
+
452
+ for (size_t i = 0; i < R; i++) {
453
+ if (i < result.size()) {
454
+ graph.at(q, i).id = result[i].id;
455
+ graph.at(q, i).distance = result[i].distance;
456
+ } else {
457
+ graph.at(q, i).id = EMPTY_ID;
458
+ }
459
+ }
460
+ }
461
+
462
+ void NSG::add_reverse_links(
463
+ int q,
464
+ std::vector<std::mutex>& locks,
465
+ DistanceComputer& dis,
466
+ nsg::Graph<Node>& graph) {
467
+ for (size_t i = 0; i < R; i++) {
468
+ if (graph.at(q, i).id == EMPTY_ID) {
469
+ break;
470
+ }
471
+
472
+ Node sn(q, graph.at(q, i).distance);
473
+ int des = graph.at(q, i).id;
474
+
475
+ std::vector<Node> tmp_pool;
476
+ int dup = 0;
477
+ {
478
+ LockGuard guard(locks[des]);
479
+ for (int j = 0; j < R; j++) {
480
+ if (graph.at(des, j).id == EMPTY_ID) {
481
+ break;
482
+ }
483
+ if (q == graph.at(des, j).id) {
484
+ dup = 1;
485
+ break;
486
+ }
487
+ tmp_pool.push_back(graph.at(des, j));
488
+ }
489
+ }
490
+
491
+ if (dup) {
492
+ continue;
493
+ }
494
+
495
+ tmp_pool.push_back(sn);
496
+ if (tmp_pool.size() > R) {
497
+ std::vector<Node> result;
498
+ int start = 0;
499
+ std::sort(tmp_pool.begin(), tmp_pool.end());
500
+ result.push_back(tmp_pool[start]);
501
+
502
+ while (result.size() < R && (++start) < tmp_pool.size()) {
503
+ auto& p = tmp_pool[start];
504
+ bool occlude = false;
505
+
506
+ for (int t = 0; t < result.size(); t++) {
507
+ if (p.id == result[t].id) {
508
+ occlude = true;
509
+ break;
510
+ }
511
+ float djk = dis.symmetric_dis(result[t].id, p.id);
512
+ if (djk < p.distance /* dik */) {
513
+ occlude = true;
514
+ break;
515
+ }
516
+ }
517
+
518
+ if (!occlude) {
519
+ result.push_back(p);
520
+ }
521
+ }
522
+
523
+ {
524
+ LockGuard guard(locks[des]);
525
+ for (int t = 0; t < result.size(); t++) {
526
+ graph.at(des, t) = result[t];
527
+ }
528
+ }
529
+
530
+ } else {
531
+ LockGuard guard(locks[des]);
532
+ for (int t = 0; t < R; t++) {
533
+ if (graph.at(des, t).id == EMPTY_ID) {
534
+ graph.at(des, t) = sn;
535
+ break;
536
+ }
537
+ }
538
+ }
539
+ }
540
+ }
541
+
542
+ int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
543
+ int root = enterpoint;
544
+ VisitedTable vt(ntotal);
545
+ VisitedTable vt2(ntotal);
546
+
547
+ int num_attached = 0;
548
+ int cnt = 0;
549
+ while (true) {
550
+ cnt = dfs(vt, root, cnt);
551
+ if (cnt >= ntotal) {
552
+ break;
553
+ }
554
+
555
+ root = attach_unlinked(storage, vt, vt2, degrees);
556
+ vt2.advance();
557
+ num_attached += 1;
558
+ }
559
+
560
+ return num_attached;
561
+ }
562
+
563
+ int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
564
+ int node = root;
565
+ std::stack<int> stack;
566
+ stack.push(root);
567
+
568
+ if (!vt.get(root)) {
569
+ cnt++;
570
+ }
571
+ vt.set(root);
572
+
573
+ while (!stack.empty()) {
574
+ int next = EMPTY_ID;
575
+ for (int i = 0; i < R; i++) {
576
+ int id = final_graph->at(node, i);
577
+ if (id != EMPTY_ID && !vt.get(id)) {
578
+ next = id;
579
+ break;
580
+ }
581
+ }
582
+
583
+ if (next == EMPTY_ID) {
584
+ stack.pop();
585
+ if (stack.empty()) {
586
+ break;
587
+ }
588
+ node = stack.top();
589
+ continue;
590
+ }
591
+ node = next;
592
+ vt.set(node);
593
+ stack.push(node);
594
+ cnt++;
595
+ }
596
+
597
+ return cnt;
598
+ }
599
+
600
+ int NSG::attach_unlinked(
601
+ Index* storage,
602
+ VisitedTable& vt,
603
+ VisitedTable& vt2,
604
+ std::vector<int>& degrees) {
605
+ /* NOTE: This implementation is slightly different from the original paper.
606
+ *
607
+ * Instead of connecting the unlinked node to the nearest point in the
608
+ * spanning tree which will increase the maximum degree of the graph and
609
+ * also make the graph hard to maintain, this implementation links the
610
+ * unlinked node to the nearest node of which the degree is smaller than R.
611
+ * It will keep the degree of all nodes to be no more than `R`.
612
+ */
613
+
614
+ // find one unlinked node
615
+ int id = EMPTY_ID;
616
+ for (int i = 0; i < ntotal; i++) {
617
+ if (!vt.get(i)) {
618
+ id = i;
619
+ break;
620
+ }
621
+ }
622
+
623
+ if (id == EMPTY_ID) {
624
+ return EMPTY_ID; // No Unlinked Node
625
+ }
626
+
627
+ std::vector<Neighbor> tmp;
628
+ std::vector<Node> pool;
629
+
630
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
631
+ std::unique_ptr<float[]> vec(new float[storage->d]);
632
+
633
+ storage->reconstruct(id, vec.get());
634
+ dis->set_query(vec.get());
635
+
636
+ // Collect the visited nodes into pool
637
+ search_on_graph<true>(
638
+ *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
639
+
640
+ std::sort(pool.begin(), pool.end());
641
+
642
+ int node;
643
+ bool found = false;
644
+ for (int i = 0; i < pool.size(); i++) {
645
+ node = pool[i].id;
646
+ if (degrees[node] < R && node != id) {
647
+ found = true;
648
+ break;
649
+ }
650
+ }
651
+
652
+ // randomly choice annother node
653
+ if (!found) {
654
+ do {
655
+ node = rng.rand_int(ntotal);
656
+ if (vt.get(node) && degrees[node] < R && node != id) {
657
+ found = true;
658
+ }
659
+ } while (!found);
660
+ }
661
+
662
+ int pos = degrees[node];
663
+ final_graph->at(node, pos) = id; // replace
664
+ degrees[node] += 1;
665
+
666
+ return node;
667
+ }
668
+
669
+ void NSG::check_graph() const {
670
+ #pragma omp parallel for
671
+ for (int i = 0; i < ntotal; i++) {
672
+ for (int j = 0; j < R; j++) {
673
+ int id = final_graph->at(i, j);
674
+ FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
675
+ }
676
+ }
677
+ }
678
+
679
+ } // namespace faiss