faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,682 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NSG.h>
11
+
12
+ #include <algorithm>
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <stack>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+
19
+ namespace faiss {
20
+
21
+ namespace nsg {
22
+
23
+ namespace {
24
+
25
+ // It needs to be smaller than 0
26
+ constexpr int EMPTY_ID = -1;
27
+
28
+ /* Wrap the distance computer into one that negates the
29
+ distances. This makes supporting INNER_PRODUCE search easier */
30
+
31
+ struct NegativeDistanceComputer : DistanceComputer {
32
+ using idx_t = Index::idx_t;
33
+
34
+ /// owned by this
35
+ DistanceComputer* basedis;
36
+
37
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
38
+ : basedis(basedis) {}
39
+
40
+ void set_query(const float* x) override {
41
+ basedis->set_query(x);
42
+ }
43
+
44
+ /// compute distance of vector i to current query
45
+ float operator()(idx_t i) override {
46
+ return -(*basedis)(i);
47
+ }
48
+
49
+ /// compute distance between two stored vectors
50
+ float symmetric_dis(idx_t i, idx_t j) override {
51
+ return -basedis->symmetric_dis(i, j);
52
+ }
53
+
54
+ ~NegativeDistanceComputer() override {
55
+ delete basedis;
56
+ }
57
+ };
58
+
59
+ } // namespace
60
+
61
+ DistanceComputer* storage_distance_computer(const Index* storage) {
62
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
63
+ return new NegativeDistanceComputer(storage->get_distance_computer());
64
+ } else {
65
+ return storage->get_distance_computer();
66
+ }
67
+ }
68
+
69
+ } // namespace nsg
70
+
71
+ using namespace nsg;
72
+
73
+ using LockGuard = std::lock_guard<std::mutex>;
74
+
75
+ struct Neighbor {
76
+ int id;
77
+ float distance;
78
+ bool flag;
79
+
80
+ Neighbor() = default;
81
+ Neighbor(int id, float distance, bool f)
82
+ : id(id), distance(distance), flag(f) {}
83
+
84
+ inline bool operator<(const Neighbor& other) const {
85
+ return distance < other.distance;
86
+ }
87
+ };
88
+
89
+ struct Node {
90
+ int id;
91
+ float distance;
92
+
93
+ Node() = default;
94
+ Node(int id, float distance) : id(id), distance(distance) {}
95
+
96
+ inline bool operator<(const Node& other) const {
97
+ return distance < other.distance;
98
+ }
99
+ };
100
+
101
+ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
102
+ // find the location to insert
103
+ int left = 0, right = K - 1;
104
+ if (addr[left].distance > nn.distance) {
105
+ memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
106
+ addr[left] = nn;
107
+ return left;
108
+ }
109
+ if (addr[right].distance < nn.distance) {
110
+ addr[K] = nn;
111
+ return K;
112
+ }
113
+ while (left < right - 1) {
114
+ int mid = (left + right) / 2;
115
+ if (addr[mid].distance > nn.distance) {
116
+ right = mid;
117
+ } else {
118
+ left = mid;
119
+ }
120
+ }
121
+ // check equal ID
122
+
123
+ while (left > 0) {
124
+ if (addr[left].distance < nn.distance) {
125
+ break;
126
+ }
127
+ if (addr[left].id == nn.id) {
128
+ return K + 1;
129
+ }
130
+ left--;
131
+ }
132
+ if (addr[left].id == nn.id || addr[right].id == nn.id) {
133
+ return K + 1;
134
+ }
135
+ memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
136
+ addr[right] = nn;
137
+ return right;
138
+ }
139
+
140
+ NSG::NSG(int R) : R(R), rng(0x0903) {
141
+ L = R + 32;
142
+ C = R + 100;
143
+ search_L = 16;
144
+ ntotal = 0;
145
+ is_built = false;
146
+ srand(0x1998);
147
+ }
148
+
149
+ void NSG::search(
150
+ DistanceComputer& dis,
151
+ int k,
152
+ idx_t* I,
153
+ float* D,
154
+ VisitedTable& vt) const {
155
+ FAISS_THROW_IF_NOT(is_built);
156
+ FAISS_THROW_IF_NOT(final_graph);
157
+
158
+ int pool_size = std::max(search_L, k);
159
+ std::vector<Neighbor> retset;
160
+ std::vector<Node> tmp;
161
+ search_on_graph<false>(
162
+ *final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
163
+
164
+ std::partial_sort(
165
+ retset.begin(), retset.begin() + k, retset.begin() + pool_size);
166
+
167
+ for (size_t i = 0; i < k; i++) {
168
+ I[i] = retset[i].id;
169
+ D[i] = retset[i].distance;
170
+ }
171
+ }
172
+
173
+ void NSG::build(
174
+ Index* storage,
175
+ idx_t n,
176
+ const nsg::Graph<idx_t>& knn_graph,
177
+ bool verbose) {
178
+ FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
179
+
180
+ if (verbose) {
181
+ printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
182
+ }
183
+
184
+ ntotal = n;
185
+ init_graph(storage, knn_graph);
186
+
187
+ std::vector<int> degrees(n, 0);
188
+ {
189
+ nsg::Graph<Node> tmp_graph(n, R);
190
+
191
+ link(storage, knn_graph, tmp_graph, verbose);
192
+
193
+ final_graph = std::make_shared<nsg::Graph<int>>(n, R);
194
+ std::fill_n(final_graph->data, n * R, EMPTY_ID);
195
+
196
+ #pragma omp parallel for
197
+ for (int i = 0; i < n; i++) {
198
+ int cnt = 0;
199
+ for (int j = 0; j < R; j++) {
200
+ int id = tmp_graph.at(i, j).id;
201
+ if (id != EMPTY_ID) {
202
+ final_graph->at(i, cnt) = id;
203
+ cnt += 1;
204
+ }
205
+ degrees[i] = cnt;
206
+ }
207
+ }
208
+ }
209
+
210
+ int num_attached = tree_grow(storage, degrees);
211
+ check_graph();
212
+ is_built = true;
213
+
214
+ if (verbose) {
215
+ int max = 0, min = 1e6;
216
+ double avg = 0;
217
+
218
+ for (int i = 0; i < n; i++) {
219
+ int size = 0;
220
+ while (size < R && final_graph->at(i, size) != EMPTY_ID) {
221
+ size += 1;
222
+ }
223
+ max = std::max(size, max);
224
+ min = std::min(size, min);
225
+ avg += size;
226
+ }
227
+
228
+ avg = avg / n;
229
+ printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
230
+ max,
231
+ min,
232
+ avg);
233
+ printf("Attached nodes: %d\n", num_attached);
234
+ }
235
+ }
236
+
237
+ void NSG::reset() {
238
+ final_graph.reset();
239
+ ntotal = 0;
240
+ is_built = false;
241
+ }
242
+
243
+ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
244
+ int d = storage->d;
245
+ int n = storage->ntotal;
246
+
247
+ std::unique_ptr<float[]> center(new float[d]);
248
+ std::unique_ptr<float[]> tmp(new float[d]);
249
+ std::fill_n(center.get(), d, 0.0f);
250
+
251
+ for (int i = 0; i < n; i++) {
252
+ storage->reconstruct(i, tmp.get());
253
+ for (int j = 0; j < d; j++) {
254
+ center[j] += tmp[j];
255
+ }
256
+ }
257
+
258
+ for (int i = 0; i < d; i++) {
259
+ center[i] /= n;
260
+ }
261
+
262
+ std::vector<Neighbor> retset;
263
+ std::vector<Node> tmpset;
264
+
265
+ // random initialize navigating point
266
+ int ep = rng.rand_int(n);
267
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
268
+
269
+ dis->set_query(center.get());
270
+ VisitedTable vt(ntotal);
271
+
272
+ // Do not collect the visited nodes
273
+ search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
274
+
275
+ // set enterpoint
276
+ enterpoint = retset[0].id;
277
+ }
278
+
279
+ template <bool collect_fullset, class index_t>
280
+ void NSG::search_on_graph(
281
+ const nsg::Graph<index_t>& graph,
282
+ DistanceComputer& dis,
283
+ VisitedTable& vt,
284
+ int ep,
285
+ int pool_size,
286
+ std::vector<Neighbor>& retset,
287
+ std::vector<Node>& fullset) const {
288
+ RandomGenerator gen(0x1234);
289
+ retset.resize(pool_size + 1);
290
+ std::vector<int> init_ids(pool_size);
291
+
292
+ int num_ids = 0;
293
+ for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
294
+ int id = (int)graph.at(ep, i);
295
+ if (id < 0 || id >= ntotal) {
296
+ continue;
297
+ }
298
+
299
+ init_ids[i] = id;
300
+ vt.set(id);
301
+ num_ids += 1;
302
+ }
303
+
304
+ while (num_ids < pool_size) {
305
+ int id = gen.rand_int(ntotal);
306
+ if (vt.get(id)) {
307
+ continue;
308
+ }
309
+
310
+ init_ids[num_ids] = id;
311
+ num_ids++;
312
+ vt.set(id);
313
+ }
314
+
315
+ for (int i = 0; i < init_ids.size(); i++) {
316
+ int id = init_ids[i];
317
+
318
+ float dist = dis(id);
319
+ retset[i] = Neighbor(id, dist, true);
320
+
321
+ if (collect_fullset) {
322
+ fullset.emplace_back(retset[i].id, retset[i].distance);
323
+ }
324
+ }
325
+
326
+ std::sort(retset.begin(), retset.begin() + pool_size);
327
+
328
+ int k = 0;
329
+ while (k < pool_size) {
330
+ int updated_pos = pool_size;
331
+
332
+ if (retset[k].flag) {
333
+ retset[k].flag = false;
334
+ int n = retset[k].id;
335
+
336
+ for (int m = 0; m < graph.K; m++) {
337
+ int id = (int)graph.at(n, m);
338
+ if (id < 0 || id > ntotal || vt.get(id)) {
339
+ continue;
340
+ }
341
+ vt.set(id);
342
+
343
+ float dist = dis(id);
344
+ Neighbor nn(id, dist, true);
345
+ if (collect_fullset) {
346
+ fullset.emplace_back(id, dist);
347
+ }
348
+
349
+ if (dist >= retset[pool_size - 1].distance) {
350
+ continue;
351
+ }
352
+
353
+ int r = insert_into_pool(retset.data(), pool_size, nn);
354
+
355
+ updated_pos = std::min(updated_pos, r);
356
+ }
357
+ }
358
+
359
+ k = (updated_pos <= k) ? updated_pos : (k + 1);
360
+ }
361
+ }
362
+
363
+ void NSG::link(
364
+ Index* storage,
365
+ const nsg::Graph<idx_t>& knn_graph,
366
+ nsg::Graph<Node>& graph,
367
+ bool /* verbose */) {
368
+ #pragma omp parallel
369
+ {
370
+ std::unique_ptr<float[]> vec(new float[storage->d]);
371
+
372
+ std::vector<Node> pool;
373
+ std::vector<Neighbor> tmp;
374
+
375
+ VisitedTable vt(ntotal);
376
+ std::unique_ptr<DistanceComputer> dis(
377
+ storage_distance_computer(storage));
378
+
379
+ #pragma omp for schedule(dynamic, 100)
380
+ for (int i = 0; i < ntotal; i++) {
381
+ storage->reconstruct(i, vec.get());
382
+ dis->set_query(vec.get());
383
+
384
+ // Collect the visited nodes into pool
385
+ search_on_graph<true>(
386
+ knn_graph, *dis, vt, enterpoint, L, tmp, pool);
387
+
388
+ sync_prune(i, pool, *dis, vt, knn_graph, graph);
389
+
390
+ pool.clear();
391
+ tmp.clear();
392
+ vt.advance();
393
+ }
394
+ } // omp parallel
395
+
396
+ std::vector<std::mutex> locks(ntotal);
397
+ #pragma omp parallel
398
+ {
399
+ std::unique_ptr<DistanceComputer> dis(
400
+ storage_distance_computer(storage));
401
+
402
+ #pragma omp for schedule(dynamic, 100)
403
+ for (int i = 0; i < ntotal; ++i) {
404
+ add_reverse_links(i, locks, *dis, graph);
405
+ }
406
+ } // omp parallel
407
+ }
408
+
409
+ void NSG::sync_prune(
410
+ int q,
411
+ std::vector<Node>& pool,
412
+ DistanceComputer& dis,
413
+ VisitedTable& vt,
414
+ const nsg::Graph<idx_t>& knn_graph,
415
+ nsg::Graph<Node>& graph) {
416
+ for (int i = 0; i < knn_graph.K; i++) {
417
+ int id = knn_graph.at(q, i);
418
+ if (id < 0 || id >= ntotal || vt.get(id)) {
419
+ continue;
420
+ }
421
+
422
+ float dist = dis.symmetric_dis(q, id);
423
+ pool.emplace_back(id, dist);
424
+ }
425
+
426
+ std::sort(pool.begin(), pool.end());
427
+
428
+ std::vector<Node> result;
429
+
430
+ int start = 0;
431
+ if (pool[start].id == q) {
432
+ start++;
433
+ }
434
+ result.push_back(pool[start]);
435
+
436
+ while (result.size() < R && (++start) < pool.size() && start < C) {
437
+ auto& p = pool[start];
438
+ bool occlude = false;
439
+ for (int t = 0; t < result.size(); t++) {
440
+ if (p.id == result[t].id) {
441
+ occlude = true;
442
+ break;
443
+ }
444
+ float djk = dis.symmetric_dis(result[t].id, p.id);
445
+ if (djk < p.distance /* dik */) {
446
+ occlude = true;
447
+ break;
448
+ }
449
+ }
450
+ if (!occlude) {
451
+ result.push_back(p);
452
+ }
453
+ }
454
+
455
+ for (size_t i = 0; i < R; i++) {
456
+ if (i < result.size()) {
457
+ graph.at(q, i).id = result[i].id;
458
+ graph.at(q, i).distance = result[i].distance;
459
+ } else {
460
+ graph.at(q, i).id = EMPTY_ID;
461
+ }
462
+ }
463
+ }
464
+
465
+ void NSG::add_reverse_links(
466
+ int q,
467
+ std::vector<std::mutex>& locks,
468
+ DistanceComputer& dis,
469
+ nsg::Graph<Node>& graph) {
470
+ for (size_t i = 0; i < R; i++) {
471
+ if (graph.at(q, i).id == EMPTY_ID) {
472
+ break;
473
+ }
474
+
475
+ Node sn(q, graph.at(q, i).distance);
476
+ int des = graph.at(q, i).id;
477
+
478
+ std::vector<Node> tmp_pool;
479
+ int dup = 0;
480
+ {
481
+ LockGuard guard(locks[des]);
482
+ for (int j = 0; j < R; j++) {
483
+ if (graph.at(des, j).id == EMPTY_ID) {
484
+ break;
485
+ }
486
+ if (q == graph.at(des, j).id) {
487
+ dup = 1;
488
+ break;
489
+ }
490
+ tmp_pool.push_back(graph.at(des, j));
491
+ }
492
+ }
493
+
494
+ if (dup) {
495
+ continue;
496
+ }
497
+
498
+ tmp_pool.push_back(sn);
499
+ if (tmp_pool.size() > R) {
500
+ std::vector<Node> result;
501
+ int start = 0;
502
+ std::sort(tmp_pool.begin(), tmp_pool.end());
503
+ result.push_back(tmp_pool[start]);
504
+
505
+ while (result.size() < R && (++start) < tmp_pool.size()) {
506
+ auto& p = tmp_pool[start];
507
+ bool occlude = false;
508
+
509
+ for (int t = 0; t < result.size(); t++) {
510
+ if (p.id == result[t].id) {
511
+ occlude = true;
512
+ break;
513
+ }
514
+ float djk = dis.symmetric_dis(result[t].id, p.id);
515
+ if (djk < p.distance /* dik */) {
516
+ occlude = true;
517
+ break;
518
+ }
519
+ }
520
+
521
+ if (!occlude) {
522
+ result.push_back(p);
523
+ }
524
+ }
525
+
526
+ {
527
+ LockGuard guard(locks[des]);
528
+ for (int t = 0; t < result.size(); t++) {
529
+ graph.at(des, t) = result[t];
530
+ }
531
+ }
532
+
533
+ } else {
534
+ LockGuard guard(locks[des]);
535
+ for (int t = 0; t < R; t++) {
536
+ if (graph.at(des, t).id == EMPTY_ID) {
537
+ graph.at(des, t) = sn;
538
+ break;
539
+ }
540
+ }
541
+ }
542
+ }
543
+ }
544
+
545
+ int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
546
+ int root = enterpoint;
547
+ VisitedTable vt(ntotal);
548
+ VisitedTable vt2(ntotal);
549
+
550
+ int num_attached = 0;
551
+ int cnt = 0;
552
+ while (true) {
553
+ cnt = dfs(vt, root, cnt);
554
+ if (cnt >= ntotal) {
555
+ break;
556
+ }
557
+
558
+ root = attach_unlinked(storage, vt, vt2, degrees);
559
+ vt2.advance();
560
+ num_attached += 1;
561
+ }
562
+
563
+ return num_attached;
564
+ }
565
+
566
+ int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
567
+ int node = root;
568
+ std::stack<int> stack;
569
+ stack.push(root);
570
+
571
+ if (!vt.get(root)) {
572
+ cnt++;
573
+ }
574
+ vt.set(root);
575
+
576
+ while (!stack.empty()) {
577
+ int next = EMPTY_ID;
578
+ for (int i = 0; i < R; i++) {
579
+ int id = final_graph->at(node, i);
580
+ if (id != EMPTY_ID && !vt.get(id)) {
581
+ next = id;
582
+ break;
583
+ }
584
+ }
585
+
586
+ if (next == EMPTY_ID) {
587
+ stack.pop();
588
+ if (stack.empty()) {
589
+ break;
590
+ }
591
+ node = stack.top();
592
+ continue;
593
+ }
594
+ node = next;
595
+ vt.set(node);
596
+ stack.push(node);
597
+ cnt++;
598
+ }
599
+
600
+ return cnt;
601
+ }
602
+
603
+ int NSG::attach_unlinked(
604
+ Index* storage,
605
+ VisitedTable& vt,
606
+ VisitedTable& vt2,
607
+ std::vector<int>& degrees) {
608
+ /* NOTE: This implementation is slightly different from the original paper.
609
+ *
610
+ * Instead of connecting the unlinked node to the nearest point in the
611
+ * spanning tree which will increase the maximum degree of the graph and
612
+ * also make the graph hard to maintain, this implementation links the
613
+ * unlinked node to the nearest node of which the degree is smaller than R.
614
+ * It will keep the degree of all nodes to be no more than `R`.
615
+ */
616
+
617
+ // find one unlinked node
618
+ int id = EMPTY_ID;
619
+ for (int i = 0; i < ntotal; i++) {
620
+ if (!vt.get(i)) {
621
+ id = i;
622
+ break;
623
+ }
624
+ }
625
+
626
+ if (id == EMPTY_ID) {
627
+ return EMPTY_ID; // No Unlinked Node
628
+ }
629
+
630
+ std::vector<Neighbor> tmp;
631
+ std::vector<Node> pool;
632
+
633
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
634
+ std::unique_ptr<float[]> vec(new float[storage->d]);
635
+
636
+ storage->reconstruct(id, vec.get());
637
+ dis->set_query(vec.get());
638
+
639
+ // Collect the visited nodes into pool
640
+ search_on_graph<true>(
641
+ *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
642
+
643
+ std::sort(pool.begin(), pool.end());
644
+
645
+ int node;
646
+ bool found = false;
647
+ for (int i = 0; i < pool.size(); i++) {
648
+ node = pool[i].id;
649
+ if (degrees[node] < R && node != id) {
650
+ found = true;
651
+ break;
652
+ }
653
+ }
654
+
655
+ // randomly choice annother node
656
+ if (!found) {
657
+ do {
658
+ node = rng.rand_int(ntotal);
659
+ if (vt.get(node) && degrees[node] < R && node != id) {
660
+ found = true;
661
+ }
662
+ } while (!found);
663
+ }
664
+
665
+ int pos = degrees[node];
666
+ final_graph->at(node, pos) = id; // replace
667
+ degrees[node] += 1;
668
+
669
+ return node;
670
+ }
671
+
672
+ void NSG::check_graph() const {
673
+ #pragma omp parallel for
674
+ for (int i = 0; i < ntotal; i++) {
675
+ for (int j = 0; j < R; j++) {
676
+ int id = final_graph->at(i, j);
677
+ FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
678
+ }
679
+ }
680
+ }
681
+
682
+ } // namespace faiss