faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -7,11 +7,15 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <faiss/impl/RaBitQUtils.h>
10
11
  #include <faiss/impl/scalar_quantizer/codecs.h>
11
12
  #include <faiss/impl/scalar_quantizer/distance_computers.h>
12
13
  #include <faiss/impl/scalar_quantizer/quantizers.h>
13
14
  #include <faiss/impl/scalar_quantizer/scanners.h>
14
15
  #include <faiss/impl/scalar_quantizer/similarities.h>
16
+ #include <faiss/utils/distances.h>
17
+ #include <faiss/utils/rabitq_simd.h>
18
+ #include <limits>
15
19
 
16
20
  #ifndef THE_LEVEL_TO_DISPATCH
17
21
  #error "THE_LEVEL_TO_DISPATCH should be set on input to this header"
@@ -24,10 +28,324 @@ namespace scalar_quantizer {
24
28
  // Define SL as alias for THE_LEVEL_TO_DISPATCH for use in this file
25
29
  constexpr SIMDLevel SL = THE_LEVEL_TO_DISPATCH;
26
30
 
31
+ /*******************************************************************
32
+ * TurboQuant SIMD kernel: masked_sum
33
+ * Compute sum of arr[j] where bit j of the bitmask is set.
34
+ * NONE specialization is inline; AVX2/AVX512/NEON specializations
35
+ * live in sq-avx2.cpp / sq-avx512.cpp / sq-neon.cpp.
36
+ *******************************************************************/
37
+
38
+ template <SIMDLevel SL0>
39
+ float turboq_masked_sum(const float* arr, const uint8_t* bits, size_t d);
40
+
41
+ template <>
42
+ inline float turboq_masked_sum<SIMDLevel::NONE>(
43
+ const float* arr,
44
+ const uint8_t* bits,
45
+ size_t d) {
46
+ float result = 0;
47
+ for (size_t byte_idx = 0; byte_idx < (d + 7) / 8; byte_idx++) {
48
+ uint8_t b = bits[byte_idx];
49
+ size_t base = byte_idx * 8;
50
+ size_t end = std::min(base + 8, d);
51
+ for (size_t j = base; j < end; j++) {
52
+ if (b & (1 << (j - base))) {
53
+ result += arr[j];
54
+ }
55
+ }
56
+ }
57
+ return result;
58
+ }
59
+
60
+ /*******************************************************************
61
+ * Full TurboQuant DC — lives here because it needs both
62
+ * quantizers.h (QuantizerTurboQuantFull, SQTurboQFactors) and
63
+ * similarities.h (Similarity::metric_type). distance_computers.h
64
+ * can't include quantizers.h due to header ordering.
65
+ *******************************************************************/
66
+ template <int NBits, class Similarity, SIMDLevel SL2>
67
+ struct DCTurboQuantFull : ScalarQuantizer::TurboQuantRefine::DistanceComputer {
68
+ using Sim = Similarity;
69
+ QuantizerTurboQuantFull<NBits, SIMDLevel::NONE> quant;
70
+ std::vector<float> query;
71
+ std::vector<float> query_proj;
72
+ float q_norm_sq = 0;
73
+ float qjl_coeff = 0;
74
+ float total_qproj_sum = 0;
75
+
76
+ // Pre-screening state
77
+ const float* threshold_ptr = nullptr;
78
+ bool prescreen_l2 = false;
79
+ float qjl_error_coeff = 0;
80
+ mutable size_t n_total = 0;
81
+ mutable size_t n_skipped = 0;
82
+
83
+ // Integer popcount state
84
+ uint8_t qb = 0;
85
+ bool int_qjl = false;
86
+ std::vector<uint8_t> rearranged_q;
87
+ float mse_base = 0;
88
+ float mse_int_scale = 0;
89
+ float mse_popcnt_scale = 0;
90
+
91
+ // Integer QJL popcount state
92
+ std::vector<uint8_t> rearranged_qproj;
93
+ float qjl_int_scale = 0;
94
+ float qjl_popcnt_scale = 0;
95
+
96
+ // Scaled centroids for 1-bit MSE fast path (NBits==2)
97
+ float scaled_c0 = 0;
98
+ float scaled_c1 = 0;
99
+ float delta_centroid = 0;
100
+ float total_q_sum = 0;
101
+
102
+ // Multi-bit MSE decomposed coefficients (NBits==3, kMSEBits==2)
103
+ float mse_multi_base = 0;
104
+ float mse_coeff_s0 = 0;
105
+ float mse_coeff_s1 = 0;
106
+ float mse_coeff_s01 = 0;
107
+ mutable std::vector<uint8_t> scratch_and;
108
+
109
+ DCTurboQuantFull(size_t d, const std::vector<float>& trained)
110
+ : quant(d, trained) {
111
+ qjl_coeff = std::sqrt(M_PI / 2.0f) / static_cast<float>(d);
112
+ }
113
+
114
+ void configure(uint8_t qb_in, bool int_qjl_in) override {
115
+ qb = qb_in;
116
+ int_qjl = int_qjl_in;
117
+ }
118
+
119
+ void set_prescreen_threshold(const float* ptr, bool l2) override {
120
+ threshold_ptr = ptr;
121
+ prescreen_l2 = l2;
122
+ }
123
+
124
+ void clear_prescreen_threshold() override {
125
+ threshold_ptr = nullptr;
126
+ }
127
+
128
+ void set_query(const float* x) final {
129
+ q = x;
130
+ size_t d = quant.d;
131
+ query.assign(x, x + d);
132
+ q_norm_sq = fvec_norm_L2sqr(x, d);
133
+
134
+ // Project query
135
+ query_proj.resize(d);
136
+ quant.project_forward(x, query_proj.data());
137
+ float inv_sqrt_pd =
138
+ 1.0f / std::sqrt(static_cast<float>(quant.padded_d));
139
+ for (size_t j = 0; j < d; j++) {
140
+ query_proj[j] *= inv_sqrt_pd;
141
+ }
142
+
143
+ total_qproj_sum = 0;
144
+ for (size_t j = 0; j < d; j++) {
145
+ total_qproj_sum += query_proj[j];
146
+ }
147
+
148
+ // Pre-screening: worst-case L1 bound on QJL error
149
+ float qproj_l1 = 0;
150
+ for (size_t j = 0; j < d; j++) {
151
+ qproj_l1 += std::abs(query_proj[j]);
152
+ }
153
+ qjl_error_coeff = qjl_coeff * qproj_l1;
154
+
155
+ // Pre-compute for 1-bit MSE fast path
156
+ if constexpr (NBits == 2) {
157
+ float inv_sqrt_d = 1.0f / std::sqrt(static_cast<float>(d));
158
+ scaled_c0 = quant.centroids[0] * inv_sqrt_d;
159
+ scaled_c1 = quant.centroids[1] * inv_sqrt_d;
160
+ delta_centroid = scaled_c1 - scaled_c0;
161
+ total_q_sum = 0;
162
+ for (size_t j = 0; j < d; j++) {
163
+ total_q_sum += query[j];
164
+ }
165
+
166
+ // Integer popcount setup
167
+ if (qb > 0) {
168
+ size_t byte_size = (d + 7) / 8;
169
+ float q_min = *std::min_element(query.begin(), query.end());
170
+ float q_max = *std::max_element(query.begin(), query.end());
171
+ float q_range = q_max - q_min;
172
+ if (q_range < 1e-30f) {
173
+ q_range = 1e-30f;
174
+ }
175
+ float max_val = static_cast<float>((1 << qb) - 1);
176
+ float scale = max_val / q_range;
177
+ float delta_q = q_range / max_val;
178
+
179
+ rearranged_q.assign(byte_size * qb, 0);
180
+ for (size_t j = 0; j < d; j++) {
181
+ int qval = static_cast<int>(
182
+ std::round((query[j] - q_min) * scale));
183
+ qval = std::max(
184
+ 0, std::min(static_cast<int>(max_val), qval));
185
+ for (int b = 0; b < qb; b++) {
186
+ if (qval & (1 << b)) {
187
+ rearranged_q[b * byte_size + j / 8] |=
188
+ (1 << (j % 8));
189
+ }
190
+ }
191
+ }
192
+ mse_base = scaled_c0 * total_q_sum;
193
+ mse_int_scale = delta_centroid * delta_q;
194
+ mse_popcnt_scale = delta_centroid * q_min;
195
+ }
196
+ }
197
+
198
+ // Pre-compute for 2-bit MSE decomposed path (NBits==3)
199
+ if constexpr (NBits == 3) {
200
+ float inv_sqrt_d = 1.0f / std::sqrt(static_cast<float>(d));
201
+ const float* c = quant.centroids;
202
+ total_q_sum = 0;
203
+ for (size_t j = 0; j < d; j++) {
204
+ total_q_sum += query[j];
205
+ }
206
+ mse_multi_base = c[0] * inv_sqrt_d * total_q_sum;
207
+ mse_coeff_s0 = (c[1] - c[0]) * inv_sqrt_d;
208
+ mse_coeff_s1 = (c[2] - c[0]) * inv_sqrt_d;
209
+ mse_coeff_s01 = (c[3] - c[2] - c[1] + c[0]) * inv_sqrt_d;
210
+ scratch_and.resize((d + 7) / 8);
211
+ }
212
+
213
+ // Integer QJL: quantize projected query into bit-planes
214
+ if (qb > 0 && int_qjl) {
215
+ size_t byte_size = (d + 7) / 8;
216
+ float qp_min =
217
+ *std::min_element(query_proj.begin(), query_proj.end());
218
+ float qp_max =
219
+ *std::max_element(query_proj.begin(), query_proj.end());
220
+ float qp_range = qp_max - qp_min;
221
+ if (qp_range < 1e-30f) {
222
+ qp_range = 1e-30f;
223
+ }
224
+ float max_val = static_cast<float>((1 << qb) - 1);
225
+ float qp_scale = max_val / qp_range;
226
+ float delta_qp = qp_range / max_val;
227
+
228
+ rearranged_qproj.assign(byte_size * qb, 0);
229
+ for (size_t j = 0; j < d; j++) {
230
+ int qval = static_cast<int>(
231
+ std::round((query_proj[j] - qp_min) * qp_scale));
232
+ qval = std::max(0, std::min(static_cast<int>(max_val), qval));
233
+ for (int b = 0; b < qb; b++) {
234
+ if (qval & (1 << b)) {
235
+ rearranged_qproj[b * byte_size + j / 8] |=
236
+ (1 << (j % 8));
237
+ }
238
+ }
239
+ }
240
+ qjl_popcnt_scale = qp_min;
241
+ qjl_int_scale = delta_qp;
242
+ }
243
+
244
+ n_total = 0;
245
+ n_skipped = 0;
246
+ }
247
+
248
+ float query_to_code(const uint8_t* code) const final {
249
+ size_t d = quant.d;
250
+ float inv_sqrt_d = 1.0f / std::sqrt(static_cast<float>(d));
251
+ const auto* factors = reinterpret_cast<const SQTurboQFactors*>(
252
+ code + quant.mse_total_bytes + quant.qjl_plane_bytes);
253
+ float norm = factors->norm;
254
+ float gamma = factors->gamma;
255
+
256
+ // Stage 1: MSE dot product
257
+ float mse_dot = 0;
258
+ if constexpr (NBits == 2) {
259
+ if (qb > 0) {
260
+ // Integer popcount path for 1-bit MSE
261
+ size_t byte_size = (d + 7) / 8;
262
+ uint64_t and_result = rabitq::bitwise_and_dot_product<SL2>(
263
+ rearranged_q.data(), code, byte_size, qb);
264
+ uint64_t pop = rabitq::popcount<SL2>(code, byte_size);
265
+ mse_dot = mse_base +
266
+ mse_int_scale * static_cast<float>(and_result) +
267
+ mse_popcnt_scale * static_cast<float>(pop);
268
+ } else {
269
+ // Float path: masked accumulation
270
+ float pos_sum = turboq_masked_sum<SL2>(query.data(), code, d);
271
+ mse_dot = scaled_c0 * total_q_sum + delta_centroid * pos_sum;
272
+ }
273
+ } else if constexpr (NBits == 3) {
274
+ // 2-bit MSE: decompose into 3 masked sums over bit-planes.
275
+ size_t pb = quant.mse_plane_bytes;
276
+ float s0 = turboq_masked_sum<SL2>(query.data(), code, d);
277
+ float s1 = turboq_masked_sum<SL2>(query.data(), code + pb, d);
278
+ for (size_t i = 0; i < pb; i++) {
279
+ scratch_and[i] = code[i] & code[pb + i];
280
+ }
281
+ float s01 =
282
+ turboq_masked_sum<SL2>(query.data(), scratch_and.data(), d);
283
+ mse_dot = mse_multi_base + mse_coeff_s0 * s0 + mse_coeff_s1 * s1 +
284
+ mse_coeff_s01 * s01;
285
+ } else {
286
+ // kMSEBits > 2: per-dimension fallback
287
+ for (size_t j = 0; j < d; j++) {
288
+ uint8_t idx = quant.load_mse_index(code, j);
289
+ mse_dot += query[j] * quant.centroids[idx] * inv_sqrt_d;
290
+ }
291
+ }
292
+
293
+ // Pre-screening
294
+ if (threshold_ptr != nullptr) {
295
+ n_total++;
296
+ float bound = qjl_error_coeff * gamma * norm;
297
+ float mse_ip = norm * mse_dot;
298
+
299
+ if constexpr (Similarity::metric_type == METRIC_INNER_PRODUCT) {
300
+ if (mse_ip + bound <= *threshold_ptr) {
301
+ n_skipped++;
302
+ return -std::numeric_limits<float>::infinity();
303
+ }
304
+ } else {
305
+ float best_possible =
306
+ q_norm_sq + norm * norm - 2.0f * (mse_ip + bound);
307
+ if (best_possible >= *threshold_ptr) {
308
+ n_skipped++;
309
+ return std::numeric_limits<float>::infinity();
310
+ }
311
+ }
312
+ }
313
+
314
+ // Stage 2: QJL dot product
315
+ const uint8_t* qjl_code = code + quant.mse_total_bytes;
316
+ float qjl_dot;
317
+ if (qb > 0 && int_qjl) {
318
+ size_t byte_size = (d + 7) / 8;
319
+ uint64_t and_result = rabitq::bitwise_and_dot_product<SL2>(
320
+ rearranged_qproj.data(), qjl_code, byte_size, qb);
321
+ uint64_t pop = rabitq::popcount<SL2>(qjl_code, byte_size);
322
+ float pos_sum = qjl_popcnt_scale * static_cast<float>(pop) +
323
+ qjl_int_scale * static_cast<float>(and_result);
324
+ qjl_dot = qjl_coeff * gamma * (2.0f * pos_sum - total_qproj_sum);
325
+ } else {
326
+ float pos_sum =
327
+ turboq_masked_sum<SL2>(query_proj.data(), qjl_code, d);
328
+ qjl_dot = qjl_coeff * gamma * (2.0f * pos_sum - total_qproj_sum);
329
+ }
330
+
331
+ float estimated_ip = norm * (mse_dot + qjl_dot);
332
+
333
+ if constexpr (Similarity::metric_type == METRIC_INNER_PRODUCT) {
334
+ return estimated_ip;
335
+ } else {
336
+ return q_norm_sq + norm * norm - 2.0f * estimated_ip;
337
+ }
338
+ }
339
+
340
+ float symmetric_dis(idx_t, idx_t) override {
341
+ FAISS_THROW_MSG("Not implemented");
342
+ }
343
+ };
344
+
27
345
  // Returns true if dimension d is compatible with the given SIMD level
28
346
  template <SIMDLevel SL2>
29
347
  constexpr bool is_dimension_compatible(size_t d) {
30
- if constexpr (SL2 == SIMDLevel::AVX512) {
348
+ if constexpr (SL2 == SIMDLevel::AVX512 || SL2 == SIMDLevel::AVX512_SPR) {
31
349
  return d % 16 == 0;
32
350
  } else if constexpr (SL2 == SIMDLevel::AVX2 || SL2 == SIMDLevel::ARM_NEON) {
33
351
  return d % 8 == 0;
@@ -98,6 +416,14 @@ ScalarQuantizer::SQuantizer* sq_select_quantizer<THE_LEVEL_TO_DISPATCH>(
98
416
  return new QuantizerTurboQuantMSE<4, SL>(d, trained);
99
417
  case ScalarQuantizer::QT_8bit_tqmse:
100
418
  return new QuantizerTurboQuantMSE<8, SL>(d, trained);
419
+ case ScalarQuantizer::QT_2bit_tq:
420
+ return new QuantizerTurboQuantFull<2, SL>(d, trained);
421
+ case ScalarQuantizer::QT_3bit_tq:
422
+ return new QuantizerTurboQuantFull<3, SL>(d, trained);
423
+ case ScalarQuantizer::QT_4bit_tq:
424
+ return new QuantizerTurboQuantFull<4, SL>(d, trained);
425
+ case ScalarQuantizer::QT_5bit_tq:
426
+ return new QuantizerTurboQuantFull<5, SL>(d, trained);
101
427
  default:
102
428
  FAISS_THROW_MSG("unknown qtype");
103
429
  }
@@ -171,7 +497,8 @@ SQDistanceComputer* select_distance_computer_body(
171
497
  return new DCTemplate<QuantizerBF16<SL2>, Sim, SL2>(d, trained);
172
498
 
173
499
  case ScalarQuantizer::QT_8bit_direct:
174
- if constexpr (SL2 == SIMDLevel::AVX512) {
500
+ if constexpr (
501
+ SL2 == SIMDLevel::AVX512 || SL2 == SIMDLevel::AVX512_SPR) {
175
502
  if (d % 32 == 0) {
176
503
  return new DistanceComputerByte<Sim, SL2>(
177
504
  static_cast<int>(d), trained);
@@ -186,6 +513,12 @@ SQDistanceComputer* select_distance_computer_body(
186
513
  d, trained);
187
514
 
188
515
  case ScalarQuantizer::QT_8bit_direct_signed:
516
+ if constexpr (SL2 == SIMDLevel::AVX512_SPR) {
517
+ if (d % 64 == 0) {
518
+ return new DistanceComputerByteSigned<Sim, SL2>(
519
+ static_cast<int>(d), trained);
520
+ }
521
+ }
189
522
  return new DCTemplate<Quantizer8bitDirectSigned<SL2>, Sim, SL2>(
190
523
  d, trained);
191
524
  case ScalarQuantizer::QT_0bit:
@@ -206,6 +539,16 @@ SQDistanceComputer* select_distance_computer_body(
206
539
  case ScalarQuantizer::QT_8bit_tqmse:
207
540
  return new DCTemplate<QuantizerTurboQuantMSE<8, SL2>, Sim, SL2>(
208
541
  d, trained);
542
+ case ScalarQuantizer::QT_2bit_tq:
543
+ // FRICTION: bypasses DCTemplate entirely — custom DC
544
+ // that doesn't fit the Quantizer+Similarity decomposition
545
+ return new DCTurboQuantFull<2, Sim, SL2>(d, trained);
546
+ case ScalarQuantizer::QT_3bit_tq:
547
+ return new DCTurboQuantFull<3, Sim, SL2>(d, trained);
548
+ case ScalarQuantizer::QT_4bit_tq:
549
+ return new DCTurboQuantFull<4, Sim, SL2>(d, trained);
550
+ case ScalarQuantizer::QT_5bit_tq:
551
+ return new DCTurboQuantFull<5, Sim, SL2>(d, trained);
209
552
  default:
210
553
  FAISS_THROW_MSG("unknown qtype");
211
554
  }
@@ -320,7 +663,9 @@ InvertedListScanner* sq_select_InvertedListScanner<THE_LEVEL_TO_DISPATCH>(
320
663
  return scan.template
321
664
  operator()<DCTemplate<QuantizerBF16<SL2>, Similarity, SL2>>();
322
665
  case ScalarQuantizer::QT_8bit_direct:
323
- if constexpr (SL2 == SIMDLevel::AVX512) {
666
+ if constexpr (
667
+ SL2 == SIMDLevel::AVX512 ||
668
+ SL2 == SIMDLevel::AVX512_SPR) {
324
669
  if (d % 32 == 0) {
325
670
  return scan.template
326
671
  operator()<DistanceComputerByte<Similarity, SL2>>();
@@ -336,6 +681,12 @@ InvertedListScanner* sq_select_InvertedListScanner<THE_LEVEL_TO_DISPATCH>(
336
681
  Similarity,
337
682
  SL2>>();
338
683
  case ScalarQuantizer::QT_8bit_direct_signed:
684
+ if constexpr (SL2 == SIMDLevel::AVX512_SPR) {
685
+ if (d % 64 == 0) {
686
+ return scan.template operator()<
687
+ DistanceComputerByteSigned<Similarity, SL2>>();
688
+ }
689
+ }
339
690
  return scan.template operator()<DCTemplate<
340
691
  Quantizer8bitDirectSigned<SL2>,
341
692
  Similarity,
@@ -368,6 +719,18 @@ InvertedListScanner* sq_select_InvertedListScanner<THE_LEVEL_TO_DISPATCH>(
368
719
  QuantizerTurboQuantMSE<8, SL2>,
369
720
  Similarity,
370
721
  SL2>>();
722
+ case ScalarQuantizer::QT_2bit_tq:
723
+ return scan.template
724
+ operator()<DCTurboQuantFull<2, Similarity, SL2>>();
725
+ case ScalarQuantizer::QT_3bit_tq:
726
+ return scan.template
727
+ operator()<DCTurboQuantFull<3, Similarity, SL2>>();
728
+ case ScalarQuantizer::QT_4bit_tq:
729
+ return scan.template
730
+ operator()<DCTurboQuantFull<4, Similarity, SL2>>();
731
+ case ScalarQuantizer::QT_5bit_tq:
732
+ return scan.template
733
+ operator()<DCTurboQuantFull<5, Similarity, SL2>>();
371
734
  default:
372
735
  FAISS_THROW_MSG("unknown qtype");
373
736
  }
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/impl/simdlib/simdlib_neon.h>
11
11
 
12
+ #include <algorithm>
12
13
  #include <cstring>
13
14
 
14
15
  #include <faiss/impl/scalar_quantizer/codecs.h>
@@ -180,6 +181,12 @@ struct QuantizerTemplate<
180
181
  xi.data.val[1],
181
182
  this->vdiff)});
182
183
  }
184
+
185
+ /// Raw codec decode without denormalization (for pre-decode opt)
186
+ FAISS_ALWAYS_INLINE simd8float32
187
+ decode_8_raw(const uint8_t* code, int i) const {
188
+ return Codec::decode_8_components(code, i);
189
+ }
183
190
  };
184
191
 
185
192
  template <class Codec>
@@ -219,23 +226,34 @@ struct QuantizerTemplate<
219
226
  * TurboQuant MSE quantizer
220
227
  **********************************************************/
221
228
 
222
- #define DEFINE_TQMSE_NEON_SPECIALIZATION(NBITS, UNPACK_FN) \
223
- template <> \
224
- struct QuantizerTurboQuantMSE<NBITS, SIMDLevel::ARM_NEON> \
225
- : QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE> { \
226
- using Base = QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE>; \
227
- \
228
- QuantizerTurboQuantMSE(size_t d, const std::vector<float>& trained) \
229
- : Base(d, trained) { \
230
- assert(d % 8 == 0); \
231
- } \
232
- \
233
- FAISS_ALWAYS_INLINE simd8float32 \
234
- reconstruct_8_components(const uint8_t* code, int i) const { \
235
- uint8_t indices[8]; \
236
- UNPACK_FN(code, i, indices); \
237
- return gather_8_components(this->centroids, indices); \
238
- } \
229
+ // NEON TurboQuantMSE: decode via gather, encode stays scalar.
230
+ // NEON doesn't have movemask so 1-bit encode is also scalar.
231
+ #define DEFINE_TQMSE_NEON_SPECIALIZATION(NBITS, UNPACK_FN) \
232
+ template <> \
233
+ struct QuantizerTurboQuantMSE<NBITS, SIMDLevel::ARM_NEON> \
234
+ : QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE> { \
235
+ using Base = QuantizerTurboQuantMSE<NBITS, SIMDLevel::NONE>; \
236
+ \
237
+ QuantizerTurboQuantMSE(size_t d, const std::vector<float>& trained) \
238
+ : Base(d, trained) { \
239
+ assert(d % 8 == 0); \
240
+ } \
241
+ \
242
+ FAISS_ALWAYS_INLINE simd8float32 \
243
+ reconstruct_8_components(const uint8_t* code, int i) const { \
244
+ uint8_t indices[8]; \
245
+ UNPACK_FN(code, i, indices); \
246
+ return gather_8_components(this->centroids, indices); \
247
+ } \
248
+ \
249
+ void decode_vector(const uint8_t* code, float* x) const final { \
250
+ for (size_t i = 0; i < this->d; i += 8) { \
251
+ simd8float32 xi = \
252
+ reconstruct_8_components(code, static_cast<int>(i)); \
253
+ vst1q_f32(x + i, xi.data.val[0]); \
254
+ vst1q_f32(x + i + 4, xi.data.val[1]); \
255
+ } \
256
+ } \
239
257
  }
240
258
 
241
259
  DEFINE_TQMSE_NEON_SPECIALIZATION(1, unpack_8x1bit_to_u8);
@@ -261,6 +279,15 @@ struct QuantizerTurboQuantMSE<8, SIMDLevel::ARM_NEON>
261
279
  std::memcpy(indices, code + static_cast<size_t>(i), sizeof(indices));
262
280
  return gather_8_components(this->centroids, indices);
263
281
  }
282
+
283
+ void decode_vector(const uint8_t* code, float* x) const final {
284
+ for (size_t i = 0; i < this->d; i += 8) {
285
+ simd8float32 xi =
286
+ reconstruct_8_components(code, static_cast<int>(i));
287
+ vst1q_f32(x + i, xi.data.val[0]);
288
+ vst1q_f32(x + i + 4, xi.data.val[1]);
289
+ }
290
+ }
264
291
  };
265
292
 
266
293
  /**********************************************************
@@ -397,6 +424,22 @@ struct SimilarityL2<SIMDLevel::ARM_NEON> {
397
424
  FAISS_ALWAYS_INLINE float result_8() {
398
425
  return horizontal_add(accu8);
399
426
  }
427
+
428
+ static void adjust_query_for_raw_decode(
429
+ const float* x,
430
+ float* q_adj,
431
+ size_t d,
432
+ float vmin,
433
+ float vdiff,
434
+ float& scale_factor,
435
+ float& bias) {
436
+ float inv_vdiff = (vdiff != 0) ? 1.0f / vdiff : 0.0f;
437
+ for (size_t i = 0; i < d; i++) {
438
+ q_adj[i] = (x[i] - vmin) * inv_vdiff;
439
+ }
440
+ scale_factor = vdiff * vdiff;
441
+ bias = 0;
442
+ }
400
443
  };
401
444
 
402
445
  template <>
@@ -431,6 +474,23 @@ struct SimilarityIP<SIMDLevel::ARM_NEON> {
431
474
  FAISS_ALWAYS_INLINE float result_8() {
432
475
  return horizontal_add(accu8);
433
476
  }
477
+
478
+ static void adjust_query_for_raw_decode(
479
+ const float* x,
480
+ float* q_adj,
481
+ size_t d,
482
+ float vmin,
483
+ float vdiff,
484
+ float& scale_factor,
485
+ float& bias) {
486
+ float sum_q = 0;
487
+ for (size_t i = 0; i < d; i++) {
488
+ q_adj[i] = x[i];
489
+ sum_q += x[i];
490
+ }
491
+ scale_factor = vdiff;
492
+ bias = vmin * sum_q;
493
+ }
434
494
  };
435
495
 
436
496
  /**********************************************************
@@ -444,8 +504,23 @@ struct DCTemplate<Quantizer, Similarity, SIMDLevel::ARM_NEON>
444
504
 
445
505
  Quantizer quant;
446
506
 
507
+ // Pre-adjusted query buffer for uniform quantizers
508
+ std::vector<float> q_adj;
509
+ float scale_factor = 0;
510
+ float bias = 0;
511
+
512
+ static constexpr bool has_decode_raw() {
513
+ return requires(const Quantizer& q, const uint8_t* c, int i) {
514
+ { q.decode_8_raw(c, i) };
515
+ };
516
+ }
517
+
447
518
  DCTemplate(size_t d, const std::vector<float>& trained)
448
- : quant(d, trained) {}
519
+ : quant(d, trained) {
520
+ if constexpr (has_decode_raw()) {
521
+ q_adj.resize(d);
522
+ }
523
+ }
449
524
 
450
525
  float compute_distance(const float* x, const uint8_t* code) const {
451
526
  Similarity sim(x);
@@ -471,6 +546,26 @@ struct DCTemplate<Quantizer, Similarity, SIMDLevel::ARM_NEON>
471
546
 
472
547
  void set_query(const float* x) final {
473
548
  q = x;
549
+ if constexpr (has_decode_raw()) {
550
+ Sim::adjust_query_for_raw_decode(
551
+ x,
552
+ q_adj.data(),
553
+ quant.d,
554
+ quant.vmin,
555
+ quant.vdiff,
556
+ scale_factor,
557
+ bias);
558
+ }
559
+ }
560
+
561
+ float query_to_code_predecoded(const uint8_t* code) const {
562
+ Similarity sim(q_adj.data());
563
+ sim.begin_8();
564
+ for (size_t i = 0; i < quant.d; i += 8) {
565
+ simd8float32 xi = quant.decode_8_raw(code, i);
566
+ sim.add_8_components(xi);
567
+ }
568
+ return bias + scale_factor * sim.result_8();
474
569
  }
475
570
 
476
571
  float symmetric_dis(idx_t i, idx_t j) override {
@@ -479,7 +574,11 @@ struct DCTemplate<Quantizer, Similarity, SIMDLevel::ARM_NEON>
479
574
  }
480
575
 
481
576
  float query_to_code(const uint8_t* code) const final {
482
- return compute_distance(q, code);
577
+ if constexpr (has_decode_raw()) {
578
+ return query_to_code_predecoded(code);
579
+ } else {
580
+ return compute_distance(q, code);
581
+ }
483
582
  }
484
583
 
485
584
  void query_to_codes_batch_4(
@@ -564,6 +663,32 @@ struct DistanceComputerByte<Similarity, SIMDLevel::ARM_NEON>
564
663
  }
565
664
  };
566
665
 
666
+ /**********************************************************
667
+ * TurboQuant masked_sum NEON specialization (scalar fallback)
668
+ **********************************************************/
669
+
670
+ template <SIMDLevel SL0>
671
+ float turboq_masked_sum(const float* arr, const uint8_t* bits, size_t d);
672
+
673
+ template <>
674
+ float turboq_masked_sum<SIMDLevel::ARM_NEON>(
675
+ const float* arr,
676
+ const uint8_t* bits,
677
+ size_t d) {
678
+ float result = 0;
679
+ for (size_t byte_idx = 0; byte_idx < (d + 7) / 8; byte_idx++) {
680
+ uint8_t b = bits[byte_idx];
681
+ size_t base = byte_idx * 8;
682
+ size_t end = std::min(base + 8, d);
683
+ for (size_t j = base; j < end; j++) {
684
+ if (b & (1 << (j - base))) {
685
+ result += arr[j];
686
+ }
687
+ }
688
+ }
689
+ return result;
690
+ }
691
+
567
692
  } // namespace scalar_quantizer
568
693
  } // namespace faiss
569
694