faiss 0.3.4 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/faiss/ext.cpp +2 -3
- data/ext/faiss/index.cpp +13 -14
- data/ext/faiss/index_binary.cpp +2 -0
- data/ext/faiss/kmeans.cpp +2 -0
- data/ext/faiss/pca_matrix.cpp +2 -0
- data/ext/faiss/product_quantizer.cpp +2 -0
- data/ext/faiss/utils.cpp +3 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +11 -8
- data/vendor/faiss/faiss/Clustering.cpp +0 -16
- data/vendor/faiss/faiss/IVFlib.cpp +213 -0
- data/vendor/faiss/faiss/IVFlib.h +42 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
- data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
- data/vendor/faiss/faiss/IndexHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
- data/vendor/faiss/faiss/IndexIVF.h +5 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
- data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
- data/vendor/faiss/faiss/clone_index.cpp +38 -3
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
- data/vendor/faiss/faiss/impl/HNSW.h +5 -4
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
- data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
- data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
- data/vendor/faiss/faiss/impl/io.h +2 -2
- data/vendor/faiss/faiss/impl/io_macros.h +2 -0
- data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
- data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
- data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
- data/vendor/faiss/faiss/index_factory.cpp +16 -5
- data/vendor/faiss/faiss/index_io.h +4 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
- data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
- data/vendor/faiss/faiss/utils/hamming.h +7 -3
- data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.h +3 -3
- metadata +16 -4
@@ -8,9 +8,7 @@
|
|
8
8
|
#include <faiss/IndexHNSW.h>
|
9
9
|
|
10
10
|
#include <omp.h>
|
11
|
-
#include <cassert>
|
12
11
|
#include <cinttypes>
|
13
|
-
#include <cmath>
|
14
12
|
#include <cstdio>
|
15
13
|
#include <cstdlib>
|
16
14
|
#include <cstring>
|
@@ -124,7 +122,7 @@ void hnsw_add_vertices(
|
|
124
122
|
int i1 = n;
|
125
123
|
|
126
124
|
for (int pt_level = hist.size() - 1;
|
127
|
-
pt_level >= !index_hnsw.init_level0;
|
125
|
+
pt_level >= int(!index_hnsw.init_level0);
|
128
126
|
pt_level--) {
|
129
127
|
int i0 = i1 - hist[pt_level];
|
130
128
|
|
@@ -212,7 +210,9 @@ IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
|
|
212
210
|
: Index(d, metric), hnsw(M) {}
|
213
211
|
|
214
212
|
IndexHNSW::IndexHNSW(Index* storage, int M)
|
215
|
-
: Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {
|
213
|
+
: Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {
|
214
|
+
metric_arg = storage->metric_arg;
|
215
|
+
}
|
216
216
|
|
217
217
|
IndexHNSW::~IndexHNSW() {
|
218
218
|
if (own_fields) {
|
@@ -237,19 +237,19 @@ void hnsw_search(
|
|
237
237
|
idx_t n,
|
238
238
|
const float* x,
|
239
239
|
BlockResultHandler& bres,
|
240
|
-
const SearchParameters*
|
240
|
+
const SearchParameters* params) {
|
241
241
|
FAISS_THROW_IF_NOT_MSG(
|
242
242
|
index->storage,
|
243
243
|
"No storage index, please use IndexHNSWFlat (or variants) "
|
244
244
|
"instead of IndexHNSW directly");
|
245
|
-
const SearchParametersHNSW* params = nullptr;
|
246
245
|
const HNSW& hnsw = index->hnsw;
|
247
246
|
|
248
247
|
int efSearch = hnsw.efSearch;
|
249
|
-
if (
|
250
|
-
|
251
|
-
|
252
|
-
|
248
|
+
if (params) {
|
249
|
+
if (const SearchParametersHNSW* hnsw_params =
|
250
|
+
dynamic_cast<const SearchParametersHNSW*>(params)) {
|
251
|
+
efSearch = hnsw_params->efSearch;
|
252
|
+
}
|
253
253
|
}
|
254
254
|
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
|
255
255
|
|
@@ -294,13 +294,13 @@ void IndexHNSW::search(
|
|
294
294
|
idx_t k,
|
295
295
|
float* distances,
|
296
296
|
idx_t* labels,
|
297
|
-
const SearchParameters*
|
297
|
+
const SearchParameters* params) const {
|
298
298
|
FAISS_THROW_IF_NOT(k > 0);
|
299
299
|
|
300
300
|
using RH = HeapBlockResultHandler<HNSW::C>;
|
301
301
|
RH bres(n, distances, labels, k);
|
302
302
|
|
303
|
-
hnsw_search(this, n, x, bres,
|
303
|
+
hnsw_search(this, n, x, bres, params);
|
304
304
|
|
305
305
|
if (is_similarity_metric(this->metric_type)) {
|
306
306
|
// we need to revert the negated distances
|
@@ -408,17 +408,10 @@ void IndexHNSW::search_level_0(
|
|
408
408
|
idx_t* labels,
|
409
409
|
int nprobe,
|
410
410
|
int search_type,
|
411
|
-
const SearchParameters*
|
411
|
+
const SearchParameters* params) const {
|
412
412
|
FAISS_THROW_IF_NOT(k > 0);
|
413
413
|
FAISS_THROW_IF_NOT(nprobe > 0);
|
414
414
|
|
415
|
-
const SearchParametersHNSW* params = nullptr;
|
416
|
-
|
417
|
-
if (params_in) {
|
418
|
-
params = dynamic_cast<const SearchParametersHNSW*>(params_in);
|
419
|
-
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
|
420
|
-
}
|
421
|
-
|
422
415
|
storage_idx_t ntotal = hnsw.levels.size();
|
423
416
|
|
424
417
|
using RH = HeapBlockResultHandler<HNSW::C>;
|
@@ -138,7 +138,7 @@ struct IndexHNSWPQ : IndexHNSW {
|
|
138
138
|
void train(idx_t n, const float* x) override;
|
139
139
|
};
|
140
140
|
|
141
|
-
/** SQ index topped with
|
141
|
+
/** SQ index topped with a HNSW structure to access elements
|
142
142
|
* more efficiently.
|
143
143
|
*/
|
144
144
|
struct IndexHNSWSQ : IndexHNSW {
|
@@ -455,7 +455,7 @@ void IndexIVF::search_preassigned(
|
|
455
455
|
#pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
|
456
456
|
{
|
457
457
|
std::unique_ptr<InvertedListScanner> scanner(
|
458
|
-
get_InvertedListScanner(store_pairs, sel));
|
458
|
+
get_InvertedListScanner(store_pairs, sel, params));
|
459
459
|
|
460
460
|
/*****************************************************
|
461
461
|
* Depending on parallel_mode, there are two possible ways
|
@@ -796,7 +796,7 @@ void IndexIVF::range_search_preassigned(
|
|
796
796
|
{
|
797
797
|
RangeSearchPartialResult pres(result);
|
798
798
|
std::unique_ptr<InvertedListScanner> scanner(
|
799
|
-
get_InvertedListScanner(store_pairs, sel));
|
799
|
+
get_InvertedListScanner(store_pairs, sel, params));
|
800
800
|
FAISS_THROW_IF_NOT(scanner.get());
|
801
801
|
all_pres[omp_get_thread_num()] = &pres;
|
802
802
|
|
@@ -912,7 +912,8 @@ void IndexIVF::range_search_preassigned(
|
|
912
912
|
|
913
913
|
InvertedListScanner* IndexIVF::get_InvertedListScanner(
|
914
914
|
bool /*store_pairs*/,
|
915
|
-
const IDSelector* /* sel
|
915
|
+
const IDSelector* /* sel */,
|
916
|
+
const IVFSearchParameters* /* params */) const {
|
916
917
|
FAISS_THROW_MSG("get_InvertedListScanner not implemented");
|
917
918
|
}
|
918
919
|
|
@@ -1290,6 +1291,14 @@ size_t InvertedListScanner::scan_codes(
|
|
1290
1291
|
|
1291
1292
|
if (!keep_max) {
|
1292
1293
|
for (size_t j = 0; j < list_size; j++) {
|
1294
|
+
if (sel != nullptr) {
|
1295
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
1296
|
+
if (!sel->is_member(id)) {
|
1297
|
+
codes += code_size;
|
1298
|
+
continue;
|
1299
|
+
}
|
1300
|
+
}
|
1301
|
+
|
1293
1302
|
float dis = distance_to_code(codes);
|
1294
1303
|
if (dis < simi[0]) {
|
1295
1304
|
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
@@ -1300,6 +1309,14 @@ size_t InvertedListScanner::scan_codes(
|
|
1300
1309
|
}
|
1301
1310
|
} else {
|
1302
1311
|
for (size_t j = 0; j < list_size; j++) {
|
1312
|
+
if (sel != nullptr) {
|
1313
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
1314
|
+
if (!sel->is_member(id)) {
|
1315
|
+
codes += code_size;
|
1316
|
+
continue;
|
1317
|
+
}
|
1318
|
+
}
|
1319
|
+
|
1303
1320
|
float dis = distance_to_code(codes);
|
1304
1321
|
if (dis > simi[0]) {
|
1305
1322
|
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
@@ -312,11 +312,14 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
312
312
|
|
313
313
|
/** Get a scanner for this index (store_pairs means ignore labels)
|
314
314
|
*
|
315
|
-
* The default search implementation uses this to compute the distances
|
315
|
+
* The default search implementation uses this to compute the distances.
|
316
|
+
* Use sel instead of params->sel, because sel is initialized with
|
317
|
+
* params->sel, but may get overridden by IndexIVF's internal logic.
|
316
318
|
*/
|
317
319
|
virtual InvertedListScanner* get_InvertedListScanner(
|
318
320
|
bool store_pairs = false,
|
319
|
-
const IDSelector* sel = nullptr
|
321
|
+
const IDSelector* sel = nullptr,
|
322
|
+
const IVFSearchParameters* params = nullptr) const;
|
320
323
|
|
321
324
|
/** reconstruct a vector. Works only if maintain_direct_map is set to 1 or 2
|
322
325
|
*/
|
@@ -253,7 +253,8 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
|
|
253
253
|
|
254
254
|
InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
|
255
255
|
bool store_pairs,
|
256
|
-
const IDSelector* sel
|
256
|
+
const IDSelector* sel,
|
257
|
+
const IVFSearchParameters*) const {
|
257
258
|
FAISS_THROW_IF_NOT(!sel);
|
258
259
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
259
260
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
@@ -52,7 +52,8 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
|
|
52
52
|
|
53
53
|
InvertedListScanner* get_InvertedListScanner(
|
54
54
|
bool store_pairs,
|
55
|
-
const IDSelector* sel
|
55
|
+
const IDSelector* sel,
|
56
|
+
const IVFSearchParameters* params) const override;
|
56
57
|
|
57
58
|
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
|
58
59
|
|
@@ -223,7 +223,8 @@ InvertedListScanner* get_InvertedListScanner1(
|
|
223
223
|
|
224
224
|
InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
|
225
225
|
bool store_pairs,
|
226
|
-
const IDSelector* sel
|
226
|
+
const IDSelector* sel,
|
227
|
+
const IVFSearchParameters*) const {
|
227
228
|
if (sel) {
|
228
229
|
return get_InvertedListScanner1<true>(this, store_pairs, sel);
|
229
230
|
} else {
|
@@ -44,7 +44,8 @@ struct IndexIVFFlat : IndexIVF {
|
|
44
44
|
|
45
45
|
InvertedListScanner* get_InvertedListScanner(
|
46
46
|
bool store_pairs,
|
47
|
-
const IDSelector* sel
|
47
|
+
const IDSelector* sel,
|
48
|
+
const IVFSearchParameters* params) const override;
|
48
49
|
|
49
50
|
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
50
51
|
const override;
|
@@ -1321,7 +1321,8 @@ InvertedListScanner* get_InvertedListScanner2(
|
|
1321
1321
|
|
1322
1322
|
InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
|
1323
1323
|
bool store_pairs,
|
1324
|
-
const IDSelector* sel
|
1324
|
+
const IDSelector* sel,
|
1325
|
+
const IVFSearchParameters*) const {
|
1325
1326
|
if (sel) {
|
1326
1327
|
return get_InvertedListScanner2<true>(*this, store_pairs, sel);
|
1327
1328
|
} else {
|
@@ -134,7 +134,8 @@ struct IndexIVFPQ : IndexIVF {
|
|
134
134
|
|
135
135
|
InvertedListScanner* get_InvertedListScanner(
|
136
136
|
bool store_pairs,
|
137
|
-
const IDSelector* sel
|
137
|
+
const IDSelector* sel,
|
138
|
+
const IVFSearchParameters* params) const override;
|
138
139
|
|
139
140
|
/// build precomputed table
|
140
141
|
void precompute_table();
|
@@ -0,0 +1,277 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <faiss/IndexIVFRaBitQ.h>
|
9
|
+
|
10
|
+
#include <omp.h>
|
11
|
+
|
12
|
+
#include <cstddef>
|
13
|
+
#include <cstdint>
|
14
|
+
#include <memory>
|
15
|
+
#include <vector>
|
16
|
+
|
17
|
+
#include <faiss/impl/FaissAssert.h>
|
18
|
+
#include <faiss/impl/RaBitQuantizer.h>
|
19
|
+
|
20
|
+
namespace faiss {
|
21
|
+
|
22
|
+
IndexIVFRaBitQ::IndexIVFRaBitQ(
|
23
|
+
Index* quantizer,
|
24
|
+
const size_t d,
|
25
|
+
const size_t nlist,
|
26
|
+
MetricType metric)
|
27
|
+
: IndexIVF(quantizer, d, nlist, 0, metric), rabitq(d, metric) {
|
28
|
+
code_size = rabitq.code_size;
|
29
|
+
invlists->code_size = code_size;
|
30
|
+
is_trained = false;
|
31
|
+
|
32
|
+
by_residual = true;
|
33
|
+
}
|
34
|
+
|
35
|
+
IndexIVFRaBitQ::IndexIVFRaBitQ() {
|
36
|
+
by_residual = true;
|
37
|
+
}
|
38
|
+
|
39
|
+
void IndexIVFRaBitQ::train_encoder(
|
40
|
+
idx_t n,
|
41
|
+
const float* x,
|
42
|
+
const idx_t* assign) {
|
43
|
+
rabitq.train(n, x);
|
44
|
+
}
|
45
|
+
|
46
|
+
void IndexIVFRaBitQ::encode_vectors(
|
47
|
+
idx_t n,
|
48
|
+
const float* x,
|
49
|
+
const idx_t* list_nos,
|
50
|
+
uint8_t* codes,
|
51
|
+
bool include_listnos) const {
|
52
|
+
size_t coarse_size = include_listnos ? coarse_code_size() : 0;
|
53
|
+
memset(codes, 0, (code_size + coarse_size) * n);
|
54
|
+
|
55
|
+
#pragma omp parallel if (n > 1000)
|
56
|
+
{
|
57
|
+
std::vector<float> centroid(d);
|
58
|
+
|
59
|
+
#pragma omp for
|
60
|
+
for (idx_t i = 0; i < n; i++) {
|
61
|
+
int64_t list_no = list_nos[i];
|
62
|
+
if (list_no >= 0) {
|
63
|
+
const float* xi = x + i * d;
|
64
|
+
uint8_t* code = codes + i * (code_size + coarse_size);
|
65
|
+
|
66
|
+
// both by_residual and !by_residual lead to the same code
|
67
|
+
quantizer->reconstruct(list_no, centroid.data());
|
68
|
+
rabitq.compute_codes_core(
|
69
|
+
xi, code + coarse_size, 1, centroid.data());
|
70
|
+
|
71
|
+
if (coarse_size) {
|
72
|
+
encode_listno(list_no, code);
|
73
|
+
}
|
74
|
+
}
|
75
|
+
}
|
76
|
+
}
|
77
|
+
}
|
78
|
+
|
79
|
+
void IndexIVFRaBitQ::add_core(
|
80
|
+
idx_t n,
|
81
|
+
const float* x,
|
82
|
+
const idx_t* xids,
|
83
|
+
const idx_t* precomputed_idx,
|
84
|
+
void* inverted_list_context) {
|
85
|
+
FAISS_THROW_IF_NOT(is_trained);
|
86
|
+
|
87
|
+
DirectMapAdd dm_add(direct_map, n, xids);
|
88
|
+
|
89
|
+
#pragma omp parallel
|
90
|
+
{
|
91
|
+
std::vector<uint8_t> one_code(code_size);
|
92
|
+
std::vector<float> centroid(d);
|
93
|
+
|
94
|
+
int nt = omp_get_num_threads();
|
95
|
+
int rank = omp_get_thread_num();
|
96
|
+
|
97
|
+
// each thread takes care of a subset of lists
|
98
|
+
for (size_t i = 0; i < n; i++) {
|
99
|
+
int64_t list_no = precomputed_idx[i];
|
100
|
+
if (list_no >= 0 && list_no % nt == rank) {
|
101
|
+
int64_t id = xids ? xids[i] : ntotal + i;
|
102
|
+
|
103
|
+
const float* xi = x + i * d;
|
104
|
+
|
105
|
+
// both by_residual and !by_residual lead to the same code
|
106
|
+
quantizer->reconstruct(list_no, centroid.data());
|
107
|
+
rabitq.compute_codes_core(
|
108
|
+
xi, one_code.data(), 1, centroid.data());
|
109
|
+
|
110
|
+
size_t ofs = invlists->add_entry(
|
111
|
+
list_no, id, one_code.data(), inverted_list_context);
|
112
|
+
|
113
|
+
dm_add.add(i, list_no, ofs);
|
114
|
+
|
115
|
+
} else if (rank == 0 && list_no == -1) {
|
116
|
+
dm_add.add(i, -1, 0);
|
117
|
+
}
|
118
|
+
}
|
119
|
+
}
|
120
|
+
|
121
|
+
ntotal += n;
|
122
|
+
}
|
123
|
+
|
124
|
+
struct RaBitInvertedListScanner : InvertedListScanner {
|
125
|
+
const IndexIVFRaBitQ& ivf_rabitq;
|
126
|
+
|
127
|
+
std::vector<float> reconstructed_centroid;
|
128
|
+
std::vector<float> query_vector;
|
129
|
+
|
130
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc;
|
131
|
+
|
132
|
+
uint8_t qb = 0;
|
133
|
+
|
134
|
+
RaBitInvertedListScanner(
|
135
|
+
const IndexIVFRaBitQ& ivf_rabitq_in,
|
136
|
+
bool store_pairs = false,
|
137
|
+
const IDSelector* sel = nullptr,
|
138
|
+
uint8_t qb_in = 0)
|
139
|
+
: InvertedListScanner(store_pairs, sel),
|
140
|
+
ivf_rabitq{ivf_rabitq_in},
|
141
|
+
qb{qb_in} {
|
142
|
+
keep_max = is_similarity_metric(ivf_rabitq.metric_type);
|
143
|
+
code_size = ivf_rabitq.code_size;
|
144
|
+
}
|
145
|
+
|
146
|
+
/// from now on we handle this query.
|
147
|
+
void set_query(const float* query_vector_in) override {
|
148
|
+
query_vector.assign(query_vector_in, query_vector_in + ivf_rabitq.d);
|
149
|
+
|
150
|
+
internal_try_setup_dc();
|
151
|
+
}
|
152
|
+
|
153
|
+
/// following codes come from this inverted list
|
154
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
155
|
+
this->list_no = list_no;
|
156
|
+
|
157
|
+
reconstructed_centroid.resize(ivf_rabitq.d);
|
158
|
+
ivf_rabitq.quantizer->reconstruct(
|
159
|
+
list_no, reconstructed_centroid.data());
|
160
|
+
|
161
|
+
internal_try_setup_dc();
|
162
|
+
}
|
163
|
+
|
164
|
+
/// compute a single query-to-code distance
|
165
|
+
float distance_to_code(const uint8_t* code) const override {
|
166
|
+
return dc->distance_to_code(code);
|
167
|
+
}
|
168
|
+
|
169
|
+
void internal_try_setup_dc() {
|
170
|
+
if (!query_vector.empty() && !reconstructed_centroid.empty()) {
|
171
|
+
// both query_vector and centroid are available!
|
172
|
+
// set up DistanceComputer
|
173
|
+
dc.reset(ivf_rabitq.rabitq.get_distance_computer(
|
174
|
+
qb, reconstructed_centroid.data()));
|
175
|
+
|
176
|
+
dc->set_query(query_vector.data());
|
177
|
+
}
|
178
|
+
}
|
179
|
+
};
|
180
|
+
|
181
|
+
InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner(
|
182
|
+
bool store_pairs,
|
183
|
+
const IDSelector* sel,
|
184
|
+
const IVFSearchParameters* search_params_in) const {
|
185
|
+
uint8_t used_qb = qb;
|
186
|
+
if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
|
187
|
+
search_params_in)) {
|
188
|
+
used_qb = params->qb;
|
189
|
+
}
|
190
|
+
|
191
|
+
return new RaBitInvertedListScanner(*this, store_pairs, sel, used_qb);
|
192
|
+
}
|
193
|
+
|
194
|
+
void IndexIVFRaBitQ::reconstruct_from_offset(
|
195
|
+
int64_t list_no,
|
196
|
+
int64_t offset,
|
197
|
+
float* recons) const {
|
198
|
+
const uint8_t* code = invlists->get_single_code(list_no, offset);
|
199
|
+
|
200
|
+
std::vector<float> centroid(d);
|
201
|
+
quantizer->reconstruct(list_no, centroid.data());
|
202
|
+
|
203
|
+
rabitq.decode_core(code, recons, 1, centroid.data());
|
204
|
+
}
|
205
|
+
|
206
|
+
void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
|
207
|
+
size_t coarse_size = coarse_code_size();
|
208
|
+
|
209
|
+
#pragma omp parallel
|
210
|
+
{
|
211
|
+
std::vector<float> centroid(d);
|
212
|
+
|
213
|
+
#pragma omp for
|
214
|
+
for (idx_t i = 0; i < n; i++) {
|
215
|
+
const uint8_t* code = codes + i * (code_size + coarse_size);
|
216
|
+
int64_t list_no = decode_listno(code);
|
217
|
+
float* xi = x + i * d;
|
218
|
+
|
219
|
+
quantizer->reconstruct(list_no, centroid.data());
|
220
|
+
rabitq.decode_core(code + coarse_size, xi, 1, centroid.data());
|
221
|
+
}
|
222
|
+
}
|
223
|
+
}
|
224
|
+
|
225
|
+
struct IVFRaBitDistanceComputer : DistanceComputer {
|
226
|
+
const float* q = nullptr;
|
227
|
+
const IndexIVFRaBitQ* parent = nullptr;
|
228
|
+
|
229
|
+
void set_query(const float* x) override;
|
230
|
+
|
231
|
+
float operator()(idx_t i) override;
|
232
|
+
|
233
|
+
float symmetric_dis(idx_t i, idx_t j) override;
|
234
|
+
};
|
235
|
+
|
236
|
+
void IVFRaBitDistanceComputer::set_query(const float* x) {
|
237
|
+
q = x;
|
238
|
+
}
|
239
|
+
|
240
|
+
float IVFRaBitDistanceComputer::operator()(idx_t i) {
|
241
|
+
// find the appropriate list
|
242
|
+
idx_t lo = parent->direct_map.get(i);
|
243
|
+
uint64_t list_no = lo_listno(lo);
|
244
|
+
uint64_t offset = lo_offset(lo);
|
245
|
+
|
246
|
+
const uint8_t* code = parent->invlists->get_single_code(list_no, offset);
|
247
|
+
|
248
|
+
// ok, we know the appropriate cluster that we need
|
249
|
+
std::vector<float> centroid(parent->d);
|
250
|
+
parent->quantizer->reconstruct(list_no, centroid.data());
|
251
|
+
|
252
|
+
// compute the distance
|
253
|
+
float distance = 0;
|
254
|
+
|
255
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc(
|
256
|
+
parent->rabitq.get_distance_computer(parent->qb, centroid.data()));
|
257
|
+
dc->set_query(q);
|
258
|
+
distance = dc->distance_to_code(code);
|
259
|
+
|
260
|
+
// deallocate
|
261
|
+
parent->invlists->release_codes(list_no, code);
|
262
|
+
|
263
|
+
// done
|
264
|
+
return distance;
|
265
|
+
}
|
266
|
+
|
267
|
+
float IVFRaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
|
268
|
+
FAISS_THROW_MSG("Not implemented");
|
269
|
+
}
|
270
|
+
|
271
|
+
DistanceComputer* IndexIVFRaBitQ::get_distance_computer() const {
|
272
|
+
IVFRaBitDistanceComputer* dc = new IVFRaBitDistanceComputer;
|
273
|
+
dc->parent = this;
|
274
|
+
return dc;
|
275
|
+
}
|
276
|
+
|
277
|
+
} // namespace faiss
|
@@ -0,0 +1,70 @@
|
|
1
|
+
/*
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <cstddef>
|
11
|
+
#include <cstdint>
|
12
|
+
|
13
|
+
#include <faiss/Index.h>
|
14
|
+
#include <faiss/IndexIVF.h>
|
15
|
+
|
16
|
+
#include <faiss/impl/RaBitQuantizer.h>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
struct IVFRaBitQSearchParameters : IVFSearchParameters {
|
21
|
+
uint8_t qb = 0;
|
22
|
+
};
|
23
|
+
|
24
|
+
// * by_residual is true, just by design
|
25
|
+
struct IndexIVFRaBitQ : IndexIVF {
|
26
|
+
RaBitQuantizer rabitq;
|
27
|
+
|
28
|
+
// the default number of bits to quantize a query with.
|
29
|
+
// use '0' to disable quantization and use raw fp32 values.
|
30
|
+
uint8_t qb = 0;
|
31
|
+
|
32
|
+
IndexIVFRaBitQ(
|
33
|
+
Index* quantizer,
|
34
|
+
const size_t d,
|
35
|
+
const size_t nlist,
|
36
|
+
MetricType metric = METRIC_L2);
|
37
|
+
|
38
|
+
IndexIVFRaBitQ();
|
39
|
+
|
40
|
+
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
|
41
|
+
|
42
|
+
void encode_vectors(
|
43
|
+
idx_t n,
|
44
|
+
const float* x,
|
45
|
+
const idx_t* list_nos,
|
46
|
+
uint8_t* codes,
|
47
|
+
bool include_listnos = false) const override;
|
48
|
+
|
49
|
+
void add_core(
|
50
|
+
idx_t n,
|
51
|
+
const float* x,
|
52
|
+
const idx_t* xids,
|
53
|
+
const idx_t* precomputed_idx,
|
54
|
+
void* inverted_list_context = nullptr) override;
|
55
|
+
|
56
|
+
InvertedListScanner* get_InvertedListScanner(
|
57
|
+
bool store_pairs,
|
58
|
+
const IDSelector* sel,
|
59
|
+
const IVFSearchParameters* params) const override;
|
60
|
+
|
61
|
+
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
62
|
+
const override;
|
63
|
+
|
64
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
65
|
+
|
66
|
+
// unfortunately
|
67
|
+
DistanceComputer* get_distance_computer() const override;
|
68
|
+
};
|
69
|
+
|
70
|
+
} // namespace faiss
|
@@ -301,7 +301,8 @@ struct BuildScanner {
|
|
301
301
|
|
302
302
|
InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
|
303
303
|
bool store_pairs,
|
304
|
-
const IDSelector* sel
|
304
|
+
const IDSelector* sel,
|
305
|
+
const IVFSearchParameters*) const {
|
305
306
|
FAISS_THROW_IF_NOT(!sel);
|
306
307
|
BuildScanner bs;
|
307
308
|
return dispatch_HammingComputer(code_size, bs, this, store_pairs);
|
@@ -71,7 +71,8 @@ struct IndexIVFSpectralHash : IndexIVF {
|
|
71
71
|
|
72
72
|
InvertedListScanner* get_InvertedListScanner(
|
73
73
|
bool store_pairs,
|
74
|
-
const IDSelector* sel
|
74
|
+
const IDSelector* sel,
|
75
|
+
const IVFSearchParameters* params) const override;
|
75
76
|
|
76
77
|
/** replace the vector transform for an empty (and possibly untrained) index
|
77
78
|
*/
|