faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,487 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NNDescent.h>
11
+
12
+ #include <mutex>
13
+ #include <string>
14
+
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+
17
+ namespace faiss {
18
+
19
+ using LockGuard = std::lock_guard<std::mutex>;
20
+
21
+ namespace nndescent {
22
+
23
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N);
24
+
25
+ Nhood::Nhood(int l, int s, std::mt19937& rng, int N) {
26
+ M = s;
27
+ nn_new.resize(s * 2);
28
+ gen_random(rng, nn_new.data(), (int)nn_new.size(), N);
29
+ }
30
+
31
+ /// Copy operator
32
+ Nhood& Nhood::operator=(const Nhood& other) {
33
+ M = other.M;
34
+ std::copy(
35
+ other.nn_new.begin(),
36
+ other.nn_new.end(),
37
+ std::back_inserter(nn_new));
38
+ nn_new.reserve(other.nn_new.capacity());
39
+ pool.reserve(other.pool.capacity());
40
+ return *this;
41
+ }
42
+
43
+ /// Copy constructor
44
+ Nhood::Nhood(const Nhood& other) {
45
+ M = other.M;
46
+ std::copy(
47
+ other.nn_new.begin(),
48
+ other.nn_new.end(),
49
+ std::back_inserter(nn_new));
50
+ nn_new.reserve(other.nn_new.capacity());
51
+ pool.reserve(other.pool.capacity());
52
+ }
53
+
54
+ /// Insert a point into the candidate pool
55
+ void Nhood::insert(int id, float dist) {
56
+ LockGuard guard(lock);
57
+ if (dist > pool.front().distance)
58
+ return;
59
+ for (int i = 0; i < pool.size(); i++) {
60
+ if (id == pool[i].id)
61
+ return;
62
+ }
63
+ if (pool.size() < pool.capacity()) {
64
+ pool.push_back(Neighbor(id, dist, true));
65
+ std::push_heap(pool.begin(), pool.end());
66
+ } else {
67
+ std::pop_heap(pool.begin(), pool.end());
68
+ pool[pool.size() - 1] = Neighbor(id, dist, true);
69
+ std::push_heap(pool.begin(), pool.end());
70
+ }
71
+ }
72
+
73
+ /// In local join, two objects are compared only if at least
74
+ /// one of them is new.
75
+ template <typename C>
76
+ void Nhood::join(C callback) const {
77
+ for (int const i : nn_new) {
78
+ for (int const j : nn_new) {
79
+ if (i < j) {
80
+ callback(i, j);
81
+ }
82
+ }
83
+ for (int j : nn_old) {
84
+ callback(i, j);
85
+ }
86
+ }
87
+ }
88
+
89
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N) {
90
+ for (int i = 0; i < size; ++i) {
91
+ addr[i] = rng() % (N - size);
92
+ }
93
+ std::sort(addr, addr + size);
94
+ for (int i = 1; i < size; ++i) {
95
+ if (addr[i] <= addr[i - 1]) {
96
+ addr[i] = addr[i - 1] + 1;
97
+ }
98
+ }
99
+ int off = rng() % N;
100
+ for (int i = 0; i < size; ++i) {
101
+ addr[i] = (addr[i] + off) % N;
102
+ }
103
+ }
104
+
105
+ // Insert a new point into the candidate pool in ascending order
106
+ int insert_into_pool(Neighbor* addr, int size, Neighbor nn) {
107
+ // find the location to insert
108
+ int left = 0, right = size - 1;
109
+ if (addr[left].distance > nn.distance) {
110
+ memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor));
111
+ addr[left] = nn;
112
+ return left;
113
+ }
114
+ if (addr[right].distance < nn.distance) {
115
+ addr[size] = nn;
116
+ return size;
117
+ }
118
+ while (left < right - 1) {
119
+ int mid = (left + right) / 2;
120
+ if (addr[mid].distance > nn.distance)
121
+ right = mid;
122
+ else
123
+ left = mid;
124
+ }
125
+ // check equal ID
126
+
127
+ while (left > 0) {
128
+ if (addr[left].distance < nn.distance)
129
+ break;
130
+ if (addr[left].id == nn.id)
131
+ return size + 1;
132
+ left--;
133
+ }
134
+ if (addr[left].id == nn.id || addr[right].id == nn.id)
135
+ return size + 1;
136
+ memmove((char*)&addr[right + 1],
137
+ &addr[right],
138
+ (size - right) * sizeof(Neighbor));
139
+ addr[right] = nn;
140
+ return right;
141
+ }
142
+
143
+ } // namespace nndescent
144
+
145
+ using namespace nndescent;
146
+
147
+ constexpr int NUM_EVAL_POINTS = 100;
148
+
149
+ NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
150
+ ntotal = 0;
151
+ has_built = false;
152
+ S = 10;
153
+ R = 100;
154
+ L = K + 50;
155
+ iter = 10;
156
+ search_L = 0;
157
+ }
158
+
159
+ NNDescent::~NNDescent() {}
160
+
161
+ void NNDescent::join(DistanceComputer& qdis) {
162
+ #pragma omp parallel for default(shared) schedule(dynamic, 100)
163
+ for (int n = 0; n < ntotal; n++) {
164
+ graph[n].join([&](int i, int j) {
165
+ if (i != j) {
166
+ float dist = qdis.symmetric_dis(i, j);
167
+ graph[i].insert(j, dist);
168
+ graph[j].insert(i, dist);
169
+ }
170
+ });
171
+ }
172
+ }
173
+
174
+ /// Sample neighbors for each node to peform local join later
175
+ /// Store them in nn_new and nn_old
176
+ void NNDescent::update() {
177
+ // Step 1.
178
+ // Clear all nn_new and nn_old
179
+ #pragma omp parallel for
180
+ for (int i = 0; i < ntotal; i++) {
181
+ std::vector<int>().swap(graph[i].nn_new);
182
+ std::vector<int>().swap(graph[i].nn_old);
183
+ }
184
+
185
+ // Step 2.
186
+ // Compute the number of neighbors which is new i.e. flag is true
187
+ // in the candidate pool. This must not exceed the sample number S.
188
+ // That means We only select S new neighbors.
189
+ #pragma omp parallel for
190
+ for (int n = 0; n < ntotal; ++n) {
191
+ auto& nn = graph[n];
192
+ std::sort(nn.pool.begin(), nn.pool.end());
193
+
194
+ if (nn.pool.size() > L)
195
+ nn.pool.resize(L);
196
+ nn.pool.reserve(L); // keep the pool size be L
197
+
198
+ int maxl = std::min(nn.M + S, (int)nn.pool.size());
199
+ int c = 0;
200
+ int l = 0;
201
+
202
+ while ((l < maxl) && (c < S)) {
203
+ if (nn.pool[l].flag)
204
+ ++c;
205
+ ++l;
206
+ }
207
+ nn.M = l;
208
+ }
209
+
210
+ // Step 3.
211
+ // Find reverse links for each node
212
+ // Randomly choose R reverse links.
213
+ #pragma omp parallel
214
+ {
215
+ std::mt19937 rng(random_seed * 5081 + omp_get_thread_num());
216
+ #pragma omp for
217
+ for (int n = 0; n < ntotal; ++n) {
218
+ auto& node = graph[n];
219
+ auto& nn_new = node.nn_new;
220
+ auto& nn_old = node.nn_old;
221
+
222
+ for (int l = 0; l < node.M; ++l) {
223
+ auto& nn = node.pool[l];
224
+ auto& other = graph[nn.id]; // the other side of the edge
225
+
226
+ if (nn.flag) { // the node is inserted newly
227
+ // push the neighbor into nn_new
228
+ nn_new.push_back(nn.id);
229
+ // push itself into other.rnn_new if it is not in
230
+ // the candidate pool of the other side
231
+ if (nn.distance > other.pool.back().distance) {
232
+ LockGuard guard(other.lock);
233
+ if (other.rnn_new.size() < R) {
234
+ other.rnn_new.push_back(n);
235
+ } else {
236
+ int pos = rng() % R;
237
+ other.rnn_new[pos] = n;
238
+ }
239
+ }
240
+ nn.flag = false;
241
+
242
+ } else { // the node is old
243
+ // push the neighbor into nn_old
244
+ nn_old.push_back(nn.id);
245
+ // push itself into other.rnn_old if it is not in
246
+ // the candidate pool of the other side
247
+ if (nn.distance > other.pool.back().distance) {
248
+ LockGuard guard(other.lock);
249
+ if (other.rnn_old.size() < R) {
250
+ other.rnn_old.push_back(n);
251
+ } else {
252
+ int pos = rng() % R;
253
+ other.rnn_old[pos] = n;
254
+ }
255
+ }
256
+ }
257
+ }
258
+ // make heap to join later (in join() function)
259
+ std::make_heap(node.pool.begin(), node.pool.end());
260
+ }
261
+ }
262
+
263
+ // Step 4.
264
+ // Combine the forward and the reverse links
265
+ // R = 0 means no reverse links are used.
266
+ #pragma omp parallel for
267
+ for (int i = 0; i < ntotal; ++i) {
268
+ auto& nn_new = graph[i].nn_new;
269
+ auto& nn_old = graph[i].nn_old;
270
+ auto& rnn_new = graph[i].rnn_new;
271
+ auto& rnn_old = graph[i].rnn_old;
272
+
273
+ nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
274
+ nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
275
+ if (nn_old.size() > R * 2) {
276
+ nn_old.resize(R * 2);
277
+ nn_old.reserve(R * 2);
278
+ }
279
+
280
+ std::vector<int>().swap(graph[i].rnn_new);
281
+ std::vector<int>().swap(graph[i].rnn_old);
282
+ }
283
+ }
284
+
285
+ void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) {
286
+ int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal);
287
+ std::vector<int> eval_points(num_eval_points);
288
+ std::vector<std::vector<int>> acc_eval_set(num_eval_points);
289
+ std::mt19937 rng(random_seed * 6577 + omp_get_thread_num());
290
+ gen_random(rng, eval_points.data(), eval_points.size(), ntotal);
291
+ generate_eval_set(qdis, eval_points, acc_eval_set, ntotal);
292
+ for (int it = 0; it < iter; it++) {
293
+ join(qdis);
294
+ update();
295
+
296
+ if (verbose) {
297
+ float recall = eval_recall(eval_points, acc_eval_set);
298
+ printf("Iter: %d, recall@%d: %lf\n", it, K, recall);
299
+ }
300
+ }
301
+ }
302
+
303
+ /// Sample a small number of points to evaluate the quality of KNNG built
304
+ void NNDescent::generate_eval_set(
305
+ DistanceComputer& qdis,
306
+ std::vector<int>& c,
307
+ std::vector<std::vector<int>>& v,
308
+ int N) {
309
+ #pragma omp parallel for
310
+ for (int i = 0; i < c.size(); i++) {
311
+ std::vector<Neighbor> tmp;
312
+ for (int j = 0; j < N; j++) {
313
+ if (i == j)
314
+ continue; // skip itself
315
+ float dist = qdis.symmetric_dis(c[i], j);
316
+ tmp.push_back(Neighbor(j, dist, true));
317
+ }
318
+
319
+ std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end());
320
+ for (int j = 0; j < K; j++) {
321
+ v[i].push_back(tmp[j].id);
322
+ }
323
+ }
324
+ }
325
+
326
+ /// Evaluate the quality of KNNG built
327
+ float NNDescent::eval_recall(
328
+ std::vector<int>& eval_points,
329
+ std::vector<std::vector<int>>& acc_eval_set) {
330
+ float mean_acc = 0.0f;
331
+ for (size_t i = 0; i < eval_points.size(); i++) {
332
+ float acc = 0;
333
+ std::vector<Neighbor>& g = graph[eval_points[i]].pool;
334
+ std::vector<int>& v = acc_eval_set[i];
335
+ for (size_t j = 0; j < g.size(); j++) {
336
+ for (size_t k = 0; k < v.size(); k++) {
337
+ if (g[j].id == v[k]) {
338
+ acc++;
339
+ break;
340
+ }
341
+ }
342
+ }
343
+ mean_acc += acc / v.size();
344
+ }
345
+ return mean_acc / eval_points.size();
346
+ }
347
+
348
+ /// Initialize the KNN graph randomly
349
+ void NNDescent::init_graph(DistanceComputer& qdis) {
350
+ graph.reserve(ntotal);
351
+ {
352
+ std::mt19937 rng(random_seed * 6007);
353
+ for (int i = 0; i < ntotal; i++) {
354
+ graph.push_back(Nhood(L, S, rng, (int)ntotal));
355
+ }
356
+ }
357
+ #pragma omp parallel
358
+ {
359
+ std::mt19937 rng(random_seed * 7741 + omp_get_thread_num());
360
+ #pragma omp for
361
+ for (int i = 0; i < ntotal; i++) {
362
+ std::vector<int> tmp(S);
363
+
364
+ gen_random(rng, tmp.data(), S, ntotal);
365
+
366
+ for (int j = 0; j < S; j++) {
367
+ int id = tmp[j];
368
+ if (id == i)
369
+ continue;
370
+ float dist = qdis.symmetric_dis(i, id);
371
+
372
+ graph[i].pool.push_back(Neighbor(id, dist, true));
373
+ }
374
+ std::make_heap(graph[i].pool.begin(), graph[i].pool.end());
375
+ graph[i].pool.reserve(L);
376
+ }
377
+ }
378
+ }
379
+
380
+ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
381
+ FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
382
+
383
+ if (verbose) {
384
+ printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
385
+ K,
386
+ S,
387
+ R,
388
+ L,
389
+ iter);
390
+ }
391
+
392
+ ntotal = n;
393
+ init_graph(qdis);
394
+ nndescent(qdis, verbose);
395
+
396
+ final_graph.resize(ntotal * K);
397
+
398
+ // Store the neighbor link structure into final_graph
399
+ // Clear the old graph
400
+ for (int i = 0; i < ntotal; i++) {
401
+ std::sort(graph[i].pool.begin(), graph[i].pool.end());
402
+ for (int j = 0; j < K; j++) {
403
+ FAISS_ASSERT(graph[i].pool[j].id < ntotal);
404
+ final_graph[i * K + j] = graph[i].pool[j].id;
405
+ }
406
+ }
407
+ std::vector<Nhood>().swap(graph);
408
+ has_built = true;
409
+
410
+ if (verbose) {
411
+ printf("Addes %d points into the index\n", ntotal);
412
+ }
413
+ }
414
+
415
+ void NNDescent::search(
416
+ DistanceComputer& qdis,
417
+ const int topk,
418
+ idx_t* indices,
419
+ float* dists,
420
+ VisitedTable& vt) const {
421
+ FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
422
+ int L = std::max(search_L, topk);
423
+
424
+ // candidate pool, the K best items is the result.
425
+ std::vector<Neighbor> retset(L + 1);
426
+
427
+ // Randomly choose L points to intialize the candidate pool
428
+ std::vector<int> init_ids(L);
429
+ std::mt19937 rng(random_seed);
430
+
431
+ gen_random(rng, init_ids.data(), L, ntotal);
432
+ for (int i = 0; i < L; i++) {
433
+ int id = init_ids[i];
434
+ float dist = qdis(id);
435
+ retset[i] = Neighbor(id, dist, true);
436
+ }
437
+
438
+ // Maintain the candidate pool in ascending order
439
+ std::sort(retset.begin(), retset.begin() + L);
440
+
441
+ int k = 0;
442
+
443
+ // Stop until the smallest position updated is >= L
444
+ while (k < L) {
445
+ int nk = L;
446
+
447
+ if (retset[k].flag) {
448
+ retset[k].flag = false;
449
+ int n = retset[k].id;
450
+
451
+ for (int m = 0; m < K; ++m) {
452
+ int id = final_graph[n * K + m];
453
+ if (vt.get(id))
454
+ continue;
455
+
456
+ vt.set(id);
457
+ float dist = qdis(id);
458
+ if (dist >= retset[L - 1].distance)
459
+ continue;
460
+
461
+ Neighbor nn(id, dist, true);
462
+ int r = insert_into_pool(retset.data(), L, nn);
463
+
464
+ if (r < nk)
465
+ nk = r;
466
+ }
467
+ }
468
+ if (nk <= k)
469
+ k = nk;
470
+ else
471
+ ++k;
472
+ }
473
+ for (size_t i = 0; i < topk; i++) {
474
+ indices[i] = retset[i].id;
475
+ dists[i] = retset[i].distance;
476
+ }
477
+
478
+ vt.advance();
479
+ }
480
+
481
+ void NNDescent::reset() {
482
+ has_built = false;
483
+ ntotal = 0;
484
+ final_graph.resize(0);
485
+ }
486
+
487
+ } // namespace faiss
@@ -0,0 +1,154 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <algorithm>
13
+ #include <mutex>
14
+ #include <queue>
15
+ #include <random>
16
+ #include <unordered_set>
17
+ #include <vector>
18
+
19
+ #include <omp.h>
20
+
21
+ #include <faiss/Index.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/platform_macros.h>
24
+ #include <faiss/utils/Heap.h>
25
+ #include <faiss/utils/random.h>
26
+
27
+ namespace faiss {
28
+
29
+ /** Implementation of NNDescent which is one of the most popular
30
+ * KNN graph building algorithms
31
+ *
32
+ * Efficient K-Nearest Neighbor Graph Construction for Generic
33
+ * Similarity Measures
34
+ *
35
+ * Dong, Wei, Charikar Moses, and Kai Li, WWW 2011
36
+ *
37
+ * This implmentation is heavily influenced by the efanna
38
+ * implementation by Cong Fu and the KGraph library by Wei Dong
39
+ * (https://github.com/ZJULearning/efanna_graph)
40
+ * (https://github.com/aaalgo/kgraph)
41
+ *
42
+ * The NNDescent object stores only the neighbor link structure,
43
+ * see IndexNNDescent.h for the full index object.
44
+ */
45
+
46
+ struct VisitedTable;
47
+ struct DistanceComputer;
48
+
49
+ namespace nndescent {
50
+
51
+ struct Neighbor {
52
+ int id;
53
+ float distance;
54
+ bool flag;
55
+
56
+ Neighbor() = default;
57
+ Neighbor(int id, float distance, bool f)
58
+ : id(id), distance(distance), flag(f) {}
59
+
60
+ inline bool operator<(const Neighbor& other) const {
61
+ return distance < other.distance;
62
+ }
63
+ };
64
+
65
+ struct Nhood {
66
+ std::mutex lock;
67
+ std::vector<Neighbor> pool; // candidate pool (a max heap)
68
+ int M; // number of new neighbors to be operated
69
+
70
+ std::vector<int> nn_old; // old neighbors
71
+ std::vector<int> nn_new; // new neighbors
72
+ std::vector<int> rnn_old; // reverse old neighbors
73
+ std::vector<int> rnn_new; // reverse new neighbors
74
+
75
+ Nhood() = default;
76
+
77
+ Nhood(int l, int s, std::mt19937& rng, int N);
78
+
79
+ Nhood& operator=(const Nhood& other);
80
+
81
+ Nhood(const Nhood& other);
82
+
83
+ void insert(int id, float dist);
84
+
85
+ template <typename C>
86
+ void join(C callback) const;
87
+ };
88
+
89
+ } // namespace nndescent
90
+
91
+ struct NNDescent {
92
+ using storage_idx_t = int;
93
+ using idx_t = Index::idx_t;
94
+
95
+ using KNNGraph = std::vector<nndescent::Nhood>;
96
+
97
+ explicit NNDescent(const int d, const int K);
98
+
99
+ ~NNDescent();
100
+
101
+ void build(DistanceComputer& qdis, const int n, bool verbose);
102
+
103
+ void search(
104
+ DistanceComputer& qdis,
105
+ const int topk,
106
+ idx_t* indices,
107
+ float* dists,
108
+ VisitedTable& vt) const;
109
+
110
+ void reset();
111
+
112
+ /// Initialize the KNN graph randomly
113
+ void init_graph(DistanceComputer& qdis);
114
+
115
+ /// Perform NNDescent algorithm
116
+ void nndescent(DistanceComputer& qdis, bool verbose);
117
+
118
+ /// Perform local join on each node
119
+ void join(DistanceComputer& qdis);
120
+
121
+ /// Sample new neighbors for each node to peform local join later
122
+ void update();
123
+
124
+ /// Sample a small number of points to evaluate the quality of KNNG built
125
+ void generate_eval_set(
126
+ DistanceComputer& qdis,
127
+ std::vector<int>& c,
128
+ std::vector<std::vector<int>>& v,
129
+ int N);
130
+
131
+ /// Evaluate the quality of KNNG built
132
+ float eval_recall(
133
+ std::vector<int>& ctrl_points,
134
+ std::vector<std::vector<int>>& acc_eval_set);
135
+
136
+ bool has_built;
137
+
138
+ int K; // K in KNN graph
139
+ int S; // number of sample neighbors to be updated for each node
140
+ int R; // size of reverse links, 0 means the reverse links will not be used
141
+ int L; // size of the candidate pool in building
142
+ int iter; // number of iterations to iterate over
143
+ int search_L; // size of candidate pool in searching
144
+ int random_seed; // random seed for generators
145
+
146
+ int d; // dimensions
147
+
148
+ int ntotal;
149
+
150
+ KNNGraph graph;
151
+ std::vector<int> final_graph;
152
+ };
153
+
154
+ } // namespace faiss