faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -199,60 +199,6 @@ void RangeSearchPartialResult::merge(
199
199
  result->lims[0] = 0;
200
200
  }
201
201
 
202
- /***********************************************************************
203
- * IDSelectorRange
204
- ***********************************************************************/
205
-
206
- IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
207
- : imin(imin), imax(imax) {}
208
-
209
- bool IDSelectorRange::is_member(idx_t id) const {
210
- return id >= imin && id < imax;
211
- }
212
-
213
- /***********************************************************************
214
- * IDSelectorArray
215
- ***********************************************************************/
216
-
217
- IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
218
-
219
- bool IDSelectorArray::is_member(idx_t id) const {
220
- for (idx_t i = 0; i < n; i++) {
221
- if (ids[i] == id)
222
- return true;
223
- }
224
- return false;
225
- }
226
-
227
- /***********************************************************************
228
- * IDSelectorBatch
229
- ***********************************************************************/
230
-
231
- IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
232
- nbits = 0;
233
- while (n > (1L << nbits))
234
- nbits++;
235
- nbits += 5;
236
- // for n = 1M, nbits = 25 is optimal, see P56659518
237
-
238
- mask = (1L << nbits) - 1;
239
- bloom.resize(1UL << (nbits - 3), 0);
240
- for (long i = 0; i < n; i++) {
241
- Index::idx_t id = indices[i];
242
- set.insert(id);
243
- id &= mask;
244
- bloom[id >> 3] |= 1 << (id & 7);
245
- }
246
- }
247
-
248
- bool IDSelectorBatch::is_member(idx_t i) const {
249
- long im = i & mask;
250
- if (!(bloom[im >> 3] & (1 << (im & 7)))) {
251
- return 0;
252
- }
253
- return set.count(i);
254
- }
255
-
256
202
  /***********************************************************
257
203
  * Interrupt callback
258
204
  ***********************************************************/
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // Auxiliary index structures, that are used in indexes but that can
11
9
  // be forward-declared
12
10
 
@@ -18,7 +16,6 @@
18
16
  #include <cstring>
19
17
  #include <memory>
20
18
  #include <mutex>
21
- #include <unordered_set>
22
19
  #include <vector>
23
20
 
24
21
  #include <faiss/Index.h>
@@ -52,55 +49,6 @@ struct RangeSearchResult {
52
49
  virtual ~RangeSearchResult();
53
50
  };
54
51
 
55
- /** Encapsulates a set of ids to remove. */
56
- struct IDSelector {
57
- typedef Index::idx_t idx_t;
58
- virtual bool is_member(idx_t id) const = 0;
59
- virtual ~IDSelector() {}
60
- };
61
-
62
- /** remove ids between [imni, imax) */
63
- struct IDSelectorRange : IDSelector {
64
- idx_t imin, imax;
65
-
66
- IDSelectorRange(idx_t imin, idx_t imax);
67
- bool is_member(idx_t id) const override;
68
- ~IDSelectorRange() override {}
69
- };
70
-
71
- /** simple list of elements to remove
72
- *
73
- * this is inefficient in most cases, except for IndexIVF with
74
- * maintain_direct_map
75
- */
76
- struct IDSelectorArray : IDSelector {
77
- size_t n;
78
- const idx_t* ids;
79
-
80
- IDSelectorArray(size_t n, const idx_t* ids);
81
- bool is_member(idx_t id) const override;
82
- ~IDSelectorArray() override {}
83
- };
84
-
85
- /** Remove ids from a set. Repetitions of ids in the indices set
86
- * passed to the constructor does not hurt performance. The hash
87
- * function used for the bloom filter and GCC's implementation of
88
- * unordered_set are just the least significant bits of the id. This
89
- * works fine for random ids or ids in sequences but will produce many
90
- * hash collisions if lsb's are always the same */
91
- struct IDSelectorBatch : IDSelector {
92
- std::unordered_set<idx_t> set;
93
-
94
- typedef unsigned char uint8_t;
95
- std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
96
- int nbits;
97
- idx_t mask;
98
-
99
- IDSelectorBatch(size_t n, const idx_t* indices);
100
- bool is_member(idx_t id) const override;
101
- ~IDSelectorBatch() override {}
102
- };
103
-
104
52
  /****************************************************************
105
53
  * Result structures for range search.
106
54
  *
@@ -186,30 +134,6 @@ struct RangeSearchPartialResult : BufferList {
186
134
  bool do_delete = true);
187
135
  };
188
136
 
189
- /***********************************************************
190
- * The distance computer maintains a current query and computes
191
- * distances to elements in an index that supports random access.
192
- *
193
- * The DistanceComputer is not intended to be thread-safe (eg. because
194
- * it maintains counters) so the distance functions are not const,
195
- * instantiate one from each thread if needed.
196
- ***********************************************************/
197
- struct DistanceComputer {
198
- using idx_t = Index::idx_t;
199
-
200
- /// called before computing distances. Pointer x should remain valid
201
- /// while operator () is called
202
- virtual void set_query(const float* x) = 0;
203
-
204
- /// compute distance of vector i to current query
205
- virtual float operator()(idx_t i) = 0;
206
-
207
- /// compute distance between two stored vectors
208
- virtual float symmetric_dis(idx_t i, idx_t j) = 0;
209
-
210
- virtual ~DistanceComputer() {}
211
- };
212
-
213
137
  /***********************************************************
214
138
  * Interrupt callback
215
139
  ***********************************************************/
@@ -0,0 +1,64 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+
12
+ namespace faiss {
13
+
14
+ /***********************************************************
15
+ * The distance computer maintains a current query and computes
16
+ * distances to elements in an index that supports random access.
17
+ *
18
+ * The DistanceComputer is not intended to be thread-safe (eg. because
19
+ * it maintains counters) so the distance functions are not const,
20
+ * instantiate one from each thread if needed.
21
+ *
22
+ * Note that the equivalent for IVF indexes is the InvertedListScanner,
23
+ * that has additional methods to handle the inverted list context.
24
+ ***********************************************************/
25
+ struct DistanceComputer {
26
+ using idx_t = Index::idx_t;
27
+
28
+ /// called before computing distances. Pointer x should remain valid
29
+ /// while operator () is called
30
+ virtual void set_query(const float* x) = 0;
31
+
32
+ /// compute distance of vector i to current query
33
+ virtual float operator()(idx_t i) = 0;
34
+
35
+ /// compute distance between two stored vectors
36
+ virtual float symmetric_dis(idx_t i, idx_t j) = 0;
37
+
38
+ virtual ~DistanceComputer() {}
39
+ };
40
+
41
+ /*************************************************************
42
+ * Specialized version of the DistanceComputer when we know that codes are
43
+ * laid out in a flat index.
44
+ */
45
+ struct FlatCodesDistanceComputer : DistanceComputer {
46
+ const uint8_t* codes;
47
+ size_t code_size;
48
+
49
+ FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
50
+ : codes(codes), code_size(code_size) {}
51
+
52
+ FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
53
+
54
+ float operator()(idx_t i) final {
55
+ return distance_to_code(codes + i * code_size);
56
+ }
57
+
58
+ /// compute distance of current query to an encoded vector
59
+ virtual float distance_to_code(const uint8_t* code) = 0;
60
+
61
+ virtual ~FlatCodesDistanceComputer() {}
62
+ };
63
+
64
+ } // namespace faiss
@@ -12,6 +12,8 @@
12
12
  #include <string>
13
13
 
14
14
  #include <faiss/impl/AuxIndexStructures.h>
15
+ #include <faiss/impl/DistanceComputer.h>
16
+ #include <faiss/impl/IDSelector.h>
15
17
 
16
18
  namespace faiss {
17
19
 
@@ -434,17 +436,22 @@ void HNSW::add_links_starting_from(
434
436
 
435
437
  ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
436
438
 
439
+ std::vector<storage_idx_t> neighbors;
440
+ neighbors.reserve(link_targets.size());
437
441
  while (!link_targets.empty()) {
438
- int other_id = link_targets.top().id;
442
+ storage_idx_t other_id = link_targets.top().id;
443
+ add_link(*this, ptdis, pt_id, other_id, level);
444
+ neighbors.push_back(other_id);
445
+ link_targets.pop();
446
+ }
439
447
 
448
+ omp_unset_lock(&locks[pt_id]);
449
+ for (storage_idx_t other_id : neighbors) {
440
450
  omp_set_lock(&locks[other_id]);
441
451
  add_link(*this, ptdis, other_id, pt_id, level);
442
452
  omp_unset_lock(&locks[other_id]);
443
-
444
- add_link(*this, ptdis, pt_id, other_id, level);
445
-
446
- link_targets.pop();
447
453
  }
454
+ omp_set_lock(&locks[pt_id]);
448
455
  }
449
456
 
450
457
  /**************************************************************
@@ -496,9 +503,19 @@ void HNSW::add_with_locks(
496
503
  }
497
504
  }
498
505
 
506
+ /**************************************************************
507
+ * Searching
508
+ **************************************************************/
509
+
510
+ namespace {
511
+
512
+ using idx_t = HNSW::idx_t;
513
+ using MinimaxHeap = HNSW::MinimaxHeap;
514
+ using Node = HNSW::Node;
499
515
  /** Do a BFS on the candidates list */
500
516
 
501
- int HNSW::search_from_candidates(
517
+ int search_from_candidates(
518
+ const HNSW& hnsw,
502
519
  DistanceComputer& qdis,
503
520
  int k,
504
521
  idx_t* I,
@@ -507,22 +524,31 @@ int HNSW::search_from_candidates(
507
524
  VisitedTable& vt,
508
525
  HNSWStats& stats,
509
526
  int level,
510
- int nres_in) const {
527
+ int nres_in = 0,
528
+ const SearchParametersHNSW* params = nullptr) {
511
529
  int nres = nres_in;
512
530
  int ndis = 0;
531
+
532
+ // can be overridden by search params
533
+ bool do_dis_check = params ? params->check_relative_distance
534
+ : hnsw.check_relative_distance;
535
+ int efSearch = params ? params->efSearch : hnsw.efSearch;
536
+ const IDSelector* sel = params ? params->sel : nullptr;
537
+
513
538
  for (int i = 0; i < candidates.size(); i++) {
514
539
  idx_t v1 = candidates.ids[i];
515
540
  float d = candidates.dis[i];
516
541
  FAISS_ASSERT(v1 >= 0);
517
- if (nres < k) {
518
- faiss::maxheap_push(++nres, D, I, d, v1);
519
- } else if (d < D[0]) {
520
- faiss::maxheap_replace_top(nres, D, I, d, v1);
542
+ if (!sel || sel->is_member(v1)) {
543
+ if (nres < k) {
544
+ faiss::maxheap_push(++nres, D, I, d, v1);
545
+ } else if (d < D[0]) {
546
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
547
+ }
521
548
  }
522
549
  vt.set(v1);
523
550
  }
524
551
 
525
- bool do_dis_check = check_relative_distance;
526
552
  int nstep = 0;
527
553
 
528
554
  while (candidates.size() > 0) {
@@ -541,10 +567,10 @@ int HNSW::search_from_candidates(
541
567
  }
542
568
 
543
569
  size_t begin, end;
544
- neighbor_range(v0, level, &begin, &end);
570
+ hnsw.neighbor_range(v0, level, &begin, &end);
545
571
 
546
572
  for (size_t j = begin; j < end; j++) {
547
- int v1 = neighbors[j];
573
+ int v1 = hnsw.neighbors[j];
548
574
  if (v1 < 0)
549
575
  break;
550
576
  if (vt.get(v1)) {
@@ -553,10 +579,12 @@ int HNSW::search_from_candidates(
553
579
  vt.set(v1);
554
580
  ndis++;
555
581
  float d = qdis(v1);
556
- if (nres < k) {
557
- faiss::maxheap_push(++nres, D, I, d, v1);
558
- } else if (d < D[0]) {
559
- faiss::maxheap_replace_top(nres, D, I, d, v1);
582
+ if (!sel || sel->is_member(v1)) {
583
+ if (nres < k) {
584
+ faiss::maxheap_push(++nres, D, I, d, v1);
585
+ } else if (d < D[0]) {
586
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
587
+ }
560
588
  }
561
589
  candidates.push(v1, d);
562
590
  }
@@ -578,16 +606,13 @@ int HNSW::search_from_candidates(
578
606
  return nres;
579
607
  }
580
608
 
581
- /**************************************************************
582
- * Searching
583
- **************************************************************/
584
-
585
- std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
609
+ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
610
+ const HNSW& hnsw,
586
611
  const Node& node,
587
612
  DistanceComputer& qdis,
588
613
  int ef,
589
614
  VisitedTable* vt,
590
- HNSWStats& stats) const {
615
+ HNSWStats& stats) {
591
616
  int ndis = 0;
592
617
  std::priority_queue<Node> top_candidates;
593
618
  std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
@@ -609,10 +634,10 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
609
634
  candidates.pop();
610
635
 
611
636
  size_t begin, end;
612
- neighbor_range(v0, 0, &begin, &end);
637
+ hnsw.neighbor_range(v0, 0, &begin, &end);
613
638
 
614
639
  for (size_t j = begin; j < end; ++j) {
615
- int v1 = neighbors[j];
640
+ int v1 = hnsw.neighbors[j];
616
641
 
617
642
  if (v1 < 0) {
618
643
  break;
@@ -646,14 +671,19 @@ std::priority_queue<HNSW::Node> HNSW::search_from_candidate_unbounded(
646
671
  return top_candidates;
647
672
  }
648
673
 
674
+ } // anonymous namespace
675
+
649
676
  HNSWStats HNSW::search(
650
677
  DistanceComputer& qdis,
651
678
  int k,
652
679
  idx_t* I,
653
680
  float* D,
654
- VisitedTable& vt) const {
681
+ VisitedTable& vt,
682
+ const SearchParametersHNSW* params) const {
655
683
  HNSWStats stats;
656
-
684
+ if (entry_point == -1) {
685
+ return stats;
686
+ }
657
687
  if (upper_beam == 1) {
658
688
  // greedy search on upper levels
659
689
  storage_idx_t nearest = entry_point;
@@ -664,16 +694,22 @@ HNSWStats HNSW::search(
664
694
  }
665
695
 
666
696
  int ef = std::max(efSearch, k);
667
- if (search_bounded_queue) {
697
+ if (search_bounded_queue) { // this is the most common branch
668
698
  MinimaxHeap candidates(ef);
669
699
 
670
700
  candidates.push(nearest, d_nearest);
671
701
 
672
- search_from_candidates(qdis, k, I, D, candidates, vt, stats, 0);
702
+ search_from_candidates(
703
+ *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
673
704
  } else {
674
705
  std::priority_queue<Node> top_candidates =
675
706
  search_from_candidate_unbounded(
676
- Node(d_nearest, nearest), qdis, ef, &vt, stats);
707
+ *this,
708
+ Node(d_nearest, nearest),
709
+ qdis,
710
+ ef,
711
+ &vt,
712
+ stats);
677
713
 
678
714
  while (top_candidates.size() > k) {
679
715
  top_candidates.pop();
@@ -713,9 +749,10 @@ HNSWStats HNSW::search(
713
749
 
714
750
  if (level == 0) {
715
751
  nres = search_from_candidates(
716
- qdis, k, I, D, candidates, vt, stats, 0);
752
+ *this, qdis, k, I, D, candidates, vt, stats, 0);
717
753
  } else {
718
754
  nres = search_from_candidates(
755
+ *this,
719
756
  qdis,
720
757
  candidates_size,
721
758
  I_to_next.data(),
@@ -732,6 +769,70 @@ HNSWStats HNSW::search(
732
769
  return stats;
733
770
  }
734
771
 
772
+ void HNSW::search_level_0(
773
+ DistanceComputer& qdis,
774
+ int k,
775
+ idx_t* idxi,
776
+ float* simi,
777
+ idx_t nprobe,
778
+ const storage_idx_t* nearest_i,
779
+ const float* nearest_d,
780
+ int search_type,
781
+ HNSWStats& search_stats,
782
+ VisitedTable& vt) const {
783
+ const HNSW& hnsw = *this;
784
+
785
+ if (search_type == 1) {
786
+ int nres = 0;
787
+
788
+ for (int j = 0; j < nprobe; j++) {
789
+ storage_idx_t cj = nearest_i[j];
790
+
791
+ if (cj < 0)
792
+ break;
793
+
794
+ if (vt.get(cj))
795
+ continue;
796
+
797
+ int candidates_size = std::max(hnsw.efSearch, int(k));
798
+ MinimaxHeap candidates(candidates_size);
799
+
800
+ candidates.push(cj, nearest_d[j]);
801
+
802
+ nres = search_from_candidates(
803
+ hnsw,
804
+ qdis,
805
+ k,
806
+ idxi,
807
+ simi,
808
+ candidates,
809
+ vt,
810
+ search_stats,
811
+ 0,
812
+ nres);
813
+ }
814
+ } else if (search_type == 2) {
815
+ int candidates_size = std::max(hnsw.efSearch, int(k));
816
+ candidates_size = std::max(candidates_size, int(nprobe));
817
+
818
+ MinimaxHeap candidates(candidates_size);
819
+ for (int j = 0; j < nprobe; j++) {
820
+ storage_idx_t cj = nearest_i[j];
821
+
822
+ if (cj < 0)
823
+ break;
824
+ candidates.push(cj, nearest_d[j]);
825
+ }
826
+
827
+ search_from_candidates(
828
+ hnsw, qdis, k, idxi, simi, candidates, vt, search_stats, 0);
829
+ }
830
+ }
831
+
832
+ /**************************************************************
833
+ * MinimaxHeap
834
+ **************************************************************/
835
+
735
836
  void HNSW::MinimaxHeap::push(storage_idx_t i, float v) {
736
837
  if (k == n) {
737
838
  if (v >= dis[0])
@@ -43,6 +43,13 @@ struct VisitedTable;
43
43
  struct DistanceComputer; // from AuxIndexStructures
44
44
  struct HNSWStats;
45
45
 
46
+ struct SearchParametersHNSW : SearchParameters {
47
+ int efSearch = 16;
48
+ bool check_relative_distance = true;
49
+
50
+ ~SearchParametersHNSW() {}
51
+ };
52
+
46
53
  struct HNSW {
47
54
  /// internal storage of vectors (32 bits: this is expensive)
48
55
  typedef int storage_idx_t;
@@ -188,30 +195,26 @@ struct HNSW {
188
195
  std::vector<omp_lock_t>& locks,
189
196
  VisitedTable& vt);
190
197
 
191
- int search_from_candidates(
198
+ /// search interface for 1 point, single thread
199
+ HNSWStats search(
192
200
  DistanceComputer& qdis,
193
201
  int k,
194
202
  idx_t* I,
195
203
  float* D,
196
- MinimaxHeap& candidates,
197
204
  VisitedTable& vt,
198
- HNSWStats& stats,
199
- int level,
200
- int nres_in = 0) const;
205
+ const SearchParametersHNSW* params = nullptr) const;
201
206
 
202
- std::priority_queue<Node> search_from_candidate_unbounded(
203
- const Node& node,
204
- DistanceComputer& qdis,
205
- int ef,
206
- VisitedTable* vt,
207
- HNSWStats& stats) const;
208
-
209
- /// search interface
210
- HNSWStats search(
207
+ /// search only in level 0 from a given vertex
208
+ void search_level_0(
211
209
  DistanceComputer& qdis,
212
210
  int k,
213
- idx_t* I,
214
- float* D,
211
+ idx_t* idxi,
212
+ float* simi,
213
+ idx_t nprobe,
214
+ const storage_idx_t* nearest_i,
215
+ const float* nearest_d,
216
+ int search_type,
217
+ HNSWStats& search_stats,
215
218
  VisitedTable& vt) const;
216
219
 
217
220
  void reset();
@@ -0,0 +1,125 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/FaissAssert.h>
9
+ #include <faiss/impl/IDSelector.h>
10
+
11
+ namespace faiss {
12
+
13
+ /***********************************************************************
14
+ * IDSelectorRange
15
+ ***********************************************************************/
16
+
17
+ IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax, bool assume_sorted)
18
+ : imin(imin), imax(imax), assume_sorted(assume_sorted) {}
19
+
20
+ bool IDSelectorRange::is_member(idx_t id) const {
21
+ return id >= imin && id < imax;
22
+ }
23
+
24
+ void IDSelectorRange::find_sorted_ids_bounds(
25
+ size_t list_size,
26
+ const idx_t* ids,
27
+ size_t* jmin_out,
28
+ size_t* jmax_out) const {
29
+ FAISS_ASSERT(assume_sorted);
30
+ if (list_size == 0 || imax <= ids[0] || imin > ids[list_size - 1]) {
31
+ *jmin_out = *jmax_out = 0;
32
+ return;
33
+ }
34
+ // bissection to find imin
35
+ if (ids[0] >= imin) {
36
+ *jmin_out = 0;
37
+ } else {
38
+ size_t j0 = 0, j1 = list_size;
39
+ while (j1 > j0 + 1) {
40
+ size_t jmed = (j0 + j1) / 2;
41
+ if (ids[jmed] >= imin) {
42
+ j1 = jmed;
43
+ } else {
44
+ j0 = jmed;
45
+ }
46
+ }
47
+ *jmin_out = j1;
48
+ }
49
+ // bissection to find imax
50
+ if (*jmin_out == list_size || ids[*jmin_out] >= imax) {
51
+ *jmax_out = *jmin_out;
52
+ } else {
53
+ size_t j0 = *jmin_out, j1 = list_size;
54
+ while (j1 > j0 + 1) {
55
+ size_t jmed = (j0 + j1) / 2;
56
+ if (ids[jmed] >= imax) {
57
+ j1 = jmed;
58
+ } else {
59
+ j0 = jmed;
60
+ }
61
+ }
62
+ *jmax_out = j1;
63
+ }
64
+ }
65
+
66
+ /***********************************************************************
67
+ * IDSelectorArray
68
+ ***********************************************************************/
69
+
70
+ IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
71
+
72
+ bool IDSelectorArray::is_member(idx_t id) const {
73
+ for (idx_t i = 0; i < n; i++) {
74
+ if (ids[i] == id)
75
+ return true;
76
+ }
77
+ return false;
78
+ }
79
+
80
+ /***********************************************************************
81
+ * IDSelectorBatch
82
+ ***********************************************************************/
83
+
84
+ IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
85
+ nbits = 0;
86
+ while (n > ((idx_t)1 << nbits)) {
87
+ nbits++;
88
+ }
89
+ nbits += 5;
90
+ // for n = 1M, nbits = 25 is optimal, see P56659518
91
+
92
+ mask = ((idx_t)1 << nbits) - 1;
93
+ bloom.resize((idx_t)1 << (nbits - 3), 0);
94
+ for (idx_t i = 0; i < n; i++) {
95
+ Index::idx_t id = indices[i];
96
+ set.insert(id);
97
+ id &= mask;
98
+ bloom[id >> 3] |= 1 << (id & 7);
99
+ }
100
+ }
101
+
102
+ bool IDSelectorBatch::is_member(idx_t i) const {
103
+ long im = i & mask;
104
+ if (!(bloom[im >> 3] & (1 << (im & 7)))) {
105
+ return 0;
106
+ }
107
+ return set.count(i);
108
+ }
109
+
110
+ /***********************************************************************
111
+ * IDSelectorBitmap
112
+ ***********************************************************************/
113
+
114
+ IDSelectorBitmap::IDSelectorBitmap(size_t n, const uint8_t* bitmap)
115
+ : n(n), bitmap(bitmap) {}
116
+
117
+ bool IDSelectorBitmap::is_member(idx_t ii) const {
118
+ uint64_t i = ii;
119
+ if ((i >> 3) >= n) {
120
+ return false;
121
+ }
122
+ return (bitmap[i >> 3] >> (i & 7)) & 1;
123
+ }
124
+
125
+ } // namespace faiss