faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,172 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <stdint.h>
11
+
12
+ #include <random>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/impl/AdditiveQuantizer.h>
18
+ #include <faiss/utils/utils.h>
19
+
20
+ namespace faiss {
21
+
22
+ /** Implementation of LSQ/LSQ++ described in the following two papers:
23
+ *
24
+ * Revisiting additive quantization
25
+ * Julieta Martinez, et al. ECCV 2016
26
+ *
27
+ * LSQ++: Lower running time and higher recall in multi-codebook quantization
28
+ * Julieta Martinez, et al. ECCV 2018
29
+ *
30
+ * This implementation is mostly translated from the Julia implementations
31
+ * by Julieta Martinez:
32
+ * (https://github.com/una-dinosauria/local-search-quantization,
33
+ * https://github.com/una-dinosauria/Rayuela.jl)
34
+ *
35
+ * The trained codes are stored in `codebooks` which is called
36
+ * `centroids` in PQ and RQ.
37
+ */
38
+
39
+ struct LocalSearchQuantizer : AdditiveQuantizer {
40
+ size_t K; ///< number of codes per codebook
41
+
42
+ size_t train_iters; ///< number of iterations in training
43
+
44
+ size_t encode_ils_iters; ///< iterations of local search in encoding
45
+ size_t train_ils_iters; ///< iterations of local search in training
46
+ size_t icm_iters; ///< number of iterations in icm
47
+
48
+ float p; ///< temperature factor
49
+ float lambd; ///< regularization factor
50
+
51
+ size_t chunk_size; ///< nb of vectors to encode at a time
52
+
53
+ int random_seed; ///< seed for random generator
54
+ size_t nperts; ///< number of perturbation in each code
55
+
56
+ LocalSearchQuantizer(
57
+ size_t d, /* dimensionality of the input vectors */
58
+ size_t M, /* number of subquantizers */
59
+ size_t nbits); /* number of bit per subvector index */
60
+
61
+ // Train the local search quantizer
62
+ void train(size_t n, const float* x) override;
63
+
64
+ /** Encode a set of vectors
65
+ *
66
+ * @param x vectors to encode, size n * d
67
+ * @param codes output codes, size n * code_size
68
+ */
69
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
70
+
71
+ /** Update codebooks given encodings
72
+ *
73
+ * @param x training vectors, size n * d
74
+ * @param codes encoded training vectors, size n * M
75
+ */
76
+ void update_codebooks(const float* x, const int32_t* codes, size_t n);
77
+
78
+ /** Encode vectors given codebooks using iterative conditional mode (icm).
79
+ *
80
+ * @param x vectors to encode, size n * d
81
+ * @param codes output codes, size n * M
82
+ * @param ils_iters number of iterations of iterative local search
83
+ */
84
+ void icm_encode(
85
+ const float* x,
86
+ int32_t* codes,
87
+ size_t n,
88
+ size_t ils_iters,
89
+ std::mt19937& gen) const;
90
+
91
+ void icm_encode_partial(
92
+ size_t index,
93
+ const float* x,
94
+ int32_t* codes,
95
+ size_t n,
96
+ const float* binaries,
97
+ size_t ils_iters,
98
+ std::mt19937& gen) const;
99
+
100
+ void icm_encode_step(
101
+ const float* unaries,
102
+ const float* binaries,
103
+ int32_t* codes,
104
+ size_t n) const;
105
+
106
+ /** Add some perturbation to codebooks
107
+ *
108
+ * @param T temperature of simulated annealing
109
+ * @param stddev standard derivations of each dimension in training data
110
+ */
111
+ void perturb_codebooks(
112
+ float T,
113
+ const std::vector<float>& stddev,
114
+ std::mt19937& gen);
115
+
116
+ /** Add some perturbation to codes
117
+ *
118
+ * @param codes codes to be perturbed, size n * M
119
+ */
120
+ void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
121
+
122
+ /** Compute binary terms
123
+ *
124
+ * @param binaries binary terms, size M * M * K * K
125
+ */
126
+ void compute_binary_terms(float* binaries) const;
127
+
128
+ /** Compute unary terms
129
+ *
130
+ * @param x vectors to encode, size n * d
131
+ * @param unaries unary terms, size n * M * K
132
+ */
133
+ void compute_unary_terms(const float* x, float* unaries, size_t n) const;
134
+
135
+ /** Helper function to compute reconstruction error
136
+ *
137
+ * @param x vectors to encode, size n * d
138
+ * @param codes encoded codes, size n * M
139
+ * @param objs if it is not null, store reconstruction
140
+ error of each vector into it, size n
141
+ */
142
+ float evaluate(
143
+ const int32_t* codes,
144
+ const float* x,
145
+ size_t n,
146
+ float* objs = nullptr) const;
147
+ };
148
+
149
+ /** A helper struct to count consuming time during training.
150
+ * It is NOT thread-safe.
151
+ */
152
+ struct LSQTimer {
153
+ std::unordered_map<std::string, double> duration;
154
+ std::unordered_map<std::string, double> t0;
155
+ std::unordered_map<std::string, bool> started;
156
+
157
+ LSQTimer() {
158
+ reset();
159
+ }
160
+
161
+ double get(const std::string& name);
162
+
163
+ void start(const std::string& name);
164
+
165
+ void end(const std::string& name);
166
+
167
+ void reset();
168
+ };
169
+
170
+ FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time
171
+
172
+ } // namespace faiss
@@ -0,0 +1,487 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NNDescent.h>
11
+
12
+ #include <mutex>
13
+ #include <string>
14
+
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+
17
+ namespace faiss {
18
+
19
+ using LockGuard = std::lock_guard<std::mutex>;
20
+
21
+ namespace nndescent {
22
+
23
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N);
24
+
25
+ Nhood::Nhood(int l, int s, std::mt19937& rng, int N) {
26
+ M = s;
27
+ nn_new.resize(s * 2);
28
+ gen_random(rng, nn_new.data(), (int)nn_new.size(), N);
29
+ }
30
+
31
+ /// Copy operator
32
+ Nhood& Nhood::operator=(const Nhood& other) {
33
+ M = other.M;
34
+ std::copy(
35
+ other.nn_new.begin(),
36
+ other.nn_new.end(),
37
+ std::back_inserter(nn_new));
38
+ nn_new.reserve(other.nn_new.capacity());
39
+ pool.reserve(other.pool.capacity());
40
+ return *this;
41
+ }
42
+
43
+ /// Copy constructor
44
+ Nhood::Nhood(const Nhood& other) {
45
+ M = other.M;
46
+ std::copy(
47
+ other.nn_new.begin(),
48
+ other.nn_new.end(),
49
+ std::back_inserter(nn_new));
50
+ nn_new.reserve(other.nn_new.capacity());
51
+ pool.reserve(other.pool.capacity());
52
+ }
53
+
54
+ /// Insert a point into the candidate pool
55
+ void Nhood::insert(int id, float dist) {
56
+ LockGuard guard(lock);
57
+ if (dist > pool.front().distance)
58
+ return;
59
+ for (int i = 0; i < pool.size(); i++) {
60
+ if (id == pool[i].id)
61
+ return;
62
+ }
63
+ if (pool.size() < pool.capacity()) {
64
+ pool.push_back(Neighbor(id, dist, true));
65
+ std::push_heap(pool.begin(), pool.end());
66
+ } else {
67
+ std::pop_heap(pool.begin(), pool.end());
68
+ pool[pool.size() - 1] = Neighbor(id, dist, true);
69
+ std::push_heap(pool.begin(), pool.end());
70
+ }
71
+ }
72
+
73
+ /// In local join, two objects are compared only if at least
74
+ /// one of them is new.
75
+ template <typename C>
76
+ void Nhood::join(C callback) const {
77
+ for (int const i : nn_new) {
78
+ for (int const j : nn_new) {
79
+ if (i < j) {
80
+ callback(i, j);
81
+ }
82
+ }
83
+ for (int j : nn_old) {
84
+ callback(i, j);
85
+ }
86
+ }
87
+ }
88
+
89
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N) {
90
+ for (int i = 0; i < size; ++i) {
91
+ addr[i] = rng() % (N - size);
92
+ }
93
+ std::sort(addr, addr + size);
94
+ for (int i = 1; i < size; ++i) {
95
+ if (addr[i] <= addr[i - 1]) {
96
+ addr[i] = addr[i - 1] + 1;
97
+ }
98
+ }
99
+ int off = rng() % N;
100
+ for (int i = 0; i < size; ++i) {
101
+ addr[i] = (addr[i] + off) % N;
102
+ }
103
+ }
104
+
105
+ // Insert a new point into the candidate pool in ascending order
106
+ int insert_into_pool(Neighbor* addr, int size, Neighbor nn) {
107
+ // find the location to insert
108
+ int left = 0, right = size - 1;
109
+ if (addr[left].distance > nn.distance) {
110
+ memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor));
111
+ addr[left] = nn;
112
+ return left;
113
+ }
114
+ if (addr[right].distance < nn.distance) {
115
+ addr[size] = nn;
116
+ return size;
117
+ }
118
+ while (left < right - 1) {
119
+ int mid = (left + right) / 2;
120
+ if (addr[mid].distance > nn.distance)
121
+ right = mid;
122
+ else
123
+ left = mid;
124
+ }
125
+ // check equal ID
126
+
127
+ while (left > 0) {
128
+ if (addr[left].distance < nn.distance)
129
+ break;
130
+ if (addr[left].id == nn.id)
131
+ return size + 1;
132
+ left--;
133
+ }
134
+ if (addr[left].id == nn.id || addr[right].id == nn.id)
135
+ return size + 1;
136
+ memmove((char*)&addr[right + 1],
137
+ &addr[right],
138
+ (size - right) * sizeof(Neighbor));
139
+ addr[right] = nn;
140
+ return right;
141
+ }
142
+
143
+ } // namespace nndescent
144
+
145
+ using namespace nndescent;
146
+
147
+ constexpr int NUM_EVAL_POINTS = 100;
148
+
149
+ NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
150
+ ntotal = 0;
151
+ has_built = false;
152
+ S = 10;
153
+ R = 100;
154
+ L = K + 50;
155
+ iter = 10;
156
+ search_L = 0;
157
+ }
158
+
159
+ NNDescent::~NNDescent() {}
160
+
161
+ void NNDescent::join(DistanceComputer& qdis) {
162
+ #pragma omp parallel for default(shared) schedule(dynamic, 100)
163
+ for (int n = 0; n < ntotal; n++) {
164
+ graph[n].join([&](int i, int j) {
165
+ if (i != j) {
166
+ float dist = qdis.symmetric_dis(i, j);
167
+ graph[i].insert(j, dist);
168
+ graph[j].insert(i, dist);
169
+ }
170
+ });
171
+ }
172
+ }
173
+
174
+ /// Sample neighbors for each node to peform local join later
175
+ /// Store them in nn_new and nn_old
176
+ void NNDescent::update() {
177
+ // Step 1.
178
+ // Clear all nn_new and nn_old
179
+ #pragma omp parallel for
180
+ for (int i = 0; i < ntotal; i++) {
181
+ std::vector<int>().swap(graph[i].nn_new);
182
+ std::vector<int>().swap(graph[i].nn_old);
183
+ }
184
+
185
+ // Step 2.
186
+ // Compute the number of neighbors which is new i.e. flag is true
187
+ // in the candidate pool. This must not exceed the sample number S.
188
+ // That means We only select S new neighbors.
189
+ #pragma omp parallel for
190
+ for (int n = 0; n < ntotal; ++n) {
191
+ auto& nn = graph[n];
192
+ std::sort(nn.pool.begin(), nn.pool.end());
193
+
194
+ if (nn.pool.size() > L)
195
+ nn.pool.resize(L);
196
+ nn.pool.reserve(L); // keep the pool size be L
197
+
198
+ int maxl = std::min(nn.M + S, (int)nn.pool.size());
199
+ int c = 0;
200
+ int l = 0;
201
+
202
+ while ((l < maxl) && (c < S)) {
203
+ if (nn.pool[l].flag)
204
+ ++c;
205
+ ++l;
206
+ }
207
+ nn.M = l;
208
+ }
209
+
210
+ // Step 3.
211
+ // Find reverse links for each node
212
+ // Randomly choose R reverse links.
213
+ #pragma omp parallel
214
+ {
215
+ std::mt19937 rng(random_seed * 5081 + omp_get_thread_num());
216
+ #pragma omp for
217
+ for (int n = 0; n < ntotal; ++n) {
218
+ auto& node = graph[n];
219
+ auto& nn_new = node.nn_new;
220
+ auto& nn_old = node.nn_old;
221
+
222
+ for (int l = 0; l < node.M; ++l) {
223
+ auto& nn = node.pool[l];
224
+ auto& other = graph[nn.id]; // the other side of the edge
225
+
226
+ if (nn.flag) { // the node is inserted newly
227
+ // push the neighbor into nn_new
228
+ nn_new.push_back(nn.id);
229
+ // push itself into other.rnn_new if it is not in
230
+ // the candidate pool of the other side
231
+ if (nn.distance > other.pool.back().distance) {
232
+ LockGuard guard(other.lock);
233
+ if (other.rnn_new.size() < R) {
234
+ other.rnn_new.push_back(n);
235
+ } else {
236
+ int pos = rng() % R;
237
+ other.rnn_new[pos] = n;
238
+ }
239
+ }
240
+ nn.flag = false;
241
+
242
+ } else { // the node is old
243
+ // push the neighbor into nn_old
244
+ nn_old.push_back(nn.id);
245
+ // push itself into other.rnn_old if it is not in
246
+ // the candidate pool of the other side
247
+ if (nn.distance > other.pool.back().distance) {
248
+ LockGuard guard(other.lock);
249
+ if (other.rnn_old.size() < R) {
250
+ other.rnn_old.push_back(n);
251
+ } else {
252
+ int pos = rng() % R;
253
+ other.rnn_old[pos] = n;
254
+ }
255
+ }
256
+ }
257
+ }
258
+ // make heap to join later (in join() function)
259
+ std::make_heap(node.pool.begin(), node.pool.end());
260
+ }
261
+ }
262
+
263
+ // Step 4.
264
+ // Combine the forward and the reverse links
265
+ // R = 0 means no reverse links are used.
266
+ #pragma omp parallel for
267
+ for (int i = 0; i < ntotal; ++i) {
268
+ auto& nn_new = graph[i].nn_new;
269
+ auto& nn_old = graph[i].nn_old;
270
+ auto& rnn_new = graph[i].rnn_new;
271
+ auto& rnn_old = graph[i].rnn_old;
272
+
273
+ nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
274
+ nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
275
+ if (nn_old.size() > R * 2) {
276
+ nn_old.resize(R * 2);
277
+ nn_old.reserve(R * 2);
278
+ }
279
+
280
+ std::vector<int>().swap(graph[i].rnn_new);
281
+ std::vector<int>().swap(graph[i].rnn_old);
282
+ }
283
+ }
284
+
285
+ void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) {
286
+ int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal);
287
+ std::vector<int> eval_points(num_eval_points);
288
+ std::vector<std::vector<int>> acc_eval_set(num_eval_points);
289
+ std::mt19937 rng(random_seed * 6577 + omp_get_thread_num());
290
+ gen_random(rng, eval_points.data(), eval_points.size(), ntotal);
291
+ generate_eval_set(qdis, eval_points, acc_eval_set, ntotal);
292
+ for (int it = 0; it < iter; it++) {
293
+ join(qdis);
294
+ update();
295
+
296
+ if (verbose) {
297
+ float recall = eval_recall(eval_points, acc_eval_set);
298
+ printf("Iter: %d, recall@%d: %lf\n", it, K, recall);
299
+ }
300
+ }
301
+ }
302
+
303
+ /// Sample a small number of points to evaluate the quality of KNNG built
304
+ void NNDescent::generate_eval_set(
305
+ DistanceComputer& qdis,
306
+ std::vector<int>& c,
307
+ std::vector<std::vector<int>>& v,
308
+ int N) {
309
+ #pragma omp parallel for
310
+ for (int i = 0; i < c.size(); i++) {
311
+ std::vector<Neighbor> tmp;
312
+ for (int j = 0; j < N; j++) {
313
+ if (i == j)
314
+ continue; // skip itself
315
+ float dist = qdis.symmetric_dis(c[i], j);
316
+ tmp.push_back(Neighbor(j, dist, true));
317
+ }
318
+
319
+ std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end());
320
+ for (int j = 0; j < K; j++) {
321
+ v[i].push_back(tmp[j].id);
322
+ }
323
+ }
324
+ }
325
+
326
+ /// Evaluate the quality of KNNG built
327
+ float NNDescent::eval_recall(
328
+ std::vector<int>& eval_points,
329
+ std::vector<std::vector<int>>& acc_eval_set) {
330
+ float mean_acc = 0.0f;
331
+ for (size_t i = 0; i < eval_points.size(); i++) {
332
+ float acc = 0;
333
+ std::vector<Neighbor>& g = graph[eval_points[i]].pool;
334
+ std::vector<int>& v = acc_eval_set[i];
335
+ for (size_t j = 0; j < g.size(); j++) {
336
+ for (size_t k = 0; k < v.size(); k++) {
337
+ if (g[j].id == v[k]) {
338
+ acc++;
339
+ break;
340
+ }
341
+ }
342
+ }
343
+ mean_acc += acc / v.size();
344
+ }
345
+ return mean_acc / eval_points.size();
346
+ }
347
+
348
+ /// Initialize the KNN graph randomly
349
+ void NNDescent::init_graph(DistanceComputer& qdis) {
350
+ graph.reserve(ntotal);
351
+ {
352
+ std::mt19937 rng(random_seed * 6007);
353
+ for (int i = 0; i < ntotal; i++) {
354
+ graph.push_back(Nhood(L, S, rng, (int)ntotal));
355
+ }
356
+ }
357
+ #pragma omp parallel
358
+ {
359
+ std::mt19937 rng(random_seed * 7741 + omp_get_thread_num());
360
+ #pragma omp for
361
+ for (int i = 0; i < ntotal; i++) {
362
+ std::vector<int> tmp(S);
363
+
364
+ gen_random(rng, tmp.data(), S, ntotal);
365
+
366
+ for (int j = 0; j < S; j++) {
367
+ int id = tmp[j];
368
+ if (id == i)
369
+ continue;
370
+ float dist = qdis.symmetric_dis(i, id);
371
+
372
+ graph[i].pool.push_back(Neighbor(id, dist, true));
373
+ }
374
+ std::make_heap(graph[i].pool.begin(), graph[i].pool.end());
375
+ graph[i].pool.reserve(L);
376
+ }
377
+ }
378
+ }
379
+
380
+ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
381
+ FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
382
+
383
+ if (verbose) {
384
+ printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
385
+ K,
386
+ S,
387
+ R,
388
+ L,
389
+ iter);
390
+ }
391
+
392
+ ntotal = n;
393
+ init_graph(qdis);
394
+ nndescent(qdis, verbose);
395
+
396
+ final_graph.resize(ntotal * K);
397
+
398
+ // Store the neighbor link structure into final_graph
399
+ // Clear the old graph
400
+ for (int i = 0; i < ntotal; i++) {
401
+ std::sort(graph[i].pool.begin(), graph[i].pool.end());
402
+ for (int j = 0; j < K; j++) {
403
+ FAISS_ASSERT(graph[i].pool[j].id < ntotal);
404
+ final_graph[i * K + j] = graph[i].pool[j].id;
405
+ }
406
+ }
407
+ std::vector<Nhood>().swap(graph);
408
+ has_built = true;
409
+
410
+ if (verbose) {
411
+ printf("Addes %d points into the index\n", ntotal);
412
+ }
413
+ }
414
+
415
+ void NNDescent::search(
416
+ DistanceComputer& qdis,
417
+ const int topk,
418
+ idx_t* indices,
419
+ float* dists,
420
+ VisitedTable& vt) const {
421
+ FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
422
+ int L = std::max(search_L, topk);
423
+
424
+ // candidate pool, the K best items is the result.
425
+ std::vector<Neighbor> retset(L + 1);
426
+
427
+ // Randomly choose L points to intialize the candidate pool
428
+ std::vector<int> init_ids(L);
429
+ std::mt19937 rng(random_seed);
430
+
431
+ gen_random(rng, init_ids.data(), L, ntotal);
432
+ for (int i = 0; i < L; i++) {
433
+ int id = init_ids[i];
434
+ float dist = qdis(id);
435
+ retset[i] = Neighbor(id, dist, true);
436
+ }
437
+
438
+ // Maintain the candidate pool in ascending order
439
+ std::sort(retset.begin(), retset.begin() + L);
440
+
441
+ int k = 0;
442
+
443
+ // Stop until the smallest position updated is >= L
444
+ while (k < L) {
445
+ int nk = L;
446
+
447
+ if (retset[k].flag) {
448
+ retset[k].flag = false;
449
+ int n = retset[k].id;
450
+
451
+ for (int m = 0; m < K; ++m) {
452
+ int id = final_graph[n * K + m];
453
+ if (vt.get(id))
454
+ continue;
455
+
456
+ vt.set(id);
457
+ float dist = qdis(id);
458
+ if (dist >= retset[L - 1].distance)
459
+ continue;
460
+
461
+ Neighbor nn(id, dist, true);
462
+ int r = insert_into_pool(retset.data(), L, nn);
463
+
464
+ if (r < nk)
465
+ nk = r;
466
+ }
467
+ }
468
+ if (nk <= k)
469
+ k = nk;
470
+ else
471
+ ++k;
472
+ }
473
+ for (size_t i = 0; i < topk; i++) {
474
+ indices[i] = retset[i].id;
475
+ dists[i] = retset[i].distance;
476
+ }
477
+
478
+ vt.advance();
479
+ }
480
+
481
+ void NNDescent::reset() {
482
+ has_built = false;
483
+ ntotal = 0;
484
+ final_graph.resize(0);
485
+ }
486
+
487
+ } // namespace faiss