faiss 0.4.2 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/faiss/index.cpp +36 -10
- data/ext/faiss/index_binary.cpp +19 -6
- data/ext/faiss/kmeans.cpp +6 -6
- data/ext/faiss/numo.hpp +273 -123
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +1 -2
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.h +10 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +107 -7
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
- data/vendor/faiss/faiss/IndexHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
- data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +3 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
- data/vendor/faiss/faiss/impl/HNSW.h +4 -4
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
- data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +43 -1
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +5 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +14 -3
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/IndexRaBitQFastScan.h>
|
|
9
|
+
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
10
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
11
|
+
#include <faiss/impl/pq4_fast_scan.h>
|
|
12
|
+
#include <faiss/utils/utils.h>
|
|
13
|
+
#include <algorithm>
|
|
14
|
+
#include <cmath>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
|
|
18
|
+
static inline size_t roundup(size_t a, size_t b) {
|
|
19
|
+
return (a + b - 1) / b * b;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
|
|
23
|
+
|
|
24
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan(idx_t d, MetricType metric, int bbs)
|
|
25
|
+
: rabitq(d, metric) {
|
|
26
|
+
// RaBitQ-specific validation
|
|
27
|
+
FAISS_THROW_IF_NOT_MSG(d > 0, "Dimension must be positive");
|
|
28
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
29
|
+
metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT,
|
|
30
|
+
"RaBitQ FastScan only supports L2 and Inner Product metrics");
|
|
31
|
+
|
|
32
|
+
// RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
|
|
33
|
+
// Each FastScan sub-quantizer handles 4 RaBitQ dimensions
|
|
34
|
+
const size_t M_fastscan = (d + 3) / 4;
|
|
35
|
+
constexpr size_t nbits_fastscan = 4;
|
|
36
|
+
|
|
37
|
+
// init_fastscan will validate bbs % 32 == 0 and nbits_fastscan == 4
|
|
38
|
+
init_fastscan(static_cast<int>(d), M_fastscan, nbits_fastscan, metric, bbs);
|
|
39
|
+
|
|
40
|
+
// Override code_size to include space for factors after bit patterns
|
|
41
|
+
// RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
|
|
42
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
43
|
+
code_size = bit_pattern_size + sizeof(FactorsData);
|
|
44
|
+
|
|
45
|
+
// Set RaBitQ-specific parameters
|
|
46
|
+
qb = 8;
|
|
47
|
+
center.resize(d, 0.0f);
|
|
48
|
+
|
|
49
|
+
// Pre-allocate storage vectors for efficiency
|
|
50
|
+
factors_storage.clear();
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
54
|
+
: rabitq(orig.rabitq) {
|
|
55
|
+
// RaBitQ-specific validation
|
|
56
|
+
FAISS_THROW_IF_NOT_MSG(orig.d > 0, "Dimension must be positive");
|
|
57
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
58
|
+
orig.metric_type == METRIC_L2 ||
|
|
59
|
+
orig.metric_type == METRIC_INNER_PRODUCT,
|
|
60
|
+
"RaBitQ FastScan only supports L2 and Inner Product metrics");
|
|
61
|
+
|
|
62
|
+
// RaBitQ uses 1 bit per dimension packed into 4-bit FastScan sub-quantizers
|
|
63
|
+
// Each FastScan sub-quantizer handles 4 RaBitQ dimensions
|
|
64
|
+
const size_t M_fastscan = (orig.d + 3) / 4;
|
|
65
|
+
constexpr size_t nbits_fastscan = 4;
|
|
66
|
+
|
|
67
|
+
// Initialize FastScan base with the original index's parameters
|
|
68
|
+
init_fastscan(
|
|
69
|
+
static_cast<int>(orig.d),
|
|
70
|
+
M_fastscan,
|
|
71
|
+
nbits_fastscan,
|
|
72
|
+
orig.metric_type,
|
|
73
|
+
bbs);
|
|
74
|
+
|
|
75
|
+
// Override code_size to include space for factors after bit patterns
|
|
76
|
+
// RaBitQ stores 1 bit per dimension, requiring (d + 7) / 8 bytes
|
|
77
|
+
const size_t bit_pattern_size = (orig.d + 7) / 8;
|
|
78
|
+
code_size = bit_pattern_size + sizeof(FactorsData);
|
|
79
|
+
|
|
80
|
+
// Copy properties from original index
|
|
81
|
+
ntotal = orig.ntotal;
|
|
82
|
+
ntotal2 = roundup(ntotal, bbs);
|
|
83
|
+
is_trained = orig.is_trained;
|
|
84
|
+
orig_codes = orig.codes.data();
|
|
85
|
+
qb = orig.qb;
|
|
86
|
+
centered = orig.centered;
|
|
87
|
+
center = orig.center;
|
|
88
|
+
|
|
89
|
+
// If the original index has data, extract factors and pack codes
|
|
90
|
+
if (ntotal > 0) {
|
|
91
|
+
// Allocate space for factors
|
|
92
|
+
factors_storage.resize(ntotal);
|
|
93
|
+
|
|
94
|
+
// Extract factors from original codes for each vector
|
|
95
|
+
const float* centroid_data = center.data();
|
|
96
|
+
|
|
97
|
+
// Use the original RaBitQ quantizer to decode and compute factors
|
|
98
|
+
std::vector<float> decoded_vectors(ntotal * orig.d);
|
|
99
|
+
orig.sa_decode(ntotal, orig.codes.data(), decoded_vectors.data());
|
|
100
|
+
|
|
101
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
102
|
+
FactorsData& fac = factors_storage[i];
|
|
103
|
+
const float* x_row = decoded_vectors.data() + i * orig.d;
|
|
104
|
+
|
|
105
|
+
// Use shared utilities for computing factors
|
|
106
|
+
fac = rabitq_utils::compute_vector_factors(
|
|
107
|
+
x_row, orig.d, centroid_data, orig.metric_type);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
|
|
111
|
+
// This follows the same pattern as IndexPQFastScan constructor
|
|
112
|
+
AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
|
|
113
|
+
memset(fastscan_codes.get(), 0, ntotal * code_size);
|
|
114
|
+
|
|
115
|
+
// Convert from RaBitQ 1-bit-per-dimension to FastScan
|
|
116
|
+
// 4-bit-per-sub-quantizer
|
|
117
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
118
|
+
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
119
|
+
uint8_t* fs_code = fastscan_codes.get() + i * code_size;
|
|
120
|
+
|
|
121
|
+
// Convert each dimension's bit (same logic as compute_codes)
|
|
122
|
+
for (size_t j = 0; j < orig.d; j++) {
|
|
123
|
+
// Extract bit from original RaBitQ format
|
|
124
|
+
const size_t orig_byte_idx = j / 8;
|
|
125
|
+
const size_t orig_bit_offset = j % 8;
|
|
126
|
+
const bool bit_value =
|
|
127
|
+
(orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
|
|
128
|
+
|
|
129
|
+
// Use RaBitQUtils for consistent bit setting
|
|
130
|
+
if (bit_value) {
|
|
131
|
+
rabitq_utils::set_bit_fastscan(fs_code, j);
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
// Pack the converted codes using pq4_pack_codes with custom stride
|
|
137
|
+
codes.resize(ntotal2 * M2 / 2);
|
|
138
|
+
pq4_pack_codes(
|
|
139
|
+
fastscan_codes.get(),
|
|
140
|
+
ntotal,
|
|
141
|
+
M,
|
|
142
|
+
ntotal2,
|
|
143
|
+
bbs,
|
|
144
|
+
M2,
|
|
145
|
+
codes.get(),
|
|
146
|
+
code_size);
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
void IndexRaBitQFastScan::train(idx_t n, const float* x) {
|
|
151
|
+
// compute a centroid
|
|
152
|
+
std::vector<float> centroid(d, 0);
|
|
153
|
+
for (int64_t i = 0; i < static_cast<int64_t>(n); i++) {
|
|
154
|
+
for (size_t j = 0; j < d; j++) {
|
|
155
|
+
centroid[j] += x[i * d + j];
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if (n != 0) {
|
|
160
|
+
for (size_t j = 0; j < d; j++) {
|
|
161
|
+
centroid[j] /= (float)n;
|
|
162
|
+
}
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
center = std::move(centroid);
|
|
166
|
+
|
|
167
|
+
rabitq.train(n, x);
|
|
168
|
+
is_trained = true;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
void IndexRaBitQFastScan::add(idx_t n, const float* x) {
|
|
172
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
173
|
+
|
|
174
|
+
// Handle blocking to avoid excessive allocations
|
|
175
|
+
constexpr idx_t bs = 65536;
|
|
176
|
+
if (n > bs) {
|
|
177
|
+
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
178
|
+
idx_t i1 = std::min(n, i0 + bs);
|
|
179
|
+
if (verbose) {
|
|
180
|
+
printf("IndexRaBitQFastScan::add %zd/%zd\n",
|
|
181
|
+
size_t(i1),
|
|
182
|
+
size_t(n));
|
|
183
|
+
}
|
|
184
|
+
add(i1 - i0, x + i0 * d);
|
|
185
|
+
}
|
|
186
|
+
return;
|
|
187
|
+
}
|
|
188
|
+
InterruptCallback::check();
|
|
189
|
+
|
|
190
|
+
// Create codes with embedded factors using our compute_codes
|
|
191
|
+
AlignedTable<uint8_t> tmp_codes(n * code_size);
|
|
192
|
+
compute_codes(tmp_codes.get(), n, x);
|
|
193
|
+
|
|
194
|
+
// Extract and store factors from embedded codes for handler access
|
|
195
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
196
|
+
factors_storage.resize(ntotal + n);
|
|
197
|
+
for (idx_t i = 0; i < n; i++) {
|
|
198
|
+
const uint8_t* code = tmp_codes.get() + i * code_size;
|
|
199
|
+
const uint8_t* factors_ptr = code + bit_pattern_size;
|
|
200
|
+
const FactorsData& embedded_factors =
|
|
201
|
+
*reinterpret_cast<const FactorsData*>(factors_ptr);
|
|
202
|
+
factors_storage[ntotal + i] = embedded_factors;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
// Resize main storage (same logic as parent)
|
|
206
|
+
ntotal2 = roundup(ntotal + n, bbs);
|
|
207
|
+
size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
|
|
208
|
+
size_t old_size = codes.size();
|
|
209
|
+
if (new_size > old_size) {
|
|
210
|
+
codes.resize(new_size);
|
|
211
|
+
memset(codes.get() + old_size, 0, new_size - old_size);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
// Use our custom packing function with correct stride
|
|
215
|
+
pq4_pack_codes_range(
|
|
216
|
+
tmp_codes.get(),
|
|
217
|
+
M, // Number of sub-quantizers (bit patterns only)
|
|
218
|
+
ntotal,
|
|
219
|
+
ntotal + n, // Range to pack
|
|
220
|
+
bbs,
|
|
221
|
+
M2, // Block parameters
|
|
222
|
+
codes.get(), // Output
|
|
223
|
+
code_size); // CUSTOM STRIDE: includes factor space
|
|
224
|
+
|
|
225
|
+
ntotal += n;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
void IndexRaBitQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
229
|
+
const {
|
|
230
|
+
FAISS_ASSERT(codes != nullptr);
|
|
231
|
+
FAISS_ASSERT(x != nullptr);
|
|
232
|
+
FAISS_ASSERT(
|
|
233
|
+
(metric_type == MetricType::METRIC_L2 ||
|
|
234
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT));
|
|
235
|
+
if (n == 0) {
|
|
236
|
+
return;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
// Hoist loop-invariant computations
|
|
240
|
+
const float* centroid_data = center.data();
|
|
241
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
242
|
+
|
|
243
|
+
memset(codes, 0, n * code_size);
|
|
244
|
+
|
|
245
|
+
#pragma omp parallel for if (n > 1000)
|
|
246
|
+
for (int64_t i = 0; i < n; i++) {
|
|
247
|
+
uint8_t* const code = codes + i * code_size;
|
|
248
|
+
const float* const x_row = x + i * d;
|
|
249
|
+
|
|
250
|
+
// Pack bits directly into FastScan format
|
|
251
|
+
for (size_t j = 0; j < d; j++) {
|
|
252
|
+
const float x_val = x_row[j];
|
|
253
|
+
const float centroid_val = centroid_data ? centroid_data[j] : 0.0f;
|
|
254
|
+
const float or_minus_c = x_val - centroid_val;
|
|
255
|
+
const bool xb = (or_minus_c > 0.0f);
|
|
256
|
+
|
|
257
|
+
if (xb) {
|
|
258
|
+
rabitq_utils::set_bit_fastscan(code, j);
|
|
259
|
+
}
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
// Calculate and append factors after the bit data
|
|
263
|
+
FactorsData factors = rabitq_utils::compute_vector_factors(
|
|
264
|
+
x_row, d, centroid_data, metric_type);
|
|
265
|
+
|
|
266
|
+
// Append factors at the end of the code
|
|
267
|
+
uint8_t* factors_ptr = code + bit_pattern_size;
|
|
268
|
+
*reinterpret_cast<FactorsData*>(factors_ptr) = factors;
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
void IndexRaBitQFastScan::compute_float_LUT(
|
|
273
|
+
float* lut,
|
|
274
|
+
idx_t n,
|
|
275
|
+
const float* x,
|
|
276
|
+
const FastScanDistancePostProcessing& context) const {
|
|
277
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
278
|
+
|
|
279
|
+
// Pre-allocate working buffers to avoid repeated allocations
|
|
280
|
+
std::vector<float> rotated_q(d);
|
|
281
|
+
std::vector<uint8_t> rotated_qq(d);
|
|
282
|
+
|
|
283
|
+
// Compute lookup tables for FastScan SIMD operations
|
|
284
|
+
// For each query vector, computes distance contributions for all
|
|
285
|
+
// possible 4-bit codes per sub-quantizer. Also computes and stores
|
|
286
|
+
// query factors for distance reconstruction.
|
|
287
|
+
for (idx_t i = 0; i < n; i++) {
|
|
288
|
+
const float* query = x + i * d;
|
|
289
|
+
|
|
290
|
+
// Compute query factors and store in array if available
|
|
291
|
+
rabitq_utils::QueryFactorsData query_factors_data =
|
|
292
|
+
rabitq_utils::compute_query_factors(
|
|
293
|
+
query,
|
|
294
|
+
d,
|
|
295
|
+
center.data(),
|
|
296
|
+
qb,
|
|
297
|
+
centered,
|
|
298
|
+
metric_type,
|
|
299
|
+
rotated_q,
|
|
300
|
+
rotated_qq);
|
|
301
|
+
|
|
302
|
+
// Store query factors in context array if provided
|
|
303
|
+
if (context.query_factors) {
|
|
304
|
+
context.query_factors[i] = query_factors_data;
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// Create lookup table storing distance contributions for all possible
|
|
308
|
+
// 4-bit codes per sub-quantizer for FastScan SIMD operations
|
|
309
|
+
float* query_lut = lut + i * M * 16;
|
|
310
|
+
|
|
311
|
+
if (centered) {
|
|
312
|
+
// For centered mode, we use the signed odd integer quantization
|
|
313
|
+
// scheme.
|
|
314
|
+
// Formula:
|
|
315
|
+
// int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
|
|
316
|
+
// We precompute the XOR contribution for each
|
|
317
|
+
// sub-quantizer
|
|
318
|
+
|
|
319
|
+
const float max_code_value = (1 << qb) - 1;
|
|
320
|
+
|
|
321
|
+
for (size_t m = 0; m < M; m++) {
|
|
322
|
+
const size_t dim_start = m * 4;
|
|
323
|
+
|
|
324
|
+
for (int code_val = 0; code_val < 16; code_val++) {
|
|
325
|
+
float xor_contribution = 0.0f;
|
|
326
|
+
|
|
327
|
+
// Process 4 bits per sub-quantizer
|
|
328
|
+
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
329
|
+
const size_t dim_idx = dim_start + dim_offset;
|
|
330
|
+
|
|
331
|
+
if (dim_idx < d) {
|
|
332
|
+
const bool db_bit = (code_val >> dim_offset) & 1;
|
|
333
|
+
const float query_value = rotated_qq[dim_idx];
|
|
334
|
+
|
|
335
|
+
// XOR contribution:
|
|
336
|
+
// If db_bit == 0: XOR result = query_value
|
|
337
|
+
// If db_bit == 1: XOR result = (2^qb - 1) -
|
|
338
|
+
// query_value
|
|
339
|
+
xor_contribution += db_bit
|
|
340
|
+
? (max_code_value - query_value)
|
|
341
|
+
: query_value;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
// Store the XOR contribution (will be scaled by -2 *
|
|
346
|
+
// int_dot_scale during distance computation)
|
|
347
|
+
query_lut[m * 16 + code_val] = xor_contribution;
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
} else {
|
|
352
|
+
// For non-centered quantization, use traditional AND dot
|
|
353
|
+
// product Compute lookup table entries by processing popcount
|
|
354
|
+
// and inner product together
|
|
355
|
+
for (size_t m = 0; m < M; m++) {
|
|
356
|
+
const size_t dim_start = m * 4;
|
|
357
|
+
|
|
358
|
+
for (int code_val = 0; code_val < 16; code_val++) {
|
|
359
|
+
float inner_product = 0.0f;
|
|
360
|
+
int popcount = 0;
|
|
361
|
+
|
|
362
|
+
// Process 4 bits per sub-quantizer
|
|
363
|
+
for (size_t dim_offset = 0; dim_offset < 4; dim_offset++) {
|
|
364
|
+
const size_t dim_idx = dim_start + dim_offset;
|
|
365
|
+
|
|
366
|
+
if (dim_idx < d && ((code_val >> dim_offset) & 1)) {
|
|
367
|
+
inner_product += rotated_qq[dim_idx];
|
|
368
|
+
popcount++;
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
// Store pre-computed distance contribution
|
|
373
|
+
query_lut[m * 16 + code_val] =
|
|
374
|
+
query_factors_data.c1 * inner_product +
|
|
375
|
+
query_factors_data.c2 * popcount;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
void IndexRaBitQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
383
|
+
const {
|
|
384
|
+
const float* centroid_in =
|
|
385
|
+
(center.data() == nullptr) ? nullptr : center.data();
|
|
386
|
+
const uint8_t* codes = bytes;
|
|
387
|
+
FAISS_ASSERT(codes != nullptr);
|
|
388
|
+
FAISS_ASSERT(x != nullptr);
|
|
389
|
+
|
|
390
|
+
const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
|
|
391
|
+
const size_t bit_pattern_size = (d + 7) / 8;
|
|
392
|
+
|
|
393
|
+
#pragma omp parallel for if (n > 1000)
|
|
394
|
+
for (int64_t i = 0; i < n; i++) {
|
|
395
|
+
// Access code using correct FastScan format
|
|
396
|
+
const uint8_t* code = codes + i * code_size;
|
|
397
|
+
|
|
398
|
+
// Extract factors directly from embedded codes
|
|
399
|
+
const uint8_t* factors_ptr = code + bit_pattern_size;
|
|
400
|
+
const FactorsData& fac =
|
|
401
|
+
*reinterpret_cast<const FactorsData*>(factors_ptr);
|
|
402
|
+
|
|
403
|
+
for (size_t j = 0; j < d; j++) {
|
|
404
|
+
// Use RaBitQUtils for consistent bit extraction
|
|
405
|
+
bool bit_value = rabitq_utils::extract_bit_fastscan(code, j);
|
|
406
|
+
float bit = bit_value ? 1.0f : 0.0f;
|
|
407
|
+
|
|
408
|
+
// Compute the output using RaBitQ reconstruction formula
|
|
409
|
+
x[i * d + j] = (bit - 0.5f) * fac.dp_multiplier * 2 * inv_d_sqrt +
|
|
410
|
+
((centroid_in == nullptr) ? 0 : centroid_in[j]);
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
void IndexRaBitQFastScan::search(
|
|
416
|
+
idx_t n,
|
|
417
|
+
const float* x,
|
|
418
|
+
idx_t k,
|
|
419
|
+
float* distances,
|
|
420
|
+
idx_t* labels,
|
|
421
|
+
const SearchParameters* params) const {
|
|
422
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
423
|
+
!params, "search params not supported for this index");
|
|
424
|
+
|
|
425
|
+
// Create query factors array on stack - memory managed by caller
|
|
426
|
+
std::vector<rabitq_utils::QueryFactorsData> query_factors_storage(n);
|
|
427
|
+
|
|
428
|
+
// Use the faster search_dispatch_implem flow from IndexFastScan
|
|
429
|
+
// Pass the query factors array - factors will be computed during LUT
|
|
430
|
+
// computation
|
|
431
|
+
FastScanDistancePostProcessing context;
|
|
432
|
+
context.query_factors = query_factors_storage.data();
|
|
433
|
+
if (metric_type == METRIC_L2) {
|
|
434
|
+
search_dispatch_implem<true>(n, x, k, distances, labels, context);
|
|
435
|
+
} else {
|
|
436
|
+
search_dispatch_implem<false>(n, x, k, distances, labels, context);
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
// Template implementations for RaBitQHeapHandler
|
|
441
|
+
template <class C, bool with_id_map>
|
|
442
|
+
RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
|
|
443
|
+
const IndexRaBitQFastScan* index,
|
|
444
|
+
size_t nq_val,
|
|
445
|
+
size_t k_val,
|
|
446
|
+
float* distances,
|
|
447
|
+
int64_t* labels,
|
|
448
|
+
const IDSelector* sel_in,
|
|
449
|
+
const FastScanDistancePostProcessing& ctx)
|
|
450
|
+
: RHC(nq_val, index->ntotal, sel_in),
|
|
451
|
+
rabitq_index(index),
|
|
452
|
+
heap_distances(distances),
|
|
453
|
+
heap_labels(labels),
|
|
454
|
+
nq(nq_val),
|
|
455
|
+
k(k_val),
|
|
456
|
+
context(ctx) {
|
|
457
|
+
// Initialize heaps for all queries in constructor
|
|
458
|
+
// This allows us to support direct normalizer assignment
|
|
459
|
+
#pragma omp parallel for if (nq > 100)
|
|
460
|
+
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
461
|
+
float* heap_dis = heap_distances + q * k;
|
|
462
|
+
int64_t* heap_ids = heap_labels + q * k;
|
|
463
|
+
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
template <class C, bool with_id_map>
|
|
468
|
+
void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
469
|
+
size_t q,
|
|
470
|
+
size_t b,
|
|
471
|
+
simd16uint16 d0,
|
|
472
|
+
simd16uint16 d1) {
|
|
473
|
+
ALIGNED(32) uint16_t d32tab[32];
|
|
474
|
+
d0.store(d32tab);
|
|
475
|
+
d1.store(d32tab + 16);
|
|
476
|
+
|
|
477
|
+
// Get heap pointers and query factors (computed once per batch)
|
|
478
|
+
float* const heap_dis = heap_distances + q * k;
|
|
479
|
+
int64_t* const heap_ids = heap_labels + q * k;
|
|
480
|
+
|
|
481
|
+
// Access query factors from query_factors pointer
|
|
482
|
+
rabitq_utils::QueryFactorsData query_factors_data = {};
|
|
483
|
+
if (context.query_factors) {
|
|
484
|
+
query_factors_data = context.query_factors[q];
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
// Compute normalizers once per batch
|
|
488
|
+
const float one_a = normalizers ? (1.0f / normalizers[2 * q]) : 1.0f;
|
|
489
|
+
const float bias = normalizers ? normalizers[2 * q + 1] : 0.0f;
|
|
490
|
+
|
|
491
|
+
// Compute loop bounds to avoid redundant bounds checking
|
|
492
|
+
const size_t base_db_idx = this->j0 + b * 32;
|
|
493
|
+
const size_t max_vectors = (base_db_idx < rabitq_index->ntotal)
|
|
494
|
+
? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
|
|
495
|
+
: 0;
|
|
496
|
+
|
|
497
|
+
// Process distances in batch
|
|
498
|
+
for (size_t i = 0; i < max_vectors; i++) {
|
|
499
|
+
const size_t db_idx = base_db_idx + i;
|
|
500
|
+
|
|
501
|
+
// Normalize distance from LUT lookup
|
|
502
|
+
const float normalized_distance = d32tab[i] * one_a + bias;
|
|
503
|
+
|
|
504
|
+
// Access factors from storage (populated from embedded codes during
|
|
505
|
+
// add())
|
|
506
|
+
const auto& db_factors = rabitq_index->factors_storage[db_idx];
|
|
507
|
+
|
|
508
|
+
float adjusted_distance;
|
|
509
|
+
|
|
510
|
+
if (rabitq_index->centered) {
|
|
511
|
+
// For centered mode: normalized_distance contains the raw XOR
|
|
512
|
+
// contribution. Apply the signed odd integer quantization formula:
|
|
513
|
+
// int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
|
|
514
|
+
int64_t int_dot = ((1 << rabitq_index->qb) - 1) * rabitq_index->d;
|
|
515
|
+
int_dot -= 2 * static_cast<int64_t>(normalized_distance);
|
|
516
|
+
|
|
517
|
+
adjusted_distance = query_factors_data.qr_to_c_L2sqr +
|
|
518
|
+
db_factors.or_minus_c_l2sqr -
|
|
519
|
+
2 * db_factors.dp_multiplier * int_dot *
|
|
520
|
+
query_factors_data.int_dot_scale;
|
|
521
|
+
} else {
|
|
522
|
+
// For non-centered quantization: use traditional formula
|
|
523
|
+
float final_dot = normalized_distance - query_factors_data.c34;
|
|
524
|
+
adjusted_distance = db_factors.or_minus_c_l2sqr +
|
|
525
|
+
query_factors_data.qr_to_c_L2sqr -
|
|
526
|
+
2 * db_factors.dp_multiplier * final_dot;
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
// Apply inner product correction if needed
|
|
530
|
+
if (query_factors_data.qr_norm_L2sqr != 0.0f) {
|
|
531
|
+
adjusted_distance = -0.5f *
|
|
532
|
+
(adjusted_distance - query_factors_data.qr_norm_L2sqr);
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
// Add to heap if better than current worst
|
|
536
|
+
if (Cfloat::cmp(heap_dis[0], adjusted_distance)) {
|
|
537
|
+
heap_replace_top<Cfloat>(
|
|
538
|
+
k, heap_dis, heap_ids, adjusted_distance, db_idx);
|
|
539
|
+
}
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
template <class C, bool with_id_map>
|
|
544
|
+
void RaBitQHeapHandler<C, with_id_map>::begin(const float* norms) {
|
|
545
|
+
normalizers = norms;
|
|
546
|
+
// Heap initialization is now done in constructor
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
template <class C, bool with_id_map>
|
|
550
|
+
void RaBitQHeapHandler<C, with_id_map>::end() {
|
|
551
|
+
// Reorder final results
|
|
552
|
+
#pragma omp parallel for if (nq > 100)
|
|
553
|
+
for (int64_t q = 0; q < static_cast<int64_t>(nq); q++) {
|
|
554
|
+
float* heap_dis = heap_distances + q * k;
|
|
555
|
+
int64_t* heap_ids = heap_labels + q * k;
|
|
556
|
+
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
|
|
557
|
+
}
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
// Implementation of virtual make_knn_handler method
|
|
561
|
+
void* IndexRaBitQFastScan::make_knn_handler(
|
|
562
|
+
bool is_max,
|
|
563
|
+
int /*impl*/,
|
|
564
|
+
idx_t n,
|
|
565
|
+
idx_t k,
|
|
566
|
+
size_t /*ntotal*/,
|
|
567
|
+
float* distances,
|
|
568
|
+
idx_t* labels,
|
|
569
|
+
const IDSelector* sel,
|
|
570
|
+
const FastScanDistancePostProcessing& context) const {
|
|
571
|
+
if (is_max) {
|
|
572
|
+
return new RaBitQHeapHandler<CMax<uint16_t, int>, false>(
|
|
573
|
+
this, n, k, distances, labels, sel, context);
|
|
574
|
+
} else {
|
|
575
|
+
return new RaBitQHeapHandler<CMin<uint16_t, int>, false>(
|
|
576
|
+
this, n, k, distances, labels, sel, context);
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
// Explicit template instantiations for the required comparator types
|
|
581
|
+
template struct RaBitQHeapHandler<CMin<uint16_t, int>, false>;
|
|
582
|
+
template struct RaBitQHeapHandler<CMax<uint16_t, int>, false>;
|
|
583
|
+
template struct RaBitQHeapHandler<CMin<uint16_t, int>, true>;
|
|
584
|
+
template struct RaBitQHeapHandler<CMax<uint16_t, int>, true>;
|
|
585
|
+
|
|
586
|
+
} // namespace faiss
|