faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -38,14 +38,14 @@ struct CodePacker {
38
38
  // code_size
39
39
  ) const = 0;
40
40
 
41
- // pack all code in a block
41
+ // pack all codes in a block
42
42
  virtual void pack_all(
43
43
  const uint8_t* flat_codes, // codes to write to the block, size
44
44
  // (nvec * code_size)
45
45
  uint8_t* block // block to write to (size block_size)
46
46
  ) const;
47
47
 
48
- // unpack all code in a block
48
+ // unpack all codes in a block
49
49
  virtual void unpack_all(
50
50
  const uint8_t* block, // block to read from (size block_size)
51
51
  uint8_t* flat_codes // where to write the resulting codes size (nvec
@@ -60,7 +60,7 @@ struct DistanceComputer {
60
60
  };
61
61
 
62
62
  /* Wrap the distance computer into one that negates the
63
- distances. This makes supporting INNER_PRODUCE search easier */
63
+ distances. This makes supporting INNER_PRODUCT search easier */
64
64
 
65
65
  struct NegativeDistanceComputer : DistanceComputer {
66
66
  /// owned by this
@@ -100,7 +100,7 @@ struct NegativeDistanceComputer : DistanceComputer {
100
100
  return -basedis->symmetric_dis(i, j);
101
101
  }
102
102
 
103
- virtual ~NegativeDistanceComputer() {
103
+ virtual ~NegativeDistanceComputer() override {
104
104
  delete basedis;
105
105
  }
106
106
  };
@@ -113,19 +113,90 @@ struct FlatCodesDistanceComputer : DistanceComputer {
113
113
  const uint8_t* codes;
114
114
  size_t code_size;
115
115
 
116
- FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
117
- : codes(codes), code_size(code_size) {}
116
+ const float* q = nullptr; // not used in all distance computers
118
117
 
119
- FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
118
+ FlatCodesDistanceComputer(
119
+ const uint8_t* codes,
120
+ size_t code_size,
121
+ const float* q = nullptr)
122
+ : codes(codes), code_size(code_size), q(q) {}
123
+
124
+ explicit FlatCodesDistanceComputer(const float* q)
125
+ : codes(nullptr), code_size(0), q(q) {}
126
+
127
+ FlatCodesDistanceComputer() : codes(nullptr), code_size(0), q(nullptr) {}
120
128
 
121
129
  float operator()(idx_t i) override {
122
130
  return distance_to_code(codes + i * code_size);
123
131
  }
124
132
 
133
+ /// Computes a partial dot product over a slice of the query vector.
134
+ /// The slice is defined by the following parameters:
135
+ /// — `offset`: the starting index of the first component to include
136
+ /// — `num_components`: the number of consecutive components to include
137
+ ///
138
+ /// Components refer to raw dimensions of the flat (uncompressed) query
139
+ /// vector.
140
+ ///
141
+ /// By default, this method throws an error, as it is only implemented
142
+ /// in specific subclasses such as `FlatL2Dis`. Other flat distance
143
+ /// computers may override this when partial dot product support is needed.
144
+ ///
145
+ /// Over time, this method might be changed to a pure virtual function (`=
146
+ /// 0`) to enforce implementation in subclasses that require this
147
+ /// functionality.
148
+ ///
149
+ /// This method is not part of the generic `DistanceComputer` interface
150
+ /// because for compressed representations (e.g., product quantization),
151
+ /// calling `partial_dot_product` repeatedly is often less efficient than
152
+ /// computing the full distance at once.
153
+ ///
154
+ /// Supporting efficient partial scans generally requires a different memory
155
+ /// layout, such as interleaved blocks that keep SIMD lanes full. This is a
156
+ /// non-trivial change and not supported in the current flat layout.
157
+ ///
158
+ /// For more details on partial (or chunked) dot product computations and
159
+ /// the performance trade-offs involved, refer to the Panorama paper:
160
+ /// https://arxiv.org/pdf/2510.00566
161
+ virtual float partial_dot_product(
162
+ const idx_t /* i */,
163
+ const uint32_t /* offset */,
164
+ const uint32_t /* num_components */) {
165
+ FAISS_THROW_MSG("partial_dot_product not implemented");
166
+ }
167
+
125
168
  /// compute distance of current query to an encoded vector
126
169
  virtual float distance_to_code(const uint8_t* code) = 0;
127
170
 
128
- virtual ~FlatCodesDistanceComputer() {}
171
+ /// Compute partial dot products of current query to 4 stored vectors.
172
+ /// See `partial_dot_product` for more details.
173
+ virtual void partial_dot_product_batch_4(
174
+ const idx_t idx0,
175
+ const idx_t idx1,
176
+ const idx_t idx2,
177
+ const idx_t idx3,
178
+ float& dp0,
179
+ float& dp1,
180
+ float& dp2,
181
+ float& dp3,
182
+ const uint32_t offset,
183
+ const uint32_t num_components) {
184
+ // default implementation for correctness
185
+ const float d0 =
186
+ this->partial_dot_product(idx0, offset, num_components);
187
+ const float d1 =
188
+ this->partial_dot_product(idx1, offset, num_components);
189
+ const float d2 =
190
+ this->partial_dot_product(idx2, offset, num_components);
191
+ const float d3 =
192
+ this->partial_dot_product(idx3, offset, num_components);
193
+ dp0 = d0;
194
+ dp1 = d1;
195
+ dp2 = d2;
196
+ dp3 = d3;
197
+ }
198
+
199
+ virtual ~FlatCodesDistanceComputer() override {}
129
200
  };
130
201
 
131
202
  } // namespace faiss
@@ -0,0 +1,53 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstddef>
11
+
12
+ namespace faiss {
13
+
14
+ // Forward declarations
15
+ struct NormTableScaler;
16
+
17
+ namespace rabitq_utils {
18
+ struct QueryFactorsData;
19
+ }
20
+
21
+ /**
22
+ * Simple context object that holds processors for FastScan operations.
23
+ * */
24
+ struct FastScanDistancePostProcessing {
25
+ /// Norm scaling processor for Additive Quantizers (nullptr if not needed)
26
+ const NormTableScaler* norm_scaler = nullptr;
27
+
28
+ /// Query factors data pointer for RaBitQ (nullptr if not needed)
29
+ /// This pointer should point to the beginning of the relevant
30
+ /// QueryFactorsData subset for this context.
31
+ rabitq_utils::QueryFactorsData* query_factors = nullptr;
32
+
33
+ /// The nprobe value used when allocating query_factors storage.
34
+ /// This is needed because the allocation size (n * nprobe) may use a
35
+ /// different nprobe than index->nprobe if search params override it.
36
+ /// Set to 0 to use index->nprobe as fallback.
37
+ size_t nprobe = 0;
38
+
39
+ /// Default constructor - no processing
40
+ FastScanDistancePostProcessing() = default;
41
+
42
+ /// Check if norm scaling is enabled
43
+ bool has_norm_scaling() const {
44
+ return norm_scaler != nullptr;
45
+ }
46
+
47
+ /// Check if query factors processing is enabled
48
+ bool has_query_processing() const {
49
+ return query_factors != nullptr;
50
+ }
51
+ };
52
+
53
+ } // namespace faiss
@@ -9,6 +9,8 @@
9
9
 
10
10
  #include <cstddef>
11
11
 
12
+ #include <faiss/IndexHNSW.h>
13
+
12
14
  #include <faiss/impl/AuxIndexStructures.h>
13
15
  #include <faiss/impl/DistanceComputer.h>
14
16
  #include <faiss/impl/IDSelector.h>
@@ -60,7 +62,7 @@ HNSW::HNSW(int M) : rng(12345) {
60
62
 
61
63
  int HNSW::random_level() {
62
64
  double f = rng.rand_float();
63
- // could be a bit faster with bissection
65
+ // could be a bit faster with bisection
64
66
  for (int level = 0; level < assign_probas.size(); level++) {
65
67
  if (f < assign_probas[level]) {
66
68
  return level;
@@ -588,6 +590,28 @@ void HNSW::add_with_locks(
588
590
  using MinimaxHeap = HNSW::MinimaxHeap;
589
591
  using Node = HNSW::Node;
590
592
  using C = HNSW::C;
593
+
594
+ /** Helper to extract search parameters from HNSW and SearchParameters */
595
+ static inline void extract_search_params(
596
+ const HNSW& hnsw,
597
+ const SearchParameters* params,
598
+ bool& do_dis_check,
599
+ int& efSearch,
600
+ const IDSelector*& sel) {
601
+ // can be overridden by search params
602
+ do_dis_check = hnsw.check_relative_distance;
603
+ efSearch = hnsw.efSearch;
604
+ sel = nullptr;
605
+ if (params) {
606
+ if (const SearchParametersHNSW* hnsw_params =
607
+ dynamic_cast<const SearchParametersHNSW*>(params)) {
608
+ do_dis_check = hnsw_params->check_relative_distance;
609
+ efSearch = hnsw_params->efSearch;
610
+ }
611
+ sel = params->sel;
612
+ }
613
+ }
614
+
591
615
  /** Do a BFS on the candidates list */
592
616
  int search_from_candidates(
593
617
  const HNSW& hnsw,
@@ -602,18 +626,10 @@ int search_from_candidates(
602
626
  int nres = nres_in;
603
627
  int ndis = 0;
604
628
 
605
- // can be overridden by search params
606
- bool do_dis_check = hnsw.check_relative_distance;
607
- int efSearch = hnsw.efSearch;
608
- const IDSelector* sel = nullptr;
609
- if (params) {
610
- if (const SearchParametersHNSW* hnsw_params =
611
- dynamic_cast<const SearchParametersHNSW*>(params)) {
612
- do_dis_check = hnsw_params->check_relative_distance;
613
- efSearch = hnsw_params->efSearch;
614
- }
615
- sel = params->sel;
616
- }
629
+ bool do_dis_check;
630
+ int efSearch;
631
+ const IDSelector* sel;
632
+ extract_search_params(hnsw, params, do_dis_check, efSearch, sel);
617
633
 
618
634
  C::T threshold = res.threshold;
619
635
  for (int i = 0; i < candidates.size(); i++) {
@@ -735,6 +751,253 @@ int search_from_candidates(
735
751
  return nres;
736
752
  }
737
753
 
754
+ int search_from_candidates_panorama(
755
+ const HNSW& hnsw,
756
+ const IndexHNSW* index,
757
+ DistanceComputer& qdis,
758
+ ResultHandler<C>& res,
759
+ MinimaxHeap& candidates,
760
+ VisitedTable& vt,
761
+ HNSWStats& stats,
762
+ int level,
763
+ int nres_in,
764
+ const SearchParameters* params) {
765
+ int nres = nres_in;
766
+ int ndis = 0;
767
+
768
+ bool do_dis_check;
769
+ int efSearch;
770
+ const IDSelector* sel;
771
+ extract_search_params(hnsw, params, do_dis_check, efSearch, sel);
772
+
773
+ C::T threshold = res.threshold;
774
+ for (int i = 0; i < candidates.size(); i++) {
775
+ idx_t v1 = candidates.ids[i];
776
+ float d = candidates.dis[i];
777
+ FAISS_ASSERT(v1 >= 0);
778
+ if (!sel || sel->is_member(v1)) {
779
+ if (d < threshold) {
780
+ if (res.add_result(d, v1)) {
781
+ threshold = res.threshold;
782
+ }
783
+ }
784
+ }
785
+ vt.set(v1);
786
+ }
787
+
788
+ // Validate the index type so we can access cumulative sums, n_levels, and
789
+ // get the ability to compute partial dot products.
790
+ const auto* panorama_index =
791
+ dynamic_cast<const IndexHNSWFlatPanorama*>(index);
792
+ FAISS_THROW_IF_NOT_MSG(
793
+ panorama_index, "Index must be a IndexHNSWFlatPanorama");
794
+ auto* flat_codes_qdis = dynamic_cast<FlatCodesDistanceComputer*>(&qdis);
795
+ FAISS_THROW_IF_NOT_MSG(
796
+ flat_codes_qdis,
797
+ "DistanceComputer must be a FlatCodesDistanceComputer");
798
+
799
+ // Allocate space for the index array and exact distances.
800
+ size_t M = hnsw.nb_neighbors(0);
801
+ std::vector<idx_t> index_array(M);
802
+ std::vector<float> exact_distances(M);
803
+
804
+ const float* query = flat_codes_qdis->q;
805
+ std::vector<float> query_cum_sums(panorama_index->num_panorama_levels + 1);
806
+ IndexHNSWFlatPanorama::compute_cum_sums(
807
+ query,
808
+ query_cum_sums.data(),
809
+ panorama_index->d,
810
+ panorama_index->num_panorama_levels,
811
+ panorama_index->panorama_level_width);
812
+ float query_norm_sq = query_cum_sums[0] * query_cum_sums[0];
813
+
814
+ int nstep = 0;
815
+
816
+ while (candidates.size() > 0) {
817
+ float d0 = 0;
818
+ int v0 = candidates.pop_min(&d0);
819
+
820
+ if (do_dis_check) {
821
+ // tricky stopping condition: there are more than ef
822
+ // distances that are processed already that are smaller
823
+ // than d0
824
+
825
+ int n_dis_below = candidates.count_below(d0);
826
+ if (n_dis_below >= efSearch) {
827
+ break;
828
+ }
829
+ }
830
+
831
+ size_t begin, end;
832
+ hnsw.neighbor_range(v0, level, &begin, &end);
833
+
834
+ // Unlike the vanilla HNSW, we already remove (and compact) the visited
835
+ // nodes from the candidates list at this stage. We also remove nodes
836
+ // that are not selected.
837
+ size_t initial_size = 0;
838
+ for (size_t j = begin; j < end; j++) {
839
+ int v1 = hnsw.neighbors[j];
840
+ if (v1 < 0) {
841
+ break;
842
+ }
843
+
844
+ const float* cum_sums_v1 = panorama_index->get_cum_sum(v1);
845
+ index_array[initial_size] = v1;
846
+ exact_distances[initial_size] =
847
+ query_norm_sq + cum_sums_v1[0] * cum_sums_v1[0];
848
+
849
+ bool is_selected = !sel || sel->is_member(v1);
850
+ initial_size += is_selected && !vt.get(v1) ? 1 : 0;
851
+
852
+ vt.set(v1);
853
+ }
854
+
855
+ size_t batch_size = initial_size;
856
+ size_t curr_panorama_level = 0;
857
+ const size_t num_panorama_levels = panorama_index->num_panorama_levels;
858
+ while (curr_panorama_level < num_panorama_levels && batch_size > 0) {
859
+ float query_cum_norm = query_cum_sums[curr_panorama_level + 1];
860
+
861
+ const size_t panorama_level_width =
862
+ panorama_index->panorama_level_width;
863
+ size_t start_dim = curr_panorama_level * panorama_level_width;
864
+ size_t end_dim = (curr_panorama_level + 1) * panorama_level_width;
865
+ end_dim = std::min(end_dim, static_cast<size_t>(panorama_index->d));
866
+
867
+ size_t i = 0;
868
+ size_t next_batch_size = 0;
869
+ for (; i + 3 < batch_size; i += 4) {
870
+ idx_t idx_0 = index_array[i];
871
+ idx_t idx_1 = index_array[i + 1];
872
+ idx_t idx_2 = index_array[i + 2];
873
+ idx_t idx_3 = index_array[i + 3];
874
+
875
+ float dp[4];
876
+ flat_codes_qdis->partial_dot_product_batch_4(
877
+ idx_0,
878
+ idx_1,
879
+ idx_2,
880
+ idx_3,
881
+ dp[0],
882
+ dp[1],
883
+ dp[2],
884
+ dp[3],
885
+ start_dim,
886
+ end_dim - start_dim);
887
+ ndis += 4;
888
+
889
+ float new_exact_0 = exact_distances[i + 0] - 2 * dp[0];
890
+ float new_exact_1 = exact_distances[i + 1] - 2 * dp[1];
891
+ float new_exact_2 = exact_distances[i + 2] - 2 * dp[2];
892
+ float new_exact_3 = exact_distances[i + 3] - 2 * dp[3];
893
+
894
+ float cum_sum_0 = panorama_index->get_cum_sum(
895
+ idx_0)[curr_panorama_level + 1];
896
+ float cum_sum_1 = panorama_index->get_cum_sum(
897
+ idx_1)[curr_panorama_level + 1];
898
+ float cum_sum_2 = panorama_index->get_cum_sum(
899
+ idx_2)[curr_panorama_level + 1];
900
+ float cum_sum_3 = panorama_index->get_cum_sum(
901
+ idx_3)[curr_panorama_level + 1];
902
+
903
+ float cs_bound_0 = 2.0f * cum_sum_0 * query_cum_norm;
904
+ float cs_bound_1 = 2.0f * cum_sum_1 * query_cum_norm;
905
+ float cs_bound_2 = 2.0f * cum_sum_2 * query_cum_norm;
906
+ float cs_bound_3 = 2.0f * cum_sum_3 * query_cum_norm;
907
+
908
+ float lower_bound_0 = new_exact_0 - cs_bound_0;
909
+ float lower_bound_1 = new_exact_1 - cs_bound_1;
910
+ float lower_bound_2 = new_exact_2 - cs_bound_2;
911
+ float lower_bound_3 = new_exact_3 - cs_bound_3;
912
+
913
+ // The following code is not the most branch friendly (due to
914
+ // the maintenance of the candidate heap), but micro-benchmarks
915
+ // have shown that it is not worth it to write horrible code to
916
+ // squeeze out those cycles.
917
+ if (lower_bound_0 <= threshold) {
918
+ exact_distances[next_batch_size] = new_exact_0;
919
+ index_array[next_batch_size] = idx_0;
920
+ next_batch_size += 1;
921
+ } else {
922
+ candidates.push(idx_0, new_exact_0);
923
+ }
924
+ if (lower_bound_1 <= threshold) {
925
+ exact_distances[next_batch_size] = new_exact_1;
926
+ index_array[next_batch_size] = idx_1;
927
+ next_batch_size += 1;
928
+ } else {
929
+ candidates.push(idx_1, new_exact_1);
930
+ }
931
+ if (lower_bound_2 <= threshold) {
932
+ exact_distances[next_batch_size] = new_exact_2;
933
+ index_array[next_batch_size] = idx_2;
934
+ next_batch_size += 1;
935
+ } else {
936
+ candidates.push(idx_2, new_exact_2);
937
+ }
938
+ if (lower_bound_3 <= threshold) {
939
+ exact_distances[next_batch_size] = new_exact_3;
940
+ index_array[next_batch_size] = idx_3;
941
+ next_batch_size += 1;
942
+ } else {
943
+ candidates.push(idx_3, new_exact_3);
944
+ }
945
+ }
946
+
947
+ // Process the remaining candidates.
948
+ for (; i < batch_size; i++) {
949
+ idx_t idx = index_array[i];
950
+
951
+ float dp = flat_codes_qdis->partial_dot_product(
952
+ idx, start_dim, end_dim - start_dim);
953
+ ndis += 1;
954
+ float new_exact = exact_distances[i] - 2.0f * dp;
955
+
956
+ float cum_sum = panorama_index->get_cum_sum(
957
+ idx)[curr_panorama_level + 1];
958
+ float cs_bound = 2.0f * cum_sum * query_cum_norm;
959
+ float lower_bound = new_exact - cs_bound;
960
+
961
+ if (lower_bound <= threshold) {
962
+ exact_distances[next_batch_size] = new_exact;
963
+ index_array[next_batch_size] = idx;
964
+ next_batch_size += 1;
965
+ } else {
966
+ candidates.push(idx, new_exact);
967
+ }
968
+ }
969
+
970
+ batch_size = next_batch_size;
971
+ curr_panorama_level++;
972
+ }
973
+
974
+ // Add surviving candidates to the result handler.
975
+ for (size_t i = 0; i < batch_size; i++) {
976
+ idx_t idx = index_array[i];
977
+ if (res.add_result(exact_distances[i], idx)) {
978
+ nres += 1;
979
+ }
980
+ candidates.push(idx, exact_distances[i]);
981
+ }
982
+
983
+ nstep++;
984
+ if (!do_dis_check && nstep > efSearch) {
985
+ break;
986
+ }
987
+ }
988
+
989
+ if (level == 0) {
990
+ stats.n1++;
991
+ if (candidates.size() == 0) {
992
+ stats.n2++;
993
+ }
994
+ stats.ndis += ndis;
995
+ stats.nhops += nstep;
996
+ }
997
+
998
+ return nres;
999
+ }
1000
+
738
1001
  std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
739
1002
  const HNSW& hnsw,
740
1003
  const Node& node,
@@ -936,6 +1199,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
936
1199
 
937
1200
  HNSWStats HNSW::search(
938
1201
  DistanceComputer& qdis,
1202
+ const IndexHNSW* index,
939
1203
  ResultHandler<C>& res,
940
1204
  VisitedTable& vt,
941
1205
  const SearchParameters* params) const {
@@ -966,13 +1230,28 @@ HNSWStats HNSW::search(
966
1230
  }
967
1231
 
968
1232
  int ef = std::max(efSearch, k);
969
- if (bounded_queue) { // this is the most common branch
1233
+ if (bounded_queue) { // this is the most common branch, for now we only
1234
+ // support Panorama search in this branch
970
1235
  MinimaxHeap candidates(ef);
971
1236
 
972
1237
  candidates.push(nearest, d_nearest);
973
1238
 
974
- search_from_candidates(
975
- *this, qdis, res, candidates, vt, stats, 0, 0, params);
1239
+ if (!is_panorama) {
1240
+ search_from_candidates(
1241
+ *this, qdis, res, candidates, vt, stats, 0, 0, params);
1242
+ } else {
1243
+ search_from_candidates_panorama(
1244
+ *this,
1245
+ index,
1246
+ qdis,
1247
+ res,
1248
+ candidates,
1249
+ vt,
1250
+ stats,
1251
+ 0,
1252
+ 0,
1253
+ params);
1254
+ }
976
1255
  } else {
977
1256
  std::priority_queue<Node> top_candidates =
978
1257
  search_from_candidate_unbounded(
@@ -8,12 +8,12 @@
8
8
  #pragma once
9
9
 
10
10
  #include <queue>
11
- #include <unordered_set>
12
11
  #include <vector>
13
12
 
14
13
  #include <omp.h>
15
14
 
16
15
  #include <faiss/Index.h>
16
+ #include <faiss/impl/DistanceComputer.h>
17
17
  #include <faiss/impl/FaissAssert.h>
18
18
  #include <faiss/impl/maybe_owned_vector.h>
19
19
  #include <faiss/impl/platform_macros.h>
@@ -22,6 +22,10 @@
22
22
 
23
23
  namespace faiss {
24
24
 
25
+ // Forward declarations to avoid circular dependency.
26
+ struct IndexHNSW;
27
+ struct IndexHNSWFlatPanorama;
28
+
25
29
  /** Implementation of the Hierarchical Navigable Small World
26
30
  * datastructure.
27
31
  *
@@ -31,7 +35,7 @@ namespace faiss {
31
35
  * Yu. A. Malkov, D. A. Yashunin, arXiv 2017
32
36
  *
33
37
  * This implementation is heavily influenced by the NMSlib
34
- * implementation by Yury Malkov and Leonid Boystov
38
+ * implementation by Yury Malkov and Leonid Boytsov
35
39
  * (https://github.com/searchivarius/nmslib)
36
40
  *
37
41
  * The HNSW object stores only the neighbor link structure, see
@@ -61,7 +65,7 @@ struct HNSW {
61
65
 
62
66
  typedef std::pair<float, storage_idx_t> Node;
63
67
 
64
- /** Heap structure that allows fast
68
+ /** Heap structure that allows fast access and updates.
65
69
  */
66
70
  struct MinimaxHeap {
67
71
  int n;
@@ -87,7 +91,7 @@ struct HNSW {
87
91
  int count_below(float thresh);
88
92
  };
89
93
 
90
- /// to sort pairs of (id, distance) from nearest to fathest or the reverse
94
+ /// to sort pairs of (id, distance) from nearest to farthest or the reverse
91
95
  struct NodeDistCloser {
92
96
  float d;
93
97
  int id;
@@ -146,6 +150,9 @@ struct HNSW {
146
150
  /// use bounded queue during exploration
147
151
  bool search_bounded_queue = true;
148
152
 
153
+ /// use Panorama progressive pruning in search
154
+ bool is_panorama = false;
155
+
149
156
  // methods that initialize the tree sizes
150
157
 
151
158
  /// initialize the assign_probas and cum_nneighbor_per_level to
@@ -160,7 +167,7 @@ struct HNSW {
160
167
  /// nb of neighbors for this level
161
168
  int nb_neighbors(int layer_no) const;
162
169
 
163
- /// cumumlative nb up to (and excluding) this level
170
+ /// cumulative nb up to (and excluding) this level
164
171
  int cum_nb_neighbors(int layer_no) const;
165
172
 
166
173
  /// range of entries in the neighbors table of vertex no at layer_no
@@ -196,9 +203,15 @@ struct HNSW {
196
203
  VisitedTable& vt,
197
204
  bool keep_max_size_level0 = false);
198
205
 
199
- /// search interface for 1 point, single thread
206
+ /// Search interface for 1 point, single thread
207
+ ///
208
+ /// NOTE: We pass a reference to the index itself to allow for additional
209
+ /// state information to be passed (used for Panorama progressive pruning).
210
+ /// The alternative would be to override both HNSW::search and
211
+ /// HNSWIndex::search, which would be a nuisance of code duplication.
200
212
  HNSWStats search(
201
213
  DistanceComputer& qdis,
214
+ const IndexHNSW* index,
202
215
  ResultHandler<C>& res,
203
216
  VisitedTable& vt,
204
217
  const SearchParameters* params = nullptr) const;
@@ -267,6 +280,22 @@ int search_from_candidates(
267
280
  int nres_in = 0,
268
281
  const SearchParameters* params = nullptr);
269
282
 
283
+ /// Equivalent to `search_from_candidates`, but applies pruning with progressive
284
+ /// refinement bounds.
285
+ /// This is used in `IndexHNSWFlatPanorama` to improve the search performance
286
+ /// for higher dimensional vectors.
287
+ int search_from_candidates_panorama(
288
+ const HNSW& hnsw,
289
+ const IndexHNSW* index,
290
+ DistanceComputer& qdis,
291
+ ResultHandler<HNSW::C>& res,
292
+ HNSW::MinimaxHeap& candidates,
293
+ VisitedTable& vt,
294
+ HNSWStats& stats,
295
+ int level,
296
+ int nres_in = 0,
297
+ const SearchParameters* params = nullptr);
298
+
270
299
  HNSWStats greedy_update_nearest(
271
300
  const HNSW& hnsw,
272
301
  DistanceComputer& qdis,
@@ -31,7 +31,7 @@ void IDSelectorRange::find_sorted_ids_bounds(
31
31
  *jmin_out = *jmax_out = 0;
32
32
  return;
33
33
  }
34
- // bissection to find imin
34
+ // bisection to find imin
35
35
  if (ids[0] >= imin) {
36
36
  *jmin_out = 0;
37
37
  } else {
@@ -46,7 +46,7 @@ void IDSelectorRange::find_sorted_ids_bounds(
46
46
  }
47
47
  *jmin_out = j1;
48
48
  }
49
- // bissection to find imax
49
+ // bisection to find imax
50
50
  if (*jmin_out == list_size || ids[*jmin_out] >= imax) {
51
51
  *jmax_out = *jmin_out;
52
52
  } else {