faiss 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +8 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/IVFlib.cpp +25 -49
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +24 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.h +3 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +80 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
- data/vendor/faiss/faiss/IndexHNSW.h +57 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
- data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/clone_index.cpp +2 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
- data/vendor/faiss/faiss/impl/HNSW.h +31 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
- data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/index_factory.cpp +182 -15
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/utils.cpp +4 -0
- metadata +18 -1
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
11
|
#include <faiss/impl/RaBitQUtils.h>
|
|
12
|
+
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
12
13
|
#include <faiss/utils/distances.h>
|
|
13
14
|
#include <faiss/utils/rabitq_simd.h>
|
|
14
15
|
#include <algorithm>
|
|
@@ -20,15 +21,47 @@
|
|
|
20
21
|
namespace faiss {
|
|
21
22
|
|
|
22
23
|
// Import shared utilities from RaBitQUtils
|
|
23
|
-
using rabitq_utils::
|
|
24
|
+
using rabitq_utils::ExtraBitsFactors;
|
|
24
25
|
using rabitq_utils::QueryFactorsData;
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
using rabitq_utils::SignBitFactors;
|
|
27
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
28
|
+
|
|
29
|
+
RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric, size_t nb_bits)
|
|
30
|
+
: Quantizer(d, 0), // code_size will be set below
|
|
31
|
+
metric_type{metric},
|
|
32
|
+
nb_bits{nb_bits} {
|
|
33
|
+
// Validate nb_bits range
|
|
34
|
+
FAISS_THROW_IF_NOT(nb_bits >= 1 && nb_bits <= 9);
|
|
35
|
+
|
|
36
|
+
// Set code_size using compute_code_size
|
|
37
|
+
code_size = compute_code_size(d, nb_bits);
|
|
28
38
|
}
|
|
29
39
|
|
|
30
|
-
RaBitQuantizer::
|
|
31
|
-
|
|
40
|
+
size_t RaBitQuantizer::compute_code_size(size_t d, size_t num_bits) const {
|
|
41
|
+
// Validate inputs
|
|
42
|
+
FAISS_THROW_IF_NOT(num_bits >= 1 && num_bits <= 9);
|
|
43
|
+
|
|
44
|
+
size_t ex_bits = num_bits - 1;
|
|
45
|
+
|
|
46
|
+
// Base: 1-bit codes + base factors
|
|
47
|
+
// Layout for 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
|
|
48
|
+
// base_factors = or_minus_c_l2sqr (4) + dp_multiplier (4)
|
|
49
|
+
// Layout for multi-bit: [binary_code: (d+7)/8
|
|
50
|
+
// bytes][SignBitFactorsWithError: 12 bytes]
|
|
51
|
+
// factors = or_minus_c_l2sqr (4) + dp_multiplier (4) + f_error (4)
|
|
52
|
+
size_t base_size = (d + 7) / 8 +
|
|
53
|
+
(ex_bits == 0 ? sizeof(SignBitFactors)
|
|
54
|
+
: sizeof(SignBitFactorsWithError));
|
|
55
|
+
|
|
56
|
+
// Extra: ex-bit codes + ex factors (only if ex_bits > 0)
|
|
57
|
+
// Layout: [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
|
|
58
|
+
size_t ex_size = 0;
|
|
59
|
+
if (ex_bits > 0) {
|
|
60
|
+
ex_size = (d * ex_bits + 7) / 8 + sizeof(ExtraBitsFactors);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return base_size + ex_size;
|
|
64
|
+
}
|
|
32
65
|
|
|
33
66
|
void RaBitQuantizer::train(size_t n, const float* x) {
|
|
34
67
|
// does nothing
|
|
@@ -54,23 +87,49 @@ void RaBitQuantizer::compute_codes_core(
|
|
|
54
87
|
return;
|
|
55
88
|
}
|
|
56
89
|
|
|
57
|
-
|
|
90
|
+
const size_t ex_bits = nb_bits - 1;
|
|
91
|
+
|
|
92
|
+
// Compute codes
|
|
58
93
|
#pragma omp parallel for if (n > 1000)
|
|
59
94
|
for (int64_t i = 0; i < n; i++) {
|
|
60
|
-
//
|
|
95
|
+
// Pointer to this vector's code
|
|
61
96
|
uint8_t* code = codes + i * code_size;
|
|
62
|
-
FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8);
|
|
63
97
|
|
|
64
|
-
//
|
|
65
|
-
|
|
66
|
-
memset(code, 0, code_size);
|
|
67
|
-
}
|
|
98
|
+
// Clear code memory
|
|
99
|
+
memset(code, 0, code_size);
|
|
68
100
|
|
|
69
101
|
const float* x_row = x + i * d;
|
|
70
102
|
|
|
103
|
+
// Pointer arithmetic for code layout:
|
|
104
|
+
// For 1-bit: [binary_code: (d+7)/8 bytes][SignBitFactors: 8 bytes]
|
|
105
|
+
// For multi-bit: [binary_code: (d+7)/8 bytes][SignBitFactorsWithError:
|
|
106
|
+
// 12 bytes]
|
|
107
|
+
// [ex_code: (d*ex_bits+7)/8 bytes][ex_factors: 8 bytes]
|
|
108
|
+
uint8_t* binary_code = code;
|
|
109
|
+
|
|
110
|
+
// Step 1: Compute 1-bit quantization and base factors
|
|
111
|
+
// Store residual for potential ex-bits quantization
|
|
112
|
+
std::vector<float> residual(d);
|
|
113
|
+
|
|
71
114
|
// Use shared utilities for computing factors
|
|
72
|
-
|
|
73
|
-
|
|
115
|
+
SignBitFactorsWithError factors_data =
|
|
116
|
+
rabitq_utils::compute_vector_factors(
|
|
117
|
+
x_row, d, centroid_in, metric_type, ex_bits > 0);
|
|
118
|
+
|
|
119
|
+
// Write appropriate factors based on nb_bits
|
|
120
|
+
if (ex_bits == 0) {
|
|
121
|
+
// For 1-bit: write only SignBitFactors (8 bytes)
|
|
122
|
+
SignBitFactors* base_factors =
|
|
123
|
+
reinterpret_cast<SignBitFactors*>(code + (d + 7) / 8);
|
|
124
|
+
base_factors->or_minus_c_l2sqr = factors_data.or_minus_c_l2sqr;
|
|
125
|
+
base_factors->dp_multiplier = factors_data.dp_multiplier;
|
|
126
|
+
} else {
|
|
127
|
+
// For multi-bit: write full SignBitFactorsWithError (12 bytes)
|
|
128
|
+
SignBitFactorsWithError* full_factors =
|
|
129
|
+
reinterpret_cast<SignBitFactorsWithError*>(
|
|
130
|
+
code + (d + 7) / 8);
|
|
131
|
+
*full_factors = factors_data;
|
|
132
|
+
}
|
|
74
133
|
|
|
75
134
|
// Pack bits into standard RaBitQ format
|
|
76
135
|
for (size_t j = 0; j < d; j++) {
|
|
@@ -78,13 +137,35 @@ void RaBitQuantizer::compute_codes_core(
|
|
|
78
137
|
const float centroid_val =
|
|
79
138
|
(centroid_in == nullptr) ? 0.0f : centroid_in[j];
|
|
80
139
|
const float or_minus_c = x_val - centroid_val;
|
|
140
|
+
residual[j] = or_minus_c;
|
|
141
|
+
|
|
81
142
|
const bool xb = (or_minus_c > 0.0f);
|
|
82
143
|
|
|
83
|
-
//
|
|
84
|
-
if (
|
|
85
|
-
rabitq_utils::set_bit_standard(
|
|
144
|
+
// Store the 1-bit sign code
|
|
145
|
+
if (xb) {
|
|
146
|
+
rabitq_utils::set_bit_standard(binary_code, j);
|
|
86
147
|
}
|
|
87
148
|
}
|
|
149
|
+
|
|
150
|
+
// Step 2: Compute ex-bits quantization (if nb_bits > 1)
|
|
151
|
+
if (ex_bits > 0) {
|
|
152
|
+
// Pointer to ex-bit code section
|
|
153
|
+
uint8_t* ex_code =
|
|
154
|
+
code + (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
155
|
+
// Pointer to ex-factors section
|
|
156
|
+
ExtraBitsFactors* ex_factors = reinterpret_cast<ExtraBitsFactors*>(
|
|
157
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
158
|
+
|
|
159
|
+
// Quantize residual to ex-bits (pass centroid for IP metric)
|
|
160
|
+
rabitq_multibit::quantize_ex_bits(
|
|
161
|
+
residual.data(),
|
|
162
|
+
d,
|
|
163
|
+
nb_bits,
|
|
164
|
+
ex_code,
|
|
165
|
+
*ex_factors,
|
|
166
|
+
metric_type,
|
|
167
|
+
centroid_in);
|
|
168
|
+
}
|
|
88
169
|
}
|
|
89
170
|
}
|
|
90
171
|
|
|
@@ -101,6 +182,7 @@ void RaBitQuantizer::decode_core(
|
|
|
101
182
|
FAISS_ASSERT(x != nullptr);
|
|
102
183
|
|
|
103
184
|
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
185
|
+
const size_t ex_bits = nb_bits - 1;
|
|
104
186
|
|
|
105
187
|
#pragma omp parallel for if (n > 1000)
|
|
106
188
|
for (int64_t i = 0; i < n; i++) {
|
|
@@ -108,10 +190,19 @@ void RaBitQuantizer::decode_core(
|
|
|
108
190
|
|
|
109
191
|
// split the code into parts
|
|
110
192
|
const uint8_t* binary_data = code;
|
|
111
|
-
const FactorsData* fac =
|
|
112
|
-
reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
|
|
113
193
|
|
|
194
|
+
// Cast to appropriate type based on nb_bits
|
|
195
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
196
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes, but only first
|
|
197
|
+
// 8 bytes used for decode)
|
|
198
|
+
const SignBitFactors* fac = (ex_bits == 0)
|
|
199
|
+
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
200
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
201
|
+
code + (d + 7) / 8);
|
|
202
|
+
|
|
203
|
+
// this is the baseline code
|
|
114
204
|
//
|
|
205
|
+
// compute <q,o> using floats
|
|
115
206
|
for (size_t j = 0; j < d; j++) {
|
|
116
207
|
// extract i-th bit
|
|
117
208
|
const uint8_t masker = (1 << (j % 8));
|
|
@@ -124,51 +215,69 @@ void RaBitQuantizer::decode_core(
|
|
|
124
215
|
}
|
|
125
216
|
}
|
|
126
217
|
|
|
127
|
-
|
|
128
|
-
// dimensionality
|
|
129
|
-
size_t d = 0;
|
|
130
|
-
// a centroid to use
|
|
131
|
-
const float* centroid = nullptr;
|
|
218
|
+
// Implementation of RaBitQDistanceComputer (declared in header)
|
|
132
219
|
|
|
133
|
-
|
|
134
|
-
|
|
220
|
+
float RaBitQDistanceComputer::lower_bound_distance(const uint8_t* code) {
|
|
221
|
+
FAISS_ASSERT(code != nullptr);
|
|
135
222
|
|
|
136
|
-
|
|
223
|
+
// Compute estimated distance using 1-bit codes
|
|
224
|
+
float est_distance = distance_to_code_1bit(code);
|
|
137
225
|
|
|
138
|
-
|
|
139
|
-
|
|
226
|
+
// Extract f_error from the code
|
|
227
|
+
size_t size = (d + 7) / 8;
|
|
228
|
+
const SignBitFactorsWithError* base_fac =
|
|
229
|
+
reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
230
|
+
float f_error = base_fac->f_error;
|
|
140
231
|
|
|
141
|
-
|
|
232
|
+
// Compute proper lower bound using RaBitQ error formula:
|
|
233
|
+
// lower_bound = est_distance - f_error * g_error
|
|
234
|
+
// This guarantees: lower_bound ≤ true_distance
|
|
235
|
+
float lower_bound = est_distance - (f_error * g_error);
|
|
142
236
|
|
|
143
|
-
|
|
144
|
-
|
|
237
|
+
// Distance cannot be negative
|
|
238
|
+
return std::max(0.0f, lower_bound);
|
|
145
239
|
}
|
|
146
240
|
|
|
147
|
-
|
|
241
|
+
namespace {
|
|
242
|
+
|
|
243
|
+
struct RaBitQDistanceComputerNotQ : RaBitQDistanceComputer {
|
|
148
244
|
// the rotated query (qr - c)
|
|
149
245
|
std::vector<float> rotated_q;
|
|
150
246
|
// some additional numbers for the query
|
|
151
247
|
QueryFactorsData query_fac;
|
|
152
248
|
|
|
153
|
-
|
|
249
|
+
RaBitQDistanceComputerNotQ();
|
|
154
250
|
|
|
155
|
-
|
|
251
|
+
// Compute distance using only 1-bit codes (fast)
|
|
252
|
+
float distance_to_code_1bit(const uint8_t* code) override;
|
|
253
|
+
|
|
254
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
255
|
+
float distance_to_code_full(const uint8_t* code) override;
|
|
156
256
|
|
|
157
257
|
void set_query(const float* x) override;
|
|
158
258
|
};
|
|
159
259
|
|
|
160
|
-
|
|
260
|
+
RaBitQDistanceComputerNotQ::RaBitQDistanceComputerNotQ() = default;
|
|
161
261
|
|
|
162
|
-
float
|
|
262
|
+
float RaBitQDistanceComputerNotQ::distance_to_code_1bit(const uint8_t* code) {
|
|
163
263
|
FAISS_ASSERT(code != nullptr);
|
|
164
264
|
FAISS_ASSERT(
|
|
165
265
|
(metric_type == MetricType::METRIC_L2 ||
|
|
166
266
|
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
267
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
167
268
|
|
|
168
269
|
// split the code into parts
|
|
169
270
|
const uint8_t* binary_data = code;
|
|
170
|
-
|
|
171
|
-
|
|
271
|
+
|
|
272
|
+
// Cast to appropriate type based on nb_bits
|
|
273
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
274
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
275
|
+
// f_error
|
|
276
|
+
size_t ex_bits = nb_bits - 1;
|
|
277
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
278
|
+
? reinterpret_cast<const SignBitFactors*>(code + (d + 7) / 8)
|
|
279
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(
|
|
280
|
+
code + (d + 7) / 8);
|
|
172
281
|
|
|
173
282
|
// this is the baseline code
|
|
174
283
|
//
|
|
@@ -177,48 +286,70 @@ float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
|
|
|
177
286
|
// It was a willful decision (after the discussion) to not to pre-cache
|
|
178
287
|
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
179
288
|
uint64_t sum_q = 0;
|
|
180
|
-
for (size_t i = 0; i < d; i++) {
|
|
181
|
-
// extract i-th bit
|
|
182
|
-
const uint8_t masker = (1 << (i % 8));
|
|
183
|
-
const bool b_bit = ((binary_data[i / 8] & masker) == masker);
|
|
184
289
|
|
|
290
|
+
for (size_t i = 0; i < d; i++) {
|
|
291
|
+
// Extract i-th bit
|
|
292
|
+
bool bit = rabitq_utils::extract_bit_standard(binary_data, i);
|
|
185
293
|
// accumulate dp
|
|
186
|
-
dot_qo +=
|
|
294
|
+
dot_qo += bit ? rotated_q[i] : 0;
|
|
187
295
|
// accumulate sum-of-bits
|
|
188
|
-
sum_q +=
|
|
296
|
+
sum_q += bit ? 1 : 0;
|
|
189
297
|
}
|
|
190
298
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
// normalizer coefficients
|
|
195
|
-
final_dot += query_fac.c2 * sum_q;
|
|
196
|
-
// normalizer coefficients
|
|
197
|
-
final_dot -= query_fac.c34;
|
|
198
|
-
|
|
199
|
-
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
200
|
-
const float or_c_l2sqr = fac->or_minus_c_l2sqr;
|
|
299
|
+
// Apply query factors
|
|
300
|
+
float final_dot =
|
|
301
|
+
query_fac.c1 * dot_qo + query_fac.c2 * sum_q - query_fac.c34;
|
|
201
302
|
|
|
202
303
|
// pre_dist = ||or - c||^2 + ||qr - c||^2 -
|
|
203
304
|
// 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
|
|
204
|
-
|
|
205
|
-
2 *
|
|
305
|
+
float pre_dist = base_fac->or_minus_c_l2sqr + query_fac.qr_to_c_L2sqr -
|
|
306
|
+
2 * base_fac->dp_multiplier * final_dot;
|
|
206
307
|
|
|
207
308
|
if (metric_type == MetricType::METRIC_L2) {
|
|
208
309
|
// ||or - q||^ 2
|
|
209
310
|
return pre_dist;
|
|
210
311
|
} else {
|
|
211
312
|
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
313
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
314
|
+
}
|
|
315
|
+
}
|
|
212
316
|
|
|
213
|
-
|
|
214
|
-
|
|
317
|
+
float RaBitQDistanceComputerNotQ::distance_to_code_full(const uint8_t* code) {
|
|
318
|
+
FAISS_ASSERT(code != nullptr);
|
|
319
|
+
FAISS_ASSERT(
|
|
320
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
321
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
322
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
215
323
|
|
|
216
|
-
|
|
217
|
-
|
|
324
|
+
size_t ex_bits = nb_bits - 1;
|
|
325
|
+
|
|
326
|
+
if (ex_bits == 0) {
|
|
327
|
+
// No ex-bits, just return 1-bit distance
|
|
328
|
+
return distance_to_code_1bit(code);
|
|
218
329
|
}
|
|
330
|
+
|
|
331
|
+
// Extract pointers to code sections
|
|
332
|
+
const uint8_t* binary_data = code;
|
|
333
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
334
|
+
const uint8_t* ex_code = code + offset;
|
|
335
|
+
const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
|
|
336
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
337
|
+
|
|
338
|
+
// Call shared utility directly with rotated_q pointer
|
|
339
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
340
|
+
binary_data,
|
|
341
|
+
ex_code,
|
|
342
|
+
*ex_fac,
|
|
343
|
+
rotated_q.data(),
|
|
344
|
+
query_fac.qr_to_c_L2sqr,
|
|
345
|
+
query_fac.qr_norm_L2sqr,
|
|
346
|
+
d,
|
|
347
|
+
ex_bits,
|
|
348
|
+
metric_type);
|
|
219
349
|
}
|
|
220
350
|
|
|
221
|
-
void
|
|
351
|
+
void RaBitQDistanceComputerNotQ::set_query(const float* x) {
|
|
352
|
+
q = x;
|
|
222
353
|
FAISS_ASSERT(x != nullptr);
|
|
223
354
|
FAISS_ASSERT(
|
|
224
355
|
(metric_type == MetricType::METRIC_L2 ||
|
|
@@ -237,6 +368,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
|
|
|
237
368
|
rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
|
|
238
369
|
}
|
|
239
370
|
|
|
371
|
+
// Compute g_error (query norm for lower bound computation)
|
|
372
|
+
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
373
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
374
|
+
|
|
240
375
|
// compute some numbers
|
|
241
376
|
const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
242
377
|
|
|
@@ -257,8 +392,10 @@ void RaBitDistanceComputerNotQ::set_query(const float* x) {
|
|
|
257
392
|
}
|
|
258
393
|
|
|
259
394
|
//
|
|
260
|
-
struct
|
|
395
|
+
struct RaBitQDistanceComputerQ : RaBitQDistanceComputer {
|
|
261
396
|
// the rotated and quantized query (qr - c)
|
|
397
|
+
std::vector<float> rotated_q;
|
|
398
|
+
// the rotated and quantized query (qr - c) for fast 1-bit computation
|
|
262
399
|
std::vector<uint8_t> rotated_qq;
|
|
263
400
|
// we're using the proposed relayout-ed scheme from 3.3 that allows
|
|
264
401
|
// using popcounts for computing the distance.
|
|
@@ -272,16 +409,20 @@ struct RaBitDistanceComputerQ : RaBitDistanceComputer {
|
|
|
272
409
|
// the smallest value divisible by 8 that is not smaller than dim
|
|
273
410
|
size_t popcount_aligned_dim = 0;
|
|
274
411
|
|
|
275
|
-
|
|
412
|
+
RaBitQDistanceComputerQ();
|
|
276
413
|
|
|
277
|
-
|
|
414
|
+
// Compute distance using only 1-bit codes (fast)
|
|
415
|
+
float distance_to_code_1bit(const uint8_t* code) override;
|
|
416
|
+
|
|
417
|
+
// Compute full distance using 1-bit + ex-bits (accurate)
|
|
418
|
+
float distance_to_code_full(const uint8_t* code) override;
|
|
278
419
|
|
|
279
420
|
void set_query(const float* x) override;
|
|
280
421
|
};
|
|
281
422
|
|
|
282
|
-
|
|
423
|
+
RaBitQDistanceComputerQ::RaBitQDistanceComputerQ() = default;
|
|
283
424
|
|
|
284
|
-
float
|
|
425
|
+
float RaBitQDistanceComputerQ::distance_to_code_1bit(const uint8_t* code) {
|
|
285
426
|
FAISS_ASSERT(code != nullptr);
|
|
286
427
|
FAISS_ASSERT(
|
|
287
428
|
(metric_type == MetricType::METRIC_L2 ||
|
|
@@ -290,21 +431,28 @@ float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
|
|
|
290
431
|
// split the code into parts
|
|
291
432
|
size_t size = (d + 7) / 8;
|
|
292
433
|
const uint8_t* binary_data = code;
|
|
293
|
-
|
|
434
|
+
|
|
435
|
+
// Cast to appropriate type based on nb_bits
|
|
436
|
+
// For 1-bit: use SignBitFactors (8 bytes)
|
|
437
|
+
// For multi-bit: use SignBitFactorsWithError (12 bytes) which includes
|
|
438
|
+
// f_error
|
|
439
|
+
size_t ex_bits = nb_bits - 1;
|
|
440
|
+
const SignBitFactors* base_fac = (ex_bits == 0)
|
|
441
|
+
? reinterpret_cast<const SignBitFactors*>(code + size)
|
|
442
|
+
: reinterpret_cast<const SignBitFactorsWithError*>(code + size);
|
|
294
443
|
|
|
295
444
|
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
296
445
|
float final_dot = 0;
|
|
297
446
|
if (centered) {
|
|
298
447
|
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
448
|
+
// See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
|
|
299
449
|
int_dot -= 2 *
|
|
300
450
|
rabitq::bitwise_xor_dot_product(
|
|
301
451
|
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
302
452
|
final_dot += int_dot * query_fac.int_dot_scale;
|
|
303
453
|
} else {
|
|
304
|
-
// See RaBitDistanceComputerNotQ::distance_to_code() for baseline code.
|
|
305
454
|
auto dot_qo = rabitq::bitwise_and_dot_product(
|
|
306
455
|
rearranged_rotated_qq.data(), binary_data, size, qb);
|
|
307
|
-
|
|
308
456
|
// It was a willful decision (after the discussion) to not to pre-cache
|
|
309
457
|
// the sum of all bits, just in order to reduce the overhead per vector.
|
|
310
458
|
// process 64-bit popcounts
|
|
@@ -317,32 +465,60 @@ float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
|
|
|
317
465
|
final_dot -= query_fac.c34;
|
|
318
466
|
}
|
|
319
467
|
|
|
320
|
-
// this is ||or - c||^2 - (IP ? ||or||^2 : 0)
|
|
321
|
-
const float or_c_l2sqr = fac->or_minus_c_l2sqr;
|
|
322
|
-
|
|
323
468
|
// pre_dist = ||or - c||^2 + ||qr - c||^2 -
|
|
324
469
|
// 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
|
|
325
|
-
const float pre_dist =
|
|
326
|
-
2 *
|
|
470
|
+
const float pre_dist = base_fac->or_minus_c_l2sqr +
|
|
471
|
+
query_fac.qr_to_c_L2sqr - 2 * base_fac->dp_multiplier * final_dot;
|
|
327
472
|
|
|
328
473
|
if (metric_type == MetricType::METRIC_L2) {
|
|
329
474
|
// ||or - q||^ 2
|
|
330
475
|
return pre_dist;
|
|
331
476
|
} else {
|
|
332
477
|
// metric == MetricType::METRIC_INNER_PRODUCT
|
|
478
|
+
// 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
|
|
479
|
+
return -0.5f * (pre_dist - query_fac.qr_norm_L2sqr);
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
float RaBitQDistanceComputerQ::distance_to_code_full(const uint8_t* code) {
|
|
484
|
+
FAISS_ASSERT(code != nullptr);
|
|
485
|
+
FAISS_ASSERT(
|
|
486
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
487
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
488
|
+
FAISS_ASSERT(rotated_q.size() == d);
|
|
333
489
|
|
|
334
|
-
|
|
335
|
-
const float query_norm_sqr = query_fac.qr_norm_L2sqr;
|
|
490
|
+
size_t ex_bits = nb_bits - 1;
|
|
336
491
|
|
|
337
|
-
|
|
338
|
-
|
|
492
|
+
if (ex_bits == 0) {
|
|
493
|
+
// No ex-bits, just return 1-bit distance
|
|
494
|
+
return distance_to_code_1bit(code);
|
|
339
495
|
}
|
|
496
|
+
|
|
497
|
+
// Extract pointers to code sections
|
|
498
|
+
const uint8_t* binary_data = code;
|
|
499
|
+
size_t offset = (d + 7) / 8 + sizeof(SignBitFactorsWithError);
|
|
500
|
+
const uint8_t* ex_code = code + offset;
|
|
501
|
+
const ExtraBitsFactors* ex_fac = reinterpret_cast<const ExtraBitsFactors*>(
|
|
502
|
+
ex_code + (d * ex_bits + 7) / 8);
|
|
503
|
+
|
|
504
|
+
// Call shared utility directly with rotated_q pointer
|
|
505
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
506
|
+
binary_data,
|
|
507
|
+
ex_code,
|
|
508
|
+
*ex_fac,
|
|
509
|
+
rotated_q.data(),
|
|
510
|
+
query_fac.qr_to_c_L2sqr,
|
|
511
|
+
query_fac.qr_norm_L2sqr,
|
|
512
|
+
d,
|
|
513
|
+
ex_bits,
|
|
514
|
+
metric_type);
|
|
340
515
|
}
|
|
341
516
|
|
|
342
517
|
// Use shared constant from RaBitQUtils
|
|
343
518
|
using rabitq_utils::Z_MAX_BY_QB;
|
|
344
519
|
|
|
345
|
-
void
|
|
520
|
+
void RaBitQDistanceComputerQ::set_query(const float* x) {
|
|
521
|
+
q = x;
|
|
346
522
|
FAISS_ASSERT(x != nullptr);
|
|
347
523
|
FAISS_ASSERT(
|
|
348
524
|
(metric_type == MetricType::METRIC_L2 ||
|
|
@@ -351,10 +527,15 @@ void RaBitDistanceComputerQ::set_query(const float* x) {
|
|
|
351
527
|
FAISS_THROW_IF_NOT(qb > 0);
|
|
352
528
|
|
|
353
529
|
// Use shared utilities for core query factor computation
|
|
354
|
-
|
|
530
|
+
// rotated_q is populated directly by compute_query_factors as an output
|
|
531
|
+
// parameter
|
|
355
532
|
query_fac = rabitq_utils::compute_query_factors(
|
|
356
533
|
x, d, centroid, qb, centered, metric_type, rotated_q, rotated_qq);
|
|
357
534
|
|
|
535
|
+
// Compute g_error (query norm for lower bound computation)
|
|
536
|
+
// g_error = ||qr - c|| (L2 norm of rotated query)
|
|
537
|
+
g_error = std::sqrt(query_fac.qr_to_c_L2sqr);
|
|
538
|
+
|
|
358
539
|
// Rearrange the query vector for SIMD operations (RaBitQuantizer-specific)
|
|
359
540
|
popcount_aligned_dim = ((d + 7) / 8) * 8;
|
|
360
541
|
size_t offset = (d + 7) / 8;
|
|
@@ -371,24 +552,28 @@ void RaBitDistanceComputerQ::set_query(const float* x) {
|
|
|
371
552
|
}
|
|
372
553
|
}
|
|
373
554
|
|
|
555
|
+
} // anonymous namespace
|
|
556
|
+
|
|
374
557
|
FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
|
|
375
558
|
uint8_t qb,
|
|
376
559
|
const float* centroid_in,
|
|
377
560
|
bool centered) const {
|
|
378
561
|
if (qb == 0) {
|
|
379
|
-
auto dc = std::make_unique<
|
|
562
|
+
auto dc = std::make_unique<RaBitQDistanceComputerNotQ>();
|
|
380
563
|
dc->metric_type = metric_type;
|
|
381
564
|
dc->d = d;
|
|
382
565
|
dc->centroid = centroid_in;
|
|
566
|
+
dc->nb_bits = nb_bits;
|
|
383
567
|
|
|
384
568
|
return dc.release();
|
|
385
569
|
} else {
|
|
386
|
-
auto dc = std::make_unique<
|
|
570
|
+
auto dc = std::make_unique<RaBitQDistanceComputerQ>();
|
|
387
571
|
dc->metric_type = metric_type;
|
|
388
572
|
dc->d = d;
|
|
389
573
|
dc->centroid = centroid_in;
|
|
390
574
|
dc->qb = qb;
|
|
391
575
|
dc->centered = centered;
|
|
576
|
+
dc->nb_bits = nb_bits;
|
|
392
577
|
|
|
393
578
|
return dc.release();
|
|
394
579
|
}
|
|
@@ -37,11 +37,28 @@ struct RaBitQuantizer : Quantizer {
|
|
|
37
37
|
// possible. Thus, a quantizer has to introduce a metric.
|
|
38
38
|
MetricType metric_type = MetricType::METRIC_L2;
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
// Number of bits per dimension (1-9). Default is 1 for backward
|
|
41
|
+
// compatibility.
|
|
42
|
+
// - nb_bits = 1: standard 1-bit RaBitQ (sign bits only)
|
|
43
|
+
// - nb_bits = 2-9: multi-bit RaBitQ (1 sign bit + ex_bits extra bits)
|
|
44
|
+
size_t nb_bits = 1;
|
|
45
|
+
|
|
46
|
+
RaBitQuantizer(
|
|
47
|
+
size_t d = 0,
|
|
48
|
+
MetricType metric = MetricType::METRIC_L2,
|
|
49
|
+
size_t nb_bits = 1);
|
|
50
|
+
|
|
51
|
+
// Compute code size based on dimensionality and number of bits
|
|
52
|
+
// Returns: size in bytes for one encoded vector
|
|
53
|
+
// - nb_bits=1: (d+7)/8 + 8 bytes (1-bit codes + base factors)
|
|
54
|
+
// - nb_bits>1: (d+7)/8 + 8 + d*ex_bits/8 + 8 bytes
|
|
55
|
+
// (1-bit codes + base factors + ex-bit codes + ex factors)
|
|
56
|
+
size_t compute_code_size(size_t d, size_t num_bits) const;
|
|
41
57
|
|
|
42
58
|
void train(size_t n, const float* x) override;
|
|
43
59
|
|
|
44
|
-
// every vector is expected to take (d + 7) / 8 + sizeof(
|
|
60
|
+
// every vector is expected to take (d + 7) / 8 + sizeof(SignBitFactors)
|
|
61
|
+
// bytes,
|
|
45
62
|
void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
|
|
46
63
|
|
|
47
64
|
void compute_codes_core(
|
|
@@ -71,9 +88,59 @@ struct RaBitQuantizer : Quantizer {
|
|
|
71
88
|
// specify qb = 0 to get an DC that does not quantize a query
|
|
72
89
|
// specify qb > 0 to have SQ qb-bits query
|
|
73
90
|
FlatCodesDistanceComputer* get_distance_computer(
|
|
74
|
-
uint8_t qb,
|
|
75
|
-
const float*
|
|
91
|
+
uint8_t qb = 0,
|
|
92
|
+
const float* centroid = nullptr,
|
|
76
93
|
bool centered = false) const;
|
|
77
94
|
};
|
|
78
95
|
|
|
96
|
+
// RaBitQDistanceComputer: Base class for RaBitQ distance computers
|
|
97
|
+
//
|
|
98
|
+
// This intermediate class exists to provide a unified interface for
|
|
99
|
+
// two-stage multi-bit search. While most Faiss quantizers extend
|
|
100
|
+
// FlatCodesDistanceComputer directly, RaBitQ requires this additional
|
|
101
|
+
// abstraction layer due to its unique split encoding strategy
|
|
102
|
+
// (1 sign bit + magnitude bits) which enables:
|
|
103
|
+
//
|
|
104
|
+
// 1. distance_to_code_1bit() - Fast 1-bit filtering using only sign bits
|
|
105
|
+
// 2. distance_to_code_full() - Accurate multi-bit refinement using all bits
|
|
106
|
+
// 3. lower_bound_distance() - Error-bounded adaptive filtering
|
|
107
|
+
// (based on 1-bit estimator)
|
|
108
|
+
//
|
|
109
|
+
// These three methods implement RaBitQ's two-stage search pattern and are
|
|
110
|
+
// shared between the quantized (Q) and non-quantized (NotQ) query variants.
|
|
111
|
+
// The intermediate class allows two-stage search code to work with both
|
|
112
|
+
// variants via a single dynamic_cast.
|
|
113
|
+
struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
|
|
114
|
+
size_t d = 0;
|
|
115
|
+
const float* centroid = nullptr;
|
|
116
|
+
MetricType metric_type = MetricType::METRIC_L2;
|
|
117
|
+
size_t nb_bits = 1;
|
|
118
|
+
|
|
119
|
+
// Query norm for lower bound computation (g_error in rabitq-library)
|
|
120
|
+
// This is the L2 norm of the rotated query: ||query - centroid||
|
|
121
|
+
float g_error = 0.0f;
|
|
122
|
+
|
|
123
|
+
float symmetric_dis(idx_t /*i*/, idx_t /*j*/) override {
|
|
124
|
+
// Not used for RaBitQ
|
|
125
|
+
FAISS_THROW_MSG("Not implemented");
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// Compute 1-bit distance estimate (fast)
|
|
129
|
+
virtual float distance_to_code_1bit(const uint8_t* code) = 0;
|
|
130
|
+
|
|
131
|
+
// Compute full multi-bit distance (accurate)
|
|
132
|
+
virtual float distance_to_code_full(const uint8_t* code) = 0;
|
|
133
|
+
|
|
134
|
+
// Compute lower bound of distance using error bounds
|
|
135
|
+
// Guarantees: actual_distance >= lower_bound_distance
|
|
136
|
+
// Used for adaptive filtering in two-stage search
|
|
137
|
+
virtual float lower_bound_distance(const uint8_t* code);
|
|
138
|
+
|
|
139
|
+
// Override from FlatCodesDistanceComputer
|
|
140
|
+
// Delegates to distance_to_code_full() for multi-bit distance computation
|
|
141
|
+
float distance_to_code(const uint8_t* code) final {
|
|
142
|
+
return distance_to_code_full(code);
|
|
143
|
+
}
|
|
144
|
+
};
|
|
145
|
+
|
|
79
146
|
} // namespace faiss
|