faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <memory>
|
|
11
|
+
|
|
12
|
+
#include <faiss/IndexIVF.h>
|
|
13
|
+
#include <faiss/utils/AlignedTable.h>
|
|
14
|
+
|
|
15
|
+
namespace faiss {
|
|
16
|
+
|
|
17
|
+
/** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now.
|
|
18
|
+
*
|
|
19
|
+
* The codes in the inverted lists are not stored sequentially but
|
|
20
|
+
* grouped in blocks of size bbs. This makes it possible to very quickly
|
|
21
|
+
* compute distances with SIMD instructions.
|
|
22
|
+
*
|
|
23
|
+
* Implementations (implem):
|
|
24
|
+
* 0: auto-select implementation (default)
|
|
25
|
+
* 1: orig's search, re-implemented
|
|
26
|
+
* 2: orig's search, re-ordered by invlist
|
|
27
|
+
* 10: optimizer int16 search, collect results in heap, no qbs
|
|
28
|
+
* 11: idem, collect results in reservoir
|
|
29
|
+
* 12: optimizer int16 search, collect results in heap, uses qbs
|
|
30
|
+
* 13: idem, collect results in reservoir
|
|
31
|
+
*/
|
|
32
|
+
|
|
33
|
+
struct IndexIVFFastScan : IndexIVF {
|
|
34
|
+
// size of the kernel
|
|
35
|
+
int bbs; // set at build time
|
|
36
|
+
|
|
37
|
+
size_t M;
|
|
38
|
+
size_t nbits;
|
|
39
|
+
size_t ksub;
|
|
40
|
+
|
|
41
|
+
// M rounded up to a multiple of 2
|
|
42
|
+
size_t M2;
|
|
43
|
+
|
|
44
|
+
// search-time implementation
|
|
45
|
+
int implem = 0;
|
|
46
|
+
// skip some parts of the computation (for timing)
|
|
47
|
+
int skip = 0;
|
|
48
|
+
bool by_residual = false;
|
|
49
|
+
|
|
50
|
+
// batching factors at search time (0 = default)
|
|
51
|
+
int qbs = 0;
|
|
52
|
+
size_t qbs2 = 0;
|
|
53
|
+
|
|
54
|
+
IndexIVFFastScan(
|
|
55
|
+
Index* quantizer,
|
|
56
|
+
size_t d,
|
|
57
|
+
size_t nlist,
|
|
58
|
+
size_t code_size,
|
|
59
|
+
MetricType metric = METRIC_L2);
|
|
60
|
+
|
|
61
|
+
IndexIVFFastScan();
|
|
62
|
+
|
|
63
|
+
void init_fastscan(
|
|
64
|
+
size_t M,
|
|
65
|
+
size_t nbits,
|
|
66
|
+
size_t nlist,
|
|
67
|
+
MetricType metric,
|
|
68
|
+
int bbs);
|
|
69
|
+
|
|
70
|
+
~IndexIVFFastScan() override;
|
|
71
|
+
|
|
72
|
+
/// orig's inverted lists (for debugging)
|
|
73
|
+
InvertedLists* orig_invlists = nullptr;
|
|
74
|
+
|
|
75
|
+
void add_with_ids(idx_t n, const float* x, const idx_t* xids) override;
|
|
76
|
+
|
|
77
|
+
// prepare look-up tables
|
|
78
|
+
|
|
79
|
+
virtual bool lookup_table_is_3d() const = 0;
|
|
80
|
+
|
|
81
|
+
virtual void compute_LUT(
|
|
82
|
+
size_t n,
|
|
83
|
+
const float* x,
|
|
84
|
+
const idx_t* coarse_ids,
|
|
85
|
+
const float* coarse_dis,
|
|
86
|
+
AlignedTable<float>& dis_tables,
|
|
87
|
+
AlignedTable<float>& biases) const = 0;
|
|
88
|
+
|
|
89
|
+
void compute_LUT_uint8(
|
|
90
|
+
size_t n,
|
|
91
|
+
const float* x,
|
|
92
|
+
const idx_t* coarse_ids,
|
|
93
|
+
const float* coarse_dis,
|
|
94
|
+
AlignedTable<uint8_t>& dis_tables,
|
|
95
|
+
AlignedTable<uint16_t>& biases,
|
|
96
|
+
float* normalizers) const;
|
|
97
|
+
|
|
98
|
+
void search(
|
|
99
|
+
idx_t n,
|
|
100
|
+
const float* x,
|
|
101
|
+
idx_t k,
|
|
102
|
+
float* distances,
|
|
103
|
+
idx_t* labels,
|
|
104
|
+
const SearchParameters* params = nullptr) const override;
|
|
105
|
+
|
|
106
|
+
/// will just fail
|
|
107
|
+
void range_search(
|
|
108
|
+
idx_t n,
|
|
109
|
+
const float* x,
|
|
110
|
+
float radius,
|
|
111
|
+
RangeSearchResult* result,
|
|
112
|
+
const SearchParameters* params = nullptr) const override;
|
|
113
|
+
|
|
114
|
+
// internal search funcs
|
|
115
|
+
|
|
116
|
+
template <bool is_max, class Scaler>
|
|
117
|
+
void search_dispatch_implem(
|
|
118
|
+
idx_t n,
|
|
119
|
+
const float* x,
|
|
120
|
+
idx_t k,
|
|
121
|
+
float* distances,
|
|
122
|
+
idx_t* labels,
|
|
123
|
+
const Scaler& scaler) const;
|
|
124
|
+
|
|
125
|
+
template <class C, class Scaler>
|
|
126
|
+
void search_implem_1(
|
|
127
|
+
idx_t n,
|
|
128
|
+
const float* x,
|
|
129
|
+
idx_t k,
|
|
130
|
+
float* distances,
|
|
131
|
+
idx_t* labels,
|
|
132
|
+
const Scaler& scaler) const;
|
|
133
|
+
|
|
134
|
+
template <class C, class Scaler>
|
|
135
|
+
void search_implem_2(
|
|
136
|
+
idx_t n,
|
|
137
|
+
const float* x,
|
|
138
|
+
idx_t k,
|
|
139
|
+
float* distances,
|
|
140
|
+
idx_t* labels,
|
|
141
|
+
const Scaler& scaler) const;
|
|
142
|
+
|
|
143
|
+
// implem 10 and 12 are not multithreaded internally, so
|
|
144
|
+
// export search stats
|
|
145
|
+
template <class C, class Scaler>
|
|
146
|
+
void search_implem_10(
|
|
147
|
+
idx_t n,
|
|
148
|
+
const float* x,
|
|
149
|
+
idx_t k,
|
|
150
|
+
float* distances,
|
|
151
|
+
idx_t* labels,
|
|
152
|
+
int impl,
|
|
153
|
+
size_t* ndis_out,
|
|
154
|
+
size_t* nlist_out,
|
|
155
|
+
const Scaler& scaler) const;
|
|
156
|
+
|
|
157
|
+
template <class C, class Scaler>
|
|
158
|
+
void search_implem_12(
|
|
159
|
+
idx_t n,
|
|
160
|
+
const float* x,
|
|
161
|
+
idx_t k,
|
|
162
|
+
float* distances,
|
|
163
|
+
idx_t* labels,
|
|
164
|
+
int impl,
|
|
165
|
+
size_t* ndis_out,
|
|
166
|
+
size_t* nlist_out,
|
|
167
|
+
const Scaler& scaler) const;
|
|
168
|
+
|
|
169
|
+
// implem 14 is mukltithreaded internally across nprobes and queries
|
|
170
|
+
template <class C, class Scaler>
|
|
171
|
+
void search_implem_14(
|
|
172
|
+
idx_t n,
|
|
173
|
+
const float* x,
|
|
174
|
+
idx_t k,
|
|
175
|
+
float* distances,
|
|
176
|
+
idx_t* labels,
|
|
177
|
+
int impl,
|
|
178
|
+
const Scaler& scaler) const;
|
|
179
|
+
|
|
180
|
+
// reconstruct vectors from packed invlists
|
|
181
|
+
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
|
182
|
+
const override;
|
|
183
|
+
|
|
184
|
+
// reconstruct orig invlists (for debugging)
|
|
185
|
+
void reconstruct_orig_invlists();
|
|
186
|
+
};
|
|
187
|
+
|
|
188
|
+
struct IVFFastScanStats {
|
|
189
|
+
uint64_t times[10];
|
|
190
|
+
uint64_t t_compute_distance_tables, t_round;
|
|
191
|
+
uint64_t t_copy_pack, t_scan, t_to_flat;
|
|
192
|
+
uint64_t reservoir_times[4];
|
|
193
|
+
double t_aq_encode;
|
|
194
|
+
double t_aq_norm_encode;
|
|
195
|
+
|
|
196
|
+
double Mcy_at(int i) {
|
|
197
|
+
return times[i] / (1000 * 1000.0);
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
double Mcy_reservoir_at(int i) {
|
|
201
|
+
return reservoir_times[i] / (1000 * 1000.0);
|
|
202
|
+
}
|
|
203
|
+
IVFFastScanStats() {
|
|
204
|
+
reset();
|
|
205
|
+
}
|
|
206
|
+
void reset() {
|
|
207
|
+
memset(this, 0, sizeof(*this));
|
|
208
|
+
}
|
|
209
|
+
};
|
|
210
|
+
|
|
211
|
+
FAISS_API extern IVFFastScanStats IVFFastScan_stats;
|
|
212
|
+
|
|
213
|
+
} // namespace faiss
|
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
#include <faiss/IndexFlat.h>
|
|
18
18
|
|
|
19
19
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
20
|
+
#include <faiss/impl/IDSelector.h>
|
|
21
|
+
|
|
20
22
|
#include <faiss/impl/FaissAssert.h>
|
|
21
23
|
#include <faiss/utils/distances.h>
|
|
22
24
|
#include <faiss/utils/utils.h>
|
|
@@ -40,9 +42,7 @@ void IndexIVFFlat::add_core(
|
|
|
40
42
|
idx_t n,
|
|
41
43
|
const float* x,
|
|
42
44
|
const int64_t* xids,
|
|
43
|
-
const int64_t* coarse_idx)
|
|
44
|
-
|
|
45
|
-
{
|
|
45
|
+
const int64_t* coarse_idx) {
|
|
46
46
|
FAISS_THROW_IF_NOT(is_trained);
|
|
47
47
|
FAISS_THROW_IF_NOT(coarse_idx);
|
|
48
48
|
assert(invlists);
|
|
@@ -118,13 +118,12 @@ void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
|
118
118
|
|
|
119
119
|
namespace {
|
|
120
120
|
|
|
121
|
-
template <MetricType metric, class C>
|
|
121
|
+
template <MetricType metric, class C, bool use_sel>
|
|
122
122
|
struct IVFFlatScanner : InvertedListScanner {
|
|
123
123
|
size_t d;
|
|
124
124
|
|
|
125
|
-
IVFFlatScanner(size_t d, bool store_pairs
|
|
126
|
-
|
|
127
|
-
}
|
|
125
|
+
IVFFlatScanner(size_t d, bool store_pairs, const IDSelector* sel)
|
|
126
|
+
: InvertedListScanner(store_pairs, sel), d(d) {}
|
|
128
127
|
|
|
129
128
|
const float* xi;
|
|
130
129
|
void set_query(const float* query) override {
|
|
@@ -154,6 +153,9 @@ struct IVFFlatScanner : InvertedListScanner {
|
|
|
154
153
|
size_t nup = 0;
|
|
155
154
|
for (size_t j = 0; j < list_size; j++) {
|
|
156
155
|
const float* yj = list_vecs + d * j;
|
|
156
|
+
if (use_sel && !sel->is_member(ids[j])) {
|
|
157
|
+
continue;
|
|
158
|
+
}
|
|
157
159
|
float dis = metric == METRIC_INNER_PRODUCT
|
|
158
160
|
? fvec_inner_product(xi, yj, d)
|
|
159
161
|
: fvec_L2sqr(xi, yj, d);
|
|
@@ -175,6 +177,9 @@ struct IVFFlatScanner : InvertedListScanner {
|
|
|
175
177
|
const float* list_vecs = (const float*)codes;
|
|
176
178
|
for (size_t j = 0; j < list_size; j++) {
|
|
177
179
|
const float* yj = list_vecs + d * j;
|
|
180
|
+
if (use_sel && !sel->is_member(ids[j])) {
|
|
181
|
+
continue;
|
|
182
|
+
}
|
|
178
183
|
float dis = metric == METRIC_INNER_PRODUCT
|
|
179
184
|
? fvec_inner_product(xi, yj, d)
|
|
180
185
|
: fvec_L2sqr(xi, yj, d);
|
|
@@ -186,20 +191,34 @@ struct IVFFlatScanner : InvertedListScanner {
|
|
|
186
191
|
}
|
|
187
192
|
};
|
|
188
193
|
|
|
194
|
+
template <bool use_sel>
|
|
195
|
+
InvertedListScanner* get_InvertedListScanner1(
|
|
196
|
+
const IndexIVFFlat* ivf,
|
|
197
|
+
bool store_pairs,
|
|
198
|
+
const IDSelector* sel) {
|
|
199
|
+
if (ivf->metric_type == METRIC_INNER_PRODUCT) {
|
|
200
|
+
return new IVFFlatScanner<
|
|
201
|
+
METRIC_INNER_PRODUCT,
|
|
202
|
+
CMin<float, int64_t>,
|
|
203
|
+
use_sel>(ivf->d, store_pairs, sel);
|
|
204
|
+
} else if (ivf->metric_type == METRIC_L2) {
|
|
205
|
+
return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>(
|
|
206
|
+
ivf->d, store_pairs, sel);
|
|
207
|
+
} else {
|
|
208
|
+
FAISS_THROW_MSG("metric type not supported");
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
|
|
189
212
|
} // anonymous namespace
|
|
190
213
|
|
|
191
214
|
InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
|
|
192
|
-
bool store_pairs
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
} else if (metric_type == METRIC_L2) {
|
|
197
|
-
return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>>(
|
|
198
|
-
d, store_pairs);
|
|
215
|
+
bool store_pairs,
|
|
216
|
+
const IDSelector* sel) const {
|
|
217
|
+
if (sel) {
|
|
218
|
+
return get_InvertedListScanner1<true>(this, store_pairs, sel);
|
|
199
219
|
} else {
|
|
200
|
-
|
|
220
|
+
return get_InvertedListScanner1<false>(this, store_pairs, sel);
|
|
201
221
|
}
|
|
202
|
-
return nullptr;
|
|
203
222
|
}
|
|
204
223
|
|
|
205
224
|
void IndexIVFFlat::reconstruct_from_offset(
|
|
@@ -447,7 +466,8 @@ void IndexIVFFlatDedup::range_search(
|
|
|
447
466
|
idx_t,
|
|
448
467
|
const float*,
|
|
449
468
|
float,
|
|
450
|
-
RangeSearchResult
|
|
469
|
+
RangeSearchResult*,
|
|
470
|
+
const SearchParameters*) const {
|
|
451
471
|
FAISS_THROW_MSG("not implemented");
|
|
452
472
|
}
|
|
453
473
|
|
|
@@ -42,7 +42,8 @@ struct IndexIVFFlat : IndexIVF {
|
|
|
42
42
|
bool include_listnos = false) const override;
|
|
43
43
|
|
|
44
44
|
InvertedListScanner* get_InvertedListScanner(
|
|
45
|
-
bool store_pairs
|
|
45
|
+
bool store_pairs,
|
|
46
|
+
const IDSelector* sel) const override;
|
|
46
47
|
|
|
47
48
|
void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
|
|
48
49
|
const override;
|
|
@@ -89,7 +90,8 @@ struct IndexIVFFlatDedup : IndexIVFFlat {
|
|
|
89
90
|
idx_t n,
|
|
90
91
|
const float* x,
|
|
91
92
|
float radius,
|
|
92
|
-
RangeSearchResult* result
|
|
93
|
+
RangeSearchResult* result,
|
|
94
|
+
const SearchParameters* params = nullptr) const override;
|
|
93
95
|
|
|
94
96
|
/// not implemented
|
|
95
97
|
void update_vectors(int nv, const idx_t* idx, const float* v) override;
|