faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,175 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef INDEX_FLAT_H
11
+ #define INDEX_FLAT_H
12
+
13
+ #include <vector>
14
+
15
+ #include <faiss/Index.h>
16
+
17
+
18
+ namespace faiss {
19
+
20
+ /** Index that stores the full vectors and performs exhaustive search */
21
+ struct IndexFlat: Index {
22
+ /// database vectors, size ntotal * d
23
+ std::vector<float> xb;
24
+
25
+ explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2);
26
+
27
+ void add(idx_t n, const float* x) override;
28
+
29
+ void reset() override;
30
+
31
+ void search(
32
+ idx_t n,
33
+ const float* x,
34
+ idx_t k,
35
+ float* distances,
36
+ idx_t* labels) const override;
37
+
38
+ void range_search(
39
+ idx_t n,
40
+ const float* x,
41
+ float radius,
42
+ RangeSearchResult* result) const override;
43
+
44
+ void reconstruct(idx_t key, float* recons) const override;
45
+
46
+ /** compute distance with a subset of vectors
47
+ *
48
+ * @param x query vectors, size n * d
49
+ * @param labels indices of the vectors that should be compared
50
+ * for each query vector, size n * k
51
+ * @param distances
52
+ * corresponding output distances, size n * k
53
+ */
54
+ void compute_distance_subset (
55
+ idx_t n,
56
+ const float *x,
57
+ idx_t k,
58
+ float *distances,
59
+ const idx_t *labels) const;
60
+
61
+ /** remove some ids. NB that Because of the structure of the
62
+ * indexing structure, the semantics of this operation are
63
+ * different from the usual ones: the new ids are shifted */
64
+ size_t remove_ids(const IDSelector& sel) override;
65
+
66
+ IndexFlat () {}
67
+
68
+ DistanceComputer * get_distance_computer() const override;
69
+
70
+ /* The stanadlone codec interface (just memcopies in this case) */
71
+ size_t sa_code_size () const override;
72
+
73
+ void sa_encode (idx_t n, const float *x,
74
+ uint8_t *bytes) const override;
75
+
76
+ void sa_decode (idx_t n, const uint8_t *bytes,
77
+ float *x) const override;
78
+
79
+ };
80
+
81
+
82
+
83
+ struct IndexFlatIP:IndexFlat {
84
+ explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
85
+ IndexFlatIP () {}
86
+ };
87
+
88
+
89
+ struct IndexFlatL2:IndexFlat {
90
+ explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
91
+ IndexFlatL2 () {}
92
+ };
93
+
94
+
95
+ // same as an IndexFlatL2 but a value is subtracted from each distance
96
+ struct IndexFlatL2BaseShift: IndexFlatL2 {
97
+ std::vector<float> shift;
98
+
99
+ IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
100
+
101
+ void search(
102
+ idx_t n,
103
+ const float* x,
104
+ idx_t k,
105
+ float* distances,
106
+ idx_t* labels) const override;
107
+ };
108
+
109
+
110
+ /** Index that queries in a base_index (a fast one) and refines the
111
+ * results with an exact search, hopefully improving the results.
112
+ */
113
+ struct IndexRefineFlat: Index {
114
+
115
+ /// storage for full vectors
116
+ IndexFlat refine_index;
117
+
118
+ /// faster index to pre-select the vectors that should be filtered
119
+ Index *base_index;
120
+ bool own_fields; ///< should the base index be deallocated?
121
+
122
+ /// factor between k requested in search and the k requested from
123
+ /// the base_index (should be >= 1)
124
+ float k_factor;
125
+
126
+ explicit IndexRefineFlat (Index *base_index);
127
+
128
+ IndexRefineFlat ();
129
+
130
+ void train(idx_t n, const float* x) override;
131
+
132
+ void add(idx_t n, const float* x) override;
133
+
134
+ void reset() override;
135
+
136
+ void search(
137
+ idx_t n,
138
+ const float* x,
139
+ idx_t k,
140
+ float* distances,
141
+ idx_t* labels) const override;
142
+
143
+ ~IndexRefineFlat() override;
144
+ };
145
+
146
+
147
+ /// optimized version for 1D "vectors"
148
+ struct IndexFlat1D:IndexFlatL2 {
149
+ bool continuous_update; ///< is the permutation updated continuously?
150
+
151
+ std::vector<idx_t> perm; ///< sorted database indices
152
+
153
+ explicit IndexFlat1D (bool continuous_update=true);
154
+
155
+ /// if not continuous_update, call this between the last add and
156
+ /// the first search
157
+ void update_permutation ();
158
+
159
+ void add(idx_t n, const float* x) override;
160
+
161
+ void reset() override;
162
+
163
+ /// Warn: the distances returned are L1 not L2
164
+ void search(
165
+ idx_t n,
166
+ const float* x,
167
+ idx_t k,
168
+ float* distances,
169
+ idx_t* labels) const override;
170
+ };
171
+
172
+
173
+ }
174
+
175
+ #endif
@@ -0,0 +1,1090 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexHNSW.h>
11
+
12
+
13
+ #include <cstdlib>
14
+ #include <cassert>
15
+ #include <cstring>
16
+ #include <cstdio>
17
+ #include <cmath>
18
+ #include <omp.h>
19
+
20
+ #include <unordered_set>
21
+ #include <queue>
22
+
23
+ #include <sys/types.h>
24
+ #include <sys/stat.h>
25
+ #include <unistd.h>
26
+ #include <stdint.h>
27
+
28
+ #ifdef __SSE__
29
+ #include <immintrin.h>
30
+ #endif
31
+
32
+ #include <faiss/utils/distances.h>
33
+ #include <faiss/utils/random.h>
34
+ #include <faiss/utils/Heap.h>
35
+ #include <faiss/impl/FaissAssert.h>
36
+ #include <faiss/IndexFlat.h>
37
+ #include <faiss/IndexIVFPQ.h>
38
+ #include <faiss/Index2Layer.h>
39
+ #include <faiss/impl/AuxIndexStructures.h>
40
+
41
+
42
+ extern "C" {
43
+
44
+ /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
45
+
46
+ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
47
+ n, FINTEGER *k, const float *alpha, const float *a,
48
+ FINTEGER *lda, const float *b, FINTEGER *
49
+ ldb, float *beta, float *c, FINTEGER *ldc);
50
+
51
+ }
52
+
53
+ namespace faiss {
54
+
55
+ using idx_t = Index::idx_t;
56
+ using MinimaxHeap = HNSW::MinimaxHeap;
57
+ using storage_idx_t = HNSW::storage_idx_t;
58
+ using NodeDistCloser = HNSW::NodeDistCloser;
59
+ using NodeDistFarther = HNSW::NodeDistFarther;
60
+
61
+ HNSWStats hnsw_stats;
62
+
63
+ /**************************************************************
64
+ * add / search blocks of descriptors
65
+ **************************************************************/
66
+
67
+ namespace {
68
+
69
+
70
+ void hnsw_add_vertices(IndexHNSW &index_hnsw,
71
+ size_t n0,
72
+ size_t n, const float *x,
73
+ bool verbose,
74
+ bool preset_levels = false) {
75
+ size_t d = index_hnsw.d;
76
+ HNSW & hnsw = index_hnsw.hnsw;
77
+ size_t ntotal = n0 + n;
78
+ double t0 = getmillisecs();
79
+ if (verbose) {
80
+ printf("hnsw_add_vertices: adding %ld elements on top of %ld "
81
+ "(preset_levels=%d)\n",
82
+ n, n0, int(preset_levels));
83
+ }
84
+
85
+ if (n == 0) {
86
+ return;
87
+ }
88
+
89
+ int max_level = hnsw.prepare_level_tab(n, preset_levels);
90
+
91
+ if (verbose) {
92
+ printf(" max_level = %d\n", max_level);
93
+ }
94
+
95
+ std::vector<omp_lock_t> locks(ntotal);
96
+ for(int i = 0; i < ntotal; i++)
97
+ omp_init_lock(&locks[i]);
98
+
99
+ // add vectors from highest to lowest level
100
+ std::vector<int> hist;
101
+ std::vector<int> order(n);
102
+
103
+ { // make buckets with vectors of the same level
104
+
105
+ // build histogram
106
+ for (int i = 0; i < n; i++) {
107
+ storage_idx_t pt_id = i + n0;
108
+ int pt_level = hnsw.levels[pt_id] - 1;
109
+ while (pt_level >= hist.size())
110
+ hist.push_back(0);
111
+ hist[pt_level] ++;
112
+ }
113
+
114
+ // accumulate
115
+ std::vector<int> offsets(hist.size() + 1, 0);
116
+ for (int i = 0; i < hist.size() - 1; i++) {
117
+ offsets[i + 1] = offsets[i] + hist[i];
118
+ }
119
+
120
+ // bucket sort
121
+ for (int i = 0; i < n; i++) {
122
+ storage_idx_t pt_id = i + n0;
123
+ int pt_level = hnsw.levels[pt_id] - 1;
124
+ order[offsets[pt_level]++] = pt_id;
125
+ }
126
+ }
127
+
128
+ idx_t check_period = InterruptCallback::get_period_hint
129
+ (max_level * index_hnsw.d * hnsw.efConstruction);
130
+
131
+ { // perform add
132
+ RandomGenerator rng2(789);
133
+
134
+ int i1 = n;
135
+
136
+ for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
137
+ int i0 = i1 - hist[pt_level];
138
+
139
+ if (verbose) {
140
+ printf("Adding %d elements at level %d\n",
141
+ i1 - i0, pt_level);
142
+ }
143
+
144
+ // random permutation to get rid of dataset order bias
145
+ for (int j = i0; j < i1; j++)
146
+ std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
147
+
148
+ bool interrupt = false;
149
+
150
+ #pragma omp parallel if(i1 > i0 + 100)
151
+ {
152
+ VisitedTable vt (ntotal);
153
+
154
+ DistanceComputer *dis =
155
+ index_hnsw.storage->get_distance_computer();
156
+ ScopeDeleter1<DistanceComputer> del(dis);
157
+ int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
158
+ size_t counter = 0;
159
+
160
+ #pragma omp for schedule(dynamic)
161
+ for (int i = i0; i < i1; i++) {
162
+ storage_idx_t pt_id = order[i];
163
+ dis->set_query (x + (pt_id - n0) * d);
164
+
165
+ // cannot break
166
+ if (interrupt) {
167
+ continue;
168
+ }
169
+
170
+ hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
171
+
172
+ if (prev_display >= 0 && i - i0 > prev_display + 10000) {
173
+ prev_display = i - i0;
174
+ printf(" %d / %d\r", i - i0, i1 - i0);
175
+ fflush(stdout);
176
+ }
177
+
178
+ if (counter % check_period == 0) {
179
+ if (InterruptCallback::is_interrupted ()) {
180
+ interrupt = true;
181
+ }
182
+ }
183
+ counter++;
184
+ }
185
+
186
+ }
187
+ if (interrupt) {
188
+ FAISS_THROW_MSG ("computation interrupted");
189
+ }
190
+ i1 = i0;
191
+ }
192
+ FAISS_ASSERT(i1 == 0);
193
+ }
194
+ if (verbose) {
195
+ printf("Done in %.3f ms\n", getmillisecs() - t0);
196
+ }
197
+
198
+ for(int i = 0; i < ntotal; i++) {
199
+ omp_destroy_lock(&locks[i]);
200
+ }
201
+ }
202
+
203
+
204
+ } // namespace
205
+
206
+
207
+
208
+
209
+ /**************************************************************
210
+ * IndexHNSW implementation
211
+ **************************************************************/
212
+
213
+ IndexHNSW::IndexHNSW(int d, int M):
214
+ Index(d, METRIC_L2),
215
+ hnsw(M),
216
+ own_fields(false),
217
+ storage(nullptr),
218
+ reconstruct_from_neighbors(nullptr)
219
+ {}
220
+
221
+ IndexHNSW::IndexHNSW(Index *storage, int M):
222
+ Index(storage->d, storage->metric_type),
223
+ hnsw(M),
224
+ own_fields(false),
225
+ storage(storage),
226
+ reconstruct_from_neighbors(nullptr)
227
+ {}
228
+
229
+ IndexHNSW::~IndexHNSW() {
230
+ if (own_fields) {
231
+ delete storage;
232
+ }
233
+ }
234
+
235
+ void IndexHNSW::train(idx_t n, const float* x)
236
+ {
237
+ FAISS_THROW_IF_NOT_MSG(storage,
238
+ "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
239
+ // hnsw structure does not require training
240
+ storage->train (n, x);
241
+ is_trained = true;
242
+ }
243
+
244
+ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
245
+ float *distances, idx_t *labels) const
246
+
247
+ {
248
+ FAISS_THROW_IF_NOT_MSG(storage,
249
+ "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
250
+ size_t nreorder = 0;
251
+
252
+ idx_t check_period = InterruptCallback::get_period_hint (
253
+ hnsw.max_level * d * hnsw.efSearch);
254
+
255
+ for (idx_t i0 = 0; i0 < n; i0 += check_period) {
256
+ idx_t i1 = std::min(i0 + check_period, n);
257
+
258
+ #pragma omp parallel reduction(+ : nreorder)
259
+ {
260
+ VisitedTable vt (ntotal);
261
+ DistanceComputer *dis = storage->get_distance_computer();
262
+ ScopeDeleter1<DistanceComputer> del(dis);
263
+
264
+ #pragma omp for
265
+ for(idx_t i = i0; i < i1; i++) {
266
+ idx_t * idxi = labels + i * k;
267
+ float * simi = distances + i * k;
268
+ dis->set_query(x + i * d);
269
+
270
+ maxheap_heapify (k, simi, idxi);
271
+ hnsw.search(*dis, k, idxi, simi, vt);
272
+
273
+ maxheap_reorder (k, simi, idxi);
274
+
275
+ if (reconstruct_from_neighbors &&
276
+ reconstruct_from_neighbors->k_reorder != 0) {
277
+ int k_reorder = reconstruct_from_neighbors->k_reorder;
278
+ if (k_reorder == -1 || k_reorder > k) k_reorder = k;
279
+
280
+ nreorder += reconstruct_from_neighbors->compute_distances(
281
+ k_reorder, idxi, x + i * d, simi);
282
+
283
+ // sort top k_reorder
284
+ maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder);
285
+ maxheap_reorder (k_reorder, simi, idxi);
286
+ }
287
+
288
+ }
289
+
290
+ }
291
+ InterruptCallback::check ();
292
+ }
293
+ hnsw_stats.nreorder += nreorder;
294
+ }
295
+
296
+
297
+ void IndexHNSW::add(idx_t n, const float *x)
298
+ {
299
+ FAISS_THROW_IF_NOT_MSG(storage,
300
+ "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
301
+ FAISS_THROW_IF_NOT(is_trained);
302
+ int n0 = ntotal;
303
+ storage->add(n, x);
304
+ ntotal = storage->ntotal;
305
+
306
+ hnsw_add_vertices (*this, n0, n, x, verbose,
307
+ hnsw.levels.size() == ntotal);
308
+ }
309
+
310
+ void IndexHNSW::reset()
311
+ {
312
+ hnsw.reset();
313
+ storage->reset();
314
+ ntotal = 0;
315
+ }
316
+
317
+ void IndexHNSW::reconstruct (idx_t key, float* recons) const
318
+ {
319
+ storage->reconstruct(key, recons);
320
+ }
321
+
322
+ void IndexHNSW::shrink_level_0_neighbors(int new_size)
323
+ {
324
+ #pragma omp parallel
325
+ {
326
+ DistanceComputer *dis = storage->get_distance_computer();
327
+ ScopeDeleter1<DistanceComputer> del(dis);
328
+
329
+ #pragma omp for
330
+ for (idx_t i = 0; i < ntotal; i++) {
331
+
332
+ size_t begin, end;
333
+ hnsw.neighbor_range(i, 0, &begin, &end);
334
+
335
+ std::priority_queue<NodeDistFarther> initial_list;
336
+
337
+ for (size_t j = begin; j < end; j++) {
338
+ int v1 = hnsw.neighbors[j];
339
+ if (v1 < 0) break;
340
+ initial_list.emplace(dis->symmetric_dis(i, v1), v1);
341
+
342
+ // initial_list.emplace(qdis(v1), v1);
343
+ }
344
+
345
+ std::vector<NodeDistFarther> shrunk_list;
346
+ HNSW::shrink_neighbor_list(*dis, initial_list,
347
+ shrunk_list, new_size);
348
+
349
+ for (size_t j = begin; j < end; j++) {
350
+ if (j - begin < shrunk_list.size())
351
+ hnsw.neighbors[j] = shrunk_list[j - begin].id;
352
+ else
353
+ hnsw.neighbors[j] = -1;
354
+ }
355
+ }
356
+ }
357
+
358
+ }
359
+
360
+ void IndexHNSW::search_level_0(
361
+ idx_t n, const float *x, idx_t k,
362
+ const storage_idx_t *nearest, const float *nearest_d,
363
+ float *distances, idx_t *labels, int nprobe,
364
+ int search_type) const
365
+ {
366
+
367
+ storage_idx_t ntotal = hnsw.levels.size();
368
+ #pragma omp parallel
369
+ {
370
+ DistanceComputer *qdis = storage->get_distance_computer();
371
+ ScopeDeleter1<DistanceComputer> del(qdis);
372
+
373
+ VisitedTable vt (ntotal);
374
+
375
+ #pragma omp for
376
+ for(idx_t i = 0; i < n; i++) {
377
+ idx_t * idxi = labels + i * k;
378
+ float * simi = distances + i * k;
379
+
380
+ qdis->set_query(x + i * d);
381
+ maxheap_heapify (k, simi, idxi);
382
+
383
+ if (search_type == 1) {
384
+
385
+ int nres = 0;
386
+
387
+ for(int j = 0; j < nprobe; j++) {
388
+ storage_idx_t cj = nearest[i * nprobe + j];
389
+
390
+ if (cj < 0) break;
391
+
392
+ if (vt.get(cj)) continue;
393
+
394
+ int candidates_size = std::max(hnsw.efSearch, int(k));
395
+ MinimaxHeap candidates(candidates_size);
396
+
397
+ candidates.push(cj, nearest_d[i * nprobe + j]);
398
+
399
+ nres = hnsw.search_from_candidates(
400
+ *qdis, k, idxi, simi,
401
+ candidates, vt, 0, nres
402
+ );
403
+ }
404
+ } else if (search_type == 2) {
405
+
406
+ int candidates_size = std::max(hnsw.efSearch, int(k));
407
+ candidates_size = std::max(candidates_size, nprobe);
408
+
409
+ MinimaxHeap candidates(candidates_size);
410
+ for(int j = 0; j < nprobe; j++) {
411
+ storage_idx_t cj = nearest[i * nprobe + j];
412
+
413
+ if (cj < 0) break;
414
+ candidates.push(cj, nearest_d[i * nprobe + j]);
415
+ }
416
+ hnsw.search_from_candidates(
417
+ *qdis, k, idxi, simi,
418
+ candidates, vt, 0
419
+ );
420
+
421
+ }
422
+ vt.advance();
423
+
424
+ maxheap_reorder (k, simi, idxi);
425
+
426
+ }
427
+ }
428
+
429
+
430
+ }
431
+
432
+ void IndexHNSW::init_level_0_from_knngraph(
433
+ int k, const float *D, const idx_t *I)
434
+ {
435
+ int dest_size = hnsw.nb_neighbors (0);
436
+
437
+ #pragma omp parallel for
438
+ for (idx_t i = 0; i < ntotal; i++) {
439
+ DistanceComputer *qdis = storage->get_distance_computer();
440
+ float vec[d];
441
+ storage->reconstruct(i, vec);
442
+ qdis->set_query(vec);
443
+
444
+ std::priority_queue<NodeDistFarther> initial_list;
445
+
446
+ for (size_t j = 0; j < k; j++) {
447
+ int v1 = I[i * k + j];
448
+ if (v1 == i) continue;
449
+ if (v1 < 0) break;
450
+ initial_list.emplace(D[i * k + j], v1);
451
+ }
452
+
453
+ std::vector<NodeDistFarther> shrunk_list;
454
+ HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size);
455
+
456
+ size_t begin, end;
457
+ hnsw.neighbor_range(i, 0, &begin, &end);
458
+
459
+ for (size_t j = begin; j < end; j++) {
460
+ if (j - begin < shrunk_list.size())
461
+ hnsw.neighbors[j] = shrunk_list[j - begin].id;
462
+ else
463
+ hnsw.neighbors[j] = -1;
464
+ }
465
+ }
466
+ }
467
+
468
+
469
+
470
+ void IndexHNSW::init_level_0_from_entry_points(
471
+ int n, const storage_idx_t *points,
472
+ const storage_idx_t *nearests)
473
+ {
474
+
475
+ std::vector<omp_lock_t> locks(ntotal);
476
+ for(int i = 0; i < ntotal; i++)
477
+ omp_init_lock(&locks[i]);
478
+
479
+ #pragma omp parallel
480
+ {
481
+ VisitedTable vt (ntotal);
482
+
483
+ DistanceComputer *dis = storage->get_distance_computer();
484
+ ScopeDeleter1<DistanceComputer> del(dis);
485
+ float vec[storage->d];
486
+
487
+ #pragma omp for schedule(dynamic)
488
+ for (int i = 0; i < n; i++) {
489
+ storage_idx_t pt_id = points[i];
490
+ storage_idx_t nearest = nearests[i];
491
+ storage->reconstruct (pt_id, vec);
492
+ dis->set_query (vec);
493
+
494
+ hnsw.add_links_starting_from(*dis, pt_id,
495
+ nearest, (*dis)(nearest),
496
+ 0, locks.data(), vt);
497
+
498
+ if (verbose && i % 10000 == 0) {
499
+ printf(" %d / %d\r", i, n);
500
+ fflush(stdout);
501
+ }
502
+ }
503
+ }
504
+ if (verbose) {
505
+ printf("\n");
506
+ }
507
+
508
+ for(int i = 0; i < ntotal; i++)
509
+ omp_destroy_lock(&locks[i]);
510
+ }
511
+
512
+ void IndexHNSW::reorder_links()
513
+ {
514
+ int M = hnsw.nb_neighbors(0);
515
+
516
+ #pragma omp parallel
517
+ {
518
+ std::vector<float> distances (M);
519
+ std::vector<size_t> order (M);
520
+ std::vector<storage_idx_t> tmp (M);
521
+ DistanceComputer *dis = storage->get_distance_computer();
522
+ ScopeDeleter1<DistanceComputer> del(dis);
523
+
524
+ #pragma omp for
525
+ for(storage_idx_t i = 0; i < ntotal; i++) {
526
+
527
+ size_t begin, end;
528
+ hnsw.neighbor_range(i, 0, &begin, &end);
529
+
530
+ for (size_t j = begin; j < end; j++) {
531
+ storage_idx_t nj = hnsw.neighbors[j];
532
+ if (nj < 0) {
533
+ end = j;
534
+ break;
535
+ }
536
+ distances[j - begin] = dis->symmetric_dis(i, nj);
537
+ tmp [j - begin] = nj;
538
+ }
539
+
540
+ fvec_argsort (end - begin, distances.data(), order.data());
541
+ for (size_t j = begin; j < end; j++) {
542
+ hnsw.neighbors[j] = tmp[order[j - begin]];
543
+ }
544
+ }
545
+
546
+ }
547
+ }
548
+
549
+
550
+ void IndexHNSW::link_singletons()
551
+ {
552
+ printf("search for singletons\n");
553
+
554
+ std::vector<bool> seen(ntotal);
555
+
556
+ for (size_t i = 0; i < ntotal; i++) {
557
+ size_t begin, end;
558
+ hnsw.neighbor_range(i, 0, &begin, &end);
559
+ for (size_t j = begin; j < end; j++) {
560
+ storage_idx_t ni = hnsw.neighbors[j];
561
+ if (ni >= 0) seen[ni] = true;
562
+ }
563
+ }
564
+
565
+ int n_sing = 0, n_sing_l1 = 0;
566
+ std::vector<storage_idx_t> singletons;
567
+ for (storage_idx_t i = 0; i < ntotal; i++) {
568
+ if (!seen[i]) {
569
+ singletons.push_back(i);
570
+ n_sing++;
571
+ if (hnsw.levels[i] > 1)
572
+ n_sing_l1++;
573
+ }
574
+ }
575
+
576
+ printf(" Found %d / %ld singletons (%d appear in a level above)\n",
577
+ n_sing, ntotal, n_sing_l1);
578
+
579
+ std::vector<float>recons(singletons.size() * d);
580
+ for (int i = 0; i < singletons.size(); i++) {
581
+
582
+ FAISS_ASSERT(!"not implemented");
583
+
584
+ }
585
+
586
+
587
+ }
588
+
589
+
590
+ /**************************************************************
591
+ * ReconstructFromNeighbors implementation
592
+ **************************************************************/
593
+
594
+
595
+ ReconstructFromNeighbors::ReconstructFromNeighbors(
596
+ const IndexHNSW & index, size_t k, size_t nsq):
597
+ index(index), k(k), nsq(nsq) {
598
+ M = index.hnsw.nb_neighbors(0);
599
+ FAISS_ASSERT(k <= 256);
600
+ code_size = k == 1 ? 0 : nsq;
601
+ ntotal = 0;
602
+ d = index.d;
603
+ FAISS_ASSERT(d % nsq == 0);
604
+ dsub = d / nsq;
605
+ k_reorder = -1;
606
+ }
607
+
608
+ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp) const
609
+ {
610
+
611
+
612
+ const HNSW & hnsw = index.hnsw;
613
+ size_t begin, end;
614
+ hnsw.neighbor_range(i, 0, &begin, &end);
615
+
616
+ if (k == 1 || nsq == 1) {
617
+ const float * beta;
618
+ if (k == 1) {
619
+ beta = codebook.data();
620
+ } else {
621
+ int idx = codes[i];
622
+ beta = codebook.data() + idx * (M + 1);
623
+ }
624
+
625
+ float w0 = beta[0]; // weight of image itself
626
+ index.storage->reconstruct(i, tmp);
627
+
628
+ for (int l = 0; l < d; l++)
629
+ x[l] = w0 * tmp[l];
630
+
631
+ for (size_t j = begin; j < end; j++) {
632
+
633
+ storage_idx_t ji = hnsw.neighbors[j];
634
+ if (ji < 0) ji = i;
635
+ float w = beta[j - begin + 1];
636
+ index.storage->reconstruct(ji, tmp);
637
+ for (int l = 0; l < d; l++)
638
+ x[l] += w * tmp[l];
639
+ }
640
+ } else if (nsq == 2) {
641
+ int idx0 = codes[2 * i];
642
+ int idx1 = codes[2 * i + 1];
643
+
644
+ const float *beta0 = codebook.data() + idx0 * (M + 1);
645
+ const float *beta1 = codebook.data() + (idx1 + k) * (M + 1);
646
+
647
+ index.storage->reconstruct(i, tmp);
648
+
649
+ float w0;
650
+
651
+ w0 = beta0[0];
652
+ for (int l = 0; l < dsub; l++)
653
+ x[l] = w0 * tmp[l];
654
+
655
+ w0 = beta1[0];
656
+ for (int l = dsub; l < d; l++)
657
+ x[l] = w0 * tmp[l];
658
+
659
+ for (size_t j = begin; j < end; j++) {
660
+ storage_idx_t ji = hnsw.neighbors[j];
661
+ if (ji < 0) ji = i;
662
+ index.storage->reconstruct(ji, tmp);
663
+ float w;
664
+ w = beta0[j - begin + 1];
665
+ for (int l = 0; l < dsub; l++)
666
+ x[l] += w * tmp[l];
667
+
668
+ w = beta1[j - begin + 1];
669
+ for (int l = dsub; l < d; l++)
670
+ x[l] += w * tmp[l];
671
+ }
672
+ } else {
673
+ const float *betas[nsq];
674
+ {
675
+ const float *b = codebook.data();
676
+ const uint8_t *c = &codes[i * code_size];
677
+ for (int sq = 0; sq < nsq; sq++) {
678
+ betas[sq] = b + (*c++) * (M + 1);
679
+ b += (M + 1) * k;
680
+ }
681
+ }
682
+
683
+ index.storage->reconstruct(i, tmp);
684
+ {
685
+ int d0 = 0;
686
+ for (int sq = 0; sq < nsq; sq++) {
687
+ float w = *(betas[sq]++);
688
+ int d1 = d0 + dsub;
689
+ for (int l = d0; l < d1; l++) {
690
+ x[l] = w * tmp[l];
691
+ }
692
+ d0 = d1;
693
+ }
694
+ }
695
+
696
+ for (size_t j = begin; j < end; j++) {
697
+ storage_idx_t ji = hnsw.neighbors[j];
698
+ if (ji < 0) ji = i;
699
+
700
+ index.storage->reconstruct(ji, tmp);
701
+ int d0 = 0;
702
+ for (int sq = 0; sq < nsq; sq++) {
703
+ float w = *(betas[sq]++);
704
+ int d1 = d0 + dsub;
705
+ for (int l = d0; l < d1; l++) {
706
+ x[l] += w * tmp[l];
707
+ }
708
+ d0 = d1;
709
+ }
710
+ }
711
+ }
712
+ }
713
+
714
+ void ReconstructFromNeighbors::reconstruct_n(storage_idx_t n0,
715
+ storage_idx_t ni,
716
+ float *x) const
717
+ {
718
+ #pragma omp parallel
719
+ {
720
+ std::vector<float> tmp(index.d);
721
+ #pragma omp for
722
+ for (storage_idx_t i = 0; i < ni; i++) {
723
+ reconstruct(n0 + i, x + i * index.d, tmp.data());
724
+ }
725
+ }
726
+ }
727
+
728
+ size_t ReconstructFromNeighbors::compute_distances(
729
+ size_t n, const idx_t *shortlist,
730
+ const float *query, float *distances) const
731
+ {
732
+ std::vector<float> tmp(2 * index.d);
733
+ size_t ncomp = 0;
734
+ for (int i = 0; i < n; i++) {
735
+ if (shortlist[i] < 0) break;
736
+ reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
737
+ distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
738
+ ncomp++;
739
+ }
740
+ return ncomp;
741
+ }
742
+
743
+ void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float *tmp1) const
744
+ {
745
+ const HNSW & hnsw = index.hnsw;
746
+ size_t begin, end;
747
+ hnsw.neighbor_range(i, 0, &begin, &end);
748
+ size_t d = index.d;
749
+
750
+ index.storage->reconstruct(i, tmp1);
751
+
752
+ for (size_t j = begin; j < end; j++) {
753
+ storage_idx_t ji = hnsw.neighbors[j];
754
+ if (ji < 0) ji = i;
755
+ index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
756
+ }
757
+
758
+ }
759
+
760
+
761
+ /// called by add_codes
762
+ void ReconstructFromNeighbors::estimate_code(
763
+ const float *x, storage_idx_t i, uint8_t *code) const
764
+ {
765
+
766
+ // fill in tmp table with the neighbor values
767
+ float *tmp1 = new float[d * (M + 1) + (d * k)];
768
+ float *tmp2 = tmp1 + d * (M + 1);
769
+ ScopeDeleter<float> del(tmp1);
770
+
771
+ // collect coordinates of base
772
+ get_neighbor_table (i, tmp1);
773
+
774
+ for (size_t sq = 0; sq < nsq; sq++) {
775
+ int d0 = sq * dsub;
776
+
777
+ {
778
+ FINTEGER ki = k, di = d, m1 = M + 1;
779
+ FINTEGER dsubi = dsub;
780
+ float zero = 0, one = 1;
781
+
782
+ sgemm_ ("N", "N", &dsubi, &ki, &m1, &one,
783
+ tmp1 + d0, &di,
784
+ codebook.data() + sq * (m1 * k), &m1,
785
+ &zero, tmp2, &dsubi);
786
+ }
787
+
788
+ float min = HUGE_VAL;
789
+ int argmin = -1;
790
+ for (size_t j = 0; j < k; j++) {
791
+ float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
792
+ if (dis < min) {
793
+ min = dis;
794
+ argmin = j;
795
+ }
796
+ }
797
+ code[sq] = argmin;
798
+ }
799
+
800
+ }
801
+
802
+ void ReconstructFromNeighbors::add_codes(size_t n, const float *x)
803
+ {
804
+ if (k == 1) { // nothing to encode
805
+ ntotal += n;
806
+ return;
807
+ }
808
+ codes.resize(codes.size() + code_size * n);
809
+ #pragma omp parallel for
810
+ for (int i = 0; i < n; i++) {
811
+ estimate_code(x + i * index.d, ntotal + i,
812
+ codes.data() + (ntotal + i) * code_size);
813
+ }
814
+ ntotal += n;
815
+ FAISS_ASSERT (codes.size() == ntotal * code_size);
816
+ }
817
+
818
+
819
+ /**************************************************************
820
+ * IndexHNSWFlat implementation
821
+ **************************************************************/
822
+
823
+
824
+ IndexHNSWFlat::IndexHNSWFlat()
825
+ {
826
+ is_trained = true;
827
+ }
828
+
829
+ IndexHNSWFlat::IndexHNSWFlat(int d, int M):
830
+ IndexHNSW(new IndexFlatL2(d), M)
831
+ {
832
+ own_fields = true;
833
+ is_trained = true;
834
+ }
835
+
836
+
837
+ /**************************************************************
838
+ * IndexHNSWPQ implementation
839
+ **************************************************************/
840
+
841
+
842
+ IndexHNSWPQ::IndexHNSWPQ() {}
843
+
844
+ IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M):
845
+ IndexHNSW(new IndexPQ(d, pq_m, 8), M)
846
+ {
847
+ own_fields = true;
848
+ is_trained = false;
849
+ }
850
+
851
+ void IndexHNSWPQ::train(idx_t n, const float* x)
852
+ {
853
+ IndexHNSW::train (n, x);
854
+ (dynamic_cast<IndexPQ*> (storage))->pq.compute_sdc_table();
855
+ }
856
+
857
+
858
+ /**************************************************************
859
+ * IndexHNSWSQ implementation
860
+ **************************************************************/
861
+
862
+
863
+ IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M):
864
+ IndexHNSW (new IndexScalarQuantizer (d, qtype), M)
865
+ {
866
+ is_trained = false;
867
+ own_fields = true;
868
+ }
869
+
870
+ IndexHNSWSQ::IndexHNSWSQ() {}
871
+
872
+
873
+ /**************************************************************
874
+ * IndexHNSW2Level implementation
875
+ **************************************************************/
876
+
877
+
878
+ IndexHNSW2Level::IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M):
879
+ IndexHNSW (new Index2Layer (quantizer, nlist, m_pq), M)
880
+ {
881
+ own_fields = true;
882
+ is_trained = false;
883
+ }
884
+
885
+ IndexHNSW2Level::IndexHNSW2Level() {}
886
+
887
+
888
+ namespace {
889
+
890
+
891
+ // same as search_from_candidates but uses v
892
+ // visno -> is in result list
893
+ // visno + 1 -> in result list + in candidates
894
+ int search_from_candidates_2(const HNSW & hnsw,
895
+ DistanceComputer & qdis, int k,
896
+ idx_t *I, float * D,
897
+ MinimaxHeap &candidates,
898
+ VisitedTable &vt,
899
+ int level, int nres_in = 0)
900
+ {
901
+ int nres = nres_in;
902
+ int ndis = 0;
903
+ for (int i = 0; i < candidates.size(); i++) {
904
+ idx_t v1 = candidates.ids[i];
905
+ FAISS_ASSERT(v1 >= 0);
906
+ vt.visited[v1] = vt.visno + 1;
907
+ }
908
+
909
+ int nstep = 0;
910
+
911
+ while (candidates.size() > 0) {
912
+ float d0 = 0;
913
+ int v0 = candidates.pop_min(&d0);
914
+
915
+ size_t begin, end;
916
+ hnsw.neighbor_range(v0, level, &begin, &end);
917
+
918
+ for (size_t j = begin; j < end; j++) {
919
+ int v1 = hnsw.neighbors[j];
920
+ if (v1 < 0) break;
921
+ if (vt.visited[v1] == vt.visno + 1) {
922
+ // nothing to do
923
+ } else {
924
+ ndis++;
925
+ float d = qdis(v1);
926
+ candidates.push(v1, d);
927
+
928
+ // never seen before --> add to heap
929
+ if (vt.visited[v1] < vt.visno) {
930
+ if (nres < k) {
931
+ faiss::maxheap_push (++nres, D, I, d, v1);
932
+ } else if (d < D[0]) {
933
+ faiss::maxheap_pop (nres--, D, I);
934
+ faiss::maxheap_push (++nres, D, I, d, v1);
935
+ }
936
+ }
937
+ vt.visited[v1] = vt.visno + 1;
938
+ }
939
+ }
940
+
941
+ nstep++;
942
+ if (nstep > hnsw.efSearch) {
943
+ break;
944
+ }
945
+ }
946
+
947
+ if (level == 0) {
948
+ #pragma omp critical
949
+ {
950
+ hnsw_stats.n1 ++;
951
+ if (candidates.size() == 0)
952
+ hnsw_stats.n2 ++;
953
+ }
954
+ }
955
+
956
+
957
+ return nres;
958
+ }
959
+
960
+
961
+ } // namespace
962
+
963
+ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
964
+ float *distances, idx_t *labels) const
965
+ {
966
+ if (dynamic_cast<const Index2Layer*>(storage)) {
967
+ IndexHNSW::search (n, x, k, distances, labels);
968
+
969
+ } else { // "mixed" search
970
+
971
+ const IndexIVFPQ *index_ivfpq =
972
+ dynamic_cast<const IndexIVFPQ*>(storage);
973
+
974
+ int nprobe = index_ivfpq->nprobe;
975
+
976
+ std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
977
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
978
+
979
+ index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(),
980
+ coarse_assign.get());
981
+
982
+ index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(),
983
+ coarse_dis.get(), distances, labels,
984
+ false);
985
+
986
+ #pragma omp parallel
987
+ {
988
+ VisitedTable vt (ntotal);
989
+ DistanceComputer *dis = storage->get_distance_computer();
990
+ ScopeDeleter1<DistanceComputer> del(dis);
991
+
992
+ int candidates_size = hnsw.upper_beam;
993
+ MinimaxHeap candidates(candidates_size);
994
+
995
+ #pragma omp for
996
+ for(idx_t i = 0; i < n; i++) {
997
+ idx_t * idxi = labels + i * k;
998
+ float * simi = distances + i * k;
999
+ dis->set_query(x + i * d);
1000
+
1001
+ // mark all inverted list elements as visited
1002
+
1003
+ for (int j = 0; j < nprobe; j++) {
1004
+ idx_t key = coarse_assign[j + i * nprobe];
1005
+ if (key < 0) break;
1006
+ size_t list_length = index_ivfpq->get_list_size (key);
1007
+ const idx_t * ids = index_ivfpq->invlists->get_ids (key);
1008
+
1009
+ for (int jj = 0; jj < list_length; jj++) {
1010
+ vt.set (ids[jj]);
1011
+ }
1012
+ }
1013
+
1014
+ candidates.clear();
1015
+ // copy the upper_beam elements to candidates list
1016
+
1017
+ int search_policy = 2;
1018
+
1019
+ if (search_policy == 1) {
1020
+
1021
+ for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
1022
+ if (idxi[j] < 0) break;
1023
+ candidates.push (idxi[j], simi[j]);
1024
+ // search_from_candidates adds them back
1025
+ idxi[j] = -1;
1026
+ simi[j] = HUGE_VAL;
1027
+ }
1028
+
1029
+ // reorder from sorted to heap
1030
+ maxheap_heapify (k, simi, idxi, simi, idxi, k);
1031
+
1032
+ hnsw.search_from_candidates(
1033
+ *dis, k, idxi, simi,
1034
+ candidates, vt, 0, k
1035
+ );
1036
+
1037
+ vt.advance();
1038
+
1039
+ } else if (search_policy == 2) {
1040
+
1041
+ for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
1042
+ if (idxi[j] < 0) break;
1043
+ candidates.push (idxi[j], simi[j]);
1044
+ }
1045
+
1046
+ // reorder from sorted to heap
1047
+ maxheap_heapify (k, simi, idxi, simi, idxi, k);
1048
+
1049
+ search_from_candidates_2 (
1050
+ hnsw, *dis, k, idxi, simi,
1051
+ candidates, vt, 0, k);
1052
+ vt.advance ();
1053
+ vt.advance ();
1054
+
1055
+ }
1056
+
1057
+ maxheap_reorder (k, simi, idxi);
1058
+ }
1059
+ }
1060
+ }
1061
+
1062
+
1063
+ }
1064
+
1065
+
1066
+ void IndexHNSW2Level::flip_to_ivf ()
1067
+ {
1068
+ Index2Layer *storage2l =
1069
+ dynamic_cast<Index2Layer*>(storage);
1070
+
1071
+ FAISS_THROW_IF_NOT (storage2l);
1072
+
1073
+ IndexIVFPQ * index_ivfpq =
1074
+ new IndexIVFPQ (storage2l->q1.quantizer,
1075
+ d, storage2l->q1.nlist,
1076
+ storage2l->pq.M, 8);
1077
+ index_ivfpq->pq = storage2l->pq;
1078
+ index_ivfpq->is_trained = storage2l->is_trained;
1079
+ index_ivfpq->precompute_table();
1080
+ index_ivfpq->own_fields = storage2l->q1.own_fields;
1081
+ storage2l->transfer_to_IVFPQ(*index_ivfpq);
1082
+ index_ivfpq->make_direct_map (true);
1083
+
1084
+ storage = index_ivfpq;
1085
+ delete storage2l;
1086
+
1087
+ }
1088
+
1089
+
1090
+ } // namespace faiss