faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -40,19 +40,13 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(
|
|
|
40
40
|
size_t d,
|
|
41
41
|
size_t nlist,
|
|
42
42
|
size_t M,
|
|
43
|
-
size_t
|
|
43
|
+
size_t nbits,
|
|
44
44
|
MetricType metric,
|
|
45
45
|
int bbs)
|
|
46
|
-
:
|
|
47
|
-
pq(d, M, nbits_per_idx),
|
|
48
|
-
bbs(bbs) {
|
|
49
|
-
FAISS_THROW_IF_NOT(nbits_per_idx == 4);
|
|
50
|
-
M2 = roundup(pq.M, 2);
|
|
46
|
+
: IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
|
|
51
47
|
by_residual = false; // set to false by default because it's much faster
|
|
52
|
-
is_trained = false;
|
|
53
|
-
code_size = pq.code_size;
|
|
54
48
|
|
|
55
|
-
|
|
49
|
+
init_fastscan(M, nbits, nlist, metric, bbs);
|
|
56
50
|
}
|
|
57
51
|
|
|
58
52
|
IndexIVFPQFastScan::IndexIVFPQFastScan() {
|
|
@@ -62,26 +56,21 @@ IndexIVFPQFastScan::IndexIVFPQFastScan() {
|
|
|
62
56
|
}
|
|
63
57
|
|
|
64
58
|
IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
|
|
65
|
-
:
|
|
59
|
+
: IndexIVFFastScan(
|
|
66
60
|
orig.quantizer,
|
|
67
61
|
orig.d,
|
|
68
62
|
orig.nlist,
|
|
69
63
|
orig.pq.code_size,
|
|
70
64
|
orig.metric_type),
|
|
71
|
-
pq(orig.pq)
|
|
72
|
-
bbs(bbs) {
|
|
65
|
+
pq(orig.pq) {
|
|
73
66
|
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
|
|
74
67
|
|
|
68
|
+
init_fastscan(orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
|
|
69
|
+
|
|
75
70
|
by_residual = orig.by_residual;
|
|
76
71
|
ntotal = orig.ntotal;
|
|
77
72
|
is_trained = orig.is_trained;
|
|
78
73
|
nprobe = orig.nprobe;
|
|
79
|
-
size_t M = pq.M;
|
|
80
|
-
|
|
81
|
-
M2 = roundup(M, 2);
|
|
82
|
-
|
|
83
|
-
replace_invlists(
|
|
84
|
-
new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2), true);
|
|
85
74
|
|
|
86
75
|
precomputed_table.resize(orig.precomputed_table.size());
|
|
87
76
|
|
|
@@ -205,150 +194,10 @@ void IndexIVFPQFastScan::encode_vectors(
|
|
|
205
194
|
}
|
|
206
195
|
}
|
|
207
196
|
|
|
208
|
-
void IndexIVFPQFastScan::add_with_ids(
|
|
209
|
-
idx_t n,
|
|
210
|
-
const float* x,
|
|
211
|
-
const idx_t* xids) {
|
|
212
|
-
// copied from IndexIVF::add_with_ids --->
|
|
213
|
-
|
|
214
|
-
// do some blocking to avoid excessive allocs
|
|
215
|
-
idx_t bs = 65536;
|
|
216
|
-
if (n > bs) {
|
|
217
|
-
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
218
|
-
idx_t i1 = std::min(n, i0 + bs);
|
|
219
|
-
if (verbose) {
|
|
220
|
-
printf(" IndexIVFPQFastScan::add_with_ids %zd: %zd",
|
|
221
|
-
size_t(i0),
|
|
222
|
-
size_t(i1));
|
|
223
|
-
}
|
|
224
|
-
add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
|
|
225
|
-
}
|
|
226
|
-
return;
|
|
227
|
-
}
|
|
228
|
-
InterruptCallback::check();
|
|
229
|
-
|
|
230
|
-
AlignedTable<uint8_t> codes(n * code_size);
|
|
231
|
-
|
|
232
|
-
FAISS_THROW_IF_NOT(is_trained);
|
|
233
|
-
direct_map.check_can_add(xids);
|
|
234
|
-
|
|
235
|
-
std::unique_ptr<idx_t[]> idx(new idx_t[n]);
|
|
236
|
-
quantizer->assign(n, x, idx.get());
|
|
237
|
-
size_t nadd = 0, nminus1 = 0;
|
|
238
|
-
|
|
239
|
-
for (size_t i = 0; i < n; i++) {
|
|
240
|
-
if (idx[i] < 0)
|
|
241
|
-
nminus1++;
|
|
242
|
-
}
|
|
243
|
-
|
|
244
|
-
AlignedTable<uint8_t> flat_codes(n * code_size);
|
|
245
|
-
encode_vectors(n, x, idx.get(), flat_codes.get());
|
|
246
|
-
|
|
247
|
-
DirectMapAdd dm_adder(direct_map, n, xids);
|
|
248
|
-
|
|
249
|
-
// <---
|
|
250
|
-
|
|
251
|
-
BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
|
|
252
|
-
FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
|
|
253
|
-
|
|
254
|
-
// prepare batches
|
|
255
|
-
std::vector<idx_t> order(n);
|
|
256
|
-
for (idx_t i = 0; i < n; i++) {
|
|
257
|
-
order[i] = i;
|
|
258
|
-
}
|
|
259
|
-
|
|
260
|
-
// TODO should not need stable
|
|
261
|
-
std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
|
|
262
|
-
return idx[a] < idx[b];
|
|
263
|
-
});
|
|
264
|
-
|
|
265
|
-
// TODO parallelize
|
|
266
|
-
idx_t i0 = 0;
|
|
267
|
-
while (i0 < n) {
|
|
268
|
-
idx_t list_no = idx[order[i0]];
|
|
269
|
-
idx_t i1 = i0 + 1;
|
|
270
|
-
while (i1 < n && idx[order[i1]] == list_no) {
|
|
271
|
-
i1++;
|
|
272
|
-
}
|
|
273
|
-
|
|
274
|
-
if (list_no == -1) {
|
|
275
|
-
i0 = i1;
|
|
276
|
-
continue;
|
|
277
|
-
}
|
|
278
|
-
|
|
279
|
-
// make linear array
|
|
280
|
-
AlignedTable<uint8_t> list_codes((i1 - i0) * code_size);
|
|
281
|
-
size_t list_size = bil->list_size(list_no);
|
|
282
|
-
|
|
283
|
-
bil->resize(list_no, list_size + i1 - i0);
|
|
284
|
-
|
|
285
|
-
for (idx_t i = i0; i < i1; i++) {
|
|
286
|
-
size_t ofs = list_size + i - i0;
|
|
287
|
-
idx_t id = xids ? xids[order[i]] : ntotal + order[i];
|
|
288
|
-
dm_adder.add(order[i], list_no, ofs);
|
|
289
|
-
bil->ids[list_no][ofs] = id;
|
|
290
|
-
memcpy(list_codes.data() + (i - i0) * code_size,
|
|
291
|
-
flat_codes.data() + order[i] * code_size,
|
|
292
|
-
code_size);
|
|
293
|
-
nadd++;
|
|
294
|
-
}
|
|
295
|
-
pq4_pack_codes_range(
|
|
296
|
-
list_codes.data(),
|
|
297
|
-
pq.M,
|
|
298
|
-
list_size,
|
|
299
|
-
list_size + i1 - i0,
|
|
300
|
-
bbs,
|
|
301
|
-
M2,
|
|
302
|
-
bil->codes[list_no].data());
|
|
303
|
-
|
|
304
|
-
i0 = i1;
|
|
305
|
-
}
|
|
306
|
-
|
|
307
|
-
ntotal += n;
|
|
308
|
-
}
|
|
309
|
-
|
|
310
197
|
/*********************************************************
|
|
311
|
-
*
|
|
198
|
+
* Look-Up Table functions
|
|
312
199
|
*********************************************************/
|
|
313
200
|
|
|
314
|
-
namespace {
|
|
315
|
-
|
|
316
|
-
// from impl/ProductQuantizer.cpp
|
|
317
|
-
template <class C, typename dis_t>
|
|
318
|
-
void pq_estimators_from_tables_generic(
|
|
319
|
-
const ProductQuantizer& pq,
|
|
320
|
-
size_t nbits,
|
|
321
|
-
const uint8_t* codes,
|
|
322
|
-
size_t ncodes,
|
|
323
|
-
const dis_t* dis_table,
|
|
324
|
-
const int64_t* ids,
|
|
325
|
-
float dis0,
|
|
326
|
-
size_t k,
|
|
327
|
-
typename C::T* heap_dis,
|
|
328
|
-
int64_t* heap_ids) {
|
|
329
|
-
using accu_t = typename C::T;
|
|
330
|
-
const size_t M = pq.M;
|
|
331
|
-
const size_t ksub = pq.ksub;
|
|
332
|
-
for (size_t j = 0; j < ncodes; ++j) {
|
|
333
|
-
PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
|
|
334
|
-
accu_t dis = dis0;
|
|
335
|
-
const dis_t* dt = dis_table;
|
|
336
|
-
for (size_t m = 0; m < M; m++) {
|
|
337
|
-
uint64_t c = decoder.decode();
|
|
338
|
-
dis += dt[c];
|
|
339
|
-
dt += ksub;
|
|
340
|
-
}
|
|
341
|
-
|
|
342
|
-
if (C::cmp(heap_dis[0], dis)) {
|
|
343
|
-
heap_pop<C>(k, heap_dis, heap_ids);
|
|
344
|
-
heap_push<C>(k, heap_dis, heap_ids, dis, ids[j]);
|
|
345
|
-
}
|
|
346
|
-
}
|
|
347
|
-
}
|
|
348
|
-
|
|
349
|
-
using idx_t = Index::idx_t;
|
|
350
|
-
using namespace quantize_lut;
|
|
351
|
-
|
|
352
201
|
void fvec_madd_avx(
|
|
353
202
|
size_t n,
|
|
354
203
|
const float* a,
|
|
@@ -373,11 +222,9 @@ void fvec_madd_avx(
|
|
|
373
222
|
}
|
|
374
223
|
}
|
|
375
224
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
* Look-Up Table functions
|
|
380
|
-
*********************************************************/
|
|
225
|
+
bool IndexIVFPQFastScan::lookup_table_is_3d() const {
|
|
226
|
+
return by_residual && metric_type == METRIC_L2;
|
|
227
|
+
}
|
|
381
228
|
|
|
382
229
|
void IndexIVFPQFastScan::compute_LUT(
|
|
383
230
|
size_t n,
|
|
@@ -386,16 +233,14 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
386
233
|
const float* coarse_dis,
|
|
387
234
|
AlignedTable<float>& dis_tables,
|
|
388
235
|
AlignedTable<float>& biases) const {
|
|
389
|
-
const IndexIVFPQFastScan& ivfpq = *this;
|
|
390
236
|
size_t dim12 = pq.ksub * pq.M;
|
|
391
237
|
size_t d = pq.d;
|
|
392
|
-
size_t nprobe = ivfpq.nprobe;
|
|
393
238
|
|
|
394
|
-
if (
|
|
395
|
-
if (
|
|
239
|
+
if (by_residual) {
|
|
240
|
+
if (metric_type == METRIC_L2) {
|
|
396
241
|
dis_tables.resize(n * nprobe * dim12);
|
|
397
242
|
|
|
398
|
-
if (
|
|
243
|
+
if (use_precomputed_table == 1) {
|
|
399
244
|
biases.resize(n * nprobe);
|
|
400
245
|
memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
|
|
401
246
|
|
|
@@ -434,7 +279,7 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
434
279
|
idx_t cij = coarse_ids[ij];
|
|
435
280
|
|
|
436
281
|
if (cij >= 0) {
|
|
437
|
-
|
|
282
|
+
quantizer->compute_residual(x + i * d, xij, cij);
|
|
438
283
|
} else {
|
|
439
284
|
// will fill with NaNs
|
|
440
285
|
memset(xij, -1, sizeof(float) * d);
|
|
@@ -445,7 +290,7 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
445
290
|
n * nprobe, xrel.get(), dis_tables.get());
|
|
446
291
|
}
|
|
447
292
|
|
|
448
|
-
} else if (
|
|
293
|
+
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
449
294
|
dis_tables.resize(n * dim12);
|
|
450
295
|
pq.compute_inner_prod_tables(n, x, dis_tables.get());
|
|
451
296
|
// compute_inner_prod_tables(pq, n, x, dis_tables.get());
|
|
@@ -453,698 +298,24 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
453
298
|
biases.resize(n * nprobe);
|
|
454
299
|
memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
|
|
455
300
|
} else {
|
|
456
|
-
FAISS_THROW_FMT("metric %d not supported",
|
|
301
|
+
FAISS_THROW_FMT("metric %d not supported", metric_type);
|
|
457
302
|
}
|
|
458
303
|
|
|
459
304
|
} else {
|
|
460
305
|
dis_tables.resize(n * dim12);
|
|
461
|
-
if (
|
|
306
|
+
if (metric_type == METRIC_L2) {
|
|
462
307
|
pq.compute_distance_tables(n, x, dis_tables.get());
|
|
463
|
-
} else if (
|
|
308
|
+
} else if (metric_type == METRIC_INNER_PRODUCT) {
|
|
464
309
|
pq.compute_inner_prod_tables(n, x, dis_tables.get());
|
|
465
310
|
} else {
|
|
466
|
-
FAISS_THROW_FMT("metric %d not supported",
|
|
467
|
-
}
|
|
468
|
-
}
|
|
469
|
-
}
|
|
470
|
-
|
|
471
|
-
void IndexIVFPQFastScan::compute_LUT_uint8(
|
|
472
|
-
size_t n,
|
|
473
|
-
const float* x,
|
|
474
|
-
const idx_t* coarse_ids,
|
|
475
|
-
const float* coarse_dis,
|
|
476
|
-
AlignedTable<uint8_t>& dis_tables,
|
|
477
|
-
AlignedTable<uint16_t>& biases,
|
|
478
|
-
float* normalizers) const {
|
|
479
|
-
const IndexIVFPQFastScan& ivfpq = *this;
|
|
480
|
-
AlignedTable<float> dis_tables_float;
|
|
481
|
-
AlignedTable<float> biases_float;
|
|
482
|
-
|
|
483
|
-
uint64_t t0 = get_cy();
|
|
484
|
-
compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
|
|
485
|
-
IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
|
|
486
|
-
|
|
487
|
-
bool lut_is_3d = ivfpq.by_residual && ivfpq.metric_type == METRIC_L2;
|
|
488
|
-
size_t dim123 = pq.ksub * pq.M;
|
|
489
|
-
size_t dim123_2 = pq.ksub * M2;
|
|
490
|
-
if (lut_is_3d) {
|
|
491
|
-
dim123 *= nprobe;
|
|
492
|
-
dim123_2 *= nprobe;
|
|
493
|
-
}
|
|
494
|
-
dis_tables.resize(n * dim123_2);
|
|
495
|
-
if (biases_float.get()) {
|
|
496
|
-
biases.resize(n * nprobe);
|
|
497
|
-
}
|
|
498
|
-
uint64_t t1 = get_cy();
|
|
499
|
-
|
|
500
|
-
#pragma omp parallel for if (n > 100)
|
|
501
|
-
for (int64_t i = 0; i < n; i++) {
|
|
502
|
-
const float* t_in = dis_tables_float.get() + i * dim123;
|
|
503
|
-
const float* b_in = nullptr;
|
|
504
|
-
uint8_t* t_out = dis_tables.get() + i * dim123_2;
|
|
505
|
-
uint16_t* b_out = nullptr;
|
|
506
|
-
if (biases_float.get()) {
|
|
507
|
-
b_in = biases_float.get() + i * nprobe;
|
|
508
|
-
b_out = biases.get() + i * nprobe;
|
|
509
|
-
}
|
|
510
|
-
|
|
511
|
-
quantize_LUT_and_bias(
|
|
512
|
-
nprobe,
|
|
513
|
-
pq.M,
|
|
514
|
-
pq.ksub,
|
|
515
|
-
lut_is_3d,
|
|
516
|
-
t_in,
|
|
517
|
-
b_in,
|
|
518
|
-
t_out,
|
|
519
|
-
M2,
|
|
520
|
-
b_out,
|
|
521
|
-
normalizers + 2 * i,
|
|
522
|
-
normalizers + 2 * i + 1);
|
|
523
|
-
}
|
|
524
|
-
IVFFastScan_stats.t_round += get_cy() - t1;
|
|
525
|
-
}
|
|
526
|
-
|
|
527
|
-
/*********************************************************
|
|
528
|
-
* Search functions
|
|
529
|
-
*********************************************************/
|
|
530
|
-
|
|
531
|
-
template <bool is_max>
|
|
532
|
-
void IndexIVFPQFastScan::search_dispatch_implem(
|
|
533
|
-
idx_t n,
|
|
534
|
-
const float* x,
|
|
535
|
-
idx_t k,
|
|
536
|
-
float* distances,
|
|
537
|
-
idx_t* labels) const {
|
|
538
|
-
using Cfloat = typename std::conditional<
|
|
539
|
-
is_max,
|
|
540
|
-
CMax<float, int64_t>,
|
|
541
|
-
CMin<float, int64_t>>::type;
|
|
542
|
-
|
|
543
|
-
using C = typename std::conditional<
|
|
544
|
-
is_max,
|
|
545
|
-
CMax<uint16_t, int64_t>,
|
|
546
|
-
CMin<uint16_t, int64_t>>::type;
|
|
547
|
-
|
|
548
|
-
if (n == 0) {
|
|
549
|
-
return;
|
|
550
|
-
}
|
|
551
|
-
|
|
552
|
-
// actual implementation used
|
|
553
|
-
int impl = implem;
|
|
554
|
-
|
|
555
|
-
if (impl == 0) {
|
|
556
|
-
if (bbs == 32) {
|
|
557
|
-
impl = 12;
|
|
558
|
-
} else {
|
|
559
|
-
impl = 10;
|
|
560
|
-
}
|
|
561
|
-
if (k > 20) {
|
|
562
|
-
impl++;
|
|
563
|
-
}
|
|
564
|
-
}
|
|
565
|
-
|
|
566
|
-
if (impl == 1) {
|
|
567
|
-
search_implem_1<Cfloat>(n, x, k, distances, labels);
|
|
568
|
-
} else if (impl == 2) {
|
|
569
|
-
search_implem_2<C>(n, x, k, distances, labels);
|
|
570
|
-
|
|
571
|
-
} else if (impl >= 10 && impl <= 13) {
|
|
572
|
-
size_t ndis = 0, nlist_visited = 0;
|
|
573
|
-
|
|
574
|
-
if (n < 2) {
|
|
575
|
-
if (impl == 12 || impl == 13) {
|
|
576
|
-
search_implem_12<C>(
|
|
577
|
-
n,
|
|
578
|
-
x,
|
|
579
|
-
k,
|
|
580
|
-
distances,
|
|
581
|
-
labels,
|
|
582
|
-
impl,
|
|
583
|
-
&ndis,
|
|
584
|
-
&nlist_visited);
|
|
585
|
-
} else {
|
|
586
|
-
search_implem_10<C>(
|
|
587
|
-
n,
|
|
588
|
-
x,
|
|
589
|
-
k,
|
|
590
|
-
distances,
|
|
591
|
-
labels,
|
|
592
|
-
impl,
|
|
593
|
-
&ndis,
|
|
594
|
-
&nlist_visited);
|
|
595
|
-
}
|
|
596
|
-
} else {
|
|
597
|
-
// explicitly slice over threads
|
|
598
|
-
int nslice;
|
|
599
|
-
if (n <= omp_get_max_threads()) {
|
|
600
|
-
nslice = n;
|
|
601
|
-
} else if (by_residual && metric_type == METRIC_L2) {
|
|
602
|
-
// make sure we don't make too big LUT tables
|
|
603
|
-
size_t lut_size_per_query = pq.M * pq.ksub * nprobe *
|
|
604
|
-
(sizeof(float) + sizeof(uint8_t));
|
|
605
|
-
|
|
606
|
-
size_t max_lut_size = precomputed_table_max_bytes;
|
|
607
|
-
// how many queries we can handle within mem budget
|
|
608
|
-
size_t nq_ok =
|
|
609
|
-
std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
610
|
-
nslice =
|
|
611
|
-
roundup(std::max(size_t(n / nq_ok), size_t(1)),
|
|
612
|
-
omp_get_max_threads());
|
|
613
|
-
} else {
|
|
614
|
-
// LUTs unlikely to be a limiting factor
|
|
615
|
-
nslice = omp_get_max_threads();
|
|
616
|
-
}
|
|
617
|
-
|
|
618
|
-
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
619
|
-
for (int slice = 0; slice < nslice; slice++) {
|
|
620
|
-
idx_t i0 = n * slice / nslice;
|
|
621
|
-
idx_t i1 = n * (slice + 1) / nslice;
|
|
622
|
-
float* dis_i = distances + i0 * k;
|
|
623
|
-
idx_t* lab_i = labels + i0 * k;
|
|
624
|
-
if (impl == 12 || impl == 13) {
|
|
625
|
-
search_implem_12<C>(
|
|
626
|
-
i1 - i0,
|
|
627
|
-
x + i0 * d,
|
|
628
|
-
k,
|
|
629
|
-
dis_i,
|
|
630
|
-
lab_i,
|
|
631
|
-
impl,
|
|
632
|
-
&ndis,
|
|
633
|
-
&nlist_visited);
|
|
634
|
-
} else {
|
|
635
|
-
search_implem_10<C>(
|
|
636
|
-
i1 - i0,
|
|
637
|
-
x + i0 * d,
|
|
638
|
-
k,
|
|
639
|
-
dis_i,
|
|
640
|
-
lab_i,
|
|
641
|
-
impl,
|
|
642
|
-
&ndis,
|
|
643
|
-
&nlist_visited);
|
|
644
|
-
}
|
|
645
|
-
}
|
|
646
|
-
}
|
|
647
|
-
indexIVF_stats.nq += n;
|
|
648
|
-
indexIVF_stats.ndis += ndis;
|
|
649
|
-
indexIVF_stats.nlist += nlist_visited;
|
|
650
|
-
} else {
|
|
651
|
-
FAISS_THROW_FMT("implem %d does not exist", implem);
|
|
652
|
-
}
|
|
653
|
-
}
|
|
654
|
-
|
|
655
|
-
void IndexIVFPQFastScan::search(
|
|
656
|
-
idx_t n,
|
|
657
|
-
const float* x,
|
|
658
|
-
idx_t k,
|
|
659
|
-
float* distances,
|
|
660
|
-
idx_t* labels) const {
|
|
661
|
-
FAISS_THROW_IF_NOT(k > 0);
|
|
662
|
-
|
|
663
|
-
if (metric_type == METRIC_L2) {
|
|
664
|
-
search_dispatch_implem<true>(n, x, k, distances, labels);
|
|
665
|
-
} else {
|
|
666
|
-
search_dispatch_implem<false>(n, x, k, distances, labels);
|
|
667
|
-
}
|
|
668
|
-
}
|
|
669
|
-
|
|
670
|
-
template <class C>
|
|
671
|
-
void IndexIVFPQFastScan::search_implem_1(
|
|
672
|
-
idx_t n,
|
|
673
|
-
const float* x,
|
|
674
|
-
idx_t k,
|
|
675
|
-
float* distances,
|
|
676
|
-
idx_t* labels) const {
|
|
677
|
-
FAISS_THROW_IF_NOT(orig_invlists);
|
|
678
|
-
|
|
679
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
680
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
681
|
-
|
|
682
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
683
|
-
|
|
684
|
-
size_t dim12 = pq.ksub * pq.M;
|
|
685
|
-
AlignedTable<float> dis_tables;
|
|
686
|
-
AlignedTable<float> biases;
|
|
687
|
-
|
|
688
|
-
compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
|
|
689
|
-
|
|
690
|
-
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
691
|
-
|
|
692
|
-
size_t ndis = 0, nlist_visited = 0;
|
|
693
|
-
|
|
694
|
-
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
695
|
-
for (idx_t i = 0; i < n; i++) {
|
|
696
|
-
int64_t* heap_ids = labels + i * k;
|
|
697
|
-
float* heap_dis = distances + i * k;
|
|
698
|
-
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
699
|
-
float* LUT = nullptr;
|
|
700
|
-
|
|
701
|
-
if (single_LUT) {
|
|
702
|
-
LUT = dis_tables.get() + i * dim12;
|
|
703
|
-
}
|
|
704
|
-
for (idx_t j = 0; j < nprobe; j++) {
|
|
705
|
-
if (!single_LUT) {
|
|
706
|
-
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
707
|
-
}
|
|
708
|
-
idx_t list_no = coarse_ids[i * nprobe + j];
|
|
709
|
-
if (list_no < 0)
|
|
710
|
-
continue;
|
|
711
|
-
size_t ls = orig_invlists->list_size(list_no);
|
|
712
|
-
if (ls == 0)
|
|
713
|
-
continue;
|
|
714
|
-
InvertedLists::ScopedCodes codes(orig_invlists, list_no);
|
|
715
|
-
InvertedLists::ScopedIds ids(orig_invlists, list_no);
|
|
716
|
-
|
|
717
|
-
float bias = biases.get() ? biases[i * nprobe + j] : 0;
|
|
718
|
-
|
|
719
|
-
pq_estimators_from_tables_generic<C>(
|
|
720
|
-
pq,
|
|
721
|
-
pq.nbits,
|
|
722
|
-
codes.get(),
|
|
723
|
-
ls,
|
|
724
|
-
LUT,
|
|
725
|
-
ids.get(),
|
|
726
|
-
bias,
|
|
727
|
-
k,
|
|
728
|
-
heap_dis,
|
|
729
|
-
heap_ids);
|
|
730
|
-
nlist_visited++;
|
|
731
|
-
ndis++;
|
|
732
|
-
}
|
|
733
|
-
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
734
|
-
}
|
|
735
|
-
indexIVF_stats.nq += n;
|
|
736
|
-
indexIVF_stats.ndis += ndis;
|
|
737
|
-
indexIVF_stats.nlist += nlist_visited;
|
|
738
|
-
}
|
|
739
|
-
|
|
740
|
-
template <class C>
|
|
741
|
-
void IndexIVFPQFastScan::search_implem_2(
|
|
742
|
-
idx_t n,
|
|
743
|
-
const float* x,
|
|
744
|
-
idx_t k,
|
|
745
|
-
float* distances,
|
|
746
|
-
idx_t* labels) const {
|
|
747
|
-
FAISS_THROW_IF_NOT(orig_invlists);
|
|
748
|
-
|
|
749
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
750
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
751
|
-
|
|
752
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
753
|
-
|
|
754
|
-
size_t dim12 = pq.ksub * M2;
|
|
755
|
-
AlignedTable<uint8_t> dis_tables;
|
|
756
|
-
AlignedTable<uint16_t> biases;
|
|
757
|
-
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
758
|
-
|
|
759
|
-
compute_LUT_uint8(
|
|
760
|
-
n,
|
|
761
|
-
x,
|
|
762
|
-
coarse_ids.get(),
|
|
763
|
-
coarse_dis.get(),
|
|
764
|
-
dis_tables,
|
|
765
|
-
biases,
|
|
766
|
-
normalizers.get());
|
|
767
|
-
|
|
768
|
-
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
769
|
-
|
|
770
|
-
size_t ndis = 0, nlist_visited = 0;
|
|
771
|
-
|
|
772
|
-
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
773
|
-
for (idx_t i = 0; i < n; i++) {
|
|
774
|
-
std::vector<uint16_t> tmp_dis(k);
|
|
775
|
-
int64_t* heap_ids = labels + i * k;
|
|
776
|
-
uint16_t* heap_dis = tmp_dis.data();
|
|
777
|
-
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
778
|
-
const uint8_t* LUT = nullptr;
|
|
779
|
-
|
|
780
|
-
if (single_LUT) {
|
|
781
|
-
LUT = dis_tables.get() + i * dim12;
|
|
782
|
-
}
|
|
783
|
-
for (idx_t j = 0; j < nprobe; j++) {
|
|
784
|
-
if (!single_LUT) {
|
|
785
|
-
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
786
|
-
}
|
|
787
|
-
idx_t list_no = coarse_ids[i * nprobe + j];
|
|
788
|
-
if (list_no < 0)
|
|
789
|
-
continue;
|
|
790
|
-
size_t ls = orig_invlists->list_size(list_no);
|
|
791
|
-
if (ls == 0)
|
|
792
|
-
continue;
|
|
793
|
-
InvertedLists::ScopedCodes codes(orig_invlists, list_no);
|
|
794
|
-
InvertedLists::ScopedIds ids(orig_invlists, list_no);
|
|
795
|
-
|
|
796
|
-
uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
|
|
797
|
-
|
|
798
|
-
pq_estimators_from_tables_generic<C>(
|
|
799
|
-
pq,
|
|
800
|
-
pq.nbits,
|
|
801
|
-
codes.get(),
|
|
802
|
-
ls,
|
|
803
|
-
LUT,
|
|
804
|
-
ids.get(),
|
|
805
|
-
bias,
|
|
806
|
-
k,
|
|
807
|
-
heap_dis,
|
|
808
|
-
heap_ids);
|
|
809
|
-
|
|
810
|
-
nlist_visited++;
|
|
811
|
-
ndis += ls;
|
|
812
|
-
}
|
|
813
|
-
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
814
|
-
// convert distances to float
|
|
815
|
-
{
|
|
816
|
-
float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
|
|
817
|
-
if (skip & 16) {
|
|
818
|
-
one_a = 1;
|
|
819
|
-
b = 0;
|
|
820
|
-
}
|
|
821
|
-
float* heap_dis_float = distances + i * k;
|
|
822
|
-
for (int j = 0; j < k; j++) {
|
|
823
|
-
heap_dis_float[j] = b + heap_dis[j] * one_a;
|
|
824
|
-
}
|
|
825
|
-
}
|
|
826
|
-
}
|
|
827
|
-
indexIVF_stats.nq += n;
|
|
828
|
-
indexIVF_stats.ndis += ndis;
|
|
829
|
-
indexIVF_stats.nlist += nlist_visited;
|
|
830
|
-
}
|
|
831
|
-
|
|
832
|
-
template <class C>
|
|
833
|
-
void IndexIVFPQFastScan::search_implem_10(
|
|
834
|
-
idx_t n,
|
|
835
|
-
const float* x,
|
|
836
|
-
idx_t k,
|
|
837
|
-
float* distances,
|
|
838
|
-
idx_t* labels,
|
|
839
|
-
int impl,
|
|
840
|
-
size_t* ndis_out,
|
|
841
|
-
size_t* nlist_out) const {
|
|
842
|
-
memset(distances, -1, sizeof(float) * k * n);
|
|
843
|
-
memset(labels, -1, sizeof(idx_t) * k * n);
|
|
844
|
-
|
|
845
|
-
using HeapHC = HeapHandler<C, true>;
|
|
846
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
847
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
848
|
-
|
|
849
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
850
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
851
|
-
|
|
852
|
-
uint64_t times[10];
|
|
853
|
-
memset(times, 0, sizeof(times));
|
|
854
|
-
int ti = 0;
|
|
855
|
-
#define TIC times[ti++] = get_cy()
|
|
856
|
-
TIC;
|
|
857
|
-
|
|
858
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
859
|
-
|
|
860
|
-
TIC;
|
|
861
|
-
|
|
862
|
-
size_t dim12 = pq.ksub * M2;
|
|
863
|
-
AlignedTable<uint8_t> dis_tables;
|
|
864
|
-
AlignedTable<uint16_t> biases;
|
|
865
|
-
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
866
|
-
|
|
867
|
-
compute_LUT_uint8(
|
|
868
|
-
n,
|
|
869
|
-
x,
|
|
870
|
-
coarse_ids.get(),
|
|
871
|
-
coarse_dis.get(),
|
|
872
|
-
dis_tables,
|
|
873
|
-
biases,
|
|
874
|
-
normalizers.get());
|
|
875
|
-
|
|
876
|
-
TIC;
|
|
877
|
-
|
|
878
|
-
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
879
|
-
|
|
880
|
-
TIC;
|
|
881
|
-
size_t ndis = 0, nlist_visited = 0;
|
|
882
|
-
|
|
883
|
-
{
|
|
884
|
-
AlignedTable<uint16_t> tmp_distances(k);
|
|
885
|
-
for (idx_t i = 0; i < n; i++) {
|
|
886
|
-
const uint8_t* LUT = nullptr;
|
|
887
|
-
int qmap1[1] = {0};
|
|
888
|
-
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
889
|
-
|
|
890
|
-
if (k == 1) {
|
|
891
|
-
handler.reset(new SingleResultHC(1, 0));
|
|
892
|
-
} else if (impl == 10) {
|
|
893
|
-
handler.reset(new HeapHC(
|
|
894
|
-
1, tmp_distances.get(), labels + i * k, k, 0));
|
|
895
|
-
} else if (impl == 11) {
|
|
896
|
-
handler.reset(new ReservoirHC(1, 0, k, 2 * k));
|
|
897
|
-
} else {
|
|
898
|
-
FAISS_THROW_MSG("invalid");
|
|
899
|
-
}
|
|
900
|
-
|
|
901
|
-
handler->q_map = qmap1;
|
|
902
|
-
|
|
903
|
-
if (single_LUT) {
|
|
904
|
-
LUT = dis_tables.get() + i * dim12;
|
|
905
|
-
}
|
|
906
|
-
for (idx_t j = 0; j < nprobe; j++) {
|
|
907
|
-
size_t ij = i * nprobe + j;
|
|
908
|
-
if (!single_LUT) {
|
|
909
|
-
LUT = dis_tables.get() + ij * dim12;
|
|
910
|
-
}
|
|
911
|
-
if (biases.get()) {
|
|
912
|
-
handler->dbias = biases.get() + ij;
|
|
913
|
-
}
|
|
914
|
-
|
|
915
|
-
idx_t list_no = coarse_ids[ij];
|
|
916
|
-
if (list_no < 0)
|
|
917
|
-
continue;
|
|
918
|
-
size_t ls = invlists->list_size(list_no);
|
|
919
|
-
if (ls == 0)
|
|
920
|
-
continue;
|
|
921
|
-
|
|
922
|
-
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
923
|
-
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
924
|
-
|
|
925
|
-
handler->ntotal = ls;
|
|
926
|
-
handler->id_map = ids.get();
|
|
927
|
-
|
|
928
|
-
#define DISPATCH(classHC) \
|
|
929
|
-
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
930
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
931
|
-
pq4_accumulate_loop( \
|
|
932
|
-
1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res); \
|
|
933
|
-
}
|
|
934
|
-
DISPATCH(HeapHC)
|
|
935
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
936
|
-
#undef DISPATCH
|
|
937
|
-
|
|
938
|
-
nlist_visited++;
|
|
939
|
-
ndis++;
|
|
940
|
-
}
|
|
941
|
-
|
|
942
|
-
handler->to_flat_arrays(
|
|
943
|
-
distances + i * k,
|
|
944
|
-
labels + i * k,
|
|
945
|
-
skip & 16 ? nullptr : normalizers.get() + i * 2);
|
|
311
|
+
FAISS_THROW_FMT("metric %d not supported", metric_type);
|
|
946
312
|
}
|
|
947
313
|
}
|
|
948
|
-
*ndis_out = ndis;
|
|
949
|
-
*nlist_out = nlist;
|
|
950
314
|
}
|
|
951
315
|
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
const float* x,
|
|
956
|
-
idx_t k,
|
|
957
|
-
float* distances,
|
|
958
|
-
idx_t* labels,
|
|
959
|
-
int impl,
|
|
960
|
-
size_t* ndis_out,
|
|
961
|
-
size_t* nlist_out) const {
|
|
962
|
-
if (n == 0) { // does not work well with reservoir
|
|
963
|
-
return;
|
|
964
|
-
}
|
|
965
|
-
FAISS_THROW_IF_NOT(bbs == 32);
|
|
966
|
-
|
|
967
|
-
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
968
|
-
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
969
|
-
|
|
970
|
-
uint64_t times[10];
|
|
971
|
-
memset(times, 0, sizeof(times));
|
|
972
|
-
int ti = 0;
|
|
973
|
-
#define TIC times[ti++] = get_cy()
|
|
974
|
-
TIC;
|
|
975
|
-
|
|
976
|
-
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
977
|
-
|
|
978
|
-
TIC;
|
|
979
|
-
|
|
980
|
-
size_t dim12 = pq.ksub * M2;
|
|
981
|
-
AlignedTable<uint8_t> dis_tables;
|
|
982
|
-
AlignedTable<uint16_t> biases;
|
|
983
|
-
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
984
|
-
|
|
985
|
-
compute_LUT_uint8(
|
|
986
|
-
n,
|
|
987
|
-
x,
|
|
988
|
-
coarse_ids.get(),
|
|
989
|
-
coarse_dis.get(),
|
|
990
|
-
dis_tables,
|
|
991
|
-
biases,
|
|
992
|
-
normalizers.get());
|
|
993
|
-
|
|
994
|
-
TIC;
|
|
995
|
-
|
|
996
|
-
struct QC {
|
|
997
|
-
int qno; // sequence number of the query
|
|
998
|
-
int list_no; // list to visit
|
|
999
|
-
int rank; // this is the rank'th result of the coarse quantizer
|
|
1000
|
-
};
|
|
1001
|
-
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
1002
|
-
|
|
1003
|
-
std::vector<QC> qcs;
|
|
1004
|
-
{
|
|
1005
|
-
int ij = 0;
|
|
1006
|
-
for (int i = 0; i < n; i++) {
|
|
1007
|
-
for (int j = 0; j < nprobe; j++) {
|
|
1008
|
-
if (coarse_ids[ij] >= 0) {
|
|
1009
|
-
qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
|
|
1010
|
-
}
|
|
1011
|
-
ij++;
|
|
1012
|
-
}
|
|
1013
|
-
}
|
|
1014
|
-
std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
|
|
1015
|
-
return a.list_no < b.list_no;
|
|
1016
|
-
});
|
|
1017
|
-
}
|
|
1018
|
-
TIC;
|
|
1019
|
-
|
|
1020
|
-
// prepare the result handlers
|
|
1021
|
-
|
|
1022
|
-
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
1023
|
-
AlignedTable<uint16_t> tmp_distances;
|
|
1024
|
-
|
|
1025
|
-
using HeapHC = HeapHandler<C, true>;
|
|
1026
|
-
using ReservoirHC = ReservoirHandler<C, true>;
|
|
1027
|
-
using SingleResultHC = SingleResultHandler<C, true>;
|
|
1028
|
-
|
|
1029
|
-
if (k == 1) {
|
|
1030
|
-
handler.reset(new SingleResultHC(n, 0));
|
|
1031
|
-
} else if (impl == 12) {
|
|
1032
|
-
tmp_distances.resize(n * k);
|
|
1033
|
-
handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
|
|
1034
|
-
} else if (impl == 13) {
|
|
1035
|
-
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
1036
|
-
}
|
|
1037
|
-
|
|
1038
|
-
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
1039
|
-
|
|
1040
|
-
std::vector<uint16_t> tmp_bias;
|
|
1041
|
-
if (biases.get()) {
|
|
1042
|
-
tmp_bias.resize(qbs2);
|
|
1043
|
-
handler->dbias = tmp_bias.data();
|
|
1044
|
-
}
|
|
1045
|
-
TIC;
|
|
1046
|
-
|
|
1047
|
-
size_t ndis = 0;
|
|
1048
|
-
|
|
1049
|
-
size_t i0 = 0;
|
|
1050
|
-
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
1051
|
-
while (i0 < qcs.size()) {
|
|
1052
|
-
uint64_t tt0 = get_cy();
|
|
1053
|
-
|
|
1054
|
-
// find all queries that access this inverted list
|
|
1055
|
-
int list_no = qcs[i0].list_no;
|
|
1056
|
-
size_t i1 = i0 + 1;
|
|
1057
|
-
|
|
1058
|
-
while (i1 < qcs.size() && i1 < i0 + qbs2) {
|
|
1059
|
-
if (qcs[i1].list_no != list_no) {
|
|
1060
|
-
break;
|
|
1061
|
-
}
|
|
1062
|
-
i1++;
|
|
1063
|
-
}
|
|
1064
|
-
|
|
1065
|
-
size_t list_size = invlists->list_size(list_no);
|
|
1066
|
-
|
|
1067
|
-
if (list_size == 0) {
|
|
1068
|
-
i0 = i1;
|
|
1069
|
-
continue;
|
|
1070
|
-
}
|
|
1071
|
-
|
|
1072
|
-
// re-organize LUTs and biases into the right order
|
|
1073
|
-
int nc = i1 - i0;
|
|
1074
|
-
|
|
1075
|
-
std::vector<int> q_map(nc), lut_entries(nc);
|
|
1076
|
-
AlignedTable<uint8_t> LUT(nc * dim12);
|
|
1077
|
-
memset(LUT.get(), -1, nc * dim12);
|
|
1078
|
-
int qbs = pq4_preferred_qbs(nc);
|
|
1079
|
-
|
|
1080
|
-
for (size_t i = i0; i < i1; i++) {
|
|
1081
|
-
const QC& qc = qcs[i];
|
|
1082
|
-
q_map[i - i0] = qc.qno;
|
|
1083
|
-
int ij = qc.qno * nprobe + qc.rank;
|
|
1084
|
-
lut_entries[i - i0] = single_LUT ? qc.qno : ij;
|
|
1085
|
-
if (biases.get()) {
|
|
1086
|
-
tmp_bias[i - i0] = biases[ij];
|
|
1087
|
-
}
|
|
1088
|
-
}
|
|
1089
|
-
pq4_pack_LUT_qbs_q_map(
|
|
1090
|
-
qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
|
|
1091
|
-
|
|
1092
|
-
// access the inverted list
|
|
1093
|
-
|
|
1094
|
-
ndis += (i1 - i0) * list_size;
|
|
1095
|
-
|
|
1096
|
-
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
1097
|
-
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
1098
|
-
|
|
1099
|
-
// prepare the handler
|
|
1100
|
-
|
|
1101
|
-
handler->ntotal = list_size;
|
|
1102
|
-
handler->q_map = q_map.data();
|
|
1103
|
-
handler->id_map = ids.get();
|
|
1104
|
-
uint64_t tt1 = get_cy();
|
|
1105
|
-
|
|
1106
|
-
#define DISPATCH(classHC) \
|
|
1107
|
-
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
1108
|
-
auto* res = static_cast<classHC*>(handler.get()); \
|
|
1109
|
-
pq4_accumulate_loop_qbs( \
|
|
1110
|
-
qbs, list_size, M2, codes.get(), LUT.get(), *res); \
|
|
1111
|
-
}
|
|
1112
|
-
DISPATCH(HeapHC)
|
|
1113
|
-
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
1114
|
-
|
|
1115
|
-
// prepare for next loop
|
|
1116
|
-
i0 = i1;
|
|
1117
|
-
|
|
1118
|
-
uint64_t tt2 = get_cy();
|
|
1119
|
-
t_copy_pack += tt1 - tt0;
|
|
1120
|
-
t_scan += tt2 - tt1;
|
|
1121
|
-
}
|
|
1122
|
-
TIC;
|
|
1123
|
-
|
|
1124
|
-
// labels is in-place for HeapHC
|
|
1125
|
-
handler->to_flat_arrays(
|
|
1126
|
-
distances, labels, skip & 16 ? nullptr : normalizers.get());
|
|
1127
|
-
|
|
1128
|
-
TIC;
|
|
1129
|
-
|
|
1130
|
-
// these stats are not thread-safe
|
|
1131
|
-
|
|
1132
|
-
for (int i = 1; i < ti; i++) {
|
|
1133
|
-
IVFFastScan_stats.times[i] += times[i] - times[i - 1];
|
|
1134
|
-
}
|
|
1135
|
-
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
1136
|
-
IVFFastScan_stats.t_scan += t_scan;
|
|
1137
|
-
|
|
1138
|
-
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
1139
|
-
for (int i = 0; i < 4; i++) {
|
|
1140
|
-
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
1141
|
-
}
|
|
1142
|
-
}
|
|
1143
|
-
|
|
1144
|
-
*ndis_out = ndis;
|
|
1145
|
-
*nlist_out = nlist;
|
|
316
|
+
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
|
317
|
+
const {
|
|
318
|
+
pq.decode(bytes, x, n);
|
|
1146
319
|
}
|
|
1147
320
|
|
|
1148
|
-
IVFFastScanStats IVFFastScan_stats;
|
|
1149
|
-
|
|
1150
321
|
} // namespace faiss
|