faiss 0.4.2 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/ext/faiss/index.cpp +36 -10
- data/ext/faiss/index_binary.cpp +19 -6
- data/ext/faiss/kmeans.cpp +6 -6
- data/ext/faiss/numo.hpp +273 -123
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +1 -2
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.h +10 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +3 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +107 -7
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -1
- data/vendor/faiss/faiss/IndexHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +366 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +13 -6
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +650 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +216 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +13 -10
- data/vendor/faiss/faiss/IndexRaBitQ.h +7 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +586 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +149 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +3 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +10 -6
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +3 -3
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +1 -1
- data/vendor/faiss/faiss/impl/HNSW.h +4 -4
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +1 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +246 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +153 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +54 -158
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +2 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +87 -3
- data/vendor/faiss/faiss/impl/index_write.cpp +73 -3
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +43 -1
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +205 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.h +62 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +5 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +14 -3
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <vector>
|
|
11
|
+
|
|
12
|
+
#include <faiss/IndexIVFFastScan.h>
|
|
13
|
+
#include <faiss/IndexIVFRaBitQ.h>
|
|
14
|
+
#include <faiss/IndexRaBitQFastScan.h>
|
|
15
|
+
#include <faiss/impl/RaBitQUtils.h>
|
|
16
|
+
#include <faiss/impl/RaBitQuantizer.h>
|
|
17
|
+
#include <faiss/impl/simd_result_handlers.h>
|
|
18
|
+
#include <faiss/utils/AlignedTable.h>
|
|
19
|
+
#include <faiss/utils/Heap.h>
|
|
20
|
+
|
|
21
|
+
namespace faiss {
|
|
22
|
+
|
|
23
|
+
// Forward declarations
|
|
24
|
+
struct FastScanDistancePostProcessing;
|
|
25
|
+
|
|
26
|
+
// Import shared utilities from RaBitQUtils
|
|
27
|
+
using rabitq_utils::FactorsData;
|
|
28
|
+
using rabitq_utils::QueryFactorsData;
|
|
29
|
+
|
|
30
|
+
/** Fast-scan version of IndexIVFRaBitQ that processes vectors in batches
|
|
31
|
+
* using SIMD operations. Combines the inverted file structure of IVF
|
|
32
|
+
* with RaBitQ's bit-level quantization and FastScan's batch processing.
|
|
33
|
+
*
|
|
34
|
+
* Key features:
|
|
35
|
+
* - Inherits from IndexIVFFastScan for IVF structure and search algorithms
|
|
36
|
+
* - Processes 32 database vectors at a time using SIMD
|
|
37
|
+
* - Separates factors from quantized bits for efficient processing
|
|
38
|
+
* - Supports both L2 and inner product metrics
|
|
39
|
+
* - Maintains compatibility with existing IVF search parameters
|
|
40
|
+
*
|
|
41
|
+
* Implementation details:
|
|
42
|
+
* - Batch size (bbs) is typically 32 for optimal SIMD performance
|
|
43
|
+
* - Factors are stored separately from packed codes for cache efficiency
|
|
44
|
+
* - Query factors are computed once per search and reused across lists
|
|
45
|
+
* - Uses specialized result handlers for RaBitQ distance corrections
|
|
46
|
+
*/
|
|
47
|
+
struct IndexIVFRaBitQFastScan : IndexIVFFastScan {
|
|
48
|
+
RaBitQuantizer rabitq;
|
|
49
|
+
|
|
50
|
+
/// Default number of bits to quantize a query with
|
|
51
|
+
uint8_t qb = 8;
|
|
52
|
+
|
|
53
|
+
/// Use zero-centered scalar quantizer for queries
|
|
54
|
+
bool centered = false;
|
|
55
|
+
|
|
56
|
+
/// Extracted factors storage for batch processing
|
|
57
|
+
/// Size: ntotal, stores factors separately from packed codes
|
|
58
|
+
std::vector<FactorsData> factors_storage;
|
|
59
|
+
|
|
60
|
+
// Constructors
|
|
61
|
+
|
|
62
|
+
IndexIVFRaBitQFastScan();
|
|
63
|
+
|
|
64
|
+
IndexIVFRaBitQFastScan(
|
|
65
|
+
Index* quantizer,
|
|
66
|
+
size_t d,
|
|
67
|
+
size_t nlist,
|
|
68
|
+
MetricType metric = METRIC_L2,
|
|
69
|
+
int bbs = 32,
|
|
70
|
+
bool own_invlists = true);
|
|
71
|
+
|
|
72
|
+
/// Build from an existing IndexIVFRaBitQ
|
|
73
|
+
explicit IndexIVFRaBitQFastScan(const IndexIVFRaBitQ& orig, int bbs = 32);
|
|
74
|
+
|
|
75
|
+
// Required overrides
|
|
76
|
+
|
|
77
|
+
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
|
|
78
|
+
|
|
79
|
+
void encode_vectors(
|
|
80
|
+
idx_t n,
|
|
81
|
+
const float* x,
|
|
82
|
+
const idx_t* list_nos,
|
|
83
|
+
uint8_t* codes,
|
|
84
|
+
bool include_listnos = false) const override;
|
|
85
|
+
|
|
86
|
+
protected:
|
|
87
|
+
/// Extract and store RaBitQ factors from encoded vectors
|
|
88
|
+
void preprocess_code_metadata(
|
|
89
|
+
idx_t n,
|
|
90
|
+
const uint8_t* flat_codes,
|
|
91
|
+
idx_t start_global_idx) override;
|
|
92
|
+
|
|
93
|
+
/// Return code_size as stride to skip embedded factor data during packing
|
|
94
|
+
size_t code_packing_stride() const override;
|
|
95
|
+
|
|
96
|
+
public:
|
|
97
|
+
/// Reconstruct a single vector from an inverted list
|
|
98
|
+
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
|
99
|
+
const override;
|
|
100
|
+
|
|
101
|
+
/// Override sa_decode to handle RaBitQ reconstruction
|
|
102
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
103
|
+
|
|
104
|
+
private:
|
|
105
|
+
/// Encode a vector to FastScan format without computing factors
|
|
106
|
+
void encode_vector_to_fastscan(
|
|
107
|
+
const float* xi,
|
|
108
|
+
const float* centroid,
|
|
109
|
+
uint8_t* fastscan_code) const;
|
|
110
|
+
|
|
111
|
+
/// Compute query factors and lookup table for a residual vector
|
|
112
|
+
/// (similar to IndexRaBitQFastScan::compute_float_LUT)
|
|
113
|
+
void compute_residual_LUT(
|
|
114
|
+
const float* residual,
|
|
115
|
+
QueryFactorsData& query_factors,
|
|
116
|
+
float* lut_out,
|
|
117
|
+
const float* original_query = nullptr) const;
|
|
118
|
+
|
|
119
|
+
/// Decode FastScan code to RaBitQ residual vector
|
|
120
|
+
void decode_fastscan_to_residual(
|
|
121
|
+
const uint8_t* fastscan_code,
|
|
122
|
+
float* residual) const;
|
|
123
|
+
|
|
124
|
+
public:
|
|
125
|
+
/// Implementation methods for IVFRaBitQFastScan specialization
|
|
126
|
+
bool lookup_table_is_3d() const override;
|
|
127
|
+
|
|
128
|
+
void compute_LUT(
|
|
129
|
+
size_t n,
|
|
130
|
+
const float* x,
|
|
131
|
+
const CoarseQuantized& cq,
|
|
132
|
+
AlignedTable<float>& dis_tables,
|
|
133
|
+
AlignedTable<float>& biases,
|
|
134
|
+
const FastScanDistancePostProcessing& context) const override;
|
|
135
|
+
|
|
136
|
+
void search_preassigned(
|
|
137
|
+
idx_t n,
|
|
138
|
+
const float* x,
|
|
139
|
+
idx_t k,
|
|
140
|
+
const idx_t* assign,
|
|
141
|
+
const float* centroid_dis,
|
|
142
|
+
float* distances,
|
|
143
|
+
idx_t* labels,
|
|
144
|
+
bool store_pairs,
|
|
145
|
+
const IVFSearchParameters* params = nullptr,
|
|
146
|
+
IndexIVFStats* stats = nullptr) const override;
|
|
147
|
+
|
|
148
|
+
/// Override to create RaBitQ-specific handlers
|
|
149
|
+
SIMDResultHandlerToFloat* make_knn_handler(
|
|
150
|
+
bool is_max,
|
|
151
|
+
int /* impl */,
|
|
152
|
+
idx_t n,
|
|
153
|
+
idx_t k,
|
|
154
|
+
float* distances,
|
|
155
|
+
idx_t* labels,
|
|
156
|
+
const IDSelector* sel,
|
|
157
|
+
const FastScanDistancePostProcessing& context,
|
|
158
|
+
const float* normalizers = nullptr) const override;
|
|
159
|
+
|
|
160
|
+
/** SIMD result handler for IndexIVFRaBitQFastScan that applies
|
|
161
|
+
* RaBitQ-specific distance corrections during batch processing.
|
|
162
|
+
*
|
|
163
|
+
* This handler processes batches of 32 distance computations from SIMD
|
|
164
|
+
* kernels, applies RaBitQ distance formula adjustments (factors and
|
|
165
|
+
* normalizers), and immediately updates result heaps. This eliminates the
|
|
166
|
+
* need for post-processing and provides significant performance benefits.
|
|
167
|
+
*
|
|
168
|
+
* Key optimizations:
|
|
169
|
+
* - Direct heap integration with no intermediate result storage
|
|
170
|
+
* - Batch-level computation of normalizers and query factors
|
|
171
|
+
* - Specialized handling for both centered and non-centered quantization
|
|
172
|
+
* modes
|
|
173
|
+
* - Efficient inner product metric corrections
|
|
174
|
+
*
|
|
175
|
+
* @tparam C Comparator type (CMin/CMax) for heap operations
|
|
176
|
+
*/
|
|
177
|
+
template <class C>
|
|
178
|
+
struct IVFRaBitQHeapHandler
|
|
179
|
+
: simd_result_handlers::ResultHandlerCompare<C, true> {
|
|
180
|
+
const IndexIVFRaBitQFastScan* index;
|
|
181
|
+
float* heap_distances; // [nq * k]
|
|
182
|
+
int64_t* heap_labels; // [nq * k]
|
|
183
|
+
const size_t nq, k;
|
|
184
|
+
size_t current_list_no = 0;
|
|
185
|
+
std::vector<int>
|
|
186
|
+
probe_indices; // probe index for each query in current batch
|
|
187
|
+
const FastScanDistancePostProcessing*
|
|
188
|
+
context; // Processing context with query factors
|
|
189
|
+
|
|
190
|
+
// Use float-based comparator for heap operations
|
|
191
|
+
using Cfloat = typename std::conditional<
|
|
192
|
+
C::is_max,
|
|
193
|
+
CMax<float, int64_t>,
|
|
194
|
+
CMin<float, int64_t>>::type;
|
|
195
|
+
|
|
196
|
+
IVFRaBitQHeapHandler(
|
|
197
|
+
const IndexIVFRaBitQFastScan* idx,
|
|
198
|
+
size_t nq_val,
|
|
199
|
+
size_t k_val,
|
|
200
|
+
float* distances,
|
|
201
|
+
int64_t* labels,
|
|
202
|
+
const FastScanDistancePostProcessing* ctx = nullptr);
|
|
203
|
+
|
|
204
|
+
void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final;
|
|
205
|
+
|
|
206
|
+
/// Override base class virtual method to receive context information
|
|
207
|
+
void set_list_context(size_t list_no, const std::vector<int>& probe_map)
|
|
208
|
+
override;
|
|
209
|
+
|
|
210
|
+
void begin(const float* norms) override;
|
|
211
|
+
|
|
212
|
+
void end() override;
|
|
213
|
+
};
|
|
214
|
+
};
|
|
215
|
+
|
|
216
|
+
} // namespace faiss
|
|
@@ -331,7 +331,7 @@ void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
|
|
|
331
331
|
/*
|
|
332
332
|
Check that the encoder is a single vector transform followed by a LSH
|
|
333
333
|
that just does thresholding.
|
|
334
|
-
If this is not the case, the linear transform +
|
|
334
|
+
If this is not the case, the linear transform + thresholds of the IndexLSH
|
|
335
335
|
should be merged into the VectorTransform (which is feasible).
|
|
336
336
|
*/
|
|
337
337
|
|
|
@@ -79,7 +79,7 @@ struct IndexIVFSpectralHash : IndexIVF {
|
|
|
79
79
|
*/
|
|
80
80
|
void replace_vt(VectorTransform* vt, bool own = false);
|
|
81
81
|
|
|
82
|
-
/** convenience function to get the VT from an index
|
|
82
|
+
/** convenience function to get the VT from an index constructed by an
|
|
83
83
|
* index_factory (should end in "LSH") */
|
|
84
84
|
void replace_vt(IndexPreTransform* index, bool own = false);
|
|
85
85
|
|
|
@@ -154,7 +154,7 @@ void IndexNNDescent::add(idx_t n, const float* x) {
|
|
|
154
154
|
|
|
155
155
|
if (ntotal != 0) {
|
|
156
156
|
fprintf(stderr,
|
|
157
|
-
"WARNING NNDescent
|
|
157
|
+
"WARNING NNDescent does not support dynamic insertions,"
|
|
158
158
|
"multiple insertions would lead to re-building the index");
|
|
159
159
|
}
|
|
160
160
|
|
|
@@ -261,7 +261,7 @@ void IndexNSG::check_knn_graph(const idx_t* knn_graph, idx_t n, int K) const {
|
|
|
261
261
|
}
|
|
262
262
|
FAISS_THROW_IF_NOT_MSG(
|
|
263
263
|
total_count < n / 10,
|
|
264
|
-
"There are too
|
|
264
|
+
"There are too many invalid entries in the knn graph. "
|
|
265
265
|
"It may be an invalid knn graph.");
|
|
266
266
|
}
|
|
267
267
|
|
|
@@ -29,7 +29,7 @@ struct IndexNeuralNetCodec : IndexFlatCodes {
|
|
|
29
29
|
void sa_encode(idx_t n, const float* x, uint8_t* codes) const override;
|
|
30
30
|
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
|
|
31
31
|
|
|
32
|
-
~IndexNeuralNetCodec() {}
|
|
32
|
+
~IndexNeuralNetCodec() override {}
|
|
33
33
|
};
|
|
34
34
|
|
|
35
35
|
struct IndexQINCo : IndexNeuralNetCodec {
|
|
@@ -164,7 +164,7 @@ struct MultiIndexQuantizer : Index {
|
|
|
164
164
|
// block size used in MultiIndexQuantizer::search
|
|
165
165
|
FAISS_API extern int multi_index_quantizer_search_bs;
|
|
166
166
|
|
|
167
|
-
/** MultiIndexQuantizer where the PQ
|
|
167
|
+
/** MultiIndexQuantizer where the PQ assignment is performed by sub-indexes
|
|
168
168
|
*/
|
|
169
169
|
struct MultiIndexQuantizer2 : MultiIndexQuantizer {
|
|
170
170
|
/// M Indexes on d / M dimensions
|
|
@@ -9,6 +9,7 @@
|
|
|
9
9
|
|
|
10
10
|
#include <memory>
|
|
11
11
|
|
|
12
|
+
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
12
13
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
13
14
|
#include <faiss/utils/utils.h>
|
|
14
15
|
|
|
@@ -53,8 +54,11 @@ void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
|
53
54
|
pq.compute_codes(x, codes, n);
|
|
54
55
|
}
|
|
55
56
|
|
|
56
|
-
void IndexPQFastScan::compute_float_LUT(
|
|
57
|
-
|
|
57
|
+
void IndexPQFastScan::compute_float_LUT(
|
|
58
|
+
float* lut,
|
|
59
|
+
idx_t n,
|
|
60
|
+
const float* x,
|
|
61
|
+
const FastScanDistancePostProcessing&) const {
|
|
58
62
|
if (metric_type == METRIC_L2) {
|
|
59
63
|
pq.compute_distance_tables(n, x, lut);
|
|
60
64
|
} else {
|
|
@@ -45,7 +45,11 @@ struct IndexPQFastScan : IndexFastScan {
|
|
|
45
45
|
|
|
46
46
|
void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
|
|
47
47
|
|
|
48
|
-
void compute_float_LUT(
|
|
48
|
+
void compute_float_LUT(
|
|
49
|
+
float* lut,
|
|
50
|
+
idx_t n,
|
|
51
|
+
const float* x,
|
|
52
|
+
const FastScanDistancePostProcessing& context) const override;
|
|
49
53
|
|
|
50
54
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
51
55
|
};
|
|
@@ -55,16 +55,17 @@ void IndexRaBitQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
|
55
55
|
|
|
56
56
|
FlatCodesDistanceComputer* IndexRaBitQ::get_FlatCodesDistanceComputer() const {
|
|
57
57
|
FlatCodesDistanceComputer* dc =
|
|
58
|
-
rabitq.get_distance_computer(qb, center.data());
|
|
58
|
+
rabitq.get_distance_computer(qb, center.data(), centered);
|
|
59
59
|
dc->code_size = rabitq.code_size;
|
|
60
60
|
dc->codes = codes.data();
|
|
61
61
|
return dc;
|
|
62
62
|
}
|
|
63
63
|
|
|
64
64
|
FlatCodesDistanceComputer* IndexRaBitQ::get_quantized_distance_computer(
|
|
65
|
-
const uint8_t qb
|
|
65
|
+
const uint8_t qb,
|
|
66
|
+
bool centered) const {
|
|
66
67
|
FlatCodesDistanceComputer* dc =
|
|
67
|
-
rabitq.get_distance_computer(qb, center.data());
|
|
68
|
+
rabitq.get_distance_computer(qb, center.data(), centered);
|
|
68
69
|
dc->code_size = rabitq.code_size;
|
|
69
70
|
dc->codes = codes.data();
|
|
70
71
|
return dc;
|
|
@@ -76,6 +77,7 @@ struct Run_search_with_dc_res {
|
|
|
76
77
|
using T = void;
|
|
77
78
|
|
|
78
79
|
uint8_t qb = 0;
|
|
80
|
+
bool centered = false;
|
|
79
81
|
|
|
80
82
|
template <class BlockResultHandler>
|
|
81
83
|
void f(BlockResultHandler& res, const IndexRaBitQ* index, const float* xq) {
|
|
@@ -87,7 +89,7 @@ struct Run_search_with_dc_res {
|
|
|
87
89
|
#pragma omp parallel // if (res.nq > 100)
|
|
88
90
|
{
|
|
89
91
|
std::unique_ptr<FlatCodesDistanceComputer> dc(
|
|
90
|
-
index->get_quantized_distance_computer(qb));
|
|
92
|
+
index->get_quantized_distance_computer(qb, centered));
|
|
91
93
|
SingleResultHandler resi(res);
|
|
92
94
|
#pragma omp for
|
|
93
95
|
for (int64_t q = 0; q < res.nq; q++) {
|
|
@@ -114,14 +116,15 @@ void IndexRaBitQ::search(
|
|
|
114
116
|
float* distances,
|
|
115
117
|
idx_t* labels,
|
|
116
118
|
const SearchParameters* params_in) const {
|
|
117
|
-
uint8_t used_qb = qb;
|
|
118
|
-
if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
|
|
119
|
-
used_qb = params->qb;
|
|
120
|
-
}
|
|
121
|
-
|
|
122
119
|
const IDSelector* sel = (params_in != nullptr) ? params_in->sel : nullptr;
|
|
123
120
|
Run_search_with_dc_res r;
|
|
124
|
-
|
|
121
|
+
if (auto params = dynamic_cast<const RaBitQSearchParameters*>(params_in)) {
|
|
122
|
+
r.qb = params->qb;
|
|
123
|
+
r.centered = params->centered;
|
|
124
|
+
} else {
|
|
125
|
+
r.qb = this->qb;
|
|
126
|
+
r.centered = this->centered;
|
|
127
|
+
}
|
|
125
128
|
|
|
126
129
|
dispatch_knn_ResultHandler(
|
|
127
130
|
n, distances, labels, k, metric_type, sel, r, this, x);
|
|
@@ -14,6 +14,7 @@ namespace faiss {
|
|
|
14
14
|
|
|
15
15
|
struct RaBitQSearchParameters : SearchParameters {
|
|
16
16
|
uint8_t qb = 0;
|
|
17
|
+
bool centered = false;
|
|
17
18
|
};
|
|
18
19
|
|
|
19
20
|
struct IndexRaBitQ : IndexFlatCodes {
|
|
@@ -26,9 +27,12 @@ struct IndexRaBitQ : IndexFlatCodes {
|
|
|
26
27
|
// use '0' to disable quantization and use raw fp32 values.
|
|
27
28
|
uint8_t qb = 0;
|
|
28
29
|
|
|
30
|
+
// quantize the query with a zero-centered scalar quantizer.
|
|
31
|
+
bool centered = false;
|
|
32
|
+
|
|
29
33
|
IndexRaBitQ();
|
|
30
34
|
|
|
31
|
-
IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
|
|
35
|
+
explicit IndexRaBitQ(idx_t d, MetricType metric = METRIC_L2);
|
|
32
36
|
|
|
33
37
|
void train(idx_t n, const float* x) override;
|
|
34
38
|
|
|
@@ -42,7 +46,8 @@ struct IndexRaBitQ : IndexFlatCodes {
|
|
|
42
46
|
// returns a quantized-to-qb bits DC if qb_in > 0
|
|
43
47
|
// returns a default fp32-based DC if qb_in == 0
|
|
44
48
|
FlatCodesDistanceComputer* get_quantized_distance_computer(
|
|
45
|
-
const uint8_t qb_in
|
|
49
|
+
const uint8_t qb_in,
|
|
50
|
+
bool centered) const;
|
|
46
51
|
|
|
47
52
|
// Don't rely on sa_decode(), bcz it is good for IP, but not for L2.
|
|
48
53
|
// As a result, use get_FlatCodesDistanceComputer() for the search.
|