faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -0,0 +1,172 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <stdint.h>
11
+
12
+ #include <random>
13
+ #include <string>
14
+ #include <unordered_map>
15
+ #include <vector>
16
+
17
+ #include <faiss/impl/AdditiveQuantizer.h>
18
+ #include <faiss/utils/utils.h>
19
+
20
+ namespace faiss {
21
+
22
+ /** Implementation of LSQ/LSQ++ described in the following two papers:
23
+ *
24
+ * Revisiting additive quantization
25
+ * Julieta Martinez, et al. ECCV 2016
26
+ *
27
+ * LSQ++: Lower running time and higher recall in multi-codebook quantization
28
+ * Julieta Martinez, et al. ECCV 2018
29
+ *
30
+ * This implementation is mostly translated from the Julia implementations
31
+ * by Julieta Martinez:
32
+ * (https://github.com/una-dinosauria/local-search-quantization,
33
+ * https://github.com/una-dinosauria/Rayuela.jl)
34
+ *
35
+ * The trained codes are stored in `codebooks` which is called
36
+ * `centroids` in PQ and RQ.
37
+ */
38
+
39
+ struct LocalSearchQuantizer : AdditiveQuantizer {
40
+ size_t K; ///< number of codes per codebook
41
+
42
+ size_t train_iters; ///< number of iterations in training
43
+
44
+ size_t encode_ils_iters; ///< iterations of local search in encoding
45
+ size_t train_ils_iters; ///< iterations of local search in training
46
+ size_t icm_iters; ///< number of iterations in icm
47
+
48
+ float p; ///< temperature factor
49
+ float lambd; ///< regularization factor
50
+
51
+ size_t chunk_size; ///< nb of vectors to encode at a time
52
+
53
+ int random_seed; ///< seed for random generator
54
+ size_t nperts; ///< number of perturbation in each code
55
+
56
+ LocalSearchQuantizer(
57
+ size_t d, /* dimensionality of the input vectors */
58
+ size_t M, /* number of subquantizers */
59
+ size_t nbits); /* number of bit per subvector index */
60
+
61
+ // Train the local search quantizer
62
+ void train(size_t n, const float* x) override;
63
+
64
+ /** Encode a set of vectors
65
+ *
66
+ * @param x vectors to encode, size n * d
67
+ * @param codes output codes, size n * code_size
68
+ */
69
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
70
+
71
+ /** Update codebooks given encodings
72
+ *
73
+ * @param x training vectors, size n * d
74
+ * @param codes encoded training vectors, size n * M
75
+ */
76
+ void update_codebooks(const float* x, const int32_t* codes, size_t n);
77
+
78
+ /** Encode vectors given codebooks using iterative conditional mode (icm).
79
+ *
80
+ * @param x vectors to encode, size n * d
81
+ * @param codes output codes, size n * M
82
+ * @param ils_iters number of iterations of iterative local search
83
+ */
84
+ void icm_encode(
85
+ const float* x,
86
+ int32_t* codes,
87
+ size_t n,
88
+ size_t ils_iters,
89
+ std::mt19937& gen) const;
90
+
91
+ void icm_encode_partial(
92
+ size_t index,
93
+ const float* x,
94
+ int32_t* codes,
95
+ size_t n,
96
+ const float* binaries,
97
+ size_t ils_iters,
98
+ std::mt19937& gen) const;
99
+
100
+ void icm_encode_step(
101
+ const float* unaries,
102
+ const float* binaries,
103
+ int32_t* codes,
104
+ size_t n) const;
105
+
106
+ /** Add some perturbation to codebooks
107
+ *
108
+ * @param T temperature of simulated annealing
109
+ * @param stddev standard derivations of each dimension in training data
110
+ */
111
+ void perturb_codebooks(
112
+ float T,
113
+ const std::vector<float>& stddev,
114
+ std::mt19937& gen);
115
+
116
+ /** Add some perturbation to codes
117
+ *
118
+ * @param codes codes to be perturbed, size n * M
119
+ */
120
+ void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
121
+
122
+ /** Compute binary terms
123
+ *
124
+ * @param binaries binary terms, size M * M * K * K
125
+ */
126
+ void compute_binary_terms(float* binaries) const;
127
+
128
+ /** Compute unary terms
129
+ *
130
+ * @param x vectors to encode, size n * d
131
+ * @param unaries unary terms, size n * M * K
132
+ */
133
+ void compute_unary_terms(const float* x, float* unaries, size_t n) const;
134
+
135
+ /** Helper function to compute reconstruction error
136
+ *
137
+ * @param x vectors to encode, size n * d
138
+ * @param codes encoded codes, size n * M
139
+ * @param objs if it is not null, store reconstruction
140
+ error of each vector into it, size n
141
+ */
142
+ float evaluate(
143
+ const int32_t* codes,
144
+ const float* x,
145
+ size_t n,
146
+ float* objs = nullptr) const;
147
+ };
148
+
149
+ /** A helper struct to count consuming time during training.
150
+ * It is NOT thread-safe.
151
+ */
152
+ struct LSQTimer {
153
+ std::unordered_map<std::string, double> duration;
154
+ std::unordered_map<std::string, double> t0;
155
+ std::unordered_map<std::string, bool> started;
156
+
157
+ LSQTimer() {
158
+ reset();
159
+ }
160
+
161
+ double get(const std::string& name);
162
+
163
+ void start(const std::string& name);
164
+
165
+ void end(const std::string& name);
166
+
167
+ void reset();
168
+ };
169
+
170
+ FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time
171
+
172
+ } // namespace faiss
@@ -0,0 +1,487 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/NNDescent.h>
11
+
12
+ #include <mutex>
13
+ #include <string>
14
+
15
+ #include <faiss/impl/AuxIndexStructures.h>
16
+
17
+ namespace faiss {
18
+
19
+ using LockGuard = std::lock_guard<std::mutex>;
20
+
21
+ namespace nndescent {
22
+
23
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N);
24
+
25
+ Nhood::Nhood(int l, int s, std::mt19937& rng, int N) {
26
+ M = s;
27
+ nn_new.resize(s * 2);
28
+ gen_random(rng, nn_new.data(), (int)nn_new.size(), N);
29
+ }
30
+
31
+ /// Copy operator
32
+ Nhood& Nhood::operator=(const Nhood& other) {
33
+ M = other.M;
34
+ std::copy(
35
+ other.nn_new.begin(),
36
+ other.nn_new.end(),
37
+ std::back_inserter(nn_new));
38
+ nn_new.reserve(other.nn_new.capacity());
39
+ pool.reserve(other.pool.capacity());
40
+ return *this;
41
+ }
42
+
43
+ /// Copy constructor
44
+ Nhood::Nhood(const Nhood& other) {
45
+ M = other.M;
46
+ std::copy(
47
+ other.nn_new.begin(),
48
+ other.nn_new.end(),
49
+ std::back_inserter(nn_new));
50
+ nn_new.reserve(other.nn_new.capacity());
51
+ pool.reserve(other.pool.capacity());
52
+ }
53
+
54
+ /// Insert a point into the candidate pool
55
+ void Nhood::insert(int id, float dist) {
56
+ LockGuard guard(lock);
57
+ if (dist > pool.front().distance)
58
+ return;
59
+ for (int i = 0; i < pool.size(); i++) {
60
+ if (id == pool[i].id)
61
+ return;
62
+ }
63
+ if (pool.size() < pool.capacity()) {
64
+ pool.push_back(Neighbor(id, dist, true));
65
+ std::push_heap(pool.begin(), pool.end());
66
+ } else {
67
+ std::pop_heap(pool.begin(), pool.end());
68
+ pool[pool.size() - 1] = Neighbor(id, dist, true);
69
+ std::push_heap(pool.begin(), pool.end());
70
+ }
71
+ }
72
+
73
+ /// In local join, two objects are compared only if at least
74
+ /// one of them is new.
75
+ template <typename C>
76
+ void Nhood::join(C callback) const {
77
+ for (int const i : nn_new) {
78
+ for (int const j : nn_new) {
79
+ if (i < j) {
80
+ callback(i, j);
81
+ }
82
+ }
83
+ for (int j : nn_old) {
84
+ callback(i, j);
85
+ }
86
+ }
87
+ }
88
+
89
+ void gen_random(std::mt19937& rng, int* addr, const int size, const int N) {
90
+ for (int i = 0; i < size; ++i) {
91
+ addr[i] = rng() % (N - size);
92
+ }
93
+ std::sort(addr, addr + size);
94
+ for (int i = 1; i < size; ++i) {
95
+ if (addr[i] <= addr[i - 1]) {
96
+ addr[i] = addr[i - 1] + 1;
97
+ }
98
+ }
99
+ int off = rng() % N;
100
+ for (int i = 0; i < size; ++i) {
101
+ addr[i] = (addr[i] + off) % N;
102
+ }
103
+ }
104
+
105
+ // Insert a new point into the candidate pool in ascending order
106
+ int insert_into_pool(Neighbor* addr, int size, Neighbor nn) {
107
+ // find the location to insert
108
+ int left = 0, right = size - 1;
109
+ if (addr[left].distance > nn.distance) {
110
+ memmove((char*)&addr[left + 1], &addr[left], size * sizeof(Neighbor));
111
+ addr[left] = nn;
112
+ return left;
113
+ }
114
+ if (addr[right].distance < nn.distance) {
115
+ addr[size] = nn;
116
+ return size;
117
+ }
118
+ while (left < right - 1) {
119
+ int mid = (left + right) / 2;
120
+ if (addr[mid].distance > nn.distance)
121
+ right = mid;
122
+ else
123
+ left = mid;
124
+ }
125
+ // check equal ID
126
+
127
+ while (left > 0) {
128
+ if (addr[left].distance < nn.distance)
129
+ break;
130
+ if (addr[left].id == nn.id)
131
+ return size + 1;
132
+ left--;
133
+ }
134
+ if (addr[left].id == nn.id || addr[right].id == nn.id)
135
+ return size + 1;
136
+ memmove((char*)&addr[right + 1],
137
+ &addr[right],
138
+ (size - right) * sizeof(Neighbor));
139
+ addr[right] = nn;
140
+ return right;
141
+ }
142
+
143
+ } // namespace nndescent
144
+
145
+ using namespace nndescent;
146
+
147
+ constexpr int NUM_EVAL_POINTS = 100;
148
+
149
+ NNDescent::NNDescent(const int d, const int K) : K(K), random_seed(2021), d(d) {
150
+ ntotal = 0;
151
+ has_built = false;
152
+ S = 10;
153
+ R = 100;
154
+ L = K + 50;
155
+ iter = 10;
156
+ search_L = 0;
157
+ }
158
+
159
+ NNDescent::~NNDescent() {}
160
+
161
+ void NNDescent::join(DistanceComputer& qdis) {
162
+ #pragma omp parallel for default(shared) schedule(dynamic, 100)
163
+ for (int n = 0; n < ntotal; n++) {
164
+ graph[n].join([&](int i, int j) {
165
+ if (i != j) {
166
+ float dist = qdis.symmetric_dis(i, j);
167
+ graph[i].insert(j, dist);
168
+ graph[j].insert(i, dist);
169
+ }
170
+ });
171
+ }
172
+ }
173
+
174
+ /// Sample neighbors for each node to peform local join later
175
+ /// Store them in nn_new and nn_old
176
+ void NNDescent::update() {
177
+ // Step 1.
178
+ // Clear all nn_new and nn_old
179
+ #pragma omp parallel for
180
+ for (int i = 0; i < ntotal; i++) {
181
+ std::vector<int>().swap(graph[i].nn_new);
182
+ std::vector<int>().swap(graph[i].nn_old);
183
+ }
184
+
185
+ // Step 2.
186
+ // Compute the number of neighbors which is new i.e. flag is true
187
+ // in the candidate pool. This must not exceed the sample number S.
188
+ // That means We only select S new neighbors.
189
+ #pragma omp parallel for
190
+ for (int n = 0; n < ntotal; ++n) {
191
+ auto& nn = graph[n];
192
+ std::sort(nn.pool.begin(), nn.pool.end());
193
+
194
+ if (nn.pool.size() > L)
195
+ nn.pool.resize(L);
196
+ nn.pool.reserve(L); // keep the pool size be L
197
+
198
+ int maxl = std::min(nn.M + S, (int)nn.pool.size());
199
+ int c = 0;
200
+ int l = 0;
201
+
202
+ while ((l < maxl) && (c < S)) {
203
+ if (nn.pool[l].flag)
204
+ ++c;
205
+ ++l;
206
+ }
207
+ nn.M = l;
208
+ }
209
+
210
+ // Step 3.
211
+ // Find reverse links for each node
212
+ // Randomly choose R reverse links.
213
+ #pragma omp parallel
214
+ {
215
+ std::mt19937 rng(random_seed * 5081 + omp_get_thread_num());
216
+ #pragma omp for
217
+ for (int n = 0; n < ntotal; ++n) {
218
+ auto& node = graph[n];
219
+ auto& nn_new = node.nn_new;
220
+ auto& nn_old = node.nn_old;
221
+
222
+ for (int l = 0; l < node.M; ++l) {
223
+ auto& nn = node.pool[l];
224
+ auto& other = graph[nn.id]; // the other side of the edge
225
+
226
+ if (nn.flag) { // the node is inserted newly
227
+ // push the neighbor into nn_new
228
+ nn_new.push_back(nn.id);
229
+ // push itself into other.rnn_new if it is not in
230
+ // the candidate pool of the other side
231
+ if (nn.distance > other.pool.back().distance) {
232
+ LockGuard guard(other.lock);
233
+ if (other.rnn_new.size() < R) {
234
+ other.rnn_new.push_back(n);
235
+ } else {
236
+ int pos = rng() % R;
237
+ other.rnn_new[pos] = n;
238
+ }
239
+ }
240
+ nn.flag = false;
241
+
242
+ } else { // the node is old
243
+ // push the neighbor into nn_old
244
+ nn_old.push_back(nn.id);
245
+ // push itself into other.rnn_old if it is not in
246
+ // the candidate pool of the other side
247
+ if (nn.distance > other.pool.back().distance) {
248
+ LockGuard guard(other.lock);
249
+ if (other.rnn_old.size() < R) {
250
+ other.rnn_old.push_back(n);
251
+ } else {
252
+ int pos = rng() % R;
253
+ other.rnn_old[pos] = n;
254
+ }
255
+ }
256
+ }
257
+ }
258
+ // make heap to join later (in join() function)
259
+ std::make_heap(node.pool.begin(), node.pool.end());
260
+ }
261
+ }
262
+
263
+ // Step 4.
264
+ // Combine the forward and the reverse links
265
+ // R = 0 means no reverse links are used.
266
+ #pragma omp parallel for
267
+ for (int i = 0; i < ntotal; ++i) {
268
+ auto& nn_new = graph[i].nn_new;
269
+ auto& nn_old = graph[i].nn_old;
270
+ auto& rnn_new = graph[i].rnn_new;
271
+ auto& rnn_old = graph[i].rnn_old;
272
+
273
+ nn_new.insert(nn_new.end(), rnn_new.begin(), rnn_new.end());
274
+ nn_old.insert(nn_old.end(), rnn_old.begin(), rnn_old.end());
275
+ if (nn_old.size() > R * 2) {
276
+ nn_old.resize(R * 2);
277
+ nn_old.reserve(R * 2);
278
+ }
279
+
280
+ std::vector<int>().swap(graph[i].rnn_new);
281
+ std::vector<int>().swap(graph[i].rnn_old);
282
+ }
283
+ }
284
+
285
+ void NNDescent::nndescent(DistanceComputer& qdis, bool verbose) {
286
+ int num_eval_points = std::min(NUM_EVAL_POINTS, ntotal);
287
+ std::vector<int> eval_points(num_eval_points);
288
+ std::vector<std::vector<int>> acc_eval_set(num_eval_points);
289
+ std::mt19937 rng(random_seed * 6577 + omp_get_thread_num());
290
+ gen_random(rng, eval_points.data(), eval_points.size(), ntotal);
291
+ generate_eval_set(qdis, eval_points, acc_eval_set, ntotal);
292
+ for (int it = 0; it < iter; it++) {
293
+ join(qdis);
294
+ update();
295
+
296
+ if (verbose) {
297
+ float recall = eval_recall(eval_points, acc_eval_set);
298
+ printf("Iter: %d, recall@%d: %lf\n", it, K, recall);
299
+ }
300
+ }
301
+ }
302
+
303
+ /// Sample a small number of points to evaluate the quality of KNNG built
304
+ void NNDescent::generate_eval_set(
305
+ DistanceComputer& qdis,
306
+ std::vector<int>& c,
307
+ std::vector<std::vector<int>>& v,
308
+ int N) {
309
+ #pragma omp parallel for
310
+ for (int i = 0; i < c.size(); i++) {
311
+ std::vector<Neighbor> tmp;
312
+ for (int j = 0; j < N; j++) {
313
+ if (i == j)
314
+ continue; // skip itself
315
+ float dist = qdis.symmetric_dis(c[i], j);
316
+ tmp.push_back(Neighbor(j, dist, true));
317
+ }
318
+
319
+ std::partial_sort(tmp.begin(), tmp.begin() + K, tmp.end());
320
+ for (int j = 0; j < K; j++) {
321
+ v[i].push_back(tmp[j].id);
322
+ }
323
+ }
324
+ }
325
+
326
+ /// Evaluate the quality of KNNG built
327
+ float NNDescent::eval_recall(
328
+ std::vector<int>& eval_points,
329
+ std::vector<std::vector<int>>& acc_eval_set) {
330
+ float mean_acc = 0.0f;
331
+ for (size_t i = 0; i < eval_points.size(); i++) {
332
+ float acc = 0;
333
+ std::vector<Neighbor>& g = graph[eval_points[i]].pool;
334
+ std::vector<int>& v = acc_eval_set[i];
335
+ for (size_t j = 0; j < g.size(); j++) {
336
+ for (size_t k = 0; k < v.size(); k++) {
337
+ if (g[j].id == v[k]) {
338
+ acc++;
339
+ break;
340
+ }
341
+ }
342
+ }
343
+ mean_acc += acc / v.size();
344
+ }
345
+ return mean_acc / eval_points.size();
346
+ }
347
+
348
+ /// Initialize the KNN graph randomly
349
+ void NNDescent::init_graph(DistanceComputer& qdis) {
350
+ graph.reserve(ntotal);
351
+ {
352
+ std::mt19937 rng(random_seed * 6007);
353
+ for (int i = 0; i < ntotal; i++) {
354
+ graph.push_back(Nhood(L, S, rng, (int)ntotal));
355
+ }
356
+ }
357
+ #pragma omp parallel
358
+ {
359
+ std::mt19937 rng(random_seed * 7741 + omp_get_thread_num());
360
+ #pragma omp for
361
+ for (int i = 0; i < ntotal; i++) {
362
+ std::vector<int> tmp(S);
363
+
364
+ gen_random(rng, tmp.data(), S, ntotal);
365
+
366
+ for (int j = 0; j < S; j++) {
367
+ int id = tmp[j];
368
+ if (id == i)
369
+ continue;
370
+ float dist = qdis.symmetric_dis(i, id);
371
+
372
+ graph[i].pool.push_back(Neighbor(id, dist, true));
373
+ }
374
+ std::make_heap(graph[i].pool.begin(), graph[i].pool.end());
375
+ graph[i].pool.reserve(L);
376
+ }
377
+ }
378
+ }
379
+
380
+ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
381
+ FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
382
+
383
+ if (verbose) {
384
+ printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
385
+ K,
386
+ S,
387
+ R,
388
+ L,
389
+ iter);
390
+ }
391
+
392
+ ntotal = n;
393
+ init_graph(qdis);
394
+ nndescent(qdis, verbose);
395
+
396
+ final_graph.resize(ntotal * K);
397
+
398
+ // Store the neighbor link structure into final_graph
399
+ // Clear the old graph
400
+ for (int i = 0; i < ntotal; i++) {
401
+ std::sort(graph[i].pool.begin(), graph[i].pool.end());
402
+ for (int j = 0; j < K; j++) {
403
+ FAISS_ASSERT(graph[i].pool[j].id < ntotal);
404
+ final_graph[i * K + j] = graph[i].pool[j].id;
405
+ }
406
+ }
407
+ std::vector<Nhood>().swap(graph);
408
+ has_built = true;
409
+
410
+ if (verbose) {
411
+ printf("Addes %d points into the index\n", ntotal);
412
+ }
413
+ }
414
+
415
+ void NNDescent::search(
416
+ DistanceComputer& qdis,
417
+ const int topk,
418
+ idx_t* indices,
419
+ float* dists,
420
+ VisitedTable& vt) const {
421
+ FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
422
+ int L = std::max(search_L, topk);
423
+
424
+ // candidate pool, the K best items is the result.
425
+ std::vector<Neighbor> retset(L + 1);
426
+
427
+ // Randomly choose L points to intialize the candidate pool
428
+ std::vector<int> init_ids(L);
429
+ std::mt19937 rng(random_seed);
430
+
431
+ gen_random(rng, init_ids.data(), L, ntotal);
432
+ for (int i = 0; i < L; i++) {
433
+ int id = init_ids[i];
434
+ float dist = qdis(id);
435
+ retset[i] = Neighbor(id, dist, true);
436
+ }
437
+
438
+ // Maintain the candidate pool in ascending order
439
+ std::sort(retset.begin(), retset.begin() + L);
440
+
441
+ int k = 0;
442
+
443
+ // Stop until the smallest position updated is >= L
444
+ while (k < L) {
445
+ int nk = L;
446
+
447
+ if (retset[k].flag) {
448
+ retset[k].flag = false;
449
+ int n = retset[k].id;
450
+
451
+ for (int m = 0; m < K; ++m) {
452
+ int id = final_graph[n * K + m];
453
+ if (vt.get(id))
454
+ continue;
455
+
456
+ vt.set(id);
457
+ float dist = qdis(id);
458
+ if (dist >= retset[L - 1].distance)
459
+ continue;
460
+
461
+ Neighbor nn(id, dist, true);
462
+ int r = insert_into_pool(retset.data(), L, nn);
463
+
464
+ if (r < nk)
465
+ nk = r;
466
+ }
467
+ }
468
+ if (nk <= k)
469
+ k = nk;
470
+ else
471
+ ++k;
472
+ }
473
+ for (size_t i = 0; i < topk; i++) {
474
+ indices[i] = retset[i].id;
475
+ dists[i] = retset[i].distance;
476
+ }
477
+
478
+ vt.advance();
479
+ }
480
+
481
+ void NNDescent::reset() {
482
+ has_built = false;
483
+ ntotal = 0;
484
+ final_graph.resize(0);
485
+ }
486
+
487
+ } // namespace faiss