faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -8,39 +8,61 @@
|
|
|
8
8
|
#include <faiss/impl/RaBitQuantizer.h>
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
12
|
+
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
11
13
|
#include <faiss/utils/distances.h>
|
|
12
14
|
#include <faiss/utils/rabitq_simd.h>
|
|
13
15
|
#include <algorithm>
|
|
14
16
|
#include <cmath>
|
|
15
17
|
#include <cstring>
|
|
16
|
-
#include <limits>
|
|
17
18
|
#include <memory>
|
|
18
19
|
#include <vector>
|
|
19
20
|
|
|
20
21
|
namespace faiss {
|
|
21
22
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
23
|
+
// Import shared utilities from RaBitQUtils
|
|
24
|
+
using rabitq_utils::ExtraBitsFactors;
|
|
25
|
+
using rabitq_utils::QueryFactorsData;
|
|
26
|
+
using rabitq_utils::SignBitFactors;
|
|
27
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
28
|
+
|
|
29
|
+
RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
|
|
30
|
+
: Quantizer(d, 0), // code_size will be set below
|
|
31
|
+
metric_type{metric},
|
|
32
|
+
nb_bits{nb_bits} {
|
|
33
|
+
// Validate nb_bits range
|
|
34
|
+
FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
|
|
35
|
+
|
|
36
|
+
// Set code_size using compute_code_size
|
|
37
|
+
code_size = compute_code_size(d, nb_bits);
|
|
38
|
+
}
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
40
|
+
size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
41
|
+
// Validate inputs
|
|
42
|
+
FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
|
|
43
|
+
|
|
44
|
+
size_t ex_bits = num_bits - 1;
|
|
45
|
+
|
|
46
|
+
// Base: 1-bit codes + base factors
|
|
47
|
+
// Layout for 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
|
|
48
|
+
// base_factors = or_minus_c_l2sqr (4) + dp_multiplier (4)
|
|
49
|
+
// Layout for multi-bit: [binary_code: (d+7)/8
|
|
50
|
+
// bytes][SignBitFactorsWithError: 12 bytes]
|
|
51
|
+
// factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
|
|
52
|
+
size_t base_size = (d + 7) / 8 +
|
|
53
|
+
(ex_bits == 0 ? sizeof(SignBitFactors)
|
|
54
|
+
: sizeof(SignBitFactorsWithError));
|
|
55
|
+
|
|
56
|
+
// Extra: ex-bit codes + ex factors (only if ex_bits > 0)
|
|
57
|
+
// Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
|
|
58
|
+
size_t ex_size = 0;
|
|
59
|
+
if (ex_bits > 0) {
|
|
60
|
+
ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
|
|
61
|
+
}
|
|
36
62
|
|
|
37
|
-
|
|
38
|
-
return (d + 7) / 8 + sizeof(FactorsData);
|
|
63
|
+
return base_size + ex_size;
|
|
39
64
|
}
|
|
40
65
|
|
|
41
|
-
RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric)
|
|
42
|
-
: Quantizer(d, get_code_size(d)), metric_type{metric} {}
|
|
43
|
-
|
|
44
66
|
void RaBitQuantizer::train(size_t n, const float* x) {
|
|
45
67
|
// does nothing
|
|
46
68
|
}
|
|
@@ -65,68 +87,85 @@ void RaBitQuantizer::compute_codes_core(
|
|
|
65
87
|
return;
|
|
66
88
|
}
|
|
67
89
|
|
|
68
|
-
|
|
69
|
-
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
90
|
+
const size_t ex_bits = nb_bits - 1;
|
|
70
91
|
|
|
71
|
-
//
|
|
92
|
+
// Compute codes
|
|
72
93
|
#pragma omp parallel for if (n > 1000)
|
|
73
94
|
for (int64_t i = 0; i < n; i++) {
|
|
74
|
-
//
|
|
75
|
-
float norm_L2sqr = 0;
|
|
76
|
-
// ||or||^2, which is equal to ||P(or)||^2 and ||P^(-1)(or)||^2
|
|
77
|
-
float or_L2sqr = 0;
|
|
78
|
-
// dot product
|
|
79
|
-
float dp_oO = 0;
|
|
80
|
-
|
|
81
|
-
// the code
|
|
95
|
+
// Pointer to this vector's code
|
|
82
96
|
uint8_t* code = codes + i * code_size;
|
|
83
|
-
FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8);
|
|
84
97
|
|
|
85
|
-
//
|
|
86
|
-
|
|
87
|
-
|
|
98
|
+
// Clear code memory
|
|
99
|
+
memset(code, 0, code_size);
|
|
100
|
+
|
|
101
|
+
const float* x_row = x + i * d;
|
|
102
|
+
|
|
103
|
+
// Pointer arithmetic for code layout:
|
|
104
|
+
// For 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
|
|
105
|
+
// For multi-bit: [binary_code: (d+7)/8 bytes][SignBitFactorsWithError:
|
|
106
|
+
// 12 bytes]
|
|
107
|
+
// [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
|
|
108
|
+
uint8_t* binary_code = code;
|
|
109
|
+
|
|
110
|
+
// Step 1: Compute 1-bit quantization and base factors
|
|
111
|
+
// Store residual for potential ex-bits quantization
|
|
112
|
+
std::vector<float> residual(d);
|
|
113
|
+
|
|
114
|
+
// Use shared utilities for computing factors
|
|
115
|
+
SignBitFactorsWithError factors_data =
|
|
116
|
+
rabitq_utils::compute_vector_factors(
|
|
117
|
+
x_row, d, centroid_in, metric_type, ex_bits > 0);
|
|
118
|
+
|
|
119
|
+
// Write appropriate factors based on nb_bits
|
|
120
|
+
if (ex_bits == 0) {
|
|
121
|
+
// For 1-bit: write only SignBitFactors (8 bytes)
|
|
122
|
+
SignBitFactors* base_factors =
|
|
123
|
+
reinterpret_cast<SignBitFactors*>(code + (d + 7) / 8);
|
|
124
|
+
base_factors->or_minus_c_l2sqr = factors_data.or_minus_c_l2sqr;
|
|
125
|
+
base_factors->dp_multiplier = factors_data.dp_multiplier;
|
|
126
|
+
} else {
|
|
127
|
+
// For multi-bit: write full SignBitFactorsWithError (12 bytes)
|
|
128
|
+
SignBitFactorsWithError* full_factors =
|
|
129
|
+
reinterpret_cast<SignBitFactorsWithError*>(
|
|
130
|
+
code + (d + 7) / 8);
|
|
131
|
+
*full_factors = factors_data;
|
|
88
132
|
}
|
|
89
133
|
|
|
134
|
+
// Pack bits into standard RaBitQ format
|
|
90
135
|
for (size_t j = 0; j < d; j++) {
|
|
91
|
-
const float
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
const bool xb = (or_minus_c > 0);
|
|
136
|
+
const float x_val = x_row[j];
|
|
137
|
+
const float centroid_val =
|
|
138
|
+
(centroid_in == nullptr) ? 0.0f : centroid_in[j];
|
|
139
|
+
const float or_minus_c = x_val - centroid_val;
|
|
140
|
+
residual[j] = or_minus_c;
|
|
97
141
|
|
|
98
|
-
|
|
142
|
+
const bool xb = (or_minus_c > 0.0f);
|
|
99
143
|
|
|
100
|
-
//
|
|
101
|
-
if (
|
|
102
|
-
|
|
103
|
-
// enable a particular bit
|
|
104
|
-
code[j / 8] |= (1 << (j % 8));
|
|
105
|
-
}
|
|
144
|
+
// Store the 1-bit sign code
|
|
145
|
+
if (xb) {
|
|
146
|
+
rabitq_utils::set_bit_standard(binary_code, j);
|
|
106
147
|
}
|
|
107
148
|
}
|
|
108
149
|
|
|
109
|
-
//
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
150
|
+
// Step 2: Compute ex-bits quantization (if nb_bits > 1)
|
|
151
|
+
if (ex_bits > 0) {
|
|
152
|
+
// Pointer to ex-bit code section
|
|
153
|
+
uint8_t* ex_code =
|
|
154
|
+
code + (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
155
|
+
// Pointer to ex-factors section
|
|
156
|
+
ExtraBitsFactors* ex_factors = reinterpret_cast<ExtraBitsFactors*>(
|
|
157
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
158
|
+
|
|
159
|
+
// Quantize residual to ex-bits (pass centroid for IP metric)
|
|
160
|
+
rabitq_multibit::quantize_ex_bits(
|
|
161
|
+
residual.data(),
|
|
162
|
+
d,
|
|
163
|
+
nb_bits,
|
|
164
|
+
ex_code,
|
|
165
|
+
*ex_factors,
|
|
166
|
+
metric_type,
|
|
167
|
+
centroid_in);
|
|
127
168
|
}
|
|
128
|
-
|
|
129
|
-
fac->dp_multiplier = inv_dp_oO * std::sqrt(norm_L2sqr);
|
|
130
169
|
}
|
|
131
170
|
}
|
|
132
171
|
|
|
@@ -143,6 +182,7 @@ void RaBitQuantizer::decode_core(
|
|
|
143
182
|
FAISS_ASSERT(x != nullptr);
|
|
144
183
|
|
|
145
184
|
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
185
|
+
const size_t ex_bits = nb_bits - 1;
|
|
146
186
|
|
|
147
187
|
#pragma omp parallel for if (n > 1000)
|
|
148
188
|
for (int64_t i = 0; i < n; i++) {
|
|
@@ -150,10 +190,19 @@ void RaBitQuantizer::decode_core(
|
|
|
150
190
|
|
|
151
191
|
// split the code into parts
|
|
152
192
|
const uint8_t* binary_data = code;
|
|
153
|
-
const FactorsData* fac =
|
|
154
|
-
reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
|
|
155
193
|
|
|
194
|
+
// Cast to appropriate type based on nb_bits
|
|
195
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
196
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes, but only first
|
|
197
|
+
// 8 bytes used for decode)
|
|
198
|
+
const SignBitFactors* fac = (ex_bits == 0)
|
|
199
|
+
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
200
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
201
|
+
code + (d + 7) / 8);
|
|
202
|
+
|
|
203
|
+
// this is the baseline code
|
|
156
204
|
//
|
|
205
|
+
// compute <q,o> using floats
|
|
157
206
|
for (size_t j = 0; j < d; j++) {
|
|
158
207
|
// extract i-th bit
|
|
159
208
|
const uint8_t masker = (1 << (j % 8));
|
|
@@ -166,51 +215,69 @@ void RaBitQuantizer::decode_core(
|
|
|
166
215
|
}
|
|
167
216
|
}
|
|
168
217
|
|
|
169
|
-
|
|
170
|
-
// dimensionality
|
|
171
|
-
size_t d = 0;
|
|
172
|
-
// a centroid to use
|
|
173
|
-
const float* centroid = nullptr;
|
|
218
|
+
// Implementation of RaBitQDistanceComputer (declared in header)
|
|
174
219
|
|
|
175
|
-
|
|
176
|
-
|
|
220
|
+
float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
|
|
221
|
+
FAISS_ASSERT(code != nullptr);
|
|
177
222
|
|
|
178
|
-
|
|
223
|
+
// Compute estimated distance using 1-bit codes
|
|
224
|
+
float est_distance = distance_to_code_1bit(code);
|
|
179
225
|
|
|
180
|
-
|
|
181
|
-
|
|
226
|
+
// Extract f_error from the code
|
|
227
|
+
size_t size = (d + 7) / 8;
|
|
228
|
+
const SignBitFactorsWithError* base_fac =
|
|
229
|
+
reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
230
|
+
float f_error = base_fac->f_error;
|
|
182
231
|
|
|
183
|
-
|
|
232
|
+
// Compute proper lower bound using RaBitQ error formula:
|
|
233
|
+
// lower_bound = est_distance - f_error * g_error
|
|
234
|
+
// This guarantees: lower_bound ≤ true_distance
|
|
235
|
+
float lower_bound = est_distance - (f_error * g_error);
|
|
184
236
|
|
|
185
|
-
|
|
186
|
-
|
|
237
|
+
// Distance cannot be negative
|
|
238
|
+
return std::max(0.0f, lower_bound);
|
|
187
239
|
}
|
|
188
240
|
|
|
189
|
-
|
|
241
|
+
namespace {
|
|
242
|
+
|
|
243
|
+
struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
|
|
190
244
|
// the rotated query (qr - c)
|
|
191
245
|
std::vector<float> rotated_q;
|
|
192
246
|
// some additional numbers for the query
|
|
193
247
|
QueryFactorsData query_fac;
|
|
194
248
|
|
|
195
|
-
|
|
249
|
+
RaBitQDistanceComputerNotQ();
|
|
196
250
|
|
|
197
|
-
|
|
251
|
+
// Compute distance using only 1-bit codes (fast)
|
|
252
|
+
float distance_to_code_1bit(const uint8_t* code) override;
|
|
253
|
+
|
|
254
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
255
|
+
float distance_to_code_full(const uint8_t* code) override;
|
|
198
256
|
|
|
199
257
|
void set_query(const float* x) override;
|
|
200
258
|
};
|
|
201
259
|
|
|
202
|
-
|
|
260
|
+
RaBitQDistanceComputerNotQ::RaBitQDistanceComputerNotQ() = default;
|
|
203
261
|
|
|
204
|
-
float
|
|
262
|
+
float RaBitQDistanceComputerNotQ::distance_to_code_1bit(const uint8_t* code) {
|
|
205
263
|
FAISS_ASSERT(code != nullptr);
|
|
206
264
|
FAISS_ASSERT(
|
|
207
265
|
(metric_type == MetricType::METRIC_L2 ||
|
|
208
266
|
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
267
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
209
268
|
|
|
210
269
|
// split the code into parts
|
|
211
270
|
const uint8_t* binary_data = code;
|
|
212
|
-
|
|
213
|
-
|
|
271
|
+
|
|
272
|
+
// Cast to appropriate type based on nb_bits
|
|
273
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
274
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
275
|
+
// f_error
|
|
276
|
+
size_t ex_bits = nb_bits - 1;
|
|
277
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
278
|
+
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
279
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
280
|
+
code + (d + 7) / 8);
|
|
214
281
|
|
|
215
282
|
// this is the baseline code
|
|
216
283
|
//
|
|
@@ -219,48 +286,70 @@ float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
|
|
|
219
286
|
// It was a willful decision (after the discussion) to not to pre-cache
|
|
220
287
|
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
221
288
|
uint64_t sum_q = 0;
|
|
222
|
-
for (size_t i = 0; i < d; i++) {
|
|
223
|
-
// extract i-th bit
|
|
224
|
-
const uint8_t masker = (1 << (i % 8));
|
|
225
|
-
const bool b_bit = ((binary_data[i / 8] & masker) == masker);
|
|
226
289
|
|
|
290
|
+
for (size_t i = 0; i < d; i++) {
|
|
291
|
+
// Extract i-th bit
|
|
292
|
+
bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
|
|
227
293
|
// accumulate dp
|
|
228
|
-
dot_qo +=
|
|
294
|
+
dot_qo += bit ? rotated_q[i] : 0;
|
|
229
295
|
// accumulate sum-of-bits
|
|
230
|
-
sum_q +=
|
|
296
|
+
sum_q += bit ? 1 : 0;
|
|
231
297
|
}
|
|
232
298
|
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
// normalizer coefficients
|
|
237
|
-
final_dot += query_fac.c2 * sum_q;
|
|
238
|
-
// normalizer coefficients
|
|
239
|
-
final_dot -= query_fac.c34;
|
|
240
|
-
|
|
241
|
-
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
242
|
-
const float or_c_l2sqr = fac->or_minus_c_l2sqr;
|
|
299
|
+
// Apply query factors
|
|
300
|
+
float final_dot =
|
|
301
|
+
query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
|
|
243
302
|
|
|
244
303
|
// pre_dist = ||or - c||^2 + ||qr - c||^2 -
|
|
245
304
|
// 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
|
|
246
|
-
|
|
247
|
-
2 *
|
|
305
|
+
float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
|
|
306
|
+
2 * base_fac->dp_multiplier * final_dot;
|
|
248
307
|
|
|
249
308
|
if (metric_type == MetricType::METRIC_L2) {
|
|
250
309
|
// ||or - q||^ 2
|
|
251
310
|
return pre_dist;
|
|
252
311
|
} else {
|
|
253
312
|
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
313
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
314
|
+
}
|
|
315
|
+
}
|
|
254
316
|
|
|
255
|
-
|
|
256
|
-
|
|
317
|
+
float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
|
|
318
|
+
FAISS_ASSERT(code != nullptr);
|
|
319
|
+
FAISS_ASSERT(
|
|
320
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
321
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
322
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
257
323
|
|
|
258
|
-
|
|
259
|
-
|
|
324
|
+
size_t ex_bits = nb_bits - 1;
|
|
325
|
+
|
|
326
|
+
if (ex_bits == 0) {
|
|
327
|
+
// No ex-bits, just return 1-bit distance
|
|
328
|
+
return distance_to_code_1bit(code);
|
|
260
329
|
}
|
|
330
|
+
|
|
331
|
+
// Extract pointers to code sections
|
|
332
|
+
const uint8_t* binary_data = code;
|
|
333
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
334
|
+
const uint8_t* ex_code = code + offset;
|
|
335
|
+
const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
|
|
336
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
337
|
+
|
|
338
|
+
// Call shared utility directly with rotated_q pointer
|
|
339
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
340
|
+
binary_data,
|
|
341
|
+
ex_code,
|
|
342
|
+
*ex_fac,
|
|
343
|
+
rotated_q.data(),
|
|
344
|
+
query_fac.qr_to_c_L2sqr,
|
|
345
|
+
query_fac.qr_norm_L2sqr,
|
|
346
|
+
d,
|
|
347
|
+
ex_bits,
|
|
348
|
+
metric_type);
|
|
261
349
|
}
|
|
262
350
|
|
|
263
|
-
void
|
|
351
|
+
void RaBitQDistanceComputerNotQ::set_query(const float* x) {
|
|
352
|
+
q = x;
|
|
264
353
|
FAISS_ASSERT(x != nullptr);
|
|
265
354
|
FAISS_ASSERT(
|
|
266
355
|
(metric_type == MetricType::METRIC_L2 ||
|
|
@@ -279,6 +368,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
|
|
|
279
368
|
rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
|
|
280
369
|
}
|
|
281
370
|
|
|
371
|
+
// Compute g_error (query norm for lower bound computation)
|
|
372
|
+
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
373
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
374
|
+
|
|
282
375
|
// compute some numbers
|
|
283
376
|
const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
284
377
|
|
|
@@ -299,8 +392,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
|
|
|
299
392
|
}
|
|
300
393
|
|
|
301
394
|
//
|
|
302
|
-
struct
|
|
395
|
+
struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
|
|
303
396
|
// the rotated and quantized query (qr - c)
|
|
397
|
+
std::vector<float> rotated_q;
|
|
398
|
+
// the rotated and quantized query (qr - c) for fast 1-bit computation
|
|
304
399
|
std::vector<uint8_t> rotated_qq;
|
|
305
400
|
// we're using the proposed relayout-ed scheme from 3.3 that allows
|
|
306
401
|
// using popcounts for computing the distance.
|
|
@@ -310,149 +405,138 @@ struct RaBitDistanceComputerQ : RaBitDistanceComputer {
|
|
|
310
405
|
|
|
311
406
|
// the number of bits for SQ quantization of the query (qb > 0)
|
|
312
407
|
uint8_t qb = 8;
|
|
408
|
+
bool centered = false;
|
|
313
409
|
// the smallest value divisible by 8 that is not smaller than dim
|
|
314
410
|
size_t popcount_aligned_dim = 0;
|
|
315
411
|
|
|
316
|
-
|
|
412
|
+
RaBitQDistanceComputerQ();
|
|
317
413
|
|
|
318
|
-
|
|
414
|
+
// Compute distance using only 1-bit codes (fast)
|
|
415
|
+
float distance_to_code_1bit(const uint8_t* code) override;
|
|
416
|
+
|
|
417
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
418
|
+
float distance_to_code_full(const uint8_t* code) override;
|
|
319
419
|
|
|
320
420
|
void set_query(const float* x) override;
|
|
321
421
|
};
|
|
322
422
|
|
|
323
|
-
|
|
423
|
+
RaBitQDistanceComputerQ::RaBitQDistanceComputerQ() = default;
|
|
324
424
|
|
|
325
|
-
float
|
|
425
|
+
float RaBitQDistanceComputerQ::distance_to_code_1bit(const uint8_t* code) {
|
|
326
426
|
FAISS_ASSERT(code != nullptr);
|
|
327
427
|
FAISS_ASSERT(
|
|
328
428
|
(metric_type == MetricType::METRIC_L2 ||
|
|
329
429
|
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
330
430
|
|
|
331
431
|
// split the code into parts
|
|
432
|
+
size_t size = (d + 7) / 8;
|
|
332
433
|
const uint8_t* binary_data = code;
|
|
333
|
-
const FactorsData* fac =
|
|
334
|
-
reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
|
|
335
|
-
|
|
336
|
-
// // this is the baseline code
|
|
337
|
-
// //
|
|
338
|
-
// // compute <q,o> using integers
|
|
339
|
-
// size_t dot_qo = 0;
|
|
340
|
-
// for (size_t i = 0; i < d; i++) {
|
|
341
|
-
// // extract i-th bit
|
|
342
|
-
// const uint8_t masker = (1 << (i % 8));
|
|
343
|
-
// const uint8_t bit = ((binary_data[i / 8] & masker) == masker) ? 1 :
|
|
344
|
-
// 0;
|
|
345
|
-
//
|
|
346
|
-
// // accumulate dp
|
|
347
|
-
// dot_qo += bit * rotated_qq[i];
|
|
348
|
-
// }
|
|
349
434
|
|
|
350
|
-
//
|
|
351
|
-
|
|
352
|
-
|
|
435
|
+
// Cast to appropriate type based on nb_bits
|
|
436
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
437
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
438
|
+
// f_error
|
|
439
|
+
size_t ex_bits = nb_bits - 1;
|
|
440
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
441
|
+
? reinterpret_cast<const SignBitFactors*>(code + size)
|
|
442
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
353
443
|
|
|
354
|
-
//
|
|
355
|
-
float
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
444
|
+
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
445
|
+
float final_dot = 0;
|
|
446
|
+
if (centered) {
|
|
447
|
+
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
448
|
+
// See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
|
|
449
|
+
int_dot -= 2 *
|
|
450
|
+
rabitq::bitwise_xor_dot_product(
|
|
451
|
+
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
452
|
+
final_dot += int_dot * query_fac.int_dot_scale;
|
|
453
|
+
} else {
|
|
454
|
+
auto dot_qo = rabitq::bitwise_and_dot_product(
|
|
455
|
+
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
456
|
+
// It was a willful decision (after the discussion) to not to pre-cache
|
|
457
|
+
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
362
458
|
// process 64-bit popcounts
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
//
|
|
369
|
-
|
|
370
|
-
const auto yv = *(binary_data + i);
|
|
371
|
-
sum_q += __builtin_popcount(yv);
|
|
372
|
-
}
|
|
459
|
+
auto sum_q = rabitq::popcount(binary_data, size);
|
|
460
|
+
// dot-product itself
|
|
461
|
+
final_dot += query_fac.c1 * dot_qo;
|
|
462
|
+
// normalizer coefficients
|
|
463
|
+
final_dot += query_fac.c2 * sum_q;
|
|
464
|
+
// normalizer coefficients
|
|
465
|
+
final_dot -= query_fac.c34;
|
|
373
466
|
}
|
|
374
467
|
|
|
375
|
-
float final_dot = 0;
|
|
376
|
-
// dot-product itself
|
|
377
|
-
final_dot += query_fac.c1 * dot_qo;
|
|
378
|
-
// normalizer coefficients
|
|
379
|
-
final_dot += query_fac.c2 * sum_q;
|
|
380
|
-
// normalizer coefficients
|
|
381
|
-
final_dot -= query_fac.c34;
|
|
382
|
-
|
|
383
|
-
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
384
|
-
const float or_c_l2sqr = fac->or_minus_c_l2sqr;
|
|
385
|
-
|
|
386
468
|
// pre_dist = ||or - c||^2 + ||qr - c||^2 -
|
|
387
469
|
// 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
|
|
388
|
-
const float pre_dist =
|
|
389
|
-
2 *
|
|
470
|
+
const float pre_dist = base_fac->or_minus_c_l2sqr +
|
|
471
|
+
query_fac.qr_to_c_L2sqr - 2 * base_fac->dp_multiplier * final_dot;
|
|
390
472
|
|
|
391
473
|
if (metric_type == MetricType::METRIC_L2) {
|
|
392
474
|
// ||or - q||^ 2
|
|
393
475
|
return pre_dist;
|
|
394
476
|
} else {
|
|
395
477
|
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
396
|
-
|
|
397
|
-
// this is ||q||^2
|
|
398
|
-
const float query_norm_sqr = query_fac.qr_norm_L2sqr;
|
|
399
|
-
|
|
400
478
|
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
401
|
-
return -0.5f * (pre_dist -
|
|
479
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
402
480
|
}
|
|
403
481
|
}
|
|
404
482
|
|
|
405
|
-
|
|
406
|
-
FAISS_ASSERT(
|
|
483
|
+
float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
|
|
484
|
+
FAISS_ASSERT(code != nullptr);
|
|
407
485
|
FAISS_ASSERT(
|
|
408
486
|
(metric_type == MetricType::METRIC_L2 ||
|
|
409
487
|
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
488
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
410
489
|
|
|
411
|
-
|
|
412
|
-
if (centroid != nullptr) {
|
|
413
|
-
query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
|
|
414
|
-
} else {
|
|
415
|
-
query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
|
|
416
|
-
}
|
|
417
|
-
|
|
418
|
-
// allocate space
|
|
419
|
-
rotated_qq.resize(d);
|
|
490
|
+
size_t ex_bits = nb_bits - 1;
|
|
420
491
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
|
|
492
|
+
if (ex_bits == 0) {
|
|
493
|
+
// No ex-bits, just return 1-bit distance
|
|
494
|
+
return distance_to_code_1bit(code);
|
|
425
495
|
}
|
|
426
496
|
|
|
427
|
-
//
|
|
428
|
-
const
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
497
|
+
// Extract pointers to code sections
|
|
498
|
+
const uint8_t* binary_data = code;
|
|
499
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
500
|
+
const uint8_t* ex_code = code + offset;
|
|
501
|
+
const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
|
|
502
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
503
|
+
|
|
504
|
+
// Call shared utility directly with rotated_q pointer
|
|
505
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
506
|
+
binary_data,
|
|
507
|
+
ex_code,
|
|
508
|
+
*ex_fac,
|
|
509
|
+
rotated_q.data(),
|
|
510
|
+
query_fac.qr_to_c_L2sqr,
|
|
511
|
+
query_fac.qr_norm_L2sqr,
|
|
512
|
+
d,
|
|
513
|
+
ex_bits,
|
|
514
|
+
metric_type);
|
|
515
|
+
}
|
|
440
516
|
|
|
441
|
-
|
|
442
|
-
|
|
517
|
+
// Use shared constant from RaBitQUtils
|
|
518
|
+
using rabitq_utils::Z_MAX_BY_QB;
|
|
443
519
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
520
|
+
void RaBitQDistanceComputerQ::set_query(const float* x) {
|
|
521
|
+
q = x;
|
|
522
|
+
FAISS_ASSERT(x != nullptr);
|
|
523
|
+
FAISS_ASSERT(
|
|
524
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
525
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
526
|
+
FAISS_THROW_IF_NOT(qb <= 8);
|
|
527
|
+
FAISS_THROW_IF_NOT(qb > 0);
|
|
447
528
|
|
|
448
|
-
|
|
449
|
-
|
|
529
|
+
// Use shared utilities for core query factor computation
|
|
530
|
+
// rotated_q is populated directly by compute_query_factors as an output
|
|
531
|
+
// parameter
|
|
532
|
+
query_fac = rabitq_utils::compute_query_factors(
|
|
533
|
+
x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
|
|
450
534
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
535
|
+
// Compute g_error (query norm for lower bound computation)
|
|
536
|
+
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
537
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
454
538
|
|
|
455
|
-
//
|
|
539
|
+
// Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
|
|
456
540
|
popcount_aligned_dim = ((d + 7) / 8) * 8;
|
|
457
541
|
size_t offset = (d + 7) / 8;
|
|
458
542
|
|
|
@@ -466,33 +550,30 @@ void RaBitDistanceComputerQ::set_query(const float* x) {
|
|
|
466
550
|
bit ? (1 << (idim % 8)) : 0;
|
|
467
551
|
}
|
|
468
552
|
}
|
|
469
|
-
|
|
470
|
-
query_fac.c1 = 2 * delta * inv_d;
|
|
471
|
-
query_fac.c2 = 2 * v_min * inv_d;
|
|
472
|
-
query_fac.c34 = inv_d * (delta * sum_qq + d * v_min);
|
|
473
|
-
|
|
474
|
-
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
475
|
-
// precompute if needed
|
|
476
|
-
query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
|
|
477
|
-
}
|
|
478
553
|
}
|
|
479
554
|
|
|
555
|
+
} // anonymous namespace
|
|
556
|
+
|
|
480
557
|
FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
|
|
481
558
|
uint8_t qb,
|
|
482
|
-
const float* centroid_in
|
|
559
|
+
const float* centroid_in,
|
|
560
|
+
bool centered) const {
|
|
483
561
|
if (qb == 0) {
|
|
484
|
-
auto dc = std::make_unique<
|
|
562
|
+
auto dc = std::make_unique<RaBitQDistanceComputerNotQ>();
|
|
485
563
|
dc->metric_type = metric_type;
|
|
486
564
|
dc->d = d;
|
|
487
565
|
dc->centroid = centroid_in;
|
|
566
|
+
dc->nb_bits = nb_bits;
|
|
488
567
|
|
|
489
568
|
return dc.release();
|
|
490
569
|
} else {
|
|
491
|
-
auto dc = std::make_unique<
|
|
570
|
+
auto dc = std::make_unique<RaBitQDistanceComputerQ>();
|
|
492
571
|
dc->metric_type = metric_type;
|
|
493
572
|
dc->d = d;
|
|
494
573
|
dc->centroid = centroid_in;
|
|
495
574
|
dc->qb = qb;
|
|
575
|
+
dc->centered = centered;
|
|
576
|
+
dc->nb_bits = nb_bits;
|
|
496
577
|
|
|
497
578
|
return dc.release();
|
|
498
579
|
}
|