faiss 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/Index.h +1 -1
  5. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
  6. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
  7. data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
  8. data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
  9. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  10. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
  11. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
  12. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
  13. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
  14. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
  15. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
  16. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  17. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
  18. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
  19. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
  20. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
  21. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
  22. data/vendor/faiss/faiss/factory_tools.cpp +4 -0
  23. data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
  24. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
  25. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
  26. data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
  27. data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
  28. data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
  29. data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
  30. data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
  31. data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
  32. data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
  33. data/vendor/faiss/faiss/impl/HNSW.h +51 -13
  34. data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
  35. data/vendor/faiss/faiss/impl/Panorama.h +11 -0
  36. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
  37. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
  38. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
  39. data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
  40. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
  41. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
  42. data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
  43. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
  44. data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
  45. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
  46. data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
  47. data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
  48. data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
  49. data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
  50. data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
  51. data/vendor/faiss/faiss/impl/io_macros.h +25 -0
  52. data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
  53. data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
  54. data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
  55. data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
  56. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
  57. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
  58. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
  59. data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
  60. data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
  61. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
  62. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
  63. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
  64. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
  65. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
  66. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
  67. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
  68. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
  69. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
  70. data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
  71. data/vendor/faiss/faiss/index_factory.cpp +5 -1
  72. data/vendor/faiss/faiss/index_io.h +16 -0
  73. data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
  74. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
  75. data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
  76. data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
  77. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
  78. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
  79. data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
  80. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
  81. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
  82. data/vendor/faiss/faiss/utils/bf16.h +34 -0
  83. data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
  84. data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
  85. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
  86. data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
  87. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
  88. data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
  89. data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
  90. data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
  91. data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
  92. data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
  93. metadata +12 -2
@@ -7,14 +7,44 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <cmath>
11
+
12
+ // Hack for MSVC
13
+ #ifndef M_PI
14
+ #define M_PI 3.14159265358979323846
15
+ #endif
16
+
10
17
  #include <algorithm>
18
+ #include <cstring>
11
19
 
12
20
  #include <faiss/impl/FaissAssert.h>
21
+ #include <faiss/impl/RaBitQUtils.h>
13
22
  #include <faiss/impl/ScalarQuantizer.h>
23
+ #include <faiss/impl/platform_macros.h>
14
24
  #include <faiss/impl/simdlib/simdlib_dispatch.h>
15
25
  #include <faiss/utils/bf16.h>
26
+ #include <faiss/utils/distances.h>
16
27
  #include <faiss/utils/fp16.h>
28
+ #include <faiss/utils/random.h>
17
29
  #include <faiss/utils/simd_levels.h>
30
+ #include <faiss/utils/utils.h>
31
+
32
+ extern "C" {
33
+ int sgemm_(
34
+ const char* transa,
35
+ const char* transb,
36
+ int* m,
37
+ int* n,
38
+ int* k,
39
+ const float* alpha,
40
+ const float* a,
41
+ int* lda,
42
+ const float* b,
43
+ int* ldb,
44
+ float* beta,
45
+ float* c,
46
+ int* ldc);
47
+ }
18
48
 
19
49
  namespace faiss {
20
50
 
@@ -142,15 +172,14 @@ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
142
172
  boundaries = trained.data() + kCentroidsCount;
143
173
  }
144
174
 
145
- FAISS_ALWAYS_INLINE uint8_t select_index(float x) const {
175
+ uint8_t select_index(float x) const {
146
176
  return static_cast<uint8_t>(
147
177
  std::upper_bound(
148
178
  boundaries, boundaries + (kCentroidsCount - 1), x) -
149
179
  boundaries);
150
180
  }
151
181
 
152
- FAISS_ALWAYS_INLINE void encode_index(uint8_t idx, uint8_t* code, size_t i)
153
- const {
182
+ void encode_index(uint8_t idx, uint8_t* code, size_t i) const {
154
183
  const size_t bit_offset = i * NBits;
155
184
  const size_t byte_offset = bit_offset >> 3;
156
185
  const size_t bit_shift = bit_offset & 7;
@@ -162,8 +191,7 @@ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
162
191
  }
163
192
  }
164
193
 
165
- FAISS_ALWAYS_INLINE uint8_t
166
- decode_index(const uint8_t* code, size_t i) const {
194
+ uint8_t decode_index(const uint8_t* code, size_t i) const {
167
195
  const size_t bit_offset = i * NBits;
168
196
  const size_t byte_offset = bit_offset >> 3;
169
197
  const size_t bit_shift = bit_offset & 7;
@@ -175,21 +203,19 @@ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
175
203
  return static_cast<uint8_t>((packed >> bit_shift) & kIndexMask);
176
204
  }
177
205
 
178
- void encode_vector(const float* x, uint8_t* code) const final {
206
+ void encode_vector(const float* x, uint8_t* code) const override {
179
207
  for (size_t i = 0; i < d; i++) {
180
208
  encode_index(select_index(x[i]), code, i);
181
209
  }
182
210
  }
183
211
 
184
- void decode_vector(const uint8_t* code, float* x) const final {
212
+ void decode_vector(const uint8_t* code, float* x) const override {
185
213
  for (size_t i = 0; i < d; i++) {
186
214
  x[i] = centroids[decode_index(code, i)];
187
215
  }
188
216
  }
189
217
 
190
- FAISS_ALWAYS_INLINE float reconstruct_component(
191
- const uint8_t* code,
192
- size_t i) const {
218
+ float reconstruct_component(const uint8_t* code, size_t i) const {
193
219
  return centroids[decode_index(code, i)];
194
220
  }
195
221
  };
@@ -252,16 +278,12 @@ struct QuantizerBF16<SIMDLevel::NONE> : ScalarQuantizer::SQuantizer {
252
278
  QuantizerBF16(size_t d_in, const std::vector<float>& /* unused */)
253
279
  : d(d_in) {}
254
280
 
255
- void encode_vector(const float* x, uint8_t* code) const final {
256
- for (size_t i = 0; i < d; i++) {
257
- ((uint16_t*)code)[i] = encode_bf16(x[i]);
258
- }
281
+ void encode_vector(const float* x, uint8_t* code) const override {
282
+ encode_bf16_simd(x, (uint16_t*)code, d);
259
283
  }
260
284
 
261
- void decode_vector(const uint8_t* code, float* x) const final {
262
- for (size_t i = 0; i < d; i++) {
263
- x[i] = decode_bf16(((uint16_t*)code)[i]);
264
- }
285
+ void decode_vector(const uint8_t* code, float* x) const override {
286
+ decode_bf16_simd((const uint16_t*)code, x, d);
265
287
  }
266
288
 
267
289
  FAISS_ALWAYS_INLINE float reconstruct_component(
@@ -276,6 +298,11 @@ struct QuantizerBF16 : QuantizerBF16<SIMDLevel::NONE> {
276
298
  using QuantizerBF16<SIMDLevel::NONE>::QuantizerBF16;
277
299
  };
278
300
 
301
+ template <>
302
+ struct QuantizerBF16<SIMDLevel::AVX512>;
303
+ template <>
304
+ struct QuantizerBF16<SIMDLevel::AVX512_SPR>;
305
+
279
306
  /*******************************************************************
280
307
  * 8bit_direct quantizer
281
308
  *******************************************************************/
@@ -355,6 +382,288 @@ struct Quantizer8bitDirectSigned : Quantizer8bitDirectSigned<SIMDLevel::NONE> {
355
382
  using Quantizer8bitDirectSigned<SIMDLevel::NONE>::Quantizer8bitDirectSigned;
356
383
  };
357
384
 
385
+ /*******************************************************************
386
+ * Full TurboQuant (MSE + QJL) quantizer
387
+ *
388
+ * NBits = total bits per dimension (2-5).
389
+ * MSE bits = NBits - 1, QJL bits = 1.
390
+ *
391
+ * Trained vector layout:
392
+ * [centroids (k floats), boundaries (k-1 floats),
393
+ * seed_lo (float), seed_hi (float), qjl_type (float)]
394
+ * where k = 2^(NBits-1).
395
+ *******************************************************************/
396
+
397
+ FAISS_PACK_STRUCTS_BEGIN
398
+ struct SQTurboQFactors {
399
+ float norm = 0;
400
+ float gamma = 0;
401
+ };
402
+ FAISS_PACK_STRUCTS_END
403
+
404
+ template <int NBits, SIMDLevel SL>
405
+ struct QuantizerTurboQuantFull;
406
+
407
+ template <int NBits>
408
+ struct QuantizerTurboQuantFull<NBits, SIMDLevel::NONE>
409
+ : ScalarQuantizer::SQuantizer {
410
+ static_assert(NBits >= 2 && NBits <= 5);
411
+
412
+ static constexpr int kMSEBits = NBits - 1;
413
+ static constexpr size_t kCentroidsCount = size_t(1) << kMSEBits;
414
+
415
+ const size_t d;
416
+ const float* centroids;
417
+ const float* boundaries;
418
+
419
+ // QJL projection type: 0 = FWHT, 2 = Random Rotation
420
+ uint8_t qjl_type;
421
+
422
+ // FWHT state (qjl_type == 0)
423
+ size_t padded_d;
424
+ std::vector<float> fwht_signs;
425
+
426
+ // Random Rotation state (qjl_type == 2)
427
+ std::vector<float> rr_matrix; // d x d orthogonal matrix (row-major)
428
+
429
+ size_t mse_plane_bytes; // bytes for one bit-plane of d bits
430
+ size_t mse_total_bytes; // kMSEBits * mse_plane_bytes
431
+ size_t qjl_plane_bytes;
432
+
433
+ QuantizerTurboQuantFull(size_t d_in, const std::vector<float>& trained)
434
+ : d(d_in),
435
+ centroids(trained.data()),
436
+ boundaries(trained.data() + kCentroidsCount) {
437
+ // trained = [centroids(k), boundaries(k-1), seed_lo, seed_hi, qjl_type]
438
+ size_t k = kCentroidsCount;
439
+ FAISS_THROW_IF_NOT(trained.size() == 2 * k - 1 + 3);
440
+
441
+ mse_plane_bytes = (d + 7) / 8;
442
+ mse_total_bytes = kMSEBits * mse_plane_bytes;
443
+ qjl_plane_bytes = (d + 7) / 8;
444
+
445
+ // Extract seed from trained
446
+ uint64_t seed = ScalarQuantizer::TurboQuantRefine::unpack_seed(
447
+ trained[2 * k - 1], trained[2 * k]);
448
+ qjl_type = static_cast<uint8_t>(trained[2 * k + 1]);
449
+
450
+ if (qjl_type == 0) {
451
+ // FWHT mode
452
+ padded_d = 1;
453
+ while (padded_d < d) {
454
+ padded_d <<= 1;
455
+ }
456
+ fwht_signs.resize(padded_d);
457
+ RandomGenerator rng(seed);
458
+ for (size_t i = 0; i < padded_d; i++) {
459
+ fwht_signs[i] = (rng.rand_int(2) == 0) ? 1.0f : -1.0f;
460
+ }
461
+ } else {
462
+ // Random Rotation mode
463
+ padded_d = d; // no padding needed for dense multiply
464
+ rr_matrix.resize(d * d);
465
+ float_randn(rr_matrix.data(), d * d, static_cast<int64_t>(seed));
466
+ matrix_qr(
467
+ static_cast<int>(d), static_cast<int>(d), rr_matrix.data());
468
+ }
469
+ }
470
+
471
+ void fwht_inplace(float* x, size_t n) const {
472
+ for (size_t h = 1; h < n; h <<= 1) {
473
+ for (size_t i = 0; i < n; i += h << 1) {
474
+ for (size_t j = i; j < i + h; j++) {
475
+ float a = x[j];
476
+ float b = x[j + h];
477
+ x[j] = a + b;
478
+ x[j + h] = a - b;
479
+ }
480
+ }
481
+ }
482
+ }
483
+
484
+ /// Forward QJL projection: residual -> projected (d outputs)
485
+ void project_forward(const float* residual, float* out) const {
486
+ if (qjl_type == 0) {
487
+ std::vector<float> fwht_buf(padded_d);
488
+ for (size_t j = 0; j < d; j++) {
489
+ fwht_buf[j] = residual[j] * fwht_signs[j];
490
+ }
491
+ for (size_t j = d; j < padded_d; j++) {
492
+ fwht_buf[j] = 0.0f;
493
+ }
494
+ fwht_inplace(fwht_buf.data(), padded_d);
495
+ for (size_t j = 0; j < d; j++) {
496
+ out[j] = fwht_buf[j];
497
+ }
498
+ } else {
499
+ rr_forward(residual, out);
500
+ }
501
+ }
502
+
503
+ /// Inverse QJL projection: signs_buf -> reconstructed (d outputs)
504
+ void project_inverse(float* signs_buf, float* out) const {
505
+ if (qjl_type == 0) {
506
+ fwht_inplace(signs_buf, padded_d);
507
+ for (size_t j = 0; j < d; j++) {
508
+ out[j] = signs_buf[j] * fwht_signs[j];
509
+ }
510
+ } else {
511
+ rr_inverse(signs_buf, out);
512
+ }
513
+ }
514
+
515
+ void rr_forward(const float* x, float* out) const {
516
+ float alpha = 1.0f;
517
+ float beta = 0.0f;
518
+ int di = static_cast<int>(d);
519
+ int one = 1;
520
+ sgemm_("T",
521
+ "N",
522
+ &di,
523
+ &one,
524
+ &di,
525
+ &alpha,
526
+ rr_matrix.data(),
527
+ &di,
528
+ x,
529
+ &di,
530
+ &beta,
531
+ out,
532
+ &di);
533
+ }
534
+
535
+ void rr_inverse(const float* x, float* out) const {
536
+ float alpha = 1.0f;
537
+ float beta = 0.0f;
538
+ int di = static_cast<int>(d);
539
+ int one = 1;
540
+ sgemm_("N",
541
+ "N",
542
+ &di,
543
+ &one,
544
+ &di,
545
+ &alpha,
546
+ rr_matrix.data(),
547
+ &di,
548
+ x,
549
+ &di,
550
+ &beta,
551
+ out,
552
+ &di);
553
+ }
554
+
555
+ /// Store MSE index for dimension j using BIT-PLANE layout.
556
+ /// Plane p stores bit p of every dimension's index.
557
+ void store_mse_index(uint8_t idx, uint8_t* code, size_t j) const {
558
+ for (int p = 0; p < kMSEBits; p++) {
559
+ if (idx & (1 << p)) {
560
+ code[p * mse_plane_bytes + j / 8] |= (1 << (j % 8));
561
+ }
562
+ }
563
+ }
564
+
565
+ /// Load MSE index for dimension j from BIT-PLANE layout.
566
+ uint8_t load_mse_index(const uint8_t* code, size_t j) const {
567
+ uint8_t idx = 0;
568
+ for (int p = 0; p < kMSEBits; p++) {
569
+ if (code[p * mse_plane_bytes + j / 8] & (1 << (j % 8))) {
570
+ idx |= (1 << p);
571
+ }
572
+ }
573
+ return idx;
574
+ }
575
+
576
+ void encode_vector(const float* x, uint8_t* code) const final {
577
+ float sqrt_d = std::sqrt(static_cast<float>(d));
578
+ float inv_sqrt_d = 1.0f / sqrt_d;
579
+
580
+ float x_norm = std::sqrt(fvec_norm_L2sqr(x, d));
581
+ if (x_norm < 1e-30f) {
582
+ x_norm = 1e-30f;
583
+ }
584
+
585
+ // MSE quantize in scaled space + compute residual
586
+ std::vector<float> residual(padded_d);
587
+ for (size_t j = 0; j < d; j++) {
588
+ float v = x[j] / x_norm; // unit-normalized
589
+ float val = v * sqrt_d; // scaled for MSE lookup
590
+ uint8_t idx = static_cast<uint8_t>(
591
+ std::upper_bound(
592
+ boundaries,
593
+ boundaries + (kCentroidsCount - 1),
594
+ val) -
595
+ boundaries);
596
+ store_mse_index(idx, code, j);
597
+ residual[j] = v - centroids[idx] * inv_sqrt_d;
598
+ }
599
+
600
+ // QJL: project residual, take signs
601
+ std::vector<float> proj(d);
602
+ project_forward(residual.data(), proj.data());
603
+
604
+ uint8_t* qjl_code = code + mse_total_bytes;
605
+ for (size_t j = 0; j < d; j++) {
606
+ if (proj[j] > 0.0f) {
607
+ rabitq_utils::set_bit_standard(qjl_code, j);
608
+ }
609
+ }
610
+
611
+ // Store per-vector factors
612
+ float gamma = std::sqrt(fvec_norm_L2sqr(residual.data(), d));
613
+ auto* factors = reinterpret_cast<SQTurboQFactors*>(
614
+ code + mse_total_bytes + qjl_plane_bytes);
615
+ factors->norm = x_norm;
616
+ factors->gamma = gamma;
617
+ }
618
+
619
+ void decode_vector(const uint8_t* code, float* x) const final {
620
+ float inv_sqrt_d = 1.0f / std::sqrt(static_cast<float>(d));
621
+ float inv_sqrt_pd = 1.0f / std::sqrt(static_cast<float>(padded_d));
622
+
623
+ const auto* factors = reinterpret_cast<const SQTurboQFactors*>(
624
+ code + mse_total_bytes + qjl_plane_bytes);
625
+
626
+ // MSE reconstruction
627
+ for (size_t j = 0; j < d; j++) {
628
+ uint8_t idx = load_mse_index(code, j);
629
+ x[j] = centroids[idx] * inv_sqrt_d;
630
+ }
631
+
632
+ // QJL reconstruction: coeff * gamma * S^T * signs
633
+ const uint8_t* qjl_code = code + mse_total_bytes;
634
+ float coeff =
635
+ std::sqrt(M_PI / 2.0f) / static_cast<float>(d) * factors->gamma;
636
+
637
+ std::vector<float> signs_buf(padded_d);
638
+ for (size_t j = 0; j < d; j++) {
639
+ signs_buf[j] = rabitq_utils::extract_bit_standard(qjl_code, j)
640
+ ? inv_sqrt_pd
641
+ : -inv_sqrt_pd;
642
+ }
643
+ for (size_t j = d; j < padded_d; j++) {
644
+ signs_buf[j] = 0.0f;
645
+ }
646
+
647
+ std::vector<float> reconstructed(d);
648
+ project_inverse(signs_buf.data(), reconstructed.data());
649
+ for (size_t j = 0; j < d; j++) {
650
+ x[j] += coeff * reconstructed[j];
651
+ }
652
+
653
+ // Scale by norm
654
+ for (size_t j = 0; j < d; j++) {
655
+ x[j] *= factors->norm;
656
+ }
657
+ }
658
+ };
659
+
660
+ template <int NBits, SIMDLevel SL>
661
+ struct QuantizerTurboQuantFull
662
+ : QuantizerTurboQuantFull<NBits, SIMDLevel::NONE> {
663
+ using QuantizerTurboQuantFull<NBits, SIMDLevel::NONE>::
664
+ QuantizerTurboQuantFull;
665
+ };
666
+
358
667
  /*******************************************************************
359
668
  * Selection function
360
669
  *******************************************************************/