faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,487 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NNDescent.h>
11
+
12
+ #include <mutex>
13
+ #include <string>
14
+
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+
17
+ namespace faiss {
18
+
19
+ using LockGuard = std::lock_guard<std::mutex>;
20
+
21
+ namespace nndescent {
22
+
23
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N);
24
+
25
+ Nhood::Nhood(int l, int s, std::mt19937& rng, int N) {
26
+ M = s;
27
+ nn_new.resize(s * 2);
28
+ gen_random(rng, nn_new.data(), (int)nn_new.size(), N);
29
+ }
30
+
31
+ /// Copy operator
32
+ Nhood& Nhood::operator=(const Nhood& other) {
33
+ M = other.M;
34
+ std::copy(
35
+ other.nn_new.begin(),
36
+ other.nn_new.end(),
37
+ std::back_inserter(nn_new));
38
+ nn_new.reserve(other.nn_new.capacity());
39
+ pool.reserve(other.pool.capacity());
40
+ return *this;
41
+ }
42
+
43
+ /// Copy constructor
44
+ Nhood::Nhood(const Nhood& other) {
45
+ M = other.M;
46
+ std::copy(
47
+ other.nn_new.begin(),
48
+ other.nn_new.end(),
49
+ std::back_inserter(nn_new));
50
+ nn_new.reserve(other.nn_new.capacity());
51
+ pool.reserve(other.pool.capacity());
52
+ }
53
+
54
+ /// Insert a point into the candidate pool
55
+ void Nhood::insert(int id, float dist) {
56
+ LockGuard guard(lock);
57
+ if (dist > pool.front().distance)
58
+ return;
59
+ for (int i = 0; i < pool.size(); i++) {
60
+ if (id == pool[i].id)
61
+ return;
62
+ }
63
+ if (pool.size() < pool.capacity()) {
64
+ pool.push_back(Neighbor(id, dist, true));
65
+ std::push_heap(pool.begin(), pool.end());
66
+ } else {
67
+ std::pop_heap(pool.begin(), pool.end());
68
+ pool[pool.size() - 1] = Neighbor(id, dist, true);
69
+ std::push_heap(pool.begin(), pool.end());
70
+ }
71
+ }
72
+
73
+ /// In local join, two objects are compared only if at least
74
+ /// one of them is new.
75
+ template <typename C>
76
+ void Nhood::join(C callback) const {
77
+ for (int const i : nn_new) {
78
+ for (int const j : nn_new) {
79
+ if (i < j) {
80
+ callback(i, j);
81
+ }
82
+ }
83
+ for (int j : nn_old) {
84
+ callback(i, j);
85
+ }
86
+ }
87
+ }
88
+
89
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N) {
90
+ for (int i = 0; i < size; ++i) {
91
+ addr[i] = rng() % (N - size);
92
+ }
93
+ std::sort(addr, addr + size);
94
+ for (int i = 1; i < size; ++i) {
95
+ if (addr[i] <= addr[i - 1]) {
96
+ addr[i] = addr[i - 1] + 1;
97
+ }
98
+ }
99
+ int off = rng() % N;
100
+ for (int i = 0; i < size; ++i) {
101
+ addr[i] = (addr[i] + off) % N;
102
+ }
103
+ }
104
+
105
+ // Insert a new point into the candidate pool in ascending order
106
+ int insert_into_pool(Neighbor* addr, int size, Neighbor nn) {
107
+ // find the location to insert
108
+ int left = 0, right = size - 1;
109
+ if (addr[left].distance > nn.distance) {
110
+ memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor));
111
+ addr[left] = nn;
112
+ return left;
113
+ }
114
+ if (addr[right].distance < nn.distance) {
115
+ addr[size] = nn;
116
+ return size;
117
+ }
118
+ while (left < right - 1) {
119
+ int mid = (left + right) / 2;
120
+ if (addr[mid].distance > nn.distance)
121
+ right = mid;
122
+ else
123
+ left = mid;
124
+ }
125
+ // check equal ID
126
+
127
+ while (left > 0) {
128
+ if (addr[left].distance < nn.distance)
129
+ break;
130
+ if (addr[left].id == nn.id)
131
+ return size + 1;
132
+ left--;
133
+ }
134
+ if (addr[left].id == nn.id || addr[right].id == nn.id)
135
+ return size + 1;
136
+ memmove((char*)&addr[right + 1],
137
+ &addr[right],
138
+ (size - right) * sizeof(Neighbor));
139
+ addr[right] = nn;
140
+ return right;
141
+ }
142
+
143
+ } // namespace nndescent
144
+
145
+ using namespace nndescent;
146
+
147
+ constexpr int NUM_EVAL_POINTS = 100;
148
+
149
+ NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
150
+ ntotal = 0;
151
+ has_built = false;
152
+ S = 10;
153
+ R = 100;
154
+ L = K + 50;
155
+ iter = 10;
156
+ search_L = 0;
157
+ }
158
+
159
+ NNDescent::~NNDescent() {}
160
+
161
+ void NNDescent::join(DistanceComputer& qdis) {
162
+ #pragma omp parallel for default(shared) schedule(dynamic, 100)
163
+ for (int n = 0; n < ntotal; n++) {
164
+ graph[n].join([&](int i, int j) {
165
+ if (i != j) {
166
+ float dist = qdis.symmetric_dis(i, j);
167
+ graph[i].insert(j, dist);
168
+ graph[j].insert(i, dist);
169
+ }
170
+ });
171
+ }
172
+ }
173
+
174
+ /// Sample neighbors for each node to peform local join later
175
+ /// Store them in nn_new and nn_old
176
+ void NNDescent::update() {
177
+ // Step 1.
178
+ // Clear all nn_new and nn_old
179
+ #pragma omp parallel for
180
+ for (int i = 0; i < ntotal; i++) {
181
+ std::vector<int>().swap(graph[i].nn_new);
182
+ std::vector<int>().swap(graph[i].nn_old);
183
+ }
184
+
185
+ // Step 2.
186
+ // Compute the number of neighbors which is new i.e. flag is true
187
+ // in the candidate pool. This must not exceed the sample number S.
188
+ // That means We only select S new neighbors.
189
+ #pragma omp parallel for
190
+ for (int n = 0; n < ntotal; ++n) {
191
+ auto& nn = graph[n];
192
+ std::sort(nn.pool.begin(), nn.pool.end());
193
+
194
+ if (nn.pool.size() > L)
195
+ nn.pool.resize(L);
196
+ nn.pool.reserve(L); // keep the pool size be L
197
+
198
+ int maxl = std::min(nn.M + S, (int)nn.pool.size());
199
+ int c = 0;
200
+ int l = 0;
201
+
202
+ while ((l < maxl) && (c < S)) {
203
+ if (nn.pool[l].flag)
204
+ ++c;
205
+ ++l;
206
+ }
207
+ nn.M = l;
208
+ }
209
+
210
+ // Step 3.
211
+ // Find reverse links for each node
212
+ // Randomly choose R reverse links.
213
+ #pragma omp parallel
214
+ {
215
+ std::mt19937 rng(random_seed * 5081 + omp_get_thread_num());
216
+ #pragma omp for
217
+ for (int n = 0; n < ntotal; ++n) {
218
+ auto& node = graph[n];
219
+ auto& nn_new = node.nn_new;
220
+ auto& nn_old = node.nn_old;
221
+
222
+ for (int l = 0; l < node.M; ++l) {
223
+ auto& nn = node.pool[l];
224
+ auto& other = graph[nn.id]; // the other side of the edge
225
+
226
+ if (nn.flag) { // the node is inserted newly
227
+ // push the neighbor into nn_new
228
+ nn_new.push_back(nn.id);
229
+ // push itself into other.rnn_new if it is not in
230
+ // the candidate pool of the other side
231
+ if (nn.distance > other.pool.back().distance) {
232
+ LockGuard guard(other.lock);
233
+ if (other.rnn_new.size() < R) {
234
+ other.rnn_new.push_back(n);
235
+ } else {
236
+ int pos = rng() % R;
237
+ other.rnn_new[pos] = n;
238
+ }
239
+ }
240
+ nn.flag = false;
241
+
242
+ } else { // the node is old
243
+ // push the neighbor into nn_old
244
+ nn_old.push_back(nn.id);
245
+ // push itself into other.rnn_old if it is not in
246
+ // the candidate pool of the other side
247
+ if (nn.distance > other.pool.back().distance) {
248
+ LockGuard guard(other.lock);
249
+ if (other.rnn_old.size() < R) {
250
+ other.rnn_old.push_back(n);
251
+ } else {
252
+ int pos = rng() % R;
253
+ other.rnn_old[pos] = n;
254
+ }
255
+ }
256
+ }
257
+ }
258
+ // make heap to join later (in join() function)
259
+ std::make_heap(node.pool.begin(), node.pool.end());
260
+ }
261
+ }
262
+
263
+ // Step 4.
264
+ // Combine the forward and the reverse links
265
+ // R = 0 means no reverse links are used.
266
+ #pragma omp parallel for
267
+ for (int i = 0; i < ntotal; ++i) {
268
+ auto& nn_new = graph[i].nn_new;
269
+ auto& nn_old = graph[i].nn_old;
270
+ auto& rnn_new = graph[i].rnn_new;
271
+ auto& rnn_old = graph[i].rnn_old;
272
+
273
+ nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
274
+ nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
275
+ if (nn_old.size() > R * 2) {
276
+ nn_old.resize(R * 2);
277
+ nn_old.reserve(R * 2);
278
+ }
279
+
280
+ std::vector<int>().swap(graph[i].rnn_new);
281
+ std::vector<int>().swap(graph[i].rnn_old);
282
+ }
283
+ }
284
+
285
+ void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) {
286
+ int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal);
287
+ std::vector<int> eval_points(num_eval_points);
288
+ std::vector<std::vector<int>> acc_eval_set(num_eval_points);
289
+ std::mt19937 rng(random_seed * 6577 + omp_get_thread_num());
290
+ gen_random(rng, eval_points.data(), eval_points.size(), ntotal);
291
+ generate_eval_set(qdis, eval_points, acc_eval_set, ntotal);
292
+ for (int it = 0; it < iter; it++) {
293
+ join(qdis);
294
+ update();
295
+
296
+ if (verbose) {
297
+ float recall = eval_recall(eval_points, acc_eval_set);
298
+ printf("Iter: %d, recall@%d: %lf\n", it, K, recall);
299
+ }
300
+ }
301
+ }
302
+
303
+ /// Sample a small number of points to evaluate the quality of KNNG built
304
+ void NNDescent::generate_eval_set(
305
+ DistanceComputer& qdis,
306
+ std::vector<int>& c,
307
+ std::vector<std::vector<int>>& v,
308
+ int N) {
309
+ #pragma omp parallel for
310
+ for (int i = 0; i < c.size(); i++) {
311
+ std::vector<Neighbor> tmp;
312
+ for (int j = 0; j < N; j++) {
313
+ if (i == j)
314
+ continue; // skip itself
315
+ float dist = qdis.symmetric_dis(c[i], j);
316
+ tmp.push_back(Neighbor(j, dist, true));
317
+ }
318
+
319
+ std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end());
320
+ for (int j = 0; j < K; j++) {
321
+ v[i].push_back(tmp[j].id);
322
+ }
323
+ }
324
+ }
325
+
326
+ /// Evaluate the quality of KNNG built
327
+ float NNDescent::eval_recall(
328
+ std::vector<int>& eval_points,
329
+ std::vector<std::vector<int>>& acc_eval_set) {
330
+ float mean_acc = 0.0f;
331
+ for (size_t i = 0; i < eval_points.size(); i++) {
332
+ float acc = 0;
333
+ std::vector<Neighbor>& g = graph[eval_points[i]].pool;
334
+ std::vector<int>& v = acc_eval_set[i];
335
+ for (size_t j = 0; j < g.size(); j++) {
336
+ for (size_t k = 0; k < v.size(); k++) {
337
+ if (g[j].id == v[k]) {
338
+ acc++;
339
+ break;
340
+ }
341
+ }
342
+ }
343
+ mean_acc += acc / v.size();
344
+ }
345
+ return mean_acc / eval_points.size();
346
+ }
347
+
348
+ /// Initialize the KNN graph randomly
349
+ void NNDescent::init_graph(DistanceComputer& qdis) {
350
+ graph.reserve(ntotal);
351
+ {
352
+ std::mt19937 rng(random_seed * 6007);
353
+ for (int i = 0; i < ntotal; i++) {
354
+ graph.push_back(Nhood(L, S, rng, (int)ntotal));
355
+ }
356
+ }
357
+ #pragma omp parallel
358
+ {
359
+ std::mt19937 rng(random_seed * 7741 + omp_get_thread_num());
360
+ #pragma omp for
361
+ for (int i = 0; i < ntotal; i++) {
362
+ std::vector<int> tmp(S);
363
+
364
+ gen_random(rng, tmp.data(), S, ntotal);
365
+
366
+ for (int j = 0; j < S; j++) {
367
+ int id = tmp[j];
368
+ if (id == i)
369
+ continue;
370
+ float dist = qdis.symmetric_dis(i, id);
371
+
372
+ graph[i].pool.push_back(Neighbor(id, dist, true));
373
+ }
374
+ std::make_heap(graph[i].pool.begin(), graph[i].pool.end());
375
+ graph[i].pool.reserve(L);
376
+ }
377
+ }
378
+ }
379
+
380
+ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
381
+ FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
382
+
383
+ if (verbose) {
384
+ printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
385
+ K,
386
+ S,
387
+ R,
388
+ L,
389
+ iter);
390
+ }
391
+
392
+ ntotal = n;
393
+ init_graph(qdis);
394
+ nndescent(qdis, verbose);
395
+
396
+ final_graph.resize(ntotal * K);
397
+
398
+ // Store the neighbor link structure into final_graph
399
+ // Clear the old graph
400
+ for (int i = 0; i < ntotal; i++) {
401
+ std::sort(graph[i].pool.begin(), graph[i].pool.end());
402
+ for (int j = 0; j < K; j++) {
403
+ FAISS_ASSERT(graph[i].pool[j].id < ntotal);
404
+ final_graph[i * K + j] = graph[i].pool[j].id;
405
+ }
406
+ }
407
+ std::vector<Nhood>().swap(graph);
408
+ has_built = true;
409
+
410
+ if (verbose) {
411
+ printf("Addes %d points into the index\n", ntotal);
412
+ }
413
+ }
414
+
415
+ void NNDescent::search(
416
+ DistanceComputer& qdis,
417
+ const int topk,
418
+ idx_t* indices,
419
+ float* dists,
420
+ VisitedTable& vt) const {
421
+ FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
422
+ int L = std::max(search_L, topk);
423
+
424
+ // candidate pool, the K best items is the result.
425
+ std::vector<Neighbor> retset(L + 1);
426
+
427
+ // Randomly choose L points to intialize the candidate pool
428
+ std::vector<int> init_ids(L);
429
+ std::mt19937 rng(random_seed);
430
+
431
+ gen_random(rng, init_ids.data(), L, ntotal);
432
+ for (int i = 0; i < L; i++) {
433
+ int id = init_ids[i];
434
+ float dist = qdis(id);
435
+ retset[i] = Neighbor(id, dist, true);
436
+ }
437
+
438
+ // Maintain the candidate pool in ascending order
439
+ std::sort(retset.begin(), retset.begin() + L);
440
+
441
+ int k = 0;
442
+
443
+ // Stop until the smallest position updated is >= L
444
+ while (k < L) {
445
+ int nk = L;
446
+
447
+ if (retset[k].flag) {
448
+ retset[k].flag = false;
449
+ int n = retset[k].id;
450
+
451
+ for (int m = 0; m < K; ++m) {
452
+ int id = final_graph[n * K + m];
453
+ if (vt.get(id))
454
+ continue;
455
+
456
+ vt.set(id);
457
+ float dist = qdis(id);
458
+ if (dist >= retset[L - 1].distance)
459
+ continue;
460
+
461
+ Neighbor nn(id, dist, true);
462
+ int r = insert_into_pool(retset.data(), L, nn);
463
+
464
+ if (r < nk)
465
+ nk = r;
466
+ }
467
+ }
468
+ if (nk <= k)
469
+ k = nk;
470
+ else
471
+ ++k;
472
+ }
473
+ for (size_t i = 0; i < topk; i++) {
474
+ indices[i] = retset[i].id;
475
+ dists[i] = retset[i].distance;
476
+ }
477
+
478
+ vt.advance();
479
+ }
480
+
481
+ void NNDescent::reset() {
482
+ has_built = false;
483
+ ntotal = 0;
484
+ final_graph.resize(0);
485
+ }
486
+
487
+ } // namespace faiss
@@ -0,0 +1,154 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #pragma once
11
+
12
+ #include <algorithm>
13
+ #include <mutex>
14
+ #include <queue>
15
+ #include <random>
16
+ #include <unordered_set>
17
+ #include <vector>
18
+
19
+ #include <omp.h>
20
+
21
+ #include <faiss/Index.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/platform_macros.h>
24
+ #include <faiss/utils/Heap.h>
25
+ #include <faiss/utils/random.h>
26
+
27
+ namespace faiss {
28
+
29
+ /** Implementation of NNDescent which is one of the most popular
30
+ * KNN graph building algorithms
31
+ *
32
+ * Efficient K-Nearest Neighbor Graph Construction for Generic
33
+ * Similarity Measures
34
+ *
35
+ * Dong, Wei, Charikar Moses, and Kai Li, WWW 2011
36
+ *
37
+ * This implmentation is heavily influenced by the efanna
38
+ * implementation by Cong Fu and the KGraph library by Wei Dong
39
+ * (https://github.com/ZJULearning/efanna_graph)
40
+ * (https://github.com/aaalgo/kgraph)
41
+ *
42
+ * The NNDescent object stores only the neighbor link structure,
43
+ * see IndexNNDescent.h for the full index object.
44
+ */
45
+
46
+ struct VisitedTable;
47
+ struct DistanceComputer;
48
+
49
+ namespace nndescent {
50
+
51
+ struct Neighbor {
52
+ int id;
53
+ float distance;
54
+ bool flag;
55
+
56
+ Neighbor() = default;
57
+ Neighbor(int id, float distance, bool f)
58
+ : id(id), distance(distance), flag(f) {}
59
+
60
+ inline bool operator<(const Neighbor& other) const {
61
+ return distance < other.distance;
62
+ }
63
+ };
64
+
65
+ struct Nhood {
66
+ std::mutex lock;
67
+ std::vector<Neighbor> pool; // candidate pool (a max heap)
68
+ int M; // number of new neighbors to be operated
69
+
70
+ std::vector<int> nn_old; // old neighbors
71
+ std::vector<int> nn_new; // new neighbors
72
+ std::vector<int> rnn_old; // reverse old neighbors
73
+ std::vector<int> rnn_new; // reverse new neighbors
74
+
75
+ Nhood() = default;
76
+
77
+ Nhood(int l, int s, std::mt19937& rng, int N);
78
+
79
+ Nhood& operator=(const Nhood& other);
80
+
81
+ Nhood(const Nhood& other);
82
+
83
+ void insert(int id, float dist);
84
+
85
+ template <typename C>
86
+ void join(C callback) const;
87
+ };
88
+
89
+ } // namespace nndescent
90
+
91
+ struct NNDescent {
92
+ using storage_idx_t = int;
93
+ using idx_t = Index::idx_t;
94
+
95
+ using KNNGraph = std::vector<nndescent::Nhood>;
96
+
97
+ explicit NNDescent(const int d, const int K);
98
+
99
+ ~NNDescent();
100
+
101
+ void build(DistanceComputer& qdis, const int n, bool verbose);
102
+
103
+ void search(
104
+ DistanceComputer& qdis,
105
+ const int topk,
106
+ idx_t* indices,
107
+ float* dists,
108
+ VisitedTable& vt) const;
109
+
110
+ void reset();
111
+
112
+ /// Initialize the KNN graph randomly
113
+ void init_graph(DistanceComputer& qdis);
114
+
115
+ /// Perform NNDescent algorithm
116
+ void nndescent(DistanceComputer& qdis, bool verbose);
117
+
118
+ /// Perform local join on each node
119
+ void join(DistanceComputer& qdis);
120
+
121
+ /// Sample new neighbors for each node to peform local join later
122
+ void update();
123
+
124
+ /// Sample a small number of points to evaluate the quality of KNNG built
125
+ void generate_eval_set(
126
+ DistanceComputer& qdis,
127
+ std::vector<int>& c,
128
+ std::vector<std::vector<int>>& v,
129
+ int N);
130
+
131
+ /// Evaluate the quality of KNNG built
132
+ float eval_recall(
133
+ std::vector<int>& ctrl_points,
134
+ std::vector<std::vector<int>>& acc_eval_set);
135
+
136
+ bool has_built;
137
+
138
+ int K; // K in KNN graph
139
+ int S; // number of sample neighbors to be updated for each node
140
+ int R; // size of reverse links, 0 means the reverse links will not be used
141
+ int L; // size of the candidate pool in building
142
+ int iter; // number of iterations to iterate over
143
+ int search_L; // size of candidate pool in searching
144
+ int random_seed; // random seed for generators
145
+
146
+ int d; // dimensions
147
+
148
+ int ntotal;
149
+
150
+ KNNGraph graph;
151
+ std::vector<int> final_graph;
152
+ };
153
+
154
+ } // namespace faiss