faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,175 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef INDEX_FLAT_H
11
+ #define INDEX_FLAT_H
12
+
13
+ #include <vector>
14
+
15
+ #include <faiss/Index.h>
16
+
17
+
18
+ namespace faiss {
19
+
20
+ /** Index that stores the full vectors and performs exhaustive search */
21
+ struct IndexFlat: Index {
22
+ /// database vectors, size ntotal * d
23
+ std::vector<float> xb;
24
+
25
+ explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2);
26
+
27
+ void add(idx_t n, const float* x) override;
28
+
29
+ void reset() override;
30
+
31
+ void search(
32
+ idx_t n,
33
+ const float* x,
34
+ idx_t k,
35
+ float* distances,
36
+ idx_t* labels) const override;
37
+
38
+ void range_search(
39
+ idx_t n,
40
+ const float* x,
41
+ float radius,
42
+ RangeSearchResult* result) const override;
43
+
44
+ void reconstruct(idx_t key, float* recons) const override;
45
+
46
+ /** compute distance with a subset of vectors
47
+ *
48
+ * @param x query vectors, size n * d
49
+ * @param labels indices of the vectors that should be compared
50
+ * for each query vector, size n * k
51
+ * @param distances
52
+ * corresponding output distances, size n * k
53
+ */
54
+ void compute_distance_subset (
55
+ idx_t n,
56
+ const float *x,
57
+ idx_t k,
58
+ float *distances,
59
+ const idx_t *labels) const;
60
+
61
+ /** remove some ids. NB that Because of the structure of the
62
+ * indexing structure, the semantics of this operation are
63
+ * different from the usual ones: the new ids are shifted */
64
+ size_t remove_ids(const IDSelector& sel) override;
65
+
66
+ IndexFlat () {}
67
+
68
+ DistanceComputer * get_distance_computer() const override;
69
+
70
+ /* The stanadlone codec interface (just memcopies in this case) */
71
+ size_t sa_code_size () const override;
72
+
73
+ void sa_encode (idx_t n, const float *x,
74
+ uint8_t *bytes) const override;
75
+
76
+ void sa_decode (idx_t n, const uint8_t *bytes,
77
+ float *x) const override;
78
+
79
+ };
80
+
81
+
82
+
83
+ struct IndexFlatIP:IndexFlat {
84
+ explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
85
+ IndexFlatIP () {}
86
+ };
87
+
88
+
89
+ struct IndexFlatL2:IndexFlat {
90
+ explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
91
+ IndexFlatL2 () {}
92
+ };
93
+
94
+
95
+ // same as an IndexFlatL2 but a value is subtracted from each distance
96
+ struct IndexFlatL2BaseShift: IndexFlatL2 {
97
+ std::vector<float> shift;
98
+
99
+ IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
100
+
101
+ void search(
102
+ idx_t n,
103
+ const float* x,
104
+ idx_t k,
105
+ float* distances,
106
+ idx_t* labels) const override;
107
+ };
108
+
109
+
110
+ /** Index that queries in a base_index (a fast one) and refines the
111
+ * results with an exact search, hopefully improving the results.
112
+ */
113
+ struct IndexRefineFlat: Index {
114
+
115
+ /// storage for full vectors
116
+ IndexFlat refine_index;
117
+
118
+ /// faster index to pre-select the vectors that should be filtered
119
+ Index *base_index;
120
+ bool own_fields; ///< should the base index be deallocated?
121
+
122
+ /// factor between k requested in search and the k requested from
123
+ /// the base_index (should be >= 1)
124
+ float k_factor;
125
+
126
+ explicit IndexRefineFlat (Index *base_index);
127
+
128
+ IndexRefineFlat ();
129
+
130
+ void train(idx_t n, const float* x) override;
131
+
132
+ void add(idx_t n, const float* x) override;
133
+
134
+ void reset() override;
135
+
136
+ void search(
137
+ idx_t n,
138
+ const float* x,
139
+ idx_t k,
140
+ float* distances,
141
+ idx_t* labels) const override;
142
+
143
+ ~IndexRefineFlat() override;
144
+ };
145
+
146
+
147
+ /// optimized version for 1D "vectors"
148
+ struct IndexFlat1D:IndexFlatL2 {
149
+ bool continuous_update; ///< is the permutation updated continuously?
150
+
151
+ std::vector<idx_t> perm; ///< sorted database indices
152
+
153
+ explicit IndexFlat1D (bool continuous_update=true);
154
+
155
+ /// if not continuous_update, call this between the last add and
156
+ /// the first search
157
+ void update_permutation ();
158
+
159
+ void add(idx_t n, const float* x) override;
160
+
161
+ void reset() override;
162
+
163
+ /// Warn: the distances returned are L1 not L2
164
+ void search(
165
+ idx_t n,
166
+ const float* x,
167
+ idx_t k,
168
+ float* distances,
169
+ idx_t* labels) const override;
170
+ };
171
+
172
+
173
+ }
174
+
175
+ #endif
@@ -0,0 +1,1090 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/IndexHNSW.h>
11
+
12
+
13
+ #include <cstdlib>
14
+ #include <cassert>
15
+ #include <cstring>
16
+ #include <cstdio>
17
+ #include <cmath>
18
+ #include <omp.h>
19
+
20
+ #include <unordered_set>
21
+ #include <queue>
22
+
23
+ #include <sys/types.h>
24
+ #include <sys/stat.h>
25
+ #include <unistd.h>
26
+ #include <stdint.h>
27
+
28
+ #ifdef __SSE__
29
+ #include <immintrin.h>
30
+ #endif
31
+
32
+ #include <faiss/utils/distances.h>
33
+ #include <faiss/utils/random.h>
34
+ #include <faiss/utils/Heap.h>
35
+ #include <faiss/impl/FaissAssert.h>
36
+ #include <faiss/IndexFlat.h>
37
+ #include <faiss/IndexIVFPQ.h>
38
+ #include <faiss/Index2Layer.h>
39
+ #include <faiss/impl/AuxIndexStructures.h>
40
+
41
+
42
+ extern "C" {
43
+
44
+ /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
45
+
46
+ int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
47
+ n, FINTEGER *k, const float *alpha, const float *a,
48
+ FINTEGER *lda, const float *b, FINTEGER *
49
+ ldb, float *beta, float *c, FINTEGER *ldc);
50
+
51
+ }
52
+
53
+ namespace faiss {
54
+
55
+ using idx_t = Index::idx_t;
56
+ using MinimaxHeap = HNSW::MinimaxHeap;
57
+ using storage_idx_t = HNSW::storage_idx_t;
58
+ using NodeDistCloser = HNSW::NodeDistCloser;
59
+ using NodeDistFarther = HNSW::NodeDistFarther;
60
+
61
+ HNSWStats hnsw_stats;
62
+
63
+ /**************************************************************
64
+ * add / search blocks of descriptors
65
+ **************************************************************/
66
+
67
+ namespace {
68
+
69
+
70
+ void hnsw_add_vertices(IndexHNSW &index_hnsw,
71
+ size_t n0,
72
+ size_t n, const float *x,
73
+ bool verbose,
74
+ bool preset_levels = false) {
75
+ size_t d = index_hnsw.d;
76
+ HNSW & hnsw = index_hnsw.hnsw;
77
+ size_t ntotal = n0 + n;
78
+ double t0 = getmillisecs();
79
+ if (verbose) {
80
+ printf("hnsw_add_vertices: adding %ld elements on top of %ld "
81
+ "(preset_levels=%d)\n",
82
+ n, n0, int(preset_levels));
83
+ }
84
+
85
+ if (n == 0) {
86
+ return;
87
+ }
88
+
89
+ int max_level = hnsw.prepare_level_tab(n, preset_levels);
90
+
91
+ if (verbose) {
92
+ printf(" max_level = %d\n", max_level);
93
+ }
94
+
95
+ std::vector<omp_lock_t> locks(ntotal);
96
+ for(int i = 0; i < ntotal; i++)
97
+ omp_init_lock(&locks[i]);
98
+
99
+ // add vectors from highest to lowest level
100
+ std::vector<int> hist;
101
+ std::vector<int> order(n);
102
+
103
+ { // make buckets with vectors of the same level
104
+
105
+ // build histogram
106
+ for (int i = 0; i < n; i++) {
107
+ storage_idx_t pt_id = i + n0;
108
+ int pt_level = hnsw.levels[pt_id] - 1;
109
+ while (pt_level >= hist.size())
110
+ hist.push_back(0);
111
+ hist[pt_level] ++;
112
+ }
113
+
114
+ // accumulate
115
+ std::vector<int> offsets(hist.size() + 1, 0);
116
+ for (int i = 0; i < hist.size() - 1; i++) {
117
+ offsets[i + 1] = offsets[i] + hist[i];
118
+ }
119
+
120
+ // bucket sort
121
+ for (int i = 0; i < n; i++) {
122
+ storage_idx_t pt_id = i + n0;
123
+ int pt_level = hnsw.levels[pt_id] - 1;
124
+ order[offsets[pt_level]++] = pt_id;
125
+ }
126
+ }
127
+
128
+ idx_t check_period = InterruptCallback::get_period_hint
129
+ (max_level * index_hnsw.d * hnsw.efConstruction);
130
+
131
+ { // perform add
132
+ RandomGenerator rng2(789);
133
+
134
+ int i1 = n;
135
+
136
+ for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
137
+ int i0 = i1 - hist[pt_level];
138
+
139
+ if (verbose) {
140
+ printf("Adding %d elements at level %d\n",
141
+ i1 - i0, pt_level);
142
+ }
143
+
144
+ // random permutation to get rid of dataset order bias
145
+ for (int j = i0; j < i1; j++)
146
+ std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
147
+
148
+ bool interrupt = false;
149
+
150
+ #pragma omp parallel if(i1 > i0 + 100)
151
+ {
152
+ VisitedTable vt (ntotal);
153
+
154
+ DistanceComputer *dis =
155
+ index_hnsw.storage->get_distance_computer();
156
+ ScopeDeleter1<DistanceComputer> del(dis);
157
+ int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
158
+ size_t counter = 0;
159
+
160
+ #pragma omp for schedule(dynamic)
161
+ for (int i = i0; i < i1; i++) {
162
+ storage_idx_t pt_id = order[i];
163
+ dis->set_query (x + (pt_id - n0) * d);
164
+
165
+ // cannot break
166
+ if (interrupt) {
167
+ continue;
168
+ }
169
+
170
+ hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
171
+
172
+ if (prev_display >= 0 && i - i0 > prev_display + 10000) {
173
+ prev_display = i - i0;
174
+ printf(" %d / %d\r", i - i0, i1 - i0);
175
+ fflush(stdout);
176
+ }
177
+
178
+ if (counter % check_period == 0) {
179
+ if (InterruptCallback::is_interrupted ()) {
180
+ interrupt = true;
181
+ }
182
+ }
183
+ counter++;
184
+ }
185
+
186
+ }
187
+ if (interrupt) {
188
+ FAISS_THROW_MSG ("computation interrupted");
189
+ }
190
+ i1 = i0;
191
+ }
192
+ FAISS_ASSERT(i1 == 0);
193
+ }
194
+ if (verbose) {
195
+ printf("Done in %.3f ms\n", getmillisecs() - t0);
196
+ }
197
+
198
+ for(int i = 0; i < ntotal; i++) {
199
+ omp_destroy_lock(&locks[i]);
200
+ }
201
+ }
202
+
203
+
204
+ } // namespace
205
+
206
+
207
+
208
+
209
+ /**************************************************************
210
+ * IndexHNSW implementation
211
+ **************************************************************/
212
+
213
+ IndexHNSW::IndexHNSW(int d, int M):
214
+ Index(d, METRIC_L2),
215
+ hnsw(M),
216
+ own_fields(false),
217
+ storage(nullptr),
218
+ reconstruct_from_neighbors(nullptr)
219
+ {}
220
+
221
+ IndexHNSW::IndexHNSW(Index *storage, int M):
222
+ Index(storage->d, storage->metric_type),
223
+ hnsw(M),
224
+ own_fields(false),
225
+ storage(storage),
226
+ reconstruct_from_neighbors(nullptr)
227
+ {}
228
+
229
+ IndexHNSW::~IndexHNSW() {
230
+ if (own_fields) {
231
+ delete storage;
232
+ }
233
+ }
234
+
235
+ void IndexHNSW::train(idx_t n, const float* x)
236
+ {
237
+ FAISS_THROW_IF_NOT_MSG(storage,
238
+ "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
239
+ // hnsw structure does not require training
240
+ storage->train (n, x);
241
+ is_trained = true;
242
+ }
243
+
244
+ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
245
+ float *distances, idx_t *labels) const
246
+
247
+ {
248
+ FAISS_THROW_IF_NOT_MSG(storage,
249
+ "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
250
+ size_t nreorder = 0;
251
+
252
+ idx_t check_period = InterruptCallback::get_period_hint (
253
+ hnsw.max_level * d * hnsw.efSearch);
254
+
255
+ for (idx_t i0 = 0; i0 < n; i0 += check_period) {
256
+ idx_t i1 = std::min(i0 + check_period, n);
257
+
258
+ #pragma omp parallel reduction(+ : nreorder)
259
+ {
260
+ VisitedTable vt (ntotal);
261
+ DistanceComputer *dis = storage->get_distance_computer();
262
+ ScopeDeleter1<DistanceComputer> del(dis);
263
+
264
+ #pragma omp for
265
+ for(idx_t i = i0; i < i1; i++) {
266
+ idx_t * idxi = labels + i * k;
267
+ float * simi = distances + i * k;
268
+ dis->set_query(x + i * d);
269
+
270
+ maxheap_heapify (k, simi, idxi);
271
+ hnsw.search(*dis, k, idxi, simi, vt);
272
+
273
+ maxheap_reorder (k, simi, idxi);
274
+
275
+ if (reconstruct_from_neighbors &&
276
+ reconstruct_from_neighbors->k_reorder != 0) {
277
+ int k_reorder = reconstruct_from_neighbors->k_reorder;
278
+ if (k_reorder == -1 || k_reorder > k) k_reorder = k;
279
+
280
+ nreorder += reconstruct_from_neighbors->compute_distances(
281
+ k_reorder, idxi, x + i * d, simi);
282
+
283
+ // sort top k_reorder
284
+ maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder);
285
+ maxheap_reorder (k_reorder, simi, idxi);
286
+ }
287
+
288
+ }
289
+
290
+ }
291
+ InterruptCallback::check ();
292
+ }
293
+ hnsw_stats.nreorder += nreorder;
294
+ }
295
+
296
+
297
+ void IndexHNSW::add(idx_t n, const float *x)
298
+ {
299
+ FAISS_THROW_IF_NOT_MSG(storage,
300
+ "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
301
+ FAISS_THROW_IF_NOT(is_trained);
302
+ int n0 = ntotal;
303
+ storage->add(n, x);
304
+ ntotal = storage->ntotal;
305
+
306
+ hnsw_add_vertices (*this, n0, n, x, verbose,
307
+ hnsw.levels.size() == ntotal);
308
+ }
309
+
310
+ void IndexHNSW::reset()
311
+ {
312
+ hnsw.reset();
313
+ storage->reset();
314
+ ntotal = 0;
315
+ }
316
+
317
+ void IndexHNSW::reconstruct (idx_t key, float* recons) const
318
+ {
319
+ storage->reconstruct(key, recons);
320
+ }
321
+
322
+ void IndexHNSW::shrink_level_0_neighbors(int new_size)
323
+ {
324
+ #pragma omp parallel
325
+ {
326
+ DistanceComputer *dis = storage->get_distance_computer();
327
+ ScopeDeleter1<DistanceComputer> del(dis);
328
+
329
+ #pragma omp for
330
+ for (idx_t i = 0; i < ntotal; i++) {
331
+
332
+ size_t begin, end;
333
+ hnsw.neighbor_range(i, 0, &begin, &end);
334
+
335
+ std::priority_queue<NodeDistFarther> initial_list;
336
+
337
+ for (size_t j = begin; j < end; j++) {
338
+ int v1 = hnsw.neighbors[j];
339
+ if (v1 < 0) break;
340
+ initial_list.emplace(dis->symmetric_dis(i, v1), v1);
341
+
342
+ // initial_list.emplace(qdis(v1), v1);
343
+ }
344
+
345
+ std::vector<NodeDistFarther> shrunk_list;
346
+ HNSW::shrink_neighbor_list(*dis, initial_list,
347
+ shrunk_list, new_size);
348
+
349
+ for (size_t j = begin; j < end; j++) {
350
+ if (j - begin < shrunk_list.size())
351
+ hnsw.neighbors[j] = shrunk_list[j - begin].id;
352
+ else
353
+ hnsw.neighbors[j] = -1;
354
+ }
355
+ }
356
+ }
357
+
358
+ }
359
+
360
+ void IndexHNSW::search_level_0(
361
+ idx_t n, const float *x, idx_t k,
362
+ const storage_idx_t *nearest, const float *nearest_d,
363
+ float *distances, idx_t *labels, int nprobe,
364
+ int search_type) const
365
+ {
366
+
367
+ storage_idx_t ntotal = hnsw.levels.size();
368
+ #pragma omp parallel
369
+ {
370
+ DistanceComputer *qdis = storage->get_distance_computer();
371
+ ScopeDeleter1<DistanceComputer> del(qdis);
372
+
373
+ VisitedTable vt (ntotal);
374
+
375
+ #pragma omp for
376
+ for(idx_t i = 0; i < n; i++) {
377
+ idx_t * idxi = labels + i * k;
378
+ float * simi = distances + i * k;
379
+
380
+ qdis->set_query(x + i * d);
381
+ maxheap_heapify (k, simi, idxi);
382
+
383
+ if (search_type == 1) {
384
+
385
+ int nres = 0;
386
+
387
+ for(int j = 0; j < nprobe; j++) {
388
+ storage_idx_t cj = nearest[i * nprobe + j];
389
+
390
+ if (cj < 0) break;
391
+
392
+ if (vt.get(cj)) continue;
393
+
394
+ int candidates_size = std::max(hnsw.efSearch, int(k));
395
+ MinimaxHeap candidates(candidates_size);
396
+
397
+ candidates.push(cj, nearest_d[i * nprobe + j]);
398
+
399
+ nres = hnsw.search_from_candidates(
400
+ *qdis, k, idxi, simi,
401
+ candidates, vt, 0, nres
402
+ );
403
+ }
404
+ } else if (search_type == 2) {
405
+
406
+ int candidates_size = std::max(hnsw.efSearch, int(k));
407
+ candidates_size = std::max(candidates_size, nprobe);
408
+
409
+ MinimaxHeap candidates(candidates_size);
410
+ for(int j = 0; j < nprobe; j++) {
411
+ storage_idx_t cj = nearest[i * nprobe + j];
412
+
413
+ if (cj < 0) break;
414
+ candidates.push(cj, nearest_d[i * nprobe + j]);
415
+ }
416
+ hnsw.search_from_candidates(
417
+ *qdis, k, idxi, simi,
418
+ candidates, vt, 0
419
+ );
420
+
421
+ }
422
+ vt.advance();
423
+
424
+ maxheap_reorder (k, simi, idxi);
425
+
426
+ }
427
+ }
428
+
429
+
430
+ }
431
+
432
+ void IndexHNSW::init_level_0_from_knngraph(
433
+ int k, const float *D, const idx_t *I)
434
+ {
435
+ int dest_size = hnsw.nb_neighbors (0);
436
+
437
+ #pragma omp parallel for
438
+ for (idx_t i = 0; i < ntotal; i++) {
439
+ DistanceComputer *qdis = storage->get_distance_computer();
440
+ float vec[d];
441
+ storage->reconstruct(i, vec);
442
+ qdis->set_query(vec);
443
+
444
+ std::priority_queue<NodeDistFarther> initial_list;
445
+
446
+ for (size_t j = 0; j < k; j++) {
447
+ int v1 = I[i * k + j];
448
+ if (v1 == i) continue;
449
+ if (v1 < 0) break;
450
+ initial_list.emplace(D[i * k + j], v1);
451
+ }
452
+
453
+ std::vector<NodeDistFarther> shrunk_list;
454
+ HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size);
455
+
456
+ size_t begin, end;
457
+ hnsw.neighbor_range(i, 0, &begin, &end);
458
+
459
+ for (size_t j = begin; j < end; j++) {
460
+ if (j - begin < shrunk_list.size())
461
+ hnsw.neighbors[j] = shrunk_list[j - begin].id;
462
+ else
463
+ hnsw.neighbors[j] = -1;
464
+ }
465
+ }
466
+ }
467
+
468
+
469
+
470
+ void IndexHNSW::init_level_0_from_entry_points(
471
+ int n, const storage_idx_t *points,
472
+ const storage_idx_t *nearests)
473
+ {
474
+
475
+ std::vector<omp_lock_t> locks(ntotal);
476
+ for(int i = 0; i < ntotal; i++)
477
+ omp_init_lock(&locks[i]);
478
+
479
+ #pragma omp parallel
480
+ {
481
+ VisitedTable vt (ntotal);
482
+
483
+ DistanceComputer *dis = storage->get_distance_computer();
484
+ ScopeDeleter1<DistanceComputer> del(dis);
485
+ float vec[storage->d];
486
+
487
+ #pragma omp for schedule(dynamic)
488
+ for (int i = 0; i < n; i++) {
489
+ storage_idx_t pt_id = points[i];
490
+ storage_idx_t nearest = nearests[i];
491
+ storage->reconstruct (pt_id, vec);
492
+ dis->set_query (vec);
493
+
494
+ hnsw.add_links_starting_from(*dis, pt_id,
495
+ nearest, (*dis)(nearest),
496
+ 0, locks.data(), vt);
497
+
498
+ if (verbose && i % 10000 == 0) {
499
+ printf(" %d / %d\r", i, n);
500
+ fflush(stdout);
501
+ }
502
+ }
503
+ }
504
+ if (verbose) {
505
+ printf("\n");
506
+ }
507
+
508
+ for(int i = 0; i < ntotal; i++)
509
+ omp_destroy_lock(&locks[i]);
510
+ }
511
+
512
+ void IndexHNSW::reorder_links()
513
+ {
514
+ int M = hnsw.nb_neighbors(0);
515
+
516
+ #pragma omp parallel
517
+ {
518
+ std::vector<float> distances (M);
519
+ std::vector<size_t> order (M);
520
+ std::vector<storage_idx_t> tmp (M);
521
+ DistanceComputer *dis = storage->get_distance_computer();
522
+ ScopeDeleter1<DistanceComputer> del(dis);
523
+
524
+ #pragma omp for
525
+ for(storage_idx_t i = 0; i < ntotal; i++) {
526
+
527
+ size_t begin, end;
528
+ hnsw.neighbor_range(i, 0, &begin, &end);
529
+
530
+ for (size_t j = begin; j < end; j++) {
531
+ storage_idx_t nj = hnsw.neighbors[j];
532
+ if (nj < 0) {
533
+ end = j;
534
+ break;
535
+ }
536
+ distances[j - begin] = dis->symmetric_dis(i, nj);
537
+ tmp [j - begin] = nj;
538
+ }
539
+
540
+ fvec_argsort (end - begin, distances.data(), order.data());
541
+ for (size_t j = begin; j < end; j++) {
542
+ hnsw.neighbors[j] = tmp[order[j - begin]];
543
+ }
544
+ }
545
+
546
+ }
547
+ }
548
+
549
+
550
+ void IndexHNSW::link_singletons()
551
+ {
552
+ printf("search for singletons\n");
553
+
554
+ std::vector<bool> seen(ntotal);
555
+
556
+ for (size_t i = 0; i < ntotal; i++) {
557
+ size_t begin, end;
558
+ hnsw.neighbor_range(i, 0, &begin, &end);
559
+ for (size_t j = begin; j < end; j++) {
560
+ storage_idx_t ni = hnsw.neighbors[j];
561
+ if (ni >= 0) seen[ni] = true;
562
+ }
563
+ }
564
+
565
+ int n_sing = 0, n_sing_l1 = 0;
566
+ std::vector<storage_idx_t> singletons;
567
+ for (storage_idx_t i = 0; i < ntotal; i++) {
568
+ if (!seen[i]) {
569
+ singletons.push_back(i);
570
+ n_sing++;
571
+ if (hnsw.levels[i] > 1)
572
+ n_sing_l1++;
573
+ }
574
+ }
575
+
576
+ printf(" Found %d / %ld singletons (%d appear in a level above)\n",
577
+ n_sing, ntotal, n_sing_l1);
578
+
579
+ std::vector<float>recons(singletons.size() * d);
580
+ for (int i = 0; i < singletons.size(); i++) {
581
+
582
+ FAISS_ASSERT(!"not implemented");
583
+
584
+ }
585
+
586
+
587
+ }
588
+
589
+
590
+ /**************************************************************
591
+ * ReconstructFromNeighbors implementation
592
+ **************************************************************/
593
+
594
+
595
+ ReconstructFromNeighbors::ReconstructFromNeighbors(
596
+ const IndexHNSW & index, size_t k, size_t nsq):
597
+ index(index), k(k), nsq(nsq) {
598
+ M = index.hnsw.nb_neighbors(0);
599
+ FAISS_ASSERT(k <= 256);
600
+ code_size = k == 1 ? 0 : nsq;
601
+ ntotal = 0;
602
+ d = index.d;
603
+ FAISS_ASSERT(d % nsq == 0);
604
+ dsub = d / nsq;
605
+ k_reorder = -1;
606
+ }
607
+
608
+ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp) const
609
+ {
610
+
611
+
612
+ const HNSW & hnsw = index.hnsw;
613
+ size_t begin, end;
614
+ hnsw.neighbor_range(i, 0, &begin, &end);
615
+
616
+ if (k == 1 || nsq == 1) {
617
+ const float * beta;
618
+ if (k == 1) {
619
+ beta = codebook.data();
620
+ } else {
621
+ int idx = codes[i];
622
+ beta = codebook.data() + idx * (M + 1);
623
+ }
624
+
625
+ float w0 = beta[0]; // weight of image itself
626
+ index.storage->reconstruct(i, tmp);
627
+
628
+ for (int l = 0; l < d; l++)
629
+ x[l] = w0 * tmp[l];
630
+
631
+ for (size_t j = begin; j < end; j++) {
632
+
633
+ storage_idx_t ji = hnsw.neighbors[j];
634
+ if (ji < 0) ji = i;
635
+ float w = beta[j - begin + 1];
636
+ index.storage->reconstruct(ji, tmp);
637
+ for (int l = 0; l < d; l++)
638
+ x[l] += w * tmp[l];
639
+ }
640
+ } else if (nsq == 2) {
641
+ int idx0 = codes[2 * i];
642
+ int idx1 = codes[2 * i + 1];
643
+
644
+ const float *beta0 = codebook.data() + idx0 * (M + 1);
645
+ const float *beta1 = codebook.data() + (idx1 + k) * (M + 1);
646
+
647
+ index.storage->reconstruct(i, tmp);
648
+
649
+ float w0;
650
+
651
+ w0 = beta0[0];
652
+ for (int l = 0; l < dsub; l++)
653
+ x[l] = w0 * tmp[l];
654
+
655
+ w0 = beta1[0];
656
+ for (int l = dsub; l < d; l++)
657
+ x[l] = w0 * tmp[l];
658
+
659
+ for (size_t j = begin; j < end; j++) {
660
+ storage_idx_t ji = hnsw.neighbors[j];
661
+ if (ji < 0) ji = i;
662
+ index.storage->reconstruct(ji, tmp);
663
+ float w;
664
+ w = beta0[j - begin + 1];
665
+ for (int l = 0; l < dsub; l++)
666
+ x[l] += w * tmp[l];
667
+
668
+ w = beta1[j - begin + 1];
669
+ for (int l = dsub; l < d; l++)
670
+ x[l] += w * tmp[l];
671
+ }
672
+ } else {
673
+ const float *betas[nsq];
674
+ {
675
+ const float *b = codebook.data();
676
+ const uint8_t *c = &codes[i * code_size];
677
+ for (int sq = 0; sq < nsq; sq++) {
678
+ betas[sq] = b + (*c++) * (M + 1);
679
+ b += (M + 1) * k;
680
+ }
681
+ }
682
+
683
+ index.storage->reconstruct(i, tmp);
684
+ {
685
+ int d0 = 0;
686
+ for (int sq = 0; sq < nsq; sq++) {
687
+ float w = *(betas[sq]++);
688
+ int d1 = d0 + dsub;
689
+ for (int l = d0; l < d1; l++) {
690
+ x[l] = w * tmp[l];
691
+ }
692
+ d0 = d1;
693
+ }
694
+ }
695
+
696
+ for (size_t j = begin; j < end; j++) {
697
+ storage_idx_t ji = hnsw.neighbors[j];
698
+ if (ji < 0) ji = i;
699
+
700
+ index.storage->reconstruct(ji, tmp);
701
+ int d0 = 0;
702
+ for (int sq = 0; sq < nsq; sq++) {
703
+ float w = *(betas[sq]++);
704
+ int d1 = d0 + dsub;
705
+ for (int l = d0; l < d1; l++) {
706
+ x[l] += w * tmp[l];
707
+ }
708
+ d0 = d1;
709
+ }
710
+ }
711
+ }
712
+ }
713
+
714
+ void ReconstructFromNeighbors::reconstruct_n(storage_idx_t n0,
715
+ storage_idx_t ni,
716
+ float *x) const
717
+ {
718
+ #pragma omp parallel
719
+ {
720
+ std::vector<float> tmp(index.d);
721
+ #pragma omp for
722
+ for (storage_idx_t i = 0; i < ni; i++) {
723
+ reconstruct(n0 + i, x + i * index.d, tmp.data());
724
+ }
725
+ }
726
+ }
727
+
728
+ size_t ReconstructFromNeighbors::compute_distances(
729
+ size_t n, const idx_t *shortlist,
730
+ const float *query, float *distances) const
731
+ {
732
+ std::vector<float> tmp(2 * index.d);
733
+ size_t ncomp = 0;
734
+ for (int i = 0; i < n; i++) {
735
+ if (shortlist[i] < 0) break;
736
+ reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
737
+ distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
738
+ ncomp++;
739
+ }
740
+ return ncomp;
741
+ }
742
+
743
+ void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float *tmp1) const
744
+ {
745
+ const HNSW & hnsw = index.hnsw;
746
+ size_t begin, end;
747
+ hnsw.neighbor_range(i, 0, &begin, &end);
748
+ size_t d = index.d;
749
+
750
+ index.storage->reconstruct(i, tmp1);
751
+
752
+ for (size_t j = begin; j < end; j++) {
753
+ storage_idx_t ji = hnsw.neighbors[j];
754
+ if (ji < 0) ji = i;
755
+ index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
756
+ }
757
+
758
+ }
759
+
760
+
761
+ /// called by add_codes
762
+ void ReconstructFromNeighbors::estimate_code(
763
+ const float *x, storage_idx_t i, uint8_t *code) const
764
+ {
765
+
766
+ // fill in tmp table with the neighbor values
767
+ float *tmp1 = new float[d * (M + 1) + (d * k)];
768
+ float *tmp2 = tmp1 + d * (M + 1);
769
+ ScopeDeleter<float> del(tmp1);
770
+
771
+ // collect coordinates of base
772
+ get_neighbor_table (i, tmp1);
773
+
774
+ for (size_t sq = 0; sq < nsq; sq++) {
775
+ int d0 = sq * dsub;
776
+
777
+ {
778
+ FINTEGER ki = k, di = d, m1 = M + 1;
779
+ FINTEGER dsubi = dsub;
780
+ float zero = 0, one = 1;
781
+
782
+ sgemm_ ("N", "N", &dsubi, &ki, &m1, &one,
783
+ tmp1 + d0, &di,
784
+ codebook.data() + sq * (m1 * k), &m1,
785
+ &zero, tmp2, &dsubi);
786
+ }
787
+
788
+ float min = HUGE_VAL;
789
+ int argmin = -1;
790
+ for (size_t j = 0; j < k; j++) {
791
+ float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
792
+ if (dis < min) {
793
+ min = dis;
794
+ argmin = j;
795
+ }
796
+ }
797
+ code[sq] = argmin;
798
+ }
799
+
800
+ }
801
+
802
+ void ReconstructFromNeighbors::add_codes(size_t n, const float *x)
803
+ {
804
+ if (k == 1) { // nothing to encode
805
+ ntotal += n;
806
+ return;
807
+ }
808
+ codes.resize(codes.size() + code_size * n);
809
+ #pragma omp parallel for
810
+ for (int i = 0; i < n; i++) {
811
+ estimate_code(x + i * index.d, ntotal + i,
812
+ codes.data() + (ntotal + i) * code_size);
813
+ }
814
+ ntotal += n;
815
+ FAISS_ASSERT (codes.size() == ntotal * code_size);
816
+ }
817
+
818
+
819
+ /**************************************************************
820
+ * IndexHNSWFlat implementation
821
+ **************************************************************/
822
+
823
+
824
+ IndexHNSWFlat::IndexHNSWFlat()
825
+ {
826
+ is_trained = true;
827
+ }
828
+
829
+ IndexHNSWFlat::IndexHNSWFlat(int d, int M):
830
+ IndexHNSW(new IndexFlatL2(d), M)
831
+ {
832
+ own_fields = true;
833
+ is_trained = true;
834
+ }
835
+
836
+
837
+ /**************************************************************
838
+ * IndexHNSWPQ implementation
839
+ **************************************************************/
840
+
841
+
842
+ IndexHNSWPQ::IndexHNSWPQ() {}
843
+
844
+ IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M):
845
+ IndexHNSW(new IndexPQ(d, pq_m, 8), M)
846
+ {
847
+ own_fields = true;
848
+ is_trained = false;
849
+ }
850
+
851
+ void IndexHNSWPQ::train(idx_t n, const float* x)
852
+ {
853
+ IndexHNSW::train (n, x);
854
+ (dynamic_cast<IndexPQ*> (storage))->pq.compute_sdc_table();
855
+ }
856
+
857
+
858
+ /**************************************************************
859
+ * IndexHNSWSQ implementation
860
+ **************************************************************/
861
+
862
+
863
+ IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M):
864
+ IndexHNSW (new IndexScalarQuantizer (d, qtype), M)
865
+ {
866
+ is_trained = false;
867
+ own_fields = true;
868
+ }
869
+
870
+ IndexHNSWSQ::IndexHNSWSQ() {}
871
+
872
+
873
+ /**************************************************************
874
+ * IndexHNSW2Level implementation
875
+ **************************************************************/
876
+
877
+
878
+ IndexHNSW2Level::IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M):
879
+ IndexHNSW (new Index2Layer (quantizer, nlist, m_pq), M)
880
+ {
881
+ own_fields = true;
882
+ is_trained = false;
883
+ }
884
+
885
+ IndexHNSW2Level::IndexHNSW2Level() {}
886
+
887
+
888
+ namespace {
889
+
890
+
891
+ // same as search_from_candidates but uses v
892
+ // visno -> is in result list
893
+ // visno + 1 -> in result list + in candidates
894
+ int search_from_candidates_2(const HNSW & hnsw,
895
+ DistanceComputer & qdis, int k,
896
+ idx_t *I, float * D,
897
+ MinimaxHeap &candidates,
898
+ VisitedTable &vt,
899
+ int level, int nres_in = 0)
900
+ {
901
+ int nres = nres_in;
902
+ int ndis = 0;
903
+ for (int i = 0; i < candidates.size(); i++) {
904
+ idx_t v1 = candidates.ids[i];
905
+ FAISS_ASSERT(v1 >= 0);
906
+ vt.visited[v1] = vt.visno + 1;
907
+ }
908
+
909
+ int nstep = 0;
910
+
911
+ while (candidates.size() > 0) {
912
+ float d0 = 0;
913
+ int v0 = candidates.pop_min(&d0);
914
+
915
+ size_t begin, end;
916
+ hnsw.neighbor_range(v0, level, &begin, &end);
917
+
918
+ for (size_t j = begin; j < end; j++) {
919
+ int v1 = hnsw.neighbors[j];
920
+ if (v1 < 0) break;
921
+ if (vt.visited[v1] == vt.visno + 1) {
922
+ // nothing to do
923
+ } else {
924
+ ndis++;
925
+ float d = qdis(v1);
926
+ candidates.push(v1, d);
927
+
928
+ // never seen before --> add to heap
929
+ if (vt.visited[v1] < vt.visno) {
930
+ if (nres < k) {
931
+ faiss::maxheap_push (++nres, D, I, d, v1);
932
+ } else if (d < D[0]) {
933
+ faiss::maxheap_pop (nres--, D, I);
934
+ faiss::maxheap_push (++nres, D, I, d, v1);
935
+ }
936
+ }
937
+ vt.visited[v1] = vt.visno + 1;
938
+ }
939
+ }
940
+
941
+ nstep++;
942
+ if (nstep > hnsw.efSearch) {
943
+ break;
944
+ }
945
+ }
946
+
947
+ if (level == 0) {
948
+ #pragma omp critical
949
+ {
950
+ hnsw_stats.n1 ++;
951
+ if (candidates.size() == 0)
952
+ hnsw_stats.n2 ++;
953
+ }
954
+ }
955
+
956
+
957
+ return nres;
958
+ }
959
+
960
+
961
+ } // namespace
962
+
963
+ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
964
+ float *distances, idx_t *labels) const
965
+ {
966
+ if (dynamic_cast<const Index2Layer*>(storage)) {
967
+ IndexHNSW::search (n, x, k, distances, labels);
968
+
969
+ } else { // "mixed" search
970
+
971
+ const IndexIVFPQ *index_ivfpq =
972
+ dynamic_cast<const IndexIVFPQ*>(storage);
973
+
974
+ int nprobe = index_ivfpq->nprobe;
975
+
976
+ std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
977
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
978
+
979
+ index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(),
980
+ coarse_assign.get());
981
+
982
+ index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(),
983
+ coarse_dis.get(), distances, labels,
984
+ false);
985
+
986
+ #pragma omp parallel
987
+ {
988
+ VisitedTable vt (ntotal);
989
+ DistanceComputer *dis = storage->get_distance_computer();
990
+ ScopeDeleter1<DistanceComputer> del(dis);
991
+
992
+ int candidates_size = hnsw.upper_beam;
993
+ MinimaxHeap candidates(candidates_size);
994
+
995
+ #pragma omp for
996
+ for(idx_t i = 0; i < n; i++) {
997
+ idx_t * idxi = labels + i * k;
998
+ float * simi = distances + i * k;
999
+ dis->set_query(x + i * d);
1000
+
1001
+ // mark all inverted list elements as visited
1002
+
1003
+ for (int j = 0; j < nprobe; j++) {
1004
+ idx_t key = coarse_assign[j + i * nprobe];
1005
+ if (key < 0) break;
1006
+ size_t list_length = index_ivfpq->get_list_size (key);
1007
+ const idx_t * ids = index_ivfpq->invlists->get_ids (key);
1008
+
1009
+ for (int jj = 0; jj < list_length; jj++) {
1010
+ vt.set (ids[jj]);
1011
+ }
1012
+ }
1013
+
1014
+ candidates.clear();
1015
+ // copy the upper_beam elements to candidates list
1016
+
1017
+ int search_policy = 2;
1018
+
1019
+ if (search_policy == 1) {
1020
+
1021
+ for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
1022
+ if (idxi[j] < 0) break;
1023
+ candidates.push (idxi[j], simi[j]);
1024
+ // search_from_candidates adds them back
1025
+ idxi[j] = -1;
1026
+ simi[j] = HUGE_VAL;
1027
+ }
1028
+
1029
+ // reorder from sorted to heap
1030
+ maxheap_heapify (k, simi, idxi, simi, idxi, k);
1031
+
1032
+ hnsw.search_from_candidates(
1033
+ *dis, k, idxi, simi,
1034
+ candidates, vt, 0, k
1035
+ );
1036
+
1037
+ vt.advance();
1038
+
1039
+ } else if (search_policy == 2) {
1040
+
1041
+ for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
1042
+ if (idxi[j] < 0) break;
1043
+ candidates.push (idxi[j], simi[j]);
1044
+ }
1045
+
1046
+ // reorder from sorted to heap
1047
+ maxheap_heapify (k, simi, idxi, simi, idxi, k);
1048
+
1049
+ search_from_candidates_2 (
1050
+ hnsw, *dis, k, idxi, simi,
1051
+ candidates, vt, 0, k);
1052
+ vt.advance ();
1053
+ vt.advance ();
1054
+
1055
+ }
1056
+
1057
+ maxheap_reorder (k, simi, idxi);
1058
+ }
1059
+ }
1060
+ }
1061
+
1062
+
1063
+ }
1064
+
1065
+
1066
+ void IndexHNSW2Level::flip_to_ivf ()
1067
+ {
1068
+ Index2Layer *storage2l =
1069
+ dynamic_cast<Index2Layer*>(storage);
1070
+
1071
+ FAISS_THROW_IF_NOT (storage2l);
1072
+
1073
+ IndexIVFPQ * index_ivfpq =
1074
+ new IndexIVFPQ (storage2l->q1.quantizer,
1075
+ d, storage2l->q1.nlist,
1076
+ storage2l->pq.M, 8);
1077
+ index_ivfpq->pq = storage2l->pq;
1078
+ index_ivfpq->is_trained = storage2l->is_trained;
1079
+ index_ivfpq->precompute_table();
1080
+ index_ivfpq->own_fields = storage2l->q1.own_fields;
1081
+ storage2l->transfer_to_IVFPQ(*index_ivfpq);
1082
+ index_ivfpq->make_direct_map (true);
1083
+
1084
+ storage = index_ivfpq;
1085
+ delete storage2l;
1086
+
1087
+ }
1088
+
1089
+
1090
+ } // namespace faiss