faiss 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +2 -0
- data/ext/faiss/index.cpp +33 -6
- data/ext/faiss/index_binary.cpp +17 -4
- data/ext/faiss/kmeans.cpp +6 -6
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +2 -3
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +2 -2
- data/vendor/faiss/faiss/Clustering.h +2 -2
- data/vendor/faiss/faiss/IVFlib.cpp +26 -51
- data/vendor/faiss/faiss/IVFlib.h +1 -1
- data/vendor/faiss/faiss/Index.cpp +11 -0
- data/vendor/faiss/faiss/Index.h +34 -11
- data/vendor/faiss/faiss/Index2Layer.cpp +1 -1
- data/vendor/faiss/faiss/Index2Layer.h +2 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +9 -4
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexBinary.h +7 -7
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +8 -2
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +3 -3
- data/vendor/faiss/faiss/IndexBinaryHash.h +5 -5
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +7 -6
- data/vendor/faiss/faiss/IndexFastScan.cpp +125 -49
- data/vendor/faiss/faiss/IndexFastScan.h +102 -7
- data/vendor/faiss/faiss/IndexFlat.cpp +374 -4
- data/vendor/faiss/faiss/IndexFlat.h +81 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +93 -2
- data/vendor/faiss/faiss/IndexHNSW.h +58 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +14 -13
- data/vendor/faiss/faiss/IndexIDMap.h +6 -6
- data/vendor/faiss/faiss/IndexIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.h +5 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +9 -3
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +3 -1
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +176 -90
- data/vendor/faiss/faiss/IndexIVFFastScan.h +173 -18
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +1 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +251 -0
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.h +64 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +134 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +99 -8
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +4 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +828 -0
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +252 -0
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +1 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +4 -1
- data/vendor/faiss/faiss/IndexPQ.h +1 -1
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +6 -2
- data/vendor/faiss/faiss/IndexPQFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +14 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +9 -0
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +96 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +11 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +731 -0
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +175 -0
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -0
- data/vendor/faiss/faiss/IndexRefine.h +17 -0
- data/vendor/faiss/faiss/IndexShards.cpp +1 -1
- data/vendor/faiss/faiss/MatrixStats.cpp +3 -3
- data/vendor/faiss/faiss/MetricType.h +1 -1
- data/vendor/faiss/faiss/VectorTransform.h +2 -2
- data/vendor/faiss/faiss/clone_index.cpp +5 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +3 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +11 -11
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +11 -7
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +7 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +2 -2
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -1
- data/vendor/faiss/faiss/impl/CodePacker.h +2 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +77 -6
- data/vendor/faiss/faiss/impl/FastScanDistancePostProcessing.h +53 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +295 -16
- data/vendor/faiss/faiss/impl/HNSW.h +35 -6
- data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -2
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NNDescent.h +2 -2
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/Panorama.cpp +193 -0
- data/vendor/faiss/faiss/impl/Panorama.h +204 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.cpp +33 -0
- data/vendor/faiss/faiss/impl/PanoramaStats.h +38 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +5 -5
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +2 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/RaBitQStats.cpp +29 -0
- data/vendor/faiss/faiss/impl/RaBitQStats.h +56 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +294 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +330 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +304 -223
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +72 -4
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +362 -0
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +112 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +4 -4
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +7 -10
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +7 -4
- data/vendor/faiss/faiss/impl/index_read.cpp +238 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +212 -19
- data/vendor/faiss/faiss/impl/io.cpp +2 -2
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +1 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +1 -1
- data/vendor/faiss/faiss/impl/lattice_Zn.h +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.cpp +2 -2
- data/vendor/faiss/faiss/impl/mapped_io.h +4 -3
- data/vendor/faiss/faiss/impl/maybe_owned_vector.h +8 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +12 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +30 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +14 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +5 -6
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +55 -11
- data/vendor/faiss/faiss/impl/svs_io.cpp +86 -0
- data/vendor/faiss/faiss/impl/svs_io.h +67 -0
- data/vendor/faiss/faiss/impl/zerocopy_io.h +1 -1
- data/vendor/faiss/faiss/index_factory.cpp +217 -8
- data/vendor/faiss/faiss/index_factory.h +1 -1
- data/vendor/faiss/faiss/index_io.h +1 -1
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +115 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.h +46 -0
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +261 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.cpp +117 -0
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +66 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.cpp +245 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +137 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.cpp +39 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLVQ.h +42 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +149 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +58 -0
- data/vendor/faiss/faiss/utils/AlignedTable.h +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +2 -2
- data/vendor/faiss/faiss/utils/Heap.h +3 -3
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +1 -1
- data/vendor/faiss/faiss/utils/NeuralNet.h +3 -3
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +2 -2
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +1 -1
- data/vendor/faiss/faiss/utils/distances.cpp +0 -3
- data/vendor/faiss/faiss/utils/distances.h +2 -2
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +3 -1
- data/vendor/faiss/faiss/utils/hamming-inl.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +7 -6
- data/vendor/faiss/faiss/utils/hamming.h +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -2
- data/vendor/faiss/faiss/utils/partitioning.cpp +5 -5
- data/vendor/faiss/faiss/utils/partitioning.h +2 -2
- data/vendor/faiss/faiss/utils/rabitq_simd.h +222 -336
- data/vendor/faiss/faiss/utils/random.cpp +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +1 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +2 -2
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +1 -1
- data/vendor/faiss/faiss/utils/utils.cpp +9 -2
- data/vendor/faiss/faiss/utils/utils.h +2 -2
- metadata +29 -1
|
@@ -8,12 +8,15 @@
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
10
|
#include <faiss/Index.h>
|
|
11
|
+
#include <faiss/impl/FastScanDistancePostProcessing.h>
|
|
11
12
|
#include <faiss/utils/AlignedTable.h>
|
|
12
13
|
|
|
13
14
|
namespace faiss {
|
|
14
15
|
|
|
15
16
|
struct CodePacker;
|
|
16
17
|
struct NormTableScaler;
|
|
18
|
+
struct IDSelector;
|
|
19
|
+
struct SIMDResultHandlerToFloat;
|
|
17
20
|
|
|
18
21
|
/** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
|
|
19
22
|
*
|
|
@@ -54,6 +57,14 @@ struct IndexFastScan : Index {
|
|
|
54
57
|
// (set when initialized by IndexPQ or IndexAQ)
|
|
55
58
|
const uint8_t* orig_codes = nullptr;
|
|
56
59
|
|
|
60
|
+
/** Initialize the fast scan index
|
|
61
|
+
*
|
|
62
|
+
* @param d dimensionality of vectors
|
|
63
|
+
* @param M number of subquantizers
|
|
64
|
+
* @param nbits number of bits per subquantizer
|
|
65
|
+
* @param metric distance metric to use
|
|
66
|
+
* @param bbs block size for SIMD processing
|
|
67
|
+
*/
|
|
57
68
|
void init_fastscan(
|
|
58
69
|
int d,
|
|
59
70
|
size_t M,
|
|
@@ -65,6 +76,15 @@ struct IndexFastScan : Index {
|
|
|
65
76
|
|
|
66
77
|
void reset() override;
|
|
67
78
|
|
|
79
|
+
/** Search for k nearest neighbors
|
|
80
|
+
*
|
|
81
|
+
* @param n number of query vectors
|
|
82
|
+
* @param x query vectors (n * d)
|
|
83
|
+
* @param k number of nearest neighbors to find
|
|
84
|
+
* @param distances output distances (n * k)
|
|
85
|
+
* @param labels output labels/indices (n * k)
|
|
86
|
+
* @param params optional search parameters
|
|
87
|
+
*/
|
|
68
88
|
void search(
|
|
69
89
|
idx_t n,
|
|
70
90
|
const float* x,
|
|
@@ -73,20 +93,70 @@ struct IndexFastScan : Index {
|
|
|
73
93
|
idx_t* labels,
|
|
74
94
|
const SearchParameters* params = nullptr) const override;
|
|
75
95
|
|
|
96
|
+
/** Add vectors to the index
|
|
97
|
+
*
|
|
98
|
+
* @param n number of vectors to add
|
|
99
|
+
* @param x vectors to add (n * d)
|
|
100
|
+
*/
|
|
76
101
|
void add(idx_t n, const float* x) override;
|
|
77
102
|
|
|
103
|
+
/** Compute codes for vectors
|
|
104
|
+
*
|
|
105
|
+
* @param codes output codes
|
|
106
|
+
* @param n number of vectors to encode
|
|
107
|
+
* @param x vectors to encode (n * d)
|
|
108
|
+
*/
|
|
78
109
|
virtual void compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
79
110
|
const = 0;
|
|
80
111
|
|
|
81
|
-
|
|
82
|
-
|
|
112
|
+
/** Compute floating-point lookup table for distance computation
|
|
113
|
+
*
|
|
114
|
+
* @param lut output lookup table
|
|
115
|
+
* @param n number of query vectors
|
|
116
|
+
* @param x query vectors (n * d)
|
|
117
|
+
* @param context processing context containing all processors
|
|
118
|
+
*/
|
|
119
|
+
virtual void compute_float_LUT(
|
|
120
|
+
float* lut,
|
|
121
|
+
idx_t n,
|
|
122
|
+
const float* x,
|
|
123
|
+
const FastScanDistancePostProcessing& context) const = 0;
|
|
124
|
+
|
|
125
|
+
/** Create a KNN handler for this index type
|
|
126
|
+
*
|
|
127
|
+
* This method can be overridden by derived classes to provide
|
|
128
|
+
* specialized handlers (e.g., RaBitQHeapHandler for RaBitQ indexes).
|
|
129
|
+
* Base implementation creates standard handlers based on k and impl.
|
|
130
|
+
*
|
|
131
|
+
* @param is_max whether to use CMax comparator (true) or CMin (false)
|
|
132
|
+
* @param impl implementation number
|
|
133
|
+
* @param n number of queries
|
|
134
|
+
* @param k number of neighbors to find
|
|
135
|
+
* @param ntotal total number of vectors in database
|
|
136
|
+
* @param distances output distances array
|
|
137
|
+
* @param labels output labels array
|
|
138
|
+
* @param sel optional ID selector
|
|
139
|
+
* @param context processing context for distance post-processing
|
|
140
|
+
* @return pointer to created handler (never returns nullptr)
|
|
141
|
+
*/
|
|
142
|
+
virtual SIMDResultHandlerToFloat* make_knn_handler(
|
|
143
|
+
bool is_max,
|
|
144
|
+
int impl,
|
|
145
|
+
idx_t n,
|
|
146
|
+
idx_t k,
|
|
147
|
+
size_t ntotal,
|
|
148
|
+
float* distances,
|
|
149
|
+
idx_t* labels,
|
|
150
|
+
const IDSelector* sel,
|
|
151
|
+
const FastScanDistancePostProcessing& context) const;
|
|
83
152
|
|
|
84
153
|
// called by search function
|
|
85
154
|
void compute_quantized_LUT(
|
|
86
155
|
idx_t n,
|
|
87
156
|
const float* x,
|
|
88
157
|
uint8_t* lut,
|
|
89
|
-
float* normalizers
|
|
158
|
+
float* normalizers,
|
|
159
|
+
const FastScanDistancePostProcessing& context) const;
|
|
90
160
|
|
|
91
161
|
template <bool is_max>
|
|
92
162
|
void search_dispatch_implem(
|
|
@@ -95,7 +165,7 @@ struct IndexFastScan : Index {
|
|
|
95
165
|
idx_t k,
|
|
96
166
|
float* distances,
|
|
97
167
|
idx_t* labels,
|
|
98
|
-
const
|
|
168
|
+
const FastScanDistancePostProcessing& context) const;
|
|
99
169
|
|
|
100
170
|
template <class Cfloat>
|
|
101
171
|
void search_implem_234(
|
|
@@ -104,7 +174,7 @@ struct IndexFastScan : Index {
|
|
|
104
174
|
idx_t k,
|
|
105
175
|
float* distances,
|
|
106
176
|
idx_t* labels,
|
|
107
|
-
const
|
|
177
|
+
const FastScanDistancePostProcessing& context) const;
|
|
108
178
|
|
|
109
179
|
template <class C>
|
|
110
180
|
void search_implem_12(
|
|
@@ -114,7 +184,7 @@ struct IndexFastScan : Index {
|
|
|
114
184
|
float* distances,
|
|
115
185
|
idx_t* labels,
|
|
116
186
|
int impl,
|
|
117
|
-
const
|
|
187
|
+
const FastScanDistancePostProcessing& context) const;
|
|
118
188
|
|
|
119
189
|
template <class C>
|
|
120
190
|
void search_implem_14(
|
|
@@ -124,14 +194,39 @@ struct IndexFastScan : Index {
|
|
|
124
194
|
float* distances,
|
|
125
195
|
idx_t* labels,
|
|
126
196
|
int impl,
|
|
127
|
-
const
|
|
197
|
+
const FastScanDistancePostProcessing& context) const;
|
|
128
198
|
|
|
199
|
+
/** Reconstruct a vector from its code
|
|
200
|
+
*
|
|
201
|
+
* @param key index of vector to reconstruct
|
|
202
|
+
* @param recons output reconstructed vector
|
|
203
|
+
*/
|
|
129
204
|
void reconstruct(idx_t key, float* recons) const override;
|
|
205
|
+
|
|
206
|
+
/** Remove vectors by ID selector
|
|
207
|
+
*
|
|
208
|
+
* @param sel selector defining which vectors to remove
|
|
209
|
+
* @return number of vectors removed
|
|
210
|
+
*/
|
|
130
211
|
size_t remove_ids(const IDSelector& sel) override;
|
|
131
212
|
|
|
213
|
+
/** Get the code packer for this index
|
|
214
|
+
*
|
|
215
|
+
* @return pointer to the code packer
|
|
216
|
+
*/
|
|
132
217
|
CodePacker* get_CodePacker() const;
|
|
133
218
|
|
|
219
|
+
/** Merge another index into this one
|
|
220
|
+
*
|
|
221
|
+
* @param otherIndex index to merge from
|
|
222
|
+
* @param add_id ID offset to add to merged vectors
|
|
223
|
+
*/
|
|
134
224
|
void merge_from(Index& otherIndex, idx_t add_id = 0) override;
|
|
225
|
+
|
|
226
|
+
/** Check if another index is compatible for merging
|
|
227
|
+
*
|
|
228
|
+
* @param otherIndex index to check compatibility with
|
|
229
|
+
*/
|
|
135
230
|
void check_compatible_for_merge(const Index& otherIndex) const override;
|
|
136
231
|
|
|
137
232
|
/// standalone codes interface (but the codes are flattened)
|
|
@@ -11,12 +11,15 @@
|
|
|
11
11
|
|
|
12
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
13
13
|
#include <faiss/impl/FaissAssert.h>
|
|
14
|
+
#include <faiss/impl/ResultHandler.h>
|
|
14
15
|
#include <faiss/utils/Heap.h>
|
|
15
16
|
#include <faiss/utils/distances.h>
|
|
16
17
|
#include <faiss/utils/extra_distances.h>
|
|
17
18
|
#include <faiss/utils/prefetch.h>
|
|
18
19
|
#include <faiss/utils/sorting.h>
|
|
20
|
+
#include <omp.h>
|
|
19
21
|
#include <cstring>
|
|
22
|
+
#include <numeric>
|
|
20
23
|
|
|
21
24
|
namespace faiss {
|
|
22
25
|
|
|
@@ -100,15 +103,24 @@ namespace {
|
|
|
100
103
|
struct FlatL2Dis : FlatCodesDistanceComputer {
|
|
101
104
|
size_t d;
|
|
102
105
|
idx_t nb;
|
|
103
|
-
const float* q;
|
|
104
106
|
const float* b;
|
|
105
107
|
size_t ndis;
|
|
108
|
+
size_t npartial_dot_products;
|
|
106
109
|
|
|
107
110
|
float distance_to_code(const uint8_t* code) final {
|
|
108
111
|
ndis++;
|
|
109
112
|
return fvec_L2sqr(q, (float*)code, d);
|
|
110
113
|
}
|
|
111
114
|
|
|
115
|
+
float partial_dot_product(
|
|
116
|
+
const idx_t i,
|
|
117
|
+
const uint32_t offset,
|
|
118
|
+
const uint32_t num_components) final override {
|
|
119
|
+
npartial_dot_products++;
|
|
120
|
+
return fvec_inner_product(
|
|
121
|
+
q + offset, b + i * d + offset, num_components);
|
|
122
|
+
}
|
|
123
|
+
|
|
112
124
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
113
125
|
return fvec_L2sqr(b + j * d, b + i * d, d);
|
|
114
126
|
}
|
|
@@ -116,12 +128,13 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
|
|
|
116
128
|
explicit FlatL2Dis(const IndexFlat& storage, const float* q = nullptr)
|
|
117
129
|
: FlatCodesDistanceComputer(
|
|
118
130
|
storage.codes.data(),
|
|
119
|
-
storage.code_size
|
|
131
|
+
storage.code_size,
|
|
132
|
+
q),
|
|
120
133
|
d(storage.d),
|
|
121
134
|
nb(storage.ntotal),
|
|
122
|
-
q(q),
|
|
123
135
|
b(storage.get_xb()),
|
|
124
|
-
ndis(0)
|
|
136
|
+
ndis(0),
|
|
137
|
+
npartial_dot_products(0) {}
|
|
125
138
|
|
|
126
139
|
void set_query(const float* x) override {
|
|
127
140
|
q = x;
|
|
@@ -159,6 +172,50 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
|
|
|
159
172
|
dis2 = dp2;
|
|
160
173
|
dis3 = dp3;
|
|
161
174
|
}
|
|
175
|
+
|
|
176
|
+
void partial_dot_product_batch_4(
|
|
177
|
+
const idx_t idx0,
|
|
178
|
+
const idx_t idx1,
|
|
179
|
+
const idx_t idx2,
|
|
180
|
+
const idx_t idx3,
|
|
181
|
+
float& dp0,
|
|
182
|
+
float& dp1,
|
|
183
|
+
float& dp2,
|
|
184
|
+
float& dp3,
|
|
185
|
+
const uint32_t offset,
|
|
186
|
+
const uint32_t num_components) final override {
|
|
187
|
+
npartial_dot_products += 4;
|
|
188
|
+
|
|
189
|
+
// compute first, assign next
|
|
190
|
+
const float* __restrict y0 =
|
|
191
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
192
|
+
const float* __restrict y1 =
|
|
193
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
194
|
+
const float* __restrict y2 =
|
|
195
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
196
|
+
const float* __restrict y3 =
|
|
197
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
198
|
+
|
|
199
|
+
float dp0_ = 0;
|
|
200
|
+
float dp1_ = 0;
|
|
201
|
+
float dp2_ = 0;
|
|
202
|
+
float dp3_ = 0;
|
|
203
|
+
fvec_inner_product_batch_4(
|
|
204
|
+
q + offset,
|
|
205
|
+
y0 + offset,
|
|
206
|
+
y1 + offset,
|
|
207
|
+
y2 + offset,
|
|
208
|
+
y3 + offset,
|
|
209
|
+
num_components,
|
|
210
|
+
dp0_,
|
|
211
|
+
dp1_,
|
|
212
|
+
dp2_,
|
|
213
|
+
dp3_);
|
|
214
|
+
dp0 = dp0_;
|
|
215
|
+
dp1 = dp1_;
|
|
216
|
+
dp2 = dp2_;
|
|
217
|
+
dp3 = dp3_;
|
|
218
|
+
}
|
|
162
219
|
};
|
|
163
220
|
|
|
164
221
|
struct FlatIPDis : FlatCodesDistanceComputer {
|
|
@@ -519,4 +576,317 @@ void IndexFlat1D::search(
|
|
|
519
576
|
done:;
|
|
520
577
|
}
|
|
521
578
|
}
|
|
579
|
+
|
|
580
|
+
/**************************************************************
|
|
581
|
+
* shared flat Panorama search code
|
|
582
|
+
**************************************************************/
|
|
583
|
+
|
|
584
|
+
namespace {
|
|
585
|
+
|
|
586
|
+
template <bool use_radius, typename BlockHandler>
|
|
587
|
+
inline void flat_pano_search_core(
|
|
588
|
+
const IndexFlatPanorama& index,
|
|
589
|
+
BlockHandler& handler,
|
|
590
|
+
idx_t n,
|
|
591
|
+
const float* x,
|
|
592
|
+
float radius,
|
|
593
|
+
const SearchParameters* params) {
|
|
594
|
+
using SingleResultHandler = typename BlockHandler::SingleResultHandler;
|
|
595
|
+
|
|
596
|
+
IDSelector* sel = params ? params->sel : nullptr;
|
|
597
|
+
bool use_sel = sel != nullptr;
|
|
598
|
+
|
|
599
|
+
[[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
|
|
600
|
+
size_t n_batches = (index.ntotal + index.batch_size - 1) / index.batch_size;
|
|
601
|
+
|
|
602
|
+
#pragma omp parallel num_threads(nt)
|
|
603
|
+
{
|
|
604
|
+
SingleResultHandler res(handler);
|
|
605
|
+
|
|
606
|
+
std::vector<float> query_cum_norms(index.n_levels + 1);
|
|
607
|
+
std::vector<float> exact_distances(index.batch_size);
|
|
608
|
+
std::vector<uint32_t> active_indices(index.batch_size);
|
|
609
|
+
|
|
610
|
+
#pragma omp for
|
|
611
|
+
for (int64_t i = 0; i < n; i++) {
|
|
612
|
+
const float* xi = x + i * index.d;
|
|
613
|
+
index.pano.compute_query_cum_sums(xi, query_cum_norms.data());
|
|
614
|
+
|
|
615
|
+
PanoramaStats local_stats;
|
|
616
|
+
local_stats.reset();
|
|
617
|
+
|
|
618
|
+
res.begin(i);
|
|
619
|
+
|
|
620
|
+
for (size_t batch_no = 0; batch_no < n_batches; batch_no++) {
|
|
621
|
+
size_t batch_start = batch_no * index.batch_size;
|
|
622
|
+
|
|
623
|
+
float threshold;
|
|
624
|
+
if constexpr (use_radius) {
|
|
625
|
+
threshold = radius;
|
|
626
|
+
} else {
|
|
627
|
+
threshold = res.heap_dis[0];
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
size_t num_active =
|
|
631
|
+
index.pano
|
|
632
|
+
.progressive_filter_batch<CMax<float, int64_t>>(
|
|
633
|
+
index.codes.data(),
|
|
634
|
+
index.cum_sums.data(),
|
|
635
|
+
xi,
|
|
636
|
+
query_cum_norms.data(),
|
|
637
|
+
batch_no,
|
|
638
|
+
index.ntotal,
|
|
639
|
+
sel,
|
|
640
|
+
nullptr,
|
|
641
|
+
use_sel,
|
|
642
|
+
active_indices,
|
|
643
|
+
exact_distances,
|
|
644
|
+
threshold,
|
|
645
|
+
local_stats);
|
|
646
|
+
|
|
647
|
+
for (size_t j = 0; j < num_active; j++) {
|
|
648
|
+
res.add_result(
|
|
649
|
+
exact_distances[active_indices[j]],
|
|
650
|
+
batch_start + active_indices[j]);
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
res.end();
|
|
655
|
+
indexPanorama_stats.add(local_stats);
|
|
656
|
+
}
|
|
657
|
+
}
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
} // anonymous namespace
|
|
661
|
+
|
|
662
|
+
/***************************************************
|
|
663
|
+
* IndexFlatPanorama
|
|
664
|
+
***************************************************/
|
|
665
|
+
|
|
666
|
+
void IndexFlatPanorama::add(idx_t n, const float* x) {
|
|
667
|
+
size_t offset = ntotal;
|
|
668
|
+
ntotal += n;
|
|
669
|
+
size_t num_batches = (ntotal + batch_size - 1) / batch_size;
|
|
670
|
+
|
|
671
|
+
codes.resize(num_batches * batch_size * code_size);
|
|
672
|
+
cum_sums.resize(num_batches * batch_size * (n_levels + 1));
|
|
673
|
+
|
|
674
|
+
const uint8_t* code = reinterpret_cast<const uint8_t*>(x);
|
|
675
|
+
pano.copy_codes_to_level_layout(codes.data(), offset, n, code);
|
|
676
|
+
pano.compute_cumulative_sums(cum_sums.data(), offset, n, x);
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
void IndexFlatPanorama::search(
|
|
680
|
+
idx_t n,
|
|
681
|
+
const float* x,
|
|
682
|
+
idx_t k,
|
|
683
|
+
float* distances,
|
|
684
|
+
idx_t* labels,
|
|
685
|
+
const SearchParameters* params) const {
|
|
686
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
687
|
+
FAISS_THROW_IF_NOT(batch_size >= k);
|
|
688
|
+
|
|
689
|
+
HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
|
|
690
|
+
size_t(n), distances, labels, size_t(k), nullptr);
|
|
691
|
+
|
|
692
|
+
flat_pano_search_core<false>(*this, handler, n, x, 0.0f, params);
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
void IndexFlatPanorama::range_search(
|
|
696
|
+
idx_t n,
|
|
697
|
+
const float* x,
|
|
698
|
+
float radius,
|
|
699
|
+
RangeSearchResult* result,
|
|
700
|
+
const SearchParameters* params) const {
|
|
701
|
+
RangeSearchBlockResultHandler<CMax<float, int64_t>, false> handler(
|
|
702
|
+
result, radius, nullptr);
|
|
703
|
+
|
|
704
|
+
flat_pano_search_core<true>(*this, handler, n, x, radius, params);
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
void IndexFlatPanorama::reset() {
|
|
708
|
+
IndexFlat::reset();
|
|
709
|
+
cum_sums.clear();
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
void IndexFlatPanorama::reconstruct(idx_t key, float* recons) const {
|
|
713
|
+
pano.reconstruct(key, recons, codes.data());
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
void IndexFlatPanorama::reconstruct_n(idx_t i, idx_t n, float* recons) const {
|
|
717
|
+
Index::reconstruct_n(i, n, recons);
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
size_t IndexFlatPanorama::remove_ids(const IDSelector& sel) {
|
|
721
|
+
idx_t j = 0;
|
|
722
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
723
|
+
if (sel.is_member(i)) {
|
|
724
|
+
// should be removed
|
|
725
|
+
} else {
|
|
726
|
+
if (i > j) {
|
|
727
|
+
pano.copy_entry(
|
|
728
|
+
codes.data(),
|
|
729
|
+
codes.data(),
|
|
730
|
+
cum_sums.data(),
|
|
731
|
+
cum_sums.data(),
|
|
732
|
+
j,
|
|
733
|
+
i);
|
|
734
|
+
}
|
|
735
|
+
j++;
|
|
736
|
+
}
|
|
737
|
+
}
|
|
738
|
+
size_t nremove = ntotal - j;
|
|
739
|
+
if (nremove > 0) {
|
|
740
|
+
ntotal = j;
|
|
741
|
+
size_t num_batches = (ntotal + batch_size - 1) / batch_size;
|
|
742
|
+
codes.resize(num_batches * batch_size * code_size);
|
|
743
|
+
cum_sums.resize(num_batches * batch_size * (n_levels + 1));
|
|
744
|
+
}
|
|
745
|
+
return nremove;
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
void IndexFlatPanorama::merge_from(Index& otherIndex, idx_t add_id) {
|
|
749
|
+
FAISS_THROW_IF_NOT_MSG(add_id == 0, "cannot set ids in FlatPanorama index");
|
|
750
|
+
check_compatible_for_merge(otherIndex);
|
|
751
|
+
IndexFlatPanorama* other = static_cast<IndexFlatPanorama*>(&otherIndex);
|
|
752
|
+
|
|
753
|
+
std::vector<float> buffer(other->ntotal * code_size);
|
|
754
|
+
otherIndex.reconstruct_n(0, other->ntotal, buffer.data());
|
|
755
|
+
|
|
756
|
+
add(other->ntotal, buffer.data());
|
|
757
|
+
other->reset();
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
void IndexFlatPanorama::add_sa_codes(
|
|
761
|
+
idx_t /* n */,
|
|
762
|
+
const uint8_t* /* codes_in */,
|
|
763
|
+
const idx_t* /* xids */) {
|
|
764
|
+
FAISS_THROW_MSG("add_sa_codes not implemented for IndexFlatPanorama");
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
void IndexFlatPanorama::permute_entries(const idx_t* perm) {
|
|
768
|
+
MaybeOwnedVector<uint8_t> new_codes(codes.size());
|
|
769
|
+
std::vector<float> new_cum_sums(cum_sums.size());
|
|
770
|
+
|
|
771
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
772
|
+
pano.copy_entry(
|
|
773
|
+
new_codes.data(),
|
|
774
|
+
codes.data(),
|
|
775
|
+
new_cum_sums.data(),
|
|
776
|
+
cum_sums.data(),
|
|
777
|
+
i,
|
|
778
|
+
perm[i]);
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
std::swap(codes, new_codes);
|
|
782
|
+
std::swap(cum_sums, new_cum_sums);
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
void IndexFlatPanorama::search_subset(
|
|
786
|
+
idx_t n,
|
|
787
|
+
const float* x,
|
|
788
|
+
idx_t k_base,
|
|
789
|
+
const idx_t* base_labels,
|
|
790
|
+
idx_t k,
|
|
791
|
+
float* distances,
|
|
792
|
+
idx_t* labels) const {
|
|
793
|
+
using SingleResultHandler =
|
|
794
|
+
HeapBlockResultHandler<CMax<float, int64_t>, false>::
|
|
795
|
+
SingleResultHandler;
|
|
796
|
+
HeapBlockResultHandler<CMax<float, int64_t>, false> handler(
|
|
797
|
+
size_t(n), distances, labels, size_t(k), nullptr);
|
|
798
|
+
|
|
799
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
800
|
+
FAISS_THROW_IF_NOT(batch_size == 1);
|
|
801
|
+
|
|
802
|
+
[[maybe_unused]] int nt = std::min(int(n), omp_get_max_threads());
|
|
803
|
+
|
|
804
|
+
#pragma omp parallel num_threads(nt)
|
|
805
|
+
{
|
|
806
|
+
SingleResultHandler res(handler);
|
|
807
|
+
|
|
808
|
+
std::vector<float> query_cum_norms(n_levels + 1);
|
|
809
|
+
|
|
810
|
+
// Panorama's optimized point-wise refinement (Algorithm 2):
|
|
811
|
+
// Batch-wise Panorama, as implemented in Panorama.h, incurs overhead
|
|
812
|
+
// from maintaining active_indices and exact_distances. This optimized
|
|
813
|
+
// implementation has minimal overhead and is thus preferred for
|
|
814
|
+
// IndexRefine's use case.
|
|
815
|
+
// 1. Initialize exact distance as ||y||^2 + ||x||^2.
|
|
816
|
+
// 2. For each level, refine distance incrementally:
|
|
817
|
+
// - Compute dot product for current level: exact_dist -= 2*<x,y>.
|
|
818
|
+
// - Use Cauchy-Schwarz bound on remaining levels to get lower bound.
|
|
819
|
+
// - If there are less than k points in the heap, add the point to
|
|
820
|
+
// the heap.
|
|
821
|
+
// - Else, prune if lower bound exceeds k-th best distance.
|
|
822
|
+
// 3. After all levels, update heap if the point survived.
|
|
823
|
+
#pragma omp for
|
|
824
|
+
for (idx_t i = 0; i < n; i++) {
|
|
825
|
+
const idx_t* __restrict idsi = base_labels + i * k_base;
|
|
826
|
+
const float* xi = x + i * d;
|
|
827
|
+
|
|
828
|
+
PanoramaStats local_stats;
|
|
829
|
+
local_stats.reset();
|
|
830
|
+
|
|
831
|
+
pano.compute_query_cum_sums(xi, query_cum_norms.data());
|
|
832
|
+
float query_cum_norm = query_cum_norms[0] * query_cum_norms[0];
|
|
833
|
+
|
|
834
|
+
res.begin(i);
|
|
835
|
+
|
|
836
|
+
for (size_t j = 0; j < k_base; j++) {
|
|
837
|
+
idx_t idx = idsi[j];
|
|
838
|
+
|
|
839
|
+
if (idx < 0) {
|
|
840
|
+
continue;
|
|
841
|
+
}
|
|
842
|
+
|
|
843
|
+
size_t cum_sum_offset = (n_levels + 1) * idx;
|
|
844
|
+
float cum_sum = cum_sums[cum_sum_offset];
|
|
845
|
+
float exact_distance = cum_sum * cum_sum + query_cum_norm;
|
|
846
|
+
cum_sum_offset++;
|
|
847
|
+
|
|
848
|
+
const float* x_ptr = xi;
|
|
849
|
+
const float* p_ptr =
|
|
850
|
+
reinterpret_cast<const float*>(codes.data()) + d * idx;
|
|
851
|
+
|
|
852
|
+
local_stats.total_dims += d;
|
|
853
|
+
|
|
854
|
+
bool pruned = false;
|
|
855
|
+
for (size_t level = 0; level < n_levels; level++) {
|
|
856
|
+
local_stats.total_dims_scanned += pano.level_width_floats;
|
|
857
|
+
|
|
858
|
+
// Refine distance
|
|
859
|
+
size_t actual_level_width = std::min(
|
|
860
|
+
pano.level_width_floats,
|
|
861
|
+
d - level * pano.level_width_floats);
|
|
862
|
+
float dot_product = fvec_inner_product(
|
|
863
|
+
x_ptr, p_ptr, actual_level_width);
|
|
864
|
+
exact_distance -= 2 * dot_product;
|
|
865
|
+
|
|
866
|
+
float cum_sum = cum_sums[cum_sum_offset];
|
|
867
|
+
float cauchy_schwarz_bound =
|
|
868
|
+
2.0f * cum_sum * query_cum_norms[level + 1];
|
|
869
|
+
float lower_bound = exact_distance - cauchy_schwarz_bound;
|
|
870
|
+
|
|
871
|
+
// Prune using Cauchy-Schwarz bound
|
|
872
|
+
if (lower_bound > res.heap_dis[0]) {
|
|
873
|
+
pruned = true;
|
|
874
|
+
break;
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
cum_sum_offset++;
|
|
878
|
+
x_ptr += pano.level_width_floats;
|
|
879
|
+
p_ptr += pano.level_width_floats;
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
if (!pruned) {
|
|
883
|
+
res.add_result(exact_distance, idx);
|
|
884
|
+
}
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
res.end();
|
|
888
|
+
indexPanorama_stats.add(local_stats);
|
|
889
|
+
}
|
|
890
|
+
}
|
|
891
|
+
}
|
|
522
892
|
} // namespace faiss
|