faiss 0.6.1 → 0.6.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +6 -7
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +173 -143
- data/vendor/faiss/faiss/IndexIVF.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -3
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +4 -13
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +68 -6
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +1 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +902 -12
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +702 -10
- data/vendor/faiss/faiss/factory_tools.cpp +4 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +3 -2
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +11 -12
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +3 -3
- data/vendor/faiss/faiss/gpu_metal/MetalDistance.h +87 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndex.h +7 -0
- data/vendor/faiss/faiss/gpu_metal/MetalIndexIVFFlat.h +181 -0
- data/vendor/faiss/faiss/gpu_metal/MetalKernels.h +48 -3
- data/vendor/faiss/faiss/gpu_metal/MetalPythonBridge.h +45 -0
- data/vendor/faiss/faiss/gpu_metal/impl/MetalIVFFlat.h +193 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +556 -199
- data/vendor/faiss/faiss/impl/HNSW.h +51 -13
- data/vendor/faiss/faiss/impl/NSG.cpp +15 -11
- data/vendor/faiss/faiss/impl/Panorama.h +11 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -2
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +7 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +1 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +271 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +50 -0
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +10 -10
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -34
- data/vendor/faiss/faiss/impl/fast_scan/dispatching.h +3 -1
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.cpp +35 -43
- data/vendor/faiss/faiss/impl/hnsw/MinimaxHeap.h +64 -15
- data/vendor/faiss/faiss/impl/hnsw/avx2.cpp +86 -40
- data/vendor/faiss/faiss/impl/hnsw/avx512.cpp +81 -50
- data/vendor/faiss/faiss/impl/index_read.cpp +100 -39
- data/vendor/faiss/faiss/impl/index_write.cpp +1 -0
- data/vendor/faiss/faiss/impl/io_macros.h +25 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -8
- data/vendor/faiss/faiss/impl/pq_code_distance/avx2.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/avx512.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/neon.cpp +2 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +20 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +36 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-sve.cpp +5 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_scan_impl.h +105 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/rvv.cpp +2 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +6 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +327 -18
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +264 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-impl.h +553 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512-spr.cpp +559 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +199 -27
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +366 -3
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +144 -19
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-rvv.cpp +26 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +65 -8
- data/vendor/faiss/faiss/index_factory.cpp +5 -1
- data/vendor/faiss/faiss/index_io.h +16 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +4 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +13 -13
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +119 -22
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +15 -5
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +3 -2
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +65 -24
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +3 -2
- data/vendor/faiss/faiss/utils/bf16.h +34 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/faiss/utils/hamming.cpp +8 -8
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx2.cpp +2 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_avx512_spr.cpp +15 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512.h +6 -30
- data/vendor/faiss/faiss/utils/hamming_distance/hamming_computer-avx512_spr.h +171 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +0 -2
- data/vendor/faiss/faiss/utils/simd_impl/partitioning_simdlib256.h +14 -68
- data/vendor/faiss/faiss/utils/simd_impl/rabitq_avx512_spr.cpp +343 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +12 -2
- metadata +12 -2
|
@@ -7,14 +7,44 @@
|
|
|
7
7
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
|
+
#include <cmath>
|
|
11
|
+
|
|
12
|
+
// Hack for MSVC
|
|
13
|
+
#ifndef M_PI
|
|
14
|
+
#define M_PI 3.14159265358979323846
|
|
15
|
+
#endif
|
|
16
|
+
|
|
10
17
|
#include <algorithm>
|
|
18
|
+
#include <cstring>
|
|
11
19
|
|
|
12
20
|
#include <faiss/impl/FaissAssert.h>
|
|
21
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
13
22
|
#include <faiss/impl/ScalarQuantizer.h>
|
|
23
|
+
#include <faiss/impl/platform_macros.h>
|
|
14
24
|
#include <faiss/impl/simdlib/simdlib_dispatch.h>
|
|
15
25
|
#include <faiss/utils/bf16.h>
|
|
26
|
+
#include <faiss/utils/distances.h>
|
|
16
27
|
#include <faiss/utils/fp16.h>
|
|
28
|
+
#include <faiss/utils/random.h>
|
|
17
29
|
#include <faiss/utils/simd_levels.h>
|
|
30
|
+
#include <faiss/utils/utils.h>
|
|
31
|
+
|
|
32
|
+
extern "C" {
|
|
33
|
+
int sgemm_(
|
|
34
|
+
const char* transa,
|
|
35
|
+
const char* transb,
|
|
36
|
+
int* m,
|
|
37
|
+
int* n,
|
|
38
|
+
int* k,
|
|
39
|
+
const float* alpha,
|
|
40
|
+
const float* a,
|
|
41
|
+
int* lda,
|
|
42
|
+
const float* b,
|
|
43
|
+
int* ldb,
|
|
44
|
+
float* beta,
|
|
45
|
+
float* c,
|
|
46
|
+
int* ldc);
|
|
47
|
+
}
|
|
18
48
|
|
|
19
49
|
namespace faiss {
|
|
20
50
|
|
|
@@ -142,15 +172,14 @@ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
|
|
|
142
172
|
boundaries = trained.data() + kCentroidsCount;
|
|
143
173
|
}
|
|
144
174
|
|
|
145
|
-
|
|
175
|
+
uint8_t select_index(float x) const {
|
|
146
176
|
return static_cast<uint8_t>(
|
|
147
177
|
std::upper_bound(
|
|
148
178
|
boundaries, boundaries + (kCentroidsCount - 1), x) -
|
|
149
179
|
boundaries);
|
|
150
180
|
}
|
|
151
181
|
|
|
152
|
-
|
|
153
|
-
const {
|
|
182
|
+
void encode_index(uint8_t idx, uint8_t* code, size_t i) const {
|
|
154
183
|
const size_t bit_offset = i * NBits;
|
|
155
184
|
const size_t byte_offset = bit_offset >> 3;
|
|
156
185
|
const size_t bit_shift = bit_offset & 7;
|
|
@@ -162,8 +191,7 @@ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
|
|
|
162
191
|
}
|
|
163
192
|
}
|
|
164
193
|
|
|
165
|
-
|
|
166
|
-
decode_index(const uint8_t* code, size_t i) const {
|
|
194
|
+
uint8_t decode_index(const uint8_t* code, size_t i) const {
|
|
167
195
|
const size_t bit_offset = i * NBits;
|
|
168
196
|
const size_t byte_offset = bit_offset >> 3;
|
|
169
197
|
const size_t bit_shift = bit_offset & 7;
|
|
@@ -175,21 +203,19 @@ struct QuantizerTurboQuantMSE<NBits, SIMDLevel::NONE>
|
|
|
175
203
|
return static_cast<uint8_t>((packed >> bit_shift) & kIndexMask);
|
|
176
204
|
}
|
|
177
205
|
|
|
178
|
-
void encode_vector(const float* x, uint8_t* code) const
|
|
206
|
+
void encode_vector(const float* x, uint8_t* code) const override {
|
|
179
207
|
for (size_t i = 0; i < d; i++) {
|
|
180
208
|
encode_index(select_index(x[i]), code, i);
|
|
181
209
|
}
|
|
182
210
|
}
|
|
183
211
|
|
|
184
|
-
void decode_vector(const uint8_t* code, float* x) const
|
|
212
|
+
void decode_vector(const uint8_t* code, float* x) const override {
|
|
185
213
|
for (size_t i = 0; i < d; i++) {
|
|
186
214
|
x[i] = centroids[decode_index(code, i)];
|
|
187
215
|
}
|
|
188
216
|
}
|
|
189
217
|
|
|
190
|
-
|
|
191
|
-
const uint8_t* code,
|
|
192
|
-
size_t i) const {
|
|
218
|
+
float reconstruct_component(const uint8_t* code, size_t i) const {
|
|
193
219
|
return centroids[decode_index(code, i)];
|
|
194
220
|
}
|
|
195
221
|
};
|
|
@@ -252,16 +278,12 @@ struct QuantizerBF16<SIMDLevel::NONE> : ScalarQuantizer::SQuantizer {
|
|
|
252
278
|
QuantizerBF16(size_t d_in, const std::vector<float>& /* unused */)
|
|
253
279
|
: d(d_in) {}
|
|
254
280
|
|
|
255
|
-
void encode_vector(const float* x, uint8_t* code) const
|
|
256
|
-
|
|
257
|
-
((uint16_t*)code)[i] = encode_bf16(x[i]);
|
|
258
|
-
}
|
|
281
|
+
void encode_vector(const float* x, uint8_t* code) const override {
|
|
282
|
+
encode_bf16_simd(x, (uint16_t*)code, d);
|
|
259
283
|
}
|
|
260
284
|
|
|
261
|
-
void decode_vector(const uint8_t* code, float* x) const
|
|
262
|
-
|
|
263
|
-
x[i] = decode_bf16(((uint16_t*)code)[i]);
|
|
264
|
-
}
|
|
285
|
+
void decode_vector(const uint8_t* code, float* x) const override {
|
|
286
|
+
decode_bf16_simd((const uint16_t*)code, x, d);
|
|
265
287
|
}
|
|
266
288
|
|
|
267
289
|
FAISS_ALWAYS_INLINE float reconstruct_component(
|
|
@@ -276,6 +298,11 @@ struct QuantizerBF16 : QuantizerBF16<SIMDLevel::NONE> {
|
|
|
276
298
|
using QuantizerBF16<SIMDLevel::NONE>::QuantizerBF16;
|
|
277
299
|
};
|
|
278
300
|
|
|
301
|
+
template <>
|
|
302
|
+
struct QuantizerBF16<SIMDLevel::AVX512>;
|
|
303
|
+
template <>
|
|
304
|
+
struct QuantizerBF16<SIMDLevel::AVX512_SPR>;
|
|
305
|
+
|
|
279
306
|
/*******************************************************************
|
|
280
307
|
* 8bit_direct quantizer
|
|
281
308
|
*******************************************************************/
|
|
@@ -355,6 +382,288 @@ struct Quantizer8bitDirectSigned : Quantizer8bitDirectSigned<SIMDLevel::NONE> {
|
|
|
355
382
|
using Quantizer8bitDirectSigned<SIMDLevel::NONE>::Quantizer8bitDirectSigned;
|
|
356
383
|
};
|
|
357
384
|
|
|
385
|
+
/*******************************************************************
|
|
386
|
+
* Full TurboQuant (MSE + QJL) quantizer
|
|
387
|
+
*
|
|
388
|
+
* NBits = total bits per dimension (2-5).
|
|
389
|
+
* MSE bits = NBits - 1, QJL bits = 1.
|
|
390
|
+
*
|
|
391
|
+
* Trained vector layout:
|
|
392
|
+
* [centroids (k floats), boundaries (k-1 floats),
|
|
393
|
+
* seed_lo (float), seed_hi (float), qjl_type (float)]
|
|
394
|
+
* where k = 2^(NBits-1).
|
|
395
|
+
*******************************************************************/
|
|
396
|
+
|
|
397
|
+
FAISS_PACK_STRUCTS_BEGIN
|
|
398
|
+
struct SQTurboQFactors {
|
|
399
|
+
float norm = 0;
|
|
400
|
+
float gamma = 0;
|
|
401
|
+
};
|
|
402
|
+
FAISS_PACK_STRUCTS_END
|
|
403
|
+
|
|
404
|
+
template <int NBits, SIMDLevel SL>
|
|
405
|
+
struct QuantizerTurboQuantFull;
|
|
406
|
+
|
|
407
|
+
template <int NBits>
|
|
408
|
+
struct QuantizerTurboQuantFull<NBits, SIMDLevel::NONE>
|
|
409
|
+
: ScalarQuantizer::SQuantizer {
|
|
410
|
+
static_assert(NBits >= 2 && NBits <= 5);
|
|
411
|
+
|
|
412
|
+
static constexpr int kMSEBits = NBits - 1;
|
|
413
|
+
static constexpr size_t kCentroidsCount = size_t(1) << kMSEBits;
|
|
414
|
+
|
|
415
|
+
const size_t d;
|
|
416
|
+
const float* centroids;
|
|
417
|
+
const float* boundaries;
|
|
418
|
+
|
|
419
|
+
// QJL projection type: 0 = FWHT, 2 = Random Rotation
|
|
420
|
+
uint8_t qjl_type;
|
|
421
|
+
|
|
422
|
+
// FWHT state (qjl_type == 0)
|
|
423
|
+
size_t padded_d;
|
|
424
|
+
std::vector<float> fwht_signs;
|
|
425
|
+
|
|
426
|
+
// Random Rotation state (qjl_type == 2)
|
|
427
|
+
std::vector<float> rr_matrix; // d x d orthogonal matrix (row-major)
|
|
428
|
+
|
|
429
|
+
size_t mse_plane_bytes; // bytes for one bit-plane of d bits
|
|
430
|
+
size_t mse_total_bytes; // kMSEBits * mse_plane_bytes
|
|
431
|
+
size_t qjl_plane_bytes;
|
|
432
|
+
|
|
433
|
+
QuantizerTurboQuantFull(size_t d_in, const std::vector<float>& trained)
|
|
434
|
+
: d(d_in),
|
|
435
|
+
centroids(trained.data()),
|
|
436
|
+
boundaries(trained.data() + kCentroidsCount) {
|
|
437
|
+
// trained = [centroids(k), boundaries(k-1), seed_lo, seed_hi, qjl_type]
|
|
438
|
+
size_t k = kCentroidsCount;
|
|
439
|
+
FAISS_THROW_IF_NOT(trained.size() == 2 * k - 1 + 3);
|
|
440
|
+
|
|
441
|
+
mse_plane_bytes = (d + 7) / 8;
|
|
442
|
+
mse_total_bytes = kMSEBits * mse_plane_bytes;
|
|
443
|
+
qjl_plane_bytes = (d + 7) / 8;
|
|
444
|
+
|
|
445
|
+
// Extract seed from trained
|
|
446
|
+
uint64_t seed = ScalarQuantizer::TurboQuantRefine::unpack_seed(
|
|
447
|
+
trained[2 * k - 1], trained[2 * k]);
|
|
448
|
+
qjl_type = static_cast<uint8_t>(trained[2 * k + 1]);
|
|
449
|
+
|
|
450
|
+
if (qjl_type == 0) {
|
|
451
|
+
// FWHT mode
|
|
452
|
+
padded_d = 1;
|
|
453
|
+
while (padded_d < d) {
|
|
454
|
+
padded_d <<= 1;
|
|
455
|
+
}
|
|
456
|
+
fwht_signs.resize(padded_d);
|
|
457
|
+
RandomGenerator rng(seed);
|
|
458
|
+
for (size_t i = 0; i < padded_d; i++) {
|
|
459
|
+
fwht_signs[i] = (rng.rand_int(2) == 0) ? 1.0f : -1.0f;
|
|
460
|
+
}
|
|
461
|
+
} else {
|
|
462
|
+
// Random Rotation mode
|
|
463
|
+
padded_d = d; // no padding needed for dense multiply
|
|
464
|
+
rr_matrix.resize(d * d);
|
|
465
|
+
float_randn(rr_matrix.data(), d * d, static_cast<int64_t>(seed));
|
|
466
|
+
matrix_qr(
|
|
467
|
+
static_cast<int>(d), static_cast<int>(d), rr_matrix.data());
|
|
468
|
+
}
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
void fwht_inplace(float* x, size_t n) const {
|
|
472
|
+
for (size_t h = 1; h < n; h <<= 1) {
|
|
473
|
+
for (size_t i = 0; i < n; i += h << 1) {
|
|
474
|
+
for (size_t j = i; j < i + h; j++) {
|
|
475
|
+
float a = x[j];
|
|
476
|
+
float b = x[j + h];
|
|
477
|
+
x[j] = a + b;
|
|
478
|
+
x[j + h] = a - b;
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
/// Forward QJL projection: residual -> projected (d outputs)
|
|
485
|
+
void project_forward(const float* residual, float* out) const {
|
|
486
|
+
if (qjl_type == 0) {
|
|
487
|
+
std::vector<float> fwht_buf(padded_d);
|
|
488
|
+
for (size_t j = 0; j < d; j++) {
|
|
489
|
+
fwht_buf[j] = residual[j] * fwht_signs[j];
|
|
490
|
+
}
|
|
491
|
+
for (size_t j = d; j < padded_d; j++) {
|
|
492
|
+
fwht_buf[j] = 0.0f;
|
|
493
|
+
}
|
|
494
|
+
fwht_inplace(fwht_buf.data(), padded_d);
|
|
495
|
+
for (size_t j = 0; j < d; j++) {
|
|
496
|
+
out[j] = fwht_buf[j];
|
|
497
|
+
}
|
|
498
|
+
} else {
|
|
499
|
+
rr_forward(residual, out);
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
/// Inverse QJL projection: signs_buf -> reconstructed (d outputs)
|
|
504
|
+
void project_inverse(float* signs_buf, float* out) const {
|
|
505
|
+
if (qjl_type == 0) {
|
|
506
|
+
fwht_inplace(signs_buf, padded_d);
|
|
507
|
+
for (size_t j = 0; j < d; j++) {
|
|
508
|
+
out[j] = signs_buf[j] * fwht_signs[j];
|
|
509
|
+
}
|
|
510
|
+
} else {
|
|
511
|
+
rr_inverse(signs_buf, out);
|
|
512
|
+
}
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
void rr_forward(const float* x, float* out) const {
|
|
516
|
+
float alpha = 1.0f;
|
|
517
|
+
float beta = 0.0f;
|
|
518
|
+
int di = static_cast<int>(d);
|
|
519
|
+
int one = 1;
|
|
520
|
+
sgemm_("T",
|
|
521
|
+
"N",
|
|
522
|
+
&di,
|
|
523
|
+
&one,
|
|
524
|
+
&di,
|
|
525
|
+
&alpha,
|
|
526
|
+
rr_matrix.data(),
|
|
527
|
+
&di,
|
|
528
|
+
x,
|
|
529
|
+
&di,
|
|
530
|
+
&beta,
|
|
531
|
+
out,
|
|
532
|
+
&di);
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
void rr_inverse(const float* x, float* out) const {
|
|
536
|
+
float alpha = 1.0f;
|
|
537
|
+
float beta = 0.0f;
|
|
538
|
+
int di = static_cast<int>(d);
|
|
539
|
+
int one = 1;
|
|
540
|
+
sgemm_("N",
|
|
541
|
+
"N",
|
|
542
|
+
&di,
|
|
543
|
+
&one,
|
|
544
|
+
&di,
|
|
545
|
+
&alpha,
|
|
546
|
+
rr_matrix.data(),
|
|
547
|
+
&di,
|
|
548
|
+
x,
|
|
549
|
+
&di,
|
|
550
|
+
&beta,
|
|
551
|
+
out,
|
|
552
|
+
&di);
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
/// Store MSE index for dimension j using BIT-PLANE layout.
|
|
556
|
+
/// Plane p stores bit p of every dimension's index.
|
|
557
|
+
void store_mse_index(uint8_t idx, uint8_t* code, size_t j) const {
|
|
558
|
+
for (int p = 0; p < kMSEBits; p++) {
|
|
559
|
+
if (idx & (1 << p)) {
|
|
560
|
+
code[p * mse_plane_bytes + j / 8] |= (1 << (j % 8));
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
/// Load MSE index for dimension j from BIT-PLANE layout.
|
|
566
|
+
uint8_t load_mse_index(const uint8_t* code, size_t j) const {
|
|
567
|
+
uint8_t idx = 0;
|
|
568
|
+
for (int p = 0; p < kMSEBits; p++) {
|
|
569
|
+
if (code[p * mse_plane_bytes + j / 8] & (1 << (j % 8))) {
|
|
570
|
+
idx |= (1 << p);
|
|
571
|
+
}
|
|
572
|
+
}
|
|
573
|
+
return idx;
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
void encode_vector(const float* x, uint8_t* code) const final {
|
|
577
|
+
float sqrt_d = std::sqrt(static_cast<float>(d));
|
|
578
|
+
float inv_sqrt_d = 1.0f / sqrt_d;
|
|
579
|
+
|
|
580
|
+
float x_norm = std::sqrt(fvec_norm_L2sqr(x, d));
|
|
581
|
+
if (x_norm < 1e-30f) {
|
|
582
|
+
x_norm = 1e-30f;
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
// MSE quantize in scaled space + compute residual
|
|
586
|
+
std::vector<float> residual(padded_d);
|
|
587
|
+
for (size_t j = 0; j < d; j++) {
|
|
588
|
+
float v = x[j] / x_norm; // unit-normalized
|
|
589
|
+
float val = v * sqrt_d; // scaled for MSE lookup
|
|
590
|
+
uint8_t idx = static_cast<uint8_t>(
|
|
591
|
+
std::upper_bound(
|
|
592
|
+
boundaries,
|
|
593
|
+
boundaries + (kCentroidsCount - 1),
|
|
594
|
+
val) -
|
|
595
|
+
boundaries);
|
|
596
|
+
store_mse_index(idx, code, j);
|
|
597
|
+
residual[j] = v - centroids[idx] * inv_sqrt_d;
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
// QJL: project residual, take signs
|
|
601
|
+
std::vector<float> proj(d);
|
|
602
|
+
project_forward(residual.data(), proj.data());
|
|
603
|
+
|
|
604
|
+
uint8_t* qjl_code = code + mse_total_bytes;
|
|
605
|
+
for (size_t j = 0; j < d; j++) {
|
|
606
|
+
if (proj[j] > 0.0f) {
|
|
607
|
+
rabitq_utils::set_bit_standard(qjl_code, j);
|
|
608
|
+
}
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
// Store per-vector factors
|
|
612
|
+
float gamma = std::sqrt(fvec_norm_L2sqr(residual.data(), d));
|
|
613
|
+
auto* factors = reinterpret_cast<SQTurboQFactors*>(
|
|
614
|
+
code + mse_total_bytes + qjl_plane_bytes);
|
|
615
|
+
factors->norm = x_norm;
|
|
616
|
+
factors->gamma = gamma;
|
|
617
|
+
}
|
|
618
|
+
|
|
619
|
+
void decode_vector(const uint8_t* code, float* x) const final {
|
|
620
|
+
float inv_sqrt_d = 1.0f / std::sqrt(static_cast<float>(d));
|
|
621
|
+
float inv_sqrt_pd = 1.0f / std::sqrt(static_cast<float>(padded_d));
|
|
622
|
+
|
|
623
|
+
const auto* factors = reinterpret_cast<const SQTurboQFactors*>(
|
|
624
|
+
code + mse_total_bytes + qjl_plane_bytes);
|
|
625
|
+
|
|
626
|
+
// MSE reconstruction
|
|
627
|
+
for (size_t j = 0; j < d; j++) {
|
|
628
|
+
uint8_t idx = load_mse_index(code, j);
|
|
629
|
+
x[j] = centroids[idx] * inv_sqrt_d;
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
// QJL reconstruction: coeff * gamma * S^T * signs
|
|
633
|
+
const uint8_t* qjl_code = code + mse_total_bytes;
|
|
634
|
+
float coeff =
|
|
635
|
+
std::sqrt(M_PI / 2.0f) / static_cast<float>(d) * factors->gamma;
|
|
636
|
+
|
|
637
|
+
std::vector<float> signs_buf(padded_d);
|
|
638
|
+
for (size_t j = 0; j < d; j++) {
|
|
639
|
+
signs_buf[j] = rabitq_utils::extract_bit_standard(qjl_code, j)
|
|
640
|
+
? inv_sqrt_pd
|
|
641
|
+
: -inv_sqrt_pd;
|
|
642
|
+
}
|
|
643
|
+
for (size_t j = d; j < padded_d; j++) {
|
|
644
|
+
signs_buf[j] = 0.0f;
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
std::vector<float> reconstructed(d);
|
|
648
|
+
project_inverse(signs_buf.data(), reconstructed.data());
|
|
649
|
+
for (size_t j = 0; j < d; j++) {
|
|
650
|
+
x[j] += coeff * reconstructed[j];
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
// Scale by norm
|
|
654
|
+
for (size_t j = 0; j < d; j++) {
|
|
655
|
+
x[j] *= factors->norm;
|
|
656
|
+
}
|
|
657
|
+
}
|
|
658
|
+
};
|
|
659
|
+
|
|
660
|
+
template <int NBits, SIMDLevel SL>
|
|
661
|
+
struct QuantizerTurboQuantFull
|
|
662
|
+
: QuantizerTurboQuantFull<NBits, SIMDLevel::NONE> {
|
|
663
|
+
using QuantizerTurboQuantFull<NBits, SIMDLevel::NONE>::
|
|
664
|
+
QuantizerTurboQuantFull;
|
|
665
|
+
};
|
|
666
|
+
|
|
358
667
|
/*******************************************************************
|
|
359
668
|
* Selection function
|
|
360
669
|
*******************************************************************/
|