faiss 0.3.4 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +11 -8
- data/vendor/faiss/faiss/Clustering.cpp +0 -16
- data/vendor/faiss/faiss/IVFlib.cpp +213 -0
- data/vendor/faiss/faiss/IVFlib.h +42 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
- data/vendor/faiss/faiss/IndexHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
- data/vendor/faiss/faiss/IndexIVF.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
- data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
- data/vendor/faiss/faiss/clone_index.cpp +38 -3
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
- data/vendor/faiss/faiss/impl/HNSW.h +5 -4
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
- data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
- data/vendor/faiss/faiss/impl/io.h +2 -2
- data/vendor/faiss/faiss/impl/io_macros.h +2 -0
- data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
- data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
- data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
- data/vendor/faiss/faiss/index_factory.cpp +16 -5
- data/vendor/faiss/faiss/index_io.h +4 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
- data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
- data/vendor/faiss/faiss/utils/hamming.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +16 -4
@@ -312,11 +312,14 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
312
312
|
|
313
313
|
/** Get a scanner for this index (store_pairs means ignore labels)
|
314
314
|
*
|
315
|
-
* The default search implementation uses this to compute the distances
|
315
|
+
* The default search implementation uses this to compute the distances.
|
316
|
+
* Use sel instead of params->sel, because sel is initialized with
|
317
|
+
* params->sel, but may get overridden by IndexIVF's internal logic.
|
316
318
|
*/
|
317
319
|
virtual InvertedListScanner* get_InvertedListScanner(
|
318
320
|
bool store_pairs = false,
|
319
|
-
const IDSelector* sel = nullptr
|
321
|
+
const IDSelector* sel = nullptr,
|
322
|
+
const IVFSearchParameters* params = nullptr) const;
|
320
323
|
|
321
324
|
/** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
|
322
325
|
*/
|
@@ -253,7 +253,8 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
|
|
253
253
|
|
254
254
|
InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
|
255
255
|
bool store_pairs,
|
256
|
-
const IDSelector* sel
|
256
|
+
const IDSelector* sel,
|
257
|
+
const IVFSearchParameters*) const {
|
257
258
|
FAISS_THROW_IF_NOT(!sel);
|
258
259
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
259
260
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
@@ -52,7 +52,8 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
|
|
52
52
|
|
53
53
|
InvertedListScanner* get_InvertedListScanner(
|
54
54
|
bool store_pairs,
|
55
|
-
const IDSelector* sel
|
55
|
+
const IDSelector* sel,
|
56
|
+
const IVFSearchParameters* params) const override;
|
56
57
|
|
57
58
|
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
|
58
59
|
|
@@ -223,7 +223,8 @@ InvertedListScanner* get_InvertedListScanner1(
|
|
223
223
|
|
224
224
|
InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
|
225
225
|
bool store_pairs,
|
226
|
-
const IDSelector* sel
|
226
|
+
const IDSelector* sel,
|
227
|
+
const IVFSearchParameters*) const {
|
227
228
|
if (sel) {
|
228
229
|
return get_InvertedListScanner1<true>(this, store_pairs, sel);
|
229
230
|
} else {
|
@@ -44,7 +44,8 @@ struct IndexIVFFlat : IndexIVF {
|
|
44
44
|
|
45
45
|
InvertedListScanner* get_InvertedListScanner(
|
46
46
|
bool store_pairs,
|
47
|
-
const IDSelector* sel
|
47
|
+
const IDSelector* sel,
|
48
|
+
const IVFSearchParameters* params) const override;
|
48
49
|
|
49
50
|
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
50
51
|
const override;
|
@@ -1321,7 +1321,8 @@ InvertedListScanner* get_InvertedListScanner2(
|
|
1321
1321
|
|
1322
1322
|
InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
|
1323
1323
|
bool store_pairs,
|
1324
|
-
const IDSelector* sel
|
1324
|
+
const IDSelector* sel,
|
1325
|
+
const IVFSearchParameters*) const {
|
1325
1326
|
if (sel) {
|
1326
1327
|
return get_InvertedListScanner2<true>(*this, store_pairs, sel);
|
1327
1328
|
} else {
|
@@ -134,7 +134,8 @@ struct IndexIVFPQ : IndexIVF {
|
|
134
134
|
|
135
135
|
InvertedListScanner* get_InvertedListScanner(
|
136
136
|
bool store_pairs,
|
137
|
-
const IDSelector* sel
|
137
|
+
const IDSelector* sel,
|
138
|
+
const IVFSearchParameters* params) const override;
|
138
139
|
|
139
140
|
/// build precomputed table
|
140
141
|
void precompute_table();
|
@@ -0,0 +1,277 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <faiss/IndexIVFRaBitQ.h>
|
9
|
+
|
10
|
+
#include <omp.h>
|
11
|
+
|
12
|
+
#include <cstddef>
|
13
|
+
#include <cstdint>
|
14
|
+
#include <memory>
|
15
|
+
#include <vector>
|
16
|
+
|
17
|
+
#include <faiss/impl/FaissAssert.h>
|
18
|
+
#include <faiss/impl/RaBitQuantizer.h>
|
19
|
+
|
20
|
+
namespace faiss {
|
21
|
+
|
22
|
+
IndexIVFRaBitQ::IndexIVFRaBitQ(
|
23
|
+
Index* quantizer,
|
24
|
+
const size_t d,
|
25
|
+
const size_t nlist,
|
26
|
+
MetricType metric)
|
27
|
+
: IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) {
|
28
|
+
code_size = rabitq.code_size;
|
29
|
+
invlists->code_size = code_size;
|
30
|
+
is_trained = false;
|
31
|
+
|
32
|
+
by_residual = true;
|
33
|
+
}
|
34
|
+
|
35
|
+
IndexIVFRaBitQ::IndexIVFRaBitQ() {
|
36
|
+
by_residual = true;
|
37
|
+
}
|
38
|
+
|
39
|
+
void IndexIVFRaBitQ::train_encoder(
|
40
|
+
idx_t n,
|
41
|
+
const float* x,
|
42
|
+
const idx_t* assign) {
|
43
|
+
rabitq.train(n, x);
|
44
|
+
}
|
45
|
+
|
46
|
+
void IndexIVFRaBitQ::encode_vectors(
|
47
|
+
idx_t n,
|
48
|
+
const float* x,
|
49
|
+
const idx_t* list_nos,
|
50
|
+
uint8_t* codes,
|
51
|
+
bool include_listnos) const {
|
52
|
+
size_t coarse_size = include_listnos ? coarse_code_size() : 0;
|
53
|
+
memset(codes, 0, (code_size + coarse_size) * n);
|
54
|
+
|
55
|
+
#pragma omp parallel if (n > 1000)
|
56
|
+
{
|
57
|
+
std::vector<float> centroid(d);
|
58
|
+
|
59
|
+
#pragma omp for
|
60
|
+
for (idx_t i = 0; i < n; i++) {
|
61
|
+
int64_t list_no = list_nos[i];
|
62
|
+
if (list_no >= 0) {
|
63
|
+
const float* xi = x + i * d;
|
64
|
+
uint8_t* code = codes + i * (code_size + coarse_size);
|
65
|
+
|
66
|
+
// both by_residual and !by_residual lead to the same code
|
67
|
+
quantizer->reconstruct(list_no, centroid.data());
|
68
|
+
rabitq.compute_codes_core(
|
69
|
+
xi, code + coarse_size, 1, centroid.data());
|
70
|
+
|
71
|
+
if (coarse_size) {
|
72
|
+
encode_listno(list_no, code);
|
73
|
+
}
|
74
|
+
}
|
75
|
+
}
|
76
|
+
}
|
77
|
+
}
|
78
|
+
|
79
|
+
void IndexIVFRaBitQ::add_core(
|
80
|
+
idx_t n,
|
81
|
+
const float* x,
|
82
|
+
const idx_t* xids,
|
83
|
+
const idx_t* precomputed_idx,
|
84
|
+
void* inverted_list_context) {
|
85
|
+
FAISS_THROW_IF_NOT(is_trained);
|
86
|
+
|
87
|
+
DirectMapAdd dm_add(direct_map, n, xids);
|
88
|
+
|
89
|
+
#pragma omp parallel
|
90
|
+
{
|
91
|
+
std::vector<uint8_t> one_code(code_size);
|
92
|
+
std::vector<float> centroid(d);
|
93
|
+
|
94
|
+
int nt = omp_get_num_threads();
|
95
|
+
int rank = omp_get_thread_num();
|
96
|
+
|
97
|
+
// each thread takes care of a subset of lists
|
98
|
+
for (size_t i = 0; i < n; i++) {
|
99
|
+
int64_t list_no = precomputed_idx[i];
|
100
|
+
if (list_no >= 0 && list_no % nt == rank) {
|
101
|
+
int64_t id = xids ? xids[i] : ntotal + i;
|
102
|
+
|
103
|
+
const float* xi = x + i * d;
|
104
|
+
|
105
|
+
// both by_residual and !by_residual lead to the same code
|
106
|
+
quantizer->reconstruct(list_no, centroid.data());
|
107
|
+
rabitq.compute_codes_core(
|
108
|
+
xi, one_code.data(), 1, centroid.data());
|
109
|
+
|
110
|
+
size_t ofs = invlists->add_entry(
|
111
|
+
list_no, id, one_code.data(), inverted_list_context);
|
112
|
+
|
113
|
+
dm_add.add(i, list_no, ofs);
|
114
|
+
|
115
|
+
} else if (rank == 0 && list_no == -1) {
|
116
|
+
dm_add.add(i, -1, 0);
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
ntotal += n;
|
122
|
+
}
|
123
|
+
|
124
|
+
struct RaBitInvertedListScanner : InvertedListScanner {
|
125
|
+
const IndexIVFRaBitQ& ivf_rabitq;
|
126
|
+
|
127
|
+
std::vector<float> reconstructed_centroid;
|
128
|
+
std::vector<float> query_vector;
|
129
|
+
|
130
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc;
|
131
|
+
|
132
|
+
uint8_t qb = 0;
|
133
|
+
|
134
|
+
RaBitInvertedListScanner(
|
135
|
+
const IndexIVFRaBitQ& ivf_rabitq_in,
|
136
|
+
bool store_pairs = false,
|
137
|
+
const IDSelector* sel = nullptr,
|
138
|
+
uint8_t qb_in = 0)
|
139
|
+
: InvertedListScanner(store_pairs, sel),
|
140
|
+
ivf_rabitq{ivf_rabitq_in},
|
141
|
+
qb{qb_in} {
|
142
|
+
keep_max = is_similarity_metric(ivf_rabitq.metric_type);
|
143
|
+
code_size = ivf_rabitq.code_size;
|
144
|
+
}
|
145
|
+
|
146
|
+
/// from now on we handle this query.
|
147
|
+
void set_query(const float* query_vector_in) override {
|
148
|
+
query_vector.assign(query_vector_in, query_vector_in + ivf_rabitq.d);
|
149
|
+
|
150
|
+
internal_try_setup_dc();
|
151
|
+
}
|
152
|
+
|
153
|
+
/// following codes come from this inverted list
|
154
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
155
|
+
this->list_no = list_no;
|
156
|
+
|
157
|
+
reconstructed_centroid.resize(ivf_rabitq.d);
|
158
|
+
ivf_rabitq.quantizer->reconstruct(
|
159
|
+
list_no, reconstructed_centroid.data());
|
160
|
+
|
161
|
+
internal_try_setup_dc();
|
162
|
+
}
|
163
|
+
|
164
|
+
/// compute a single query-to-code distance
|
165
|
+
float distance_to_code(const uint8_t* code) const override {
|
166
|
+
return dc->distance_to_code(code);
|
167
|
+
}
|
168
|
+
|
169
|
+
void internal_try_setup_dc() {
|
170
|
+
if (!query_vector.empty() && !reconstructed_centroid.empty()) {
|
171
|
+
// both query_vector and centroid are available!
|
172
|
+
// set up DistanceComputer
|
173
|
+
dc.reset(ivf_rabitq.rabitq.get_distance_computer(
|
174
|
+
qb, reconstructed_centroid.data()));
|
175
|
+
|
176
|
+
dc->set_query(query_vector.data());
|
177
|
+
}
|
178
|
+
}
|
179
|
+
};
|
180
|
+
|
181
|
+
InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner(
|
182
|
+
bool store_pairs,
|
183
|
+
const IDSelector* sel,
|
184
|
+
const IVFSearchParameters* search_params_in) const {
|
185
|
+
uint8_t used_qb = qb;
|
186
|
+
if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
|
187
|
+
search_params_in)) {
|
188
|
+
used_qb = params->qb;
|
189
|
+
}
|
190
|
+
|
191
|
+
return new RaBitInvertedListScanner(*this, store_pairs, sel, used_qb);
|
192
|
+
}
|
193
|
+
|
194
|
+
void IndexIVFRaBitQ::reconstruct_from_offset(
|
195
|
+
int64_t list_no,
|
196
|
+
int64_t offset,
|
197
|
+
float* recons) const {
|
198
|
+
const uint8_t* code = invlists->get_single_code(list_no, offset);
|
199
|
+
|
200
|
+
std::vector<float> centroid(d);
|
201
|
+
quantizer->reconstruct(list_no, centroid.data());
|
202
|
+
|
203
|
+
rabitq.decode_core(code, recons, 1, centroid.data());
|
204
|
+
}
|
205
|
+
|
206
|
+
void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
|
207
|
+
size_t coarse_size = coarse_code_size();
|
208
|
+
|
209
|
+
#pragma omp parallel
|
210
|
+
{
|
211
|
+
std::vector<float> centroid(d);
|
212
|
+
|
213
|
+
#pragma omp for
|
214
|
+
for (idx_t i = 0; i < n; i++) {
|
215
|
+
const uint8_t* code = codes + i * (code_size + coarse_size);
|
216
|
+
int64_t list_no = decode_listno(code);
|
217
|
+
float* xi = x + i * d;
|
218
|
+
|
219
|
+
quantizer->reconstruct(list_no, centroid.data());
|
220
|
+
rabitq.decode_core(code + coarse_size, xi, 1, centroid.data());
|
221
|
+
}
|
222
|
+
}
|
223
|
+
}
|
224
|
+
|
225
|
+
struct IVFRaBitDistanceComputer : DistanceComputer {
|
226
|
+
const float* q = nullptr;
|
227
|
+
const IndexIVFRaBitQ* parent = nullptr;
|
228
|
+
|
229
|
+
void set_query(const float* x) override;
|
230
|
+
|
231
|
+
float operator()(idx_t i) override;
|
232
|
+
|
233
|
+
float symmetric_dis(idx_t i, idx_t j) override;
|
234
|
+
};
|
235
|
+
|
236
|
+
void IVFRaBitDistanceComputer::set_query(const float* x) {
|
237
|
+
q = x;
|
238
|
+
}
|
239
|
+
|
240
|
+
float IVFRaBitDistanceComputer::operator()(idx_t i) {
|
241
|
+
// find the appropriate list
|
242
|
+
idx_t lo = parent->direct_map.get(i);
|
243
|
+
uint64_t list_no = lo_listno(lo);
|
244
|
+
uint64_t offset = lo_offset(lo);
|
245
|
+
|
246
|
+
const uint8_t* code = parent->invlists->get_single_code(list_no, offset);
|
247
|
+
|
248
|
+
// ok, we know the appropriate cluster that we need
|
249
|
+
std::vector<float> centroid(parent->d);
|
250
|
+
parent->quantizer->reconstruct(list_no, centroid.data());
|
251
|
+
|
252
|
+
// compute the distance
|
253
|
+
float distance = 0;
|
254
|
+
|
255
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc(
|
256
|
+
parent->rabitq.get_distance_computer(parent->qb, centroid.data()));
|
257
|
+
dc->set_query(q);
|
258
|
+
distance = dc->distance_to_code(code);
|
259
|
+
|
260
|
+
// deallocate
|
261
|
+
parent->invlists->release_codes(list_no, code);
|
262
|
+
|
263
|
+
// done
|
264
|
+
return distance;
|
265
|
+
}
|
266
|
+
|
267
|
+
float IVFRaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
|
268
|
+
FAISS_THROW_MSG("Not implemented");
|
269
|
+
}
|
270
|
+
|
271
|
+
DistanceComputer* IndexIVFRaBitQ::get_distance_computer() const {
|
272
|
+
IVFRaBitDistanceComputer* dc = new IVFRaBitDistanceComputer;
|
273
|
+
dc->parent = this;
|
274
|
+
return dc;
|
275
|
+
}
|
276
|
+
|
277
|
+
} // namespace faiss
|
@@ -0,0 +1,70 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <cstddef>
|
11
|
+
#include <cstdint>
|
12
|
+
|
13
|
+
#include <faiss/Index.h>
|
14
|
+
#include <faiss/IndexIVF.h>
|
15
|
+
|
16
|
+
#include <faiss/impl/RaBitQuantizer.h>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
struct IVFRaBitQSearchParameters : IVFSearchParameters {
|
21
|
+
uint8_t qb = 0;
|
22
|
+
};
|
23
|
+
|
24
|
+
// * by_residual is true, just by design
|
25
|
+
struct IndexIVFRaBitQ : IndexIVF {
|
26
|
+
RaBitQuantizer rabitq;
|
27
|
+
|
28
|
+
// the default number of bits to quantize a query with.
|
29
|
+
// use '0' to disable quantization and use raw fp32 values.
|
30
|
+
uint8_t qb = 0;
|
31
|
+
|
32
|
+
IndexIVFRaBitQ(
|
33
|
+
Index* quantizer,
|
34
|
+
const size_t d,
|
35
|
+
const size_t nlist,
|
36
|
+
MetricType metric = METRIC_L2);
|
37
|
+
|
38
|
+
IndexIVFRaBitQ();
|
39
|
+
|
40
|
+
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
|
41
|
+
|
42
|
+
void encode_vectors(
|
43
|
+
idx_t n,
|
44
|
+
const float* x,
|
45
|
+
const idx_t* list_nos,
|
46
|
+
uint8_t* codes,
|
47
|
+
bool include_listnos = false) const override;
|
48
|
+
|
49
|
+
void add_core(
|
50
|
+
idx_t n,
|
51
|
+
const float* x,
|
52
|
+
const idx_t* xids,
|
53
|
+
const idx_t* precomputed_idx,
|
54
|
+
void* inverted_list_context = nullptr) override;
|
55
|
+
|
56
|
+
InvertedListScanner* get_InvertedListScanner(
|
57
|
+
bool store_pairs,
|
58
|
+
const IDSelector* sel,
|
59
|
+
const IVFSearchParameters* params) const override;
|
60
|
+
|
61
|
+
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
62
|
+
const override;
|
63
|
+
|
64
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
65
|
+
|
66
|
+
// unfortunately
|
67
|
+
DistanceComputer* get_distance_computer() const override;
|
68
|
+
};
|
69
|
+
|
70
|
+
} // namespace faiss
|
@@ -301,7 +301,8 @@ struct BuildScanner {
|
|
301
301
|
|
302
302
|
InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
|
303
303
|
bool store_pairs,
|
304
|
-
const IDSelector* sel
|
304
|
+
const IDSelector* sel,
|
305
|
+
const IVFSearchParameters*) const {
|
305
306
|
FAISS_THROW_IF_NOT(!sel);
|
306
307
|
BuildScanner bs;
|
307
308
|
return dispatch_HammingComputer(code_size, bs, this, store_pairs);
|
@@ -71,7 +71,8 @@ struct IndexIVFSpectralHash : IndexIVF {
|
|
71
71
|
|
72
72
|
InvertedListScanner* get_InvertedListScanner(
|
73
73
|
bool store_pairs,
|
74
|
-
const IDSelector* sel
|
74
|
+
const IDSelector* sel,
|
75
|
+
const IVFSearchParameters* params) const override;
|
75
76
|
|
76
77
|
/** replace the vector transform for an empty (and possibly untrained) index
|
77
78
|
*/
|
@@ -0,0 +1,148 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <faiss/IndexRaBitQ.h>
|
9
|
+
|
10
|
+
#include <faiss/impl/FaissAssert.h>
|
11
|
+
#include <faiss/impl/ResultHandler.h>
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
IndexRaBitQ::IndexRaBitQ() = default;
|
16
|
+
|
17
|
+
IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric)
|
18
|
+
: IndexFlatCodes(0, d, metric), rabitq(d, metric) {
|
19
|
+
code_size = rabitq.code_size;
|
20
|
+
|
21
|
+
is_trained = false;
|
22
|
+
}
|
23
|
+
|
24
|
+
void IndexRaBitQ::train(idx_t n, const float* x) {
|
25
|
+
// compute a centroid
|
26
|
+
std::vector<float> centroid(d, 0);
|
27
|
+
for (size_t i = 0; i < n; i++) {
|
28
|
+
for (size_t j = 0; j < d; j++) {
|
29
|
+
centroid[j] += x[i * d + j];
|
30
|
+
}
|
31
|
+
}
|
32
|
+
|
33
|
+
if (n != 0) {
|
34
|
+
for (size_t j = 0; j < d; j++) {
|
35
|
+
centroid[j] /= (float)n;
|
36
|
+
}
|
37
|
+
}
|
38
|
+
|
39
|
+
center = std::move(centroid);
|
40
|
+
|
41
|
+
//
|
42
|
+
rabitq.train(n, x);
|
43
|
+
is_trained = true;
|
44
|
+
}
|
45
|
+
|
46
|
+
void IndexRaBitQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
47
|
+
FAISS_THROW_IF_NOT(is_trained);
|
48
|
+
rabitq.compute_codes_core(x, bytes, n, center.data());
|
49
|
+
}
|
50
|
+
|
51
|
+
void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
52
|
+
FAISS_THROW_IF_NOT(is_trained);
|
53
|
+
rabitq.decode_core(bytes, x, n, center.data());
|
54
|
+
}
|
55
|
+
|
56
|
+
FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
|
57
|
+
FlatCodesDistanceComputer* dc =
|
58
|
+
rabitq.get_distance_computer(qb, center.data());
|
59
|
+
dc->code_size = rabitq.code_size;
|
60
|
+
dc->codes = codes.data();
|
61
|
+
return dc;
|
62
|
+
}
|
63
|
+
|
64
|
+
FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
|
65
|
+
const uint8_t qb) const {
|
66
|
+
FlatCodesDistanceComputer* dc =
|
67
|
+
rabitq.get_distance_computer(qb, center.data());
|
68
|
+
dc->code_size = rabitq.code_size;
|
69
|
+
dc->codes = codes.data();
|
70
|
+
return dc;
|
71
|
+
}
|
72
|
+
|
73
|
+
namespace {
|
74
|
+
|
75
|
+
struct Run_search_with_dc_res {
|
76
|
+
using T = void;
|
77
|
+
|
78
|
+
uint8_t qb = 0;
|
79
|
+
|
80
|
+
template <class BlockResultHandler>
|
81
|
+
void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
|
82
|
+
size_t ntotal = index->ntotal;
|
83
|
+
using SingleResultHandler =
|
84
|
+
typename BlockResultHandler::SingleResultHandler;
|
85
|
+
const int d = index->d;
|
86
|
+
|
87
|
+
#pragma omp parallel // if (res.nq > 100)
|
88
|
+
{
|
89
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc(
|
90
|
+
index->get_quantized_distance_computer(qb));
|
91
|
+
SingleResultHandler resi(res);
|
92
|
+
#pragma omp for
|
93
|
+
for (int64_t q = 0; q < res.nq; q++) {
|
94
|
+
resi.begin(q);
|
95
|
+
dc->set_query(xq + d * q);
|
96
|
+
for (size_t i = 0; i < ntotal; i++) {
|
97
|
+
if (res.is_in_selection(i)) {
|
98
|
+
float dis = (*dc)(i);
|
99
|
+
resi.add_result(dis, i);
|
100
|
+
}
|
101
|
+
}
|
102
|
+
resi.end();
|
103
|
+
}
|
104
|
+
}
|
105
|
+
}
|
106
|
+
};
|
107
|
+
|
108
|
+
} // namespace
|
109
|
+
|
110
|
+
void IndexRaBitQ::search(
|
111
|
+
idx_t n,
|
112
|
+
const float* x,
|
113
|
+
idx_t k,
|
114
|
+
float* distances,
|
115
|
+
idx_t* labels,
|
116
|
+
const SearchParameters* params_in) const {
|
117
|
+
uint8_t used_qb = qb;
|
118
|
+
if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
|
119
|
+
used_qb = params->qb;
|
120
|
+
}
|
121
|
+
|
122
|
+
const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
|
123
|
+
Run_search_with_dc_res r;
|
124
|
+
r.qb = used_qb;
|
125
|
+
|
126
|
+
dispatch_knn_ResultHandler(
|
127
|
+
n, distances, labels, k, metric_type, sel, r, this, x);
|
128
|
+
}
|
129
|
+
|
130
|
+
void IndexRaBitQ::range_search(
|
131
|
+
idx_t n,
|
132
|
+
const float* x,
|
133
|
+
float radius,
|
134
|
+
RangeSearchResult* result,
|
135
|
+
const SearchParameters* params_in) const {
|
136
|
+
uint8_t used_qb = qb;
|
137
|
+
if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
|
138
|
+
used_qb = params->qb;
|
139
|
+
}
|
140
|
+
|
141
|
+
const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
|
142
|
+
Run_search_with_dc_res r;
|
143
|
+
r.qb = used_qb;
|
144
|
+
|
145
|
+
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
|
146
|
+
}
|
147
|
+
|
148
|
+
} // namespace faiss
|
@@ -0,0 +1,65 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <faiss/IndexFlatCodes.h>
|
11
|
+
#include <faiss/impl/RaBitQuantizer.h>
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
struct RaBitQSearchParameters : SearchParameters {
|
16
|
+
uint8_t qb = 0;
|
17
|
+
};
|
18
|
+
|
19
|
+
struct IndexRaBitQ : IndexFlatCodes {
|
20
|
+
RaBitQuantizer rabitq;
|
21
|
+
|
22
|
+
// center of all points
|
23
|
+
std::vector<float> center;
|
24
|
+
|
25
|
+
// the default number of bits to quantize a query with.
|
26
|
+
// use '0' to disable quantization and use raw fp32 values.
|
27
|
+
uint8_t qb = 0;
|
28
|
+
|
29
|
+
IndexRaBitQ();
|
30
|
+
|
31
|
+
IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
|
32
|
+
|
33
|
+
void train(idx_t n, const float* x) override;
|
34
|
+
|
35
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
36
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
37
|
+
|
38
|
+
// returns a quantized-to-qb bits DC if qb > 0
|
39
|
+
// returns a default fp32-based DC if qb == 0
|
40
|
+
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
|
41
|
+
|
42
|
+
// returns a quantized-to-qb bits DC if qb_in > 0
|
43
|
+
// returns a default fp32-based DC if qb_in == 0
|
44
|
+
FlatCodesDistanceComputer* get_quantized_distance_computer(
|
45
|
+
const uint8_t qb_in) const;
|
46
|
+
|
47
|
+
// Don't rely on sa_decode(), bcz it is good for IP, but not for L2.
|
48
|
+
// As a result, use get_FlatCodesDistanceComputer() for the search.
|
49
|
+
void search(
|
50
|
+
idx_t n,
|
51
|
+
const float* x,
|
52
|
+
idx_t k,
|
53
|
+
float* distances,
|
54
|
+
idx_t* labels,
|
55
|
+
const SearchParameters* params = nullptr) const override;
|
56
|
+
|
57
|
+
void range_search(
|
58
|
+
idx_t n,
|
59
|
+
const float* x,
|
60
|
+
float radius,
|
61
|
+
RangeSearchResult* result,
|
62
|
+
const SearchParameters* params = nullptr) const override;
|
63
|
+
};
|
64
|
+
|
65
|
+
} // namespace faiss
|
@@ -254,7 +254,8 @@ void IndexIVFScalarQuantizer::add_core(
|
|
254
254
|
|
255
255
|
InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
|
256
256
|
bool store_pairs,
|
257
|
-
const IDSelector* sel
|
257
|
+
const IDSelector* sel,
|
258
|
+
const IVFSearchParameters*) const {
|
258
259
|
return sq.select_InvertedListScanner(
|
259
260
|
metric_type, quantizer, store_pairs, sel, by_residual);
|
260
261
|
}
|
@@ -96,7 +96,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
|
|
96
96
|
|
97
97
|
InvertedListScanner* get_InvertedListScanner(
|
98
98
|
bool store_pairs,
|
99
|
-
const IDSelector* sel
|
99
|
+
const IDSelector* sel,
|
100
|
+
const IVFSearchParameters* params) const override;
|
100
101
|
|
101
102
|
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
102
103
|
const override;
|