faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/RaBitQStats.h>
|
|
9
|
+
|
|
10
|
+
namespace faiss {
|
|
11
|
+
|
|
12
|
+
// NOLINTNEXTLINE(facebook-avoid-non-const-global-variables)
|
|
13
|
+
RaBitQStats rabitq_stats;
|
|
14
|
+
|
|
15
|
+
void RaBitQStats::reset() {
|
|
16
|
+
n_1bit_evaluations = 0;
|
|
17
|
+
n_multibit_evaluations = 0;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
double RaBitQStats::skip_percentage() const {
|
|
21
|
+
const size_t copy_n_1bit_evaluations = n_1bit_evaluations;
|
|
22
|
+
const size_t copy_n_multibit_evaluations = n_multibit_evaluations;
|
|
23
|
+
return copy_n_1bit_evaluations > 0
|
|
24
|
+
? 100.0 * (copy_n_1bit_evaluations - copy_n_multibit_evaluations) /
|
|
25
|
+
copy_n_1bit_evaluations
|
|
26
|
+
: 0.0;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
} // namespace faiss
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/platform_macros.h>
|
|
11
|
+
#include <cstddef>
|
|
12
|
+
|
|
13
|
+
namespace faiss {
|
|
14
|
+
|
|
15
|
+
/// Statistics for RaBitQ multi-bit two-stage search.
|
|
16
|
+
///
|
|
17
|
+
/// These stats are ONLY collected for multi-bit mode (nb_bits > 1).
|
|
18
|
+
/// In 1-bit mode, there is no two-stage filtering - all candidates are
|
|
19
|
+
/// evaluated with a single distance computation, so there is nothing
|
|
20
|
+
/// meaningful to track. For 1-bit mode, both counters remain 0.
|
|
21
|
+
///
|
|
22
|
+
/// Multi-bit mode uses a two-stage search:
|
|
23
|
+
/// Stage 1: Compute 1-bit lower bound distance for all candidates
|
|
24
|
+
/// Stage 2: Compute full multi-bit distance only for promising candidates
|
|
25
|
+
///
|
|
26
|
+
/// The skip_percentage() metric measures filtering effectiveness:
|
|
27
|
+
/// how many candidates were filtered out by the 1-bit lower bound
|
|
28
|
+
/// without needing the more expensive multi-bit distance computation.
|
|
29
|
+
///
|
|
30
|
+
/// WARNING: Statistics are not robust to internal threading nor to
|
|
31
|
+
/// concurrent RaBitQ searches. Use these values in a single-threaded
|
|
32
|
+
/// context to accurately gauge RaBitQ's filtering effectiveness.
|
|
33
|
+
/// Call reset() before search, then read stats after search completes.
|
|
34
|
+
struct RaBitQStats {
|
|
35
|
+
/// Number of candidates evaluated using 1-bit (lower bound) distance.
|
|
36
|
+
/// This is the first stage of two-stage search in multi-bit mode.
|
|
37
|
+
/// Always 0 in 1-bit mode (stats not tracked).
|
|
38
|
+
size_t n_1bit_evaluations = 0;
|
|
39
|
+
|
|
40
|
+
/// Number of candidates that passed 1-bit filtering and required
|
|
41
|
+
/// full multi-bit distance computation (second stage).
|
|
42
|
+
/// Always 0 in 1-bit mode (stats not tracked).
|
|
43
|
+
size_t n_multibit_evaluations = 0;
|
|
44
|
+
|
|
45
|
+
void reset();
|
|
46
|
+
|
|
47
|
+
/// Compute percentage of candidates skipped (filtered out by 1-bit stage).
|
|
48
|
+
/// Returns 0 if no candidates were evaluated (including 1-bit mode).
|
|
49
|
+
double skip_percentage() const;
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
/// Global stats for RaBitQ indexes
|
|
53
|
+
// NOLINTNEXTLINE(facebook-avoid-non-const-global-variables)
|
|
54
|
+
FAISS_API extern RaBitQStats rabitq_stats;
|
|
55
|
+
|
|
56
|
+
} // namespace faiss
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/FaissAssert.h>
|
|
11
|
+
#include <faiss/utils/distances.h>
|
|
12
|
+
#include <algorithm>
|
|
13
|
+
#include <cmath>
|
|
14
|
+
#include <limits>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
namespace rabitq_utils {
|
|
18
|
+
|
|
19
|
+
// Verify no unexpected padding in structures used for per-vector storage.
|
|
20
|
+
// These checks ensure compute_per_vector_storage_size() remains accurate.
|
|
21
|
+
static_assert(
|
|
22
|
+
sizeof(SignBitFactors) == 8,
|
|
23
|
+
"SignBitFactors has unexpected padding");
|
|
24
|
+
static_assert(
|
|
25
|
+
sizeof(SignBitFactorsWithError) == 12,
|
|
26
|
+
"SignBitFactorsWithError has unexpected padding");
|
|
27
|
+
static_assert(
|
|
28
|
+
sizeof(ExtraBitsFactors) == 8,
|
|
29
|
+
"ExtraBitsFactors has unexpected padding");
|
|
30
|
+
|
|
31
|
+
// Ideal quantizer radii for quantizers of 1..8 bits, optimized to minimize
|
|
32
|
+
// L2 reconstruction error.
|
|
33
|
+
const float Z_MAX_BY_QB[8] = {
|
|
34
|
+
0.79688, // qb = 1.
|
|
35
|
+
1.49375,
|
|
36
|
+
2.05078,
|
|
37
|
+
2.50938,
|
|
38
|
+
2.91250,
|
|
39
|
+
3.26406,
|
|
40
|
+
3.59844,
|
|
41
|
+
3.91016, // qb = 8.
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
void compute_vector_intermediate_values(
|
|
45
|
+
const float* x,
|
|
46
|
+
size_t d,
|
|
47
|
+
const float* centroid,
|
|
48
|
+
float& norm_L2sqr,
|
|
49
|
+
float& or_L2sqr,
|
|
50
|
+
float& dp_oO) {
|
|
51
|
+
norm_L2sqr = 0.0f;
|
|
52
|
+
or_L2sqr = 0.0f;
|
|
53
|
+
dp_oO = 0.0f;
|
|
54
|
+
|
|
55
|
+
for (size_t j = 0; j < d; j++) {
|
|
56
|
+
const float x_val = x[j];
|
|
57
|
+
const float centroid_val = (centroid != nullptr) ? centroid[j] : 0.0f;
|
|
58
|
+
const float or_minus_c = x_val - centroid_val;
|
|
59
|
+
|
|
60
|
+
const float or_minus_c_sq = or_minus_c * or_minus_c;
|
|
61
|
+
norm_L2sqr += or_minus_c_sq;
|
|
62
|
+
or_L2sqr += x_val * x_val;
|
|
63
|
+
|
|
64
|
+
const bool xb = (or_minus_c > 0.0f);
|
|
65
|
+
dp_oO += xb ? or_minus_c : -or_minus_c;
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
SignBitFactorsWithError compute_factors_from_intermediates(
|
|
70
|
+
float norm_L2sqr,
|
|
71
|
+
float or_L2sqr,
|
|
72
|
+
float dp_oO,
|
|
73
|
+
size_t d,
|
|
74
|
+
MetricType metric_type,
|
|
75
|
+
bool compute_error) {
|
|
76
|
+
constexpr float epsilon = std::numeric_limits<float>::epsilon();
|
|
77
|
+
constexpr float kConstEpsilon =
|
|
78
|
+
1.9f; // Error bound constant from RaBitQ paper
|
|
79
|
+
const float inv_d_sqrt =
|
|
80
|
+
(d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
|
|
81
|
+
|
|
82
|
+
const float sqrt_norm_L2 = std::sqrt(norm_L2sqr);
|
|
83
|
+
const float inv_norm_L2 =
|
|
84
|
+
(norm_L2sqr < epsilon) ? 1.0f : (1.0f / sqrt_norm_L2);
|
|
85
|
+
|
|
86
|
+
const float normalized_dp = dp_oO * inv_norm_L2 * inv_d_sqrt;
|
|
87
|
+
const float inv_dp_oO =
|
|
88
|
+
(std::abs(normalized_dp) < epsilon) ? 1.0f : (1.0f / normalized_dp);
|
|
89
|
+
|
|
90
|
+
SignBitFactorsWithError factors;
|
|
91
|
+
factors.or_minus_c_l2sqr = (metric_type == MetricType::METRIC_INNER_PRODUCT)
|
|
92
|
+
? (norm_L2sqr - or_L2sqr)
|
|
93
|
+
: norm_L2sqr;
|
|
94
|
+
factors.dp_multiplier = inv_dp_oO * sqrt_norm_L2;
|
|
95
|
+
|
|
96
|
+
// Compute error bound only if needed (skip for 1-bit mode)
|
|
97
|
+
if (compute_error) {
|
|
98
|
+
const float xu_cb_norm_sqr = static_cast<float>(d) * 0.25f;
|
|
99
|
+
const float ip_resi_xucb = 0.5f * dp_oO;
|
|
100
|
+
|
|
101
|
+
float tmp_error = 0.0f;
|
|
102
|
+
if (std::abs(ip_resi_xucb) > epsilon) {
|
|
103
|
+
const float ratio_sq = (norm_L2sqr * xu_cb_norm_sqr) /
|
|
104
|
+
(ip_resi_xucb * ip_resi_xucb);
|
|
105
|
+
if (ratio_sq > 1.0f) {
|
|
106
|
+
if (d == 1) {
|
|
107
|
+
tmp_error = sqrt_norm_L2 * kConstEpsilon *
|
|
108
|
+
std::sqrt(ratio_sq - 1.0f);
|
|
109
|
+
} else {
|
|
110
|
+
tmp_error = sqrt_norm_L2 * kConstEpsilon *
|
|
111
|
+
std::sqrt((ratio_sq - 1.0f) /
|
|
112
|
+
static_cast<float>(d - 1));
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
// Apply metric-specific multiplier
|
|
118
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
119
|
+
factors.f_error = 2.0f * tmp_error;
|
|
120
|
+
} else if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
121
|
+
factors.f_error = 1.0f * tmp_error;
|
|
122
|
+
} else {
|
|
123
|
+
factors.f_error = 0.0f;
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
return factors;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
SignBitFactorsWithError compute_vector_factors(
|
|
131
|
+
const float* x,
|
|
132
|
+
size_t d,
|
|
133
|
+
const float* centroid,
|
|
134
|
+
MetricType metric_type,
|
|
135
|
+
bool compute_error) {
|
|
136
|
+
float norm_L2sqr, or_L2sqr, dp_oO;
|
|
137
|
+
compute_vector_intermediate_values(
|
|
138
|
+
x, d, centroid, norm_L2sqr, or_L2sqr, dp_oO);
|
|
139
|
+
return compute_factors_from_intermediates(
|
|
140
|
+
norm_L2sqr, or_L2sqr, dp_oO, d, metric_type, compute_error);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
QueryFactorsData compute_query_factors(
|
|
144
|
+
const float* query,
|
|
145
|
+
size_t d,
|
|
146
|
+
const float* centroid,
|
|
147
|
+
uint8_t qb,
|
|
148
|
+
bool centered,
|
|
149
|
+
MetricType metric_type,
|
|
150
|
+
std::vector<float>& rotated_q,
|
|
151
|
+
std::vector<uint8_t>& rotated_qq) {
|
|
152
|
+
FAISS_THROW_IF_NOT(qb <= 8);
|
|
153
|
+
FAISS_THROW_IF_NOT(qb > 0);
|
|
154
|
+
|
|
155
|
+
QueryFactorsData query_factors;
|
|
156
|
+
|
|
157
|
+
// Compute distance from query to centroid
|
|
158
|
+
if (centroid != nullptr) {
|
|
159
|
+
query_factors.qr_to_c_L2sqr = fvec_L2sqr(query, centroid, d);
|
|
160
|
+
} else {
|
|
161
|
+
query_factors.qr_to_c_L2sqr = fvec_norm_L2sqr(query, d);
|
|
162
|
+
}
|
|
163
|
+
query_factors.g_error = std::sqrt(query_factors.qr_to_c_L2sqr);
|
|
164
|
+
|
|
165
|
+
// Rotate the query (subtract centroid)
|
|
166
|
+
rotated_q.resize(d);
|
|
167
|
+
for (size_t i = 0; i < d; i++) {
|
|
168
|
+
if (i < rotated_q.size()) {
|
|
169
|
+
rotated_q[i] =
|
|
170
|
+
query[i] - ((centroid == nullptr) ? 0.0f : centroid[i]);
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
const float inv_d_sqrt =
|
|
175
|
+
(d == 0) ? 1.0f : (1.0f / std::sqrt(static_cast<float>(d)));
|
|
176
|
+
|
|
177
|
+
// Compute quantization range
|
|
178
|
+
float v_min = std::numeric_limits<float>::max();
|
|
179
|
+
float v_max = std::numeric_limits<float>::lowest();
|
|
180
|
+
|
|
181
|
+
if (centered) {
|
|
182
|
+
float z_max = Z_MAX_BY_QB[qb - 1];
|
|
183
|
+
float v_radius = z_max * std::sqrt(query_factors.qr_to_c_L2sqr / d);
|
|
184
|
+
v_min = -v_radius;
|
|
185
|
+
v_max = v_radius;
|
|
186
|
+
} else {
|
|
187
|
+
// Only compute min/max if we have dimensions to process
|
|
188
|
+
if (d > 0 && !rotated_q.empty()) {
|
|
189
|
+
for (size_t i = 0; i < d; i++) {
|
|
190
|
+
const float v_q = rotated_q[i];
|
|
191
|
+
v_min = std::min(v_min, v_q);
|
|
192
|
+
v_max = std::max(v_max, v_q);
|
|
193
|
+
}
|
|
194
|
+
} else {
|
|
195
|
+
// For empty dimensions, use default range
|
|
196
|
+
v_min = 0.0f;
|
|
197
|
+
v_max = 1.0f;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
// Quantize the query
|
|
202
|
+
const uint8_t max_code = (1 << qb) - 1;
|
|
203
|
+
const float delta = (v_max - v_min) / max_code;
|
|
204
|
+
const float inv_delta = 1.0f / delta;
|
|
205
|
+
|
|
206
|
+
rotated_qq.resize(d);
|
|
207
|
+
size_t sum_qq = 0;
|
|
208
|
+
int64_t sum2_signed_odd_int = 0;
|
|
209
|
+
|
|
210
|
+
// Process arrays - throw error if they are unexpectedly empty
|
|
211
|
+
if (d > 0 && !rotated_q.empty() && !rotated_qq.empty()) {
|
|
212
|
+
for (size_t i = 0; i < d; i++) {
|
|
213
|
+
const float v_q = rotated_q[i];
|
|
214
|
+
// Non-randomized scalar quantization
|
|
215
|
+
const uint8_t v_qq = std::clamp<float>(
|
|
216
|
+
std::round((v_q - v_min) * inv_delta), 0, max_code);
|
|
217
|
+
rotated_qq[i] = v_qq;
|
|
218
|
+
sum_qq += v_qq;
|
|
219
|
+
|
|
220
|
+
if (centered) {
|
|
221
|
+
int64_t signed_odd_int = int64_t(v_qq) * 2 - max_code;
|
|
222
|
+
sum2_signed_odd_int += signed_odd_int * signed_odd_int;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
} else {
|
|
226
|
+
FAISS_THROW_MSG(
|
|
227
|
+
"Arrays unexpectedly empty when d=" + std::to_string(d) +
|
|
228
|
+
"or d is incorrectly set");
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// Compute query factors
|
|
232
|
+
query_factors.c1 = 2 * delta * inv_d_sqrt;
|
|
233
|
+
query_factors.c2 = 2 * v_min * inv_d_sqrt;
|
|
234
|
+
query_factors.c34 = inv_d_sqrt * (delta * sum_qq + d * v_min);
|
|
235
|
+
|
|
236
|
+
if (centered) {
|
|
237
|
+
query_factors.int_dot_scale = std::sqrt(
|
|
238
|
+
query_factors.qr_to_c_L2sqr / (sum2_signed_odd_int * d));
|
|
239
|
+
} else {
|
|
240
|
+
query_factors.int_dot_scale = 1.0f;
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// Compute query norm for inner product metric
|
|
244
|
+
query_factors.qr_norm_L2sqr = 0.0f;
|
|
245
|
+
if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
|
|
246
|
+
query_factors.qr_norm_L2sqr = fvec_norm_L2sqr(query, d);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
return query_factors;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
bool extract_bit_standard(const uint8_t* code, size_t bit_index) {
|
|
253
|
+
const size_t byte_idx = bit_index / 8;
|
|
254
|
+
const size_t bit_offset = bit_index % 8;
|
|
255
|
+
return (code[byte_idx] >> bit_offset) & 1;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
bool extract_bit_fastscan(const uint8_t* code, size_t bit_index) {
|
|
259
|
+
const size_t m = bit_index / 4; // Sub-quantizer index
|
|
260
|
+
const size_t dim_offset =
|
|
261
|
+
bit_index % 4; // Bit position within sub-quantizer
|
|
262
|
+
const size_t byte_idx = m / 2; // Byte index (2 sub-quantizers per byte)
|
|
263
|
+
const uint8_t bit_mask = static_cast<uint8_t>(1 << dim_offset);
|
|
264
|
+
|
|
265
|
+
if (m % 2 == 0) {
|
|
266
|
+
// Lower 4 bits of byte
|
|
267
|
+
return (code[byte_idx] & bit_mask) != 0;
|
|
268
|
+
} else {
|
|
269
|
+
// Upper 4 bits of byte (shifted)
|
|
270
|
+
return (code[byte_idx] & (bit_mask << 4)) != 0;
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
void set_bit_standard(uint8_t* code, size_t bit_index) {
|
|
275
|
+
const size_t byte_idx = bit_index / 8;
|
|
276
|
+
const size_t bit_offset = bit_index % 8;
|
|
277
|
+
code[byte_idx] |= (1 << bit_offset);
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
void set_bit_fastscan(uint8_t* code, size_t bit_index) {
|
|
281
|
+
const size_t m = bit_index / 4;
|
|
282
|
+
const size_t dim_offset = bit_index % 4;
|
|
283
|
+
const uint8_t bit_mask = static_cast<uint8_t>(1 << dim_offset);
|
|
284
|
+
const size_t byte_idx = m / 2;
|
|
285
|
+
|
|
286
|
+
if (m % 2 == 0) {
|
|
287
|
+
code[byte_idx] |= bit_mask;
|
|
288
|
+
} else {
|
|
289
|
+
code[byte_idx] |= (bit_mask << 4);
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
} // namespace rabitq_utils
|
|
294
|
+
} // namespace faiss
|