faiss 0.3.1 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.h +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +35 -4
- data/vendor/faiss/faiss/Clustering.h +10 -1
- data/vendor/faiss/faiss/IVFlib.cpp +4 -1
- data/vendor/faiss/faiss/Index.h +21 -6
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
- data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
- data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
- data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
- data/vendor/faiss/faiss/IndexHNSW.h +52 -3
- data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVF.h +9 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
- data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.h +1 -1
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
- data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
- data/vendor/faiss/faiss/impl/HNSW.h +43 -22
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
- data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
- data/vendor/faiss/faiss/impl/io.cpp +13 -5
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
- data/vendor/faiss/faiss/index_factory.cpp +31 -13
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +58 -88
- data/vendor/faiss/faiss/utils/distances.h +5 -5
- data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +10 -3
- data/vendor/faiss/faiss/utils/utils.h +3 -0
- metadata +16 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -46,6 +46,7 @@ struct ResultHandler;
|
|
|
46
46
|
struct SearchParametersHNSW : SearchParameters {
|
|
47
47
|
int efSearch = 16;
|
|
48
48
|
bool check_relative_distance = true;
|
|
49
|
+
bool bounded_queue = true;
|
|
49
50
|
|
|
50
51
|
~SearchParametersHNSW() {}
|
|
51
52
|
};
|
|
@@ -141,9 +142,6 @@ struct HNSW {
|
|
|
141
142
|
/// enough?
|
|
142
143
|
bool check_relative_distance = true;
|
|
143
144
|
|
|
144
|
-
/// number of entry points in levels > 0.
|
|
145
|
-
int upper_beam = 1;
|
|
146
|
-
|
|
147
145
|
/// use bounded queue during exploration
|
|
148
146
|
bool search_bounded_queue = true;
|
|
149
147
|
|
|
@@ -184,7 +182,8 @@ struct HNSW {
|
|
|
184
182
|
float d_nearest,
|
|
185
183
|
int level,
|
|
186
184
|
omp_lock_t* locks,
|
|
187
|
-
VisitedTable& vt
|
|
185
|
+
VisitedTable& vt,
|
|
186
|
+
bool keep_max_size_level0 = false);
|
|
188
187
|
|
|
189
188
|
/** add point pt_id on all levels <= pt_level and build the link
|
|
190
189
|
* structure for them. */
|
|
@@ -193,7 +192,8 @@ struct HNSW {
|
|
|
193
192
|
int pt_level,
|
|
194
193
|
int pt_id,
|
|
195
194
|
std::vector<omp_lock_t>& locks,
|
|
196
|
-
VisitedTable& vt
|
|
195
|
+
VisitedTable& vt,
|
|
196
|
+
bool keep_max_size_level0 = false);
|
|
197
197
|
|
|
198
198
|
/// search interface for 1 point, single thread
|
|
199
199
|
HNSWStats search(
|
|
@@ -211,7 +211,8 @@ struct HNSW {
|
|
|
211
211
|
const float* nearest_d,
|
|
212
212
|
int search_type,
|
|
213
213
|
HNSWStats& search_stats,
|
|
214
|
-
VisitedTable& vt
|
|
214
|
+
VisitedTable& vt,
|
|
215
|
+
const SearchParametersHNSW* params = nullptr) const;
|
|
215
216
|
|
|
216
217
|
void reset();
|
|
217
218
|
|
|
@@ -224,40 +225,60 @@ struct HNSW {
|
|
|
224
225
|
DistanceComputer& qdis,
|
|
225
226
|
std::priority_queue<NodeDistFarther>& input,
|
|
226
227
|
std::vector<NodeDistFarther>& output,
|
|
227
|
-
int max_size
|
|
228
|
+
int max_size,
|
|
229
|
+
bool keep_max_size_level0 = false);
|
|
228
230
|
|
|
229
231
|
void permute_entries(const idx_t* map);
|
|
230
232
|
};
|
|
231
233
|
|
|
232
234
|
struct HNSWStats {
|
|
233
|
-
size_t n1
|
|
234
|
-
size_t
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
size_t n1 = 0,
|
|
239
|
-
size_t n2 = 0,
|
|
240
|
-
size_t n3 = 0,
|
|
241
|
-
size_t ndis = 0,
|
|
242
|
-
size_t nreorder = 0)
|
|
243
|
-
: n1(n1), n2(n2), n3(n3), ndis(ndis), nreorder(nreorder) {}
|
|
235
|
+
size_t n1 = 0; /// number of vectors searched
|
|
236
|
+
size_t n2 =
|
|
237
|
+
0; /// number of queries for which the candidate list is exhausted
|
|
238
|
+
size_t ndis = 0; /// number of distances computed
|
|
239
|
+
size_t nhops = 0; /// number of hops aka number of edges traversed
|
|
244
240
|
|
|
245
241
|
void reset() {
|
|
246
|
-
n1 = n2 =
|
|
242
|
+
n1 = n2 = 0;
|
|
247
243
|
ndis = 0;
|
|
248
|
-
|
|
244
|
+
nhops = 0;
|
|
249
245
|
}
|
|
250
246
|
|
|
251
247
|
void combine(const HNSWStats& other) {
|
|
252
248
|
n1 += other.n1;
|
|
253
249
|
n2 += other.n2;
|
|
254
|
-
n3 += other.n3;
|
|
255
250
|
ndis += other.ndis;
|
|
256
|
-
|
|
251
|
+
nhops += other.nhops;
|
|
257
252
|
}
|
|
258
253
|
};
|
|
259
254
|
|
|
260
255
|
// global var that collects them all
|
|
261
256
|
FAISS_API extern HNSWStats hnsw_stats;
|
|
262
257
|
|
|
258
|
+
int search_from_candidates(
|
|
259
|
+
const HNSW& hnsw,
|
|
260
|
+
DistanceComputer& qdis,
|
|
261
|
+
ResultHandler<HNSW::C>& res,
|
|
262
|
+
HNSW::MinimaxHeap& candidates,
|
|
263
|
+
VisitedTable& vt,
|
|
264
|
+
HNSWStats& stats,
|
|
265
|
+
int level,
|
|
266
|
+
int nres_in = 0,
|
|
267
|
+
const SearchParametersHNSW* params = nullptr);
|
|
268
|
+
|
|
269
|
+
HNSWStats greedy_update_nearest(
|
|
270
|
+
const HNSW& hnsw,
|
|
271
|
+
DistanceComputer& qdis,
|
|
272
|
+
int level,
|
|
273
|
+
HNSW::storage_idx_t& nearest,
|
|
274
|
+
float& d_nearest);
|
|
275
|
+
|
|
276
|
+
std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
|
|
277
|
+
const HNSW& hnsw,
|
|
278
|
+
const HNSW::Node& node,
|
|
279
|
+
DistanceComputer& qdis,
|
|
280
|
+
int ef,
|
|
281
|
+
VisitedTable* vt,
|
|
282
|
+
HNSWStats& stats);
|
|
283
|
+
|
|
263
284
|
} // namespace faiss
|
|
@@ -104,10 +104,10 @@ int dgemm_(
|
|
|
104
104
|
|
|
105
105
|
namespace {
|
|
106
106
|
|
|
107
|
-
void fmat_inverse(float* a,
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
std::vector<
|
|
107
|
+
void fmat_inverse(float* a, FINTEGER n) {
|
|
108
|
+
FINTEGER info;
|
|
109
|
+
FINTEGER lwork = n * n;
|
|
110
|
+
std::vector<FINTEGER> ipiv(n);
|
|
111
111
|
std::vector<float> workspace(lwork);
|
|
112
112
|
|
|
113
113
|
sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
|
@@ -123,10 +123,10 @@ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
|
|
|
123
123
|
}
|
|
124
124
|
}
|
|
125
125
|
|
|
126
|
-
void dmat_inverse(double* a,
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
std::vector<
|
|
126
|
+
void dmat_inverse(double* a, FINTEGER n) {
|
|
127
|
+
FINTEGER info;
|
|
128
|
+
FINTEGER lwork = n * n;
|
|
129
|
+
std::vector<FINTEGER> ipiv(n);
|
|
130
130
|
std::vector<double> workspace(lwork);
|
|
131
131
|
|
|
132
132
|
dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
|
|
@@ -38,6 +38,23 @@ struct DummyScaler {
|
|
|
38
38
|
return simd16uint16(0);
|
|
39
39
|
}
|
|
40
40
|
|
|
41
|
+
#ifdef __AVX512F__
|
|
42
|
+
inline simd64uint8 lookup(const simd64uint8&, const simd64uint8&) const {
|
|
43
|
+
FAISS_THROW_MSG("DummyScaler::lookup should not be called.");
|
|
44
|
+
return simd64uint8(0);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
inline simd32uint16 scale_lo(const simd64uint8&) const {
|
|
48
|
+
FAISS_THROW_MSG("DummyScaler::scale_lo should not be called.");
|
|
49
|
+
return simd32uint16(0);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
inline simd32uint16 scale_hi(const simd64uint8&) const {
|
|
53
|
+
FAISS_THROW_MSG("DummyScaler::scale_hi should not be called.");
|
|
54
|
+
return simd32uint16(0);
|
|
55
|
+
}
|
|
56
|
+
#endif
|
|
57
|
+
|
|
41
58
|
template <class dist_t>
|
|
42
59
|
inline dist_t scale_one(const dist_t&) const {
|
|
43
60
|
FAISS_THROW_MSG("DummyScaler::scale_one should not be called.");
|
|
@@ -67,6 +84,23 @@ struct NormTableScaler {
|
|
|
67
84
|
return (simd16uint16(res) >> 8) * scale_simd;
|
|
68
85
|
}
|
|
69
86
|
|
|
87
|
+
#ifdef __AVX512F__
|
|
88
|
+
inline simd64uint8 lookup(const simd64uint8& lut, const simd64uint8& c)
|
|
89
|
+
const {
|
|
90
|
+
return lut.lookup_4_lanes(c);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
inline simd32uint16 scale_lo(const simd64uint8& res) const {
|
|
94
|
+
auto scale_simd_wide = simd32uint16(scale_simd, scale_simd);
|
|
95
|
+
return simd32uint16(res) * scale_simd_wide;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
inline simd32uint16 scale_hi(const simd64uint8& res) const {
|
|
99
|
+
auto scale_simd_wide = simd32uint16(scale_simd, scale_simd);
|
|
100
|
+
return (simd32uint16(res) >> 8) * scale_simd_wide;
|
|
101
|
+
}
|
|
102
|
+
#endif
|
|
103
|
+
|
|
70
104
|
// for non-SIMD implem 2, 3, 4
|
|
71
105
|
template <class dist_t>
|
|
72
106
|
inline dist_t scale_one(const dist_t& x) const {
|
|
@@ -154,15 +154,20 @@ NNDescent::NNDescent(const int d, const int K) : K(K), d(d) {
|
|
|
154
154
|
NNDescent::~NNDescent() {}
|
|
155
155
|
|
|
156
156
|
void NNDescent::join(DistanceComputer& qdis) {
|
|
157
|
+
idx_t check_period = InterruptCallback::get_period_hint(d * search_L);
|
|
158
|
+
for (idx_t i0 = 0; i0 < (idx_t)ntotal; i0 += check_period) {
|
|
159
|
+
idx_t i1 = std::min(i0 + check_period, (idx_t)ntotal);
|
|
157
160
|
#pragma omp parallel for default(shared) schedule(dynamic, 100)
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
161
|
+
for (idx_t n = i0; n < i1; n++) {
|
|
162
|
+
graph[n].join([&](int i, int j) {
|
|
163
|
+
if (i != j) {
|
|
164
|
+
float dist = qdis.symmetric_dis(i, j);
|
|
165
|
+
graph[i].insert(j, dist);
|
|
166
|
+
graph[j].insert(i, dist);
|
|
167
|
+
}
|
|
168
|
+
});
|
|
169
|
+
}
|
|
170
|
+
InterruptCallback::check();
|
|
166
171
|
}
|
|
167
172
|
}
|
|
168
173
|
|
|
@@ -25,35 +25,6 @@ namespace {
|
|
|
25
25
|
// It needs to be smaller than 0
|
|
26
26
|
constexpr int EMPTY_ID = -1;
|
|
27
27
|
|
|
28
|
-
/* Wrap the distance computer into one that negates the
|
|
29
|
-
distances. This makes supporting INNER_PRODUCE search easier */
|
|
30
|
-
|
|
31
|
-
struct NegativeDistanceComputer : DistanceComputer {
|
|
32
|
-
/// owned by this
|
|
33
|
-
DistanceComputer* basedis;
|
|
34
|
-
|
|
35
|
-
explicit NegativeDistanceComputer(DistanceComputer* basedis)
|
|
36
|
-
: basedis(basedis) {}
|
|
37
|
-
|
|
38
|
-
void set_query(const float* x) override {
|
|
39
|
-
basedis->set_query(x);
|
|
40
|
-
}
|
|
41
|
-
|
|
42
|
-
/// compute distance of vector i to current query
|
|
43
|
-
float operator()(idx_t i) override {
|
|
44
|
-
return -(*basedis)(i);
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
/// compute distance between two stored vectors
|
|
48
|
-
float symmetric_dis(idx_t i, idx_t j) override {
|
|
49
|
-
return -basedis->symmetric_dis(i, j);
|
|
50
|
-
}
|
|
51
|
-
|
|
52
|
-
~NegativeDistanceComputer() override {
|
|
53
|
-
delete basedis;
|
|
54
|
-
}
|
|
55
|
-
};
|
|
56
|
-
|
|
57
28
|
} // namespace
|
|
58
29
|
|
|
59
30
|
DistanceComputer* storage_distance_computer(const Index* storage) {
|
|
@@ -61,6 +61,7 @@ void ProductQuantizer::set_derived_values() {
|
|
|
61
61
|
"The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
|
|
62
62
|
dsub = d / M;
|
|
63
63
|
code_size = (nbits * M + 7) / 8;
|
|
64
|
+
FAISS_THROW_IF_MSG(nbits > 24, "nbits larger than 24 is not practical.");
|
|
64
65
|
ksub = 1 << nbits;
|
|
65
66
|
centroids.resize(d * ksub);
|
|
66
67
|
verbose = false;
|
|
@@ -21,7 +21,11 @@
|
|
|
21
21
|
|
|
22
22
|
namespace faiss {
|
|
23
23
|
|
|
24
|
-
/** Product Quantizer.
|
|
24
|
+
/** Product Quantizer.
|
|
25
|
+
* PQ is trained using k-means, minimizing the L2 distance to centroids.
|
|
26
|
+
* PQ supports L2 and Inner Product search, however the quantization error is
|
|
27
|
+
* biased towards L2 distance.
|
|
28
|
+
*/
|
|
25
29
|
struct ProductQuantizer : Quantizer {
|
|
26
30
|
size_t M; ///< number of subquantizers
|
|
27
31
|
size_t nbits; ///< number of bits per quantization index
|
|
@@ -12,9 +12,14 @@
|
|
|
12
12
|
#pragma once
|
|
13
13
|
|
|
14
14
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
15
|
+
#include <faiss/impl/FaissException.h>
|
|
16
|
+
#include <faiss/impl/IDSelector.h>
|
|
15
17
|
#include <faiss/utils/Heap.h>
|
|
16
18
|
#include <faiss/utils/partitioning.h>
|
|
17
19
|
|
|
20
|
+
#include <algorithm>
|
|
21
|
+
#include <iostream>
|
|
22
|
+
|
|
18
23
|
namespace faiss {
|
|
19
24
|
|
|
20
25
|
/*****************************************************************
|
|
@@ -24,16 +29,21 @@ namespace faiss {
|
|
|
24
29
|
* - by instanciating a SingleResultHandler that tracks results for a single
|
|
25
30
|
* query
|
|
26
31
|
* - with begin_multiple/add_results/end_multiple calls where a whole block of
|
|
27
|
-
*
|
|
32
|
+
* results is submitted
|
|
28
33
|
* All classes are templated on C which to define wheter the min or the max of
|
|
29
|
-
* results is to be kept
|
|
34
|
+
* results is to be kept, and on sel, so that the codepaths for with / without
|
|
35
|
+
* selector can be separated at compile time.
|
|
30
36
|
*****************************************************************/
|
|
31
37
|
|
|
32
|
-
template <class C>
|
|
38
|
+
template <class C, bool use_sel = false>
|
|
33
39
|
struct BlockResultHandler {
|
|
34
40
|
size_t nq; // number of queries for which we search
|
|
41
|
+
const IDSelector* sel;
|
|
35
42
|
|
|
36
|
-
explicit BlockResultHandler(size_t nq
|
|
43
|
+
explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
|
|
44
|
+
: nq(nq), sel(sel) {
|
|
45
|
+
assert(!use_sel || sel);
|
|
46
|
+
}
|
|
37
47
|
|
|
38
48
|
// currently handled query range
|
|
39
49
|
size_t i0 = 0, i1 = 0;
|
|
@@ -51,13 +61,17 @@ struct BlockResultHandler {
|
|
|
51
61
|
virtual void end_multiple() {}
|
|
52
62
|
|
|
53
63
|
virtual ~BlockResultHandler() {}
|
|
64
|
+
|
|
65
|
+
bool is_in_selection(idx_t i) const {
|
|
66
|
+
return !use_sel || sel->is_member(i);
|
|
67
|
+
}
|
|
54
68
|
};
|
|
55
69
|
|
|
56
70
|
// handler for a single query
|
|
57
71
|
template <class C>
|
|
58
72
|
struct ResultHandler {
|
|
59
73
|
// if not better than threshold, then not necessary to call add_result
|
|
60
|
-
typename C::T threshold =
|
|
74
|
+
typename C::T threshold = C::neutral();
|
|
61
75
|
|
|
62
76
|
// return whether threshold was updated
|
|
63
77
|
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
|
|
@@ -71,20 +85,26 @@ struct ResultHandler {
|
|
|
71
85
|
* some temporary data in memory.
|
|
72
86
|
*****************************************************************/
|
|
73
87
|
|
|
74
|
-
template <class C>
|
|
75
|
-
struct Top1BlockResultHandler : BlockResultHandler<C> {
|
|
88
|
+
template <class C, bool use_sel = false>
|
|
89
|
+
struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
76
90
|
using T = typename C::T;
|
|
77
91
|
using TI = typename C::TI;
|
|
78
|
-
using BlockResultHandler<C>::i0;
|
|
79
|
-
using BlockResultHandler<C>::i1;
|
|
92
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
93
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
80
94
|
|
|
81
95
|
// contains exactly nq elements
|
|
82
96
|
T* dis_tab;
|
|
83
97
|
// contains exactly nq elements
|
|
84
98
|
TI* ids_tab;
|
|
85
99
|
|
|
86
|
-
Top1BlockResultHandler(
|
|
87
|
-
|
|
100
|
+
Top1BlockResultHandler(
|
|
101
|
+
size_t nq,
|
|
102
|
+
T* dis_tab,
|
|
103
|
+
TI* ids_tab,
|
|
104
|
+
const IDSelector* sel = nullptr)
|
|
105
|
+
: BlockResultHandler<C, use_sel>(nq, sel),
|
|
106
|
+
dis_tab(dis_tab),
|
|
107
|
+
ids_tab(ids_tab) {}
|
|
88
108
|
|
|
89
109
|
struct SingleResultHandler : ResultHandler<C> {
|
|
90
110
|
Top1BlockResultHandler& hr;
|
|
@@ -163,12 +183,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
|
|
|
163
183
|
* Heap based result handler
|
|
164
184
|
*****************************************************************/
|
|
165
185
|
|
|
166
|
-
template <class C>
|
|
167
|
-
struct HeapBlockResultHandler : BlockResultHandler<C> {
|
|
186
|
+
template <class C, bool use_sel = false>
|
|
187
|
+
struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
168
188
|
using T = typename C::T;
|
|
169
189
|
using TI = typename C::TI;
|
|
170
|
-
using BlockResultHandler<C>::i0;
|
|
171
|
-
using BlockResultHandler<C>::i1;
|
|
190
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
191
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
172
192
|
|
|
173
193
|
T* heap_dis_tab;
|
|
174
194
|
TI* heap_ids_tab;
|
|
@@ -179,8 +199,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
|
|
|
179
199
|
size_t nq,
|
|
180
200
|
T* heap_dis_tab,
|
|
181
201
|
TI* heap_ids_tab,
|
|
182
|
-
size_t k
|
|
183
|
-
|
|
202
|
+
size_t k,
|
|
203
|
+
const IDSelector* sel = nullptr)
|
|
204
|
+
: BlockResultHandler<C, use_sel>(nq, sel),
|
|
184
205
|
heap_dis_tab(heap_dis_tab),
|
|
185
206
|
heap_ids_tab(heap_ids_tab),
|
|
186
207
|
k(k) {}
|
|
@@ -345,12 +366,12 @@ struct ReservoirTopN : ResultHandler<C> {
|
|
|
345
366
|
}
|
|
346
367
|
};
|
|
347
368
|
|
|
348
|
-
template <class C>
|
|
349
|
-
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
|
369
|
+
template <class C, bool use_sel = false>
|
|
370
|
+
struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
350
371
|
using T = typename C::T;
|
|
351
372
|
using TI = typename C::TI;
|
|
352
|
-
using BlockResultHandler<C>::i0;
|
|
353
|
-
using BlockResultHandler<C>::i1;
|
|
373
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
374
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
354
375
|
|
|
355
376
|
T* heap_dis_tab;
|
|
356
377
|
TI* heap_ids_tab;
|
|
@@ -362,8 +383,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
|
|
362
383
|
size_t nq,
|
|
363
384
|
T* heap_dis_tab,
|
|
364
385
|
TI* heap_ids_tab,
|
|
365
|
-
size_t k
|
|
366
|
-
|
|
386
|
+
size_t k,
|
|
387
|
+
const IDSelector* sel = nullptr)
|
|
388
|
+
: BlockResultHandler<C, use_sel>(nq, sel),
|
|
367
389
|
heap_dis_tab(heap_dis_tab),
|
|
368
390
|
heap_ids_tab(heap_ids_tab),
|
|
369
391
|
k(k) {
|
|
@@ -458,18 +480,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
|
|
|
458
480
|
* Result handler for range searches
|
|
459
481
|
*****************************************************************/
|
|
460
482
|
|
|
461
|
-
template <class C>
|
|
462
|
-
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
|
483
|
+
template <class C, bool use_sel = false>
|
|
484
|
+
struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
|
|
463
485
|
using T = typename C::T;
|
|
464
486
|
using TI = typename C::TI;
|
|
465
|
-
using BlockResultHandler<C>::i0;
|
|
466
|
-
using BlockResultHandler<C>::i1;
|
|
487
|
+
using BlockResultHandler<C, use_sel>::i0;
|
|
488
|
+
using BlockResultHandler<C, use_sel>::i1;
|
|
467
489
|
|
|
468
490
|
RangeSearchResult* res;
|
|
469
491
|
T radius;
|
|
470
492
|
|
|
471
|
-
RangeSearchBlockResultHandler(
|
|
472
|
-
|
|
493
|
+
RangeSearchBlockResultHandler(
|
|
494
|
+
RangeSearchResult* res,
|
|
495
|
+
float radius,
|
|
496
|
+
const IDSelector* sel = nullptr)
|
|
497
|
+
: BlockResultHandler<C, use_sel>(res->nq, sel),
|
|
498
|
+
res(res),
|
|
499
|
+
radius(radius) {}
|
|
473
500
|
|
|
474
501
|
/******************************************************
|
|
475
502
|
* API for 1 result at a time (each SingleResultHandler is
|
|
@@ -504,7 +531,15 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
|
|
504
531
|
void end() {}
|
|
505
532
|
|
|
506
533
|
~SingleResultHandler() {
|
|
507
|
-
|
|
534
|
+
try {
|
|
535
|
+
// finalize the partial result
|
|
536
|
+
pres.finalize();
|
|
537
|
+
} catch (const faiss::FaissException& e) {
|
|
538
|
+
// Do nothing if allocation fails in finalizing partial results.
|
|
539
|
+
#ifndef NDEBUG
|
|
540
|
+
std::cerr << e.what() << std::endl;
|
|
541
|
+
#endif
|
|
542
|
+
}
|
|
508
543
|
}
|
|
509
544
|
};
|
|
510
545
|
|
|
@@ -559,10 +594,94 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
|
|
|
559
594
|
}
|
|
560
595
|
|
|
561
596
|
~RangeSearchBlockResultHandler() {
|
|
562
|
-
|
|
563
|
-
|
|
597
|
+
try {
|
|
598
|
+
if (partial_results.size() > 0) {
|
|
599
|
+
RangeSearchPartialResult::merge(partial_results);
|
|
600
|
+
}
|
|
601
|
+
} catch (const faiss::FaissException& e) {
|
|
602
|
+
// Do nothing if allocation fails in merge.
|
|
603
|
+
#ifndef NDEBUG
|
|
604
|
+
std::cerr << e.what() << std::endl;
|
|
605
|
+
#endif
|
|
564
606
|
}
|
|
565
607
|
}
|
|
566
608
|
};
|
|
567
609
|
|
|
610
|
+
/*****************************************************************
|
|
611
|
+
* Dispatcher function to choose the right knn result handler depending on k
|
|
612
|
+
*****************************************************************/
|
|
613
|
+
|
|
614
|
+
// declared in distances.cpp
|
|
615
|
+
FAISS_API extern int distance_compute_min_k_reservoir;
|
|
616
|
+
|
|
617
|
+
template <class Consumer, class... Types>
|
|
618
|
+
typename Consumer::T dispatch_knn_ResultHandler(
|
|
619
|
+
size_t nx,
|
|
620
|
+
float* vals,
|
|
621
|
+
int64_t* ids,
|
|
622
|
+
size_t k,
|
|
623
|
+
MetricType metric,
|
|
624
|
+
const IDSelector* sel,
|
|
625
|
+
Consumer& consumer,
|
|
626
|
+
Types... args) {
|
|
627
|
+
#define DISPATCH_C_SEL(C, use_sel) \
|
|
628
|
+
if (k == 1) { \
|
|
629
|
+
Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
|
|
630
|
+
return consumer.template f<>(res, args...); \
|
|
631
|
+
} else if (k < distance_compute_min_k_reservoir) { \
|
|
632
|
+
HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
|
|
633
|
+
return consumer.template f<>(res, args...); \
|
|
634
|
+
} else { \
|
|
635
|
+
ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
|
|
636
|
+
return consumer.template f<>(res, args...); \
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
if (is_similarity_metric(metric)) {
|
|
640
|
+
using C = CMin<float, int64_t>;
|
|
641
|
+
if (sel) {
|
|
642
|
+
DISPATCH_C_SEL(C, true);
|
|
643
|
+
} else {
|
|
644
|
+
DISPATCH_C_SEL(C, false);
|
|
645
|
+
}
|
|
646
|
+
} else {
|
|
647
|
+
using C = CMax<float, int64_t>;
|
|
648
|
+
if (sel) {
|
|
649
|
+
DISPATCH_C_SEL(C, true);
|
|
650
|
+
} else {
|
|
651
|
+
DISPATCH_C_SEL(C, false);
|
|
652
|
+
}
|
|
653
|
+
}
|
|
654
|
+
#undef DISPATCH_C_SEL
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
template <class Consumer, class... Types>
|
|
658
|
+
typename Consumer::T dispatch_range_ResultHandler(
|
|
659
|
+
RangeSearchResult* res,
|
|
660
|
+
float radius,
|
|
661
|
+
MetricType metric,
|
|
662
|
+
const IDSelector* sel,
|
|
663
|
+
Consumer& consumer,
|
|
664
|
+
Types... args) {
|
|
665
|
+
#define DISPATCH_C_SEL(C, use_sel) \
|
|
666
|
+
RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
|
|
667
|
+
return consumer.template f<>(resb, args...);
|
|
668
|
+
|
|
669
|
+
if (is_similarity_metric(metric)) {
|
|
670
|
+
using C = CMin<float, int64_t>;
|
|
671
|
+
if (sel) {
|
|
672
|
+
DISPATCH_C_SEL(C, true);
|
|
673
|
+
} else {
|
|
674
|
+
DISPATCH_C_SEL(C, false);
|
|
675
|
+
}
|
|
676
|
+
} else {
|
|
677
|
+
using C = CMax<float, int64_t>;
|
|
678
|
+
if (sel) {
|
|
679
|
+
DISPATCH_C_SEL(C, true);
|
|
680
|
+
} else {
|
|
681
|
+
DISPATCH_C_SEL(C, false);
|
|
682
|
+
}
|
|
683
|
+
}
|
|
684
|
+
#undef DISPATCH_C_SEL
|
|
685
|
+
}
|
|
686
|
+
|
|
568
687
|
} // namespace faiss
|