faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,679 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NSG.h>
11
+
12
+ #include <algorithm>
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <stack>
16
+
17
+ #include <faiss/impl/AuxIndexStructures.h>
18
+
19
+ namespace faiss {
20
+
21
+ namespace nsg {
22
+
23
+ namespace {
24
+
25
+ // It needs to be smaller than 0
26
+ constexpr int EMPTY_ID = -1;
27
+
28
+ /* Wrap the distance computer into one that negates the
29
+ distances. This makes supporting INNER_PRODUCE search easier */
30
+
31
+ struct NegativeDistanceComputer : DistanceComputer {
32
+ using idx_t = Index::idx_t;
33
+
34
+ /// owned by this
35
+ DistanceComputer* basedis;
36
+
37
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
38
+ : basedis(basedis) {}
39
+
40
+ void set_query(const float* x) override {
41
+ basedis->set_query(x);
42
+ }
43
+
44
+ /// compute distance of vector i to current query
45
+ float operator()(idx_t i) override {
46
+ return -(*basedis)(i);
47
+ }
48
+
49
+ /// compute distance between two stored vectors
50
+ float symmetric_dis(idx_t i, idx_t j) override {
51
+ return -basedis->symmetric_dis(i, j);
52
+ }
53
+
54
+ ~NegativeDistanceComputer() override {
55
+ delete basedis;
56
+ }
57
+ };
58
+
59
+ } // namespace
60
+
61
+ DistanceComputer* storage_distance_computer(const Index* storage) {
62
+ if (storage->metric_type == METRIC_INNER_PRODUCT) {
63
+ return new NegativeDistanceComputer(storage->get_distance_computer());
64
+ } else {
65
+ return storage->get_distance_computer();
66
+ }
67
+ }
68
+
69
+ } // namespace nsg
70
+
71
+ using namespace nsg;
72
+
73
+ using LockGuard = std::lock_guard<std::mutex>;
74
+
75
+ struct Neighbor {
76
+ int id;
77
+ float distance;
78
+ bool flag;
79
+
80
+ Neighbor() = default;
81
+ Neighbor(int id, float distance, bool f)
82
+ : id(id), distance(distance), flag(f) {}
83
+
84
+ inline bool operator<(const Neighbor& other) const {
85
+ return distance < other.distance;
86
+ }
87
+ };
88
+
89
+ struct Node {
90
+ int id;
91
+ float distance;
92
+
93
+ Node() = default;
94
+ Node(int id, float distance) : id(id), distance(distance) {}
95
+
96
+ inline bool operator<(const Node& other) const {
97
+ return distance < other.distance;
98
+ }
99
+ };
100
+
101
+ inline int insert_into_pool(Neighbor* addr, int K, Neighbor nn) {
102
+ // find the location to insert
103
+ int left = 0, right = K - 1;
104
+ if (addr[left].distance > nn.distance) {
105
+ memmove(&addr[left + 1], &addr[left], K * sizeof(Neighbor));
106
+ addr[left] = nn;
107
+ return left;
108
+ }
109
+ if (addr[right].distance < nn.distance) {
110
+ addr[K] = nn;
111
+ return K;
112
+ }
113
+ while (left < right - 1) {
114
+ int mid = (left + right) / 2;
115
+ if (addr[mid].distance > nn.distance) {
116
+ right = mid;
117
+ } else {
118
+ left = mid;
119
+ }
120
+ }
121
+ // check equal ID
122
+
123
+ while (left > 0) {
124
+ if (addr[left].distance < nn.distance) {
125
+ break;
126
+ }
127
+ if (addr[left].id == nn.id) {
128
+ return K + 1;
129
+ }
130
+ left--;
131
+ }
132
+ if (addr[left].id == nn.id || addr[right].id == nn.id) {
133
+ return K + 1;
134
+ }
135
+ memmove(&addr[right + 1], &addr[right], (K - right) * sizeof(Neighbor));
136
+ addr[right] = nn;
137
+ return right;
138
+ }
139
+
140
+ NSG::NSG(int R) : R(R), rng(0x0903) {
141
+ L = R + 32;
142
+ C = R + 100;
143
+ search_L = 16;
144
+ ntotal = 0;
145
+ is_built = false;
146
+ srand(0x1998);
147
+ }
148
+
149
+ void NSG::search(
150
+ DistanceComputer& dis,
151
+ int k,
152
+ idx_t* I,
153
+ float* D,
154
+ VisitedTable& vt) const {
155
+ FAISS_THROW_IF_NOT(is_built);
156
+ FAISS_THROW_IF_NOT(final_graph);
157
+
158
+ int pool_size = std::max(search_L, k);
159
+ std::vector<Neighbor> retset;
160
+ std::vector<Node> tmp;
161
+ search_on_graph<false>(
162
+ *final_graph, dis, vt, enterpoint, pool_size, retset, tmp);
163
+
164
+ for (size_t i = 0; i < k; i++) {
165
+ I[i] = retset[i].id;
166
+ D[i] = retset[i].distance;
167
+ }
168
+ }
169
+
170
+ void NSG::build(
171
+ Index* storage,
172
+ idx_t n,
173
+ const nsg::Graph<idx_t>& knn_graph,
174
+ bool verbose) {
175
+ FAISS_THROW_IF_NOT(!is_built && ntotal == 0);
176
+
177
+ if (verbose) {
178
+ printf("NSG::build R=%d, L=%d, C=%d\n", R, L, C);
179
+ }
180
+
181
+ ntotal = n;
182
+ init_graph(storage, knn_graph);
183
+
184
+ std::vector<int> degrees(n, 0);
185
+ {
186
+ nsg::Graph<Node> tmp_graph(n, R);
187
+
188
+ link(storage, knn_graph, tmp_graph, verbose);
189
+
190
+ final_graph = std::make_shared<nsg::Graph<int>>(n, R);
191
+ std::fill_n(final_graph->data, n * R, EMPTY_ID);
192
+
193
+ #pragma omp parallel for
194
+ for (int i = 0; i < n; i++) {
195
+ int cnt = 0;
196
+ for (int j = 0; j < R; j++) {
197
+ int id = tmp_graph.at(i, j).id;
198
+ if (id != EMPTY_ID) {
199
+ final_graph->at(i, cnt) = id;
200
+ cnt += 1;
201
+ }
202
+ degrees[i] = cnt;
203
+ }
204
+ }
205
+ }
206
+
207
+ int num_attached = tree_grow(storage, degrees);
208
+ check_graph();
209
+ is_built = true;
210
+
211
+ if (verbose) {
212
+ int max = 0, min = 1e6;
213
+ double avg = 0;
214
+
215
+ for (int i = 0; i < n; i++) {
216
+ int size = 0;
217
+ while (size < R && final_graph->at(i, size) != EMPTY_ID) {
218
+ size += 1;
219
+ }
220
+ max = std::max(size, max);
221
+ min = std::min(size, min);
222
+ avg += size;
223
+ }
224
+
225
+ avg = avg / n;
226
+ printf("Degree Statistics: Max = %d, Min = %d, Avg = %lf\n",
227
+ max,
228
+ min,
229
+ avg);
230
+ printf("Attached nodes: %d\n", num_attached);
231
+ }
232
+ }
233
+
234
+ void NSG::reset() {
235
+ final_graph.reset();
236
+ ntotal = 0;
237
+ is_built = false;
238
+ }
239
+
240
+ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
241
+ int d = storage->d;
242
+ int n = storage->ntotal;
243
+
244
+ std::unique_ptr<float[]> center(new float[d]);
245
+ std::unique_ptr<float[]> tmp(new float[d]);
246
+ std::fill_n(center.get(), d, 0.0f);
247
+
248
+ for (int i = 0; i < n; i++) {
249
+ storage->reconstruct(i, tmp.get());
250
+ for (int j = 0; j < d; j++) {
251
+ center[j] += tmp[j];
252
+ }
253
+ }
254
+
255
+ for (int i = 0; i < d; i++) {
256
+ center[i] /= n;
257
+ }
258
+
259
+ std::vector<Neighbor> retset;
260
+ std::vector<Node> tmpset;
261
+
262
+ // random initialize navigating point
263
+ int ep = rng.rand_int(n);
264
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
265
+
266
+ dis->set_query(center.get());
267
+ VisitedTable vt(ntotal);
268
+
269
+ // Do not collect the visited nodes
270
+ search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
271
+
272
+ // set enterpoint
273
+ enterpoint = retset[0].id;
274
+ }
275
+
276
+ template <bool collect_fullset, class index_t>
277
+ void NSG::search_on_graph(
278
+ const nsg::Graph<index_t>& graph,
279
+ DistanceComputer& dis,
280
+ VisitedTable& vt,
281
+ int ep,
282
+ int pool_size,
283
+ std::vector<Neighbor>& retset,
284
+ std::vector<Node>& fullset) const {
285
+ RandomGenerator gen(0x1234);
286
+ retset.resize(pool_size + 1);
287
+ std::vector<int> init_ids(pool_size);
288
+
289
+ int num_ids = 0;
290
+ for (int i = 0; i < init_ids.size() && i < graph.K; i++) {
291
+ int id = (int)graph.at(ep, i);
292
+ if (id < 0 || id >= ntotal) {
293
+ continue;
294
+ }
295
+
296
+ init_ids[i] = id;
297
+ vt.set(id);
298
+ num_ids += 1;
299
+ }
300
+
301
+ while (num_ids < pool_size) {
302
+ int id = gen.rand_int(ntotal);
303
+ if (vt.get(id)) {
304
+ continue;
305
+ }
306
+
307
+ init_ids[num_ids] = id;
308
+ num_ids++;
309
+ vt.set(id);
310
+ }
311
+
312
+ for (int i = 0; i < init_ids.size(); i++) {
313
+ int id = init_ids[i];
314
+
315
+ float dist = dis(id);
316
+ retset[i] = Neighbor(id, dist, true);
317
+
318
+ if (collect_fullset) {
319
+ fullset.emplace_back(retset[i].id, retset[i].distance);
320
+ }
321
+ }
322
+
323
+ std::sort(retset.begin(), retset.begin() + pool_size);
324
+
325
+ int k = 0;
326
+ while (k < pool_size) {
327
+ int updated_pos = pool_size;
328
+
329
+ if (retset[k].flag) {
330
+ retset[k].flag = false;
331
+ int n = retset[k].id;
332
+
333
+ for (int m = 0; m < graph.K; m++) {
334
+ int id = (int)graph.at(n, m);
335
+ if (id < 0 || id > ntotal || vt.get(id)) {
336
+ continue;
337
+ }
338
+ vt.set(id);
339
+
340
+ float dist = dis(id);
341
+ Neighbor nn(id, dist, true);
342
+ if (collect_fullset) {
343
+ fullset.emplace_back(id, dist);
344
+ }
345
+
346
+ if (dist >= retset[pool_size - 1].distance) {
347
+ continue;
348
+ }
349
+
350
+ int r = insert_into_pool(retset.data(), pool_size, nn);
351
+
352
+ updated_pos = std::min(updated_pos, r);
353
+ }
354
+ }
355
+
356
+ k = (updated_pos <= k) ? updated_pos : (k + 1);
357
+ }
358
+ }
359
+
360
+ void NSG::link(
361
+ Index* storage,
362
+ const nsg::Graph<idx_t>& knn_graph,
363
+ nsg::Graph<Node>& graph,
364
+ bool /* verbose */) {
365
+ #pragma omp parallel
366
+ {
367
+ std::unique_ptr<float[]> vec(new float[storage->d]);
368
+
369
+ std::vector<Node> pool;
370
+ std::vector<Neighbor> tmp;
371
+
372
+ VisitedTable vt(ntotal);
373
+ std::unique_ptr<DistanceComputer> dis(
374
+ storage_distance_computer(storage));
375
+
376
+ #pragma omp for schedule(dynamic, 100)
377
+ for (int i = 0; i < ntotal; i++) {
378
+ storage->reconstruct(i, vec.get());
379
+ dis->set_query(vec.get());
380
+
381
+ // Collect the visited nodes into pool
382
+ search_on_graph<true>(
383
+ knn_graph, *dis, vt, enterpoint, L, tmp, pool);
384
+
385
+ sync_prune(i, pool, *dis, vt, knn_graph, graph);
386
+
387
+ pool.clear();
388
+ tmp.clear();
389
+ vt.advance();
390
+ }
391
+ } // omp parallel
392
+
393
+ std::vector<std::mutex> locks(ntotal);
394
+ #pragma omp parallel
395
+ {
396
+ std::unique_ptr<DistanceComputer> dis(
397
+ storage_distance_computer(storage));
398
+
399
+ #pragma omp for schedule(dynamic, 100)
400
+ for (int i = 0; i < ntotal; ++i) {
401
+ add_reverse_links(i, locks, *dis, graph);
402
+ }
403
+ } // omp parallel
404
+ }
405
+
406
+ void NSG::sync_prune(
407
+ int q,
408
+ std::vector<Node>& pool,
409
+ DistanceComputer& dis,
410
+ VisitedTable& vt,
411
+ const nsg::Graph<idx_t>& knn_graph,
412
+ nsg::Graph<Node>& graph) {
413
+ for (int i = 0; i < knn_graph.K; i++) {
414
+ int id = knn_graph.at(q, i);
415
+ if (id < 0 || id >= ntotal || vt.get(id)) {
416
+ continue;
417
+ }
418
+
419
+ float dist = dis.symmetric_dis(q, id);
420
+ pool.emplace_back(id, dist);
421
+ }
422
+
423
+ std::sort(pool.begin(), pool.end());
424
+
425
+ std::vector<Node> result;
426
+
427
+ int start = 0;
428
+ if (pool[start].id == q) {
429
+ start++;
430
+ }
431
+ result.push_back(pool[start]);
432
+
433
+ while (result.size() < R && (++start) < pool.size() && start < C) {
434
+ auto& p = pool[start];
435
+ bool occlude = false;
436
+ for (int t = 0; t < result.size(); t++) {
437
+ if (p.id == result[t].id) {
438
+ occlude = true;
439
+ break;
440
+ }
441
+ float djk = dis.symmetric_dis(result[t].id, p.id);
442
+ if (djk < p.distance /* dik */) {
443
+ occlude = true;
444
+ break;
445
+ }
446
+ }
447
+ if (!occlude) {
448
+ result.push_back(p);
449
+ }
450
+ }
451
+
452
+ for (size_t i = 0; i < R; i++) {
453
+ if (i < result.size()) {
454
+ graph.at(q, i).id = result[i].id;
455
+ graph.at(q, i).distance = result[i].distance;
456
+ } else {
457
+ graph.at(q, i).id = EMPTY_ID;
458
+ }
459
+ }
460
+ }
461
+
462
+ void NSG::add_reverse_links(
463
+ int q,
464
+ std::vector<std::mutex>& locks,
465
+ DistanceComputer& dis,
466
+ nsg::Graph<Node>& graph) {
467
+ for (size_t i = 0; i < R; i++) {
468
+ if (graph.at(q, i).id == EMPTY_ID) {
469
+ break;
470
+ }
471
+
472
+ Node sn(q, graph.at(q, i).distance);
473
+ int des = graph.at(q, i).id;
474
+
475
+ std::vector<Node> tmp_pool;
476
+ int dup = 0;
477
+ {
478
+ LockGuard guard(locks[des]);
479
+ for (int j = 0; j < R; j++) {
480
+ if (graph.at(des, j).id == EMPTY_ID) {
481
+ break;
482
+ }
483
+ if (q == graph.at(des, j).id) {
484
+ dup = 1;
485
+ break;
486
+ }
487
+ tmp_pool.push_back(graph.at(des, j));
488
+ }
489
+ }
490
+
491
+ if (dup) {
492
+ continue;
493
+ }
494
+
495
+ tmp_pool.push_back(sn);
496
+ if (tmp_pool.size() > R) {
497
+ std::vector<Node> result;
498
+ int start = 0;
499
+ std::sort(tmp_pool.begin(), tmp_pool.end());
500
+ result.push_back(tmp_pool[start]);
501
+
502
+ while (result.size() < R && (++start) < tmp_pool.size()) {
503
+ auto& p = tmp_pool[start];
504
+ bool occlude = false;
505
+
506
+ for (int t = 0; t < result.size(); t++) {
507
+ if (p.id == result[t].id) {
508
+ occlude = true;
509
+ break;
510
+ }
511
+ float djk = dis.symmetric_dis(result[t].id, p.id);
512
+ if (djk < p.distance /* dik */) {
513
+ occlude = true;
514
+ break;
515
+ }
516
+ }
517
+
518
+ if (!occlude) {
519
+ result.push_back(p);
520
+ }
521
+ }
522
+
523
+ {
524
+ LockGuard guard(locks[des]);
525
+ for (int t = 0; t < result.size(); t++) {
526
+ graph.at(des, t) = result[t];
527
+ }
528
+ }
529
+
530
+ } else {
531
+ LockGuard guard(locks[des]);
532
+ for (int t = 0; t < R; t++) {
533
+ if (graph.at(des, t).id == EMPTY_ID) {
534
+ graph.at(des, t) = sn;
535
+ break;
536
+ }
537
+ }
538
+ }
539
+ }
540
+ }
541
+
542
+ int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
543
+ int root = enterpoint;
544
+ VisitedTable vt(ntotal);
545
+ VisitedTable vt2(ntotal);
546
+
547
+ int num_attached = 0;
548
+ int cnt = 0;
549
+ while (true) {
550
+ cnt = dfs(vt, root, cnt);
551
+ if (cnt >= ntotal) {
552
+ break;
553
+ }
554
+
555
+ root = attach_unlinked(storage, vt, vt2, degrees);
556
+ vt2.advance();
557
+ num_attached += 1;
558
+ }
559
+
560
+ return num_attached;
561
+ }
562
+
563
+ int NSG::dfs(VisitedTable& vt, int root, int cnt) const {
564
+ int node = root;
565
+ std::stack<int> stack;
566
+ stack.push(root);
567
+
568
+ if (!vt.get(root)) {
569
+ cnt++;
570
+ }
571
+ vt.set(root);
572
+
573
+ while (!stack.empty()) {
574
+ int next = EMPTY_ID;
575
+ for (int i = 0; i < R; i++) {
576
+ int id = final_graph->at(node, i);
577
+ if (id != EMPTY_ID && !vt.get(id)) {
578
+ next = id;
579
+ break;
580
+ }
581
+ }
582
+
583
+ if (next == EMPTY_ID) {
584
+ stack.pop();
585
+ if (stack.empty()) {
586
+ break;
587
+ }
588
+ node = stack.top();
589
+ continue;
590
+ }
591
+ node = next;
592
+ vt.set(node);
593
+ stack.push(node);
594
+ cnt++;
595
+ }
596
+
597
+ return cnt;
598
+ }
599
+
600
+ int NSG::attach_unlinked(
601
+ Index* storage,
602
+ VisitedTable& vt,
603
+ VisitedTable& vt2,
604
+ std::vector<int>& degrees) {
605
+ /* NOTE: This implementation is slightly different from the original paper.
606
+ *
607
+ * Instead of connecting the unlinked node to the nearest point in the
608
+ * spanning tree which will increase the maximum degree of the graph and
609
+ * also make the graph hard to maintain, this implementation links the
610
+ * unlinked node to the nearest node of which the degree is smaller than R.
611
+ * It will keep the degree of all nodes to be no more than `R`.
612
+ */
613
+
614
+ // find one unlinked node
615
+ int id = EMPTY_ID;
616
+ for (int i = 0; i < ntotal; i++) {
617
+ if (!vt.get(i)) {
618
+ id = i;
619
+ break;
620
+ }
621
+ }
622
+
623
+ if (id == EMPTY_ID) {
624
+ return EMPTY_ID; // No Unlinked Node
625
+ }
626
+
627
+ std::vector<Neighbor> tmp;
628
+ std::vector<Node> pool;
629
+
630
+ std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
631
+ std::unique_ptr<float[]> vec(new float[storage->d]);
632
+
633
+ storage->reconstruct(id, vec.get());
634
+ dis->set_query(vec.get());
635
+
636
+ // Collect the visited nodes into pool
637
+ search_on_graph<true>(
638
+ *final_graph, *dis, vt2, enterpoint, search_L, tmp, pool);
639
+
640
+ std::sort(pool.begin(), pool.end());
641
+
642
+ int node;
643
+ bool found = false;
644
+ for (int i = 0; i < pool.size(); i++) {
645
+ node = pool[i].id;
646
+ if (degrees[node] < R && node != id) {
647
+ found = true;
648
+ break;
649
+ }
650
+ }
651
+
652
+ // randomly choice annother node
653
+ if (!found) {
654
+ do {
655
+ node = rng.rand_int(ntotal);
656
+ if (vt.get(node) && degrees[node] < R && node != id) {
657
+ found = true;
658
+ }
659
+ } while (!found);
660
+ }
661
+
662
+ int pos = degrees[node];
663
+ final_graph->at(node, pos) = id; // replace
664
+ degrees[node] += 1;
665
+
666
+ return node;
667
+ }
668
+
669
+ void NSG::check_graph() const {
670
+ #pragma omp parallel for
671
+ for (int i = 0; i < ntotal; i++) {
672
+ for (int j = 0; j < R; j++) {
673
+ int id = final_graph->at(i, j);
674
+ FAISS_THROW_IF_NOT(id < ntotal && (id >= 0 || id == EMPTY_ID));
675
+ }
676
+ }
677
+ }
678
+
679
+ } // namespace faiss