faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,172 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <stdint.h>
11
+
12
+ #include <random>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/impl/AdditiveQuantizer.h>
18
+ #include <faiss/utils/utils.h>
19
+
20
+ namespace faiss {
21
+
22
+ /** Implementation of LSQ/LSQ++ described in the following two papers:
23
+ *
24
+ * Revisiting additive quantization
25
+ * Julieta Martinez, et al. ECCV 2016
26
+ *
27
+ * LSQ++: Lower running time and higher recall in multi-codebook quantization
28
+ * Julieta Martinez, et al. ECCV 2018
29
+ *
30
+ * This implementation is mostly translated from the Julia implementations
31
+ * by Julieta Martinez:
32
+ * (https://github.com/una-dinosauria/local-search-quantization,
33
+ * https://github.com/una-dinosauria/Rayuela.jl)
34
+ *
35
+ * The trained codes are stored in `codebooks` which is called
36
+ * `centroids` in PQ and RQ.
37
+ */
38
+
39
+ struct LocalSearchQuantizer : AdditiveQuantizer {
40
+ size_t K; ///< number of codes per codebook
41
+
42
+ size_t train_iters; ///< number of iterations in training
43
+
44
+ size_t encode_ils_iters; ///< iterations of local search in encoding
45
+ size_t train_ils_iters; ///< iterations of local search in training
46
+ size_t icm_iters; ///< number of iterations in icm
47
+
48
+ float p; ///< temperature factor
49
+ float lambd; ///< regularization factor
50
+
51
+ size_t chunk_size; ///< nb of vectors to encode at a time
52
+
53
+ int random_seed; ///< seed for random generator
54
+ size_t nperts; ///< number of perturbation in each code
55
+
56
+ LocalSearchQuantizer(
57
+ size_t d, /* dimensionality of the input vectors */
58
+ size_t M, /* number of subquantizers */
59
+ size_t nbits); /* number of bit per subvector index */
60
+
61
+ // Train the local search quantizer
62
+ void train(size_t n, const float* x) override;
63
+
64
+ /** Encode a set of vectors
65
+ *
66
+ * @param x vectors to encode, size n * d
67
+ * @param codes output codes, size n * code_size
68
+ */
69
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
70
+
71
+ /** Update codebooks given encodings
72
+ *
73
+ * @param x training vectors, size n * d
74
+ * @param codes encoded training vectors, size n * M
75
+ */
76
+ void update_codebooks(const float* x, const int32_t* codes, size_t n);
77
+
78
+ /** Encode vectors given codebooks using iterative conditional mode (icm).
79
+ *
80
+ * @param x vectors to encode, size n * d
81
+ * @param codes output codes, size n * M
82
+ * @param ils_iters number of iterations of iterative local search
83
+ */
84
+ void icm_encode(
85
+ const float* x,
86
+ int32_t* codes,
87
+ size_t n,
88
+ size_t ils_iters,
89
+ std::mt19937& gen) const;
90
+
91
+ void icm_encode_partial(
92
+ size_t index,
93
+ const float* x,
94
+ int32_t* codes,
95
+ size_t n,
96
+ const float* binaries,
97
+ size_t ils_iters,
98
+ std::mt19937& gen) const;
99
+
100
+ void icm_encode_step(
101
+ const float* unaries,
102
+ const float* binaries,
103
+ int32_t* codes,
104
+ size_t n) const;
105
+
106
+ /** Add some perturbation to codebooks
107
+ *
108
+ * @param T temperature of simulated annealing
109
+ * @param stddev standard derivations of each dimension in training data
110
+ */
111
+ void perturb_codebooks(
112
+ float T,
113
+ const std::vector<float>& stddev,
114
+ std::mt19937& gen);
115
+
116
+ /** Add some perturbation to codes
117
+ *
118
+ * @param codes codes to be perturbed, size n * M
119
+ */
120
+ void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
121
+
122
+ /** Compute binary terms
123
+ *
124
+ * @param binaries binary terms, size M * M * K * K
125
+ */
126
+ void compute_binary_terms(float* binaries) const;
127
+
128
+ /** Compute unary terms
129
+ *
130
+ * @param x vectors to encode, size n * d
131
+ * @param unaries unary terms, size n * M * K
132
+ */
133
+ void compute_unary_terms(const float* x, float* unaries, size_t n) const;
134
+
135
+ /** Helper function to compute reconstruction error
136
+ *
137
+ * @param x vectors to encode, size n * d
138
+ * @param codes encoded codes, size n * M
139
+ * @param objs if it is not null, store reconstruction
140
+ error of each vector into it, size n
141
+ */
142
+ float evaluate(
143
+ const int32_t* codes,
144
+ const float* x,
145
+ size_t n,
146
+ float* objs = nullptr) const;
147
+ };
148
+
149
+ /** A helper struct to count consuming time during training.
150
+ * It is NOT thread-safe.
151
+ */
152
+ struct LSQTimer {
153
+ std::unordered_map<std::string, double> duration;
154
+ std::unordered_map<std::string, double> t0;
155
+ std::unordered_map<std::string, bool> started;
156
+
157
+ LSQTimer() {
158
+ reset();
159
+ }
160
+
161
+ double get(const std::string& name);
162
+
163
+ void start(const std::string& name);
164
+
165
+ void end(const std::string& name);
166
+
167
+ void reset();
168
+ };
169
+
170
+ FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time
171
+
172
+ } // namespace faiss
@@ -0,0 +1,487 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NNDescent.h>
11
+
12
+ #include <mutex>
13
+ #include <string>
14
+
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+
17
+ namespace faiss {
18
+
19
+ using LockGuard = std::lock_guard<std::mutex>;
20
+
21
+ namespace nndescent {
22
+
23
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N);
24
+
25
+ Nhood::Nhood(int l, int s, std::mt19937& rng, int N) {
26
+ M = s;
27
+ nn_new.resize(s * 2);
28
+ gen_random(rng, nn_new.data(), (int)nn_new.size(), N);
29
+ }
30
+
31
+ /// Copy operator
32
+ Nhood& Nhood::operator=(const Nhood& other) {
33
+ M = other.M;
34
+ std::copy(
35
+ other.nn_new.begin(),
36
+ other.nn_new.end(),
37
+ std::back_inserter(nn_new));
38
+ nn_new.reserve(other.nn_new.capacity());
39
+ pool.reserve(other.pool.capacity());
40
+ return *this;
41
+ }
42
+
43
+ /// Copy constructor
44
+ Nhood::Nhood(const Nhood& other) {
45
+ M = other.M;
46
+ std::copy(
47
+ other.nn_new.begin(),
48
+ other.nn_new.end(),
49
+ std::back_inserter(nn_new));
50
+ nn_new.reserve(other.nn_new.capacity());
51
+ pool.reserve(other.pool.capacity());
52
+ }
53
+
54
+ /// Insert a point into the candidate pool
55
+ void Nhood::insert(int id, float dist) {
56
+ LockGuard guard(lock);
57
+ if (dist > pool.front().distance)
58
+ return;
59
+ for (int i = 0; i < pool.size(); i++) {
60
+ if (id == pool[i].id)
61
+ return;
62
+ }
63
+ if (pool.size() < pool.capacity()) {
64
+ pool.push_back(Neighbor(id, dist, true));
65
+ std::push_heap(pool.begin(), pool.end());
66
+ } else {
67
+ std::pop_heap(pool.begin(), pool.end());
68
+ pool[pool.size() - 1] = Neighbor(id, dist, true);
69
+ std::push_heap(pool.begin(), pool.end());
70
+ }
71
+ }
72
+
73
+ /// In local join, two objects are compared only if at least
74
+ /// one of them is new.
75
+ template <typename C>
76
+ void Nhood::join(C callback) const {
77
+ for (int const i : nn_new) {
78
+ for (int const j : nn_new) {
79
+ if (i < j) {
80
+ callback(i, j);
81
+ }
82
+ }
83
+ for (int j : nn_old) {
84
+ callback(i, j);
85
+ }
86
+ }
87
+ }
88
+
89
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N) {
90
+ for (int i = 0; i < size; ++i) {
91
+ addr[i] = rng() % (N - size);
92
+ }
93
+ std::sort(addr, addr + size);
94
+ for (int i = 1; i < size; ++i) {
95
+ if (addr[i] <= addr[i - 1]) {
96
+ addr[i] = addr[i - 1] + 1;
97
+ }
98
+ }
99
+ int off = rng() % N;
100
+ for (int i = 0; i < size; ++i) {
101
+ addr[i] = (addr[i] + off) % N;
102
+ }
103
+ }
104
+
105
+ // Insert a new point into the candidate pool in ascending order
106
+ int insert_into_pool(Neighbor* addr, int size, Neighbor nn) {
107
+ // find the location to insert
108
+ int left = 0, right = size - 1;
109
+ if (addr[left].distance > nn.distance) {
110
+ memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor));
111
+ addr[left] = nn;
112
+ return left;
113
+ }
114
+ if (addr[right].distance < nn.distance) {
115
+ addr[size] = nn;
116
+ return size;
117
+ }
118
+ while (left < right - 1) {
119
+ int mid = (left + right) / 2;
120
+ if (addr[mid].distance > nn.distance)
121
+ right = mid;
122
+ else
123
+ left = mid;
124
+ }
125
+ // check equal ID
126
+
127
+ while (left > 0) {
128
+ if (addr[left].distance < nn.distance)
129
+ break;
130
+ if (addr[left].id == nn.id)
131
+ return size + 1;
132
+ left--;
133
+ }
134
+ if (addr[left].id == nn.id || addr[right].id == nn.id)
135
+ return size + 1;
136
+ memmove((char*)&addr[right + 1],
137
+ &addr[right],
138
+ (size - right) * sizeof(Neighbor));
139
+ addr[right] = nn;
140
+ return right;
141
+ }
142
+
143
+ } // namespace nndescent
144
+
145
+ using namespace nndescent;
146
+
147
+ constexpr int NUM_EVAL_POINTS = 100;
148
+
149
+ NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
150
+ ntotal = 0;
151
+ has_built = false;
152
+ S = 10;
153
+ R = 100;
154
+ L = K + 50;
155
+ iter = 10;
156
+ search_L = 0;
157
+ }
158
+
159
+ NNDescent::~NNDescent() {}
160
+
161
+ void NNDescent::join(DistanceComputer& qdis) {
162
+ #pragma omp parallel for default(shared) schedule(dynamic, 100)
163
+ for (int n = 0; n < ntotal; n++) {
164
+ graph[n].join([&](int i, int j) {
165
+ if (i != j) {
166
+ float dist = qdis.symmetric_dis(i, j);
167
+ graph[i].insert(j, dist);
168
+ graph[j].insert(i, dist);
169
+ }
170
+ });
171
+ }
172
+ }
173
+
174
+ /// Sample neighbors for each node to peform local join later
175
+ /// Store them in nn_new and nn_old
176
+ void NNDescent::update() {
177
+ // Step 1.
178
+ // Clear all nn_new and nn_old
179
+ #pragma omp parallel for
180
+ for (int i = 0; i < ntotal; i++) {
181
+ std::vector<int>().swap(graph[i].nn_new);
182
+ std::vector<int>().swap(graph[i].nn_old);
183
+ }
184
+
185
+ // Step 2.
186
+ // Compute the number of neighbors which is new i.e. flag is true
187
+ // in the candidate pool. This must not exceed the sample number S.
188
+ // That means We only select S new neighbors.
189
+ #pragma omp parallel for
190
+ for (int n = 0; n < ntotal; ++n) {
191
+ auto& nn = graph[n];
192
+ std::sort(nn.pool.begin(), nn.pool.end());
193
+
194
+ if (nn.pool.size() > L)
195
+ nn.pool.resize(L);
196
+ nn.pool.reserve(L); // keep the pool size be L
197
+
198
+ int maxl = std::min(nn.M + S, (int)nn.pool.size());
199
+ int c = 0;
200
+ int l = 0;
201
+
202
+ while ((l < maxl) && (c < S)) {
203
+ if (nn.pool[l].flag)
204
+ ++c;
205
+ ++l;
206
+ }
207
+ nn.M = l;
208
+ }
209
+
210
+ // Step 3.
211
+ // Find reverse links for each node
212
+ // Randomly choose R reverse links.
213
+ #pragma omp parallel
214
+ {
215
+ std::mt19937 rng(random_seed * 5081 + omp_get_thread_num());
216
+ #pragma omp for
217
+ for (int n = 0; n < ntotal; ++n) {
218
+ auto& node = graph[n];
219
+ auto& nn_new = node.nn_new;
220
+ auto& nn_old = node.nn_old;
221
+
222
+ for (int l = 0; l < node.M; ++l) {
223
+ auto& nn = node.pool[l];
224
+ auto& other = graph[nn.id]; // the other side of the edge
225
+
226
+ if (nn.flag) { // the node is inserted newly
227
+ // push the neighbor into nn_new
228
+ nn_new.push_back(nn.id);
229
+ // push itself into other.rnn_new if it is not in
230
+ // the candidate pool of the other side
231
+ if (nn.distance > other.pool.back().distance) {
232
+ LockGuard guard(other.lock);
233
+ if (other.rnn_new.size() < R) {
234
+ other.rnn_new.push_back(n);
235
+ } else {
236
+ int pos = rng() % R;
237
+ other.rnn_new[pos] = n;
238
+ }
239
+ }
240
+ nn.flag = false;
241
+
242
+ } else { // the node is old
243
+ // push the neighbor into nn_old
244
+ nn_old.push_back(nn.id);
245
+ // push itself into other.rnn_old if it is not in
246
+ // the candidate pool of the other side
247
+ if (nn.distance > other.pool.back().distance) {
248
+ LockGuard guard(other.lock);
249
+ if (other.rnn_old.size() < R) {
250
+ other.rnn_old.push_back(n);
251
+ } else {
252
+ int pos = rng() % R;
253
+ other.rnn_old[pos] = n;
254
+ }
255
+ }
256
+ }
257
+ }
258
+ // make heap to join later (in join() function)
259
+ std::make_heap(node.pool.begin(), node.pool.end());
260
+ }
261
+ }
262
+
263
+ // Step 4.
264
+ // Combine the forward and the reverse links
265
+ // R = 0 means no reverse links are used.
266
+ #pragma omp parallel for
267
+ for (int i = 0; i < ntotal; ++i) {
268
+ auto& nn_new = graph[i].nn_new;
269
+ auto& nn_old = graph[i].nn_old;
270
+ auto& rnn_new = graph[i].rnn_new;
271
+ auto& rnn_old = graph[i].rnn_old;
272
+
273
+ nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
274
+ nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
275
+ if (nn_old.size() > R * 2) {
276
+ nn_old.resize(R * 2);
277
+ nn_old.reserve(R * 2);
278
+ }
279
+
280
+ std::vector<int>().swap(graph[i].rnn_new);
281
+ std::vector<int>().swap(graph[i].rnn_old);
282
+ }
283
+ }
284
+
285
+ void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) {
286
+ int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal);
287
+ std::vector<int> eval_points(num_eval_points);
288
+ std::vector<std::vector<int>> acc_eval_set(num_eval_points);
289
+ std::mt19937 rng(random_seed * 6577 + omp_get_thread_num());
290
+ gen_random(rng, eval_points.data(), eval_points.size(), ntotal);
291
+ generate_eval_set(qdis, eval_points, acc_eval_set, ntotal);
292
+ for (int it = 0; it < iter; it++) {
293
+ join(qdis);
294
+ update();
295
+
296
+ if (verbose) {
297
+ float recall = eval_recall(eval_points, acc_eval_set);
298
+ printf("Iter: %d, recall@%d: %lf\n", it, K, recall);
299
+ }
300
+ }
301
+ }
302
+
303
+ /// Sample a small number of points to evaluate the quality of KNNG built
304
+ void NNDescent::generate_eval_set(
305
+ DistanceComputer& qdis,
306
+ std::vector<int>& c,
307
+ std::vector<std::vector<int>>& v,
308
+ int N) {
309
+ #pragma omp parallel for
310
+ for (int i = 0; i < c.size(); i++) {
311
+ std::vector<Neighbor> tmp;
312
+ for (int j = 0; j < N; j++) {
313
+ if (i == j)
314
+ continue; // skip itself
315
+ float dist = qdis.symmetric_dis(c[i], j);
316
+ tmp.push_back(Neighbor(j, dist, true));
317
+ }
318
+
319
+ std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end());
320
+ for (int j = 0; j < K; j++) {
321
+ v[i].push_back(tmp[j].id);
322
+ }
323
+ }
324
+ }
325
+
326
+ /// Evaluate the quality of KNNG built
327
+ float NNDescent::eval_recall(
328
+ std::vector<int>& eval_points,
329
+ std::vector<std::vector<int>>& acc_eval_set) {
330
+ float mean_acc = 0.0f;
331
+ for (size_t i = 0; i < eval_points.size(); i++) {
332
+ float acc = 0;
333
+ std::vector<Neighbor>& g = graph[eval_points[i]].pool;
334
+ std::vector<int>& v = acc_eval_set[i];
335
+ for (size_t j = 0; j < g.size(); j++) {
336
+ for (size_t k = 0; k < v.size(); k++) {
337
+ if (g[j].id == v[k]) {
338
+ acc++;
339
+ break;
340
+ }
341
+ }
342
+ }
343
+ mean_acc += acc / v.size();
344
+ }
345
+ return mean_acc / eval_points.size();
346
+ }
347
+
348
+ /// Initialize the KNN graph randomly
349
+ void NNDescent::init_graph(DistanceComputer& qdis) {
350
+ graph.reserve(ntotal);
351
+ {
352
+ std::mt19937 rng(random_seed * 6007);
353
+ for (int i = 0; i < ntotal; i++) {
354
+ graph.push_back(Nhood(L, S, rng, (int)ntotal));
355
+ }
356
+ }
357
+ #pragma omp parallel
358
+ {
359
+ std::mt19937 rng(random_seed * 7741 + omp_get_thread_num());
360
+ #pragma omp for
361
+ for (int i = 0; i < ntotal; i++) {
362
+ std::vector<int> tmp(S);
363
+
364
+ gen_random(rng, tmp.data(), S, ntotal);
365
+
366
+ for (int j = 0; j < S; j++) {
367
+ int id = tmp[j];
368
+ if (id == i)
369
+ continue;
370
+ float dist = qdis.symmetric_dis(i, id);
371
+
372
+ graph[i].pool.push_back(Neighbor(id, dist, true));
373
+ }
374
+ std::make_heap(graph[i].pool.begin(), graph[i].pool.end());
375
+ graph[i].pool.reserve(L);
376
+ }
377
+ }
378
+ }
379
+
380
+ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
381
+ FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
382
+
383
+ if (verbose) {
384
+ printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
385
+ K,
386
+ S,
387
+ R,
388
+ L,
389
+ iter);
390
+ }
391
+
392
+ ntotal = n;
393
+ init_graph(qdis);
394
+ nndescent(qdis, verbose);
395
+
396
+ final_graph.resize(ntotal * K);
397
+
398
+ // Store the neighbor link structure into final_graph
399
+ // Clear the old graph
400
+ for (int i = 0; i < ntotal; i++) {
401
+ std::sort(graph[i].pool.begin(), graph[i].pool.end());
402
+ for (int j = 0; j < K; j++) {
403
+ FAISS_ASSERT(graph[i].pool[j].id < ntotal);
404
+ final_graph[i * K + j] = graph[i].pool[j].id;
405
+ }
406
+ }
407
+ std::vector<Nhood>().swap(graph);
408
+ has_built = true;
409
+
410
+ if (verbose) {
411
+ printf("Addes %d points into the index\n", ntotal);
412
+ }
413
+ }
414
+
415
+ void NNDescent::search(
416
+ DistanceComputer& qdis,
417
+ const int topk,
418
+ idx_t* indices,
419
+ float* dists,
420
+ VisitedTable& vt) const {
421
+ FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
422
+ int L = std::max(search_L, topk);
423
+
424
+ // candidate pool, the K best items is the result.
425
+ std::vector<Neighbor> retset(L + 1);
426
+
427
+ // Randomly choose L points to intialize the candidate pool
428
+ std::vector<int> init_ids(L);
429
+ std::mt19937 rng(random_seed);
430
+
431
+ gen_random(rng, init_ids.data(), L, ntotal);
432
+ for (int i = 0; i < L; i++) {
433
+ int id = init_ids[i];
434
+ float dist = qdis(id);
435
+ retset[i] = Neighbor(id, dist, true);
436
+ }
437
+
438
+ // Maintain the candidate pool in ascending order
439
+ std::sort(retset.begin(), retset.begin() + L);
440
+
441
+ int k = 0;
442
+
443
+ // Stop until the smallest position updated is >= L
444
+ while (k < L) {
445
+ int nk = L;
446
+
447
+ if (retset[k].flag) {
448
+ retset[k].flag = false;
449
+ int n = retset[k].id;
450
+
451
+ for (int m = 0; m < K; ++m) {
452
+ int id = final_graph[n * K + m];
453
+ if (vt.get(id))
454
+ continue;
455
+
456
+ vt.set(id);
457
+ float dist = qdis(id);
458
+ if (dist >= retset[L - 1].distance)
459
+ continue;
460
+
461
+ Neighbor nn(id, dist, true);
462
+ int r = insert_into_pool(retset.data(), L, nn);
463
+
464
+ if (r < nk)
465
+ nk = r;
466
+ }
467
+ }
468
+ if (nk <= k)
469
+ k = nk;
470
+ else
471
+ ++k;
472
+ }
473
+ for (size_t i = 0; i < topk; i++) {
474
+ indices[i] = retset[i].id;
475
+ dists[i] = retset[i].distance;
476
+ }
477
+
478
+ vt.advance();
479
+ }
480
+
481
+ void NNDescent::reset() {
482
+ has_built = false;
483
+ ntotal = 0;
484
+ final_graph.resize(0);
485
+ }
486
+
487
+ } // namespace faiss