faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -0,0 +1,731 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/IndexRaBitQFastScan.h>
|
|
9
|
+
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
10
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
11
|
+
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
12
|
+
#include <faiss/impl/pq4_fast_scan.h>
|
|
13
|
+
#include <faiss/utils/utils.h>
|
|
14
|
+
#include <algorithm>
|
|
15
|
+
#include <cmath>
|
|
16
|
+
|
|
17
|
+
namespace faiss {
|
|
18
|
+
|
|
19
|
+
static inline size_t roundup(size_t a, size_t b) {
|
|
20
|
+
return (a + b - 1) / b * b;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
24
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
25
|
+
|
|
26
|
+
if (ex_bits == 0) {
|
|
27
|
+
// 1-bit: only SignBitFactors
|
|
28
|
+
return sizeof(rabitq_utils::SignBitFactors);
|
|
29
|
+
} else {
|
|
30
|
+
// Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
|
|
31
|
+
// mag-codes
|
|
32
|
+
return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
|
|
33
|
+
(d * ex_bits + 7) / 8;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
|
|
38
|
+
|
|
39
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan(
|
|
40
|
+
idx_t d,
|
|
41
|
+
MetricType metric,
|
|
42
|
+
int bbs,
|
|
43
|
+
uint8_t nb_bits)
|
|
44
|
+
: rabitq(d, metric, nb_bits) {
|
|
45
|
+
// RaBitQ-specific validation
|
|
46
|
+
FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
|
|
47
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
48
|
+
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
49
|
+
"RaBitQ FastScan only supports L2 and Inner Product metrics");
|
|
50
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
51
|
+
nb_bits >= 1 && nb_bits <= 9, "nb_bits must be between 1 and 9");
|
|
52
|
+
|
|
53
|
+
// RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
|
|
54
|
+
// Each FastScan sub-quantizer handles 4 RaBitQ dimensions
|
|
55
|
+
const size_t M_fastscan = (d + 3) / 4;
|
|
56
|
+
constexpr size_t nbits_fastscan = 4;
|
|
57
|
+
|
|
58
|
+
// init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
|
|
59
|
+
init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
|
|
60
|
+
|
|
61
|
+
// Compute code_size directly using RaBitQuantizer
|
|
62
|
+
code_size = rabitq.compute_code_size(d, nb_bits);
|
|
63
|
+
|
|
64
|
+
// Set RaBitQ-specific parameters
|
|
65
|
+
qb = 8;
|
|
66
|
+
center.resize(d, 0.0f);
|
|
67
|
+
|
|
68
|
+
// Initialize empty flat storage
|
|
69
|
+
flat_storage.clear();
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
73
|
+
: rabitq(orig.rabitq) {
|
|
74
|
+
// RaBitQ-specific validation
|
|
75
|
+
FAISS_THROW_IF_NOT_MSG(orig.d > 0, "Dimension must be positive");
|
|
76
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
77
|
+
orig.metric_type == METRIC_L2 ||
|
|
78
|
+
orig.metric_type == METRIC_INNER_PRODUCT,
|
|
79
|
+
"RaBitQ FastScan only supports L2 and Inner Product metrics");
|
|
80
|
+
|
|
81
|
+
// RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
|
|
82
|
+
// Each FastScan sub-quantizer handles 4 RaBitQ dimensions
|
|
83
|
+
const size_t M_fastscan = (orig.d + 3) / 4;
|
|
84
|
+
constexpr size_t nbits_fastscan = 4;
|
|
85
|
+
|
|
86
|
+
// Initialize FastScan base with the original index's parameters
|
|
87
|
+
init_fastscan(
|
|
88
|
+
static_cast<int>(orig.d),
|
|
89
|
+
M_fastscan,
|
|
90
|
+
nbits_fastscan,
|
|
91
|
+
orig.metric_type,
|
|
92
|
+
bbs);
|
|
93
|
+
|
|
94
|
+
code_size = rabitq.compute_code_size(d, rabitq.nb_bits);
|
|
95
|
+
|
|
96
|
+
// Copy properties from original index
|
|
97
|
+
ntotal = orig.ntotal;
|
|
98
|
+
ntotal2 = roundup(ntotal, bbs);
|
|
99
|
+
is_trained = orig.is_trained;
|
|
100
|
+
orig_codes = orig.codes.data();
|
|
101
|
+
qb = orig.qb;
|
|
102
|
+
centered = orig.centered;
|
|
103
|
+
center = orig.center;
|
|
104
|
+
|
|
105
|
+
// If the original index has data, extract factors and pack codes
|
|
106
|
+
if (ntotal > 0) {
|
|
107
|
+
// Compute per-vector storage size for flat storage
|
|
108
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
109
|
+
|
|
110
|
+
// Allocate flat storage
|
|
111
|
+
flat_storage.resize(ntotal * storage_size);
|
|
112
|
+
|
|
113
|
+
// Copy factors directly from original codes
|
|
114
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
115
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
116
|
+
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
117
|
+
const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
|
|
118
|
+
uint8_t* storage = flat_storage.data() + i * storage_size;
|
|
119
|
+
memcpy(storage, source_factors_ptr, storage_size);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
|
|
123
|
+
// This follows the same pattern as IndexPQFastScan constructor
|
|
124
|
+
AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
|
|
125
|
+
memset(fastscan_codes.get(), 0, ntotal * code_size);
|
|
126
|
+
|
|
127
|
+
// Convert from RaBitQ 1-bit-per-dimension to FastScan
|
|
128
|
+
// 4-bit-per-sub-quantizer
|
|
129
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
130
|
+
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
131
|
+
uint8_t* fs_code = fastscan_codes.get() + i * code_size;
|
|
132
|
+
|
|
133
|
+
// Convert each dimension's bit (same logic as compute_codes)
|
|
134
|
+
for (size_t j = 0; j < orig.d; j++) {
|
|
135
|
+
// Extract bit from original RaBitQ format
|
|
136
|
+
const size_t orig_byte_idx = j / 8;
|
|
137
|
+
const size_t orig_bit_offset = j % 8;
|
|
138
|
+
const bool bit_value =
|
|
139
|
+
(orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
|
|
140
|
+
|
|
141
|
+
// Use RaBitQUtils for consistent bit setting
|
|
142
|
+
if (bit_value) {
|
|
143
|
+
rabitq_utils::set_bit_fastscan(fs_code, j);
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Pack the converted codes using pq4_pack_codes with custom stride
|
|
149
|
+
codes.resize(ntotal2 * M2 / 2);
|
|
150
|
+
pq4_pack_codes(
|
|
151
|
+
fastscan_codes.get(),
|
|
152
|
+
ntotal,
|
|
153
|
+
M,
|
|
154
|
+
ntotal2,
|
|
155
|
+
bbs,
|
|
156
|
+
M2,
|
|
157
|
+
codes.get(),
|
|
158
|
+
code_size);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
void IndexRaBitQFastScan::train(idx_t n, const float* x) {
|
|
163
|
+
// compute a centroid
|
|
164
|
+
std::vector<float> centroid(d, 0);
|
|
165
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
166
|
+
for (size_t j = 0; j < d; j++) {
|
|
167
|
+
centroid[j] += x[i * d + j];
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
if (n != 0) {
|
|
172
|
+
for (size_t j = 0; j < d; j++) {
|
|
173
|
+
centroid[j] /= (float)n;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
center = std::move(centroid);
|
|
178
|
+
|
|
179
|
+
rabitq.train(n, x);
|
|
180
|
+
is_trained = true;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
void IndexRaBitQFastScan::add(idx_t n, const float* x) {
|
|
184
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
185
|
+
|
|
186
|
+
// Handle blocking to avoid excessive allocations
|
|
187
|
+
constexpr idx_t bs = 65536;
|
|
188
|
+
if (n > bs) {
|
|
189
|
+
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
190
|
+
idx_t i1 = std::min(n, i0 + bs);
|
|
191
|
+
if (verbose) {
|
|
192
|
+
printf("IndexRaBitQFastScan::add %zd/%zd\n",
|
|
193
|
+
size_t(i1),
|
|
194
|
+
size_t(n));
|
|
195
|
+
}
|
|
196
|
+
add(i1 - i0, x + i0 * d);
|
|
197
|
+
}
|
|
198
|
+
return;
|
|
199
|
+
}
|
|
200
|
+
InterruptCallback::check();
|
|
201
|
+
|
|
202
|
+
// Create codes with embedded factors using our compute_codes
|
|
203
|
+
AlignedTable<uint8_t> tmp_codes(n * code_size);
|
|
204
|
+
compute_codes(tmp_codes.get(), n, x);
|
|
205
|
+
|
|
206
|
+
const size_t storage_size = compute_per_vector_storage_size();
|
|
207
|
+
flat_storage.resize((ntotal + n) * storage_size);
|
|
208
|
+
|
|
209
|
+
// Populate flat storage (no sign bits copying needed!)
|
|
210
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
211
|
+
for (idx_t i = 0; i < n; i++) {
|
|
212
|
+
const uint8_t* code = tmp_codes.get() + i * code_size;
|
|
213
|
+
const idx_t vec_idx = ntotal + i;
|
|
214
|
+
|
|
215
|
+
// Copy factors data directly to flat storage (no reordering needed)
|
|
216
|
+
const uint8_t* source_factors_ptr = code + bit_pattern_size;
|
|
217
|
+
uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
|
|
218
|
+
memcpy(storage, source_factors_ptr, storage_size);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
// Resize main storage (same logic as parent)
|
|
222
|
+
ntotal2 = roundup(ntotal + n, bbs);
|
|
223
|
+
size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
|
|
224
|
+
size_t old_size = codes.size();
|
|
225
|
+
if (new_size > old_size) {
|
|
226
|
+
codes.resize(new_size);
|
|
227
|
+
memset(codes.get() + old_size, 0, new_size - old_size);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Use our custom packing function with correct stride
|
|
231
|
+
pq4_pack_codes_range(
|
|
232
|
+
tmp_codes.get(),
|
|
233
|
+
M, // Number of sub-quantizers (bit patterns only)
|
|
234
|
+
ntotal,
|
|
235
|
+
ntotal + n, // Range to pack
|
|
236
|
+
bbs,
|
|
237
|
+
M2, // Block parameters
|
|
238
|
+
codes.get(), // Output
|
|
239
|
+
code_size); // CUSTOM STRIDE: includes factor space
|
|
240
|
+
|
|
241
|
+
ntotal += n;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
245
|
+
const {
|
|
246
|
+
FAISS_ASSERT(codes != nullptr);
|
|
247
|
+
FAISS_ASSERT(x != nullptr);
|
|
248
|
+
FAISS_ASSERT(
|
|
249
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
250
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
251
|
+
if (n == 0) {
|
|
252
|
+
return;
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
// Hoist loop-invariant computations
|
|
256
|
+
const float* centroid_data = center.data();
|
|
257
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
258
|
+
const size_t ex_bits = rabitq.nb_bits - 1;
|
|
259
|
+
const size_t ex_code_size = (d * ex_bits + 7) / 8;
|
|
260
|
+
|
|
261
|
+
memset(codes, 0, n * code_size);
|
|
262
|
+
|
|
263
|
+
#pragma omp parallel for if (n > 1000)
|
|
264
|
+
for (int64_t i = 0; i < n; i++) {
|
|
265
|
+
uint8_t* const code = codes + i * code_size;
|
|
266
|
+
const float* const x_row = x + i * d;
|
|
267
|
+
|
|
268
|
+
// Compute residual once, reuse for both sign bits and ex-bits
|
|
269
|
+
std::vector<float> residual(d);
|
|
270
|
+
for (size_t j = 0; j < d; j++) {
|
|
271
|
+
const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
|
|
272
|
+
residual[j] = x_row[j] - centroid_val;
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// Pack sign bits directly into FastScan format using precomputed
|
|
276
|
+
// residual
|
|
277
|
+
for (size_t j = 0; j < d; j++) {
|
|
278
|
+
if (residual[j] > 0.0f) {
|
|
279
|
+
rabitq_utils::set_bit_fastscan(code, j);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
SignBitFactorsWithError factors = rabitq_utils::compute_vector_factors(
|
|
284
|
+
x_row, d, centroid_data, metric_type, ex_bits > 0);
|
|
285
|
+
|
|
286
|
+
if (ex_bits == 0) {
|
|
287
|
+
// 1-bit: store only SignBitFactors (8 bytes)
|
|
288
|
+
memcpy(code + bit_pattern_size, &factors, sizeof(SignBitFactors));
|
|
289
|
+
} else {
|
|
290
|
+
// Multi-bit: store full SignBitFactorsWithError (12 bytes)
|
|
291
|
+
memcpy(code + bit_pattern_size,
|
|
292
|
+
&factors,
|
|
293
|
+
sizeof(SignBitFactorsWithError));
|
|
294
|
+
|
|
295
|
+
// Add mag-codes and ExtraBitsFactors using precomputed
|
|
296
|
+
// residual
|
|
297
|
+
uint8_t* ex_code =
|
|
298
|
+
code + bit_pattern_size + sizeof(SignBitFactorsWithError);
|
|
299
|
+
ExtraBitsFactors ex_factors_temp;
|
|
300
|
+
|
|
301
|
+
rabitq_multibit::quantize_ex_bits(
|
|
302
|
+
residual.data(),
|
|
303
|
+
d,
|
|
304
|
+
rabitq.nb_bits,
|
|
305
|
+
ex_code,
|
|
306
|
+
ex_factors_temp,
|
|
307
|
+
metric_type,
|
|
308
|
+
centroid_data);
|
|
309
|
+
|
|
310
|
+
memcpy(ex_code + ex_code_size,
|
|
311
|
+
&ex_factors_temp,
|
|
312
|
+
sizeof(ExtraBitsFactors));
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
void IndexRaBitQFastScan::compute_float_LUT(
|
|
318
|
+
float* lut,
|
|
319
|
+
idx_t n,
|
|
320
|
+
const float* x,
|
|
321
|
+
const FastScanDistancePostProcessing& context) const {
|
|
322
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
323
|
+
|
|
324
|
+
// Pre-allocate working buffers to avoid repeated allocations
|
|
325
|
+
std::vector<float> rotated_q(d);
|
|
326
|
+
std::vector<uint8_t> rotated_qq(d);
|
|
327
|
+
|
|
328
|
+
// Compute lookup tables for FastScan SIMD operations
|
|
329
|
+
// For each query vector, computes distance contributions for all
|
|
330
|
+
// possible 4-bit codes per sub-quantizer. Also computes and stores
|
|
331
|
+
// query factors for distance reconstruction.
|
|
332
|
+
for (idx_t i = 0; i < n; i++) {
|
|
333
|
+
const float* query = x + i * d;
|
|
334
|
+
|
|
335
|
+
// Compute query factors and store in array if available
|
|
336
|
+
rabitq_utils::QueryFactorsData query_factors_data =
|
|
337
|
+
rabitq_utils::compute_query_factors(
|
|
338
|
+
query,
|
|
339
|
+
d,
|
|
340
|
+
center.data(),
|
|
341
|
+
qb,
|
|
342
|
+
centered,
|
|
343
|
+
metric_type,
|
|
344
|
+
rotated_q,
|
|
345
|
+
rotated_qq);
|
|
346
|
+
|
|
347
|
+
// Store query factors in context array if provided
|
|
348
|
+
if (context.query_factors != nullptr) {
|
|
349
|
+
query_factors_data.rotated_q = rotated_q;
|
|
350
|
+
context.query_factors[i] = query_factors_data;
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
// Create lookup table storing distance contributions for all possible
|
|
354
|
+
// 4-bit codes per sub-quantizer for FastScan SIMD operations
|
|
355
|
+
float* query_lut = lut + i * M * 16;
|
|
356
|
+
|
|
357
|
+
if (centered) {
|
|
358
|
+
// For centered mode, we use the signed odd integer quantization
|
|
359
|
+
// scheme.
|
|
360
|
+
// Formula:
|
|
361
|
+
// int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
|
|
362
|
+
// We precompute the XOR contribution for each
|
|
363
|
+
// sub-quantizer
|
|
364
|
+
|
|
365
|
+
const float max_code_value = (1 << qb) - 1;
|
|
366
|
+
|
|
367
|
+
for (size_t m = 0; m < M; m++) {
|
|
368
|
+
const size_t dim_start = m * 4;
|
|
369
|
+
|
|
370
|
+
for (int code_val = 0; code_val < 16; code_val++) {
|
|
371
|
+
float xor_contribution = 0.0f;
|
|
372
|
+
|
|
373
|
+
// Process 4 bits per sub-quantizer
|
|
374
|
+
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
375
|
+
const size_t dim_idx = dim_start + dim_offset;
|
|
376
|
+
|
|
377
|
+
if (dim_idx < d) {
|
|
378
|
+
const bool db_bit = (code_val >> dim_offset) & 1;
|
|
379
|
+
const float query_value = rotated_qq[dim_idx];
|
|
380
|
+
|
|
381
|
+
// XOR contribution:
|
|
382
|
+
// If db_bit == 0: XOR result = query_value
|
|
383
|
+
// If db_bit == 1: XOR result = (2^qb - 1) -
|
|
384
|
+
// query_value
|
|
385
|
+
xor_contribution += db_bit
|
|
386
|
+
? (max_code_value - query_value)
|
|
387
|
+
: query_value;
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
// Store the XOR contribution (will be scaled by -2 *
|
|
392
|
+
// int_dot_scale during distance computation)
|
|
393
|
+
query_lut[m * 16 + code_val] = xor_contribution;
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
} else {
|
|
398
|
+
// For non-centered quantization, use traditional AND dot
|
|
399
|
+
// product Compute lookup table entries by processing popcount
|
|
400
|
+
// and inner product together
|
|
401
|
+
for (size_t m = 0; m < M; m++) {
|
|
402
|
+
const size_t dim_start = m * 4;
|
|
403
|
+
|
|
404
|
+
for (int code_val = 0; code_val < 16; code_val++) {
|
|
405
|
+
float inner_product = 0.0f;
|
|
406
|
+
int popcount = 0;
|
|
407
|
+
|
|
408
|
+
// Process 4 bits per sub-quantizer
|
|
409
|
+
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
410
|
+
const size_t dim_idx = dim_start + dim_offset;
|
|
411
|
+
|
|
412
|
+
if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
|
|
413
|
+
inner_product += rotated_qq[dim_idx];
|
|
414
|
+
popcount++;
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
// Store pre-computed distance contribution
|
|
419
|
+
query_lut[m * 16 + code_val] =
|
|
420
|
+
query_factors_data.c1 * inner_product +
|
|
421
|
+
query_factors_data.c2 * popcount;
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
429
|
+
const {
|
|
430
|
+
const float* centroid_in =
|
|
431
|
+
(center.data() == nullptr) ? nullptr : center.data();
|
|
432
|
+
const uint8_t* codes = bytes;
|
|
433
|
+
FAISS_ASSERT(codes != nullptr);
|
|
434
|
+
FAISS_ASSERT(x != nullptr);
|
|
435
|
+
|
|
436
|
+
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
437
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
438
|
+
|
|
439
|
+
#pragma omp parallel for if (n > 1000)
|
|
440
|
+
for (int64_t i = 0; i < n; i++) {
|
|
441
|
+
// Access code using correct FastScan format
|
|
442
|
+
const uint8_t* code = codes + i * code_size;
|
|
443
|
+
|
|
444
|
+
// Extract factors directly from embedded codes
|
|
445
|
+
const uint8_t* factors_ptr = code + bit_pattern_size;
|
|
446
|
+
const rabitq_utils::SignBitFactors* fac =
|
|
447
|
+
reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
448
|
+
factors_ptr);
|
|
449
|
+
|
|
450
|
+
for (size_t j = 0; j < d; j++) {
|
|
451
|
+
// Use RaBitQUtils for consistent bit extraction
|
|
452
|
+
bool bit_value = rabitq_utils::extract_bit_fastscan(code, j);
|
|
453
|
+
float bit = bit_value ? 1.0f : 0.0f;
|
|
454
|
+
|
|
455
|
+
// Compute the output using RaBitQ reconstruction formula
|
|
456
|
+
x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt +
|
|
457
|
+
((centroid_in == nullptr) ? 0 : centroid_in[j]);
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
void IndexRaBitQFastScan::search(
|
|
463
|
+
idx_t n,
|
|
464
|
+
const float* x,
|
|
465
|
+
idx_t k,
|
|
466
|
+
float* distances,
|
|
467
|
+
idx_t* labels,
|
|
468
|
+
const SearchParameters* params) const {
|
|
469
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
470
|
+
!params, "search params not supported for this index");
|
|
471
|
+
|
|
472
|
+
// Create query factors array on stack - memory managed by caller
|
|
473
|
+
std::vector<rabitq_utils::QueryFactorsData> query_factors_storage(n);
|
|
474
|
+
|
|
475
|
+
// Use the faster search_dispatch_implem flow from IndexFastScan
|
|
476
|
+
// Pass the query factors array - factors will be computed during LUT
|
|
477
|
+
// computation
|
|
478
|
+
FastScanDistancePostProcessing context;
|
|
479
|
+
context.query_factors = query_factors_storage.data();
|
|
480
|
+
if (metric_type == METRIC_L2) {
|
|
481
|
+
search_dispatch_implem<true>(n, x, k, distances, labels, context);
|
|
482
|
+
} else {
|
|
483
|
+
search_dispatch_implem<false>(n, x, k, distances, labels, context);
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
// Template implementations for RaBitQHeapHandler
|
|
488
|
+
template <class C, bool with_id_map>
|
|
489
|
+
RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
|
|
490
|
+
const IndexRaBitQFastScan* index,
|
|
491
|
+
size_t nq_val,
|
|
492
|
+
size_t k_val,
|
|
493
|
+
float* distances,
|
|
494
|
+
int64_t* labels,
|
|
495
|
+
const IDSelector* sel_in,
|
|
496
|
+
const FastScanDistancePostProcessing& ctx,
|
|
497
|
+
bool multi_bit)
|
|
498
|
+
: RHC(nq_val, index->ntotal, sel_in),
|
|
499
|
+
rabitq_index(index),
|
|
500
|
+
heap_distances(distances),
|
|
501
|
+
heap_labels(labels),
|
|
502
|
+
nq(nq_val),
|
|
503
|
+
k(k_val),
|
|
504
|
+
context(ctx),
|
|
505
|
+
is_multi_bit(multi_bit) {
|
|
506
|
+
// Initialize heaps for all queries in constructor
|
|
507
|
+
// This allows us to support direct normalizer assignment
|
|
508
|
+
#pragma omp parallel for if (nq > 100)
|
|
509
|
+
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
510
|
+
float* heap_dis = heap_distances + q * k;
|
|
511
|
+
int64_t* heap_ids = heap_labels + q * k;
|
|
512
|
+
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
513
|
+
}
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
template <class C, bool with_id_map>
|
|
517
|
+
void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
518
|
+
size_t q,
|
|
519
|
+
size_t b,
|
|
520
|
+
simd16uint16 d0,
|
|
521
|
+
simd16uint16 d1) {
|
|
522
|
+
ALIGNED(32) uint16_t d32tab[32];
|
|
523
|
+
d0.store(d32tab);
|
|
524
|
+
d1.store(d32tab + 16);
|
|
525
|
+
|
|
526
|
+
// Get heap pointers and query factors (computed once per batch)
|
|
527
|
+
float* const heap_dis = heap_distances + q * k;
|
|
528
|
+
int64_t* const heap_ids = heap_labels + q * k;
|
|
529
|
+
|
|
530
|
+
// Access query factors from query_factors pointer
|
|
531
|
+
rabitq_utils::QueryFactorsData query_factors_data = {};
|
|
532
|
+
if (context.query_factors != nullptr) {
|
|
533
|
+
query_factors_data = context.query_factors[q];
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
// Compute normalizers once per batch
|
|
537
|
+
const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
|
|
538
|
+
const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
|
|
539
|
+
|
|
540
|
+
// Compute loop bounds to avoid redundant bounds checking
|
|
541
|
+
const size_t base_db_idx = this->j0 + b * 32;
|
|
542
|
+
const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
|
|
543
|
+
? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
|
|
544
|
+
: 0;
|
|
545
|
+
|
|
546
|
+
// Get storage size once
|
|
547
|
+
const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
|
|
548
|
+
|
|
549
|
+
// Stats tracking for multi-bit two-stage search only
|
|
550
|
+
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
551
|
+
// n_multibit_evaluations: candidates requiring full multi-bit distance
|
|
552
|
+
size_t local_1bit_evaluations = 0;
|
|
553
|
+
size_t local_multibit_evaluations = 0;
|
|
554
|
+
|
|
555
|
+
// Process distances in batch
|
|
556
|
+
for (size_t i = 0; i < max_vectors; i++) {
|
|
557
|
+
const size_t db_idx = base_db_idx + i;
|
|
558
|
+
|
|
559
|
+
// Normalize distance from LUT lookup
|
|
560
|
+
const float normalized_distance = d32tab[i] * one_a + bias;
|
|
561
|
+
|
|
562
|
+
// Access factors from flat storage
|
|
563
|
+
const uint8_t* base_ptr =
|
|
564
|
+
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
565
|
+
|
|
566
|
+
if (is_multi_bit) {
|
|
567
|
+
// Track candidates actually considered for two-stage filtering
|
|
568
|
+
local_1bit_evaluations++;
|
|
569
|
+
|
|
570
|
+
const SignBitFactorsWithError& full_factors =
|
|
571
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
572
|
+
|
|
573
|
+
float dist_1bit = rabitq_utils::compute_1bit_adjusted_distance(
|
|
574
|
+
normalized_distance,
|
|
575
|
+
full_factors,
|
|
576
|
+
query_factors_data,
|
|
577
|
+
rabitq_index->centered,
|
|
578
|
+
rabitq_index->qb,
|
|
579
|
+
rabitq_index->d);
|
|
580
|
+
|
|
581
|
+
float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
|
|
582
|
+
|
|
583
|
+
// Adaptive filtering: decide whether to compute full distance
|
|
584
|
+
const bool is_similarity = rabitq_index->metric_type ==
|
|
585
|
+
MetricType::METRIC_INNER_PRODUCT;
|
|
586
|
+
bool should_refine = is_similarity
|
|
587
|
+
? (lower_bound > heap_dis[0]) // IP: keep if better
|
|
588
|
+
: (lower_bound < heap_dis[0]); // L2: keep if better
|
|
589
|
+
|
|
590
|
+
if (should_refine) {
|
|
591
|
+
local_multibit_evaluations++;
|
|
592
|
+
float dist_full = compute_full_multibit_distance(db_idx, q);
|
|
593
|
+
|
|
594
|
+
if (Cfloat::cmp(heap_dis[0], dist_full)) {
|
|
595
|
+
heap_replace_top<Cfloat>(
|
|
596
|
+
k, heap_dis, heap_ids, dist_full, db_idx);
|
|
597
|
+
}
|
|
598
|
+
}
|
|
599
|
+
} else {
|
|
600
|
+
const rabitq_utils::SignBitFactors& db_factors =
|
|
601
|
+
*reinterpret_cast<const rabitq_utils::SignBitFactors*>(
|
|
602
|
+
base_ptr);
|
|
603
|
+
|
|
604
|
+
float adjusted_distance =
|
|
605
|
+
rabitq_utils::compute_1bit_adjusted_distance(
|
|
606
|
+
normalized_distance,
|
|
607
|
+
db_factors,
|
|
608
|
+
query_factors_data,
|
|
609
|
+
rabitq_index->centered,
|
|
610
|
+
rabitq_index->qb,
|
|
611
|
+
rabitq_index->d);
|
|
612
|
+
|
|
613
|
+
// Add to heap if better than current worst
|
|
614
|
+
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
615
|
+
heap_replace_top<Cfloat>(
|
|
616
|
+
k, heap_dis, heap_ids, adjusted_distance, db_idx);
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
// Update global stats atomically
|
|
622
|
+
#pragma omp atomic
|
|
623
|
+
rabitq_stats.n_1bit_evaluations += local_1bit_evaluations;
|
|
624
|
+
#pragma omp atomic
|
|
625
|
+
rabitq_stats.n_multibit_evaluations += local_multibit_evaluations;
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
template <class C, bool with_id_map>
|
|
629
|
+
void RaBitQHeapHandler<C, with_id_map>::begin(const float* norms) {
|
|
630
|
+
normalizers = norms;
|
|
631
|
+
// Heap initialization is now done in constructor
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
template <class C, bool with_id_map>
|
|
635
|
+
void RaBitQHeapHandler<C, with_id_map>::end() {
|
|
636
|
+
// Reorder final results
|
|
637
|
+
#pragma omp parallel for if (nq > 100)
|
|
638
|
+
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
639
|
+
float* heap_dis = heap_distances + q * k;
|
|
640
|
+
int64_t* heap_ids = heap_labels + q * k;
|
|
641
|
+
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
|
|
642
|
+
}
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
template <class C, bool with_id_map>
|
|
646
|
+
float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
|
|
647
|
+
float dist_1bit,
|
|
648
|
+
size_t db_idx,
|
|
649
|
+
size_t q) const {
|
|
650
|
+
// Access f_error directly from SignBitFactorsWithError in flat storage
|
|
651
|
+
const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
|
|
652
|
+
const uint8_t* base_ptr =
|
|
653
|
+
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
654
|
+
const SignBitFactorsWithError& db_factors =
|
|
655
|
+
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
656
|
+
float f_error = db_factors.f_error;
|
|
657
|
+
|
|
658
|
+
// Get g_error from query factors (query-dependent error term)
|
|
659
|
+
float g_error = 0.0f;
|
|
660
|
+
if (context.query_factors != nullptr) {
|
|
661
|
+
g_error = context.query_factors[q].g_error;
|
|
662
|
+
}
|
|
663
|
+
|
|
664
|
+
// Compute error adjustment: f_error * g_error
|
|
665
|
+
float error_adjustment = f_error * g_error;
|
|
666
|
+
|
|
667
|
+
return dist_1bit - error_adjustment;
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
template <class C, bool with_id_map>
|
|
671
|
+
float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
|
|
672
|
+
size_t db_idx,
|
|
673
|
+
size_t q) const {
|
|
674
|
+
const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
|
|
675
|
+
const size_t dim = rabitq_index->d;
|
|
676
|
+
|
|
677
|
+
const size_t storage_size = rabitq_index->compute_per_vector_storage_size();
|
|
678
|
+
const uint8_t* base_ptr =
|
|
679
|
+
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
680
|
+
|
|
681
|
+
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
682
|
+
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
683
|
+
const ExtraBitsFactors& ex_fac = *reinterpret_cast<const ExtraBitsFactors*>(
|
|
684
|
+
base_ptr + sizeof(SignBitFactorsWithError) + ex_code_size);
|
|
685
|
+
|
|
686
|
+
// Get query factors reference (avoid copying)
|
|
687
|
+
const rabitq_utils::QueryFactorsData& query_factors =
|
|
688
|
+
context.query_factors[q];
|
|
689
|
+
|
|
690
|
+
// Get sign bits from FastScan packed format
|
|
691
|
+
std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
|
|
692
|
+
CodePackerPQ4 packer(rabitq_index->M2, rabitq_index->bbs);
|
|
693
|
+
packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
|
|
694
|
+
const uint8_t* sign_bits = unpacked_code.data();
|
|
695
|
+
|
|
696
|
+
return rabitq_utils::compute_full_multibit_distance(
|
|
697
|
+
sign_bits,
|
|
698
|
+
ex_code,
|
|
699
|
+
ex_fac,
|
|
700
|
+
query_factors.rotated_q.data(),
|
|
701
|
+
query_factors.qr_to_c_L2sqr,
|
|
702
|
+
query_factors.qr_norm_L2sqr,
|
|
703
|
+
dim,
|
|
704
|
+
ex_bits,
|
|
705
|
+
rabitq_index->metric_type);
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
// Implementation of virtual make_knn_handler method
|
|
709
|
+
SIMDResultHandlerToFloat* IndexRaBitQFastScan::make_knn_handler(
|
|
710
|
+
bool is_max,
|
|
711
|
+
int /*impl*/,
|
|
712
|
+
idx_t n,
|
|
713
|
+
idx_t k,
|
|
714
|
+
size_t /*ntotal*/,
|
|
715
|
+
float* distances,
|
|
716
|
+
idx_t* labels,
|
|
717
|
+
const IDSelector* sel,
|
|
718
|
+
const FastScanDistancePostProcessing& context) const {
|
|
719
|
+
// Use runtime boolean for multi-bit mode
|
|
720
|
+
const bool multi_bit = rabitq.nb_bits > 1;
|
|
721
|
+
|
|
722
|
+
if (is_max) {
|
|
723
|
+
return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
|
|
724
|
+
this, n, k, distances, labels, sel, context, multi_bit);
|
|
725
|
+
} else {
|
|
726
|
+
return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
|
|
727
|
+
this, n, k, distances, labels, sel, context, multi_bit);
|
|
728
|
+
}
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
} // namespace faiss
|