faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#include <faiss/IndexIVFPQFastScan.h>
|
|
9
9
|
|
|
10
|
+
#include <array>
|
|
10
11
|
#include <cassert>
|
|
11
12
|
#include <cstdio>
|
|
12
13
|
|
|
@@ -14,6 +15,7 @@
|
|
|
14
15
|
|
|
15
16
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
16
17
|
#include <faiss/impl/FaissAssert.h>
|
|
18
|
+
#include <faiss/utils/Heap.h>
|
|
17
19
|
#include <faiss/utils/distances.h>
|
|
18
20
|
#include <faiss/utils/simdlib.h>
|
|
19
21
|
|
|
@@ -210,10 +212,11 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
210
212
|
const float* x,
|
|
211
213
|
const CoarseQuantized& cq,
|
|
212
214
|
AlignedTable<float>& dis_tables,
|
|
213
|
-
AlignedTable<float>& biases
|
|
215
|
+
AlignedTable<float>& biases,
|
|
216
|
+
const FastScanDistancePostProcessing&) const {
|
|
214
217
|
size_t dim12 = pq.ksub * pq.M;
|
|
215
218
|
size_t d = pq.d;
|
|
216
|
-
size_t nprobe =
|
|
219
|
+
size_t nprobe = cq.nprobe;
|
|
217
220
|
|
|
218
221
|
if (by_residual) {
|
|
219
222
|
if (metric_type == METRIC_L2) {
|
|
@@ -292,4 +295,133 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
292
295
|
}
|
|
293
296
|
}
|
|
294
297
|
|
|
298
|
+
/*********************************************************
|
|
299
|
+
* InvertedListScanner for IVFPQFS
|
|
300
|
+
*********************************************************/
|
|
301
|
+
|
|
302
|
+
namespace {
|
|
303
|
+
|
|
304
|
+
struct IVFPQFastScanScanner : InvertedListScanner {
|
|
305
|
+
static constexpr int impl = 10; // based on search_implem_10
|
|
306
|
+
static constexpr size_t nq = 1; // 1 query at a time.
|
|
307
|
+
const IndexIVFPQFastScan& index;
|
|
308
|
+
AlignedTable<uint8_t> dis_tables;
|
|
309
|
+
AlignedTable<uint16_t> biases;
|
|
310
|
+
std::array<float, 2> normalizers{};
|
|
311
|
+
const float* xi = nullptr;
|
|
312
|
+
|
|
313
|
+
IVFPQFastScanScanner(
|
|
314
|
+
const IndexIVFPQFastScan& index,
|
|
315
|
+
bool store_pairs,
|
|
316
|
+
const IDSelector* sel)
|
|
317
|
+
: InvertedListScanner(store_pairs, sel), index(index) {
|
|
318
|
+
this->keep_max = is_similarity_metric(index.metric_type);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
void set_query(const float* query) override {
|
|
322
|
+
this->xi = query;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
|
326
|
+
this->list_no = list_no;
|
|
327
|
+
IndexIVFFastScan::CoarseQuantized cq{
|
|
328
|
+
.nprobe = 1, // 1 due to explicitly passing in list_no
|
|
329
|
+
.dis = &coarse_dis, // dis from query to list_no centroid.
|
|
330
|
+
.ids = &list_no, // id of the current list we are scanning
|
|
331
|
+
};
|
|
332
|
+
FastScanDistancePostProcessing empty_context{};
|
|
333
|
+
index.compute_LUT_uint8(
|
|
334
|
+
1, xi, cq, dis_tables, biases, &normalizers[0], empty_context);
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
float distance_to_code(const uint8_t* /* code */) const override {
|
|
338
|
+
// It's not really possible to implement a distance_to_code since codes
|
|
339
|
+
// for 32 database vectors are intermixed.
|
|
340
|
+
FAISS_THROW_MSG("not implemented");
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
// Based on IVFFastScan search_implem_10, since it also deals with 1 query
|
|
344
|
+
// at a time.
|
|
345
|
+
size_t scan_codes(
|
|
346
|
+
size_t ntotal,
|
|
347
|
+
const uint8_t* codes,
|
|
348
|
+
const idx_t* ids,
|
|
349
|
+
float* distances,
|
|
350
|
+
idx_t* labels,
|
|
351
|
+
size_t k) const override {
|
|
352
|
+
// initialize the current iteration heap to the worst possible value of
|
|
353
|
+
// the prior loop
|
|
354
|
+
std::vector<float> curr_dists(k, distances[0]);
|
|
355
|
+
std::vector<idx_t> curr_labels(k, labels[0]);
|
|
356
|
+
FastScanDistancePostProcessing empty_context{};
|
|
357
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler(
|
|
358
|
+
index.make_knn_handler(
|
|
359
|
+
!keep_max,
|
|
360
|
+
impl,
|
|
361
|
+
nq,
|
|
362
|
+
k,
|
|
363
|
+
curr_dists.data(),
|
|
364
|
+
curr_labels.data(),
|
|
365
|
+
sel,
|
|
366
|
+
empty_context,
|
|
367
|
+
&normalizers[0]));
|
|
368
|
+
|
|
369
|
+
// This does not quite match search_implem_10, but it is fine because
|
|
370
|
+
// the scanner operates on a single query at a time, and this value is
|
|
371
|
+
// used as the query index. For a single query, the value is always 0.
|
|
372
|
+
int qmap1[1] = {0};
|
|
373
|
+
|
|
374
|
+
handler->q_map = qmap1;
|
|
375
|
+
handler->begin(&normalizers[0]);
|
|
376
|
+
|
|
377
|
+
const uint8_t* LUT = dis_tables.get();
|
|
378
|
+
handler->dbias = biases.get();
|
|
379
|
+
|
|
380
|
+
handler->ntotal = ntotal;
|
|
381
|
+
handler->id_map = ids;
|
|
382
|
+
|
|
383
|
+
pq4_accumulate_loop(
|
|
384
|
+
1,
|
|
385
|
+
roundup(ntotal, index.bbs),
|
|
386
|
+
index.bbs,
|
|
387
|
+
static_cast<int>(index.M2),
|
|
388
|
+
codes,
|
|
389
|
+
LUT,
|
|
390
|
+
*handler,
|
|
391
|
+
nullptr);
|
|
392
|
+
|
|
393
|
+
// The handler is for the results of this iteration.
|
|
394
|
+
// Then we need a second heap to combine across iterations.
|
|
395
|
+
handler->end();
|
|
396
|
+
if (keep_max) {
|
|
397
|
+
minheap_addn(
|
|
398
|
+
k,
|
|
399
|
+
distances,
|
|
400
|
+
labels,
|
|
401
|
+
curr_dists.data(),
|
|
402
|
+
curr_labels.data(),
|
|
403
|
+
k);
|
|
404
|
+
} else {
|
|
405
|
+
maxheap_addn(
|
|
406
|
+
k,
|
|
407
|
+
distances,
|
|
408
|
+
labels,
|
|
409
|
+
curr_dists.data(),
|
|
410
|
+
curr_labels.data(),
|
|
411
|
+
k);
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
return handler->num_updates();
|
|
415
|
+
}
|
|
416
|
+
};
|
|
417
|
+
|
|
418
|
+
} // anonymous namespace
|
|
419
|
+
|
|
420
|
+
InvertedListScanner* IndexIVFPQFastScan::get_InvertedListScanner(
|
|
421
|
+
bool store_pairs,
|
|
422
|
+
const IDSelector* sel,
|
|
423
|
+
const IVFSearchParameters*) const {
|
|
424
|
+
return new IVFPQFastScanScanner(*this, store_pairs, sel);
|
|
425
|
+
}
|
|
426
|
+
|
|
295
427
|
} // namespace faiss
|
|
@@ -80,7 +80,13 @@ struct IndexIVFPQFastScan : IndexIVFFastScan {
|
|
|
80
80
|
const float* x,
|
|
81
81
|
const CoarseQuantized& cq,
|
|
82
82
|
AlignedTable<float>& dis_tables,
|
|
83
|
-
AlignedTable<float>& biases
|
|
83
|
+
AlignedTable<float>& biases,
|
|
84
|
+
const FastScanDistancePostProcessing& context) const override;
|
|
85
|
+
|
|
86
|
+
InvertedListScanner* get_InvertedListScanner(
|
|
87
|
+
bool store_pairs,
|
|
88
|
+
const IDSelector* sel,
|
|
89
|
+
const IVFSearchParameters*) const override;
|
|
84
90
|
};
|
|
85
91
|
|
|
86
92
|
} // namespace faiss
|
|
@@ -24,9 +24,10 @@ IndexIVFRaBitQ::IndexIVFRaBitQ(
|
|
|
24
24
|
const size_t d,
|
|
25
25
|
const size_t nlist,
|
|
26
26
|
MetricType metric,
|
|
27
|
-
bool own_invlists
|
|
27
|
+
bool own_invlists,
|
|
28
|
+
uint8_t nb_bits_in)
|
|
28
29
|
: IndexIVF(quantizer, d, nlist, 0, metric, own_invlists),
|
|
29
|
-
rabitq(d, metric) {
|
|
30
|
+
rabitq(d, metric, nb_bits_in) {
|
|
30
31
|
code_size = rabitq.code_size;
|
|
31
32
|
if (own_invlists) {
|
|
32
33
|
invlists->code_size = code_size;
|
|
@@ -153,17 +154,22 @@ struct RaBitInvertedListScanner : InvertedListScanner {
|
|
|
153
154
|
std::vector<float> query_vector;
|
|
154
155
|
|
|
155
156
|
std::unique_ptr<FlatCodesDistanceComputer> dc;
|
|
157
|
+
RaBitQDistanceComputer* rabitq_dc =
|
|
158
|
+
nullptr; // For multi-bit adaptive filtering
|
|
156
159
|
|
|
157
160
|
uint8_t qb = 0;
|
|
161
|
+
bool centered = false;
|
|
158
162
|
|
|
159
|
-
RaBitInvertedListScanner(
|
|
163
|
+
explicit RaBitInvertedListScanner(
|
|
160
164
|
const IndexIVFRaBitQ& ivf_rabitq_in,
|
|
161
165
|
bool store_pairs = false,
|
|
162
166
|
const IDSelector* sel = nullptr,
|
|
163
|
-
uint8_t qb_in = 0
|
|
167
|
+
uint8_t qb_in = 0,
|
|
168
|
+
bool centered = false)
|
|
164
169
|
: InvertedListScanner(store_pairs, sel),
|
|
165
170
|
ivf_rabitq{ivf_rabitq_in},
|
|
166
|
-
qb{qb_in}
|
|
171
|
+
qb{qb_in},
|
|
172
|
+
centered(centered) {
|
|
167
173
|
keep_max = is_similarity_metric(ivf_rabitq.metric_type);
|
|
168
174
|
code_size = ivf_rabitq.code_size;
|
|
169
175
|
}
|
|
@@ -191,14 +197,95 @@ struct RaBitInvertedListScanner : InvertedListScanner {
|
|
|
191
197
|
return dc->distance_to_code(code);
|
|
192
198
|
}
|
|
193
199
|
|
|
200
|
+
/// Override scan_codes to implement adaptive filtering for multi-bit codes
|
|
201
|
+
size_t scan_codes(
|
|
202
|
+
size_t list_size,
|
|
203
|
+
const uint8_t* codes,
|
|
204
|
+
const idx_t* ids,
|
|
205
|
+
float* simi,
|
|
206
|
+
idx_t* idxi,
|
|
207
|
+
size_t k) const override {
|
|
208
|
+
size_t ex_bits = ivf_rabitq.rabitq.nb_bits - 1;
|
|
209
|
+
|
|
210
|
+
// For 1-bit codes, use default implementation
|
|
211
|
+
if (ex_bits == 0 || rabitq_dc == nullptr) {
|
|
212
|
+
return InvertedListScanner::scan_codes(
|
|
213
|
+
list_size, codes, ids, simi, idxi, k);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
// Multi-bit: Two-stage search with adaptive filtering
|
|
217
|
+
size_t nup = 0;
|
|
218
|
+
|
|
219
|
+
// Stats tracking for multi-bit two-stage search
|
|
220
|
+
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
221
|
+
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
222
|
+
size_t local_1bit_evaluations = 0;
|
|
223
|
+
size_t local_multibit_evaluations = 0;
|
|
224
|
+
|
|
225
|
+
for (size_t j = 0; j < list_size; j++) {
|
|
226
|
+
if (sel != nullptr) {
|
|
227
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
228
|
+
if (!sel->is_member(id)) {
|
|
229
|
+
codes += code_size;
|
|
230
|
+
continue;
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
local_1bit_evaluations++;
|
|
235
|
+
|
|
236
|
+
// Stage 1: Compute lower bound using 1-bit codes
|
|
237
|
+
float lower_bound = rabitq_dc->lower_bound_distance(codes);
|
|
238
|
+
|
|
239
|
+
// Stage 2: Adaptive filtering
|
|
240
|
+
// L2 (min-heap): filter if lower_bound < simi[0]
|
|
241
|
+
// IP (max-heap): filter if lower_bound > simi[0]
|
|
242
|
+
// Note: Using simi[0] directly (not cached) enables more aggressive
|
|
243
|
+
// filtering as the heap is updated with better candidates
|
|
244
|
+
bool should_refine = keep_max ? (lower_bound > simi[0])
|
|
245
|
+
: (lower_bound < simi[0]);
|
|
246
|
+
|
|
247
|
+
if (should_refine) {
|
|
248
|
+
local_multibit_evaluations++;
|
|
249
|
+
// Lower bound is promising, compute full distance
|
|
250
|
+
float dis = distance_to_code(codes);
|
|
251
|
+
|
|
252
|
+
// Check if distance improves heap
|
|
253
|
+
bool improves_heap =
|
|
254
|
+
keep_max ? (dis > simi[0]) : (dis < simi[0]);
|
|
255
|
+
|
|
256
|
+
if (improves_heap) {
|
|
257
|
+
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
|
|
258
|
+
if (keep_max) {
|
|
259
|
+
minheap_replace_top(k, simi, idxi, dis, id);
|
|
260
|
+
} else {
|
|
261
|
+
maxheap_replace_top(k, simi, idxi, dis, id);
|
|
262
|
+
}
|
|
263
|
+
nup++;
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
codes += code_size;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
// Update global stats atomically
|
|
270
|
+
#pragma omp atomic
|
|
271
|
+
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
272
|
+
#pragma omp atomic
|
|
273
|
+
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
274
|
+
|
|
275
|
+
return nup;
|
|
276
|
+
}
|
|
277
|
+
|
|
194
278
|
void internal_try_setup_dc() {
|
|
195
279
|
if (!query_vector.empty() && !reconstructed_centroid.empty()) {
|
|
196
280
|
// both query_vector and centroid are available!
|
|
197
281
|
// set up DistanceComputer
|
|
198
282
|
dc.reset(ivf_rabitq.rabitq.get_distance_computer(
|
|
199
|
-
qb, reconstructed_centroid.data()));
|
|
283
|
+
qb, reconstructed_centroid.data(), centered));
|
|
200
284
|
|
|
201
285
|
dc->set_query(query_vector.data());
|
|
286
|
+
|
|
287
|
+
// Try to cast to RaBitQDistanceComputer for multi-bit support
|
|
288
|
+
rabitq_dc = dynamic_cast<RaBitQDistanceComputer*>(dc.get());
|
|
202
289
|
}
|
|
203
290
|
}
|
|
204
291
|
};
|
|
@@ -208,12 +295,15 @@ InvertedListScanner* IndexIVFRaBitQ::get_InvertedListScanner(
|
|
|
208
295
|
const IDSelector* sel,
|
|
209
296
|
const IVFSearchParameters* search_params_in) const {
|
|
210
297
|
uint8_t used_qb = qb;
|
|
298
|
+
bool centered = false;
|
|
211
299
|
if (auto params = dynamic_cast<const IVFRaBitQSearchParameters*>(
|
|
212
300
|
search_params_in)) {
|
|
213
301
|
used_qb = params->qb;
|
|
302
|
+
centered = params->centered;
|
|
214
303
|
}
|
|
215
304
|
|
|
216
|
-
return new RaBitInvertedListScanner(
|
|
305
|
+
return new RaBitInvertedListScanner(
|
|
306
|
+
*this, store_pairs, sel, used_qb, centered);
|
|
217
307
|
}
|
|
218
308
|
|
|
219
309
|
void IndexIVFRaBitQ::reconstruct_from_offset(
|
|
@@ -278,7 +368,8 @@ float IVFRaBitDistanceComputer::operator()(idx_t i) {
|
|
|
278
368
|
float distance = 0;
|
|
279
369
|
|
|
280
370
|
std::unique_ptr<FlatCodesDistanceComputer> dc(
|
|
281
|
-
parent->rabitq.get_distance_computer(
|
|
371
|
+
parent->rabitq.get_distance_computer(
|
|
372
|
+
parent->qb, centroid.data(), /*centered=*/false));
|
|
282
373
|
dc->set_query(q);
|
|
283
374
|
distance = dc->distance_to_code(code);
|
|
284
375
|
|
|
@@ -13,12 +13,14 @@
|
|
|
13
13
|
#include <faiss/Index.h>
|
|
14
14
|
#include <faiss/IndexIVF.h>
|
|
15
15
|
|
|
16
|
+
#include <faiss/impl/RaBitQStats.h>
|
|
16
17
|
#include <faiss/impl/RaBitQuantizer.h>
|
|
17
18
|
|
|
18
19
|
namespace faiss {
|
|
19
20
|
|
|
20
21
|
struct IVFRaBitQSearchParameters : IVFSearchParameters {
|
|
21
22
|
uint8_t qb = 0;
|
|
23
|
+
bool centered = false;
|
|
22
24
|
};
|
|
23
25
|
|
|
24
26
|
// * by_residual is true, just by design
|
|
@@ -34,7 +36,8 @@ struct IndexIVFRaBitQ : IndexIVF {
|
|
|
34
36
|
const size_t d,
|
|
35
37
|
const size_t nlist,
|
|
36
38
|
MetricType metric = METRIC_L2,
|
|
37
|
-
bool own_invlists = true
|
|
39
|
+
bool own_invlists = true,
|
|
40
|
+
uint8_t nb_bits = 1);
|
|
38
41
|
|
|
39
42
|
IndexIVFRaBitQ();
|
|
40
43
|
|