faiss 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +8 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/IVFlib.cpp +25 -49
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +24 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.h +3 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +80 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
- data/vendor/faiss/faiss/IndexHNSW.h +57 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
- data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/clone_index.cpp +2 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
- data/vendor/faiss/faiss/impl/HNSW.h +31 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
- data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/index_factory.cpp +182 -15
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/utils.cpp +4 -0
- metadata +18 -1
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// NOTE: Parts of this implementation are adapted from:
|
|
9
|
+
// RaBitQ-Library/include/rabitqlib/quantization/rabitq_impl.hpp
|
|
10
|
+
// https://github.com/VectorDB-NTU/RaBitQ-Library
|
|
11
|
+
|
|
12
|
+
#include <faiss/impl/FaissAssert.h>
|
|
13
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
14
|
+
#include <faiss/utils/distances.h>
|
|
15
|
+
|
|
16
|
+
#include <algorithm>
|
|
17
|
+
#include <cmath>
|
|
18
|
+
#include <cstring>
|
|
19
|
+
#include <queue>
|
|
20
|
+
#include <vector>
|
|
21
|
+
|
|
22
|
+
namespace faiss {
|
|
23
|
+
namespace rabitq_multibit {
|
|
24
|
+
|
|
25
|
+
using rabitq_utils::ExtraBitsFactors;
|
|
26
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
27
|
+
|
|
28
|
+
constexpr float kTightStart[9] =
|
|
29
|
+
{0.0f, 0.15f, 0.20f, 0.52f, 0.59f, 0.71f, 0.75f, 0.77f, 0.81f};
|
|
30
|
+
|
|
31
|
+
constexpr double kEps = 1e-5;
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* Compute optimal scaling factor for ex-bits quantization using priority
|
|
35
|
+
* queue-based search.
|
|
36
|
+
*
|
|
37
|
+
* This function finds the optimal scaling factor 't' that maximizes the
|
|
38
|
+
* inner product between the normalized quantized vector and the normalized
|
|
39
|
+
* absolute residual. The algorithm uses a priority queue to efficiently
|
|
40
|
+
* explore different quantization levels.
|
|
41
|
+
*
|
|
42
|
+
*
|
|
43
|
+
* @param o_abs Normalized absolute residual vector (must be positive, length
|
|
44
|
+
* d)
|
|
45
|
+
* @param d Dimensionality of the vector
|
|
46
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
47
|
+
* @return Optimal scaling factor 't'
|
|
48
|
+
*/
|
|
49
|
+
float compute_optimal_scaling_factor(
|
|
50
|
+
const float* o_abs,
|
|
51
|
+
size_t d,
|
|
52
|
+
size_t nb_bits) {
|
|
53
|
+
const size_t ex_bits = nb_bits - 1;
|
|
54
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
55
|
+
ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
|
|
56
|
+
|
|
57
|
+
const int kNEnum = 10;
|
|
58
|
+
const int max_code = (1 << ex_bits) - 1;
|
|
59
|
+
|
|
60
|
+
float max_o = *std::max_element(o_abs, o_abs + d);
|
|
61
|
+
|
|
62
|
+
// Determine search range [t_start, t_end]
|
|
63
|
+
float t_end = static_cast<float>(max_code + kNEnum) / max_o;
|
|
64
|
+
float t_start = t_end * kTightStart[ex_bits];
|
|
65
|
+
|
|
66
|
+
std::vector<float> inv_o_abs(d);
|
|
67
|
+
for (size_t i = 0; i < d; ++i) {
|
|
68
|
+
inv_o_abs[i] = 1.0f / o_abs[i];
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
std::vector<int> cur_o_bar(d);
|
|
72
|
+
float sqr_denominator = static_cast<float>(d) * 0.25f;
|
|
73
|
+
float numerator = 0.0f;
|
|
74
|
+
|
|
75
|
+
for (size_t i = 0; i < d; ++i) {
|
|
76
|
+
int cur = static_cast<int>((t_start * o_abs[i]) + kEps);
|
|
77
|
+
cur_o_bar[i] = cur;
|
|
78
|
+
sqr_denominator += static_cast<float>(cur * cur + cur);
|
|
79
|
+
numerator += (cur + 0.5f) * o_abs[i];
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
float inv_sqrt_denom = 1.0f / std::sqrt(sqr_denominator);
|
|
83
|
+
|
|
84
|
+
// Pair: (next_t, dimension_index)
|
|
85
|
+
// Maximum size is d (one entry per dimension), so reserve exactly d
|
|
86
|
+
std::vector<std::pair<float, size_t>> pq_storage;
|
|
87
|
+
pq_storage.reserve(d);
|
|
88
|
+
std::priority_queue<
|
|
89
|
+
std::pair<float, size_t>,
|
|
90
|
+
std::vector<std::pair<float, size_t>>,
|
|
91
|
+
std::greater<>>
|
|
92
|
+
next_t(std::greater<>(), std::move(pq_storage));
|
|
93
|
+
|
|
94
|
+
// Initialize queue with next quantization level for each dimension
|
|
95
|
+
for (size_t i = 0; i < d; ++i) {
|
|
96
|
+
float t_next = static_cast<float>(cur_o_bar[i] + 1) * inv_o_abs[i];
|
|
97
|
+
if (t_next < t_end) {
|
|
98
|
+
next_t.emplace(t_next, i);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
float max_ip = 0.0f;
|
|
103
|
+
float t = 0.0f;
|
|
104
|
+
|
|
105
|
+
while (!next_t.empty()) {
|
|
106
|
+
float cur_t = next_t.top().first;
|
|
107
|
+
size_t update_id = next_t.top().second;
|
|
108
|
+
next_t.pop();
|
|
109
|
+
|
|
110
|
+
cur_o_bar[update_id]++;
|
|
111
|
+
int update_o_bar = cur_o_bar[update_id];
|
|
112
|
+
|
|
113
|
+
float delta = 2.0f * update_o_bar;
|
|
114
|
+
sqr_denominator += delta;
|
|
115
|
+
numerator += o_abs[update_id];
|
|
116
|
+
|
|
117
|
+
float old_denom = sqr_denominator - delta;
|
|
118
|
+
inv_sqrt_denom = inv_sqrt_denom *
|
|
119
|
+
(1.0f - 0.5f * delta / (old_denom + delta * 0.5f));
|
|
120
|
+
|
|
121
|
+
float cur_ip = numerator * inv_sqrt_denom;
|
|
122
|
+
|
|
123
|
+
if (cur_ip > max_ip) {
|
|
124
|
+
max_ip = cur_ip;
|
|
125
|
+
t = cur_t;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
if (update_o_bar < max_code) {
|
|
129
|
+
float t_next =
|
|
130
|
+
static_cast<float>(update_o_bar + 1) * inv_o_abs[update_id];
|
|
131
|
+
if (t_next < t_end) {
|
|
132
|
+
next_t.emplace(t_next, update_id);
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
return t;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* Pack multi-bit codes from integer array to byte array.
|
|
142
|
+
*
|
|
143
|
+
* @param tmp_code Integer codes (length d), each value in [0, 2^ex_bits - 1]
|
|
144
|
+
* @param ex_code Output packed byte array
|
|
145
|
+
* @param d Dimensionality
|
|
146
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
147
|
+
*/
|
|
148
|
+
void pack_multibit_codes(
|
|
149
|
+
const int* tmp_code,
|
|
150
|
+
uint8_t* ex_code,
|
|
151
|
+
size_t d,
|
|
152
|
+
size_t nb_bits) {
|
|
153
|
+
const size_t ex_bits = nb_bits - 1;
|
|
154
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
155
|
+
ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
|
|
156
|
+
|
|
157
|
+
size_t total_bits = d * ex_bits;
|
|
158
|
+
size_t output_size = (total_bits + 7) / 8;
|
|
159
|
+
memset(ex_code, 0, output_size);
|
|
160
|
+
|
|
161
|
+
size_t bit_pos = 0;
|
|
162
|
+
for (size_t i = 0; i < d; i++) {
|
|
163
|
+
int code_value = tmp_code[i];
|
|
164
|
+
|
|
165
|
+
for (size_t bit = 0; bit < ex_bits; bit++) {
|
|
166
|
+
size_t byte_idx = bit_pos / 8;
|
|
167
|
+
size_t bit_idx = bit_pos % 8;
|
|
168
|
+
|
|
169
|
+
if (code_value & (1 << bit)) {
|
|
170
|
+
ex_code[byte_idx] |= (1 << bit_idx);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
bit_pos++;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
/**
|
|
179
|
+
* Compute ex-bits factors for distance computation.
|
|
180
|
+
*
|
|
181
|
+
* @param residual Original residual vector (data - centroid)
|
|
182
|
+
* @param centroid Centroid vector (can be nullptr for zero centroid)
|
|
183
|
+
* @param tmp_code Quantized ex-bit codes (before packing, after bit flipping)
|
|
184
|
+
* @param d Dimensionality
|
|
185
|
+
* @param ex_bits Number of extra bits
|
|
186
|
+
* @param norm L2 norm of residual
|
|
187
|
+
* @param ipnorm Unnormalized inner product between quantized and normalized
|
|
188
|
+
* residual
|
|
189
|
+
* @param ex_factors Output factors structure
|
|
190
|
+
* @param metric_type Distance metric (L2 or Inner Product)
|
|
191
|
+
*/
|
|
192
|
+
void compute_ex_factors(
|
|
193
|
+
const float* residual,
|
|
194
|
+
const float* centroid,
|
|
195
|
+
const int* tmp_code,
|
|
196
|
+
size_t d,
|
|
197
|
+
size_t ex_bits,
|
|
198
|
+
float norm,
|
|
199
|
+
double ipnorm,
|
|
200
|
+
ExtraBitsFactors& ex_factors,
|
|
201
|
+
MetricType metric_type) {
|
|
202
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
203
|
+
metric_type == MetricType::METRIC_L2 ||
|
|
204
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT,
|
|
205
|
+
"Unsupported metric type");
|
|
206
|
+
|
|
207
|
+
// Compute ipnorm_inv = 1 / ipnorm
|
|
208
|
+
float ipnorm_inv = static_cast<float>(1.0 / ipnorm);
|
|
209
|
+
if (!std::isnormal(ipnorm_inv)) {
|
|
210
|
+
ipnorm_inv = 1.0f;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
// Reconstruct xu_cb from total_code
|
|
214
|
+
// total_code was formed from: total_code[i] = (sign << ex_bits) +
|
|
215
|
+
// ex_code[i] Reconstruction: xu_cb[i] = total_code[i] + cb
|
|
216
|
+
const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
|
|
217
|
+
std::vector<float> xu_cb(d);
|
|
218
|
+
for (size_t i = 0; i < d; i++) {
|
|
219
|
+
xu_cb[i] = static_cast<float>(tmp_code[i]) + cb;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
// Compute inner products needed for factors
|
|
223
|
+
float l2_sqr = norm * norm;
|
|
224
|
+
float ip_resi_xucb = fvec_inner_product(residual, xu_cb.data(), d);
|
|
225
|
+
|
|
226
|
+
// Compute factors
|
|
227
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
228
|
+
// For L2, no centroid correction needed in IVF setting
|
|
229
|
+
// because residual = x - centroid, distance computed in residual space
|
|
230
|
+
ex_factors.f_add_ex = l2_sqr;
|
|
231
|
+
ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
|
|
232
|
+
} else {
|
|
233
|
+
// For IP, centroid correction is needed
|
|
234
|
+
float ip_resi_cent = 0;
|
|
235
|
+
if (centroid != nullptr) {
|
|
236
|
+
ip_resi_cent = fvec_inner_product(residual, centroid, d);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
float ip_cent_xucb = 0;
|
|
240
|
+
if (centroid != nullptr) {
|
|
241
|
+
ip_cent_xucb = fvec_inner_product(centroid, xu_cb.data(), d);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// When ip_resi_xucb is zero, the correction term should be zero
|
|
245
|
+
float correction_term = 0.0f;
|
|
246
|
+
if (ip_resi_xucb != 0.0f) {
|
|
247
|
+
correction_term = l2_sqr * ip_cent_xucb / ip_resi_xucb;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
ex_factors.f_add_ex = 1 - ip_resi_cent + correction_term;
|
|
251
|
+
ex_factors.f_rescale_ex = ipnorm_inv * -norm;
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
/**
|
|
256
|
+
* Quantize residual vector to ex-bits.
|
|
257
|
+
*
|
|
258
|
+
* This is the main quantization function that:
|
|
259
|
+
* 1. Normalizes the residual
|
|
260
|
+
* 2. Takes absolute value
|
|
261
|
+
* 3. Finds optimal scaling factor
|
|
262
|
+
* 4. Quantizes to ex_bits
|
|
263
|
+
* 5. Handles negative dimensions by flipping bits
|
|
264
|
+
* 6. Packs codes into byte array
|
|
265
|
+
* 7. Computes factors for distance computation
|
|
266
|
+
*
|
|
267
|
+
* @param residual Input residual vector (data - centroid), length d
|
|
268
|
+
* @param d Dimensionality
|
|
269
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
270
|
+
* @param ex_code Output packed ex-bit codes
|
|
271
|
+
* @param ex_factors Output ex-bits factors
|
|
272
|
+
* @param metric_type Distance metric (L2 or Inner Product)
|
|
273
|
+
* @param centroid Optional centroid vector (needed for IP metric)
|
|
274
|
+
*/
|
|
275
|
+
void quantize_ex_bits(
|
|
276
|
+
const float* residual,
|
|
277
|
+
size_t d,
|
|
278
|
+
size_t nb_bits,
|
|
279
|
+
uint8_t* ex_code,
|
|
280
|
+
ExtraBitsFactors& ex_factors,
|
|
281
|
+
MetricType metric_type,
|
|
282
|
+
const float* centroid) {
|
|
283
|
+
const size_t ex_bits = nb_bits - 1;
|
|
284
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
285
|
+
ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
|
|
286
|
+
FAISS_THROW_IF_NOT_MSG(residual != nullptr, "residual cannot be null");
|
|
287
|
+
FAISS_THROW_IF_NOT_MSG(ex_code != nullptr, "ex_code cannot be null");
|
|
288
|
+
|
|
289
|
+
// Step 1: Compute L2 norm of residual
|
|
290
|
+
float norm_sqr = fvec_norm_L2sqr(residual, d);
|
|
291
|
+
float norm = std::sqrt(norm_sqr);
|
|
292
|
+
|
|
293
|
+
// Handle degenerate case
|
|
294
|
+
if (norm < 1e-10f) {
|
|
295
|
+
size_t code_size = (d * ex_bits + 7) / 8;
|
|
296
|
+
memset(ex_code, 0, code_size);
|
|
297
|
+
ex_factors.f_add_ex = 0.0f;
|
|
298
|
+
ex_factors.f_rescale_ex = 1.0f;
|
|
299
|
+
return;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
// Step 2: Normalize residual
|
|
303
|
+
std::vector<float> normalized_residual(d);
|
|
304
|
+
for (size_t i = 0; i < d; i++) {
|
|
305
|
+
normalized_residual[i] = residual[i] / norm;
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
// Step 3: Take absolute value
|
|
309
|
+
std::vector<float> o_abs(d);
|
|
310
|
+
for (size_t i = 0; i < d; i++) {
|
|
311
|
+
o_abs[i] = std::abs(normalized_residual[i]);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
// Step 4: Find optimal scaling factor
|
|
315
|
+
float t = compute_optimal_scaling_factor(o_abs.data(), d, nb_bits);
|
|
316
|
+
|
|
317
|
+
// Step 5: Quantize to ex_bits
|
|
318
|
+
std::vector<int> tmp_code(d);
|
|
319
|
+
double ipnorm = 0;
|
|
320
|
+
int max_code = (1 << ex_bits) - 1;
|
|
321
|
+
|
|
322
|
+
for (size_t i = 0; i < d; i++) {
|
|
323
|
+
tmp_code[i] = std::min(static_cast<int>(t * o_abs[i] + kEps), max_code);
|
|
324
|
+
// Compute unnormalized inner product
|
|
325
|
+
ipnorm += (tmp_code[i] + 0.5) * o_abs[i];
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// Step 6: Handle negative dimensions (flip bits)
|
|
329
|
+
// For negative residuals, flip all bits: code' = ~code & max_code
|
|
330
|
+
for (size_t i = 0; i < d; i++) {
|
|
331
|
+
if (residual[i] < 0) {
|
|
332
|
+
tmp_code[i] = (~tmp_code[i]) & max_code;
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// Step 7: Pack codes into byte array
|
|
337
|
+
pack_multibit_codes(tmp_code.data(), ex_code, d, nb_bits);
|
|
338
|
+
|
|
339
|
+
// Step 8: Compute factors for distance computation
|
|
340
|
+
// Reconstruct total_code for factor computation
|
|
341
|
+
std::vector<int> total_code(d);
|
|
342
|
+
for (size_t i = 0; i < d; i++) {
|
|
343
|
+
// Form total_code = (sign << ex_bits) + ex_code
|
|
344
|
+
bool sign_bit = (residual[i] >= 0);
|
|
345
|
+
total_code[i] = tmp_code[i] + ((sign_bit ? 1 : 0) << ex_bits);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
// Compute ex-factors; centroid is needed for IP metric correction
|
|
349
|
+
compute_ex_factors(
|
|
350
|
+
residual,
|
|
351
|
+
centroid, // Pass centroid for IP metric factor computation
|
|
352
|
+
total_code.data(),
|
|
353
|
+
d,
|
|
354
|
+
ex_bits,
|
|
355
|
+
norm,
|
|
356
|
+
ipnorm,
|
|
357
|
+
ex_factors,
|
|
358
|
+
metric_type);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
} // namespace rabitq_multibit
|
|
362
|
+
} // namespace faiss
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// Reference:
|
|
9
|
+
// "Practical and asymptotically optimal quantization of high-dimensional
|
|
10
|
+
// vectors in euclidean space for approximate nearest neighbor search"
|
|
11
|
+
// Jianyang Gao, Yutong Gou, Yuexuan Xu, Yongyi Yang, Cheng Long, Raymond
|
|
12
|
+
// Chi-Wing Wong https://dl.acm.org/doi/pdf/10.1145/3725413
|
|
13
|
+
//
|
|
14
|
+
// Reference implementation: https://github.com/VectorDB-NTU/RaBitQ-Library
|
|
15
|
+
// NOTE: Parts of this implementation are adapted from
|
|
16
|
+
// rabitqlib/quantization/rabitq_impl.hpp in the above repository.
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include <faiss/MetricType.h>
|
|
21
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
22
|
+
#include <cstddef>
|
|
23
|
+
#include <cstdint>
|
|
24
|
+
|
|
25
|
+
namespace faiss {
|
|
26
|
+
namespace rabitq_multibit {
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* Compute optimal scaling factor for ex-bits quantization.
|
|
30
|
+
*
|
|
31
|
+
* Uses priority queue-based search to find the scaling factor that
|
|
32
|
+
* maximizes the inner product between quantized and original vectors.
|
|
33
|
+
*
|
|
34
|
+
* @param o_abs Normalized absolute residual vector (positive values)
|
|
35
|
+
* @param d Dimensionality
|
|
36
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
37
|
+
* @return Optimal scaling factor 't'
|
|
38
|
+
*/
|
|
39
|
+
float compute_optimal_scaling_factor(
|
|
40
|
+
const float* o_abs,
|
|
41
|
+
size_t d,
|
|
42
|
+
size_t nb_bits);
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Pack multi-bit codes from integer array to byte array.
|
|
46
|
+
*
|
|
47
|
+
* @param tmp_code Integer codes (length d), values in [0, 2^ex_bits - 1]
|
|
48
|
+
* @param ex_code Output packed byte array
|
|
49
|
+
* @param d Dimensionality
|
|
50
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
51
|
+
*/
|
|
52
|
+
void pack_multibit_codes(
|
|
53
|
+
const int* tmp_code,
|
|
54
|
+
uint8_t* ex_code,
|
|
55
|
+
size_t d,
|
|
56
|
+
size_t nb_bits);
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Compute ex-bits factors for distance computation.
|
|
60
|
+
*
|
|
61
|
+
* @param residual Original residual vector (data - centroid)
|
|
62
|
+
* @param centroid Centroid vector (can be nullptr for zero centroid)
|
|
63
|
+
* @param tmp_code Quantized ex-bit codes (unpacked integers)
|
|
64
|
+
* @param d Dimensionality
|
|
65
|
+
* @param ex_bits Number of extra bits
|
|
66
|
+
* @param norm L2 norm of residual
|
|
67
|
+
* @param ipnorm Unnormalized inner product
|
|
68
|
+
* @param ex_factors Output factors structure
|
|
69
|
+
* @param metric_type Distance metric (L2 or IP)
|
|
70
|
+
*/
|
|
71
|
+
void compute_ex_factors(
|
|
72
|
+
const float* residual,
|
|
73
|
+
const float* centroid,
|
|
74
|
+
const int* tmp_code,
|
|
75
|
+
size_t d,
|
|
76
|
+
size_t ex_bits,
|
|
77
|
+
float norm,
|
|
78
|
+
double ipnorm,
|
|
79
|
+
rabitq_utils::ExtraBitsFactors& ex_factors,
|
|
80
|
+
MetricType metric_type);
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Main quantization function: quantize residual vector to ex-bits.
|
|
84
|
+
*
|
|
85
|
+
* Performs the complete multi-bit quantization pipeline:
|
|
86
|
+
* 1. Normalize residual
|
|
87
|
+
* 2. Take absolute value
|
|
88
|
+
* 3. Find optimal scaling factor
|
|
89
|
+
* 4. Quantize to ex_bits
|
|
90
|
+
* 5. Handle negative dimensions by bit flipping
|
|
91
|
+
* 6. Pack codes into byte array
|
|
92
|
+
* 7. Compute factors for distance computation
|
|
93
|
+
*
|
|
94
|
+
* @param residual Input residual vector (data - centroid), length d
|
|
95
|
+
* @param d Dimensionality
|
|
96
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
97
|
+
* @param ex_code Output packed ex-bit codes
|
|
98
|
+
* @param ex_factors Output ex-bits factors
|
|
99
|
+
* @param metric_type Distance metric (L2 or Inner Product)
|
|
100
|
+
* @param centroid Optional centroid vector (needed for IP metric)
|
|
101
|
+
*/
|
|
102
|
+
void quantize_ex_bits(
|
|
103
|
+
const float* residual,
|
|
104
|
+
size_t d,
|
|
105
|
+
size_t nb_bits,
|
|
106
|
+
uint8_t* ex_code,
|
|
107
|
+
rabitq_utils::ExtraBitsFactors& ex_factors,
|
|
108
|
+
MetricType metric_type,
|
|
109
|
+
const float* centroid = nullptr);
|
|
110
|
+
|
|
111
|
+
} // namespace rabitq_multibit
|
|
112
|
+
} // namespace faiss
|
|
@@ -1009,16 +1009,13 @@ void train_Uniform(
|
|
|
1009
1009
|
} else if (rs == ScalarQuantizer::RS_quantiles) {
|
|
1010
1010
|
std::vector<float> x_copy(n);
|
|
1011
1011
|
memcpy(x_copy.data(), x, n * sizeof(*x));
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
o = 0;
|
|
1017
|
-
}
|
|
1018
|
-
if (o > n - o) {
|
|
1019
|
-
o = n / 2;
|
|
1020
|
-
}
|
|
1012
|
+
int temp = int(rs_arg * n);
|
|
1013
|
+
int o = temp < 0 ? 0 : (temp > n / 2 ? n / 2 : temp);
|
|
1014
|
+
|
|
1015
|
+
std::nth_element(x_copy.begin(), x_copy.begin() + o, x_copy.end());
|
|
1021
1016
|
vmin = x_copy[o];
|
|
1017
|
+
std::nth_element(
|
|
1018
|
+
x_copy.begin(), x_copy.begin() + (n - 1 - o), x_copy.end());
|
|
1022
1019
|
vmax = x_copy[n - 1 - o];
|
|
1023
1020
|
|
|
1024
1021
|
} else if (rs == ScalarQuantizer::RS_optim) {
|
|
@@ -98,9 +98,7 @@ struct ScalarQuantizer : Quantizer {
|
|
|
98
98
|
SQuantizer* select_quantizer() const;
|
|
99
99
|
|
|
100
100
|
struct SQDistanceComputer : FlatCodesDistanceComputer {
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
SQDistanceComputer() : q(nullptr) {}
|
|
101
|
+
SQDistanceComputer() : FlatCodesDistanceComputer(nullptr) {}
|
|
104
102
|
|
|
105
103
|
virtual float query_to_code(const uint8_t* code) const = 0;
|
|
106
104
|
|