faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -37,11 +37,28 @@ struct RaBitQuantizer : Quantizer {
|
|
|
37
37
|
// possible. Thus, a quantizer has to introduce a metric.
|
|
38
38
|
MetricType metric_type = MetricType::METRIC_L2;
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
// Number of bits per dimension (1-9). Default is 1 for backward
|
|
41
|
+
// compatibility.
|
|
42
|
+
// - nb_bits = 1: standard 1-bit RaBitQ (sign bits only)
|
|
43
|
+
// - nb_bits = 2-9: multi-bit RaBitQ (1 sign bit + ex_bits extra bits)
|
|
44
|
+
size_t nb_bits = 1;
|
|
45
|
+
|
|
46
|
+
RaBitQuantizer(
|
|
47
|
+
size_t d = 0,
|
|
48
|
+
MetricType metric = MetricType::METRIC_L2,
|
|
49
|
+
size_t nb_bits = 1);
|
|
50
|
+
|
|
51
|
+
// Compute code size based on dimensionality and number of bits
|
|
52
|
+
// Returns: size in bytes for one encoded vector
|
|
53
|
+
// - nb_bits=1: (d+7)/8 + 8 bytes (1-bit codes + base factors)
|
|
54
|
+
// - nb_bits>1: (d+7)/8 + 8 + d*ex_bits/8 + 8 bytes
|
|
55
|
+
// (1-bit codes + base factors + ex-bit codes + ex factors)
|
|
56
|
+
size_t compute_code_size(size_t d, size_t num_bits) const;
|
|
41
57
|
|
|
42
58
|
void train(size_t n, const float* x) override;
|
|
43
59
|
|
|
44
|
-
// every vector is expected to take (d + 7) / 8 + sizeof(
|
|
60
|
+
// every vector is expected to take (d + 7) / 8 + sizeof(SignBitFactors)
|
|
61
|
+
// bytes,
|
|
45
62
|
void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
|
|
46
63
|
|
|
47
64
|
void compute_codes_core(
|
|
@@ -71,8 +88,59 @@ struct RaBitQuantizer : Quantizer {
|
|
|
71
88
|
// specify qb = 0 to get an DC that does not quantize a query
|
|
72
89
|
// specify qb > 0 to have SQ qb-bits query
|
|
73
90
|
FlatCodesDistanceComputer* get_distance_computer(
|
|
74
|
-
uint8_t qb,
|
|
75
|
-
const float*
|
|
91
|
+
uint8_t qb = 0,
|
|
92
|
+
const float* centroid = nullptr,
|
|
93
|
+
bool centered = false) const;
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
// RaBitQDistanceComputer: Base class for RaBitQ distance computers
|
|
97
|
+
//
|
|
98
|
+
// This intermediate class exists to provide a unified interface for
|
|
99
|
+
// two-stage multi-bit search. While most Faiss quantizers extend
|
|
100
|
+
// FlatCodesDistanceComputer directly, RaBitQ requires this additional
|
|
101
|
+
// abstraction layer due to its unique split encoding strategy
|
|
102
|
+
// (1 sign bit + magnitude bits) which enables:
|
|
103
|
+
//
|
|
104
|
+
// 1. distance_to_code_1bit() - Fast 1-bit filtering using only sign bits
|
|
105
|
+
// 2. distance_to_code_full() - Accurate multi-bit refinement using all bits
|
|
106
|
+
// 3. lower_bound_distance() - Error-bounded adaptive filtering
|
|
107
|
+
// (based on 1-bit estimator)
|
|
108
|
+
//
|
|
109
|
+
// These three methods implement RaBitQ's two-stage search pattern and are
|
|
110
|
+
// shared between the quantized (Q) and non-quantized (NotQ) query variants.
|
|
111
|
+
// The intermediate class allows two-stage search code to work with both
|
|
112
|
+
// variants via a single dynamic_cast.
|
|
113
|
+
struct RaBitQDistanceComputer : FlatCodesDistanceComputer {
|
|
114
|
+
size_t d = 0;
|
|
115
|
+
const float* centroid = nullptr;
|
|
116
|
+
MetricType metric_type = MetricType::METRIC_L2;
|
|
117
|
+
size_t nb_bits = 1;
|
|
118
|
+
|
|
119
|
+
// Query norm for lower bound computation (g_error in rabitq-library)
|
|
120
|
+
// This is the L2 norm of the rotated query: ||query - centroid||
|
|
121
|
+
float g_error = 0.0f;
|
|
122
|
+
|
|
123
|
+
float symmetric_dis(idx_t /*i*/, idx_t /*j*/) override {
|
|
124
|
+
// Not used for RaBitQ
|
|
125
|
+
FAISS_THROW_MSG("Not implemented");
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// Compute 1-bit distance estimate (fast)
|
|
129
|
+
virtual float distance_to_code_1bit(const uint8_t* code) = 0;
|
|
130
|
+
|
|
131
|
+
// Compute full multi-bit distance (accurate)
|
|
132
|
+
virtual float distance_to_code_full(const uint8_t* code) = 0;
|
|
133
|
+
|
|
134
|
+
// Compute lower bound of distance using error bounds
|
|
135
|
+
// Guarantees: actual_distance >= lower_bound_distance
|
|
136
|
+
// Used for adaptive filtering in two-stage search
|
|
137
|
+
virtual float lower_bound_distance(const uint8_t* code);
|
|
138
|
+
|
|
139
|
+
// Override from FlatCodesDistanceComputer
|
|
140
|
+
// Delegates to distance_to_code_full() for multi-bit distance computation
|
|
141
|
+
float distance_to_code(const uint8_t* code) final {
|
|
142
|
+
return distance_to_code_full(code);
|
|
143
|
+
}
|
|
76
144
|
};
|
|
77
145
|
|
|
78
146
|
} // namespace faiss
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// NOTE: Parts of this implementation are adapted from:
|
|
9
|
+
// RaBitQ-Library/include/rabitqlib/quantization/rabitq_impl.hpp
|
|
10
|
+
// https://github.com/VectorDB-NTU/RaBitQ-Library
|
|
11
|
+
|
|
12
|
+
#include <faiss/impl/FaissAssert.h>
|
|
13
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
14
|
+
#include <faiss/utils/distances.h>
|
|
15
|
+
|
|
16
|
+
#include <algorithm>
|
|
17
|
+
#include <cmath>
|
|
18
|
+
#include <cstring>
|
|
19
|
+
#include <queue>
|
|
20
|
+
#include <vector>
|
|
21
|
+
|
|
22
|
+
namespace faiss {
|
|
23
|
+
namespace rabitq_multibit {
|
|
24
|
+
|
|
25
|
+
using rabitq_utils::ExtraBitsFactors;
|
|
26
|
+
using rabitq_utils::SignBitFactorsWithError;
|
|
27
|
+
|
|
28
|
+
constexpr float kTightStart[9] =
|
|
29
|
+
{0.0f, 0.15f, 0.20f, 0.52f, 0.59f, 0.71f, 0.75f, 0.77f, 0.81f};
|
|
30
|
+
|
|
31
|
+
constexpr double kEps = 1e-5;
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* Compute optimal scaling factor for ex-bits quantization using priority
|
|
35
|
+
* queue-based search.
|
|
36
|
+
*
|
|
37
|
+
* This function finds the optimal scaling factor 't' that maximizes the
|
|
38
|
+
* inner product between the normalized quantized vector and the normalized
|
|
39
|
+
* absolute residual. The algorithm uses a priority queue to efficiently
|
|
40
|
+
* explore different quantization levels.
|
|
41
|
+
*
|
|
42
|
+
*
|
|
43
|
+
* @param o_abs Normalized absolute residual vector (must be positive, length
|
|
44
|
+
* d)
|
|
45
|
+
* @param d Dimensionality of the vector
|
|
46
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
47
|
+
* @return Optimal scaling factor 't'
|
|
48
|
+
*/
|
|
49
|
+
float compute_optimal_scaling_factor(
|
|
50
|
+
const float* o_abs,
|
|
51
|
+
size_t d,
|
|
52
|
+
size_t nb_bits) {
|
|
53
|
+
const size_t ex_bits = nb_bits - 1;
|
|
54
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
55
|
+
ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
|
|
56
|
+
|
|
57
|
+
const int kNEnum = 10;
|
|
58
|
+
const int max_code = (1 << ex_bits) - 1;
|
|
59
|
+
|
|
60
|
+
float max_o = *std::max_element(o_abs, o_abs + d);
|
|
61
|
+
|
|
62
|
+
// Determine search range [t_start, t_end]
|
|
63
|
+
float t_end = static_cast<float>(max_code + kNEnum) / max_o;
|
|
64
|
+
float t_start = t_end * kTightStart[ex_bits];
|
|
65
|
+
|
|
66
|
+
std::vector<float> inv_o_abs(d);
|
|
67
|
+
for (size_t i = 0; i < d; ++i) {
|
|
68
|
+
inv_o_abs[i] = 1.0f / o_abs[i];
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
std::vector<int> cur_o_bar(d);
|
|
72
|
+
float sqr_denominator = static_cast<float>(d) * 0.25f;
|
|
73
|
+
float numerator = 0.0f;
|
|
74
|
+
|
|
75
|
+
for (size_t i = 0; i < d; ++i) {
|
|
76
|
+
int cur = static_cast<int>((t_start * o_abs[i]) + kEps);
|
|
77
|
+
cur_o_bar[i] = cur;
|
|
78
|
+
sqr_denominator += static_cast<float>(cur * cur + cur);
|
|
79
|
+
numerator += (cur + 0.5f) * o_abs[i];
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
float inv_sqrt_denom = 1.0f / std::sqrt(sqr_denominator);
|
|
83
|
+
|
|
84
|
+
// Pair: (next_t, dimension_index)
|
|
85
|
+
// Maximum size is d (one entry per dimension), so reserve exactly d
|
|
86
|
+
std::vector<std::pair<float, size_t>> pq_storage;
|
|
87
|
+
pq_storage.reserve(d);
|
|
88
|
+
std::priority_queue<
|
|
89
|
+
std::pair<float, size_t>,
|
|
90
|
+
std::vector<std::pair<float, size_t>>,
|
|
91
|
+
std::greater<>>
|
|
92
|
+
next_t(std::greater<>(), std::move(pq_storage));
|
|
93
|
+
|
|
94
|
+
// Initialize queue with next quantization level for each dimension
|
|
95
|
+
for (size_t i = 0; i < d; ++i) {
|
|
96
|
+
float t_next = static_cast<float>(cur_o_bar[i] + 1) * inv_o_abs[i];
|
|
97
|
+
if (t_next < t_end) {
|
|
98
|
+
next_t.emplace(t_next, i);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
float max_ip = 0.0f;
|
|
103
|
+
float t = 0.0f;
|
|
104
|
+
|
|
105
|
+
while (!next_t.empty()) {
|
|
106
|
+
float cur_t = next_t.top().first;
|
|
107
|
+
size_t update_id = next_t.top().second;
|
|
108
|
+
next_t.pop();
|
|
109
|
+
|
|
110
|
+
cur_o_bar[update_id]++;
|
|
111
|
+
int update_o_bar = cur_o_bar[update_id];
|
|
112
|
+
|
|
113
|
+
float delta = 2.0f * update_o_bar;
|
|
114
|
+
sqr_denominator += delta;
|
|
115
|
+
numerator += o_abs[update_id];
|
|
116
|
+
|
|
117
|
+
float old_denom = sqr_denominator - delta;
|
|
118
|
+
inv_sqrt_denom = inv_sqrt_denom *
|
|
119
|
+
(1.0f - 0.5f * delta / (old_denom + delta * 0.5f));
|
|
120
|
+
|
|
121
|
+
float cur_ip = numerator * inv_sqrt_denom;
|
|
122
|
+
|
|
123
|
+
if (cur_ip > max_ip) {
|
|
124
|
+
max_ip = cur_ip;
|
|
125
|
+
t = cur_t;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
if (update_o_bar < max_code) {
|
|
129
|
+
float t_next =
|
|
130
|
+
static_cast<float>(update_o_bar + 1) * inv_o_abs[update_id];
|
|
131
|
+
if (t_next < t_end) {
|
|
132
|
+
next_t.emplace(t_next, update_id);
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
return t;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
/**
|
|
141
|
+
* Pack multi-bit codes from integer array to byte array.
|
|
142
|
+
*
|
|
143
|
+
* @param tmp_code Integer codes (length d), each value in [0, 2^ex_bits - 1]
|
|
144
|
+
* @param ex_code Output packed byte array
|
|
145
|
+
* @param d Dimensionality
|
|
146
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
147
|
+
*/
|
|
148
|
+
void pack_multibit_codes(
|
|
149
|
+
const int* tmp_code,
|
|
150
|
+
uint8_t* ex_code,
|
|
151
|
+
size_t d,
|
|
152
|
+
size_t nb_bits) {
|
|
153
|
+
const size_t ex_bits = nb_bits - 1;
|
|
154
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
155
|
+
ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
|
|
156
|
+
|
|
157
|
+
size_t total_bits = d * ex_bits;
|
|
158
|
+
size_t output_size = (total_bits + 7) / 8;
|
|
159
|
+
memset(ex_code, 0, output_size);
|
|
160
|
+
|
|
161
|
+
size_t bit_pos = 0;
|
|
162
|
+
for (size_t i = 0; i < d; i++) {
|
|
163
|
+
int code_value = tmp_code[i];
|
|
164
|
+
|
|
165
|
+
for (size_t bit = 0; bit < ex_bits; bit++) {
|
|
166
|
+
size_t byte_idx = bit_pos / 8;
|
|
167
|
+
size_t bit_idx = bit_pos % 8;
|
|
168
|
+
|
|
169
|
+
if (code_value & (1 << bit)) {
|
|
170
|
+
ex_code[byte_idx] |= (1 << bit_idx);
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
bit_pos++;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
/**
|
|
179
|
+
* Compute ex-bits factors for distance computation.
|
|
180
|
+
*
|
|
181
|
+
* @param residual Original residual vector (data - centroid)
|
|
182
|
+
* @param centroid Centroid vector (can be nullptr for zero centroid)
|
|
183
|
+
* @param tmp_code Quantized ex-bit codes (before packing, after bit flipping)
|
|
184
|
+
* @param d Dimensionality
|
|
185
|
+
* @param ex_bits Number of extra bits
|
|
186
|
+
* @param norm L2 norm of residual
|
|
187
|
+
* @param ipnorm Unnormalized inner product between quantized and normalized
|
|
188
|
+
* residual
|
|
189
|
+
* @param ex_factors Output factors structure
|
|
190
|
+
* @param metric_type Distance metric (L2 or Inner Product)
|
|
191
|
+
*/
|
|
192
|
+
void compute_ex_factors(
|
|
193
|
+
const float* residual,
|
|
194
|
+
const float* centroid,
|
|
195
|
+
const int* tmp_code,
|
|
196
|
+
size_t d,
|
|
197
|
+
size_t ex_bits,
|
|
198
|
+
float norm,
|
|
199
|
+
double ipnorm,
|
|
200
|
+
ExtraBitsFactors& ex_factors,
|
|
201
|
+
MetricType metric_type) {
|
|
202
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
203
|
+
metric_type == MetricType::METRIC_L2 ||
|
|
204
|
+
metric_type == MetricType::METRIC_INNER_PRODUCT,
|
|
205
|
+
"Unsupported metric type");
|
|
206
|
+
|
|
207
|
+
// Compute ipnorm_inv = 1 / ipnorm
|
|
208
|
+
float ipnorm_inv = static_cast<float>(1.0 / ipnorm);
|
|
209
|
+
if (!std::isnormal(ipnorm_inv)) {
|
|
210
|
+
ipnorm_inv = 1.0f;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
// Reconstruct xu_cb from total_code
|
|
214
|
+
// total_code was formed from: total_code[i] = (sign << ex_bits) +
|
|
215
|
+
// ex_code[i] Reconstruction: xu_cb[i] = total_code[i] + cb
|
|
216
|
+
const float cb = -(static_cast<float>(1 << ex_bits) - 0.5f);
|
|
217
|
+
std::vector<float> xu_cb(d);
|
|
218
|
+
for (size_t i = 0; i < d; i++) {
|
|
219
|
+
xu_cb[i] = static_cast<float>(tmp_code[i]) + cb;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
// Compute inner products needed for factors
|
|
223
|
+
float l2_sqr = norm * norm;
|
|
224
|
+
float ip_resi_xucb = fvec_inner_product(residual, xu_cb.data(), d);
|
|
225
|
+
|
|
226
|
+
// Compute factors
|
|
227
|
+
if (metric_type == MetricType::METRIC_L2) {
|
|
228
|
+
// For L2, no centroid correction needed in IVF setting
|
|
229
|
+
// because residual = x - centroid, distance computed in residual space
|
|
230
|
+
ex_factors.f_add_ex = l2_sqr;
|
|
231
|
+
ex_factors.f_rescale_ex = ipnorm_inv * -2.0f * norm;
|
|
232
|
+
} else {
|
|
233
|
+
// For IP, centroid correction is needed
|
|
234
|
+
float ip_resi_cent = 0;
|
|
235
|
+
if (centroid != nullptr) {
|
|
236
|
+
ip_resi_cent = fvec_inner_product(residual, centroid, d);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
float ip_cent_xucb = 0;
|
|
240
|
+
if (centroid != nullptr) {
|
|
241
|
+
ip_cent_xucb = fvec_inner_product(centroid, xu_cb.data(), d);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// When ip_resi_xucb is zero, the correction term should be zero
|
|
245
|
+
float correction_term = 0.0f;
|
|
246
|
+
if (ip_resi_xucb != 0.0f) {
|
|
247
|
+
correction_term = l2_sqr * ip_cent_xucb / ip_resi_xucb;
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
ex_factors.f_add_ex = 1 - ip_resi_cent + correction_term;
|
|
251
|
+
ex_factors.f_rescale_ex = ipnorm_inv * -norm;
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
/**
|
|
256
|
+
* Quantize residual vector to ex-bits.
|
|
257
|
+
*
|
|
258
|
+
* This is the main quantization function that:
|
|
259
|
+
* 1. Normalizes the residual
|
|
260
|
+
* 2. Takes absolute value
|
|
261
|
+
* 3. Finds optimal scaling factor
|
|
262
|
+
* 4. Quantizes to ex_bits
|
|
263
|
+
* 5. Handles negative dimensions by flipping bits
|
|
264
|
+
* 6. Packs codes into byte array
|
|
265
|
+
* 7. Computes factors for distance computation
|
|
266
|
+
*
|
|
267
|
+
* @param residual Input residual vector (data - centroid), length d
|
|
268
|
+
* @param d Dimensionality
|
|
269
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
270
|
+
* @param ex_code Output packed ex-bit codes
|
|
271
|
+
* @param ex_factors Output ex-bits factors
|
|
272
|
+
* @param metric_type Distance metric (L2 or Inner Product)
|
|
273
|
+
* @param centroid Optional centroid vector (needed for IP metric)
|
|
274
|
+
*/
|
|
275
|
+
void quantize_ex_bits(
|
|
276
|
+
const float* residual,
|
|
277
|
+
size_t d,
|
|
278
|
+
size_t nb_bits,
|
|
279
|
+
uint8_t* ex_code,
|
|
280
|
+
ExtraBitsFactors& ex_factors,
|
|
281
|
+
MetricType metric_type,
|
|
282
|
+
const float* centroid) {
|
|
283
|
+
const size_t ex_bits = nb_bits - 1;
|
|
284
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
285
|
+
ex_bits >= 1 && ex_bits <= 8, "ex_bits must be in range [1, 8]");
|
|
286
|
+
FAISS_THROW_IF_NOT_MSG(residual != nullptr, "residual cannot be null");
|
|
287
|
+
FAISS_THROW_IF_NOT_MSG(ex_code != nullptr, "ex_code cannot be null");
|
|
288
|
+
|
|
289
|
+
// Step 1: Compute L2 norm of residual
|
|
290
|
+
float norm_sqr = fvec_norm_L2sqr(residual, d);
|
|
291
|
+
float norm = std::sqrt(norm_sqr);
|
|
292
|
+
|
|
293
|
+
// Handle degenerate case
|
|
294
|
+
if (norm < 1e-10f) {
|
|
295
|
+
size_t code_size = (d * ex_bits + 7) / 8;
|
|
296
|
+
memset(ex_code, 0, code_size);
|
|
297
|
+
ex_factors.f_add_ex = 0.0f;
|
|
298
|
+
ex_factors.f_rescale_ex = 1.0f;
|
|
299
|
+
return;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
// Step 2: Normalize residual
|
|
303
|
+
std::vector<float> normalized_residual(d);
|
|
304
|
+
for (size_t i = 0; i < d; i++) {
|
|
305
|
+
normalized_residual[i] = residual[i] / norm;
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
// Step 3: Take absolute value
|
|
309
|
+
std::vector<float> o_abs(d);
|
|
310
|
+
for (size_t i = 0; i < d; i++) {
|
|
311
|
+
o_abs[i] = std::abs(normalized_residual[i]);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
// Step 4: Find optimal scaling factor
|
|
315
|
+
float t = compute_optimal_scaling_factor(o_abs.data(), d, nb_bits);
|
|
316
|
+
|
|
317
|
+
// Step 5: Quantize to ex_bits
|
|
318
|
+
std::vector<int> tmp_code(d);
|
|
319
|
+
double ipnorm = 0;
|
|
320
|
+
int max_code = (1 << ex_bits) - 1;
|
|
321
|
+
|
|
322
|
+
for (size_t i = 0; i < d; i++) {
|
|
323
|
+
tmp_code[i] = std::min(static_cast<int>(t * o_abs[i] + kEps), max_code);
|
|
324
|
+
// Compute unnormalized inner product
|
|
325
|
+
ipnorm += (tmp_code[i] + 0.5) * o_abs[i];
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
// Step 6: Handle negative dimensions (flip bits)
|
|
329
|
+
// For negative residuals, flip all bits: code' = ~code & max_code
|
|
330
|
+
for (size_t i = 0; i < d; i++) {
|
|
331
|
+
if (residual[i] < 0) {
|
|
332
|
+
tmp_code[i] = (~tmp_code[i]) & max_code;
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// Step 7: Pack codes into byte array
|
|
337
|
+
pack_multibit_codes(tmp_code.data(), ex_code, d, nb_bits);
|
|
338
|
+
|
|
339
|
+
// Step 8: Compute factors for distance computation
|
|
340
|
+
// Reconstruct total_code for factor computation
|
|
341
|
+
std::vector<int> total_code(d);
|
|
342
|
+
for (size_t i = 0; i < d; i++) {
|
|
343
|
+
// Form total_code = (sign << ex_bits) + ex_code
|
|
344
|
+
bool sign_bit = (residual[i] >= 0);
|
|
345
|
+
total_code[i] = tmp_code[i] + ((sign_bit ? 1 : 0) << ex_bits);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
// Compute ex-factors; centroid is needed for IP metric correction
|
|
349
|
+
compute_ex_factors(
|
|
350
|
+
residual,
|
|
351
|
+
centroid, // Pass centroid for IP metric factor computation
|
|
352
|
+
total_code.data(),
|
|
353
|
+
d,
|
|
354
|
+
ex_bits,
|
|
355
|
+
norm,
|
|
356
|
+
ipnorm,
|
|
357
|
+
ex_factors,
|
|
358
|
+
metric_type);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
} // namespace rabitq_multibit
|
|
362
|
+
} // namespace faiss
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// Reference:
|
|
9
|
+
// "Practical and asymptotically optimal quantization of high-dimensional
|
|
10
|
+
// vectors in euclidean space for approximate nearest neighbor search"
|
|
11
|
+
// Jianyang Gao, Yutong Gou, Yuexuan Xu, Yongyi Yang, Cheng Long, Raymond
|
|
12
|
+
// Chi-Wing Wong https://dl.acm.org/doi/pdf/10.1145/3725413
|
|
13
|
+
//
|
|
14
|
+
// Reference implementation: https://github.com/VectorDB-NTU/RaBitQ-Library
|
|
15
|
+
// NOTE: Parts of this implementation are adapted from
|
|
16
|
+
// rabitqlib/quantization/rabitq_impl.hpp in the above repository.
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include <faiss/MetricType.h>
|
|
21
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
22
|
+
#include <cstddef>
|
|
23
|
+
#include <cstdint>
|
|
24
|
+
|
|
25
|
+
namespace faiss {
|
|
26
|
+
namespace rabitq_multibit {
|
|
27
|
+
|
|
28
|
+
/**
|
|
29
|
+
* Compute optimal scaling factor for ex-bits quantization.
|
|
30
|
+
*
|
|
31
|
+
* Uses priority queue-based search to find the scaling factor that
|
|
32
|
+
* maximizes the inner product between quantized and original vectors.
|
|
33
|
+
*
|
|
34
|
+
* @param o_abs Normalized absolute residual vector (positive values)
|
|
35
|
+
* @param d Dimensionality
|
|
36
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
37
|
+
* @return Optimal scaling factor 't'
|
|
38
|
+
*/
|
|
39
|
+
float compute_optimal_scaling_factor(
|
|
40
|
+
const float* o_abs,
|
|
41
|
+
size_t d,
|
|
42
|
+
size_t nb_bits);
|
|
43
|
+
|
|
44
|
+
/**
|
|
45
|
+
* Pack multi-bit codes from integer array to byte array.
|
|
46
|
+
*
|
|
47
|
+
* @param tmp_code Integer codes (length d), values in [0, 2^ex_bits - 1]
|
|
48
|
+
* @param ex_code Output packed byte array
|
|
49
|
+
* @param d Dimensionality
|
|
50
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
51
|
+
*/
|
|
52
|
+
void pack_multibit_codes(
|
|
53
|
+
const int* tmp_code,
|
|
54
|
+
uint8_t* ex_code,
|
|
55
|
+
size_t d,
|
|
56
|
+
size_t nb_bits);
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Compute ex-bits factors for distance computation.
|
|
60
|
+
*
|
|
61
|
+
* @param residual Original residual vector (data - centroid)
|
|
62
|
+
* @param centroid Centroid vector (can be nullptr for zero centroid)
|
|
63
|
+
* @param tmp_code Quantized ex-bit codes (unpacked integers)
|
|
64
|
+
* @param d Dimensionality
|
|
65
|
+
* @param ex_bits Number of extra bits
|
|
66
|
+
* @param norm L2 norm of residual
|
|
67
|
+
* @param ipnorm Unnormalized inner product
|
|
68
|
+
* @param ex_factors Output factors structure
|
|
69
|
+
* @param metric_type Distance metric (L2 or IP)
|
|
70
|
+
*/
|
|
71
|
+
void compute_ex_factors(
|
|
72
|
+
const float* residual,
|
|
73
|
+
const float* centroid,
|
|
74
|
+
const int* tmp_code,
|
|
75
|
+
size_t d,
|
|
76
|
+
size_t ex_bits,
|
|
77
|
+
float norm,
|
|
78
|
+
double ipnorm,
|
|
79
|
+
rabitq_utils::ExtraBitsFactors& ex_factors,
|
|
80
|
+
MetricType metric_type);
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Main quantization function: quantize residual vector to ex-bits.
|
|
84
|
+
*
|
|
85
|
+
* Performs the complete multi-bit quantization pipeline:
|
|
86
|
+
* 1. Normalize residual
|
|
87
|
+
* 2. Take absolute value
|
|
88
|
+
* 3. Find optimal scaling factor
|
|
89
|
+
* 4. Quantize to ex_bits
|
|
90
|
+
* 5. Handle negative dimensions by bit flipping
|
|
91
|
+
* 6. Pack codes into byte array
|
|
92
|
+
* 7. Compute factors for distance computation
|
|
93
|
+
*
|
|
94
|
+
* @param residual Input residual vector (data - centroid), length d
|
|
95
|
+
* @param d Dimensionality
|
|
96
|
+
* @param nb_bits Number of bits per dimension (2-9)
|
|
97
|
+
* @param ex_code Output packed ex-bit codes
|
|
98
|
+
* @param ex_factors Output ex-bits factors
|
|
99
|
+
* @param metric_type Distance metric (L2 or Inner Product)
|
|
100
|
+
* @param centroid Optional centroid vector (needed for IP metric)
|
|
101
|
+
*/
|
|
102
|
+
void quantize_ex_bits(
|
|
103
|
+
const float* residual,
|
|
104
|
+
size_t d,
|
|
105
|
+
size_t nb_bits,
|
|
106
|
+
uint8_t* ex_code,
|
|
107
|
+
rabitq_utils::ExtraBitsFactors& ex_factors,
|
|
108
|
+
MetricType metric_type,
|
|
109
|
+
const float* centroid = nullptr);
|
|
110
|
+
|
|
111
|
+
} // namespace rabitq_multibit
|
|
112
|
+
} // namespace faiss
|
|
@@ -49,7 +49,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
|
|
|
49
49
|
* first element of the beam (faster but less accurate) */
|
|
50
50
|
static const int Train_top_beam = 1024;
|
|
51
51
|
|
|
52
|
-
/** set this bit to *not*
|
|
52
|
+
/** set this bit to *not* automatically compute the codebook tables
|
|
53
53
|
* after training */
|
|
54
54
|
static const int Skip_codebook_tables = 2048;
|
|
55
55
|
|
|
@@ -26,11 +26,11 @@ namespace faiss {
|
|
|
26
26
|
* The classes below are intended to be used as template arguments
|
|
27
27
|
* they handle results for batches of queries (size nq).
|
|
28
28
|
* They can be called in two ways:
|
|
29
|
-
* - by
|
|
29
|
+
* - by instantiating a SingleResultHandler that tracks results for a single
|
|
30
30
|
* query
|
|
31
31
|
* - with begin_multiple/add_results/end_multiple calls where a whole block of
|
|
32
32
|
* results is submitted
|
|
33
|
-
* All classes are templated on C which to define
|
|
33
|
+
* All classes are templated on C which to define whether the min or the max of
|
|
34
34
|
* results is to be kept, and on sel, so that the codepaths for with / without
|
|
35
35
|
* selector can be separated at compile time.
|
|
36
36
|
*****************************************************************/
|
|
@@ -306,7 +306,7 @@ struct HeapBlockResultHandler : TopkBlockResultHandler<C, use_sel> {
|
|
|
306
306
|
*
|
|
307
307
|
* A reservoir is a result array of size capacity > n (number of requested
|
|
308
308
|
* results) all results below a threshold are stored in an arbitrary order.
|
|
309
|
-
*When the capacity is reached, a new threshold is chosen by
|
|
309
|
+
*When the capacity is reached, a new threshold is chosen by partitioning
|
|
310
310
|
*the distance array.
|
|
311
311
|
*****************************************************************/
|
|
312
312
|
|
|
@@ -572,7 +572,7 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
|
572
572
|
RangeSearchPartialResult* pres;
|
|
573
573
|
// there is one RangeSearchPartialResult structure per j0
|
|
574
574
|
// (= block of columns of the large distance matrix)
|
|
575
|
-
// it is a bit tricky to find the
|
|
575
|
+
// it is a bit tricky to find the proper PartialResult structure
|
|
576
576
|
// because the inner loop is on db not on queries.
|
|
577
577
|
|
|
578
578
|
if (pr < j0s.size() && j0 == j0s[pr]) {
|
|
@@ -321,7 +321,7 @@ struct Codec6bit {
|
|
|
321
321
|
static FAISS_ALWAYS_INLINE __m256
|
|
322
322
|
decode_8_components(const uint8_t* code, int i) {
|
|
323
323
|
// // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
|
|
324
|
-
// // for the reference, maybe, it becomes used
|
|
324
|
+
// // for the reference, maybe, it becomes used one day.
|
|
325
325
|
// const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
|
|
326
326
|
// const uint32_t* data32 = (const uint32_t*)data16;
|
|
327
327
|
// const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
|
|
@@ -1009,16 +1009,13 @@ void train_Uniform(
|
|
|
1009
1009
|
} else if (rs == ScalarQuantizer::RS_quantiles) {
|
|
1010
1010
|
std::vector<float> x_copy(n);
|
|
1011
1011
|
memcpy(x_copy.data(), x, n * sizeof(*x));
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
o = 0;
|
|
1017
|
-
}
|
|
1018
|
-
if (o > n - o) {
|
|
1019
|
-
o = n / 2;
|
|
1020
|
-
}
|
|
1012
|
+
int temp = int(rs_arg * n);
|
|
1013
|
+
int o = temp < 0 ? 0 : (temp > n / 2 ? n / 2 : temp);
|
|
1014
|
+
|
|
1015
|
+
std::nth_element(x_copy.begin(), x_copy.begin() + o, x_copy.end());
|
|
1021
1016
|
vmin = x_copy[o];
|
|
1017
|
+
std::nth_element(
|
|
1018
|
+
x_copy.begin(), x_copy.begin() + (n - 1 - o), x_copy.end());
|
|
1022
1019
|
vmax = x_copy[n - 1 - o];
|
|
1023
1020
|
|
|
1024
1021
|
} else if (rs == ScalarQuantizer::RS_optim) {
|
|
@@ -40,7 +40,7 @@ struct ScalarQuantizer : Quantizer {
|
|
|
40
40
|
QuantizerType qtype = QT_8bit;
|
|
41
41
|
|
|
42
42
|
/** The uniform encoder can estimate the range of representable
|
|
43
|
-
* values of the
|
|
43
|
+
* values of the uniform encoder using different statistics. Here
|
|
44
44
|
* rs = rangestat_arg */
|
|
45
45
|
|
|
46
46
|
// rangestat_arg.
|
|
@@ -98,9 +98,7 @@ struct ScalarQuantizer : Quantizer {
|
|
|
98
98
|
SQuantizer* select_quantizer() const;
|
|
99
99
|
|
|
100
100
|
struct SQDistanceComputer : FlatCodesDistanceComputer {
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
SQDistanceComputer() : q(nullptr) {}
|
|
101
|
+
SQDistanceComputer() : FlatCodesDistanceComputer(nullptr) {}
|
|
104
102
|
|
|
105
103
|
virtual float query_to_code(const uint8_t* code) const = 0;
|
|
106
104
|
|
|
@@ -5,6 +5,8 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
8
10
|
#include <faiss/impl/FaissAssert.h>
|
|
9
11
|
#include <exception>
|
|
10
12
|
#include <iostream>
|
|
@@ -75,10 +77,11 @@ void ThreadedIndex<IndexT>::addIndex(IndexT* index) {
|
|
|
75
77
|
}
|
|
76
78
|
}
|
|
77
79
|
|
|
78
|
-
indices_.emplace_back(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
80
|
+
indices_.emplace_back(
|
|
81
|
+
std::make_pair(
|
|
82
|
+
index,
|
|
83
|
+
std::unique_ptr<WorkerThread>(
|
|
84
|
+
isThreaded_ ? new WorkerThread : nullptr)));
|
|
82
85
|
|
|
83
86
|
onAfterAddIndex(index);
|
|
84
87
|
}
|