faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <faiss/MetricType.h>
|
|
11
|
+
#include <faiss/impl/platform_macros.h>
|
|
12
|
+
#include <cstddef>
|
|
13
|
+
#include <cstdint>
|
|
14
|
+
#include <vector>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
namespace rabitq_utils {
|
|
18
|
+
|
|
19
|
+
/** Base factors computed per database vector for RaBitQ distance computation.
|
|
20
|
+
* Used by both 1-bit and multi-bit RaBitQ variants.
|
|
21
|
+
* These can be stored either embedded in codes (IndexRaBitQ) or separately
|
|
22
|
+
* (IndexRaBitQFastScan).
|
|
23
|
+
*
|
|
24
|
+
* For 1-bit mode only - contains the minimal factors needed for distance
|
|
25
|
+
* estimation using just sign bits.
|
|
26
|
+
*/
|
|
27
|
+
FAISS_PACK_STRUCTS_BEGIN
|
|
28
|
+
struct FAISS_PACKED SignBitFactors {
|
|
29
|
+
// ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0)
|
|
30
|
+
float or_minus_c_l2sqr = 0;
|
|
31
|
+
float dp_multiplier = 0;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
/** Extended factors for multi-bit RaBitQ (nb_bits > 1).
|
|
35
|
+
* Includes error bound for lower bound computation in two-stage search.
|
|
36
|
+
* Inherits base factors to maintain layout compatibility.
|
|
37
|
+
*
|
|
38
|
+
* Used in multi-bit mode - the error bound enables quick filtering of
|
|
39
|
+
* unlikely candidates in the first stage of two-stage search.
|
|
40
|
+
*/
|
|
41
|
+
struct FAISS_PACKED SignBitFactorsWithError : SignBitFactors {
|
|
42
|
+
// Error bound for lower bound computation in two-stage search
|
|
43
|
+
// Used in formula: lower_bound = est_distance - f_error * g_error
|
|
44
|
+
// Only allocated when nb_bits > 1
|
|
45
|
+
float f_error = 0;
|
|
46
|
+
};
|
|
47
|
+
|
|
48
|
+
/** Additional factors for multi-bit RaBitQ (nb_bits > 1).
|
|
49
|
+
* Used to store normalization and scaling factors for the refinement bits
|
|
50
|
+
* that encode additional precision beyond the sign bit.
|
|
51
|
+
*/
|
|
52
|
+
struct FAISS_PACKED ExtraBitsFactors {
|
|
53
|
+
// Additive correction factor for refinement bit reconstruction
|
|
54
|
+
float f_add_ex = 0;
|
|
55
|
+
// Scaling/rescaling factor for refinement bit reconstruction
|
|
56
|
+
float f_rescale_ex = 0;
|
|
57
|
+
};
|
|
58
|
+
FAISS_PACK_STRUCTS_END
|
|
59
|
+
|
|
60
|
+
/** Query-specific factors computed during search for RaBitQ distance
|
|
61
|
+
* computation. Used by both IndexRaBitQ and IndexRaBitQFastScan
|
|
62
|
+
* implementations.
|
|
63
|
+
*/
|
|
64
|
+
struct QueryFactorsData {
|
|
65
|
+
float c1 = 0;
|
|
66
|
+
float c2 = 0;
|
|
67
|
+
float c34 = 0;
|
|
68
|
+
|
|
69
|
+
float qr_to_c_L2sqr = 0;
|
|
70
|
+
float qr_norm_L2sqr = 0;
|
|
71
|
+
|
|
72
|
+
float int_dot_scale = 1;
|
|
73
|
+
|
|
74
|
+
float g_error = 0;
|
|
75
|
+
std::vector<float> rotated_q;
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
/** Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
|
|
79
|
+
* L2 reconstruction error. Shared between all RaBitQ implementations.
|
|
80
|
+
*/
|
|
81
|
+
FAISS_API extern const float Z_MAX_BY_QB[8];
|
|
82
|
+
|
|
83
|
+
/** Compute factors for a single database vector using RaBitQ algorithm.
|
|
84
|
+
* This function consolidates the mathematical logic that was duplicated
|
|
85
|
+
* between IndexRaBitQ and IndexRaBitQFastScan.
|
|
86
|
+
*
|
|
87
|
+
* @param x input vector (d dimensions)
|
|
88
|
+
* @param d dimensionality
|
|
89
|
+
* @param centroid database centroid (nullptr if not used)
|
|
90
|
+
* @param metric_type distance metric (L2 or Inner Product)
|
|
91
|
+
* @param compute_error whether to compute f_error (false for 1-bit mode)
|
|
92
|
+
* @return computed factors for distance computation
|
|
93
|
+
*/
|
|
94
|
+
SignBitFactorsWithError compute_vector_factors(
|
|
95
|
+
const float* x,
|
|
96
|
+
size_t d,
|
|
97
|
+
const float* centroid,
|
|
98
|
+
MetricType metric_type,
|
|
99
|
+
bool compute_error = true);
|
|
100
|
+
|
|
101
|
+
/** Compute intermediate values needed for vector factor computation.
|
|
102
|
+
* Separated out to allow different bit packing strategies while sharing
|
|
103
|
+
* the core mathematical computation.
|
|
104
|
+
*
|
|
105
|
+
* @param x input vector (d dimensions)
|
|
106
|
+
* @param d dimensionality
|
|
107
|
+
* @param centroid database centroid (nullptr if not used)
|
|
108
|
+
* @param norm_L2sqr output: ||or - c||^2
|
|
109
|
+
* @param or_L2sqr output: ||or||^2
|
|
110
|
+
* @param dp_oO output: sum of |or_i - c_i| (absolute deviations)
|
|
111
|
+
*/
|
|
112
|
+
void compute_vector_intermediate_values(
|
|
113
|
+
const float* x,
|
|
114
|
+
size_t d,
|
|
115
|
+
const float* centroid,
|
|
116
|
+
float& norm_L2sqr,
|
|
117
|
+
float& or_L2sqr,
|
|
118
|
+
float& dp_oO);
|
|
119
|
+
|
|
120
|
+
/** Compute final factors from intermediate values.
|
|
121
|
+
* @param norm_L2sqr ||or - c||^2
|
|
122
|
+
* @param or_L2sqr ||or||^2
|
|
123
|
+
* @param dp_oO sum of |or_i - c_i|
|
|
124
|
+
* @param d dimensionality
|
|
125
|
+
* @param metric_type distance metric
|
|
126
|
+
* @param compute_error whether to compute f_error (false for 1-bit mode)
|
|
127
|
+
* @return computed factors
|
|
128
|
+
*/
|
|
129
|
+
SignBitFactorsWithError compute_factors_from_intermediates(
|
|
130
|
+
float norm_L2sqr,
|
|
131
|
+
float or_L2sqr,
|
|
132
|
+
float dp_oO,
|
|
133
|
+
size_t d,
|
|
134
|
+
MetricType metric_type,
|
|
135
|
+
bool compute_error = true);
|
|
136
|
+
|
|
137
|
+
/** Compute query factors for RaBitQ distance computation.
|
|
138
|
+
* This consolidates the query processing logic shared between implementations.
|
|
139
|
+
*
|
|
140
|
+
* @param query query vector (d dimensions)
|
|
141
|
+
* @param d dimensionality
|
|
142
|
+
* @param centroid database centroid (nullptr if not used)
|
|
143
|
+
* @param qb number of quantization bits (1-8)
|
|
144
|
+
* @param centered whether to use centered quantization
|
|
145
|
+
* @param metric_type distance metric
|
|
146
|
+
* @param rotated_q output: query - centroid
|
|
147
|
+
* @param rotated_qq output: quantized query values
|
|
148
|
+
* @return computed query factors
|
|
149
|
+
*/
|
|
150
|
+
QueryFactorsData compute_query_factors(
|
|
151
|
+
const float* query,
|
|
152
|
+
size_t d,
|
|
153
|
+
const float* centroid,
|
|
154
|
+
uint8_t qb,
|
|
155
|
+
bool centered,
|
|
156
|
+
MetricType metric_type,
|
|
157
|
+
std::vector<float>& rotated_q,
|
|
158
|
+
std::vector<uint8_t>& rotated_qq);
|
|
159
|
+
|
|
160
|
+
/** Extract bit value from RaBitQ code in standard format.
|
|
161
|
+
* Used by IndexRaBitQ which stores bits sequentially.
|
|
162
|
+
*
|
|
163
|
+
* @param code RaBitQ code data
|
|
164
|
+
* @param bit_index which bit to extract (0 to d-1)
|
|
165
|
+
* @return bit value (true/false)
|
|
166
|
+
*/
|
|
167
|
+
bool extract_bit_standard(const uint8_t* code, size_t bit_index);
|
|
168
|
+
|
|
169
|
+
/** Extract bit value from FastScan code format.
|
|
170
|
+
* Used by IndexRaBitQFastScan which packs bits into 4-bit sub-quantizers.
|
|
171
|
+
*
|
|
172
|
+
* @param code FastScan code data
|
|
173
|
+
* @param bit_index which bit to extract (0 to d-1)
|
|
174
|
+
* @return bit value (true/false)
|
|
175
|
+
*/
|
|
176
|
+
bool extract_bit_fastscan(const uint8_t* code, size_t bit_index);
|
|
177
|
+
|
|
178
|
+
/** Set bit value in standard RaBitQ code format.
|
|
179
|
+
* @param code RaBitQ code data to modify
|
|
180
|
+
* @param bit_index which bit to set (0 to d-1)
|
|
181
|
+
*/
|
|
182
|
+
void set_bit_standard(uint8_t* code, size_t bit_index);
|
|
183
|
+
|
|
184
|
+
/** Set bit value in FastScan code format.
|
|
185
|
+
* @param code FastScan code data to modify
|
|
186
|
+
* @param bit_index which bit to set (0 to d-1)
|
|
187
|
+
*/
|
|
188
|
+
void set_bit_fastscan(uint8_t* code, size_t bit_index);
|
|
189
|
+
|
|
190
|
+
/** Compute adjusted 1-bit distance from normalized LUT distance.
|
|
191
|
+
* This is the core distance formula shared by all RaBitQ handlers.
|
|
192
|
+
*
|
|
193
|
+
* @param normalized_distance Distance from SIMD LUT lookup (after
|
|
194
|
+
* normalization)
|
|
195
|
+
* @param db_factors Database vector factors (SignBitFactors or
|
|
196
|
+
* SignBitFactorsWithError)
|
|
197
|
+
* @param query_factors Query factors computed during search
|
|
198
|
+
* @param centered Whether centered quantization is used
|
|
199
|
+
* @param qb Number of quantization bits
|
|
200
|
+
* @param d Dimensionality
|
|
201
|
+
* @return Adjusted distance value
|
|
202
|
+
*/
|
|
203
|
+
inline float compute_1bit_adjusted_distance(
|
|
204
|
+
float normalized_distance,
|
|
205
|
+
const SignBitFactors& db_factors,
|
|
206
|
+
const QueryFactorsData& query_factors,
|
|
207
|
+
bool centered,
|
|
208
|
+
size_t qb,
|
|
209
|
+
size_t d) {
|
|
210
|
+
float adjusted_distance;
|
|
211
|
+
|
|
212
|
+
if (centered) {
|
|
213
|
+
// For centered mode: normalized_distance contains the raw XOR
|
|
214
|
+
// contribution. Apply the signed odd integer quantization formula:
|
|
215
|
+
// int_dot = ((1 << qb) - 1) * d - 2 * xor_dot_product
|
|
216
|
+
int64_t int_dot = ((1 << qb) - 1) * d;
|
|
217
|
+
int_dot -= 2 * static_cast<int64_t>(normalized_distance);
|
|
218
|
+
|
|
219
|
+
adjusted_distance = query_factors.qr_to_c_L2sqr +
|
|
220
|
+
db_factors.or_minus_c_l2sqr -
|
|
221
|
+
2 * db_factors.dp_multiplier * int_dot *
|
|
222
|
+
query_factors.int_dot_scale;
|
|
223
|
+
} else {
|
|
224
|
+
// For non-centered quantization: use traditional formula
|
|
225
|
+
float final_dot = normalized_distance - query_factors.c34;
|
|
226
|
+
adjusted_distance = db_factors.or_minus_c_l2sqr +
|
|
227
|
+
query_factors.qr_to_c_L2sqr -
|
|
228
|
+
2 * db_factors.dp_multiplier * final_dot;
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// Apply inner product correction if needed
|
|
232
|
+
if (query_factors.qr_norm_L2sqr != 0.0f) {
|
|
233
|
+
adjusted_distance =
|
|
234
|
+
-0.5f * (adjusted_distance - query_factors.qr_norm_L2sqr);
|
|
235
|
+
} else {
|
|
236
|
+
adjusted_distance = std::max(0.0f, adjusted_distance);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
return adjusted_distance;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/** Extract multi-bit code on-the-fly from packed ex-bit codes.
|
|
243
|
+
* This inline function extracts a single code value without unpacking the
|
|
244
|
+
* entire array, enabling efficient on-the-fly decoding during distance
|
|
245
|
+
* computation.
|
|
246
|
+
*
|
|
247
|
+
* @param ex_code packed ex-bit codes
|
|
248
|
+
* @param index which code to extract (0 to d-1)
|
|
249
|
+
* @param ex_bits number of bits per code (1-8)
|
|
250
|
+
* @return extracted code value in range [0, 2^ex_bits - 1]
|
|
251
|
+
*/
|
|
252
|
+
inline int extract_code_inline(
|
|
253
|
+
const uint8_t* ex_code,
|
|
254
|
+
size_t index,
|
|
255
|
+
size_t ex_bits) {
|
|
256
|
+
size_t bit_pos = index * ex_bits;
|
|
257
|
+
int code_value = 0;
|
|
258
|
+
|
|
259
|
+
// Extract ex_bits bits starting at bit_pos
|
|
260
|
+
for (size_t bit = 0; bit < ex_bits; bit++) {
|
|
261
|
+
size_t byte_idx = bit_pos / 8;
|
|
262
|
+
size_t bit_idx = bit_pos % 8;
|
|
263
|
+
|
|
264
|
+
if (ex_code[byte_idx] & (1 << bit_idx)) {
|
|
265
|
+
code_value |= (1 << bit);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
bit_pos++;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
return code_value;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
/** Compute full multi-bit distance from sign bits and ex-bit codes.
|
|
275
|
+
* This is the core distance computation shared by RaBitQFastScan handlers.
|
|
276
|
+
*
|
|
277
|
+
* The multi-bit distance combines the sign bit (1-bit) with additional
|
|
278
|
+
* magnitude bits (ex_bits) to compute a more accurate distance estimate.
|
|
279
|
+
*
|
|
280
|
+
* @param sign_bits unpacked sign bits (1-bit codes in standard format)
|
|
281
|
+
* @param ex_code packed ex-bit codes
|
|
282
|
+
* @param ex_fac ex-bit factors (f_add_ex, f_rescale_ex)
|
|
283
|
+
* @param rotated_q rotated query vector
|
|
284
|
+
* @param qr_to_c_L2sqr precomputed ||query_rotated - centroid||^2
|
|
285
|
+
* @param qr_norm_L2sqr precomputed ||query_rotated||^2 (0 for L2 metric)
|
|
286
|
+
* @param d dimensionality
|
|
287
|
+
* @param ex_bits number of extra bits (nb_bits - 1)
|
|
288
|
+
* @param metric_type distance metric (L2 or Inner Product)
|
|
289
|
+
* @return computed full multi-bit distance
|
|
290
|
+
*/
|
|
291
|
+
inline float compute_full_multibit_distance(
|
|
292
|
+
const uint8_t* sign_bits,
|
|
293
|
+
const uint8_t* ex_code,
|
|
294
|
+
const ExtraBitsFactors& ex_fac,
|
|
295
|
+
const float* rotated_q,
|
|
296
|
+
float qr_to_c_L2sqr,
|
|
297
|
+
float qr_norm_L2sqr,
|
|
298
|
+
size_t d,
|
|
299
|
+
size_t ex_bits,
|
|
300
|
+
MetricType metric_type) {
|
|
301
|
+
float ex_ip = 0.0f;
|
|
302
|
+
const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
|
|
303
|
+
|
|
304
|
+
for (size_t i = 0; i < d; i++) {
|
|
305
|
+
const size_t byte_idx = i / 8;
|
|
306
|
+
const size_t bit_offset = i % 8;
|
|
307
|
+
const bool sign_bit = (sign_bits[byte_idx] >> bit_offset) & 1;
|
|
308
|
+
|
|
309
|
+
int ex_code_val = extract_code_inline(ex_code, i, ex_bits);
|
|
310
|
+
|
|
311
|
+
int total_code = (sign_bit ? 1 : 0) << ex_bits;
|
|
312
|
+
total_code += ex_code_val;
|
|
313
|
+
float reconstructed = static_cast<float>(total_code) + cb;
|
|
314
|
+
|
|
315
|
+
ex_ip += rotated_q[i] * reconstructed;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
float dist = qr_to_c_L2sqr + ex_fac.f_add_ex + ex_fac.f_rescale_ex * ex_ip;
|
|
319
|
+
|
|
320
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
321
|
+
dist = -0.5f * (dist - qr_norm_L2sqr);
|
|
322
|
+
} else {
|
|
323
|
+
dist = std::max(0.0f, dist);
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
return dist;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
} // namespace rabitq_utils
|
|
330
|
+
} // namespace faiss
|