faiss 0.4.3 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +2 -0
  4. data/ext/faiss/index.cpp +33 -6
  5. data/ext/faiss/index_binary.cpp +17 -4
  6. data/ext/faiss/kmeans.cpp +6 -6
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +26 -51
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.cpp +11 -0
  15. data/vendor/faiss/faiss/Index.h +34 -11
  16. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  17. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  21. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  26. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  27. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  28. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  29. data/vendor/faiss/faiss/IndexFastScan.h +102 -7
  30. data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
  31. data/vendor/faiss/faiss/IndexFlat.h +81 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
  33. data/vendor/faiss/faiss/IndexHNSW.h +58 -2
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  35. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  36. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  37. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  41. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  42. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  43. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  44. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
  45. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  50. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
  51. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
  52. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
  53. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  56. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  58. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  59. data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
  60. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  61. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  62. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
  64. data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
  65. data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
  66. data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
  67. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
  68. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
  69. data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
  70. data/vendor/faiss/faiss/IndexRefine.h +17 -0
  71. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  72. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  73. data/vendor/faiss/faiss/MetricType.h +1 -1
  74. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  75. data/vendor/faiss/faiss/clone_index.cpp +5 -1
  76. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  77. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
  78. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  79. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  80. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  81. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
  82. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  83. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  84. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  85. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  86. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  87. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  88. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  89. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  90. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  91. data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
  92. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  93. data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
  94. data/vendor/faiss/faiss/impl/HNSW.h +35 -6
  95. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  96. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  97. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  98. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  99. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  100. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  101. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  102. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
  104. data/vendor/faiss/faiss/impl/Panorama.h +204 -0
  105. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  106. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  107. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  108. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  109. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  110. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  111. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  112. data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
  113. data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
  114. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
  115. data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
  116. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
  117. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
  118. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
  119. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  121. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  125. data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
  126. data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
  127. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  128. data/vendor/faiss/faiss/impl/io.h +4 -4
  129. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  130. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  131. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  132. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  133. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  134. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  139. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  140. data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
  141. data/vendor/faiss/faiss/impl/svs_io.h +67 -0
  142. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  143. data/vendor/faiss/faiss/index_factory.cpp +217 -8
  144. data/vendor/faiss/faiss/index_factory.h +1 -1
  145. data/vendor/faiss/faiss/index_io.h +1 -1
  146. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
  147. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  148. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
  149. data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
  150. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  151. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  152. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
  153. data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
  154. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
  155. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
  156. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
  157. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
  158. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
  159. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
  160. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
  161. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  162. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  163. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  164. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  165. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  166. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  167. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  168. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  169. data/vendor/faiss/faiss/utils/distances.cpp +0 -3
  170. data/vendor/faiss/faiss/utils/distances.h +2 -2
  171. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  172. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  173. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  174. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  176. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  177. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  178. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  179. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  181. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  183. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  184. data/vendor/faiss/faiss/utils/utils.cpp +9 -2
  185. data/vendor/faiss/faiss/utils/utils.h +2 -2
  186. metadata +29 -1
@@ -8,12 +8,15 @@
8
8
  #pragma once
9
9
 
10
10
  #include <faiss/Index.h>
11
+ #include <faiss/impl/FastScanDistancePostProcessing.h>
11
12
  #include <faiss/utils/AlignedTable.h>
12
13
 
13
14
  namespace faiss {
14
15
 
15
16
  struct CodePacker;
16
17
  struct NormTableScaler;
18
+ struct IDSelector;
19
+ struct SIMDResultHandlerToFloat;
17
20
 
18
21
  /** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
19
22
  *
@@ -54,6 +57,14 @@ struct IndexFastScan : Index {
54
57
  // (set when initialized by IndexPQ or IndexAQ)
55
58
  const uint8_t* orig_codes = nullptr;
56
59
 
60
+ /** Initialize the fast scan index
61
+ *
62
+ * @param d dimensionality of vectors
63
+ * @param M number of subquantizers
64
+ * @param nbits number of bits per subquantizer
65
+ * @param metric distance metric to use
66
+ * @param bbs block size for SIMD processing
67
+ */
57
68
  void init_fastscan(
58
69
  int d,
59
70
  size_t M,
@@ -65,6 +76,15 @@ struct IndexFastScan : Index {
65
76
 
66
77
  void reset() override;
67
78
 
79
+ /** Search for k nearest neighbors
80
+ *
81
+ * @param n number of query vectors
82
+ * @param x query vectors (n * d)
83
+ * @param k number of nearest neighbors to find
84
+ * @param distances output distances (n * k)
85
+ * @param labels output labels/indices (n * k)
86
+ * @param params optional search parameters
87
+ */
68
88
  void search(
69
89
  idx_t n,
70
90
  const float* x,
@@ -73,20 +93,70 @@ struct IndexFastScan : Index {
73
93
  idx_t* labels,
74
94
  const SearchParameters* params = nullptr) const override;
75
95
 
96
+ /** Add vectors to the index
97
+ *
98
+ * @param n number of vectors to add
99
+ * @param x vectors to add (n * d)
100
+ */
76
101
  void add(idx_t n, const float* x) override;
77
102
 
103
+ /** Compute codes for vectors
104
+ *
105
+ * @param codes output codes
106
+ * @param n number of vectors to encode
107
+ * @param x vectors to encode (n * d)
108
+ */
78
109
  virtual void compute_codes(uint8_t* codes, idx_t n, const float* x)
79
110
  const = 0;
80
111
 
81
- virtual void compute_float_LUT(float* lut, idx_t n, const float* x)
82
- const = 0;
112
+ /** Compute floating-point lookup table for distance computation
113
+ *
114
+ * @param lut output lookup table
115
+ * @param n number of query vectors
116
+ * @param x query vectors (n * d)
117
+ * @param context processing context containing all processors
118
+ */
119
+ virtual void compute_float_LUT(
120
+ float* lut,
121
+ idx_t n,
122
+ const float* x,
123
+ const FastScanDistancePostProcessing& context) const = 0;
124
+
125
+ /** Create a KNN handler for this index type
126
+ *
127
+ * This method can be overridden by derived classes to provide
128
+ * specialized handlers (e.g., RaBitQHeapHandler for RaBitQ indexes).
129
+ * Base implementation creates standard handlers based on k and impl.
130
+ *
131
+ * @param is_max whether to use CMax comparator (true) or CMin (false)
132
+ * @param impl implementation number
133
+ * @param n number of queries
134
+ * @param k number of neighbors to find
135
+ * @param ntotal total number of vectors in database
136
+ * @param distances output distances array
137
+ * @param labels output labels array
138
+ * @param sel optional ID selector
139
+ * @param context processing context for distance post-processing
140
+ * @return pointer to created handler (never returns nullptr)
141
+ */
142
+ virtual SIMDResultHandlerToFloat* make_knn_handler(
143
+ bool is_max,
144
+ int impl,
145
+ idx_t n,
146
+ idx_t k,
147
+ size_t ntotal,
148
+ float* distances,
149
+ idx_t* labels,
150
+ const IDSelector* sel,
151
+ const FastScanDistancePostProcessing& context) const;
83
152
 
84
153
  // called by search function
85
154
  void compute_quantized_LUT(
86
155
  idx_t n,
87
156
  const float* x,
88
157
  uint8_t* lut,
89
- float* normalizers) const;
158
+ float* normalizers,
159
+ const FastScanDistancePostProcessing& context) const;
90
160
 
91
161
  template <bool is_max>
92
162
  void search_dispatch_implem(
@@ -95,7 +165,7 @@ struct IndexFastScan : Index {
95
165
  idx_t k,
96
166
  float* distances,
97
167
  idx_t* labels,
98
- const NormTableScaler* scaler) const;
168
+ const FastScanDistancePostProcessing& context) const;
99
169
 
100
170
  template <class Cfloat>
101
171
  void search_implem_234(
@@ -104,7 +174,7 @@ struct IndexFastScan : Index {
104
174
  idx_t k,
105
175
  float* distances,
106
176
  idx_t* labels,
107
- const NormTableScaler* scaler) const;
177
+ const FastScanDistancePostProcessing& context) const;
108
178
 
109
179
  template <class C>
110
180
  void search_implem_12(
@@ -114,7 +184,7 @@ struct IndexFastScan : Index {
114
184
  float* distances,
115
185
  idx_t* labels,
116
186
  int impl,
117
- const NormTableScaler* scaler) const;
187
+ const FastScanDistancePostProcessing& context) const;
118
188
 
119
189
  template <class C>
120
190
  void search_implem_14(
@@ -124,14 +194,39 @@ struct IndexFastScan : Index {
124
194
  float* distances,
125
195
  idx_t* labels,
126
196
  int impl,
127
- const NormTableScaler* scaler) const;
197
+ const FastScanDistancePostProcessing& context) const;
128
198
 
199
+ /** Reconstruct a vector from its code
200
+ *
201
+ * @param key index of vector to reconstruct
202
+ * @param recons output reconstructed vector
203
+ */
129
204
  void reconstruct(idx_t key, float* recons) const override;
205
+
206
+ /** Remove vectors by ID selector
207
+ *
208
+ * @param sel selector defining which vectors to remove
209
+ * @return number of vectors removed
210
+ */
130
211
  size_t remove_ids(const IDSelector& sel) override;
131
212
 
213
+ /** Get the code packer for this index
214
+ *
215
+ * @return pointer to the code packer
216
+ */
132
217
  CodePacker* get_CodePacker() const;
133
218
 
219
+ /** Merge another index into this one
220
+ *
221
+ * @param otherIndex index to merge from
222
+ * @param add_id ID offset to add to merged vectors
223
+ */
134
224
  void merge_from(Index& otherIndex, idx_t add_id = 0) override;
225
+
226
+ /** Check if another index is compatible for merging
227
+ *
228
+ * @param otherIndex index to check compatibility with
229
+ */
135
230
  void check_compatible_for_merge(const Index& otherIndex) const override;
136
231
 
137
232
  /// standalone codes interface (but the codes are flattened)
@@ -11,12 +11,15 @@
11
11
 
12
12
  #include <faiss/impl/AuxIndexStructures.h>
13
13
  #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/impl/ResultHandler.h>
14
15
  #include <faiss/utils/Heap.h>
15
16
  #include <faiss/utils/distances.h>
16
17
  #include <faiss/utils/extra_distances.h>
17
18
  #include <faiss/utils/prefetch.h>
18
19
  #include <faiss/utils/sorting.h>
20
+ #include <omp.h>
19
21
  #include <cstring>
22
+ #include <numeric>
20
23
 
21
24
  namespace faiss {
22
25
 
@@ -100,15 +103,24 @@ namespace {
100
103
  struct FlatL2Dis : FlatCodesDistanceComputer {
101
104
  size_t d;
102
105
  idx_t nb;
103
- const float* q;
104
106
  const float* b;
105
107
  size_t ndis;
108
+ size_t npartial_dot_products;
106
109
 
107
110
  float distance_to_code(const uint8_t* code) final {
108
111
  ndis++;
109
112
  return fvec_L2sqr(q, (float*)code, d);
110
113
  }
111
114
 
115
+ float partial_dot_product(
116
+ const idx_t i,
117
+ const uint32_t offset,
118
+ const uint32_t num_components) final override {
119
+ npartial_dot_products++;
120
+ return fvec_inner_product(
121
+ q + offset, b + i * d + offset, num_components);
122
+ }
123
+
112
124
  float symmetric_dis(idx_t i, idx_t j) override {
113
125
  return fvec_L2sqr(b + j * d, b + i * d, d);
114
126
  }
@@ -116,12 +128,13 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
116
128
  explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
117
129
  : FlatCodesDistanceComputer(
118
130
  storage.codes.data(),
119
- storage.code_size),
131
+ storage.code_size,
132
+ q),
120
133
  d(storage.d),
121
134
  nb(storage.ntotal),
122
- q(q),
123
135
  b(storage.get_xb()),
124
- ndis(0) {}
136
+ ndis(0),
137
+ npartial_dot_products(0) {}
125
138
 
126
139
  void set_query(const float* x) override {
127
140
  q = x;
@@ -159,6 +172,50 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
159
172
  dis2 = dp2;
160
173
  dis3 = dp3;
161
174
  }
175
+
176
+ void partial_dot_product_batch_4(
177
+ const idx_t idx0,
178
+ const idx_t idx1,
179
+ const idx_t idx2,
180
+ const idx_t idx3,
181
+ float& dp0,
182
+ float& dp1,
183
+ float& dp2,
184
+ float& dp3,
185
+ const uint32_t offset,
186
+ const uint32_t num_components) final override {
187
+ npartial_dot_products += 4;
188
+
189
+ // compute first, assign next
190
+ const float* __restrict y0 =
191
+ reinterpret_cast<const float*>(codes + idx0 * code_size);
192
+ const float* __restrict y1 =
193
+ reinterpret_cast<const float*>(codes + idx1 * code_size);
194
+ const float* __restrict y2 =
195
+ reinterpret_cast<const float*>(codes + idx2 * code_size);
196
+ const float* __restrict y3 =
197
+ reinterpret_cast<const float*>(codes + idx3 * code_size);
198
+
199
+ float dp0_ = 0;
200
+ float dp1_ = 0;
201
+ float dp2_ = 0;
202
+ float dp3_ = 0;
203
+ fvec_inner_product_batch_4(
204
+ q + offset,
205
+ y0 + offset,
206
+ y1 + offset,
207
+ y2 + offset,
208
+ y3 + offset,
209
+ num_components,
210
+ dp0_,
211
+ dp1_,
212
+ dp2_,
213
+ dp3_);
214
+ dp0 = dp0_;
215
+ dp1 = dp1_;
216
+ dp2 = dp2_;
217
+ dp3 = dp3_;
218
+ }
162
219
  };
163
220
 
164
221
  struct FlatIPDis : FlatCodesDistanceComputer {
@@ -519,4 +576,317 @@ void IndexFlat1D::search(
519
576
  done:;
520
577
  }
521
578
  }
579
+
580
+ /**************************************************************
581
+ * shared flat Panorama search code
582
+ **************************************************************/
583
+
584
+ namespace {
585
+
586
+ template <bool use_radius, typename BlockHandler>
587
+ inline void flat_pano_search_core(
588
+ const IndexFlatPanorama& index,
589
+ BlockHandler& handler,
590
+ idx_t n,
591
+ const float* x,
592
+ float radius,
593
+ const SearchParameters* params) {
594
+ using SingleResultHandler = typename BlockHandler::SingleResultHandler;
595
+
596
+ IDSelector* sel = params ? params->sel : nullptr;
597
+ bool use_sel = sel != nullptr;
598
+
599
+ [[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
600
+ size_t n_batches = (index.ntotal + index.batch_size - 1) / index.batch_size;
601
+
602
+ #pragma omp parallel num_threads(nt)
603
+ {
604
+ SingleResultHandler res(handler);
605
+
606
+ std::vector<float> query_cum_norms(index.n_levels + 1);
607
+ std::vector<float> exact_distances(index.batch_size);
608
+ std::vector<uint32_t> active_indices(index.batch_size);
609
+
610
+ #pragma omp for
611
+ for (int64_t i = 0; i < n; i++) {
612
+ const float* xi = x + i * index.d;
613
+ index.pano.compute_query_cum_sums(xi, query_cum_norms.data());
614
+
615
+ PanoramaStats local_stats;
616
+ local_stats.reset();
617
+
618
+ res.begin(i);
619
+
620
+ for (size_t batch_no = 0; batch_no < n_batches; batch_no++) {
621
+ size_t batch_start = batch_no * index.batch_size;
622
+
623
+ float threshold;
624
+ if constexpr (use_radius) {
625
+ threshold = radius;
626
+ } else {
627
+ threshold = res.heap_dis[0];
628
+ }
629
+
630
+ size_t num_active =
631
+ index.pano
632
+ .progressive_filter_batch<CMax<float, int64_t>>(
633
+ index.codes.data(),
634
+ index.cum_sums.data(),
635
+ xi,
636
+ query_cum_norms.data(),
637
+ batch_no,
638
+ index.ntotal,
639
+ sel,
640
+ nullptr,
641
+ use_sel,
642
+ active_indices,
643
+ exact_distances,
644
+ threshold,
645
+ local_stats);
646
+
647
+ for (size_t j = 0; j < num_active; j++) {
648
+ res.add_result(
649
+ exact_distances[active_indices[j]],
650
+ batch_start + active_indices[j]);
651
+ }
652
+ }
653
+
654
+ res.end();
655
+ indexPanorama_stats.add(local_stats);
656
+ }
657
+ }
658
+ }
659
+
660
+ } // anonymous namespace
661
+
662
+ /***************************************************
663
+ * IndexFlatPanorama
664
+ ***************************************************/
665
+
666
+ void IndexFlatPanorama::add(idx_t n, const float* x) {
667
+ size_t offset = ntotal;
668
+ ntotal += n;
669
+ size_t num_batches = (ntotal + batch_size - 1) / batch_size;
670
+
671
+ codes.resize(num_batches * batch_size * code_size);
672
+ cum_sums.resize(num_batches * batch_size * (n_levels + 1));
673
+
674
+ const uint8_t* code = reinterpret_cast<const uint8_t*>(x);
675
+ pano.copy_codes_to_level_layout(codes.data(), offset, n, code);
676
+ pano.compute_cumulative_sums(cum_sums.data(), offset, n, x);
677
+ }
678
+
679
+ void IndexFlatPanorama::search(
680
+ idx_t n,
681
+ const float* x,
682
+ idx_t k,
683
+ float* distances,
684
+ idx_t* labels,
685
+ const SearchParameters* params) const {
686
+ FAISS_THROW_IF_NOT(k > 0);
687
+ FAISS_THROW_IF_NOT(batch_size >= k);
688
+
689
+ HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
690
+ size_t(n), distances, labels, size_t(k), nullptr);
691
+
692
+ flat_pano_search_core<false>(*this, handler, n, x, 0.0f, params);
693
+ }
694
+
695
+ void IndexFlatPanorama::range_search(
696
+ idx_t n,
697
+ const float* x,
698
+ float radius,
699
+ RangeSearchResult* result,
700
+ const SearchParameters* params) const {
701
+ RangeSearchBlockResultHandler<CMax<float, int64_t>, false> handler(
702
+ result, radius, nullptr);
703
+
704
+ flat_pano_search_core<true>(*this, handler, n, x, radius, params);
705
+ }
706
+
707
+ void IndexFlatPanorama::reset() {
708
+ IndexFlat::reset();
709
+ cum_sums.clear();
710
+ }
711
+
712
+ void IndexFlatPanorama::reconstruct(idx_t key, float* recons) const {
713
+ pano.reconstruct(key, recons, codes.data());
714
+ }
715
+
716
+ void IndexFlatPanorama::reconstruct_n(idx_t i, idx_t n, float* recons) const {
717
+ Index::reconstruct_n(i, n, recons);
718
+ }
719
+
720
+ size_t IndexFlatPanorama::remove_ids(const IDSelector& sel) {
721
+ idx_t j = 0;
722
+ for (idx_t i = 0; i < ntotal; i++) {
723
+ if (sel.is_member(i)) {
724
+ // should be removed
725
+ } else {
726
+ if (i > j) {
727
+ pano.copy_entry(
728
+ codes.data(),
729
+ codes.data(),
730
+ cum_sums.data(),
731
+ cum_sums.data(),
732
+ j,
733
+ i);
734
+ }
735
+ j++;
736
+ }
737
+ }
738
+ size_t nremove = ntotal - j;
739
+ if (nremove > 0) {
740
+ ntotal = j;
741
+ size_t num_batches = (ntotal + batch_size - 1) / batch_size;
742
+ codes.resize(num_batches * batch_size * code_size);
743
+ cum_sums.resize(num_batches * batch_size * (n_levels + 1));
744
+ }
745
+ return nremove;
746
+ }
747
+
748
+ void IndexFlatPanorama::merge_from(Index& otherIndex, idx_t add_id) {
749
+ FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatPanorama index");
750
+ check_compatible_for_merge(otherIndex);
751
+ IndexFlatPanorama* other = static_cast<IndexFlatPanorama*>(&otherIndex);
752
+
753
+ std::vector<float> buffer(other->ntotal * code_size);
754
+ otherIndex.reconstruct_n(0, other->ntotal, buffer.data());
755
+
756
+ add(other->ntotal, buffer.data());
757
+ other->reset();
758
+ }
759
+
760
+ void IndexFlatPanorama::add_sa_codes(
761
+ idx_t /* n */,
762
+ const uint8_t* /* codes_in */,
763
+ const idx_t* /* xids */) {
764
+ FAISS_THROW_MSG("add_sa_codes not implemented for IndexFlatPanorama");
765
+ }
766
+
767
+ void IndexFlatPanorama::permute_entries(const idx_t* perm) {
768
+ MaybeOwnedVector<uint8_t> new_codes(codes.size());
769
+ std::vector<float> new_cum_sums(cum_sums.size());
770
+
771
+ for (idx_t i = 0; i < ntotal; i++) {
772
+ pano.copy_entry(
773
+ new_codes.data(),
774
+ codes.data(),
775
+ new_cum_sums.data(),
776
+ cum_sums.data(),
777
+ i,
778
+ perm[i]);
779
+ }
780
+
781
+ std::swap(codes, new_codes);
782
+ std::swap(cum_sums, new_cum_sums);
783
+ }
784
+
785
+ void IndexFlatPanorama::search_subset(
786
+ idx_t n,
787
+ const float* x,
788
+ idx_t k_base,
789
+ const idx_t* base_labels,
790
+ idx_t k,
791
+ float* distances,
792
+ idx_t* labels) const {
793
+ using SingleResultHandler =
794
+ HeapBlockResultHandler<CMax<float, int64_t>, false>::
795
+ SingleResultHandler;
796
+ HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
797
+ size_t(n), distances, labels, size_t(k), nullptr);
798
+
799
+ FAISS_THROW_IF_NOT(k > 0);
800
+ FAISS_THROW_IF_NOT(batch_size == 1);
801
+
802
+ [[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
803
+
804
+ #pragma omp parallel num_threads(nt)
805
+ {
806
+ SingleResultHandler res(handler);
807
+
808
+ std::vector<float> query_cum_norms(n_levels + 1);
809
+
810
+ // Panorama's optimized point-wise refinement (Algorithm 2):
811
+ // Batch-wise Panorama, as implemented in Panorama.h, incurs overhead
812
+ // from maintaining active_indices and exact_distances. This optimized
813
+ // implementation has minimal overhead and is thus preferred for
814
+ // IndexRefine's use case.
815
+ // 1. Initialize exact distance as ||y||^2 + ||x||^2.
816
+ // 2. For each level, refine distance incrementally:
817
+ // - Compute dot product for current level: exact_dist -= 2*<x,y>.
818
+ // - Use Cauchy-Schwarz bound on remaining levels to get lower bound.
819
+ // - If there are less than k points in the heap, add the point to
820
+ // the heap.
821
+ // - Else, prune if lower bound exceeds k-th best distance.
822
+ // 3. After all levels, update heap if the point survived.
823
+ #pragma omp for
824
+ for (idx_t i = 0; i < n; i++) {
825
+ const idx_t* __restrict idsi = base_labels + i * k_base;
826
+ const float* xi = x + i * d;
827
+
828
+ PanoramaStats local_stats;
829
+ local_stats.reset();
830
+
831
+ pano.compute_query_cum_sums(xi, query_cum_norms.data());
832
+ float query_cum_norm = query_cum_norms[0] * query_cum_norms[0];
833
+
834
+ res.begin(i);
835
+
836
+ for (size_t j = 0; j < k_base; j++) {
837
+ idx_t idx = idsi[j];
838
+
839
+ if (idx < 0) {
840
+ continue;
841
+ }
842
+
843
+ size_t cum_sum_offset = (n_levels + 1) * idx;
844
+ float cum_sum = cum_sums[cum_sum_offset];
845
+ float exact_distance = cum_sum * cum_sum + query_cum_norm;
846
+ cum_sum_offset++;
847
+
848
+ const float* x_ptr = xi;
849
+ const float* p_ptr =
850
+ reinterpret_cast<const float*>(codes.data()) + d * idx;
851
+
852
+ local_stats.total_dims += d;
853
+
854
+ bool pruned = false;
855
+ for (size_t level = 0; level < n_levels; level++) {
856
+ local_stats.total_dims_scanned += pano.level_width_floats;
857
+
858
+ // Refine distance
859
+ size_t actual_level_width = std::min(
860
+ pano.level_width_floats,
861
+ d - level * pano.level_width_floats);
862
+ float dot_product = fvec_inner_product(
863
+ x_ptr, p_ptr, actual_level_width);
864
+ exact_distance -= 2 * dot_product;
865
+
866
+ float cum_sum = cum_sums[cum_sum_offset];
867
+ float cauchy_schwarz_bound =
868
+ 2.0f * cum_sum * query_cum_norms[level + 1];
869
+ float lower_bound = exact_distance - cauchy_schwarz_bound;
870
+
871
+ // Prune using Cauchy-Schwarz bound
872
+ if (lower_bound > res.heap_dis[0]) {
873
+ pruned = true;
874
+ break;
875
+ }
876
+
877
+ cum_sum_offset++;
878
+ x_ptr += pano.level_width_floats;
879
+ p_ptr += pano.level_width_floats;
880
+ }
881
+
882
+ if (!pruned) {
883
+ res.add_result(exact_distance, idx);
884
+ }
885
+ }
886
+
887
+ res.end();
888
+ indexPanorama_stats.add(local_stats);
889
+ }
890
+ }
891
+ }
522
892
  } // namespace faiss