faiss 0.4.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/ext/faiss/index.cpp +36 -10
  4. data/ext/faiss/index_binary.cpp +19 -6
  5. data/ext/faiss/kmeans.cpp +6 -6
  6. data/ext/faiss/numo.hpp +273 -123
  7. data/lib/faiss/version.rb +1 -1
  8. data/vendor/faiss/faiss/AutoTune.cpp +2 -3
  9. data/vendor/faiss/faiss/AutoTune.h +1 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +2 -2
  11. data/vendor/faiss/faiss/Clustering.h +2 -2
  12. data/vendor/faiss/faiss/IVFlib.cpp +1 -2
  13. data/vendor/faiss/faiss/IVFlib.h +1 -1
  14. data/vendor/faiss/faiss/Index.h +10 -10
  15. data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
  16. data/vendor/faiss/faiss/Index2Layer.h +2 -2
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
  19. data/vendor/faiss/faiss/IndexBinary.h +7 -7
  20. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
  22. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  23. data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
  24. data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
  27. data/vendor/faiss/faiss/IndexFastScan.h +107 -7
  28. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  29. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
  30. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  31. data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
  32. data/vendor/faiss/faiss/IndexIDMap.h +6 -6
  33. data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.h +5 -5
  35. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
  42. data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
  43. data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
  44. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
  46. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
  51. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  54. data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
  56. data/vendor/faiss/faiss/IndexPQ.h +1 -1
  57. data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
  58. data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
  59. data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
  60. data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
  62. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
  63. data/vendor/faiss/faiss/IndexShards.cpp +1 -1
  64. data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
  65. data/vendor/faiss/faiss/MetricType.h +1 -1
  66. data/vendor/faiss/faiss/VectorTransform.h +2 -2
  67. data/vendor/faiss/faiss/clone_index.cpp +3 -1
  68. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
  70. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
  71. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
  72. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
  73. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
  74. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
  75. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
  76. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
  77. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
  78. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
  79. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
  80. data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
  81. data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
  82. data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
  83. data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
  84. data/vendor/faiss/faiss/impl/HNSW.h +4 -4
  85. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
  86. data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
  87. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
  88. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
  89. data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
  90. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  91. data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
  92. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  93. data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
  94. data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
  95. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
  96. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
  97. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  98. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
  99. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  100. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
  101. data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
  102. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
  103. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
  104. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
  105. data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
  106. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
  108. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
  109. data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
  110. data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
  111. data/vendor/faiss/faiss/impl/io.cpp +2 -2
  112. data/vendor/faiss/faiss/impl/io.h +4 -4
  113. data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
  114. data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
  115. data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
  116. data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
  117. data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
  118. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
  119. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
  120. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
  121. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
  122. data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
  123. data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
  124. data/vendor/faiss/faiss/index_factory.cpp +43 -1
  125. data/vendor/faiss/faiss/index_factory.h +1 -1
  126. data/vendor/faiss/faiss/index_io.h +1 -1
  127. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
  128. data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
  129. data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
  130. data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
  131. data/vendor/faiss/faiss/utils/Heap.h +3 -3
  132. data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
  133. data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
  134. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
  135. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
  136. data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
  137. data/vendor/faiss/faiss/utils/distances.h +2 -2
  138. data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
  139. data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
  140. data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
  141. data/vendor/faiss/faiss/utils/hamming.h +1 -1
  142. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
  143. data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
  144. data/vendor/faiss/faiss/utils/partitioning.h +2 -2
  145. data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
  146. data/vendor/faiss/faiss/utils/random.cpp +1 -1
  147. data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
  148. data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
  149. data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
  150. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
  151. data/vendor/faiss/faiss/utils/utils.cpp +5 -2
  152. data/vendor/faiss/faiss/utils/utils.h +2 -2
  153. metadata +14 -3
@@ -116,7 +116,7 @@ struct IDSelectorBitmap : IDSelector {
116
116
  /** reverts the membership test of another selector */
117
117
  struct IDSelectorNot : IDSelector {
118
118
  const IDSelector* sel;
119
- IDSelectorNot(const IDSelector* sel) : sel(sel) {}
119
+ explicit IDSelectorNot(const IDSelector* sel) : sel(sel) {}
120
120
  bool is_member(idx_t id) const final {
121
121
  return !sel->is_member(id);
122
122
  }
@@ -30,7 +30,7 @@
30
30
  #endif
31
31
 
32
32
  extern "C" {
33
- // LU decomoposition of a general matrix
33
+ // LU decomposition of a general matrix
34
34
  void sgetrf_(
35
35
  FINTEGER* m,
36
36
  FINTEGER* n,
@@ -65,7 +65,7 @@ int sgemm_(
65
65
  float* c,
66
66
  FINTEGER* ldc);
67
67
 
68
- // LU decomoposition of a general matrix
68
+ // LU decomposition of a general matrix
69
69
  void dgetrf_(
70
70
  FINTEGER* m,
71
71
  FINTEGER* n,
@@ -189,7 +189,7 @@ void LocalSearchQuantizer::train(size_t n, const float* x) {
189
189
  std::vector<int32_t> codes(n * M); // [n, M]
190
190
  random_int32(codes, 0, K - 1, gen);
191
191
 
192
- // compute standard derivations of each dimension
192
+ // compute standard deviations of each dimension
193
193
  std::vector<float> stddev(d, 0);
194
194
 
195
195
  #pragma omp parallel for
@@ -487,7 +487,7 @@ void LocalSearchQuantizer::update_codebooks(
487
487
  * L = (X - \sum cj)^2, j = 1, ..., M
488
488
  * L = X^2 - 2X * \sum cj + (\sum cj)^2
489
489
  *
490
- * X^2 is negligable since it is the same for all possible value
490
+ * X^2 is negligible since it is the same for all possible value
491
491
  * k of the m-th subcode.
492
492
  *
493
493
  * 2X * \sum cj is the unary term
@@ -138,7 +138,7 @@ struct LocalSearchQuantizer : AdditiveQuantizer {
138
138
  /** Add some perturbation to codebooks
139
139
  *
140
140
  * @param T temperature of simulated annealing
141
- * @param stddev standard derivations of each dimension in training data
141
+ * @param stddev standard deviations of each dimension in training data
142
142
  */
143
143
  void perturb_codebooks(
144
144
  float T,
@@ -63,7 +63,7 @@ struct DummyScaler {
63
63
  };
64
64
 
65
65
  /// consumes 2x4 bits to encode a norm as a scalar additive quantizer
66
- /// the norm is scaled because its range if larger than other components
66
+ /// the norm is scaled because its range is larger than other components
67
67
  struct NormTableScaler {
68
68
  static constexpr int nscale = 2;
69
69
  int scale_int;
@@ -177,7 +177,7 @@ void NNDescent::join(DistanceComputer& qdis) {
177
177
  }
178
178
  }
179
179
 
180
- /// Sample neighbors for each node to peform local join later
180
+ /// Sample neighbors for each node to perform local join later
181
181
  /// Store them in nn_new and nn_old
182
182
  void NNDescent::update() {
183
183
  // Step 1.
@@ -34,7 +34,7 @@ namespace faiss {
34
34
  *
35
35
  * Dong, Wei, Charikar Moses, and Kai Li, WWW 2011
36
36
  *
37
- * This implmentation is heavily influenced by the efanna
37
+ * This implementation is heavily influenced by the efanna
38
38
  * implementation by Cong Fu and the KGraph library by Wei Dong
39
39
  * (https://github.com/ZJULearning/efanna_graph)
40
40
  * (https://github.com/aaalgo/kgraph)
@@ -117,7 +117,7 @@ struct NNDescent {
117
117
  /// Perform local join on each node
118
118
  void join(DistanceComputer& qdis);
119
119
 
120
- /// Sample new neighbors for each node to peform local join later
120
+ /// Sample new neighbors for each node to perform local join later
121
121
  void update();
122
122
 
123
123
  /// Sample a small number of points to evaluate the quality of KNNG built
@@ -621,7 +621,7 @@ int NSG::attach_unlinked(
621
621
  }
622
622
  }
623
623
 
624
- // randomly choice annother node
624
+ // randomly choice another node
625
625
  if (!found) {
626
626
  do {
627
627
  node = rng.rand_int(ntotal);
@@ -0,0 +1,33 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/PanoramaStats.h>
11
+
12
+ namespace faiss {
13
+
14
+ void PanoramaStats::reset() {
15
+ total_dims_scanned = 0;
16
+ total_dims = 0;
17
+ ratio_dims_scanned = 1.0f;
18
+ }
19
+
20
+ void PanoramaStats::add(const PanoramaStats& other) {
21
+ total_dims_scanned += other.total_dims_scanned;
22
+ total_dims += other.total_dims;
23
+ if (total_dims > 0) {
24
+ ratio_dims_scanned =
25
+ static_cast<float>(total_dims_scanned) / total_dims;
26
+ } else {
27
+ ratio_dims_scanned = 1.0f;
28
+ }
29
+ }
30
+
31
+ PanoramaStats indexPanorama_stats;
32
+
33
+ } // namespace faiss
@@ -0,0 +1,38 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_PANORAMA_STATS_H
11
+ #define FAISS_PANORAMA_STATS_H
12
+
13
+ #include <faiss/impl/platform_macros.h>
14
+
15
+ namespace faiss {
16
+
17
+ /// Statistics are not robust to internal threading nor to
18
+ /// concurrent Panorama searches. Use these values in a
19
+ /// single-threaded context to accurately gauge Panorama's
20
+ /// pruning effectiveness.
21
+ struct PanoramaStats {
22
+ uint64_t total_dims_scanned = 0; // total dimensions scanned
23
+ uint64_t total_dims = 0; // total dimensions
24
+ float ratio_dims_scanned = 1.0f; // fraction of dimensions actually scanned
25
+
26
+ PanoramaStats() {
27
+ reset();
28
+ }
29
+ void reset();
30
+ void add(const PanoramaStats& other);
31
+ };
32
+
33
+ // Single global var for all Panorama indexes
34
+ FAISS_API extern PanoramaStats indexPanorama_stats;
35
+
36
+ } // namespace faiss
37
+
38
+ #endif
@@ -178,7 +178,7 @@ struct ReproduceWithHammingObjective : PermutationObjective {
178
178
  return x * x;
179
179
  }
180
180
 
181
- // weihgting of distances: it is more important to reproduce small
181
+ // weighting of distances: it is more important to reproduce small
182
182
  // distances well
183
183
  double dis_weight(double x) const {
184
184
  return exp(-dis_weight_factor * x);
@@ -295,7 +295,7 @@ struct ReproduceWithHammingObjective : PermutationObjective {
295
295
 
296
296
  } // anonymous namespace
297
297
 
298
- // weihgting of distances: it is more important to reproduce small
298
+ // weighting of distances: it is more important to reproduce small
299
299
  // distances well
300
300
  double ReproduceDistancesObjective::dis_weight(double x) const {
301
301
  return exp(-dis_weight_factor * x);
@@ -636,7 +636,7 @@ struct Score3Computer : PermutationObjective {
636
636
  return accu;
637
637
  }
638
638
 
639
- /// PermutationObjective implementeation (just negates the scores
639
+ /// PermutationObjective implementation (just negates the scores
640
640
  /// for minimization)
641
641
 
642
642
  double compute_cost(const int* perm) const override {
@@ -689,7 +689,7 @@ struct RankingScore2 : Score3Computer<float, double> {
689
689
  /// count nb of i, j in a x b st. i < j
690
690
  /// a and b should be sorted on input
691
691
  /// they are the ranks of j and k respectively.
692
- /// specific version for diff-of-rank weighting, cannot optimized
692
+ /// specific version for diff-of-rank weighting, cannot optimize
693
693
  /// with a cumulative table
694
694
  double accum_gt_weight_diff(
695
695
  const std::vector<int>& a,
@@ -985,7 +985,7 @@ size_t PolysemousTraining::memory_usage_per_thread(
985
985
  return n * n * n * sizeof(float);
986
986
  }
987
987
 
988
- FAISS_THROW_MSG("Invalid optmization type");
988
+ FAISS_THROW_MSG("Invalid optimization type");
989
989
  return 0;
990
990
  }
991
991
 
@@ -154,7 +154,7 @@ void ProductAdditiveQuantizer::compute_unpacked_codes(
154
154
  int32_t* unpacked_codes,
155
155
  size_t n,
156
156
  const float* centroids) const {
157
- /// TODO: actuallly we do not need to unpack and pack
157
+ /// TODO: actually we do not need to unpack and pack
158
158
  size_t offset_d = 0, offset_m = 0;
159
159
  std::vector<float> xsub;
160
160
  std::vector<uint8_t> codes;
@@ -46,7 +46,7 @@ struct ProductAdditiveQuantizer : AdditiveQuantizer {
46
46
 
47
47
  ProductAdditiveQuantizer();
48
48
 
49
- virtual ~ProductAdditiveQuantizer();
49
+ virtual ~ProductAdditiveQuantizer() override;
50
50
 
51
51
  void init(
52
52
  size_t d,
@@ -5,6 +5,8 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
+ #pragma once
9
+
8
10
  namespace faiss {
9
11
 
10
12
  inline PQEncoderGeneric::PQEncoderGeneric(
@@ -166,7 +166,7 @@ struct ProductQuantizer : Quantizer {
166
166
  /// Symmetric Distance Table
167
167
  std::vector<float> sdc_table;
168
168
 
169
- // intitialize the SDC table from the centroids
169
+ // initialize the SDC table from the centroids
170
170
  void compute_sdc_table();
171
171
 
172
172
  void search_sdc(
@@ -0,0 +1,246 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/RaBitQUtils.h>
9
+
10
+ #include <faiss/impl/FaissAssert.h>
11
+ #include <faiss/utils/distances.h>
12
+ #include <algorithm>
13
+ #include <cmath>
14
+ #include <limits>
15
+
16
+ namespace faiss {
17
+ namespace rabitq_utils {
18
+
19
+ // Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
20
+ // L2 reconstruction error.
21
+ const float Z_MAX_BY_QB[8] = {
22
+ 0.79688, // qb = 1.
23
+ 1.49375,
24
+ 2.05078,
25
+ 2.50938,
26
+ 2.91250,
27
+ 3.26406,
28
+ 3.59844,
29
+ 3.91016, // qb = 8.
30
+ };
31
+
32
+ void compute_vector_intermediate_values(
33
+ const float* x,
34
+ size_t d,
35
+ const float* centroid,
36
+ float& norm_L2sqr,
37
+ float& or_L2sqr,
38
+ float& dp_oO) {
39
+ norm_L2sqr = 0.0f;
40
+ or_L2sqr = 0.0f;
41
+ dp_oO = 0.0f;
42
+
43
+ for (size_t j = 0; j < d; j++) {
44
+ const float x_val = x[j];
45
+ const float centroid_val = (centroid != nullptr) ? centroid[j] : 0.0f;
46
+ const float or_minus_c = x_val - centroid_val;
47
+
48
+ const float or_minus_c_sq = or_minus_c * or_minus_c;
49
+ norm_L2sqr += or_minus_c_sq;
50
+ or_L2sqr += x_val * x_val;
51
+
52
+ const bool xb = (or_minus_c > 0.0f);
53
+ dp_oO += xb ? or_minus_c : -or_minus_c;
54
+ }
55
+ }
56
+
57
+ FactorsData compute_factors_from_intermediates(
58
+ float norm_L2sqr,
59
+ float or_L2sqr,
60
+ float dp_oO,
61
+ size_t d,
62
+ MetricType metric_type) {
63
+ constexpr float epsilon = std::numeric_limits<float>::epsilon();
64
+ const float inv_d_sqrt =
65
+ (d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
66
+
67
+ const float sqrt_norm_L2 = std::sqrt(norm_L2sqr);
68
+ const float inv_norm_L2 =
69
+ (norm_L2sqr < epsilon) ? 1.0f : (1.0f / sqrt_norm_L2);
70
+
71
+ const float normalized_dp = dp_oO * inv_norm_L2 * inv_d_sqrt;
72
+ const float inv_dp_oO =
73
+ (std::abs(normalized_dp) < epsilon) ? 1.0f : (1.0f / normalized_dp);
74
+
75
+ FactorsData factors;
76
+ factors.or_minus_c_l2sqr = (metric_type == MetricType::METRIC_INNER_PRODUCT)
77
+ ? (norm_L2sqr - or_L2sqr)
78
+ : norm_L2sqr;
79
+ factors.dp_multiplier = inv_dp_oO * sqrt_norm_L2;
80
+
81
+ return factors;
82
+ }
83
+
84
+ FactorsData compute_vector_factors(
85
+ const float* x,
86
+ size_t d,
87
+ const float* centroid,
88
+ MetricType metric_type) {
89
+ float norm_L2sqr, or_L2sqr, dp_oO;
90
+ compute_vector_intermediate_values(
91
+ x, d, centroid, norm_L2sqr, or_L2sqr, dp_oO);
92
+ return compute_factors_from_intermediates(
93
+ norm_L2sqr, or_L2sqr, dp_oO, d, metric_type);
94
+ }
95
+
96
+ QueryFactorsData compute_query_factors(
97
+ const float* query,
98
+ size_t d,
99
+ const float* centroid,
100
+ uint8_t qb,
101
+ bool centered,
102
+ MetricType metric_type,
103
+ std::vector<float>& rotated_q,
104
+ std::vector<uint8_t>& rotated_qq) {
105
+ FAISS_THROW_IF_NOT(qb <= 8);
106
+ FAISS_THROW_IF_NOT(qb > 0);
107
+
108
+ QueryFactorsData query_factors;
109
+
110
+ // Compute distance from query to centroid
111
+ if (centroid != nullptr) {
112
+ query_factors.qr_to_c_L2sqr = fvec_L2sqr(query, centroid, d);
113
+ } else {
114
+ query_factors.qr_to_c_L2sqr = fvec_norm_L2sqr(query, d);
115
+ }
116
+
117
+ // Rotate the query (subtract centroid)
118
+ rotated_q.resize(d);
119
+ for (size_t i = 0; i < d; i++) {
120
+ if (i < rotated_q.size()) {
121
+ rotated_q[i] =
122
+ query[i] - ((centroid == nullptr) ? 0.0f : centroid[i]);
123
+ }
124
+ }
125
+
126
+ const float inv_d_sqrt =
127
+ (d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
128
+
129
+ // Compute quantization range
130
+ float v_min = std::numeric_limits<float>::max();
131
+ float v_max = std::numeric_limits<float>::lowest();
132
+
133
+ if (centered) {
134
+ float z_max = Z_MAX_BY_QB[qb - 1];
135
+ float v_radius = z_max * std::sqrt(query_factors.qr_to_c_L2sqr / d);
136
+ v_min = -v_radius;
137
+ v_max = v_radius;
138
+ } else {
139
+ // Only compute min/max if we have dimensions to process
140
+ if (d > 0 && !rotated_q.empty()) {
141
+ for (size_t i = 0; i < d; i++) {
142
+ const float v_q = rotated_q[i];
143
+ v_min = std::min(v_min, v_q);
144
+ v_max = std::max(v_max, v_q);
145
+ }
146
+ } else {
147
+ // For empty dimensions, use default range
148
+ v_min = 0.0f;
149
+ v_max = 1.0f;
150
+ }
151
+ }
152
+
153
+ // Quantize the query
154
+ const uint8_t max_code = (1 << qb) - 1;
155
+ const float delta = (v_max - v_min) / max_code;
156
+ const float inv_delta = 1.0f / delta;
157
+
158
+ rotated_qq.resize(d);
159
+ size_t sum_qq = 0;
160
+ int64_t sum2_signed_odd_int = 0;
161
+
162
+ // Process arrays - throw error if they are unexpectedly empty
163
+ if (d > 0 && !rotated_q.empty() && !rotated_qq.empty()) {
164
+ for (size_t i = 0; i < d; i++) {
165
+ const float v_q = rotated_q[i];
166
+ // Non-randomized scalar quantization
167
+ const uint8_t v_qq = std::clamp<float>(
168
+ std::round((v_q - v_min) * inv_delta), 0, max_code);
169
+ rotated_qq[i] = v_qq;
170
+ sum_qq += v_qq;
171
+
172
+ if (centered) {
173
+ int64_t signed_odd_int = int64_t(v_qq) * 2 - max_code;
174
+ sum2_signed_odd_int += signed_odd_int * signed_odd_int;
175
+ }
176
+ }
177
+ } else {
178
+ FAISS_THROW_MSG(
179
+ "Arrays unexpectedly empty when d=" + std::to_string(d) +
180
+ "or d is incorrectly set");
181
+ }
182
+
183
+ // Compute query factors
184
+ query_factors.c1 = 2 * delta * inv_d_sqrt;
185
+ query_factors.c2 = 2 * v_min * inv_d_sqrt;
186
+ query_factors.c34 = inv_d_sqrt * (delta * sum_qq + d * v_min);
187
+
188
+ if (centered) {
189
+ query_factors.int_dot_scale = std::sqrt(
190
+ query_factors.qr_to_c_L2sqr / (sum2_signed_odd_int * d));
191
+ } else {
192
+ query_factors.int_dot_scale = 1.0f;
193
+ }
194
+
195
+ // Compute query norm for inner product metric
196
+ query_factors.qr_norm_L2sqr = 0.0f;
197
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
198
+ query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d);
199
+ }
200
+
201
+ return query_factors;
202
+ }
203
+
204
+ bool extract_bit_standard(const uint8_t* code, size_t bit_index) {
205
+ const size_t byte_idx = bit_index / 8;
206
+ const size_t bit_offset = bit_index % 8;
207
+ return (code[byte_idx] >> bit_offset) & 1;
208
+ }
209
+
210
+ bool extract_bit_fastscan(const uint8_t* code, size_t bit_index) {
211
+ const size_t m = bit_index / 4; // Sub-quantizer index
212
+ const size_t dim_offset =
213
+ bit_index % 4; // Bit position within sub-quantizer
214
+ const size_t byte_idx = m / 2; // Byte index (2 sub-quantizers per byte)
215
+ const uint8_t bit_mask = static_cast<uint8_t>(1 << dim_offset);
216
+
217
+ if (m % 2 == 0) {
218
+ // Lower 4 bits of byte
219
+ return (code[byte_idx] & bit_mask) != 0;
220
+ } else {
221
+ // Upper 4 bits of byte (shifted)
222
+ return (code[byte_idx] & (bit_mask << 4)) != 0;
223
+ }
224
+ }
225
+
226
+ void set_bit_standard(uint8_t* code, size_t bit_index) {
227
+ const size_t byte_idx = bit_index / 8;
228
+ const size_t bit_offset = bit_index % 8;
229
+ code[byte_idx] |= (1 << bit_offset);
230
+ }
231
+
232
+ void set_bit_fastscan(uint8_t* code, size_t bit_index) {
233
+ const size_t m = bit_index / 4;
234
+ const size_t dim_offset = bit_index % 4;
235
+ const uint8_t bit_mask = static_cast<uint8_t>(1 << dim_offset);
236
+ const size_t byte_idx = m / 2;
237
+
238
+ if (m % 2 == 0) {
239
+ code[byte_idx] |= bit_mask;
240
+ } else {
241
+ code[byte_idx] |= (bit_mask << 4);
242
+ }
243
+ }
244
+
245
+ } // namespace rabitq_utils
246
+ } // namespace faiss
@@ -0,0 +1,153 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/MetricType.h>
11
+ #include <faiss/impl/platform_macros.h>
12
+ #include <cstddef>
13
+ #include <cstdint>
14
+ #include <vector>
15
+
16
+ namespace faiss {
17
+ namespace rabitq_utils {
18
+
19
+ /** Factors computed per database vector for RaBitQ distance computation.
20
+ * These can be stored either embedded in codes (IndexRaBitQ) or separately
21
+ * (IndexRaBitQFastScan).
22
+ */
23
+ struct FactorsData {
24
+ // ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0)
25
+ float or_minus_c_l2sqr = 0;
26
+ float dp_multiplier = 0;
27
+ };
28
+
29
+ /** Query-specific factors computed during search for RaBitQ distance
30
+ * computation. Used by both IndexRaBitQ and IndexRaBitQFastScan
31
+ * implementations.
32
+ */
33
+ struct QueryFactorsData {
34
+ float c1 = 0;
35
+ float c2 = 0;
36
+ float c34 = 0;
37
+
38
+ float qr_to_c_L2sqr = 0;
39
+ float qr_norm_L2sqr = 0;
40
+
41
+ float int_dot_scale = 1;
42
+ };
43
+
44
+ /** Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
45
+ * L2 reconstruction error. Shared between all RaBitQ implementations.
46
+ */
47
+ FAISS_API extern const float Z_MAX_BY_QB[8];
48
+
49
+ /** Compute factors for a single database vector using RaBitQ algorithm.
50
+ * This function consolidates the mathematical logic that was duplicated
51
+ * between IndexRaBitQ and IndexRaBitQFastScan.
52
+ *
53
+ * @param x input vector (d dimensions)
54
+ * @param d dimensionality
55
+ * @param centroid database centroid (nullptr if not used)
56
+ * @param metric_type distance metric (L2 or Inner Product)
57
+ * @return computed factors for distance computation
58
+ */
59
+ FactorsData compute_vector_factors(
60
+ const float* x,
61
+ size_t d,
62
+ const float* centroid,
63
+ MetricType metric_type);
64
+
65
+ /** Compute intermediate values needed for vector factor computation.
66
+ * Separated out to allow different bit packing strategies while sharing
67
+ * the core mathematical computation.
68
+ *
69
+ * @param x input vector (d dimensions)
70
+ * @param d dimensionality
71
+ * @param centroid database centroid (nullptr if not used)
72
+ * @param norm_L2sqr output: ||or - c||^2
73
+ * @param or_L2sqr output: ||or||^2
74
+ * @param dp_oO output: sum of |or_i - c_i| (absolute deviations)
75
+ */
76
+ void compute_vector_intermediate_values(
77
+ const float* x,
78
+ size_t d,
79
+ const float* centroid,
80
+ float& norm_L2sqr,
81
+ float& or_L2sqr,
82
+ float& dp_oO);
83
+
84
+ /** Compute final factors from intermediate values.
85
+ * @param norm_L2sqr ||or - c||^2
86
+ * @param or_L2sqr ||or||^2
87
+ * @param dp_oO sum of |or_i - c_i|
88
+ * @param d dimensionality
89
+ * @param metric_type distance metric
90
+ * @return computed factors
91
+ */
92
+ FactorsData compute_factors_from_intermediates(
93
+ float norm_L2sqr,
94
+ float or_L2sqr,
95
+ float dp_oO,
96
+ size_t d,
97
+ MetricType metric_type);
98
+
99
+ /** Compute query factors for RaBitQ distance computation.
100
+ * This consolidates the query processing logic shared between implementations.
101
+ *
102
+ * @param query query vector (d dimensions)
103
+ * @param d dimensionality
104
+ * @param centroid database centroid (nullptr if not used)
105
+ * @param qb number of quantization bits (1-8)
106
+ * @param centered whether to use centered quantization
107
+ * @param metric_type distance metric
108
+ * @param rotated_q output: query - centroid
109
+ * @param rotated_qq output: quantized query values
110
+ * @return computed query factors
111
+ */
112
+ QueryFactorsData compute_query_factors(
113
+ const float* query,
114
+ size_t d,
115
+ const float* centroid,
116
+ uint8_t qb,
117
+ bool centered,
118
+ MetricType metric_type,
119
+ std::vector<float>& rotated_q,
120
+ std::vector<uint8_t>& rotated_qq);
121
+
122
+ /** Extract bit value from RaBitQ code in standard format.
123
+ * Used by IndexRaBitQ which stores bits sequentially.
124
+ *
125
+ * @param code RaBitQ code data
126
+ * @param bit_index which bit to extract (0 to d-1)
127
+ * @return bit value (true/false)
128
+ */
129
+ bool extract_bit_standard(const uint8_t* code, size_t bit_index);
130
+
131
+ /** Extract bit value from FastScan code format.
132
+ * Used by IndexRaBitQFastScan which packs bits into 4-bit sub-quantizers.
133
+ *
134
+ * @param code FastScan code data
135
+ * @param bit_index which bit to extract (0 to d-1)
136
+ * @return bit value (true/false)
137
+ */
138
+ bool extract_bit_fastscan(const uint8_t* code, size_t bit_index);
139
+
140
+ /** Set bit value in standard RaBitQ code format.
141
+ * @param code RaBitQ code data to modify
142
+ * @param bit_index which bit to set (0 to d-1)
143
+ */
144
+ void set_bit_standard(uint8_t* code, size_t bit_index);
145
+
146
+ /** Set bit value in FastScan code format.
147
+ * @param code FastScan code data to modify
148
+ * @param bit_index which bit to set (0 to d-1)
149
+ */
150
+ void set_bit_fastscan(uint8_t* code, size_t bit_index);
151
+
152
+ } // namespace rabitq_utils
153
+ } // namespace faiss