faiss 0.2.3 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -199,60 +199,6 @@ void RangeSearchPartialResult::merge(
199
199
  result->lims[0] = 0;
200
200
  }
201
201
 
202
- /***********************************************************************
203
- * IDSelectorRange
204
- ***********************************************************************/
205
-
206
- IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
207
- : imin(imin), imax(imax) {}
208
-
209
- bool IDSelectorRange::is_member(idx_t id) const {
210
- return id >= imin && id < imax;
211
- }
212
-
213
- /***********************************************************************
214
- * IDSelectorArray
215
- ***********************************************************************/
216
-
217
- IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
218
-
219
- bool IDSelectorArray::is_member(idx_t id) const {
220
- for (idx_t i = 0; i < n; i++) {
221
- if (ids[i] == id)
222
- return true;
223
- }
224
- return false;
225
- }
226
-
227
- /***********************************************************************
228
- * IDSelectorBatch
229
- ***********************************************************************/
230
-
231
- IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
232
- nbits = 0;
233
- while (n > (1L << nbits))
234
- nbits++;
235
- nbits += 5;
236
- // for n = 1M, nbits = 25 is optimal, see P56659518
237
-
238
- mask = (1L << nbits) - 1;
239
- bloom.resize(1UL << (nbits - 3), 0);
240
- for (long i = 0; i < n; i++) {
241
- Index::idx_t id = indices[i];
242
- set.insert(id);
243
- id &= mask;
244
- bloom[id >> 3] |= 1 << (id & 7);
245
- }
246
- }
247
-
248
- bool IDSelectorBatch::is_member(idx_t i) const {
249
- long im = i & mask;
250
- if (!(bloom[im >> 3] & (1 << (im & 7)))) {
251
- return 0;
252
- }
253
- return set.count(i);
254
- }
255
-
256
202
  /***********************************************************
257
203
  * Interrupt callback
258
204
  ***********************************************************/
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // Auxiliary index structures, that are used in indexes but that can
11
9
  // be forward-declared
12
10
 
@@ -18,7 +16,6 @@
18
16
  #include <cstring>
19
17
  #include <memory>
20
18
  #include <mutex>
21
- #include <unordered_set>
22
19
  #include <vector>
23
20
 
24
21
  #include <faiss/Index.h>
@@ -52,55 +49,6 @@ struct RangeSearchResult {
52
49
  virtual ~RangeSearchResult();
53
50
  };
54
51
 
55
- /** Encapsulates a set of ids to remove. */
56
- struct IDSelector {
57
- typedef Index::idx_t idx_t;
58
- virtual bool is_member(idx_t id) const = 0;
59
- virtual ~IDSelector() {}
60
- };
61
-
62
- /** remove ids between [imni, imax) */
63
- struct IDSelectorRange : IDSelector {
64
- idx_t imin, imax;
65
-
66
- IDSelectorRange(idx_t imin, idx_t imax);
67
- bool is_member(idx_t id) const override;
68
- ~IDSelectorRange() override {}
69
- };
70
-
71
- /** simple list of elements to remove
72
- *
73
- * this is inefficient in most cases, except for IndexIVF with
74
- * maintain_direct_map
75
- */
76
- struct IDSelectorArray : IDSelector {
77
- size_t n;
78
- const idx_t* ids;
79
-
80
- IDSelectorArray(size_t n, const idx_t* ids);
81
- bool is_member(idx_t id) const override;
82
- ~IDSelectorArray() override {}
83
- };
84
-
85
- /** Remove ids from a set. Repetitions of ids in the indices set
86
- * passed to the constructor does not hurt performance. The hash
87
- * function used for the bloom filter and GCC's implementation of
88
- * unordered_set are just the least significant bits of the id. This
89
- * works fine for random ids or ids in sequences but will produce many
90
- * hash collisions if lsb's are always the same */
91
- struct IDSelectorBatch : IDSelector {
92
- std::unordered_set<idx_t> set;
93
-
94
- typedef unsigned char uint8_t;
95
- std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
96
- int nbits;
97
- idx_t mask;
98
-
99
- IDSelectorBatch(size_t n, const idx_t* indices);
100
- bool is_member(idx_t id) const override;
101
- ~IDSelectorBatch() override {}
102
- };
103
-
104
52
  /****************************************************************
105
53
  * Result structures for range search.
106
54
  *
@@ -186,30 +134,6 @@ struct RangeSearchPartialResult : BufferList {
186
134
  bool do_delete = true);
187
135
  };
188
136
 
189
- /***********************************************************
190
- * The distance computer maintains a current query and computes
191
- * distances to elements in an index that supports random access.
192
- *
193
- * The DistanceComputer is not intended to be thread-safe (eg. because
194
- * it maintains counters) so the distance functions are not const,
195
- * instantiate one from each thread if needed.
196
- ***********************************************************/
197
- struct DistanceComputer {
198
- using idx_t = Index::idx_t;
199
-
200
- /// called before computing distances. Pointer x should remain valid
201
- /// while operator () is called
202
- virtual void set_query(const float* x) = 0;
203
-
204
- /// compute distance of vector i to current query
205
- virtual float operator()(idx_t i) = 0;
206
-
207
- /// compute distance between two stored vectors
208
- virtual float symmetric_dis(idx_t i, idx_t j) = 0;
209
-
210
- virtual ~DistanceComputer() {}
211
- };
212
-
213
137
  /***********************************************************
214
138
  * Interrupt callback
215
139
  ***********************************************************/
@@ -0,0 +1,64 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+
12
+ namespace faiss {
13
+
14
+ /***********************************************************
15
+ * The distance computer maintains a current query and computes
16
+ * distances to elements in an index that supports random access.
17
+ *
18
+ * The DistanceComputer is not intended to be thread-safe (eg. because
19
+ * it maintains counters) so the distance functions are not const,
20
+ * instantiate one from each thread if needed.
21
+ *
22
+ * Note that the equivalent for IVF indexes is the InvertedListScanner,
23
+ * that has additional methods to handle the inverted list context.
24
+ ***********************************************************/
25
+ struct DistanceComputer {
26
+ using idx_t = Index::idx_t;
27
+
28
+ /// called before computing distances. Pointer x should remain valid
29
+ /// while operator () is called
30
+ virtual void set_query(const float* x) = 0;
31
+
32
+ /// compute distance of vector i to current query
33
+ virtual float operator()(idx_t i) = 0;
34
+
35
+ /// compute distance between two stored vectors
36
+ virtual float symmetric_dis(idx_t i, idx_t j) = 0;
37
+
38
+ virtual ~DistanceComputer() {}
39
+ };
40
+
41
+ /*************************************************************
42
+ * Specialized version of the DistanceComputer when we know that codes are
43
+ * laid out in a flat index.
44
+ */
45
+ struct FlatCodesDistanceComputer : DistanceComputer {
46
+ const uint8_t* codes;
47
+ size_t code_size;
48
+
49
+ FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
50
+ : codes(codes), code_size(code_size) {}
51
+
52
+ FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
53
+
54
+ float operator()(idx_t i) final {
55
+ return distance_to_code(codes + i * code_size);
56
+ }
57
+
58
+ /// compute distance of current query to an encoded vector
59
+ virtual float distance_to_code(const uint8_t* code) = 0;
60
+
61
+ virtual ~FlatCodesDistanceComputer() {}
62
+ };
63
+
64
+ } // namespace faiss
@@ -12,6 +12,8 @@
12
12
  #include <string>
13
13
 
14
14
  #include <faiss/impl/AuxIndexStructures.h>
15
+ #include <faiss/impl/DistanceComputer.h>
16
+ #include <faiss/impl/IDSelector.h>
15
17
 
16
18
  namespace faiss {
17
19
 
@@ -434,17 +436,22 @@ void HNSW::add_links_starting_from(
434
436
 
435
437
  ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
436
438
 
439
+ std::vector<storage_idx_t> neighbors;
440
+ neighbors.reserve(link_targets.size());
437
441
  while (!link_targets.empty()) {
438
- int other_id = link_targets.top().id;
442
+ storage_idx_t other_id = link_targets.top().id;
443
+ add_link(*this, ptdis, pt_id, other_id, level);
444
+ neighbors.push_back(other_id);
445
+ link_targets.pop();
446
+ }
439
447
 
448
+ omp_unset_lock(&locks[pt_id]);
449
+ for (storage_idx_t other_id : neighbors) {
440
450
  omp_set_lock(&locks[other_id]);
441
451
  add_link(*this, ptdis, other_id, pt_id, level);
442
452
  omp_unset_lock(&locks[other_id]);
443
-
444
- add_link(*this, ptdis, pt_id, other_id, level);
445
-
446
- link_targets.pop();
447
453
  }
454
+ omp_set_lock(&locks[pt_id]);
448
455
  }
449
456
 
450
457
  /**************************************************************
@@ -496,9 +503,19 @@ void HNSW::add_with_locks(
496
503
  }
497
504
  }
498
505
 
506
+ /**************************************************************
507
+ * Searching
508
+ **************************************************************/
509
+
510
+ namespace {
511
+
512
+ using idx_t = HNSW::idx_t;
513
+ using MinimaxHeap = HNSW::MinimaxHeap;
514
+ using Node = HNSW::Node;
499
515
  /** Do a BFS on the candidates list */
500
516
 
501
- int HNSW::search_from_candidates(
517
+ int search_from_candidates(
518
+ const HNSW& hnsw,
502
519
  DistanceComputer& qdis,
503
520
  int k,
504
521
  idx_t* I,
@@ -507,22 +524,31 @@ int HNSW::search_from_candidates(
507
524
  VisitedTable& vt,
508
525
  HNSWStats& stats,
509
526
  int level,
510
- int nres_in) const {
527
+ int nres_in = 0,
528
+ const SearchParametersHNSW* params = nullptr) {
511
529
  int nres = nres_in;
512
530
  int ndis = 0;
531
+
532
+ // can be overridden by search params
533
+ bool do_dis_check = params ? params->check_relative_distance
534
+ : hnsw.check_relative_distance;
535
+ int efSearch = params ? params->efSearch : hnsw.efSearch;
536
+ const IDSelector* sel = params ? params->sel : nullptr;
537
+
513
538
  for (int i = 0; i < candidates.size(); i++) {
514
539
  idx_t v1 = candidates.ids[i];
515
540
  float d = candidates.dis[i];
516
541
  FAISS_ASSERT(v1 >= 0);
517
- if (nres < k) {
518
- faiss::maxheap_push(++nres, D, I, d, v1);
519
- } else if (d < D[0]) {
520
- faiss::maxheap_replace_top(nres, D, I, d, v1);
542
+ if (!sel || sel->is_member(v1)) {
543
+ if (nres < k) {
544
+ faiss::maxheap_push(++nres, D, I, d, v1);
545
+ } else if (d < D[0]) {
546
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
547
+ }
521
548
  }
522
549
  vt.set(v1);
523
550
  }
524
551
 
525
- bool do_dis_check = check_relative_distance;
526
552
  int nstep = 0;
527
553
 
528
554
  while (candidates.size() > 0) {
@@ -541,10 +567,10 @@ int HNSW::search_from_candidates(
541
567
  }
542
568
 
543
569
  size_t begin, end;
544
- neighbor_range(v0, level, &begin, &end);
570
+ hnsw.neighbor_range(v0, level, &begin, &end);
545
571
 
546
572
  for (size_t j = begin; j < end; j++) {
547
- int v1 = neighbors[j];
573
+ int v1 = hnsw.neighbors[j];
548
574
  if (v1 < 0)
549
575
  break;
550
576
  if (vt.get(v1)) {
@@ -553,10 +579,12 @@ int HNSW::search_from_candidates(
553
579
  vt.set(v1);
554
580
  ndis++;
555
581
  float d = qdis(v1);
556
- if (nres < k) {
557
- faiss::maxheap_push(++nres, D, I, d, v1);
558
- } else if (d < D[0]) {
559
- faiss::maxheap_replace_top(nres, D, I, d, v1);
582
+ if (!sel || sel->is_member(v1)) {
583
+ if (nres < k) {
584
+ faiss::maxheap_push(++nres, D, I, d, v1);
585
+ } else if (d < D[0]) {
586
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
587
+ }
560
588
  }
561
589
  candidates.push(v1, d);
562
590
  }
@@ -578,16 +606,13 @@ int HNSW::search_from_candidates(
578
606
  return nres;
579
607
  }
580
608
 
581
- /**************************************************************
582
- * Searching
583
- **************************************************************/
584
-
585
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
609
+ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
610
+ const HNSW& hnsw,
586
611
  const Node& node,
587
612
  DistanceComputer& qdis,
588
613
  int ef,
589
614
  VisitedTable* vt,
590
- HNSWStats& stats) const {
615
+ HNSWStats& stats) {
591
616
  int ndis = 0;
592
617
  std::priority_queue<Node> top_candidates;
593
618
  std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
@@ -609,10 +634,10 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
609
634
  candidates.pop();
610
635
 
611
636
  size_t begin, end;
612
- neighbor_range(v0, 0, &begin, &end);
637
+ hnsw.neighbor_range(v0, 0, &begin, &end);
613
638
 
614
639
  for (size_t j = begin; j < end; ++j) {
615
- int v1 = neighbors[j];
640
+ int v1 = hnsw.neighbors[j];
616
641
 
617
642
  if (v1 < 0) {
618
643
  break;
@@ -646,14 +671,19 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
646
671
  return top_candidates;
647
672
  }
648
673
 
674
+ } // anonymous namespace
675
+
649
676
  HNSWStats HNSW::search(
650
677
  DistanceComputer& qdis,
651
678
  int k,
652
679
  idx_t* I,
653
680
  float* D,
654
- VisitedTable& vt) const {
681
+ VisitedTable& vt,
682
+ const SearchParametersHNSW* params) const {
655
683
  HNSWStats stats;
656
-
684
+ if (entry_point == -1) {
685
+ return stats;
686
+ }
657
687
  if (upper_beam == 1) {
658
688
  // greedy search on upper levels
659
689
  storage_idx_t nearest = entry_point;
@@ -664,16 +694,22 @@ HNSWStats HNSW::search(
664
694
  }
665
695
 
666
696
  int ef = std::max(efSearch, k);
667
- if (search_bounded_queue) {
697
+ if (search_bounded_queue) { // this is the most common branch
668
698
  MinimaxHeap candidates(ef);
669
699
 
670
700
  candidates.push(nearest, d_nearest);
671
701
 
672
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
702
+ search_from_candidates(
703
+ *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
673
704
  } else {
674
705
  std::priority_queue<Node> top_candidates =
675
706
  search_from_candidate_unbounded(
676
- Node(d_nearest, nearest), qdis, ef, &vt, stats);
707
+ *this,
708
+ Node(d_nearest, nearest),
709
+ qdis,
710
+ ef,
711
+ &vt,
712
+ stats);
677
713
 
678
714
  while (top_candidates.size() > k) {
679
715
  top_candidates.pop();
@@ -713,9 +749,10 @@ HNSWStats HNSW::search(
713
749
 
714
750
  if (level == 0) {
715
751
  nres = search_from_candidates(
716
- qdis, k, I, D, candidates, vt, stats, 0);
752
+ *this, qdis, k, I, D, candidates, vt, stats, 0);
717
753
  } else {
718
754
  nres = search_from_candidates(
755
+ *this,
719
756
  qdis,
720
757
  candidates_size,
721
758
  I_to_next.data(),
@@ -732,6 +769,70 @@ HNSWStats HNSW::search(
732
769
  return stats;
733
770
  }
734
771
 
772
+ void HNSW::search_level_0(
773
+ DistanceComputer& qdis,
774
+ int k,
775
+ idx_t* idxi,
776
+ float* simi,
777
+ idx_t nprobe,
778
+ const storage_idx_t* nearest_i,
779
+ const float* nearest_d,
780
+ int search_type,
781
+ HNSWStats& search_stats,
782
+ VisitedTable& vt) const {
783
+ const HNSW& hnsw = *this;
784
+
785
+ if (search_type == 1) {
786
+ int nres = 0;
787
+
788
+ for (int j = 0; j < nprobe; j++) {
789
+ storage_idx_t cj = nearest_i[j];
790
+
791
+ if (cj < 0)
792
+ break;
793
+
794
+ if (vt.get(cj))
795
+ continue;
796
+
797
+ int candidates_size = std::max(hnsw.efSearch, int(k));
798
+ MinimaxHeap candidates(candidates_size);
799
+
800
+ candidates.push(cj, nearest_d[j]);
801
+
802
+ nres = search_from_candidates(
803
+ hnsw,
804
+ qdis,
805
+ k,
806
+ idxi,
807
+ simi,
808
+ candidates,
809
+ vt,
810
+ search_stats,
811
+ 0,
812
+ nres);
813
+ }
814
+ } else if (search_type == 2) {
815
+ int candidates_size = std::max(hnsw.efSearch, int(k));
816
+ candidates_size = std::max(candidates_size, int(nprobe));
817
+
818
+ MinimaxHeap candidates(candidates_size);
819
+ for (int j = 0; j < nprobe; j++) {
820
+ storage_idx_t cj = nearest_i[j];
821
+
822
+ if (cj < 0)
823
+ break;
824
+ candidates.push(cj, nearest_d[j]);
825
+ }
826
+
827
+ search_from_candidates(
828
+ hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
829
+ }
830
+ }
831
+
832
+ /**************************************************************
833
+ * MinimaxHeap
834
+ **************************************************************/
835
+
735
836
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
736
837
  if (k == n) {
737
838
  if (v >= dis[0])
@@ -43,6 +43,13 @@ struct VisitedTable;
43
43
  struct DistanceComputer; // from AuxIndexStructures
44
44
  struct HNSWStats;
45
45
 
46
+ struct SearchParametersHNSW : SearchParameters {
47
+ int efSearch = 16;
48
+ bool check_relative_distance = true;
49
+
50
+ ~SearchParametersHNSW() {}
51
+ };
52
+
46
53
  struct HNSW {
47
54
  /// internal storage of vectors (32 bits: this is expensive)
48
55
  typedef int storage_idx_t;
@@ -188,30 +195,26 @@ struct HNSW {
188
195
  std::vector<omp_lock_t>& locks,
189
196
  VisitedTable& vt);
190
197
 
191
- int search_from_candidates(
198
+ /// search interface for 1 point, single thread
199
+ HNSWStats search(
192
200
  DistanceComputer& qdis,
193
201
  int k,
194
202
  idx_t* I,
195
203
  float* D,
196
- MinimaxHeap& candidates,
197
204
  VisitedTable& vt,
198
- HNSWStats& stats,
199
- int level,
200
- int nres_in = 0) const;
205
+ const SearchParametersHNSW* params = nullptr) const;
201
206
 
202
- std::priority_queue<Node> search_from_candidate_unbounded(
203
- const Node& node,
204
- DistanceComputer& qdis,
205
- int ef,
206
- VisitedTable* vt,
207
- HNSWStats& stats) const;
208
-
209
- /// search interface
210
- HNSWStats search(
207
+ /// search only in level 0 from a given vertex
208
+ void search_level_0(
211
209
  DistanceComputer& qdis,
212
210
  int k,
213
- idx_t* I,
214
- float* D,
211
+ idx_t* idxi,
212
+ float* simi,
213
+ idx_t nprobe,
214
+ const storage_idx_t* nearest_i,
215
+ const float* nearest_d,
216
+ int search_type,
217
+ HNSWStats& search_stats,
215
218
  VisitedTable& vt) const;
216
219
 
217
220
  void reset();
@@ -0,0 +1,125 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/impl/IDSelector.h>
10
+
11
+ namespace faiss {
12
+
13
+ /***********************************************************************
14
+ * IDSelectorRange
15
+ ***********************************************************************/
16
+
17
+ IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted)
18
+ : imin(imin), imax(imax), assume_sorted(assume_sorted) {}
19
+
20
+ bool IDSelectorRange::is_member(idx_t id) const {
21
+ return id >= imin && id < imax;
22
+ }
23
+
24
+ void IDSelectorRange::find_sorted_ids_bounds(
25
+ size_t list_size,
26
+ const idx_t* ids,
27
+ size_t* jmin_out,
28
+ size_t* jmax_out) const {
29
+ FAISS_ASSERT(assume_sorted);
30
+ if (list_size == 0 || imax <= ids[0] || imin > ids[list_size - 1]) {
31
+ *jmin_out = *jmax_out = 0;
32
+ return;
33
+ }
34
+ // bissection to find imin
35
+ if (ids[0] >= imin) {
36
+ *jmin_out = 0;
37
+ } else {
38
+ size_t j0 = 0, j1 = list_size;
39
+ while (j1 > j0 + 1) {
40
+ size_t jmed = (j0 + j1) / 2;
41
+ if (ids[jmed] >= imin) {
42
+ j1 = jmed;
43
+ } else {
44
+ j0 = jmed;
45
+ }
46
+ }
47
+ *jmin_out = j1;
48
+ }
49
+ // bissection to find imax
50
+ if (*jmin_out == list_size || ids[*jmin_out] >= imax) {
51
+ *jmax_out = *jmin_out;
52
+ } else {
53
+ size_t j0 = *jmin_out, j1 = list_size;
54
+ while (j1 > j0 + 1) {
55
+ size_t jmed = (j0 + j1) / 2;
56
+ if (ids[jmed] >= imax) {
57
+ j1 = jmed;
58
+ } else {
59
+ j0 = jmed;
60
+ }
61
+ }
62
+ *jmax_out = j1;
63
+ }
64
+ }
65
+
66
+ /***********************************************************************
67
+ * IDSelectorArray
68
+ ***********************************************************************/
69
+
70
+ IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
71
+
72
+ bool IDSelectorArray::is_member(idx_t id) const {
73
+ for (idx_t i = 0; i < n; i++) {
74
+ if (ids[i] == id)
75
+ return true;
76
+ }
77
+ return false;
78
+ }
79
+
80
+ /***********************************************************************
81
+ * IDSelectorBatch
82
+ ***********************************************************************/
83
+
84
+ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
85
+ nbits = 0;
86
+ while (n > ((idx_t)1 << nbits)) {
87
+ nbits++;
88
+ }
89
+ nbits += 5;
90
+ // for n = 1M, nbits = 25 is optimal, see P56659518
91
+
92
+ mask = ((idx_t)1 << nbits) - 1;
93
+ bloom.resize((idx_t)1 << (nbits - 3), 0);
94
+ for (idx_t i = 0; i < n; i++) {
95
+ Index::idx_t id = indices[i];
96
+ set.insert(id);
97
+ id &= mask;
98
+ bloom[id >> 3] |= 1 << (id & 7);
99
+ }
100
+ }
101
+
102
+ bool IDSelectorBatch::is_member(idx_t i) const {
103
+ long im = i & mask;
104
+ if (!(bloom[im >> 3] & (1 << (im & 7)))) {
105
+ return 0;
106
+ }
107
+ return set.count(i);
108
+ }
109
+
110
+ /***********************************************************************
111
+ * IDSelectorBitmap
112
+ ***********************************************************************/
113
+
114
+ IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap)
115
+ : n(n), bitmap(bitmap) {}
116
+
117
+ bool IDSelectorBitmap::is_member(idx_t ii) const {
118
+ uint64_t i = ii;
119
+ if ((i >> 3) >= n) {
120
+ return false;
121
+ }
122
+ return (bitmap[i >> 3] >> (i & 7)) & 1;
123
+ }
124
+
125
+ } // namespace faiss