faiss 0.5.0 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +8 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/IVFlib.cpp +25 -49
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +24 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.h +3 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +80 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +90 -1
- data/vendor/faiss/faiss/IndexHNSW.h +57 -1
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +34 -149
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +86 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +3 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +293 -115
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +52 -16
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -16
- data/vendor/faiss/faiss/IndexRaBitQ.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +238 -93
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +35 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/clone_index.cpp +2 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +74 -3
- data/vendor/faiss/faiss/impl/HNSW.cpp +294 -15
- data/vendor/faiss/faiss/impl/HNSW.h +31 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +3 -3
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +54 -6
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +183 -6
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +269 -84
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +71 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +6 -9
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/impl/index_read.cpp +156 -12
- data/vendor/faiss/faiss/impl/index_write.cpp +142 -19
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/index_factory.cpp +182 -15
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +18 -109
- data/vendor/faiss/faiss/invlists/InvertedLists.h +2 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/utils.cpp +4 -0
- metadata +18 -1
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
#include <faiss/IndexIVFFastScan.h>
|
|
13
13
|
#include <faiss/IndexIVFRaBitQ.h>
|
|
14
14
|
#include <faiss/IndexRaBitQFastScan.h>
|
|
15
|
+
#include <faiss/impl/RaBitQStats.h>
|
|
15
16
|
#include <faiss/impl/RaBitQUtils.h>
|
|
16
17
|
#include <faiss/impl/RaBitQuantizer.h>
|
|
17
18
|
#include <faiss/impl/simd_result_handlers.h>
|
|
@@ -24,8 +25,9 @@ namespace faiss {
|
|
|
24
25
|
struct FastScanDistancePostProcessing;
|
|
25
26
|
|
|
26
27
|
// Import shared utilities from RaBitQUtils
|
|
27
|
-
using rabitq_utils::FactorsData;
|
|
28
28
|
using rabitq_utils::QueryFactorsData;
|
|
29
|
+
using rabitq_utils::SignBitFactors;
|
|
30
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
29
31
|
|
|
30
32
|
/** Fast-scan version of IndexIVFRaBitQ that processes vectors in batches
|
|
31
33
|
* using SIMD operations. Combines the inverted file structure of IVF
|
|
@@ -53,9 +55,16 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
53
55
|
/// Use zero-centered scalar quantizer for queries
|
|
54
56
|
bool centered = false;
|
|
55
57
|
|
|
56
|
-
///
|
|
57
|
-
///
|
|
58
|
-
|
|
58
|
+
/// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
|
|
59
|
+
///
|
|
60
|
+
/// 1-bit codes (sign bits) are stored in the inherited `codes` array from
|
|
61
|
+
/// IndexFastScan in packed FastScan format for SIMD processing.
|
|
62
|
+
///
|
|
63
|
+
/// This flat_storage holds per-vector factors and refinement-bit codes:
|
|
64
|
+
/// Layout for 1-bit: [SignBitFactors (8 bytes)]
|
|
65
|
+
/// Layout for multi-bit: [SignBitFactorsWithError
|
|
66
|
+
/// (12B)][ref_codes][ExtraBitsFactors (8B)]
|
|
67
|
+
std::vector<uint8_t> flat_storage;
|
|
59
68
|
|
|
60
69
|
// Constructors
|
|
61
70
|
|
|
@@ -67,7 +76,8 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
67
76
|
size_t nlist,
|
|
68
77
|
MetricType metric = METRIC_L2,
|
|
69
78
|
int bbs = 32,
|
|
70
|
-
bool own_invlists = true
|
|
79
|
+
bool own_invlists = true,
|
|
80
|
+
uint8_t nb_bits = 1);
|
|
71
81
|
|
|
72
82
|
/// Build from an existing IndexIVFRaBitQ
|
|
73
83
|
explicit IndexIVFRaBitQFastScan(const IndexIVFRaBitQ& orig, int bbs = 32);
|
|
@@ -101,13 +111,10 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
101
111
|
/// Override sa_decode to handle RaBitQ reconstruction
|
|
102
112
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
103
113
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
void encode_vector_to_fastscan(
|
|
107
|
-
const float* xi,
|
|
108
|
-
const float* centroid,
|
|
109
|
-
uint8_t* fastscan_code) const;
|
|
114
|
+
/// Compute storage size per vector in flat_storage based on nb_bits
|
|
115
|
+
size_t compute_per_vector_storage_size() const;
|
|
110
116
|
|
|
117
|
+
private:
|
|
111
118
|
/// Compute query factors and lookup table for a residual vector
|
|
112
119
|
/// (similar to IndexRaBitQFastScan::compute_float_LUT)
|
|
113
120
|
void compute_residual_LUT(
|
|
@@ -116,10 +123,12 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
116
123
|
float* lut_out,
|
|
117
124
|
const float* original_query = nullptr) const;
|
|
118
125
|
|
|
119
|
-
/// Decode FastScan code to RaBitQ residual vector
|
|
126
|
+
/// Decode FastScan code to RaBitQ residual vector with explicit
|
|
127
|
+
/// dp_multiplier
|
|
120
128
|
void decode_fastscan_to_residual(
|
|
121
129
|
const uint8_t* fastscan_code,
|
|
122
|
-
float* residual
|
|
130
|
+
float* residual,
|
|
131
|
+
float dp_multiplier) const;
|
|
123
132
|
|
|
124
133
|
public:
|
|
125
134
|
/// Implementation methods for IVFRaBitQFastScan specialization
|
|
@@ -171,6 +180,7 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
171
180
|
* - Specialized handling for both centered and non-centered quantization
|
|
172
181
|
* modes
|
|
173
182
|
* - Efficient inner product metric corrections
|
|
183
|
+
* - Uses runtime boolean for multi-bit mode
|
|
174
184
|
*
|
|
175
185
|
* @tparam C Comparator type (CMin/CMax) for heap operations
|
|
176
186
|
*/
|
|
@@ -185,7 +195,8 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
185
195
|
std::vector<int>
|
|
186
196
|
probe_indices; // probe index for each query in current batch
|
|
187
197
|
const FastScanDistancePostProcessing*
|
|
188
|
-
context;
|
|
198
|
+
context; // Processing context with query factors
|
|
199
|
+
const bool is_multibit; // Whether to use multi-bit two-stage search
|
|
189
200
|
|
|
190
201
|
// Use float-based comparator for heap operations
|
|
191
202
|
using Cfloat = typename std::conditional<
|
|
@@ -199,9 +210,11 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
199
210
|
size_t k_val,
|
|
200
211
|
float* distances,
|
|
201
212
|
int64_t* labels,
|
|
202
|
-
const FastScanDistancePostProcessing* ctx = nullptr
|
|
213
|
+
const FastScanDistancePostProcessing* ctx = nullptr,
|
|
214
|
+
bool multibit = false);
|
|
203
215
|
|
|
204
|
-
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1)
|
|
216
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1)
|
|
217
|
+
override;
|
|
205
218
|
|
|
206
219
|
/// Override base class virtual method to receive context information
|
|
207
220
|
void set_list_context(size_t list_no, const std::vector<int>& probe_map)
|
|
@@ -210,6 +223,29 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
210
223
|
void begin(const float* norms) override;
|
|
211
224
|
|
|
212
225
|
void end() override;
|
|
226
|
+
|
|
227
|
+
private:
|
|
228
|
+
/// Compute full multi-bit distance for a candidate vector (multi-bit
|
|
229
|
+
/// only)
|
|
230
|
+
/// @param db_idx Global database vector index
|
|
231
|
+
/// @param local_q Batch-local query index (for probe_indices access)
|
|
232
|
+
/// @param global_q Global query index (for storage indexing)
|
|
233
|
+
/// @param local_offset Offset within the current inverted list
|
|
234
|
+
float compute_full_multibit_distance(
|
|
235
|
+
size_t db_idx,
|
|
236
|
+
size_t local_q,
|
|
237
|
+
size_t global_q,
|
|
238
|
+
size_t local_offset) const;
|
|
239
|
+
|
|
240
|
+
/// Compute lower bound using 1-bit distance and error bound (multi-bit
|
|
241
|
+
/// only)
|
|
242
|
+
/// @param local_q Batch-local query index (for probe_indices access)
|
|
243
|
+
/// @param global_q Global query index (for storage indexing)
|
|
244
|
+
float compute_lower_bound(
|
|
245
|
+
float dist_1bit,
|
|
246
|
+
size_t db_idx,
|
|
247
|
+
size_t local_q,
|
|
248
|
+
size_t global_q) const;
|
|
213
249
|
};
|
|
214
250
|
};
|
|
215
251
|
|
|
@@ -81,6 +81,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
81
81
|
const float* sdc;
|
|
82
82
|
std::vector<float> precomputed_table;
|
|
83
83
|
size_t ndis;
|
|
84
|
+
const float* q;
|
|
84
85
|
|
|
85
86
|
float distance_to_code(const uint8_t* code) final {
|
|
86
87
|
ndis++;
|
|
@@ -109,7 +110,8 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
109
110
|
: FlatCodesDistanceComputer(
|
|
110
111
|
storage.codes.data(),
|
|
111
112
|
storage.code_size),
|
|
112
|
-
pq(storage.pq)
|
|
113
|
+
pq(storage.pq),
|
|
114
|
+
q(nullptr) {
|
|
113
115
|
precomputed_table.resize(pq.M * pq.ksub);
|
|
114
116
|
nb = storage.ntotal;
|
|
115
117
|
d = storage.d;
|
|
@@ -123,6 +125,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
123
125
|
}
|
|
124
126
|
|
|
125
127
|
void set_query(const float* x) override {
|
|
128
|
+
q = x;
|
|
126
129
|
if (metric == METRIC_L2) {
|
|
127
130
|
pq.compute_distance_table(x, precomputed_table.data());
|
|
128
131
|
} else {
|
|
@@ -197,6 +197,20 @@ void IndexPreTransform::range_search(
|
|
|
197
197
|
n, tv.x, radius, result, extract_index_search_params(params));
|
|
198
198
|
}
|
|
199
199
|
|
|
200
|
+
void IndexPreTransform::search_subset(
|
|
201
|
+
idx_t n,
|
|
202
|
+
const float* x,
|
|
203
|
+
idx_t k_base,
|
|
204
|
+
const idx_t* base_labels,
|
|
205
|
+
idx_t k,
|
|
206
|
+
float* distances,
|
|
207
|
+
idx_t* labels) const {
|
|
208
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
209
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
210
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
|
211
|
+
index->search_subset(n, tv.x, k_base, base_labels, k, distances, labels);
|
|
212
|
+
}
|
|
213
|
+
|
|
200
214
|
void IndexPreTransform::reset() {
|
|
201
215
|
index->reset();
|
|
202
216
|
ntotal = 0;
|
|
@@ -57,6 +57,15 @@ struct IndexPreTransform : Index {
|
|
|
57
57
|
idx_t* labels,
|
|
58
58
|
const SearchParameters* params = nullptr) const override;
|
|
59
59
|
|
|
60
|
+
void search_subset(
|
|
61
|
+
idx_t n,
|
|
62
|
+
const float* x,
|
|
63
|
+
idx_t k_base,
|
|
64
|
+
const idx_t* base_labels,
|
|
65
|
+
idx_t k,
|
|
66
|
+
float* distances,
|
|
67
|
+
idx_t* labels) const override;
|
|
68
|
+
|
|
60
69
|
/* range search, no attempt is done to change the radius */
|
|
61
70
|
void range_search(
|
|
62
71
|
idx_t n,
|
|
@@ -9,13 +9,18 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
11
|
#include <faiss/impl/ResultHandler.h>
|
|
12
|
+
#include <memory>
|
|
12
13
|
|
|
13
14
|
namespace faiss {
|
|
14
15
|
|
|
16
|
+
// Forward declaration from RaBitQuantizer.cpp
|
|
17
|
+
struct RaBitQDistanceComputer;
|
|
18
|
+
|
|
15
19
|
IndexRaBitQ::IndexRaBitQ() = default;
|
|
16
20
|
|
|
17
|
-
IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric)
|
|
18
|
-
: IndexFlatCodes(0, d, metric), rabitq(d, metric) {
|
|
21
|
+
IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric, uint8_t nb_bits_in)
|
|
22
|
+
: IndexFlatCodes(0, d, metric), rabitq(d, metric, nb_bits_in) {
|
|
23
|
+
// Update code size based on nb_bits
|
|
19
24
|
code_size = rabitq.code_size;
|
|
20
25
|
|
|
21
26
|
is_trained = false;
|
|
@@ -78,6 +83,7 @@ struct Run_search_with_dc_res {
|
|
|
78
83
|
|
|
79
84
|
uint8_t qb = 0;
|
|
80
85
|
bool centered = false;
|
|
86
|
+
uint8_t nb_bits = 1; // Number of bits per dimension
|
|
81
87
|
|
|
82
88
|
template <class BlockResultHandler>
|
|
83
89
|
void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
|
|
@@ -85,22 +91,87 @@ struct Run_search_with_dc_res {
|
|
|
85
91
|
using SingleResultHandler =
|
|
86
92
|
typename BlockResultHandler::SingleResultHandler;
|
|
87
93
|
const int d = index->d;
|
|
94
|
+
size_t ex_bits = nb_bits - 1;
|
|
88
95
|
|
|
89
|
-
#pragma omp parallel
|
|
96
|
+
#pragma omp parallel
|
|
90
97
|
{
|
|
91
|
-
std::unique_ptr<FlatCodesDistanceComputer>
|
|
98
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc_base(
|
|
92
99
|
index->get_quantized_distance_computer(qb, centered));
|
|
93
100
|
SingleResultHandler resi(res);
|
|
94
101
|
#pragma omp for
|
|
95
102
|
for (int64_t q = 0; q < res.nq; q++) {
|
|
96
103
|
resi.begin(q);
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
104
|
+
dc_base->set_query(xq + d * q);
|
|
105
|
+
|
|
106
|
+
// Stats tracking for multi-bit two-stage search only
|
|
107
|
+
// n_1bit_evaluations: candidates evaluated using 1-bit lower
|
|
108
|
+
// bound n_multibit_evaluations: candidates requiring full
|
|
109
|
+
// multi-bit distance
|
|
110
|
+
size_t local_1bit_evaluations = 0;
|
|
111
|
+
size_t local_multibit_evaluations = 0;
|
|
112
|
+
|
|
113
|
+
if (ex_bits == 0) {
|
|
114
|
+
// 1-bit: Standard single-stage search (no stats tracking)
|
|
115
|
+
for (size_t i = 0; i < ntotal; i++) {
|
|
116
|
+
if (res.is_in_selection(i)) {
|
|
117
|
+
float dis = (*dc_base)(i);
|
|
118
|
+
resi.add_result(dis, i);
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
} else {
|
|
122
|
+
// Multi-bit: Two-stage search with adaptive filtering
|
|
123
|
+
// Note: Even with query quantization (qb > 0), ex-bits
|
|
124
|
+
// distance computation uses the float query to maintain
|
|
125
|
+
// consistency with encoding-time factor computation. See
|
|
126
|
+
// RaBitQuantizer.cpp for details.
|
|
127
|
+
auto* dc = dynamic_cast<RaBitQDistanceComputer*>(
|
|
128
|
+
dc_base.get());
|
|
129
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
130
|
+
dc != nullptr,
|
|
131
|
+
"Failed to cast to RaBitQDistanceComputer for two-stage search");
|
|
132
|
+
|
|
133
|
+
// Use appropriate comparison based on metric type
|
|
134
|
+
bool is_similarity =
|
|
135
|
+
is_similarity_metric(index->metric_type);
|
|
136
|
+
|
|
137
|
+
for (size_t i = 0; i < ntotal; i++) {
|
|
138
|
+
if (res.is_in_selection(i)) {
|
|
139
|
+
const uint8_t* code =
|
|
140
|
+
index->codes.data() + i * index->code_size;
|
|
141
|
+
|
|
142
|
+
local_1bit_evaluations++;
|
|
143
|
+
|
|
144
|
+
// Stage 1: Compute 1-bit lower bound
|
|
145
|
+
float lower_bound = dc->lower_bound_distance(code);
|
|
146
|
+
|
|
147
|
+
// Stage 2: Adaptive filtering using threshold
|
|
148
|
+
// For L2 (min-heap): filter if lower_bound <
|
|
149
|
+
// resi.threshold For IP (max-heap): filter if
|
|
150
|
+
// lower_bound > resi.threshold Note: Using
|
|
151
|
+
// resi.threshold directly (not cached) enables more
|
|
152
|
+
// aggressive filtering as the heap is updated
|
|
153
|
+
bool should_refine = is_similarity
|
|
154
|
+
? (lower_bound > resi.threshold)
|
|
155
|
+
: (lower_bound < resi.threshold);
|
|
156
|
+
|
|
157
|
+
if (should_refine) {
|
|
158
|
+
local_multibit_evaluations++;
|
|
159
|
+
// Compute full multi-bit distance
|
|
160
|
+
float dist_full =
|
|
161
|
+
dc->distance_to_code_full(code);
|
|
162
|
+
resi.add_result(dist_full, i);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
102
165
|
}
|
|
103
166
|
}
|
|
167
|
+
|
|
168
|
+
// Update global stats atomically
|
|
169
|
+
#pragma omp atomic
|
|
170
|
+
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
171
|
+
#pragma omp atomic
|
|
172
|
+
rabitq_stats.n_multibit_evaluations +=
|
|
173
|
+
local_multibit_evaluations;
|
|
174
|
+
|
|
104
175
|
resi.end();
|
|
105
176
|
}
|
|
106
177
|
}
|
|
@@ -116,16 +187,25 @@ void IndexRaBitQ::search(
|
|
|
116
187
|
float* distances,
|
|
117
188
|
idx_t* labels,
|
|
118
189
|
const SearchParameters* params_in) const {
|
|
119
|
-
|
|
120
|
-
|
|
190
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
191
|
+
|
|
192
|
+
// Extract search parameters
|
|
193
|
+
uint8_t used_qb = qb;
|
|
194
|
+
bool used_centered = centered;
|
|
121
195
|
if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
} else {
|
|
125
|
-
r.qb = this->qb;
|
|
126
|
-
r.centered = this->centered;
|
|
196
|
+
used_qb = params->qb;
|
|
197
|
+
used_centered = params->centered;
|
|
127
198
|
}
|
|
128
199
|
|
|
200
|
+
const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
|
|
201
|
+
|
|
202
|
+
// Set up functor with all necessary parameters
|
|
203
|
+
Run_search_with_dc_res r;
|
|
204
|
+
r.qb = used_qb;
|
|
205
|
+
r.centered = used_centered;
|
|
206
|
+
r.nb_bits = rabitq.nb_bits; // Pass multi-bit info to functor
|
|
207
|
+
|
|
208
|
+
// Use Faiss framework for all cases (single-stage and two-stage)
|
|
129
209
|
dispatch_knn_ResultHandler(
|
|
130
210
|
n, distances, labels, k, metric_type, sel, r, this, x);
|
|
131
211
|
}
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
10
|
#include <faiss/IndexFlatCodes.h>
|
|
11
|
+
#include <faiss/impl/RaBitQStats.h>
|
|
11
12
|
#include <faiss/impl/RaBitQuantizer.h>
|
|
12
13
|
|
|
13
14
|
namespace faiss {
|
|
@@ -32,7 +33,10 @@ struct IndexRaBitQ : IndexFlatCodes {
|
|
|
32
33
|
|
|
33
34
|
IndexRaBitQ();
|
|
34
35
|
|
|
35
|
-
explicit IndexRaBitQ(
|
|
36
|
+
explicit IndexRaBitQ(
|
|
37
|
+
idx_t d,
|
|
38
|
+
MetricType metric = METRIC_L2,
|
|
39
|
+
uint8_t nb_bits = 1);
|
|
36
40
|
|
|
37
41
|
void train(idx_t n, const float* x) override;
|
|
38
42
|
|