faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -10,10 +10,46 @@
10
10
  #include <algorithm>
11
11
  #include <cmath>
12
12
  #include <cstring>
13
- #include <vector>
13
+
14
+ #include <faiss/impl/FaissAssert.h>
14
15
 
15
16
  namespace faiss {
16
17
 
18
+ namespace {
19
+
20
+ /// Helper function to compute cumulative sums by iterating backwards through
21
+ /// levels. This is the core logic shared by compute_cumulative_sums and
22
+ /// compute_query_cum_sums.
23
+ template <typename OffsetFunc>
24
+ inline void compute_cum_sums_impl(
25
+ const float* vector,
26
+ float* output,
27
+ size_t d,
28
+ size_t n_levels,
29
+ size_t level_width_floats,
30
+ OffsetFunc&& get_offset) {
31
+ // Iterate backwards through levels, accumulating sum as we go.
32
+ // This avoids computing the suffix sum for each vector, which takes
33
+ // extra memory.
34
+ float sum = 0.0f;
35
+
36
+ for (int level = n_levels - 1; level >= 0; level--) {
37
+ size_t start_idx = level * level_width_floats;
38
+ size_t end_idx = std::min(
39
+ (level + 1) * level_width_floats, static_cast<size_t>(d));
40
+
41
+ for (size_t j = start_idx; j < end_idx; j++) {
42
+ sum += vector[j] * vector[j];
43
+ }
44
+
45
+ output[get_offset(level)] = std::sqrt(sum);
46
+ }
47
+
48
+ output[get_offset(n_levels)] = 0.0f;
49
+ }
50
+
51
+ } // namespace
52
+
17
53
  /**************************************************************
18
54
  * Panorama structure implementation
19
55
  **************************************************************/
@@ -24,6 +60,7 @@ Panorama::Panorama(size_t code_size, size_t n_levels, size_t batch_size)
24
60
  }
25
61
 
26
62
  void Panorama::set_derived_values() {
63
+ FAISS_THROW_IF_NOT_MSG(n_levels > 0, "Panorama: n_levels must be > 0");
27
64
  this->d = code_size / sizeof(float);
28
65
  this->level_width_floats = ((d + n_levels - 1) / n_levels);
29
66
  this->level_width = this->level_width_floats * sizeof(float);
@@ -69,64 +106,34 @@ void Panorama::compute_cumulative_sums(
69
106
  float* cumsum_base,
70
107
  size_t offset,
71
108
  size_t n_entry,
72
- const float* vectors) {
73
- std::vector<float> suffix_sums(d + 1);
74
-
109
+ const float* vectors) const {
75
110
  for (size_t entry_idx = 0; entry_idx < n_entry; entry_idx++) {
76
111
  size_t current_pos = offset + entry_idx;
77
112
  size_t batch_no = current_pos / batch_size;
78
113
  size_t pos_in_batch = current_pos % batch_size;
79
114
 
80
115
  const float* vector = vectors + entry_idx * d;
81
-
82
- // Compute suffix sums of squared values.
83
- suffix_sums[d] = 0.0f;
84
- for (int j = d - 1; j >= 0; j--) {
85
- float squared_val = vector[j] * vector[j];
86
- suffix_sums[j] = suffix_sums[j + 1] + squared_val;
87
- }
88
-
89
- // Store cumulative sums in batch-oriented layout.
90
116
  size_t cumsum_batch_offset = batch_no * batch_size * (n_levels + 1);
91
117
 
92
- for (size_t level = 0; level < n_levels; level++) {
93
- size_t start_idx = level * level_width_floats;
94
- size_t cumsum_offset =
95
- cumsum_batch_offset + level * batch_size + pos_in_batch;
96
- if (start_idx < d) {
97
- cumsum_base[cumsum_offset] = std::sqrt(suffix_sums[start_idx]);
98
- } else {
99
- cumsum_base[cumsum_offset] = 0.0f;
100
- }
101
- }
102
-
103
- // Last level sum is always 0.
104
- size_t cumsum_offset =
105
- cumsum_batch_offset + n_levels * batch_size + pos_in_batch;
106
- cumsum_base[cumsum_offset] = 0.0f;
118
+ auto get_offset = [&](size_t level) {
119
+ return cumsum_batch_offset + level * batch_size + pos_in_batch;
120
+ };
121
+
122
+ compute_cum_sums_impl(
123
+ vector,
124
+ cumsum_base,
125
+ d,
126
+ n_levels,
127
+ level_width_floats,
128
+ get_offset);
107
129
  }
108
130
  }
109
131
 
110
132
  void Panorama::compute_query_cum_sums(const float* query, float* query_cum_sums)
111
133
  const {
112
- std::vector<float> suffix_sums(d + 1);
113
- suffix_sums[d] = 0.0f;
114
-
115
- for (int j = d - 1; j >= 0; j--) {
116
- float squared_val = query[j] * query[j];
117
- suffix_sums[j] = suffix_sums[j + 1] + squared_val;
118
- }
119
-
120
- for (size_t level = 0; level < n_levels; level++) {
121
- size_t start_idx = level * level_width_floats;
122
- if (start_idx < d) {
123
- query_cum_sums[level] = std::sqrt(suffix_sums[start_idx]);
124
- } else {
125
- query_cum_sums[level] = 0.0f;
126
- }
127
- }
128
-
129
- query_cum_sums[n_levels] = 0.0f;
134
+ auto get_offset = [](size_t level) { return level; };
135
+ compute_cum_sums_impl(
136
+ query, query_cum_sums, d, n_levels, level_width_floats, get_offset);
130
137
  }
131
138
 
132
139
  void Panorama::reconstruct(idx_t key, float* recons, const uint8_t* codes_base)
@@ -10,6 +10,7 @@
10
10
  #ifndef FAISS_PANORAMA_H
11
11
  #define FAISS_PANORAMA_H
12
12
 
13
+ #include <faiss/MetricType.h>
13
14
  #include <faiss/impl/IDSelector.h>
14
15
  #include <faiss/impl/PanoramaStats.h>
15
16
  #include <faiss/utils/distances.h>
@@ -67,7 +68,7 @@ struct Panorama {
67
68
  float* cumsum_base,
68
69
  size_t offset,
69
70
  size_t n_entry,
70
- const float* vectors);
71
+ const float* vectors) const;
71
72
 
72
73
  /// Compute the cumulative sums of the query vector.
73
74
  void compute_query_cum_sums(const float* query, float* query_cum_sums)
@@ -97,7 +98,7 @@ struct Panorama {
97
98
  /// 4. After all levels, survivors are exact distances; update heap.
98
99
  /// This achieves early termination while maintaining SIMD-friendly
99
100
  /// sequential access patterns in the level-oriented storage layout.
100
- template <typename C>
101
+ template <typename C, MetricType M>
101
102
  size_t progressive_filter_batch(
102
103
  const uint8_t* codes_base,
103
104
  const float* cum_sums,
@@ -116,7 +117,7 @@ struct Panorama {
116
117
  void reconstruct(idx_t key, float* recons, const uint8_t* codes_base) const;
117
118
  };
118
119
 
119
- template <typename C>
120
+ template <typename C, MetricType M>
120
121
  size_t Panorama::progressive_filter_batch(
121
122
  const uint8_t* codes_base,
122
123
  const float* cum_sums,
@@ -151,7 +152,12 @@ size_t Panorama::progressive_filter_batch(
151
152
 
152
153
  active_indices[num_active] = i;
153
154
  float cum_sum = batch_cum_sums[i];
154
- exact_distances[i] = cum_sum * cum_sum + q_norm;
155
+
156
+ if constexpr (M == METRIC_INNER_PRODUCT) {
157
+ exact_distances[i] = 0.0f;
158
+ } else {
159
+ exact_distances[i] = cum_sum * cum_sum + q_norm;
160
+ }
155
161
 
156
162
  num_active += include;
157
163
  }
@@ -183,10 +189,20 @@ size_t Panorama::progressive_filter_batch(
183
189
  float dot_product =
184
190
  fvec_inner_product(query_level, yj, actual_level_width);
185
191
 
186
- exact_distances[idx] -= 2.0f * dot_product;
192
+ if constexpr (M == METRIC_INNER_PRODUCT) {
193
+ exact_distances[idx] += dot_product;
194
+ } else {
195
+ exact_distances[idx] -= 2.0f * dot_product;
196
+ }
187
197
 
188
198
  float cum_sum = level_cum_sums[idx];
189
- float cauchy_schwarz_bound = 2.0f * cum_sum * query_cum_norm;
199
+ float cauchy_schwarz_bound;
200
+ if constexpr (M == METRIC_INNER_PRODUCT) {
201
+ cauchy_schwarz_bound = -cum_sum * query_cum_norm;
202
+ } else {
203
+ cauchy_schwarz_bound = 2.0f * cum_sum * query_cum_norm;
204
+ }
205
+
190
206
  float lower_bound = exact_distances[idx] - cauchy_schwarz_bound;
191
207
 
192
208
  active_indices[next_active] = idx;
@@ -18,6 +18,7 @@
18
18
  #include <cstring>
19
19
  #include <memory>
20
20
 
21
+ #include <faiss/impl/simd_dispatch.h>
21
22
  #include <faiss/utils/distances.h>
22
23
  #include <faiss/utils/hamming.h>
23
24
  #include <faiss/utils/random.h>
@@ -431,6 +432,8 @@ void ReproduceDistancesObjective::set_affine_target_dis(
431
432
  * Cost functions: RankingScore
432
433
  ****************************************************/
433
434
 
435
+ namespace {
436
+
434
437
  /// Maintains a 3D table of elementary costs.
435
438
  /// Accumulates elements based on Hamming distance comparisons
436
439
  template <typename Ttab, typename Taccu>
@@ -756,6 +759,8 @@ struct RankingScore2 : Score3Computer<float, double> {
756
759
  }
757
760
  };
758
761
 
762
+ } // namespace
763
+
759
764
  /*****************************************
760
765
  * PolysemousTraining
761
766
  ******************************************/
@@ -798,12 +803,18 @@ void PolysemousTraining::optimize_reproduce_distances(
798
803
 
799
804
  float* centroids = pq.get_centroids(m, 0);
800
805
 
801
- for (int i = 0; i < n; i++) {
802
- for (int j = 0; j < n; j++) {
803
- dis_table.push_back(fvec_L2sqr(
804
- centroids + i * dsub, centroids + j * dsub, dsub));
806
+ auto compute_dis_table = [&]<SIMDLevel SL>() {
807
+ for (int i = 0; i < n; i++) {
808
+ for (int j = 0; j < n; j++) {
809
+ dis_table.push_back(
810
+ fvec_L2sqr<SL>(
811
+ centroids + i * dsub,
812
+ centroids + j * dsub,
813
+ dsub));
814
+ }
805
815
  }
806
- }
816
+ };
817
+ with_simd_level(compute_dis_table);
807
818
 
808
819
  std::vector<int> perm(n);
809
820
  ReproduceWithHammingObjective obj(nbits, dis_table, dis_weight_factor);
@@ -19,6 +19,7 @@
19
19
  #include <faiss/IndexFlat.h>
20
20
  #include <faiss/VectorTransform.h>
21
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/simd_dispatch.h>
22
23
  #include <faiss/utils/distances.h>
23
24
 
24
25
  extern "C" {
@@ -56,14 +57,15 @@ ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {}
56
57
 
57
58
  void ProductQuantizer::set_derived_values() {
58
59
  // quite a few derived values
60
+ FAISS_THROW_IF_NOT_MSG(M > 0, "M must be > 0");
59
61
  FAISS_THROW_IF_NOT_MSG(
60
62
  d % M == 0,
61
63
  "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
62
64
  dsub = d / M;
63
- code_size = (nbits * M + 7) / 8;
64
65
  FAISS_THROW_IF_MSG(nbits > 24, "nbits larger than 24 is not practical.");
66
+ code_size = (nbits * M + 7) / 8;
65
67
  ksub = 1 << nbits;
66
- centroids.resize(d * ksub);
68
+ centroids.resize(mul_no_overflow(d, (size_t)ksub, "PQ centroids"));
67
69
  verbose = false;
68
70
  train_type = Train_default;
69
71
  }
@@ -201,8 +203,10 @@ void ProductQuantizer::train(size_t n, const float* x) {
201
203
  }
202
204
  }
203
205
 
204
- template <class PQEncoder>
205
- void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
206
+ namespace {
207
+
208
+ template <class PQEncoder, SIMDLevel SL>
209
+ void compute_1_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
206
210
  std::vector<float> distances(pq.ksub);
207
211
 
208
212
  // It seems to be meaningless to allocate std::vector<float> distances.
@@ -248,7 +252,7 @@ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
248
252
  uint64_t idxm = 0;
249
253
  if (pq.transposed_centroids.empty()) {
250
254
  // the regular version
251
- idxm = fvec_L2sqr_ny_nearest(
255
+ idxm = fvec_L2sqr_ny_nearest<SL>(
252
256
  distances.data(),
253
257
  xsub,
254
258
  pq.get_centroids(m, 0),
@@ -256,7 +260,7 @@ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
256
260
  pq.ksub);
257
261
  } else {
258
262
  // transposed centroids are available, use'em
259
- idxm = fvec_L2sqr_ny_nearest_y_transposed(
263
+ idxm = fvec_L2sqr_ny_nearest_y_transposed<SL>(
260
264
  distances.data(),
261
265
  xsub,
262
266
  pq.transposed_centroids.data() + m * pq.ksub,
@@ -270,20 +274,24 @@ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
270
274
  }
271
275
  }
272
276
 
277
+ } // namespace
278
+
273
279
  void ProductQuantizer::compute_code(const float* x, uint8_t* code) const {
274
- switch (nbits) {
275
- case 8:
276
- faiss::compute_code<PQEncoder8>(*this, x, code);
277
- break;
280
+ with_simd_level([&]<SIMDLevel SL>() {
281
+ switch (nbits) {
282
+ case 8:
283
+ compute_1_code<PQEncoder8, SL>(*this, x, code);
284
+ break;
278
285
 
279
- case 16:
280
- faiss::compute_code<PQEncoder16>(*this, x, code);
281
- break;
286
+ case 16:
287
+ compute_1_code<PQEncoder16, SL>(*this, x, code);
288
+ break;
282
289
 
283
- default:
284
- faiss::compute_code<PQEncoderGeneric>(*this, x, code);
285
- break;
286
- }
290
+ default:
291
+ compute_1_code<PQEncoderGeneric, SL>(*this, x, code);
292
+ break;
293
+ }
294
+ }); // with_simd_level
287
295
  }
288
296
 
289
297
  template <class PQDecoder>
@@ -428,44 +436,46 @@ void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
428
436
 
429
437
  void ProductQuantizer::compute_distance_table(const float* x, float* dis_table)
430
438
  const {
431
- if (transposed_centroids.empty()) {
432
- // use regular version
433
- for (size_t m = 0; m < M; m++) {
434
- fvec_L2sqr_ny(
435
- dis_table + m * ksub,
436
- x + m * dsub,
437
- get_centroids(m, 0),
438
- dsub,
439
- ksub);
439
+ with_simd_level([&]<SIMDLevel SL>() {
440
+ if (transposed_centroids.empty()) {
441
+ // use regular version
442
+ for (size_t m = 0; m < M; m++) {
443
+ fvec_L2sqr_ny<SL>(
444
+ dis_table + m * ksub,
445
+ x + m * dsub,
446
+ get_centroids(m, 0),
447
+ dsub,
448
+ ksub);
449
+ }
450
+ } else {
451
+ // transposed centroids are available, use'em
452
+ for (size_t m = 0; m < M; m++) {
453
+ fvec_L2sqr_ny_transposed<SL>(
454
+ dis_table + m * ksub,
455
+ x + m * dsub,
456
+ transposed_centroids.data() + m * ksub,
457
+ centroids_sq_lengths.data() + m * ksub,
458
+ dsub,
459
+ M * ksub,
460
+ ksub);
461
+ }
440
462
  }
441
- } else {
442
- // transposed centroids are available, use'em
463
+ });
464
+ }
465
+
466
+ void ProductQuantizer::compute_inner_prod_table(
467
+ const float* x,
468
+ float* dis_table) const {
469
+ with_simd_level([&]<SIMDLevel SL>() {
443
470
  for (size_t m = 0; m < M; m++) {
444
- fvec_L2sqr_ny_transposed(
471
+ fvec_inner_products_ny<SL>(
445
472
  dis_table + m * ksub,
446
473
  x + m * dsub,
447
- transposed_centroids.data() + m * ksub,
448
- centroids_sq_lengths.data() + m * ksub,
474
+ get_centroids(m, 0),
449
475
  dsub,
450
- M * ksub,
451
476
  ksub);
452
477
  }
453
- }
454
- }
455
-
456
- void ProductQuantizer::compute_inner_prod_table(
457
- const float* x,
458
- float* dis_table) const {
459
- size_t m;
460
-
461
- for (m = 0; m < M; m++) {
462
- fvec_inner_products_ny(
463
- dis_table + m * ksub,
464
- x + m * dsub,
465
- get_centroids(m, 0),
466
- dsub,
467
- ksub);
468
- }
478
+ });
469
479
  }
470
480
 
471
481
  void ProductQuantizer::compute_distance_tables(
@@ -785,17 +795,19 @@ void ProductQuantizer::compute_sdc_table() {
785
795
  sdc_table.resize(M * ksub * ksub);
786
796
 
787
797
  if (dsub < 4) {
798
+ with_simd_level([&]<SIMDLevel SL>() {
788
799
  #pragma omp parallel for
789
- for (int mk = 0; mk < M * ksub; mk++) {
790
- // allow omp to schedule in a more fine-grained way
791
- // `collapse` is not supported in OpenMP 2.x
792
- int m = mk / ksub;
793
- int k = mk % ksub;
794
- const float* cents = centroids.data() + m * ksub * dsub;
795
- const float* centi = cents + k * dsub;
796
- float* dis_tab = sdc_table.data() + m * ksub * ksub;
797
- fvec_L2sqr_ny(dis_tab + k * ksub, centi, cents, dsub, ksub);
798
- }
800
+ for (int mk = 0; mk < M * ksub; mk++) {
801
+ // allow omp to schedule in a more fine-grained way
802
+ // `collapse` is not supported in OpenMP 2.x
803
+ int m = mk / ksub;
804
+ int k = mk % ksub;
805
+ const float* cents = centroids.data() + m * ksub * dsub;
806
+ const float* centi = cents + k * dsub;
807
+ float* dis_tab = sdc_table.data() + m * ksub * ksub;
808
+ fvec_L2sqr_ny<SL>(dis_tab + k * ksub, centi, cents, dsub, ksub);
809
+ }
810
+ });
799
811
  } else {
800
812
  // NOTE: it would disable the omp loop in pairwise_L2sqr
801
813
  // but still accelerate especially when M >= 4
@@ -9,8 +9,10 @@
9
9
 
10
10
  #include <faiss/impl/FaissAssert.h>
11
11
  #include <faiss/utils/distances.h>
12
+ #include <faiss/utils/rabitq_simd.h>
12
13
  #include <algorithm>
13
14
  #include <cmath>
15
+ #include <cstring>
14
16
  #include <limits>
15
17
 
16
18
  namespace faiss {
@@ -242,8 +244,12 @@ QueryFactorsData compute_query_factors(
242
244
 
243
245
  // Compute query norm for inner product metric
244
246
  query_factors.qr_norm_L2sqr = 0.0f;
247
+ query_factors.q_dot_c = 0.0f;
245
248
  if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
246
249
  query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d);
250
+ if (centroid != nullptr) {
251
+ query_factors.q_dot_c = fvec_inner_product(query, centroid, d);
252
+ }
247
253
  }
248
254
 
249
255
  return query_factors;
@@ -290,5 +296,91 @@ void set_bit_fastscan(uint8_t* code, size_t bit_index) {
290
296
  }
291
297
  }
292
298
 
299
+ size_t compute_per_vector_storage_size(size_t nb_bits, size_t d) {
300
+ const size_t ex_bits = nb_bits - 1;
301
+ if (ex_bits == 0) {
302
+ return sizeof(SignBitFactors);
303
+ } else {
304
+ return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
305
+ (d * ex_bits + 7) / 8;
306
+ }
307
+ }
308
+
309
+ float compute_full_multibit_distance(
310
+ const uint8_t* sign_bits,
311
+ const uint8_t* ex_code,
312
+ const ExtraBitsFactors& ex_fac,
313
+ const float* rotated_q,
314
+ float qr_base,
315
+ size_t d,
316
+ size_t ex_bits,
317
+ MetricType metric_type) {
318
+ const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
319
+
320
+ float ex_ip = rabitq::multibit::compute_inner_product(
321
+ sign_bits, ex_code, rotated_q, d, ex_bits, cb);
322
+
323
+ float dist = qr_base + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
324
+
325
+ if (metric_type == MetricType::METRIC_L2) {
326
+ dist = std::max(0.0f, dist);
327
+ }
328
+
329
+ return dist;
330
+ }
331
+
332
+ void populate_block_aux_from_flat_storage(
333
+ const std::vector<uint8_t>& flat_storage,
334
+ AlignedTable<uint8_t>& codes,
335
+ size_t num_vectors,
336
+ size_t bbs,
337
+ size_t M2,
338
+ size_t old_block_stride,
339
+ size_t new_block_stride,
340
+ size_t storage_size,
341
+ const int64_t* id_map) {
342
+ if (flat_storage.empty() || num_vectors == 0) {
343
+ return;
344
+ }
345
+
346
+ const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
347
+ const size_t n_blocks = (num_vectors + bbs - 1) / bbs;
348
+
349
+ if (old_block_stride < new_block_stride) {
350
+ AlignedTable<uint8_t> old_data;
351
+ old_data.resize(codes.size());
352
+ memcpy(old_data.data(), codes.data(), codes.size());
353
+
354
+ codes.resize(n_blocks * new_block_stride);
355
+ memset(codes.data(), 0, n_blocks * new_block_stride);
356
+ for (size_t b = 0; b < n_blocks; b++) {
357
+ memcpy(codes.data() + b * new_block_stride,
358
+ old_data.data() + b * old_block_stride,
359
+ packed_block_size);
360
+ }
361
+ }
362
+
363
+ for (size_t offset = 0; offset < num_vectors; offset++) {
364
+ const int64_t global_id =
365
+ id_map ? id_map[offset] : static_cast<int64_t>(offset);
366
+ FAISS_THROW_IF_NOT_MSG(
367
+ global_id >= 0 &&
368
+ static_cast<size_t>(global_id) * storage_size +
369
+ storage_size <=
370
+ flat_storage.size(),
371
+ "global_id out of bounds for flat_storage during migration");
372
+
373
+ const uint8_t* src = flat_storage.data() + global_id * storage_size;
374
+ uint8_t* dst = get_block_aux_ptr(
375
+ codes.data(),
376
+ offset,
377
+ bbs,
378
+ packed_block_size,
379
+ new_block_stride,
380
+ storage_size);
381
+ memcpy(dst, src, storage_size);
382
+ }
383
+ }
384
+
293
385
  } // namespace rabitq_utils
294
386
  } // namespace faiss