faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -25,7 +25,9 @@ namespace faiss {
25
25
  // Forward declarations to avoid circular dependency.
26
26
  struct IndexHNSW;
27
27
  struct IndexHNSWFlatPanorama;
28
- struct MinimaxHeap;
28
+ template <class HC_>
29
+ struct MinimaxHeapT;
30
+ using MinimaxHeap = MinimaxHeapT<CMax<float, int32_t>>;
29
31
  class LockVector;
30
32
 
31
33
  /** Implementation of the Hierarchical Navigable Small World
@@ -60,30 +62,52 @@ struct HNSW {
60
62
  /// internal storage of vectors (32 bits: this is expensive)
61
63
  using storage_idx_t = int32_t;
62
64
 
63
- // for now we do only these distances
64
- using C = CMax<float, int64_t>;
65
+ // The two comparator flavors HNSW supports. CMax (smaller-is-better)
66
+ // is the default; CMin (larger-is-better) is used when `is_similarity`
67
+ // is set on the owning index.
68
+ using C_distance = CMax<float, int64_t>;
69
+ using C_similarity = CMin<float, int64_t>;
70
+
71
+ // Back-compat alias: keeps `HNSW::C` resolving to the distance
72
+ // (CMax) comparator everywhere the type is referenced directly.
73
+ using C = C_distance;
65
74
 
66
75
  typedef std::pair<float, storage_idx_t> Node;
67
76
 
68
77
  /// to sort pairs of (id, distance) from nearest to farthest or the reverse
69
- struct NodeDistCloser {
78
+ template <class CT>
79
+ struct NodeDistCloserT {
70
80
  float d;
71
81
  int id;
72
- NodeDistCloser(float d_in, int id_in) : d(d_in), id(id_in) {}
73
- bool operator<(const NodeDistCloser& obj1) const {
74
- return d < obj1.d;
82
+ NodeDistCloserT(float d_in, int id_in) : d(d_in), id(id_in) {}
83
+ bool operator<(const NodeDistCloserT& obj1) const {
84
+ // priority_queue keeps the "worst" element at the top so that
85
+ // when the queue is full we can pop it. For CMax (distance) the
86
+ // worst element is the largest d; for CMin (similarity) it is
87
+ // the smallest d. Equivalent to: obj1.d "better than" d.
88
+ return CT::cmp(obj1.d, d);
75
89
  }
76
90
  };
77
91
 
78
- struct NodeDistFarther {
92
+ template <class CT>
93
+ struct NodeDistFartherT {
79
94
  float d;
80
95
  int id;
81
- NodeDistFarther(float d_in, int id_in) : d(d_in), id(id_in) {}
82
- bool operator<(const NodeDistFarther& obj1) const {
83
- return d > obj1.d;
96
+ NodeDistFartherT(float d_in, int id_in) : d(d_in), id(id_in) {}
97
+ bool operator<(const NodeDistFartherT& obj1) const {
98
+ // priority_queue here keeps the "best" element at the top so we
99
+ // can process the nearest candidate first. For CMax (distance)
100
+ // the best is the smallest d; for CMin (similarity) the best is
101
+ // the largest d. Equivalent to: d "better than" obj1.d.
102
+ return CT::cmp(d, obj1.d);
84
103
  }
85
104
  };
86
105
 
106
+ // Back-compat aliases: default to the distance (CMax) comparator so
107
+ // existing call sites that mention `HNSW::NodeDist*` keep working.
108
+ using NodeDistCloser = NodeDistCloserT<C_distance>;
109
+ using NodeDistFarther = NodeDistFartherT<C_distance>;
110
+
87
111
  /// assignment probability to each layer (sum=1)
88
112
  std::vector<double> assign_probas;
89
113
 
@@ -131,6 +155,12 @@ struct HNSW {
131
155
  /// use Panorama progressive pruning in search
132
156
  bool is_panorama = false;
133
157
 
158
+ /// distance comparison semantics: when true, distances are treated as
159
+ /// similarity scores (larger is better). Default false matches the
160
+ /// historical L2/Hamming behavior (smaller is better).
161
+ /// Not serialized: must be re-set by the owning Index after loading.
162
+ bool is_similarity = false;
163
+
134
164
  // See impl/VisitedTable.h.
135
165
  std::optional<bool> use_visited_hashset;
136
166
 
@@ -216,10 +246,11 @@ struct HNSW {
216
246
 
217
247
  int prepare_level_tab(size_t n, bool preset_levels = false);
218
248
 
249
+ template <class C = C_distance>
219
250
  static void shrink_neighbor_list(
220
251
  DistanceComputer& qdis,
221
- std::priority_queue<NodeDistFarther>& input,
222
- std::vector<NodeDistFarther>& output,
252
+ std::priority_queue<NodeDistFartherT<C>>& input,
253
+ std::vector<NodeDistFartherT<C>>& output,
223
254
  size_t max_size,
224
255
  bool keep_max_size_level0 = false);
225
256
 
@@ -250,6 +281,11 @@ struct HNSWStats {
250
281
  // global var that collects them all
251
282
  FAISS_API extern HNSWStats hnsw_stats;
252
283
 
284
+ /// Internal HNSW algorithm helpers. These are not part of the public API; they
285
+ /// are exposed here only so that unit tests (and a few cross-TU callers such as
286
+ /// the Panorama search variant) can reach them.
287
+ namespace hnsw_detail {
288
+
253
289
  int search_from_candidates(
254
290
  const HNSW& hnsw,
255
291
  DistanceComputer& qdis,
@@ -302,4 +338,6 @@ void search_neighbors_to_add(
302
338
  VisitedTable& vt,
303
339
  bool reference_version = false);
304
340
 
341
+ } // namespace hnsw_detail
342
+
305
343
  } // namespace faiss
@@ -234,10 +234,11 @@ void NSG::init_graph(Index* storage, const nsg::Graph<idx_t>& knn_graph) {
234
234
  std::unique_ptr<DistanceComputer> dis(storage_distance_computer(storage));
235
235
 
236
236
  dis->set_query(center.get());
237
- VisitedTable vt(ntotal, use_visited_hashset);
237
+ std::unique_ptr<VisitedTable> vt =
238
+ VisitedTable::create(ntotal, use_visited_hashset);
238
239
 
239
240
  // Do not collect the visited nodes
240
- search_on_graph<false>(knn_graph, *dis, vt, ep, L, retset, tmpset);
241
+ search_on_graph<false>(knn_graph, *dis, *vt, ep, L, retset, tmpset);
241
242
 
242
243
  // set enterpoint
243
244
  enterpoint = retset[0].id;
@@ -344,7 +345,8 @@ void NSG::link(
344
345
  std::vector<Node> pool;
345
346
  std::vector<Neighbor> tmp;
346
347
 
347
- VisitedTable vt(ntotal, use_visited_hashset);
348
+ std::unique_ptr<VisitedTable> vt =
349
+ VisitedTable::create(ntotal, use_visited_hashset);
348
350
  std::unique_ptr<DistanceComputer> dis(
349
351
  storage_distance_computer(storage));
350
352
 
@@ -355,13 +357,13 @@ void NSG::link(
355
357
 
356
358
  // Collect the visited nodes into pool
357
359
  search_on_graph<true>(
358
- knn_graph, *dis, vt, enterpoint, L, tmp, pool);
360
+ knn_graph, *dis, *vt, enterpoint, L, tmp, pool);
359
361
 
360
- sync_prune(i, pool, *dis, vt, knn_graph, graph);
362
+ sync_prune(i, pool, *dis, *vt, knn_graph, graph);
361
363
 
362
364
  pool.clear();
363
365
  tmp.clear();
364
- vt.advance();
366
+ vt->advance();
365
367
  }
366
368
  } // omp parallel
367
369
 
@@ -531,19 +533,21 @@ void NSG::add_reverse_links(
531
533
 
532
534
  int NSG::tree_grow(Index* storage, std::vector<int>& degrees) {
533
535
  int root = enterpoint;
534
- VisitedTable vt(ntotal, use_visited_hashset);
535
- VisitedTable vt2(ntotal, use_visited_hashset);
536
+ std::unique_ptr<VisitedTable> vt =
537
+ VisitedTable::create(ntotal, use_visited_hashset);
538
+ std::unique_ptr<VisitedTable> vt2 =
539
+ VisitedTable::create(ntotal, use_visited_hashset);
536
540
 
537
541
  int num_attached = 0;
538
542
  int cnt = 0;
539
543
  while (true) {
540
- cnt = dfs(vt, root, cnt);
544
+ cnt = dfs(*vt, root, cnt);
541
545
  if (cnt >= ntotal) {
542
546
  break;
543
547
  }
544
548
 
545
- root = attach_unlinked(storage, vt, vt2, degrees);
546
- vt2.advance();
549
+ root = attach_unlinked(storage, *vt, *vt2, degrees);
550
+ vt2->advance();
547
551
  num_attached += 1;
548
552
  }
549
553
 
@@ -36,7 +36,12 @@ namespace faiss {
36
36
  /// from active_indices (subsequent levels after pruning).
37
37
  /// @tparam LevelWidth Compile-time level width in floats (0 = use runtime
38
38
  /// level_width_dims). Enables full loop unrolling.
39
+ // Skip pragmas under nvcc: its EDG frontend warns on `#pragma GCC optimize`
40
+ // (#1675-D) for every `.cu` that transitively includes this header. These
41
+ // templates are CPU-only, so the hint is irrelevant during nvcc parse.
42
+ #if !defined(__NVCC__)
39
43
  FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
44
+ #endif
40
45
  template <bool AllActive = false, size_t LevelWidth = 0>
41
46
  static inline void compute_level_dot_kernel(
42
47
  const float* FAISS_RESTRICT query_level,
@@ -83,7 +88,9 @@ static inline void compute_level_dot_kernel(
83
88
  dot_products[i] = dp;
84
89
  }
85
90
  }
91
+ #if !defined(__NVCC__)
86
92
  FAISS_PRAGMA_IMPRECISE_FUNCTION_END
93
+ #endif
87
94
 
88
95
  /// Update exact distances with the current level's dot products, then apply
89
96
  /// Panorama pruning: for each active vector, compute a lower bound on
@@ -92,7 +99,9 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
92
99
  ///
93
100
  /// Uses `if constexpr` on C::is_max rather than C::cmp() to ensure the
94
101
  /// comparison autovectorizes (C::cmp generates scalar function calls).
102
+ #if !defined(__NVCC__)
95
103
  FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
104
+ #endif
96
105
  template <bool AllActive, typename C, MetricType M>
97
106
  static inline void prune_kernel(
98
107
  float* FAISS_RESTRICT exact_distances,
@@ -128,7 +137,9 @@ static inline void prune_kernel(
128
137
  }
129
138
  }
130
139
  }
140
+ #if !defined(__NVCC__)
131
141
  FAISS_PRAGMA_IMPRECISE_FUNCTION_END
142
+ #endif
132
143
 
133
144
  /// Compact active_indices in-place, removing entries where active_byteset[i]
134
145
  /// is zero. Returns the new count of active elements. Uses a branchless BMI2 +
@@ -13,12 +13,15 @@
13
13
  #include <cstdio>
14
14
  #include <cstring>
15
15
  #include <memory>
16
+ #include <type_traits>
16
17
 
17
18
  #include <algorithm>
18
19
 
19
20
  #include <faiss/IndexFlat.h>
20
21
  #include <faiss/VectorTransform.h>
21
22
  #include <faiss/impl/FaissAssert.h>
23
+ // NOLINTNEXTLINE(facebook-hte-InlineHeader)
24
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
22
25
  #include <faiss/impl/simd_dispatch.h>
23
26
  #include <faiss/utils/distances.h>
24
27
 
@@ -719,8 +722,28 @@ void pq_knn_search_with_tables(
719
722
 
720
723
  switch (nbits) {
721
724
  case 8:
722
- pq_estimators_from_tables<uint8_t, C>(
723
- pq, codes, ncodes, dis_table, k, heap_dis, heap_ids);
725
+ if (ksub == 256) {
726
+ constexpr bool max_heap =
727
+ std::is_same_v<C, CMax<float, int64_t>>;
728
+ pq_code_distance::pq_scan_8bit(
729
+ M,
730
+ dis_table,
731
+ codes,
732
+ ncodes,
733
+ k,
734
+ heap_dis,
735
+ heap_ids,
736
+ max_heap);
737
+ } else {
738
+ pq_estimators_from_tables<uint8_t, C>(
739
+ pq,
740
+ codes,
741
+ ncodes,
742
+ dis_table,
743
+ k,
744
+ heap_dis,
745
+ heap_ids);
746
+ }
724
747
  break;
725
748
 
726
749
  case 16:
@@ -321,7 +321,7 @@ float compute_full_multibit_distance(
321
321
  size_t d,
322
322
  size_t ex_bits,
323
323
  MetricType metric_type) {
324
- return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
324
+ return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0_SPR>(
325
325
  [&]<SIMDLevel SL>() {
326
326
  return compute_full_multibit_distance<SL>(
327
327
  sign_bits,
@@ -551,7 +551,13 @@ FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
551
551
  // Dispatch on SIMDLevel once here so the distance computer methods
552
552
  // call the SIMD-specialized rabitq functions directly (no per-call
553
553
  // with_simd_level overhead).
554
- return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0>(
554
+ //
555
+ // Use A0_SPR (which includes AVX512_SPR) so that on Sapphire Rapids
556
+ // and later x86 microarchitectures the VPOPCNTDQ-based RaBitQ
557
+ // specialization in rabitq_avx512_spr.cpp is selected. On AVX-512
558
+ // CPUs without VPOPCNTDQ, dispatch falls through to the AVX512
559
+ // specialization in rabitq_avx512.cpp.
560
+ return with_selected_simd_levels<AVAILABLE_SIMD_LEVELS_A0_SPR>(
555
561
  [&]<SIMDLevel SL>() -> FlatCodesDistanceComputer* {
556
562
  if (qb == 0) {
557
563
  auto dc =
@@ -229,6 +229,7 @@ struct Top1BlockResultHandler : TopkBlockResultHandler<C, use_sel> {
229
229
 
230
230
  for (size_t i = i0; i < i1; i++) {
231
231
  this->dis_tab[i] = C::neutral();
232
+ this->ids_tab[i] = -1;
232
233
  }
233
234
  }
234
235
 
@@ -25,6 +25,218 @@
25
25
 
26
26
  namespace faiss {
27
27
 
28
+ namespace {
29
+
30
+ // Gaussian Lloyd-Max optimal quantizer centroids and boundaries for N(0,1).
31
+ // clang-format off
32
+ const float kLloydMaxCentroids1[] = {
33
+ -0.797884560802865f, 0.797884560802865f
34
+ };
35
+ const float kLloydMaxBoundaries1[] = {
36
+ 0.000000000000000f
37
+ };
38
+ const float kLloydMaxCentroids2[] = {
39
+ -1.510417608499078f, -0.452780034636484f,
40
+ 0.452780034636483f, 1.510417608499078f
41
+ };
42
+ const float kLloydMaxBoundaries2[] = {
43
+ -0.981598821567781f, 0.000000000000000f, 0.981598821567781f
44
+ };
45
+ const float kLloydMaxCentroids3[] = {
46
+ -2.151945704536914f, -1.343909278504930f,
47
+ -0.756005281205826f, -0.245094178944203f,
48
+ 0.245094178944203f, 0.756005281205825f,
49
+ 1.343909278504930f, 2.151945704536914f
50
+ };
51
+ const float kLloydMaxBoundaries3[] = {
52
+ -1.747927491520922f, -1.049957279855378f,
53
+ -0.500549730075014f, 0.000000000000000f,
54
+ 0.500549730075014f, 1.049957279855378f,
55
+ 1.747927491520922f
56
+ };
57
+ const float kLloydMaxCentroids4[] = {
58
+ -2.732589570994957f, -2.069017226531159f,
59
+ -1.618046386021649f, -1.256231197346957f,
60
+ -0.942340456486774f, -0.656759118532318f,
61
+ -0.388048299490198f, -0.128395029851116f,
62
+ 0.128395029851116f, 0.388048299490198f,
63
+ 0.656759118532318f, 0.942340456486773f,
64
+ 1.256231197346959f, 1.618046386021649f,
65
+ 2.069017226531160f, 2.732589570994943f
66
+ };
67
+ const float kLloydMaxBoundaries4[] = {
68
+ -2.400803398763058f, -1.843531806276404f,
69
+ -1.437138791684303f, -1.099285826916865f,
70
+ -0.799549787509546f, -0.522403709011258f,
71
+ -0.258221664670657f, 0.000000000000000f,
72
+ 0.258221664670657f, 0.522403709011258f,
73
+ 0.799549787509546f, 1.099285826916866f,
74
+ 1.437138791684304f, 1.843531806276404f,
75
+ 2.400803398763051f
76
+ };
77
+ const float kLloydMaxCentroids8[] = {
78
+ -4.2734901319f, -3.8270895246f, -3.5457169520f, -3.3354593381f,
79
+ -3.1655721017f, -3.0219515320f, -2.8969009924f, -2.7857394515f,
80
+ -2.6853990170f, -2.5937556343f, -2.5092755166f, -2.4308135619f,
81
+ -2.3574913691f, -2.2886197969f, -2.2236478246f, -2.1621276457f,
82
+ -2.1036901632f, -2.0480273642f, -1.9948793740f, -1.9440247677f,
83
+ -1.8952732015f, -1.8484597247f, -1.8034403315f, -1.7600884415f,
84
+ -1.7182920846f, -1.6779516274f, -1.6389779215f, -1.6012907825f,
85
+ -1.5648177311f, -1.5294929453f, -1.4952563823f, -1.4620530375f,
86
+ -1.4298323186f, -1.3985475108f, -1.3681553217f, -1.3386154890f,
87
+ -1.3098904444f, -1.2819450217f, -1.2547462051f, -1.2282629097f,
88
+ -1.2024657910f, -1.1773270781f, -1.1528204287f, -1.1289208010f,
89
+ -1.1056043421f, -1.0828482901f, -1.0606308873f, -1.0389313043f,
90
+ -1.0177295729f, -0.9970065268f, -0.9767437492f, -0.9569235264f,
91
+ -0.9375288069f, -0.9185431646f, -0.8999507663f, -0.8817363426f,
92
+ -0.8638851621f, -0.8463830081f, -0.8292161569f, -0.8123713596f,
93
+ -0.7958358242f, -0.7795971999f, -0.7636435625f, -0.7479634007f,
94
+ -0.7325456038f, -0.7173794494f, -0.7024545929f, -0.6877610560f,
95
+ -0.6732892172f, -0.6590298016f, -0.6449738716f, -0.6311128174f,
96
+ -0.6174383481f, -0.6039424829f, -0.5906175419f, -0.5774561379f,
97
+ -0.5644511676f, -0.5515958029f, -0.5388834832f, -0.5263079060f,
98
+ -0.5138630194f, -0.5015430136f, -0.4893423125f, -0.4772555660f,
99
+ -0.4652776416f, -0.4534036165f, -0.4416287701f, -0.4299485757f,
100
+ -0.4183586932f, -0.4068549615f, -0.3954333909f, -0.3840901561f,
101
+ -0.3728215889f, -0.3616241712f, -0.3504945283f, -0.3394294221f,
102
+ -0.3284257446f, -0.3174805116f, -0.3065908567f, -0.2957540250f,
103
+ -0.2849673675f, -0.2742283355f, -0.2635344752f, -0.2528834222f,
104
+ -0.2422728967f, -0.2317006985f, -0.2211647022f, -0.2106628526f,
105
+ -0.2001931607f, -0.1897536989f, -0.1793425974f, -0.1689580400f,
106
+ -0.1585982605f, -0.1482615390f, -0.1379461985f, -0.1276506012f,
107
+ -0.1173731457f, -0.1071122637f, -0.0968664166f, -0.0866340933f,
108
+ -0.0764138065f, -0.0662040909f, -0.0560034994f, -0.0458106014f,
109
+ -0.0356239797f, -0.0254422284f, -0.0152639496f, -0.0050877521f,
110
+ 0.0050877521f, 0.0152639496f, 0.0254422284f, 0.0356239797f,
111
+ 0.0458106014f, 0.0560034994f, 0.0662040909f, 0.0764138065f,
112
+ 0.0866340933f, 0.0968664166f, 0.1071122637f, 0.1173731457f,
113
+ 0.1276506012f, 0.1379461985f, 0.1482615390f, 0.1585982605f,
114
+ 0.1689580400f, 0.1793425974f, 0.1897536989f, 0.2001931607f,
115
+ 0.2106628526f, 0.2211647022f, 0.2317006985f, 0.2422728967f,
116
+ 0.2528834222f, 0.2635344752f, 0.2742283355f, 0.2849673675f,
117
+ 0.2957540250f, 0.3065908567f, 0.3174805116f, 0.3284257446f,
118
+ 0.3394294221f, 0.3504945283f, 0.3616241712f, 0.3728215889f,
119
+ 0.3840901561f, 0.3954333909f, 0.4068549615f, 0.4183586932f,
120
+ 0.4299485757f, 0.4416287701f, 0.4534036165f, 0.4652776416f,
121
+ 0.4772555660f, 0.4893423125f, 0.5015430136f, 0.5138630194f,
122
+ 0.5263079060f, 0.5388834832f, 0.5515958029f, 0.5644511676f,
123
+ 0.5774561379f, 0.5906175419f, 0.6039424829f, 0.6174383481f,
124
+ 0.6311128174f, 0.6449738716f, 0.6590298016f, 0.6732892172f,
125
+ 0.6877610560f, 0.7024545929f, 0.7173794494f, 0.7325456038f,
126
+ 0.7479634007f, 0.7636435625f, 0.7795971999f, 0.7958358242f,
127
+ 0.8123713596f, 0.8292161569f, 0.8463830081f, 0.8638851621f,
128
+ 0.8817363426f, 0.8999507663f, 0.9185431646f, 0.9375288069f,
129
+ 0.9569235264f, 0.9767437492f, 0.9970065268f, 1.0177295729f,
130
+ 1.0389313043f, 1.0606308873f, 1.0828482901f, 1.1056043421f,
131
+ 1.1289208010f, 1.1528204287f, 1.1773270781f, 1.2024657910f,
132
+ 1.2282629097f, 1.2547462051f, 1.2819450217f, 1.3098904444f,
133
+ 1.3386154890f, 1.3681553217f, 1.3985475108f, 1.4298323186f,
134
+ 1.4620530375f, 1.4952563823f, 1.5294929453f, 1.5648177311f,
135
+ 1.6012907825f, 1.6389779215f, 1.6779516274f, 1.7182920846f,
136
+ 1.7600884415f, 1.8034403315f, 1.8484597247f, 1.8952732015f,
137
+ 1.9440247677f, 1.9948793740f, 2.0480273642f, 2.1036901632f,
138
+ 2.1621276457f, 2.2236478246f, 2.2886197969f, 2.3574913691f,
139
+ 2.4308135619f, 2.5092755166f, 2.5937556343f, 2.6853990170f,
140
+ 2.7857394515f, 2.8969009924f, 3.0219515320f, 3.1655721017f,
141
+ 3.3354593381f, 3.5457169520f, 3.8270895246f, 4.2734901319f
142
+ };
143
+ const float kLloydMaxBoundaries8[] = {
144
+ -4.0502898282f, -3.6864032383f, -3.4405881450f, -3.2505157199f,
145
+ -3.0937618168f, -2.9594262622f, -2.8413202220f, -2.7355692343f,
146
+ -2.6395773257f, -2.5515155755f, -2.4700445392f, -2.3941524655f,
147
+ -2.3230555830f, -2.2561338107f, -2.1928877352f, -2.1329089044f,
148
+ -2.0758587637f, -2.0214533691f, -1.9694520708f, -1.9196489846f,
149
+ -1.8718664631f, -1.8259500281f, -1.7817643865f, -1.7391902630f,
150
+ -1.6981218560f, -1.6584647744f, -1.6201343520f, -1.5830542568f,
151
+ -1.5471553382f, -1.5123746638f, -1.4786547099f, -1.4459426781f,
152
+ -1.4141899147f, -1.3833514163f, -1.3533854053f, -1.3242529667f,
153
+ -1.2959177331f, -1.2683456134f, -1.2415045574f, -1.2153643503f,
154
+ -1.1898964346f, -1.1650737534f, -1.1408706148f, -1.1172625715f,
155
+ -1.0942263161f, -1.0717395887f, -1.0497810958f, -1.0283304386f,
156
+ -1.0073680499f, -0.9868751380f, -0.9668336378f, -0.9472261667f,
157
+ -0.9280359858f, -0.9092469654f, -0.8908435544f, -0.8728107524f,
158
+ -0.8551340851f, -0.8377995825f, -0.8207937582f, -0.8041035919f,
159
+ -0.7877165121f, -0.7716203812f, -0.7558034816f, -0.7402545023f,
160
+ -0.7249625266f, -0.7099170212f, -0.6951078244f, -0.6805251366f,
161
+ -0.6661595094f, -0.6520018366f, -0.6380433445f, -0.6242755828f,
162
+ -0.6106904155f, -0.5972800124f, -0.5840368399f, -0.5709536527f,
163
+ -0.5580234853f, -0.5452396431f, -0.5325956946f, -0.5200854627f,
164
+ -0.5077030165f, -0.4954426631f, -0.4832989393f, -0.4712666038f,
165
+ -0.4593406291f, -0.4475161933f, -0.4357886729f, -0.4241536345f,
166
+ -0.4126068274f, -0.4011441762f, -0.3897617735f, -0.3784558725f,
167
+ -0.3672228800f, -0.3560593498f, -0.3449619752f, -0.3339275834f,
168
+ -0.3229531281f, -0.3120356842f, -0.3011724408f, -0.2903606962f,
169
+ -0.2795978515f, -0.2688814053f, -0.2582089487f, -0.2475781595f,
170
+ -0.2369867976f, -0.2264327004f, -0.2159137774f, -0.2054280067f,
171
+ -0.1949734298f, -0.1845481481f, -0.1741503187f, -0.1637781502f,
172
+ -0.1534298998f, -0.1431038688f, -0.1327983999f, -0.1225118735f,
173
+ -0.1122427047f, -0.1019893401f, -0.0917502549f, -0.0815239499f,
174
+ -0.0713089487f, -0.0611037951f, -0.0509070504f, -0.0407172906f,
175
+ -0.0305331041f, -0.0203530890f, -0.0101758509f, 0.0000000000f,
176
+ 0.0101758509f, 0.0203530890f, 0.0305331041f, 0.0407172906f,
177
+ 0.0509070504f, 0.0611037951f, 0.0713089487f, 0.0815239499f,
178
+ 0.0917502549f, 0.1019893401f, 0.1122427047f, 0.1225118735f,
179
+ 0.1327983999f, 0.1431038688f, 0.1534298998f, 0.1637781502f,
180
+ 0.1741503187f, 0.1845481481f, 0.1949734298f, 0.2054280067f,
181
+ 0.2159137774f, 0.2264327004f, 0.2369867976f, 0.2475781595f,
182
+ 0.2582089487f, 0.2688814053f, 0.2795978515f, 0.2903606962f,
183
+ 0.3011724408f, 0.3120356842f, 0.3229531281f, 0.3339275834f,
184
+ 0.3449619752f, 0.3560593498f, 0.3672228800f, 0.3784558725f,
185
+ 0.3897617735f, 0.4011441762f, 0.4126068274f, 0.4241536345f,
186
+ 0.4357886729f, 0.4475161933f, 0.4593406291f, 0.4712666038f,
187
+ 0.4832989393f, 0.4954426631f, 0.5077030165f, 0.5200854627f,
188
+ 0.5325956946f, 0.5452396431f, 0.5580234853f, 0.5709536527f,
189
+ 0.5840368399f, 0.5972800124f, 0.6106904155f, 0.6242755828f,
190
+ 0.6380433445f, 0.6520018366f, 0.6661595094f, 0.6805251366f,
191
+ 0.6951078244f, 0.7099170212f, 0.7249625266f, 0.7402545023f,
192
+ 0.7558034816f, 0.7716203812f, 0.7877165121f, 0.8041035919f,
193
+ 0.8207937582f, 0.8377995825f, 0.8551340851f, 0.8728107524f,
194
+ 0.8908435544f, 0.9092469654f, 0.9280359858f, 0.9472261667f,
195
+ 0.9668336378f, 0.9868751380f, 1.0073680499f, 1.0283304386f,
196
+ 1.0497810958f, 1.0717395887f, 1.0942263161f, 1.1172625715f,
197
+ 1.1408706148f, 1.1650737534f, 1.1898964346f, 1.2153643503f,
198
+ 1.2415045574f, 1.2683456134f, 1.2959177331f, 1.3242529667f,
199
+ 1.3533854053f, 1.3833514163f, 1.4141899147f, 1.4459426781f,
200
+ 1.4786547099f, 1.5123746638f, 1.5471553382f, 1.5830542568f,
201
+ 1.6201343520f, 1.6584647744f, 1.6981218560f, 1.7391902630f,
202
+ 1.7817643865f, 1.8259500281f, 1.8718664631f, 1.9196489846f,
203
+ 1.9694520708f, 2.0214533691f, 2.0758587637f, 2.1329089044f,
204
+ 2.1928877352f, 2.2561338107f, 2.3230555830f, 2.3941524655f,
205
+ 2.4700445392f, 2.5515155755f, 2.6395773257f, 2.7355692343f,
206
+ 2.8413202220f, 2.9594262622f, 3.0937618168f, 3.2505157199f,
207
+ 3.4405881450f, 3.6864032383f, 4.0502898282f
208
+ };
209
+ // clang-format on
210
+
211
+ struct LloydMaxTable {
212
+ const float* centroids;
213
+ const float* boundaries;
214
+ };
215
+
216
+ const LloydMaxTable kLloydMaxTables[] = {
217
+ {nullptr, nullptr}, // 0
218
+ {kLloydMaxCentroids1, kLloydMaxBoundaries1}, // 1
219
+ {kLloydMaxCentroids2, kLloydMaxBoundaries2}, // 2
220
+ {kLloydMaxCentroids3, kLloydMaxBoundaries3}, // 3
221
+ {kLloydMaxCentroids4, kLloydMaxBoundaries4}, // 4
222
+ {nullptr, nullptr}, // 5 (unused)
223
+ {nullptr, nullptr}, // 6 (unused)
224
+ {nullptr, nullptr}, // 7 (unused)
225
+ {kLloydMaxCentroids8, kLloydMaxBoundaries8}, // 8
226
+ };
227
+
228
+ void populate_lloyd_max_trained(size_t mse_bits, std::vector<float>& trained) {
229
+ FAISS_THROW_IF_NOT(mse_bits >= 1 && mse_bits <= 8);
230
+ FAISS_THROW_IF_NOT(kLloydMaxTables[mse_bits].centroids != nullptr);
231
+ size_t k = size_t(1) << mse_bits;
232
+ const auto& t = kLloydMaxTables[mse_bits];
233
+ trained.resize(k + (k - 1));
234
+ std::copy(t.centroids, t.centroids + k, trained.begin());
235
+ std::copy(t.boundaries, t.boundaries + k - 1, trained.begin() + k);
236
+ }
237
+
238
+ } // namespace
239
+
28
240
  /*******************************************************************
29
241
  * ScalarQuantizer implementation
30
242
  ********************************************************************/
@@ -80,6 +292,24 @@ void ScalarQuantizer::set_derived_sizes() {
80
292
  code_size = 0;
81
293
  bits = 0;
82
294
  break;
295
+ case QT_2bit_tq:
296
+ case QT_3bit_tq:
297
+ case QT_4bit_tq:
298
+ case QT_5bit_tq: {
299
+ size_t nb_bits = (qtype == QT_2bit_tq) ? 2
300
+ : (qtype == QT_3bit_tq) ? 3
301
+ : (qtype == QT_4bit_tq) ? 4
302
+ : (qtype == QT_5bit_tq) ? 5
303
+ : 0;
304
+ FAISS_THROW_IF_NOT_MSG(nb_bits > 0, "unexpected TurboQ qtype");
305
+ size_t mse_bits = nb_bits - 1;
306
+ size_t mse_bytes = mse_bits * ((d + 7) / 8);
307
+ size_t qjl_bytes = (d + 7) / 8;
308
+ code_size = mse_bytes + qjl_bytes +
309
+ sizeof(scalar_quantizer::SQTurboQFactors);
310
+ bits = nb_bits;
311
+ break;
312
+ }
83
313
  default:
84
314
  break;
85
315
  }
@@ -134,27 +364,60 @@ void ScalarQuantizer::train(size_t n, const float* x) {
134
364
  // no training necessary
135
365
  break;
136
366
  case QT_1bit_tqmse:
137
- scalar_quantizer::train_TurboQuantMSE(d, 1, trained);
367
+ populate_lloyd_max_trained(1, trained);
138
368
  break;
139
369
  case QT_2bit_tqmse:
140
- scalar_quantizer::train_TurboQuantMSE(d, 2, trained);
370
+ populate_lloyd_max_trained(2, trained);
141
371
  break;
142
372
  case QT_3bit_tqmse:
143
- scalar_quantizer::train_TurboQuantMSE(d, 3, trained);
373
+ populate_lloyd_max_trained(3, trained);
144
374
  break;
145
375
  case QT_4bit_tqmse:
146
- scalar_quantizer::train_TurboQuantMSE(d, 4, trained);
376
+ populate_lloyd_max_trained(4, trained);
147
377
  break;
148
378
  case QT_8bit_tqmse:
149
- scalar_quantizer::train_TurboQuantMSE(d, 8, trained);
379
+ populate_lloyd_max_trained(8, trained);
380
+ break;
381
+ case QT_2bit_tq:
382
+ case QT_3bit_tq:
383
+ case QT_4bit_tq:
384
+ case QT_5bit_tq: {
385
+ size_t mse_bits = bits - 1;
386
+ populate_lloyd_max_trained(mse_bits, trained);
387
+ // Pack seed and qjl_type at end of trained for dispatch
388
+ float seed_f[2];
389
+ TurboQuantRefine::pack_seed(turboq_refine.seed, seed_f);
390
+ trained.push_back(seed_f[0]);
391
+ trained.push_back(seed_f[1]);
392
+ trained.push_back(static_cast<float>(turboq_refine.qjl_type));
393
+ turboq_refine.init_projection(d);
150
394
  break;
395
+ }
151
396
  default:
152
397
  break;
153
398
  }
154
399
  }
155
400
 
401
+ void ScalarQuantizer::TurboQuantRefine::init_projection(size_t d) {
402
+ if (use_fwht()) {
403
+ padded_d = 1;
404
+ while (padded_d < d) {
405
+ padded_d <<= 1;
406
+ }
407
+ fwht_signs.resize(padded_d);
408
+ RandomGenerator rng(seed);
409
+ for (size_t i = 0; i < padded_d; i++) {
410
+ fwht_signs[i] = (rng.rand_int(2) == 0) ? 1.0f : -1.0f;
411
+ }
412
+ } else {
413
+ rr_matrix.resize(d * d);
414
+ float_randn(rr_matrix.data(), d * d, seed);
415
+ matrix_qr(static_cast<int>(d), static_cast<int>(d), rr_matrix.data());
416
+ }
417
+ }
418
+
156
419
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
157
- return with_simd_level([&]<SIMDLevel SL>() -> SQuantizer* {
420
+ return with_simd_level_spr([&]<SIMDLevel SL>() -> SQuantizer* {
158
421
  if constexpr (SL != SIMDLevel::NONE) {
159
422
  auto* q = scalar_quantizer::sq_select_quantizer<SL>(
160
423
  qtype, d, trained);
@@ -197,7 +460,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
197
460
  ScalarQuantizer::SQDistanceComputer* ScalarQuantizer::get_distance_computer(
198
461
  MetricType metric) const {
199
462
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
200
- return with_simd_level([&]<SIMDLevel SL>() -> SQDistanceComputer* {
463
+ return with_simd_level_spr([&]<SIMDLevel SL>() -> SQDistanceComputer* {
201
464
  if constexpr (SL != SIMDLevel::NONE) {
202
465
  auto* dc = scalar_quantizer::sq_select_distance_computer<SL>(
203
466
  metric, qtype, d, trained);
@@ -216,7 +479,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
216
479
  bool store_pairs,
217
480
  const IDSelector* sel,
218
481
  bool by_residual) const {
219
- return with_simd_level([&]<SIMDLevel SL>() -> InvertedListScanner* {
482
+ return with_simd_level_spr([&]<SIMDLevel SL>() -> InvertedListScanner* {
220
483
  if constexpr (SL != SIMDLevel::NONE) {
221
484
  auto* s = scalar_quantizer::sq_select_InvertedListScanner<SL>(
222
485
  qtype,