faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -12,6 +12,7 @@
12
12
  #include <faiss/impl/AuxIndexStructures.h>
13
13
  #include <faiss/impl/FaissAssert.h>
14
14
  #include <faiss/impl/ResultHandler.h>
15
+ #include <faiss/impl/simd_dispatch.h>
15
16
  #include <faiss/utils/Heap.h>
16
17
  #include <faiss/utils/distances.h>
17
18
  #include <faiss/utils/extra_distances.h>
@@ -19,7 +20,6 @@
19
20
  #include <faiss/utils/sorting.h>
20
21
  #include <omp.h>
21
22
  #include <cstring>
22
- #include <numeric>
23
23
 
24
24
  namespace faiss {
25
25
 
@@ -44,7 +44,6 @@ void IndexFlat::search(
44
44
  float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45
45
  knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
46
46
  } else {
47
- FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
48
47
  knn_extra_metrics(
49
48
  x,
50
49
  get_xb(),
@@ -55,7 +54,8 @@ void IndexFlat::search(
55
54
  metric_arg,
56
55
  k,
57
56
  distances,
58
- labels);
57
+ labels,
58
+ sel);
59
59
  }
60
60
  }
61
61
 
@@ -100,6 +100,7 @@ void IndexFlat::compute_distance_subset(
100
100
 
101
101
  namespace {
102
102
 
103
+ template <SIMDLevel SL>
103
104
  struct FlatL2Dis : FlatCodesDistanceComputer {
104
105
  size_t d;
105
106
  idx_t nb;
@@ -109,7 +110,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
109
110
 
110
111
  float distance_to_code(const uint8_t* code) final {
111
112
  ndis++;
112
- return fvec_L2sqr(q, (float*)code, d);
113
+ return fvec_L2sqr<SL>(q, (float*)code, d);
113
114
  }
114
115
 
115
116
  float partial_dot_product(
@@ -117,12 +118,12 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
117
118
  const uint32_t offset,
118
119
  const uint32_t num_components) final override {
119
120
  npartial_dot_products++;
120
- return fvec_inner_product(
121
+ return fvec_inner_product<SL>(
121
122
  q + offset, b + i * d + offset, num_components);
122
123
  }
123
124
 
124
125
  float symmetric_dis(idx_t i, idx_t j) override {
125
- return fvec_L2sqr(b + j * d, b + i * d, d);
126
+ return fvec_L2sqr<SL>(b + j * d, b + i * d, d);
126
127
  }
127
128
 
128
129
  explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
@@ -166,7 +167,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
166
167
  float dp1 = 0;
167
168
  float dp2 = 0;
168
169
  float dp3 = 0;
169
- fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
170
+ fvec_L2sqr_batch_4<SL>(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
170
171
  dis0 = dp0;
171
172
  dis1 = dp1;
172
173
  dis2 = dp2;
@@ -200,7 +201,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
200
201
  float dp1_ = 0;
201
202
  float dp2_ = 0;
202
203
  float dp3_ = 0;
203
- fvec_inner_product_batch_4(
204
+ fvec_inner_product_batch_4<SL>(
204
205
  q + offset,
205
206
  y0 + offset,
206
207
  y1 + offset,
@@ -218,6 +219,7 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
218
219
  }
219
220
  };
220
221
 
222
+ template <SIMDLevel SL>
221
223
  struct FlatIPDis : FlatCodesDistanceComputer {
222
224
  size_t d;
223
225
  idx_t nb;
@@ -226,12 +228,12 @@ struct FlatIPDis : FlatCodesDistanceComputer {
226
228
  size_t ndis;
227
229
 
228
230
  float symmetric_dis(idx_t i, idx_t j) final override {
229
- return fvec_inner_product(b + j * d, b + i * d, d);
231
+ return fvec_inner_product<SL>(b + j * d, b + i * d, d);
230
232
  }
231
233
 
232
234
  float distance_to_code(const uint8_t* code) final override {
233
235
  ndis++;
234
- return fvec_inner_product(q, (const float*)code, d);
236
+ return fvec_inner_product<SL>(q, (const float*)code, d);
235
237
  }
236
238
 
237
239
  explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
@@ -274,7 +276,8 @@ struct FlatIPDis : FlatCodesDistanceComputer {
274
276
  float dp1 = 0;
275
277
  float dp2 = 0;
276
278
  float dp3 = 0;
277
- fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
279
+ fvec_inner_product_batch_4<SL>(
280
+ q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
278
281
  dis0 = dp0;
279
282
  dis1 = dp1;
280
283
  dis2 = dp2;
@@ -285,14 +288,16 @@ struct FlatIPDis : FlatCodesDistanceComputer {
285
288
  } // namespace
286
289
 
287
290
  FlatCodesDistanceComputer* IndexFlat::get_FlatCodesDistanceComputer() const {
291
+ FlatCodesDistanceComputer* dc = nullptr;
288
292
  if (metric_type == METRIC_L2) {
289
- return new FlatL2Dis(*this);
293
+ with_simd_level([&]<SIMDLevel SL>() { dc = new FlatL2Dis<SL>(*this); });
290
294
  } else if (metric_type == METRIC_INNER_PRODUCT) {
291
- return new FlatIPDis(*this);
295
+ with_simd_level([&]<SIMDLevel SL>() { dc = new FlatIPDis<SL>(*this); });
292
296
  } else {
293
- return get_extra_distance_computer(
297
+ dc = get_extra_distance_computer(
294
298
  d, metric_type, metric_arg, ntotal, get_xb());
295
299
  }
300
+ return dc;
296
301
  }
297
302
 
298
303
  void IndexFlat::reconstruct(idx_t key, float* recons) const {
@@ -317,6 +322,7 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
317
322
  ***************************************************/
318
323
 
319
324
  namespace {
325
+ template <SIMDLevel SL>
320
326
  struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
321
327
  size_t d;
322
328
  idx_t nb;
@@ -329,7 +335,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
329
335
 
330
336
  float distance_to_code(const uint8_t* code) final override {
331
337
  ndis++;
332
- return fvec_L2sqr(q, (float*)code, d);
338
+ return fvec_L2sqr<SL>(q, (float*)code, d);
333
339
  }
334
340
 
335
341
  float operator()(const idx_t i) final override {
@@ -337,7 +343,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
337
343
  reinterpret_cast<const float*>(codes + i * code_size);
338
344
 
339
345
  prefetch_L2(l2norms + i);
340
- const float dp0 = fvec_inner_product(q, y, d);
346
+ const float dp0 = fvec_inner_product<SL>(q, y, d);
341
347
  return query_l2norm + l2norms[i] - 2 * dp0;
342
348
  }
343
349
 
@@ -349,7 +355,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
349
355
 
350
356
  prefetch_L2(l2norms + i);
351
357
  prefetch_L2(l2norms + j);
352
- const float dp0 = fvec_inner_product(yi, yj, d);
358
+ const float dp0 = fvec_inner_product<SL>(yi, yj, d);
353
359
  return l2norms[i] + l2norms[j] - 2 * dp0;
354
360
  }
355
361
 
@@ -369,7 +375,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
369
375
 
370
376
  void set_query(const float* x) override {
371
377
  q = x;
372
- query_l2norm = fvec_norm_L2sqr(q, d);
378
+ query_l2norm = fvec_norm_L2sqr<SL>(q, d);
373
379
  }
374
380
 
375
381
  // compute four distances
@@ -403,7 +409,8 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
403
409
  float dp1 = 0;
404
410
  float dp2 = 0;
405
411
  float dp3 = 0;
406
- fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
412
+ fvec_inner_product_batch_4<SL>(
413
+ q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
407
414
  dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
408
415
  dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
409
416
  dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
@@ -430,7 +437,11 @@ void IndexFlatL2::clear_l2norms() {
430
437
  FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
431
438
  if (metric_type == METRIC_L2) {
432
439
  if (!cached_l2norms.empty()) {
433
- return new FlatL2WithNormsDis(*this);
440
+ FlatCodesDistanceComputer* dc = nullptr;
441
+ with_simd_level([&]<SIMDLevel SL>() {
442
+ dc = new FlatL2WithNormsDis<SL>(*this);
443
+ });
444
+ return dc;
434
445
  }
435
446
  }
436
447
 
@@ -583,7 +594,17 @@ void IndexFlat1D::search(
583
594
 
584
595
  namespace {
585
596
 
586
- template <bool use_radius, typename BlockHandler>
597
+ template <typename Fn>
598
+ inline auto dispatch_metric_compare(MetricType metric, Fn&& fn) {
599
+ if (is_similarity_metric(metric)) {
600
+ using C = CMin<float, int64_t>;
601
+ return fn.template operator()<C>();
602
+ }
603
+ using C = CMax<float, int64_t>;
604
+ return fn.template operator()<C>();
605
+ }
606
+
607
+ template <bool use_radius, typename C, typename BlockHandler>
587
608
  inline void flat_pano_search_core(
588
609
  const IndexFlatPanorama& index,
589
610
  BlockHandler& handler,
@@ -627,22 +648,23 @@ inline void flat_pano_search_core(
627
648
  threshold = res.heap_dis[0];
628
649
  }
629
650
 
630
- size_t num_active =
631
- index.pano
632
- .progressive_filter_batch<CMax<float, int64_t>>(
633
- index.codes.data(),
634
- index.cum_sums.data(),
635
- xi,
636
- query_cum_norms.data(),
637
- batch_no,
638
- index.ntotal,
639
- sel,
640
- nullptr,
641
- use_sel,
642
- active_indices,
643
- exact_distances,
644
- threshold,
645
- local_stats);
651
+ size_t num_active = with_metric_type(
652
+ index.metric_type, [&]<MetricType M>() {
653
+ return index.pano.progressive_filter_batch<C, M>(
654
+ index.codes.data(),
655
+ index.cum_sums.data(),
656
+ xi,
657
+ query_cum_norms.data(),
658
+ batch_no,
659
+ index.ntotal,
660
+ sel,
661
+ nullptr,
662
+ use_sel,
663
+ active_indices,
664
+ exact_distances,
665
+ threshold,
666
+ local_stats);
667
+ });
646
668
 
647
669
  for (size_t j = 0; j < num_active; j++) {
648
670
  res.add_result(
@@ -686,10 +708,11 @@ void IndexFlatPanorama::search(
686
708
  FAISS_THROW_IF_NOT(k > 0);
687
709
  FAISS_THROW_IF_NOT(batch_size >= k);
688
710
 
689
- HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
690
- size_t(n), distances, labels, size_t(k), nullptr);
691
-
692
- flat_pano_search_core<false>(*this, handler, n, x, 0.0f, params);
711
+ dispatch_metric_compare(metric_type, [&]<typename C>() {
712
+ HeapBlockResultHandler<C, false> handler(
713
+ size_t(n), distances, labels, size_t(k), nullptr);
714
+ flat_pano_search_core<false, C>(*this, handler, n, x, 0.0f, params);
715
+ });
693
716
  }
694
717
 
695
718
  void IndexFlatPanorama::range_search(
@@ -698,10 +721,11 @@ void IndexFlatPanorama::range_search(
698
721
  float radius,
699
722
  RangeSearchResult* result,
700
723
  const SearchParameters* params) const {
701
- RangeSearchBlockResultHandler<CMax<float, int64_t>, false> handler(
702
- result, radius, nullptr);
703
-
704
- flat_pano_search_core<true>(*this, handler, n, x, radius, params);
724
+ dispatch_metric_compare(metric_type, [&]<typename C>() {
725
+ RangeSearchBlockResultHandler<C, false> handler(
726
+ result, radius, nullptr);
727
+ flat_pano_search_core<true, C>(*this, handler, n, x, radius, params);
728
+ });
705
729
  }
706
730
 
707
731
  void IndexFlatPanorama::reset() {
@@ -790,103 +814,136 @@ void IndexFlatPanorama::search_subset(
790
814
  idx_t k,
791
815
  float* distances,
792
816
  idx_t* labels) const {
793
- using SingleResultHandler =
794
- HeapBlockResultHandler<CMax<float, int64_t>, false>::
795
- SingleResultHandler;
796
- HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
797
- size_t(n), distances, labels, size_t(k), nullptr);
798
-
799
- FAISS_THROW_IF_NOT(k > 0);
800
- FAISS_THROW_IF_NOT(batch_size == 1);
801
-
802
- [[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
817
+ with_simd_level([&]<SIMDLevel SL>() {
818
+ with_metric_type(metric_type, [&]<MetricType M>() {
819
+ constexpr bool is_sim = is_similarity_metric(M);
820
+ using C = std::conditional_t<
821
+ is_sim,
822
+ CMin<float, int64_t>,
823
+ CMax<float, int64_t>>;
824
+ using SingleResultHandler =
825
+ typename HeapBlockResultHandler<C, false>::
826
+ SingleResultHandler;
827
+ HeapBlockResultHandler<C, false> handler(
828
+ size_t(n), distances, labels, size_t(k), nullptr);
829
+
830
+ FAISS_THROW_IF_NOT(k > 0);
831
+ FAISS_THROW_IF_NOT(batch_size == 1);
832
+
833
+ [[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
803
834
 
804
835
  #pragma omp parallel num_threads(nt)
805
- {
806
- SingleResultHandler res(handler);
807
-
808
- std::vector<float> query_cum_norms(n_levels + 1);
809
-
810
- // Panorama's optimized point-wise refinement (Algorithm 2):
811
- // Batch-wise Panorama, as implemented in Panorama.h, incurs overhead
812
- // from maintaining active_indices and exact_distances. This optimized
813
- // implementation has minimal overhead and is thus preferred for
814
- // IndexRefine's use case.
815
- // 1. Initialize exact distance as ||y||^2 + ||x||^2.
816
- // 2. For each level, refine distance incrementally:
817
- // - Compute dot product for current level: exact_dist -= 2*<x,y>.
818
- // - Use Cauchy-Schwarz bound on remaining levels to get lower bound.
819
- // - If there are less than k points in the heap, add the point to
820
- // the heap.
821
- // - Else, prune if lower bound exceeds k-th best distance.
822
- // 3. After all levels, update heap if the point survived.
836
+ {
837
+ SingleResultHandler res(handler);
838
+
839
+ std::vector<float> query_cum_norms(n_levels + 1);
840
+
841
+ // Panorama's optimized point-wise refinement (Algorithm 2):
842
+ // Batch-wise Panorama, as implemented in Panorama.h, incurs
843
+ // overhead from maintaining active_indices and exact_distances.
844
+ // This optimized implementation has minimal overhead and is
845
+ // thus preferred for IndexRefine's use case.
846
+ // 1. Initialize exact distance as ||y||^2 + ||x||^2.
847
+ // 2. For each level, refine distance incrementally:
848
+ // - Compute dot product for current level: exact_dist -=
849
+ // 2*<x,y>.
850
+ // - Use Cauchy-Schwarz bound on remaining levels to get
851
+ // lower bound.
852
+ // - If there are less than k points in the heap, add the
853
+ // point to the heap.
854
+ // - Else, prune if lower bound exceeds k-th best distance.
855
+ // 3. After all levels, update heap if the point survived.
823
856
  #pragma omp for
824
- for (idx_t i = 0; i < n; i++) {
825
- const idx_t* __restrict idsi = base_labels + i * k_base;
826
- const float* xi = x + i * d;
827
-
828
- PanoramaStats local_stats;
829
- local_stats.reset();
830
-
831
- pano.compute_query_cum_sums(xi, query_cum_norms.data());
832
- float query_cum_norm = query_cum_norms[0] * query_cum_norms[0];
833
-
834
- res.begin(i);
835
-
836
- for (size_t j = 0; j < k_base; j++) {
837
- idx_t idx = idsi[j];
838
-
839
- if (idx < 0) {
840
- continue;
841
- }
842
-
843
- size_t cum_sum_offset = (n_levels + 1) * idx;
844
- float cum_sum = cum_sums[cum_sum_offset];
845
- float exact_distance = cum_sum * cum_sum + query_cum_norm;
846
- cum_sum_offset++;
847
-
848
- const float* x_ptr = xi;
849
- const float* p_ptr =
850
- reinterpret_cast<const float*>(codes.data()) + d * idx;
851
-
852
- local_stats.total_dims += d;
853
-
854
- bool pruned = false;
855
- for (size_t level = 0; level < n_levels; level++) {
856
- local_stats.total_dims_scanned += pano.level_width_floats;
857
-
858
- // Refine distance
859
- size_t actual_level_width = std::min(
860
- pano.level_width_floats,
861
- d - level * pano.level_width_floats);
862
- float dot_product = fvec_inner_product(
863
- x_ptr, p_ptr, actual_level_width);
864
- exact_distance -= 2 * dot_product;
865
-
866
- float cum_sum = cum_sums[cum_sum_offset];
867
- float cauchy_schwarz_bound =
868
- 2.0f * cum_sum * query_cum_norms[level + 1];
869
- float lower_bound = exact_distance - cauchy_schwarz_bound;
870
-
871
- // Prune using Cauchy-Schwarz bound
872
- if (lower_bound > res.heap_dis[0]) {
873
- pruned = true;
874
- break;
857
+ for (idx_t i = 0; i < n; i++) {
858
+ const idx_t* __restrict idsi = base_labels + i * k_base;
859
+ const float* xi = x + i * d;
860
+
861
+ PanoramaStats local_stats;
862
+ local_stats.reset();
863
+
864
+ pano.compute_query_cum_sums(xi, query_cum_norms.data());
865
+ float query_cum_norm =
866
+ query_cum_norms[0] * query_cum_norms[0];
867
+
868
+ res.begin(i);
869
+
870
+ for (size_t j = 0; j < k_base; j++) {
871
+ idx_t idx = idsi[j];
872
+
873
+ if (idx < 0) {
874
+ continue;
875
+ }
876
+
877
+ size_t cum_sum_offset = (n_levels + 1) * idx;
878
+ float cum_sum = cum_sums[cum_sum_offset];
879
+ float exact_distance = 0.0f;
880
+ if constexpr (!is_sim) {
881
+ exact_distance = cum_sum * cum_sum + query_cum_norm;
882
+ }
883
+ cum_sum_offset++;
884
+
885
+ const float* x_ptr = xi;
886
+ const float* p_ptr =
887
+ reinterpret_cast<const float*>(codes.data()) +
888
+ d * idx;
889
+
890
+ local_stats.total_dims += d;
891
+
892
+ bool pruned = false;
893
+ for (size_t level = 0; level < n_levels; level++) {
894
+ local_stats.total_dims_scanned +=
895
+ pano.level_width_floats;
896
+
897
+ // Refine distance
898
+ size_t actual_level_width = std::min(
899
+ pano.level_width_floats,
900
+ d - level * pano.level_width_floats);
901
+ float dot_product = fvec_inner_product<SL>(
902
+ x_ptr, p_ptr, actual_level_width);
903
+ if constexpr (is_sim) {
904
+ exact_distance += dot_product;
905
+ } else {
906
+ exact_distance -= 2 * dot_product;
907
+ }
908
+
909
+ float level_cum_sum = cum_sums[cum_sum_offset];
910
+ float cauchy_schwarz_bound;
911
+ if constexpr (is_sim) {
912
+ cauchy_schwarz_bound = -level_cum_sum *
913
+ query_cum_norms[level + 1];
914
+ } else {
915
+ cauchy_schwarz_bound = 2.0f * level_cum_sum *
916
+ query_cum_norms[level + 1];
917
+ }
918
+ float bound = exact_distance - cauchy_schwarz_bound;
919
+
920
+ // Prune using Cauchy-Schwarz bound
921
+ bool should_prune = false;
922
+ if constexpr (is_sim) {
923
+ should_prune = bound < res.heap_dis[0];
924
+ } else {
925
+ should_prune = bound > res.heap_dis[0];
926
+ }
927
+ if (should_prune) {
928
+ pruned = true;
929
+ break;
930
+ }
931
+
932
+ cum_sum_offset++;
933
+ x_ptr += pano.level_width_floats;
934
+ p_ptr += pano.level_width_floats;
935
+ }
936
+
937
+ if (!pruned) {
938
+ res.add_result(exact_distance, idx);
939
+ }
875
940
  }
876
941
 
877
- cum_sum_offset++;
878
- x_ptr += pano.level_width_floats;
879
- p_ptr += pano.level_width_floats;
880
- }
881
-
882
- if (!pruned) {
883
- res.add_result(exact_distance, idx);
942
+ res.end();
943
+ indexPanorama_stats.add(local_stats);
884
944
  }
885
945
  }
886
-
887
- res.end();
888
- indexPanorama_stats.add(local_stats);
889
- }
890
- }
946
+ });
947
+ });
891
948
  }
892
949
  } // namespace faiss
@@ -121,7 +121,8 @@ struct IndexFlatPanorama : IndexFlat {
121
121
  batch_size(batch_size),
122
122
  n_levels(n_levels),
123
123
  pano(code_size, n_levels, batch_size) {
124
- FAISS_THROW_IF_NOT(metric == METRIC_L2);
124
+ FAISS_THROW_IF_NOT(
125
+ metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
125
126
  }
126
127
 
127
128
  void add(idx_t n, const float* x) override;
@@ -179,6 +180,20 @@ struct IndexFlatL2Panorama : IndexFlatPanorama {
179
180
  : IndexFlatPanorama(d, METRIC_L2, n_levels, batch_size) {}
180
181
  };
181
182
 
183
+ struct IndexFlatIPPanorama : IndexFlatPanorama {
184
+ /**
185
+ * @param d dimensionality of the input vectors
186
+ * @param n_levels number of Panorama levels
187
+ * @param batch_size batch size for Panorama storage
188
+ */
189
+ explicit IndexFlatIPPanorama(
190
+ idx_t d,
191
+ size_t n_levels,
192
+ size_t batch_size = 512)
193
+ : IndexFlatPanorama(d, METRIC_INNER_PRODUCT, n_levels, batch_size) {
194
+ }
195
+ };
196
+
182
197
  /// optimized version for 1D "vectors".
183
198
  struct IndexFlat1D : IndexFlatL2 {
184
199
  bool continuous_update = true; ///< is the permutation updated continuously?
@@ -182,15 +182,6 @@ struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
182
182
  }
183
183
  };
184
184
 
185
- struct Run_get_distance_computer {
186
- using T = FlatCodesDistanceComputer*;
187
-
188
- template <class VD>
189
- FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
190
- return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
191
- }
192
- };
193
-
194
185
  template <class BlockResultHandler>
195
186
  struct Run_search_with_decompress {
196
187
  using T = void;
@@ -231,17 +222,15 @@ struct Run_search_with_decompress {
231
222
  struct Run_search_with_decompress_res {
232
223
  using T = void;
233
224
 
234
- template <class ResultHandler>
235
- void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
236
- Run_search_with_decompress<ResultHandler> r;
237
- dispatch_VectorDistance(
238
- index->d,
239
- index->metric_type,
240
- index->metric_arg,
241
- r,
242
- index,
243
- xq,
244
- res);
225
+ template <class BlockResultHandler>
226
+ void f(BlockResultHandler& res,
227
+ const IndexFlatCodes* index,
228
+ const float* xq) {
229
+ with_VectorDistance(
230
+ index->d, index->metric_type, index->metric_arg, [&](auto vd) {
231
+ Run_search_with_decompress<BlockResultHandler> r;
232
+ r.template f<decltype(vd)>(vd, index, xq, res);
233
+ });
245
234
  }
246
235
  };
247
236
 
@@ -249,8 +238,14 @@ struct Run_search_with_decompress_res {
249
238
 
250
239
  FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
251
240
  const {
252
- Run_get_distance_computer r;
253
- return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
241
+ return with_VectorDistance(
242
+ d,
243
+ metric_type,
244
+ metric_arg,
245
+ [&](auto vd) -> FlatCodesDistanceComputer* {
246
+ return new GenericFlatCodesDistanceComputer<decltype(vd)>(
247
+ this, vd);
248
+ });
254
249
  }
255
250
 
256
251
  void IndexFlatCodes::search(
@@ -277,4 +272,33 @@ void IndexFlatCodes::range_search(
277
272
  dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
278
273
  }
279
274
 
275
+ void IndexFlatCodes::search1(
276
+ const float* x,
277
+ ResultHandler& handler,
278
+ SearchParameters* params) const {
279
+ const IDSelector* sel = params ? params->sel : nullptr;
280
+ Run_search_with_decompress_res r;
281
+ if (sel) {
282
+ if (is_similarity_metric(metric_type)) {
283
+ SingleQueryBlockResultHandler<CMin<float, idx_t>, true> res(
284
+ handler, sel);
285
+ r.f(res, this, x);
286
+ } else {
287
+ SingleQueryBlockResultHandler<CMax<float, idx_t>, true> res(
288
+ handler, sel);
289
+ r.f(res, this, x);
290
+ }
291
+ } else {
292
+ if (is_similarity_metric(metric_type)) {
293
+ SingleQueryBlockResultHandler<CMin<float, idx_t>, false> res(
294
+ handler);
295
+ r.f(res, this, x);
296
+ } else {
297
+ SingleQueryBlockResultHandler<CMax<float, idx_t>, false> res(
298
+ handler);
299
+ r.f(res, this, x);
300
+ }
301
+ }
302
+ }
303
+
280
304
  } // namespace faiss
@@ -55,7 +55,8 @@ struct IndexFlatCodes : Index {
55
55
  return get_FlatCodesDistanceComputer();
56
56
  }
57
57
 
58
- /** Search implemented by decoding */
58
+ /** Search implemented by decoding (most index types will have a faster
59
+ * implementation) */
59
60
  void search(
60
61
  idx_t n,
61
62
  const float* x,
@@ -71,6 +72,11 @@ struct IndexFlatCodes : Index {
71
72
  RangeSearchResult* result,
72
73
  const SearchParameters* params = nullptr) const override;
73
74
 
75
+ virtual void search1(
76
+ const float* x,
77
+ ResultHandler& handler,
78
+ SearchParameters* params = nullptr) const override;
79
+
74
80
  // returns a new instance of a CodePacker
75
81
  CodePacker* get_CodePacker() const;
76
82