faiss 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +76 -17
- data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
- data/ext/faiss/kmeans.cpp +12 -9
- data/ext/faiss/numo.hpp +11 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
- data/ext/faiss/{utils.h → utils_rb.h} +6 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -8,8 +8,11 @@
|
|
|
8
8
|
#include <faiss/IndexIVFRaBitQFastScan.h>
|
|
9
9
|
|
|
10
10
|
#include <algorithm>
|
|
11
|
+
#include <array>
|
|
11
12
|
#include <cstdio>
|
|
13
|
+
#include <memory>
|
|
12
14
|
|
|
15
|
+
#include <faiss/impl/CodePackerRaBitQ.h>
|
|
13
16
|
#include <faiss/impl/FaissAssert.h>
|
|
14
17
|
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
15
18
|
#include <faiss/impl/RaBitQUtils.h>
|
|
@@ -79,8 +82,6 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
79
82
|
if (own_invlists) {
|
|
80
83
|
replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
|
|
81
84
|
}
|
|
82
|
-
|
|
83
|
-
flat_storage.clear();
|
|
84
85
|
}
|
|
85
86
|
|
|
86
87
|
// Constructor that converts an existing IndexIVFRaBitQ to FastScan format
|
|
@@ -97,41 +98,52 @@ IndexIVFRaBitQFastScan::IndexIVFRaBitQFastScan(
|
|
|
97
98
|
rabitq(orig.rabitq) {}
|
|
98
99
|
|
|
99
100
|
size_t IndexIVFRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
100
|
-
|
|
101
|
+
return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
|
|
102
|
+
}
|
|
101
103
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
} else {
|
|
106
|
-
// Multi-bit: SignBitFactorsWithError + ExtraBitsFactors + ex-codes
|
|
107
|
-
return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
|
|
108
|
-
(d * ex_bits + 7) / 8;
|
|
109
|
-
}
|
|
104
|
+
size_t IndexIVFRaBitQFastScan::code_packing_stride() const {
|
|
105
|
+
// Use code_size as stride to skip embedded factor data during packing
|
|
106
|
+
return code_size;
|
|
110
107
|
}
|
|
111
108
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
109
|
+
CodePacker* IndexIVFRaBitQFastScan::get_CodePacker() const {
|
|
110
|
+
return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/*********************************************************
|
|
114
|
+
* postprocess_packed_codes: write auxiliary data into blocks
|
|
115
|
+
*********************************************************/
|
|
116
|
+
|
|
117
|
+
void IndexIVFRaBitQFastScan::postprocess_packed_codes(
|
|
118
|
+
idx_t list_no,
|
|
119
|
+
size_t list_offset,
|
|
120
|
+
size_t n_added,
|
|
121
|
+
const uint8_t* flat_codes) {
|
|
122
|
+
auto* bil = dynamic_cast<BlockInvertedLists*>(invlists);
|
|
123
|
+
FAISS_THROW_IF_NOT(bil);
|
|
119
124
|
|
|
120
|
-
|
|
125
|
+
uint8_t* block_data = bil->codes[list_no].data();
|
|
126
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
121
127
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
+
const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
|
|
129
|
+
const size_t full_block_size = get_block_stride();
|
|
130
|
+
|
|
131
|
+
for (size_t i = 0; i < n_added; i++) {
|
|
132
|
+
const uint8_t* src = flat_codes + i * code_size + bit_pattern_size;
|
|
133
|
+
uint8_t* dst = rabitq_utils::get_block_aux_ptr(
|
|
134
|
+
block_data,
|
|
135
|
+
list_offset + i,
|
|
136
|
+
bbs,
|
|
137
|
+
packed_block_size,
|
|
138
|
+
full_block_size,
|
|
139
|
+
storage_size);
|
|
140
|
+
memcpy(dst, src, storage_size);
|
|
128
141
|
}
|
|
129
142
|
}
|
|
130
143
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
}
|
|
144
|
+
/*********************************************************
|
|
145
|
+
* train_encoder
|
|
146
|
+
*********************************************************/
|
|
135
147
|
|
|
136
148
|
void IndexIVFRaBitQFastScan::train_encoder(
|
|
137
149
|
idx_t n,
|
|
@@ -271,10 +283,11 @@ void IndexIVFRaBitQFastScan::compute_residual_LUT(
|
|
|
271
283
|
rotated_q,
|
|
272
284
|
rotated_qq);
|
|
273
285
|
|
|
274
|
-
// Override query norm for inner product if original query is provided
|
|
275
286
|
if (metric_type == MetricType::METRIC_INNER_PRODUCT &&
|
|
276
287
|
original_query != nullptr) {
|
|
277
288
|
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(original_query, d);
|
|
289
|
+
query_factors.q_dot_c = query_factors.qr_norm_L2sqr -
|
|
290
|
+
fvec_inner_product(original_query, residual, d);
|
|
278
291
|
}
|
|
279
292
|
|
|
280
293
|
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
@@ -441,23 +454,22 @@ void IndexIVFRaBitQFastScan::reconstruct_from_offset(
|
|
|
441
454
|
}
|
|
442
455
|
}
|
|
443
456
|
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
}
|
|
457
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
458
|
+
const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
|
|
459
|
+
const size_t full_block_size = get_block_stride();
|
|
460
|
+
|
|
461
|
+
InvertedLists::ScopedCodes list_block_codes(invlists, list_no);
|
|
462
|
+
const uint8_t* aux_ptr = rabitq_utils::get_block_aux_ptr(
|
|
463
|
+
list_block_codes.get(),
|
|
464
|
+
offset,
|
|
465
|
+
bbs,
|
|
466
|
+
packed_block_size,
|
|
467
|
+
full_block_size,
|
|
468
|
+
storage_size);
|
|
469
|
+
|
|
470
|
+
const auto& base_factors =
|
|
471
|
+
*reinterpret_cast<const SignBitFactors*>(aux_ptr);
|
|
472
|
+
const float dp_multiplier = base_factors.dp_multiplier;
|
|
461
473
|
|
|
462
474
|
// Decode residual directly using dp_multiplier
|
|
463
475
|
std::vector<float> residual(d);
|
|
@@ -573,7 +585,11 @@ IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::IVFRaBitQHeapHandler(
|
|
|
573
585
|
nq(nq_val),
|
|
574
586
|
k(k_val),
|
|
575
587
|
context(ctx),
|
|
576
|
-
is_multibit(multibit)
|
|
588
|
+
is_multibit(multibit),
|
|
589
|
+
storage_size(idx->compute_per_vector_storage_size()),
|
|
590
|
+
packed_block_size(((idx->M2 + 1) / 2) * idx->bbs),
|
|
591
|
+
full_block_size(idx->get_block_stride()),
|
|
592
|
+
packer(idx->get_CodePacker()) {
|
|
577
593
|
current_list_no = 0;
|
|
578
594
|
probe_indices.clear();
|
|
579
595
|
|
|
@@ -649,10 +665,13 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
|
649
665
|
|
|
650
666
|
const float normalized_distance = d32tab[j] * one_a + bias;
|
|
651
667
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
index->
|
|
668
|
+
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
669
|
+
list_codes_ptr,
|
|
670
|
+
idx_base + j,
|
|
671
|
+
index->bbs,
|
|
672
|
+
packed_block_size,
|
|
673
|
+
full_block_size,
|
|
674
|
+
storage_size);
|
|
656
675
|
|
|
657
676
|
if (is_multibit) {
|
|
658
677
|
// Track candidates actually considered for two-stage filtering
|
|
@@ -671,17 +690,18 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
|
671
690
|
index->qb,
|
|
672
691
|
index->d);
|
|
673
692
|
|
|
674
|
-
// Compute lower bound using error bound
|
|
675
|
-
float lower_bound =
|
|
676
|
-
compute_lower_bound(dist_1bit, result_id, local_q, q);
|
|
677
|
-
|
|
678
693
|
// Adaptive filtering: decide whether to compute full distance
|
|
679
694
|
const bool is_similarity =
|
|
680
695
|
index->metric_type == MetricType::METRIC_INNER_PRODUCT;
|
|
681
|
-
bool should_refine = is_similarity
|
|
682
|
-
? (lower_bound > heap_dis[0]) // IP: keep if better
|
|
683
|
-
: (lower_bound < heap_dis[0]); // L2: keep if better
|
|
684
696
|
|
|
697
|
+
float g_error = query_factors.g_error;
|
|
698
|
+
|
|
699
|
+
bool should_refine = rabitq_utils::should_refine_candidate(
|
|
700
|
+
dist_1bit,
|
|
701
|
+
full_factors.f_error,
|
|
702
|
+
g_error,
|
|
703
|
+
heap_dis[0],
|
|
704
|
+
is_similarity);
|
|
685
705
|
if (should_refine) {
|
|
686
706
|
local_multibit_evaluations++;
|
|
687
707
|
|
|
@@ -696,6 +716,7 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
|
696
716
|
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
697
717
|
heap_replace_top<Cfloat>(
|
|
698
718
|
k, heap_dis, heap_ids, dist_full, result_id);
|
|
719
|
+
nup++;
|
|
699
720
|
}
|
|
700
721
|
}
|
|
701
722
|
} else {
|
|
@@ -715,6 +736,7 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::handle(
|
|
|
715
736
|
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
716
737
|
heap_replace_top<Cfloat>(
|
|
717
738
|
k, heap_dis, heap_ids, adjusted_distance, result_id);
|
|
739
|
+
nup++;
|
|
718
740
|
}
|
|
719
741
|
}
|
|
720
742
|
}
|
|
@@ -732,6 +754,7 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::set_list_context(
|
|
|
732
754
|
const std::vector<int>& probe_map) {
|
|
733
755
|
current_list_no = list_no;
|
|
734
756
|
probe_indices = probe_map;
|
|
757
|
+
list_codes_ptr = index->invlists->get_codes(list_no);
|
|
735
758
|
}
|
|
736
759
|
|
|
737
760
|
template <class C>
|
|
@@ -750,49 +773,23 @@ void IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::end() {
|
|
|
750
773
|
}
|
|
751
774
|
}
|
|
752
775
|
|
|
753
|
-
template <class C>
|
|
754
|
-
float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::compute_lower_bound(
|
|
755
|
-
float dist_1bit,
|
|
756
|
-
size_t db_idx,
|
|
757
|
-
size_t local_q,
|
|
758
|
-
size_t global_q) const {
|
|
759
|
-
// Access f_error from SignBitFactorsWithError in flat storage
|
|
760
|
-
const size_t storage_size = index->compute_per_vector_storage_size();
|
|
761
|
-
const uint8_t* base_ptr =
|
|
762
|
-
index->flat_storage.data() + db_idx * storage_size;
|
|
763
|
-
const SignBitFactorsWithError& db_factors =
|
|
764
|
-
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
765
|
-
float f_error = db_factors.f_error;
|
|
766
|
-
|
|
767
|
-
// Get g_error from query factors
|
|
768
|
-
// Use local_q to access probe_indices (batch-local), global_q for storage
|
|
769
|
-
float g_error = 0.0f;
|
|
770
|
-
if (context && context->query_factors) {
|
|
771
|
-
size_t probe_rank = probe_indices[local_q];
|
|
772
|
-
size_t nprobe = context->nprobe > 0 ? context->nprobe : index->nprobe;
|
|
773
|
-
size_t storage_idx = global_q * nprobe + probe_rank;
|
|
774
|
-
g_error = context->query_factors[storage_idx].g_error;
|
|
775
|
-
}
|
|
776
|
-
|
|
777
|
-
// Compute error adjustment: f_error * g_error
|
|
778
|
-
float error_adjustment = f_error * g_error;
|
|
779
|
-
|
|
780
|
-
return dist_1bit - error_adjustment;
|
|
781
|
-
}
|
|
782
|
-
|
|
783
776
|
template <class C>
|
|
784
777
|
float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
|
|
785
778
|
compute_full_multibit_distance(
|
|
786
|
-
size_t db_idx
|
|
779
|
+
size_t /*db_idx*/,
|
|
787
780
|
size_t local_q,
|
|
788
781
|
size_t global_q,
|
|
789
782
|
size_t local_offset) const {
|
|
790
783
|
const size_t ex_bits = index->rabitq.nb_bits - 1;
|
|
791
784
|
const size_t dim = index->d;
|
|
792
785
|
|
|
793
|
-
const
|
|
794
|
-
|
|
795
|
-
|
|
786
|
+
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
787
|
+
list_codes_ptr,
|
|
788
|
+
local_offset,
|
|
789
|
+
index->bbs,
|
|
790
|
+
packed_block_size,
|
|
791
|
+
full_block_size,
|
|
792
|
+
storage_size);
|
|
796
793
|
|
|
797
794
|
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
798
795
|
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
@@ -809,8 +806,7 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
|
|
|
809
806
|
InvertedLists::ScopedCodes list_codes(index->invlists, list_no);
|
|
810
807
|
|
|
811
808
|
std::vector<uint8_t> unpacked_code(index->code_size);
|
|
812
|
-
|
|
813
|
-
packer.unpack_1(list_codes.get(), local_offset, unpacked_code.data());
|
|
809
|
+
packer->unpack_1(list_codes.get(), local_offset, unpacked_code.data());
|
|
814
810
|
const uint8_t* sign_bits = unpacked_code.data();
|
|
815
811
|
|
|
816
812
|
return rabitq_utils::compute_full_multibit_distance(
|
|
@@ -818,11 +814,164 @@ float IndexIVFRaBitQFastScan::IVFRaBitQHeapHandler<C>::
|
|
|
818
814
|
ex_code,
|
|
819
815
|
ex_fac,
|
|
820
816
|
query_factors.rotated_q.data(),
|
|
821
|
-
|
|
822
|
-
|
|
817
|
+
(index->metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
818
|
+
? query_factors.q_dot_c
|
|
819
|
+
: query_factors.qr_to_c_L2sqr,
|
|
823
820
|
dim,
|
|
824
821
|
ex_bits,
|
|
825
822
|
index->metric_type);
|
|
826
823
|
}
|
|
827
824
|
|
|
825
|
+
/*********************************************************
|
|
826
|
+
* IVFRaBitQFastScanScanner implementation
|
|
827
|
+
*********************************************************/
|
|
828
|
+
|
|
829
|
+
namespace {
|
|
830
|
+
|
|
831
|
+
/// Provides IVF scanner interface using FastScan's SIMD batch processing.
|
|
832
|
+
struct IVFRaBitQFastScanScanner : InvertedListScanner {
|
|
833
|
+
static constexpr int impl = 10;
|
|
834
|
+
static constexpr size_t nq = 1;
|
|
835
|
+
|
|
836
|
+
const IndexIVFRaBitQFastScan& index;
|
|
837
|
+
|
|
838
|
+
AlignedTable<uint8_t> dis_tables;
|
|
839
|
+
AlignedTable<uint16_t> biases;
|
|
840
|
+
/// [scale, offset] for converting uint16 to float
|
|
841
|
+
std::array<float, 2> normalizers{};
|
|
842
|
+
|
|
843
|
+
const float* xi = nullptr;
|
|
844
|
+
|
|
845
|
+
QueryFactorsData query_factors;
|
|
846
|
+
FastScanDistancePostProcessing context;
|
|
847
|
+
|
|
848
|
+
std::unique_ptr<FlatCodesDistanceComputer> dc;
|
|
849
|
+
std::vector<float> centroid;
|
|
850
|
+
|
|
851
|
+
IVFRaBitQFastScanScanner(
|
|
852
|
+
const IndexIVFRaBitQFastScan& index,
|
|
853
|
+
bool store_pairs,
|
|
854
|
+
const IDSelector* sel)
|
|
855
|
+
: InvertedListScanner(store_pairs, sel), index(index) {
|
|
856
|
+
this->keep_max = is_similarity_metric(index.metric_type);
|
|
857
|
+
}
|
|
858
|
+
|
|
859
|
+
void set_query(const float* query) override {
|
|
860
|
+
this->xi = query;
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
|
864
|
+
this->list_no = list_no;
|
|
865
|
+
|
|
866
|
+
IndexIVFFastScan::CoarseQuantized cq{
|
|
867
|
+
.nprobe = 1,
|
|
868
|
+
.dis = &coarse_dis,
|
|
869
|
+
.ids = &list_no,
|
|
870
|
+
};
|
|
871
|
+
|
|
872
|
+
// Set up context for use in scan_codes
|
|
873
|
+
context = FastScanDistancePostProcessing{};
|
|
874
|
+
context.query_factors = &query_factors;
|
|
875
|
+
context.nprobe = 1;
|
|
876
|
+
|
|
877
|
+
index.compute_LUT_uint8(
|
|
878
|
+
1, xi, cq, dis_tables, biases, &normalizers[0], context);
|
|
879
|
+
|
|
880
|
+
// Set up distance computer for distance_to_code
|
|
881
|
+
centroid.resize(index.d);
|
|
882
|
+
index.quantizer->reconstruct(list_no, centroid.data());
|
|
883
|
+
dc.reset(index.rabitq.get_distance_computer(
|
|
884
|
+
index.qb, centroid.data(), index.centered));
|
|
885
|
+
dc->set_query(xi);
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
float distance_to_code(const uint8_t* code) const override {
|
|
889
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
890
|
+
dc,
|
|
891
|
+
"set_query and set_list must be called before distance_to_code");
|
|
892
|
+
return dc->distance_to_code(code);
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
public:
|
|
896
|
+
size_t scan_codes(
|
|
897
|
+
size_t ntotal,
|
|
898
|
+
const uint8_t* codes,
|
|
899
|
+
const idx_t* ids,
|
|
900
|
+
float* distances,
|
|
901
|
+
idx_t* labels,
|
|
902
|
+
size_t k) const override {
|
|
903
|
+
// initialize the current iteration heap to the worst possible value of
|
|
904
|
+
// the prior loop
|
|
905
|
+
std::vector<float> curr_dists(k, distances[0]);
|
|
906
|
+
std::vector<idx_t> curr_labels(k, labels[0]);
|
|
907
|
+
|
|
908
|
+
std::unique_ptr<SIMDResultHandlerToFloat> handler(
|
|
909
|
+
index.make_knn_handler(
|
|
910
|
+
!keep_max,
|
|
911
|
+
impl,
|
|
912
|
+
nq,
|
|
913
|
+
k,
|
|
914
|
+
curr_dists.data(),
|
|
915
|
+
curr_labels.data(),
|
|
916
|
+
sel,
|
|
917
|
+
context,
|
|
918
|
+
&normalizers[0]));
|
|
919
|
+
|
|
920
|
+
int qmap1[1] = {0};
|
|
921
|
+
handler->q_map = qmap1;
|
|
922
|
+
handler->begin(&normalizers[0]);
|
|
923
|
+
|
|
924
|
+
const uint8_t* LUT = dis_tables.get();
|
|
925
|
+
handler->dbias = biases.get();
|
|
926
|
+
handler->ntotal = ntotal;
|
|
927
|
+
handler->id_map = ids;
|
|
928
|
+
|
|
929
|
+
// RaBitQ needs list context for factor lookup
|
|
930
|
+
std::vector<int> probe_map = {0};
|
|
931
|
+
handler->set_list_context(list_no, probe_map);
|
|
932
|
+
|
|
933
|
+
pq4_accumulate_loop(
|
|
934
|
+
1,
|
|
935
|
+
roundup(ntotal, index.bbs),
|
|
936
|
+
index.bbs,
|
|
937
|
+
static_cast<int>(index.M2),
|
|
938
|
+
codes,
|
|
939
|
+
LUT,
|
|
940
|
+
*handler,
|
|
941
|
+
nullptr,
|
|
942
|
+
index.get_block_stride());
|
|
943
|
+
|
|
944
|
+
// Combine results across iterations
|
|
945
|
+
handler->end();
|
|
946
|
+
if (keep_max) {
|
|
947
|
+
minheap_addn(
|
|
948
|
+
k,
|
|
949
|
+
distances,
|
|
950
|
+
labels,
|
|
951
|
+
curr_dists.data(),
|
|
952
|
+
curr_labels.data(),
|
|
953
|
+
k);
|
|
954
|
+
} else {
|
|
955
|
+
maxheap_addn(
|
|
956
|
+
k,
|
|
957
|
+
distances,
|
|
958
|
+
labels,
|
|
959
|
+
curr_dists.data(),
|
|
960
|
+
curr_labels.data(),
|
|
961
|
+
k);
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
return handler->num_updates();
|
|
965
|
+
}
|
|
966
|
+
};
|
|
967
|
+
|
|
968
|
+
} // anonymous namespace
|
|
969
|
+
|
|
970
|
+
InvertedListScanner* IndexIVFRaBitQFastScan::get_InvertedListScanner(
|
|
971
|
+
bool store_pairs,
|
|
972
|
+
const IDSelector* sel,
|
|
973
|
+
const IVFSearchParameters*) const {
|
|
974
|
+
return new IVFRaBitQFastScanScanner(*this, store_pairs, sel);
|
|
975
|
+
}
|
|
976
|
+
|
|
828
977
|
} // namespace faiss
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
|
+
#include <memory>
|
|
10
11
|
#include <vector>
|
|
11
12
|
|
|
12
13
|
#include <faiss/IndexIVFFastScan.h>
|
|
@@ -55,17 +56,6 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
55
56
|
/// Use zero-centered scalar quantizer for queries
|
|
56
57
|
bool centered = false;
|
|
57
58
|
|
|
58
|
-
/// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
|
|
59
|
-
///
|
|
60
|
-
/// 1-bit codes (sign bits) are stored in the inherited `codes` array from
|
|
61
|
-
/// IndexFastScan in packed FastScan format for SIMD processing.
|
|
62
|
-
///
|
|
63
|
-
/// This flat_storage holds per-vector factors and refinement-bit codes:
|
|
64
|
-
/// Layout for 1-bit: [SignBitFactors (8 bytes)]
|
|
65
|
-
/// Layout for multi-bit: [SignBitFactorsWithError
|
|
66
|
-
/// (12B)][ref_codes][ExtraBitsFactors (8B)]
|
|
67
|
-
std::vector<uint8_t> flat_storage;
|
|
68
|
-
|
|
69
59
|
// Constructors
|
|
70
60
|
|
|
71
61
|
IndexIVFRaBitQFastScan();
|
|
@@ -94,16 +84,20 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
94
84
|
bool include_listnos = false) const override;
|
|
95
85
|
|
|
96
86
|
protected:
|
|
97
|
-
/// Extract and store RaBitQ factors from encoded vectors
|
|
98
|
-
void preprocess_code_metadata(
|
|
99
|
-
idx_t n,
|
|
100
|
-
const uint8_t* flat_codes,
|
|
101
|
-
idx_t start_global_idx) override;
|
|
102
|
-
|
|
103
87
|
/// Return code_size as stride to skip embedded factor data during packing
|
|
104
88
|
size_t code_packing_stride() const override;
|
|
105
89
|
|
|
106
90
|
public:
|
|
91
|
+
/// Return CodePackerRaBitQ with enlarged block size
|
|
92
|
+
CodePacker* get_CodePacker() const override;
|
|
93
|
+
|
|
94
|
+
/// Write per-vector auxiliary data into block auxiliary region
|
|
95
|
+
void postprocess_packed_codes(
|
|
96
|
+
idx_t list_no,
|
|
97
|
+
size_t list_offset,
|
|
98
|
+
size_t n_added,
|
|
99
|
+
const uint8_t* flat_codes) override;
|
|
100
|
+
|
|
107
101
|
/// Reconstruct a single vector from an inverted list
|
|
108
102
|
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
|
109
103
|
const override;
|
|
@@ -111,7 +105,7 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
111
105
|
/// Override sa_decode to handle RaBitQ reconstruction
|
|
112
106
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
113
107
|
|
|
114
|
-
/// Compute
|
|
108
|
+
/// Compute per-vector auxiliary storage size based on nb_bits
|
|
115
109
|
size_t compute_per_vector_storage_size() const;
|
|
116
110
|
|
|
117
111
|
private:
|
|
@@ -166,6 +160,13 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
166
160
|
const FastScanDistancePostProcessing& context,
|
|
167
161
|
const float* normalizers = nullptr) const override;
|
|
168
162
|
|
|
163
|
+
/// Get an InvertedListScanner for single-query scanning.
|
|
164
|
+
/// This provides compatibility with the standard IVF search interface
|
|
165
|
+
InvertedListScanner* get_InvertedListScanner(
|
|
166
|
+
bool store_pairs = false,
|
|
167
|
+
const IDSelector* sel = nullptr,
|
|
168
|
+
const IVFSearchParameters* params = nullptr) const override;
|
|
169
|
+
|
|
169
170
|
/** SIMD result handler for IndexIVFRaBitQFastScan that applies
|
|
170
171
|
* RaBitQ-specific distance corrections during batch processing.
|
|
171
172
|
*
|
|
@@ -192,11 +193,19 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
192
193
|
int64_t* heap_labels; // [nq * k]
|
|
193
194
|
const size_t nq, k;
|
|
194
195
|
size_t current_list_no = 0;
|
|
196
|
+
const uint8_t* list_codes_ptr = nullptr; // raw block data for list
|
|
195
197
|
std::vector<int>
|
|
196
198
|
probe_indices; // probe index for each query in current batch
|
|
197
199
|
const FastScanDistancePostProcessing*
|
|
198
200
|
context; // Processing context with query factors
|
|
199
201
|
const bool is_multibit; // Whether to use multi-bit two-stage search
|
|
202
|
+
size_t nup = 0; // Number of heap updates
|
|
203
|
+
|
|
204
|
+
// Cached block-layout constants (invariant for handler lifetime)
|
|
205
|
+
const size_t storage_size;
|
|
206
|
+
const size_t packed_block_size;
|
|
207
|
+
const size_t full_block_size;
|
|
208
|
+
std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
|
|
200
209
|
|
|
201
210
|
// Use float-based comparator for heap operations
|
|
202
211
|
using Cfloat = typename std::conditional<
|
|
@@ -224,6 +233,10 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
224
233
|
|
|
225
234
|
void end() override;
|
|
226
235
|
|
|
236
|
+
size_t num_updates() override {
|
|
237
|
+
return nup;
|
|
238
|
+
}
|
|
239
|
+
|
|
227
240
|
private:
|
|
228
241
|
/// Compute full multi-bit distance for a candidate vector (multi-bit
|
|
229
242
|
/// only)
|
|
@@ -232,20 +245,10 @@ struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
|
232
245
|
/// @param global_q Global query index (for storage indexing)
|
|
233
246
|
/// @param local_offset Offset within the current inverted list
|
|
234
247
|
float compute_full_multibit_distance(
|
|
235
|
-
size_t db_idx
|
|
248
|
+
size_t /*db_idx*/,
|
|
236
249
|
size_t local_q,
|
|
237
250
|
size_t global_q,
|
|
238
251
|
size_t local_offset) const;
|
|
239
|
-
|
|
240
|
-
/// Compute lower bound using 1-bit distance and error bound (multi-bit
|
|
241
|
-
/// only)
|
|
242
|
-
/// @param local_q Batch-local query index (for probe_indices access)
|
|
243
|
-
/// @param global_q Global query index (for storage indexing)
|
|
244
|
-
float compute_lower_bound(
|
|
245
|
-
float dist_1bit,
|
|
246
|
-
size_t db_idx,
|
|
247
|
-
size_t local_q,
|
|
248
|
-
size_t global_q) const;
|
|
249
252
|
};
|
|
250
253
|
};
|
|
251
254
|
|
|
@@ -86,12 +86,14 @@ void IndexLSH::train(idx_t n, const float* x) {
|
|
|
86
86
|
|
|
87
87
|
for (idx_t i = 0; i < nbits; i++) {
|
|
88
88
|
float* xi = transposed_x.get() + i * n;
|
|
89
|
-
//
|
|
90
|
-
std::
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
// Use nth_element (O(n)) instead of sort (O(n log n))
|
|
90
|
+
std::nth_element(xi, xi + n / 2, xi + n);
|
|
91
|
+
float median = xi[n / 2];
|
|
92
|
+
if (n % 2 == 0) {
|
|
93
|
+
std::nth_element(xi, xi + n / 2 - 1, xi + n);
|
|
94
|
+
median = (median + xi[n / 2 - 1]) / 2;
|
|
95
|
+
}
|
|
96
|
+
thresholds[i] = median;
|
|
95
97
|
}
|
|
96
98
|
}
|
|
97
99
|
is_trained = true;
|