faiss 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +38 -28
- data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
- data/ext/faiss/kmeans.cpp +10 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
- data/ext/faiss/{utils.h → utils_rb.h} +4 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/IndexLattice.h>
|
|
11
11
|
#include <faiss/impl/FaissAssert.h>
|
|
12
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
12
13
|
#include <faiss/utils/distances.h>
|
|
13
14
|
#include <faiss/utils/hamming.h> // for the bitstring routines
|
|
14
15
|
|
|
@@ -44,17 +45,19 @@ void IndexLattice::train(idx_t n, const float* x) {
|
|
|
44
45
|
maxs[sq] = -1;
|
|
45
46
|
}
|
|
46
47
|
|
|
47
|
-
|
|
48
|
-
for (
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
maxs[sq]
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
mins[sq]
|
|
48
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
49
|
+
for (idx_t i = 0; i < n; i++) {
|
|
50
|
+
for (int sq = 0; sq < nsq; sq++) {
|
|
51
|
+
float norm2 = fvec_norm_L2sqr<SL>(x + i * d + sq * dsq, dsq);
|
|
52
|
+
if (norm2 > maxs[sq]) {
|
|
53
|
+
maxs[sq] = norm2;
|
|
54
|
+
}
|
|
55
|
+
if (norm2 < mins[sq]) {
|
|
56
|
+
mins[sq] = norm2;
|
|
57
|
+
}
|
|
55
58
|
}
|
|
56
59
|
}
|
|
57
|
-
}
|
|
60
|
+
});
|
|
58
61
|
|
|
59
62
|
for (int sq = 0; sq < nsq; sq++) {
|
|
60
63
|
mins[sq] = sqrtf(mins[sq]);
|
|
@@ -74,24 +77,26 @@ void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
|
|
|
74
77
|
const float* maxs = mins + nsq;
|
|
75
78
|
int64_t sc = int64_t(1) << scale_nbit;
|
|
76
79
|
|
|
80
|
+
with_simd_level([&]<SIMDLevel SL>() {
|
|
77
81
|
#pragma omp parallel for
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
82
|
+
for (idx_t i = 0; i < n; i++) {
|
|
83
|
+
BitstringWriter wr(codes + i * code_size, code_size);
|
|
84
|
+
const float* xi = x + i * d;
|
|
85
|
+
for (int j = 0; j < nsq; j++) {
|
|
86
|
+
float nj = (sqrtf(fvec_norm_L2sqr<SL>(xi, dsq)) - mins[j]) *
|
|
87
|
+
sc / (maxs[j] - mins[j]);
|
|
88
|
+
if (nj < 0) {
|
|
89
|
+
nj = 0;
|
|
90
|
+
}
|
|
91
|
+
if (nj >= sc) {
|
|
92
|
+
nj = sc - 1;
|
|
93
|
+
}
|
|
94
|
+
wr.write((int64_t)nj, scale_nbit);
|
|
95
|
+
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
|
|
96
|
+
xi += dsq;
|
|
89
97
|
}
|
|
90
|
-
wr.write((int64_t)nj, scale_nbit);
|
|
91
|
-
wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
|
|
92
|
-
xi += dsq;
|
|
93
98
|
}
|
|
94
|
-
}
|
|
99
|
+
});
|
|
95
100
|
}
|
|
96
101
|
|
|
97
102
|
void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
#include <faiss/IndexNNDescent.h>
|
|
17
17
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
18
18
|
#include <faiss/impl/FaissAssert.h>
|
|
19
|
+
#include <faiss/impl/VisitedTable.h>
|
|
19
20
|
#include <faiss/utils/distances.h>
|
|
20
21
|
|
|
21
22
|
namespace faiss {
|
|
@@ -74,7 +75,7 @@ void IndexNSG::search(
|
|
|
74
75
|
|
|
75
76
|
#pragma omp parallel
|
|
76
77
|
{
|
|
77
|
-
VisitedTable vt(ntotal);
|
|
78
|
+
VisitedTable vt(ntotal, nsg.use_visited_hashset);
|
|
78
79
|
|
|
79
80
|
std::unique_ptr<DistanceComputer> dis(
|
|
80
81
|
storage_distance_computer(storage));
|
|
@@ -24,7 +24,7 @@ IndexNeuralNetCodec::IndexNeuralNetCodec(
|
|
|
24
24
|
is_trained = false;
|
|
25
25
|
}
|
|
26
26
|
|
|
27
|
-
void IndexNeuralNetCodec::train(idx_t n
|
|
27
|
+
void IndexNeuralNetCodec::train(idx_t /*n*/, const float* /*x*/) {
|
|
28
28
|
FAISS_THROW_MSG("Training not implemented in C++, use Pytorch");
|
|
29
29
|
}
|
|
30
30
|
|
|
@@ -19,7 +19,8 @@
|
|
|
19
19
|
#include <faiss/impl/FaissAssert.h>
|
|
20
20
|
#include <faiss/utils/hamming.h>
|
|
21
21
|
|
|
22
|
-
#include <faiss/impl/
|
|
22
|
+
#include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
|
|
23
|
+
#include <faiss/impl/simd_dispatch.h>
|
|
23
24
|
|
|
24
25
|
namespace faiss {
|
|
25
26
|
|
|
@@ -72,8 +73,9 @@ void IndexPQ::train(idx_t n, const float* x) {
|
|
|
72
73
|
|
|
73
74
|
namespace {
|
|
74
75
|
|
|
75
|
-
template <class
|
|
76
|
+
template <class PQCodeDist>
|
|
76
77
|
struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
78
|
+
using PQDecoder = typename PQCodeDist::PQDecoder;
|
|
77
79
|
size_t d;
|
|
78
80
|
MetricType metric;
|
|
79
81
|
idx_t nb;
|
|
@@ -86,7 +88,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
86
88
|
float distance_to_code(const uint8_t* code) final {
|
|
87
89
|
ndis++;
|
|
88
90
|
|
|
89
|
-
float dis = distance_single_code
|
|
91
|
+
float dis = PQCodeDist::distance_single_code(
|
|
90
92
|
pq.M, pq.nbits, precomputed_table.data(), code);
|
|
91
93
|
return dis;
|
|
92
94
|
}
|
|
@@ -134,16 +136,23 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
|
134
136
|
}
|
|
135
137
|
};
|
|
136
138
|
|
|
139
|
+
template <SIMDLevel SL>
|
|
140
|
+
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1(
|
|
141
|
+
const IndexPQ& index) {
|
|
142
|
+
if (index.pq.nbits == 8) {
|
|
143
|
+
return new PQDistanceComputer<PQCodeDistance<PQDecoder8, SL>>(index);
|
|
144
|
+
} else if (index.pq.nbits == 16) {
|
|
145
|
+
return new PQDistanceComputer<PQCodeDistance<PQDecoder16, SL>>(index);
|
|
146
|
+
} else {
|
|
147
|
+
return new PQDistanceComputer<PQCodeDistance<PQDecoderGeneric, SL>>(
|
|
148
|
+
index);
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
137
152
|
} // namespace
|
|
138
153
|
|
|
139
154
|
FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
|
|
140
|
-
|
|
141
|
-
return new PQDistanceComputer<PQDecoder8>(*this);
|
|
142
|
-
} else if (pq.nbits == 16) {
|
|
143
|
-
return new PQDistanceComputer<PQDecoder16>(*this);
|
|
144
|
-
} else {
|
|
145
|
-
return new PQDistanceComputer<PQDecoderGeneric>(*this);
|
|
146
|
-
}
|
|
155
|
+
DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this);
|
|
147
156
|
}
|
|
148
157
|
|
|
149
158
|
/*****************************************
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
#include <faiss/IndexRaBitQ.h>
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
11
12
|
#include <faiss/impl/ResultHandler.h>
|
|
12
13
|
#include <memory>
|
|
13
14
|
|
|
@@ -16,6 +17,8 @@ namespace faiss {
|
|
|
16
17
|
// Forward declaration from RaBitQuantizer.cpp
|
|
17
18
|
struct RaBitQDistanceComputer;
|
|
18
19
|
|
|
20
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
21
|
+
|
|
19
22
|
IndexRaBitQ::IndexRaBitQ() = default;
|
|
20
23
|
|
|
21
24
|
IndexRaBitQ::IndexRaBitQ(idx_t d, MetricType metric, uint8_t nb_bits_in)
|
|
@@ -141,19 +144,29 @@ struct Run_search_with_dc_res {
|
|
|
141
144
|
|
|
142
145
|
local_1bit_evaluations++;
|
|
143
146
|
|
|
144
|
-
// Stage 1: Compute 1-bit
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
//
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
//
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
147
|
+
// Stage 1: Compute distance bound using 1-bit codes
|
|
148
|
+
// For L2 (min-heap): use lower_bound (est -
|
|
149
|
+
// error) For IP (max-heap): use upper_bound (est
|
|
150
|
+
// + error)
|
|
151
|
+
float est_distance =
|
|
152
|
+
dc->distance_to_code_1bit(code);
|
|
153
|
+
|
|
154
|
+
// Extract f_error for filtering
|
|
155
|
+
size_t code_size_base = (index->d + 7) / 8;
|
|
156
|
+
const rabitq_utils::SignBitFactorsWithError*
|
|
157
|
+
base_fac = reinterpret_cast<
|
|
158
|
+
const rabitq_utils::
|
|
159
|
+
SignBitFactorsWithError*>(
|
|
160
|
+
code + code_size_base);
|
|
161
|
+
|
|
162
|
+
// Stage 2: Adaptive filtering
|
|
163
|
+
bool should_refine =
|
|
164
|
+
rabitq_utils::should_refine_candidate(
|
|
165
|
+
est_distance,
|
|
166
|
+
base_fac->f_error,
|
|
167
|
+
dc->g_error,
|
|
168
|
+
resi.threshold,
|
|
169
|
+
is_similarity);
|
|
157
170
|
if (should_refine) {
|
|
158
171
|
local_multibit_evaluations++;
|
|
159
172
|
// Compute full multi-bit distance
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
namespace faiss {
|
|
15
15
|
|
|
16
16
|
struct RaBitQSearchParameters : SearchParameters {
|
|
17
|
-
uint8_t qb =
|
|
17
|
+
uint8_t qb = 4;
|
|
18
18
|
bool centered = false;
|
|
19
19
|
};
|
|
20
20
|
|
|
@@ -26,7 +26,7 @@ struct IndexRaBitQ : IndexFlatCodes {
|
|
|
26
26
|
|
|
27
27
|
// the default number of bits to quantize a query with.
|
|
28
28
|
// use '0' to disable quantization and use raw fp32 values.
|
|
29
|
-
uint8_t qb =
|
|
29
|
+
uint8_t qb = 4;
|
|
30
30
|
|
|
31
31
|
// quantize the query with a zero-centered scalar quantizer.
|
|
32
32
|
bool centered = false;
|
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
8
|
#include <faiss/IndexRaBitQFastScan.h>
|
|
9
|
+
#include <faiss/impl/CodePackerRaBitQ.h>
|
|
9
10
|
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
10
11
|
#include <faiss/impl/RaBitQUtils.h>
|
|
11
12
|
#include <faiss/impl/RaBitQuantizerMultiBit.h>
|
|
@@ -21,17 +22,7 @@ static inline size_t roundup(size_t a, size_t b) {
|
|
|
21
22
|
}
|
|
22
23
|
|
|
23
24
|
size_t IndexRaBitQFastScan::compute_per_vector_storage_size() const {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
if (ex_bits == 0) {
|
|
27
|
-
// 1-bit: only SignBitFactors
|
|
28
|
-
return sizeof(rabitq_utils::SignBitFactors);
|
|
29
|
-
} else {
|
|
30
|
-
// Multi-bit: SignBitFactorsWithError + ExtraBitsFactors +
|
|
31
|
-
// mag-codes
|
|
32
|
-
return sizeof(SignBitFactorsWithError) + sizeof(ExtraBitsFactors) +
|
|
33
|
-
(d * ex_bits + 7) / 8;
|
|
34
|
-
}
|
|
25
|
+
return rabitq_utils::compute_per_vector_storage_size(rabitq.nb_bits, d);
|
|
35
26
|
}
|
|
36
27
|
|
|
37
28
|
IndexRaBitQFastScan::IndexRaBitQFastScan() = default;
|
|
@@ -64,9 +55,51 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(
|
|
|
64
55
|
// Set RaBitQ-specific parameters
|
|
65
56
|
qb = 8;
|
|
66
57
|
center.resize(d, 0.0f);
|
|
58
|
+
}
|
|
67
59
|
|
|
68
|
-
|
|
69
|
-
|
|
60
|
+
CodePacker* IndexRaBitQFastScan::get_CodePacker() const {
|
|
61
|
+
return new CodePackerRaBitQ(M2, bbs, compute_per_vector_storage_size());
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
size_t IndexRaBitQFastScan::remove_ids(const IDSelector& sel) {
|
|
65
|
+
const size_t block_stride = get_block_stride();
|
|
66
|
+
|
|
67
|
+
idx_t j = 0;
|
|
68
|
+
std::vector<uint8_t> buffer(code_size);
|
|
69
|
+
std::unique_ptr<CodePacker> packer(get_CodePacker());
|
|
70
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
71
|
+
if (sel.is_member(i)) {
|
|
72
|
+
} else {
|
|
73
|
+
if (i > j) {
|
|
74
|
+
packer->unpack_1(codes.data(), i, buffer.data());
|
|
75
|
+
packer->pack_1(buffer.data(), j, codes.data());
|
|
76
|
+
}
|
|
77
|
+
j++;
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
size_t nremove = ntotal - j;
|
|
81
|
+
if (nremove > 0) {
|
|
82
|
+
ntotal = j;
|
|
83
|
+
ntotal2 = roundup(ntotal, bbs);
|
|
84
|
+
size_t new_size = ntotal2 / bbs * block_stride;
|
|
85
|
+
|
|
86
|
+
// Zero out stale data in the last block beyond the retained vectors.
|
|
87
|
+
// This is necessary because pq4_pack_codes_range uses |= to write
|
|
88
|
+
// new codes, so any stale non-zero nibbles would corrupt future adds.
|
|
89
|
+
// pack_1 with a zero buffer zeroes both PQ4 codes and aux data.
|
|
90
|
+
const size_t last_pos = ntotal % bbs;
|
|
91
|
+
if (last_pos > 0) {
|
|
92
|
+
const size_t last_block = ntotal / bbs;
|
|
93
|
+
std::vector<uint8_t> zero_code(code_size, 0);
|
|
94
|
+
for (size_t pos = last_pos; pos < bbs; pos++) {
|
|
95
|
+
packer->pack_1(
|
|
96
|
+
zero_code.data(), last_block * bbs + pos, codes.data());
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
codes.resize(new_size);
|
|
101
|
+
}
|
|
102
|
+
return nremove;
|
|
70
103
|
}
|
|
71
104
|
|
|
72
105
|
IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
@@ -104,58 +137,59 @@ IndexRaBitQFastScan::IndexRaBitQFastScan(const IndexRaBitQ& orig, int bbs)
|
|
|
104
137
|
|
|
105
138
|
// If the original index has data, extract factors and pack codes
|
|
106
139
|
if (ntotal > 0) {
|
|
107
|
-
// Compute per-vector storage size for flat storage
|
|
108
140
|
const size_t storage_size = compute_per_vector_storage_size();
|
|
109
|
-
|
|
110
|
-
// Allocate flat storage
|
|
111
|
-
flat_storage.resize(ntotal * storage_size);
|
|
112
|
-
|
|
113
|
-
// Copy factors directly from original codes
|
|
114
141
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
115
|
-
for (idx_t i = 0; i < ntotal; i++) {
|
|
116
|
-
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
117
|
-
const uint8_t* source_factors_ptr = orig_code + bit_pattern_size;
|
|
118
|
-
uint8_t* storage = flat_storage.data() + i * storage_size;
|
|
119
|
-
memcpy(storage, source_factors_ptr, storage_size);
|
|
120
|
-
}
|
|
121
142
|
|
|
122
143
|
// Convert RaBitQ bit format to FastScan 4-bit sub-quantizer format
|
|
123
|
-
// This follows the same pattern as IndexPQFastScan constructor
|
|
124
144
|
AlignedTable<uint8_t> fastscan_codes(ntotal * code_size);
|
|
125
145
|
memset(fastscan_codes.get(), 0, ntotal * code_size);
|
|
126
146
|
|
|
127
|
-
// Convert from RaBitQ 1-bit-per-dimension to FastScan
|
|
128
|
-
// 4-bit-per-sub-quantizer
|
|
129
147
|
for (idx_t i = 0; i < ntotal; i++) {
|
|
130
148
|
const uint8_t* orig_code = orig.codes.data() + i * orig.code_size;
|
|
131
149
|
uint8_t* fs_code = fastscan_codes.get() + i * code_size;
|
|
132
150
|
|
|
133
|
-
// Convert each dimension's bit (same logic as compute_codes)
|
|
134
151
|
for (size_t j = 0; j < orig.d; j++) {
|
|
135
|
-
// Extract bit from original RaBitQ format
|
|
136
152
|
const size_t orig_byte_idx = j / 8;
|
|
137
153
|
const size_t orig_bit_offset = j % 8;
|
|
138
154
|
const bool bit_value =
|
|
139
155
|
(orig_code[orig_byte_idx] >> orig_bit_offset) & 1;
|
|
140
156
|
|
|
141
|
-
// Use RaBitQUtils for consistent bit setting
|
|
142
157
|
if (bit_value) {
|
|
143
158
|
rabitq_utils::set_bit_fastscan(fs_code, j);
|
|
144
159
|
}
|
|
145
160
|
}
|
|
146
161
|
}
|
|
147
162
|
|
|
148
|
-
// Pack the converted codes using
|
|
149
|
-
|
|
150
|
-
|
|
163
|
+
// Pack the converted codes using enlarged block layout
|
|
164
|
+
const size_t block_stride = get_block_stride();
|
|
165
|
+
const size_t n_blocks = ntotal2 / bbs;
|
|
166
|
+
codes.resize(n_blocks * block_stride);
|
|
167
|
+
memset(codes.get(), 0, n_blocks * block_stride);
|
|
168
|
+
pq4_pack_codes_range(
|
|
151
169
|
fastscan_codes.get(),
|
|
152
|
-
ntotal,
|
|
153
170
|
M,
|
|
154
|
-
|
|
171
|
+
0,
|
|
172
|
+
ntotal,
|
|
155
173
|
bbs,
|
|
156
174
|
M2,
|
|
157
175
|
codes.get(),
|
|
158
|
-
code_size
|
|
176
|
+
code_size,
|
|
177
|
+
block_stride);
|
|
178
|
+
|
|
179
|
+
// Copy auxiliary data from original codes into block aux region
|
|
180
|
+
const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
|
|
181
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
182
|
+
const uint8_t* src =
|
|
183
|
+
orig.codes.data() + i * orig.code_size + bit_pattern_size;
|
|
184
|
+
uint8_t* dst = rabitq_utils::get_block_aux_ptr(
|
|
185
|
+
codes.get(),
|
|
186
|
+
i,
|
|
187
|
+
bbs,
|
|
188
|
+
packed_block_size,
|
|
189
|
+
block_stride,
|
|
190
|
+
storage_size);
|
|
191
|
+
memcpy(dst, src, storage_size);
|
|
192
|
+
}
|
|
159
193
|
}
|
|
160
194
|
}
|
|
161
195
|
|
|
@@ -204,23 +238,13 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
|
|
|
204
238
|
compute_codes(tmp_codes.get(), n, x);
|
|
205
239
|
|
|
206
240
|
const size_t storage_size = compute_per_vector_storage_size();
|
|
207
|
-
flat_storage.resize((ntotal + n) * storage_size);
|
|
208
|
-
|
|
209
|
-
// Populate flat storage (no sign bits copying needed!)
|
|
210
241
|
const size_t bit_pattern_size = (d + 7) / 8;
|
|
211
|
-
for (idx_t i = 0; i < n; i++) {
|
|
212
|
-
const uint8_t* code = tmp_codes.get() + i * code_size;
|
|
213
|
-
const idx_t vec_idx = ntotal + i;
|
|
214
|
-
|
|
215
|
-
// Copy factors data directly to flat storage (no reordering needed)
|
|
216
|
-
const uint8_t* source_factors_ptr = code + bit_pattern_size;
|
|
217
|
-
uint8_t* storage = flat_storage.data() + vec_idx * storage_size;
|
|
218
|
-
memcpy(storage, source_factors_ptr, storage_size);
|
|
219
|
-
}
|
|
220
242
|
|
|
221
|
-
// Resize main storage
|
|
243
|
+
// Resize main storage with enlarged block layout
|
|
222
244
|
ntotal2 = roundup(ntotal + n, bbs);
|
|
223
|
-
size_t
|
|
245
|
+
const size_t block_stride = get_block_stride();
|
|
246
|
+
const size_t n_blocks = ntotal2 / bbs;
|
|
247
|
+
size_t new_size = n_blocks * block_stride;
|
|
224
248
|
size_t old_size = codes.size();
|
|
225
249
|
if (new_size > old_size) {
|
|
226
250
|
codes.resize(new_size);
|
|
@@ -230,13 +254,27 @@ void IndexRaBitQFastScan::add(idx_t n, const float* x) {
|
|
|
230
254
|
// Use our custom packing function with correct stride
|
|
231
255
|
pq4_pack_codes_range(
|
|
232
256
|
tmp_codes.get(),
|
|
233
|
-
M,
|
|
257
|
+
M,
|
|
234
258
|
ntotal,
|
|
235
|
-
ntotal + n,
|
|
259
|
+
ntotal + n,
|
|
236
260
|
bbs,
|
|
237
|
-
M2,
|
|
238
|
-
codes.get(),
|
|
239
|
-
code_size
|
|
261
|
+
M2,
|
|
262
|
+
codes.get(),
|
|
263
|
+
code_size,
|
|
264
|
+
block_stride);
|
|
265
|
+
|
|
266
|
+
const size_t packed_block_size = ((M2 + 1) / 2) * bbs;
|
|
267
|
+
for (idx_t i = 0; i < n; i++) {
|
|
268
|
+
const uint8_t* src = tmp_codes.get() + i * code_size + bit_pattern_size;
|
|
269
|
+
uint8_t* dst = rabitq_utils::get_block_aux_ptr(
|
|
270
|
+
codes.get(),
|
|
271
|
+
ntotal + i,
|
|
272
|
+
bbs,
|
|
273
|
+
packed_block_size,
|
|
274
|
+
block_stride,
|
|
275
|
+
storage_size);
|
|
276
|
+
memcpy(dst, src, storage_size);
|
|
277
|
+
}
|
|
240
278
|
|
|
241
279
|
ntotal += n;
|
|
242
280
|
}
|
|
@@ -502,7 +540,11 @@ RaBitQHeapHandler<C, with_id_map>::RaBitQHeapHandler(
|
|
|
502
540
|
nq(nq_val),
|
|
503
541
|
k(k_val),
|
|
504
542
|
context(ctx),
|
|
505
|
-
is_multi_bit(multi_bit)
|
|
543
|
+
is_multi_bit(multi_bit),
|
|
544
|
+
storage_size(index->compute_per_vector_storage_size()),
|
|
545
|
+
packed_block_size(((index->M2 + 1) / 2) * index->bbs),
|
|
546
|
+
full_block_size(index->get_block_stride()),
|
|
547
|
+
packer(index->get_CodePacker()) {
|
|
506
548
|
// Initialize heaps for all queries in constructor
|
|
507
549
|
// This allows us to support direct normalizer assignment
|
|
508
550
|
#pragma omp parallel for if (nq > 100)
|
|
@@ -543,8 +585,11 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
|
543
585
|
? std::min<size_t>(32, rabitq_index->ntotal - base_db_idx)
|
|
544
586
|
: 0;
|
|
545
587
|
|
|
546
|
-
//
|
|
547
|
-
|
|
588
|
+
// Compute block auxiliary region base pointer once per batch.
|
|
589
|
+
// Since bbs=32, each batch of 32 vectors aligns to one block.
|
|
590
|
+
const size_t block_idx = base_db_idx / rabitq_index->bbs;
|
|
591
|
+
const uint8_t* aux_base = rabitq_index->codes.get() +
|
|
592
|
+
block_idx * full_block_size + packed_block_size;
|
|
548
593
|
|
|
549
594
|
// Stats tracking for multi-bit two-stage search only
|
|
550
595
|
// n_1bit_evaluations: candidates evaluated using 1-bit lower bound
|
|
@@ -559,9 +604,8 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
|
559
604
|
// Normalize distance from LUT lookup
|
|
560
605
|
const float normalized_distance = d32tab[i] * one_a + bias;
|
|
561
606
|
|
|
562
|
-
// Access factors from
|
|
563
|
-
const uint8_t* base_ptr =
|
|
564
|
-
rabitq_index->flat_storage.data() + db_idx * storage_size;
|
|
607
|
+
// Access factors from block auxiliary region
|
|
608
|
+
const uint8_t* base_ptr = aux_base + i * storage_size;
|
|
565
609
|
|
|
566
610
|
if (is_multi_bit) {
|
|
567
611
|
// Track candidates actually considered for two-stage filtering
|
|
@@ -578,14 +622,16 @@ void RaBitQHeapHandler<C, with_id_map>::handle(
|
|
|
578
622
|
rabitq_index->qb,
|
|
579
623
|
rabitq_index->d);
|
|
580
624
|
|
|
581
|
-
float lower_bound = compute_lower_bound(dist_1bit, db_idx, q);
|
|
582
|
-
|
|
583
625
|
// Adaptive filtering: decide whether to compute full distance
|
|
584
626
|
const bool is_similarity = rabitq_index->metric_type ==
|
|
585
627
|
MetricType::METRIC_INNER_PRODUCT;
|
|
586
|
-
bool should_refine =
|
|
587
|
-
|
|
588
|
-
|
|
628
|
+
bool should_refine = rabitq_utils::should_refine_candidate(
|
|
629
|
+
dist_1bit,
|
|
630
|
+
full_factors.f_error,
|
|
631
|
+
context.query_factors ? context.query_factors[q].g_error
|
|
632
|
+
: 0.0f,
|
|
633
|
+
heap_dis[0],
|
|
634
|
+
is_similarity);
|
|
589
635
|
|
|
590
636
|
if (should_refine) {
|
|
591
637
|
local_multibit_evaluations++;
|
|
@@ -647,10 +693,14 @@ float RaBitQHeapHandler<C, with_id_map>::compute_lower_bound(
|
|
|
647
693
|
float dist_1bit,
|
|
648
694
|
size_t db_idx,
|
|
649
695
|
size_t q) const {
|
|
650
|
-
// Access f_error
|
|
651
|
-
const
|
|
652
|
-
|
|
653
|
-
|
|
696
|
+
// Access f_error from block auxiliary region
|
|
697
|
+
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
698
|
+
rabitq_index->codes.get(),
|
|
699
|
+
db_idx,
|
|
700
|
+
rabitq_index->bbs,
|
|
701
|
+
packed_block_size,
|
|
702
|
+
full_block_size,
|
|
703
|
+
storage_size);
|
|
654
704
|
const SignBitFactorsWithError& db_factors =
|
|
655
705
|
*reinterpret_cast<const SignBitFactorsWithError*>(base_ptr);
|
|
656
706
|
float f_error = db_factors.f_error;
|
|
@@ -674,9 +724,13 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
|
|
|
674
724
|
const size_t ex_bits = rabitq_index->rabitq.nb_bits - 1;
|
|
675
725
|
const size_t dim = rabitq_index->d;
|
|
676
726
|
|
|
677
|
-
const
|
|
678
|
-
|
|
679
|
-
|
|
727
|
+
const uint8_t* base_ptr = rabitq_utils::get_block_aux_ptr(
|
|
728
|
+
rabitq_index->codes.get(),
|
|
729
|
+
db_idx,
|
|
730
|
+
rabitq_index->bbs,
|
|
731
|
+
packed_block_size,
|
|
732
|
+
full_block_size,
|
|
733
|
+
storage_size);
|
|
680
734
|
|
|
681
735
|
const size_t ex_code_size = (dim * ex_bits + 7) / 8;
|
|
682
736
|
const uint8_t* ex_code = base_ptr + sizeof(SignBitFactorsWithError);
|
|
@@ -689,8 +743,7 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
|
|
|
689
743
|
|
|
690
744
|
// Get sign bits from FastScan packed format
|
|
691
745
|
std::vector<uint8_t> unpacked_code(rabitq_index->code_size);
|
|
692
|
-
|
|
693
|
-
packer.unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
|
|
746
|
+
packer->unpack_1(rabitq_index->codes.get(), db_idx, unpacked_code.data());
|
|
694
747
|
const uint8_t* sign_bits = unpacked_code.data();
|
|
695
748
|
|
|
696
749
|
return rabitq_utils::compute_full_multibit_distance(
|
|
@@ -698,8 +751,9 @@ float RaBitQHeapHandler<C, with_id_map>::compute_full_multibit_distance(
|
|
|
698
751
|
ex_code,
|
|
699
752
|
ex_fac,
|
|
700
753
|
query_factors.rotated_q.data(),
|
|
701
|
-
|
|
702
|
-
|
|
754
|
+
(rabitq_index->metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
755
|
+
? query_factors.q_dot_c
|
|
756
|
+
: query_factors.qr_to_c_L2sqr,
|
|
703
757
|
dim,
|
|
704
758
|
ex_bits,
|
|
705
759
|
rabitq_index->metric_type);
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
|
+
#include <memory>
|
|
10
11
|
#include <vector>
|
|
11
12
|
|
|
12
13
|
#include <faiss/IndexFastScan.h>
|
|
@@ -43,17 +44,6 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
43
44
|
/// Center of all points (same as IndexRaBitQ)
|
|
44
45
|
std::vector<float> center;
|
|
45
46
|
|
|
46
|
-
/// Per-vector auxiliary data (1-bit codes stored separately in `codes`)
|
|
47
|
-
///
|
|
48
|
-
/// 1-bit codes (sign bits) are stored in the inherited `codes` array from
|
|
49
|
-
/// IndexFastScan in packed FastScan format for SIMD processing.
|
|
50
|
-
///
|
|
51
|
-
/// This flat_storage holds per-vector factors and refinement-bit codes:
|
|
52
|
-
/// Layout for 1-bit: [SignBitFactors (8 bytes)]
|
|
53
|
-
/// Layout for multi-bit: [SignBitFactorsWithError
|
|
54
|
-
/// (12B)][ref_codes][ExtraBitsFactors (8B)]
|
|
55
|
-
std::vector<uint8_t> flat_storage;
|
|
56
|
-
|
|
57
47
|
/// Default number of bits to quantize a query with
|
|
58
48
|
uint8_t qb = 8;
|
|
59
49
|
|
|
@@ -77,7 +67,7 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
77
67
|
|
|
78
68
|
void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
|
|
79
69
|
|
|
80
|
-
/// Compute
|
|
70
|
+
/// Compute per-vector auxiliary data size in block aux region
|
|
81
71
|
size_t compute_per_vector_storage_size() const;
|
|
82
72
|
|
|
83
73
|
void compute_float_LUT(
|
|
@@ -88,6 +78,12 @@ struct IndexRaBitQFastScan : IndexFastScan {
|
|
|
88
78
|
|
|
89
79
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
90
80
|
|
|
81
|
+
/// Return CodePackerRaBitQ with enlarged block size
|
|
82
|
+
CodePacker* get_CodePacker() const override;
|
|
83
|
+
|
|
84
|
+
/// Remove vectors and compact both PQ4 codes and auxiliary data
|
|
85
|
+
size_t remove_ids(const IDSelector& sel) override;
|
|
86
|
+
|
|
91
87
|
void search(
|
|
92
88
|
idx_t n,
|
|
93
89
|
const float* x,
|
|
@@ -141,6 +137,12 @@ struct RaBitQHeapHandler
|
|
|
141
137
|
context; // Processing context with query offset
|
|
142
138
|
const bool is_multi_bit; // Runtime flag for multi-bit mode
|
|
143
139
|
|
|
140
|
+
// Cached block-layout constants (invariant for handler lifetime)
|
|
141
|
+
const size_t storage_size;
|
|
142
|
+
const size_t packed_block_size;
|
|
143
|
+
const size_t full_block_size;
|
|
144
|
+
std::unique_ptr<CodePacker> packer; // cached for unpack in hot path
|
|
145
|
+
|
|
144
146
|
// Use float-based comparator for heap operations
|
|
145
147
|
using Cfloat = typename std::conditional<
|
|
146
148
|
C::is_max,
|