faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -8,70 +8,68 @@
|
|
|
8
8
|
#include <faiss/IndexIVFPQFastScan.h>
|
|
9
9
|
|
|
10
10
|
#include <cassert>
|
|
11
|
+
#include <cinttypes>
|
|
11
12
|
#include <cstdio>
|
|
12
|
-
#include <inttypes.h>
|
|
13
13
|
|
|
14
14
|
#include <omp.h>
|
|
15
15
|
|
|
16
16
|
#include <memory>
|
|
17
17
|
|
|
18
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
18
19
|
#include <faiss/impl/FaissAssert.h>
|
|
19
|
-
#include <faiss/utils/utils.h>
|
|
20
20
|
#include <faiss/utils/distances.h>
|
|
21
21
|
#include <faiss/utils/simdlib.h>
|
|
22
|
-
#include <faiss/
|
|
22
|
+
#include <faiss/utils/utils.h>
|
|
23
23
|
|
|
24
24
|
#include <faiss/invlists/BlockInvertedLists.h>
|
|
25
25
|
|
|
26
|
+
#include <faiss/impl/pq4_fast_scan.h>
|
|
26
27
|
#include <faiss/impl/simd_result_handlers.h>
|
|
27
28
|
#include <faiss/utils/quantize_lut.h>
|
|
28
|
-
#include <faiss/impl/pq4_fast_scan.h>
|
|
29
29
|
|
|
30
30
|
namespace faiss {
|
|
31
31
|
|
|
32
32
|
using namespace simd_result_handlers;
|
|
33
33
|
|
|
34
|
-
|
|
35
34
|
inline size_t roundup(size_t a, size_t b) {
|
|
36
35
|
return (a + b - 1) / b * b;
|
|
37
36
|
}
|
|
38
37
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
38
|
+
IndexIVFPQFastScan::IndexIVFPQFastScan(
|
|
39
|
+
Index* quantizer,
|
|
40
|
+
size_t d,
|
|
41
|
+
size_t nlist,
|
|
42
|
+
size_t M,
|
|
43
|
+
size_t nbits_per_idx,
|
|
44
|
+
MetricType metric,
|
|
45
|
+
int bbs)
|
|
46
|
+
: IndexIVF(quantizer, d, nlist, 0, metric),
|
|
47
|
+
pq(d, M, nbits_per_idx),
|
|
48
|
+
bbs(bbs) {
|
|
48
49
|
FAISS_THROW_IF_NOT(nbits_per_idx == 4);
|
|
49
50
|
M2 = roundup(pq.M, 2);
|
|
50
51
|
by_residual = false; // set to false by default because it's much faster
|
|
51
52
|
is_trained = false;
|
|
52
53
|
code_size = pq.code_size;
|
|
53
54
|
|
|
54
|
-
replace_invlists(
|
|
55
|
-
new BlockInvertedLists(nlist, bbs, bbs * M2 / 2),
|
|
56
|
-
true
|
|
57
|
-
);
|
|
55
|
+
replace_invlists(new BlockInvertedLists(nlist, bbs, bbs * M2 / 2), true);
|
|
58
56
|
}
|
|
59
57
|
|
|
60
|
-
IndexIVFPQFastScan::IndexIVFPQFastScan
|
|
61
|
-
{
|
|
58
|
+
IndexIVFPQFastScan::IndexIVFPQFastScan() {
|
|
62
59
|
by_residual = false;
|
|
63
60
|
bbs = 0;
|
|
64
61
|
M2 = 0;
|
|
65
62
|
}
|
|
66
63
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
64
|
+
IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
|
|
65
|
+
: IndexIVF(
|
|
66
|
+
orig.quantizer,
|
|
67
|
+
orig.d,
|
|
68
|
+
orig.nlist,
|
|
69
|
+
orig.pq.code_size,
|
|
70
|
+
orig.metric_type),
|
|
71
|
+
pq(orig.pq),
|
|
72
|
+
bbs(bbs) {
|
|
75
73
|
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
|
|
76
74
|
|
|
77
75
|
by_residual = orig.by_residual;
|
|
@@ -83,69 +81,68 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs):
|
|
|
83
81
|
M2 = roundup(M, 2);
|
|
84
82
|
|
|
85
83
|
replace_invlists(
|
|
86
|
-
|
|
87
|
-
true
|
|
88
|
-
);
|
|
84
|
+
new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2), true);
|
|
89
85
|
|
|
90
86
|
precomputed_table.resize(orig.precomputed_table.size());
|
|
91
87
|
|
|
92
88
|
if (precomputed_table.nbytes() > 0) {
|
|
93
|
-
memcpy(precomputed_table.get(),
|
|
94
|
-
precomputed_table.
|
|
95
|
-
|
|
89
|
+
memcpy(precomputed_table.get(),
|
|
90
|
+
orig.precomputed_table.data(),
|
|
91
|
+
precomputed_table.nbytes());
|
|
96
92
|
}
|
|
97
93
|
|
|
98
|
-
for(size_t i = 0; i < nlist; i++) {
|
|
94
|
+
for (size_t i = 0; i < nlist; i++) {
|
|
99
95
|
size_t nb = orig.invlists->list_size(i);
|
|
100
96
|
size_t nb2 = roundup(nb, bbs);
|
|
101
97
|
AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
|
|
102
98
|
pq4_pack_codes(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
99
|
+
InvertedLists::ScopedCodes(orig.invlists, i).get(),
|
|
100
|
+
nb,
|
|
101
|
+
M,
|
|
102
|
+
nb2,
|
|
103
|
+
bbs,
|
|
104
|
+
M2,
|
|
105
|
+
tmp.get());
|
|
107
106
|
invlists->add_entries(
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
107
|
+
i,
|
|
108
|
+
nb,
|
|
109
|
+
InvertedLists::ScopedIds(orig.invlists, i).get(),
|
|
110
|
+
tmp.get());
|
|
112
111
|
}
|
|
113
112
|
|
|
114
113
|
orig_invlists = orig.invlists;
|
|
115
114
|
}
|
|
116
115
|
|
|
117
|
-
|
|
118
|
-
|
|
119
116
|
/*********************************************************
|
|
120
117
|
* Training
|
|
121
118
|
*********************************************************/
|
|
122
119
|
|
|
123
|
-
void IndexIVFPQFastScan::train_residual
|
|
124
|
-
|
|
120
|
+
void IndexIVFPQFastScan::train_residual(idx_t n, const float* x_in) {
|
|
121
|
+
const float* x = fvecs_maybe_subsample(
|
|
122
|
+
d,
|
|
123
|
+
(size_t*)&n,
|
|
124
|
+
pq.cp.max_points_per_centroid * pq.ksub,
|
|
125
|
+
x_in,
|
|
126
|
+
verbose,
|
|
127
|
+
pq.cp.seed);
|
|
125
128
|
|
|
126
|
-
|
|
127
|
-
d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub,
|
|
128
|
-
x_in, verbose, pq.cp.seed);
|
|
129
|
-
|
|
130
|
-
std::unique_ptr<float []> del_x;
|
|
129
|
+
std::unique_ptr<float[]> del_x;
|
|
131
130
|
if (x != x_in) {
|
|
132
131
|
del_x.reset((float*)x);
|
|
133
132
|
}
|
|
134
133
|
|
|
135
|
-
const float
|
|
134
|
+
const float* trainset;
|
|
136
135
|
AlignedTable<float> residuals;
|
|
137
136
|
|
|
138
137
|
if (by_residual) {
|
|
139
|
-
if(verbose)
|
|
138
|
+
if (verbose)
|
|
139
|
+
printf("computing residuals\n");
|
|
140
140
|
std::vector<idx_t> assign(n);
|
|
141
|
-
quantizer->assign
|
|
141
|
+
quantizer->assign(n, x, assign.data());
|
|
142
142
|
residuals.resize(n * d);
|
|
143
143
|
for (idx_t i = 0; i < n; i++) {
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
residuals.data() + i * d,
|
|
147
|
-
assign[i]
|
|
148
|
-
);
|
|
144
|
+
quantizer->compute_residual(
|
|
145
|
+
x + i * d, residuals.data() + i * d, assign[i]);
|
|
149
146
|
}
|
|
150
147
|
trainset = residuals.data();
|
|
151
148
|
} else {
|
|
@@ -153,82 +150,78 @@ void IndexIVFPQFastScan::train_residual (idx_t n, const float *x_in)
|
|
|
153
150
|
}
|
|
154
151
|
|
|
155
152
|
if (verbose) {
|
|
156
|
-
printf
|
|
157
|
-
|
|
153
|
+
printf("training %zdx%zd product quantizer on "
|
|
154
|
+
"%" PRId64 " vectors in %dD\n",
|
|
155
|
+
pq.M,
|
|
156
|
+
pq.ksub,
|
|
157
|
+
n,
|
|
158
|
+
d);
|
|
158
159
|
}
|
|
159
160
|
pq.verbose = verbose;
|
|
160
|
-
pq.train
|
|
161
|
+
pq.train(n, trainset);
|
|
161
162
|
|
|
162
163
|
if (by_residual && metric_type == METRIC_L2) {
|
|
163
164
|
precompute_table();
|
|
164
165
|
}
|
|
165
|
-
|
|
166
166
|
}
|
|
167
167
|
|
|
168
|
-
void IndexIVFPQFastScan::precompute_table
|
|
169
|
-
{
|
|
168
|
+
void IndexIVFPQFastScan::precompute_table() {
|
|
170
169
|
initialize_IVFPQ_precomputed_table(
|
|
171
|
-
|
|
172
|
-
quantizer, pq, precomputed_table, verbose
|
|
173
|
-
);
|
|
170
|
+
use_precomputed_table, quantizer, pq, precomputed_table, verbose);
|
|
174
171
|
}
|
|
175
172
|
|
|
176
|
-
|
|
177
173
|
/*********************************************************
|
|
178
174
|
* Code management functions
|
|
179
175
|
*********************************************************/
|
|
180
176
|
|
|
181
|
-
|
|
182
|
-
|
|
183
177
|
void IndexIVFPQFastScan::encode_vectors(
|
|
184
|
-
idx_t n,
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
178
|
+
idx_t n,
|
|
179
|
+
const float* x,
|
|
180
|
+
const idx_t* list_nos,
|
|
181
|
+
uint8_t* codes,
|
|
182
|
+
bool include_listnos) const {
|
|
188
183
|
if (by_residual) {
|
|
189
|
-
AlignedTable<float> residuals
|
|
184
|
+
AlignedTable<float> residuals(n * d);
|
|
190
185
|
for (size_t i = 0; i < n; i++) {
|
|
191
186
|
if (list_nos[i] < 0) {
|
|
192
|
-
memset
|
|
187
|
+
memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
|
|
193
188
|
} else {
|
|
194
|
-
quantizer->compute_residual
|
|
195
|
-
|
|
189
|
+
quantizer->compute_residual(
|
|
190
|
+
x + i * d, residuals.data() + i * d, list_nos[i]);
|
|
196
191
|
}
|
|
197
192
|
}
|
|
198
|
-
pq.compute_codes
|
|
193
|
+
pq.compute_codes(residuals.data(), codes, n);
|
|
199
194
|
} else {
|
|
200
|
-
pq.compute_codes
|
|
195
|
+
pq.compute_codes(x, codes, n);
|
|
201
196
|
}
|
|
202
197
|
|
|
203
198
|
if (include_listnos) {
|
|
204
199
|
size_t coarse_size = coarse_code_size();
|
|
205
200
|
for (idx_t i = n - 1; i >= 0; i--) {
|
|
206
|
-
uint8_t
|
|
207
|
-
memmove
|
|
208
|
-
|
|
209
|
-
encode_listno (list_nos[i], code);
|
|
201
|
+
uint8_t* code = codes + i * (coarse_size + code_size);
|
|
202
|
+
memmove(code + coarse_size, codes + i * code_size, code_size);
|
|
203
|
+
encode_listno(list_nos[i], code);
|
|
210
204
|
}
|
|
211
205
|
}
|
|
212
206
|
}
|
|
213
207
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
208
|
+
void IndexIVFPQFastScan::add_with_ids(
|
|
209
|
+
idx_t n,
|
|
210
|
+
const float* x,
|
|
211
|
+
const idx_t* xids) {
|
|
219
212
|
// copied from IndexIVF::add_with_ids --->
|
|
220
213
|
|
|
221
214
|
// do some blocking to avoid excessive allocs
|
|
222
215
|
idx_t bs = 65536;
|
|
223
216
|
if (n > bs) {
|
|
224
217
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
225
|
-
idx_t i1 = std::min
|
|
218
|
+
idx_t i1 = std::min(n, i0 + bs);
|
|
226
219
|
if (verbose) {
|
|
227
220
|
printf(" IndexIVFPQFastScan::add_with_ids %zd: %zd",
|
|
228
|
-
size_t(i0),
|
|
221
|
+
size_t(i0),
|
|
222
|
+
size_t(i1));
|
|
229
223
|
}
|
|
230
|
-
add_with_ids
|
|
231
|
-
xids ? xids + i0 : nullptr);
|
|
224
|
+
add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
|
|
232
225
|
}
|
|
233
226
|
return;
|
|
234
227
|
}
|
|
@@ -236,37 +229,38 @@ void IndexIVFPQFastScan::add_with_ids (
|
|
|
236
229
|
|
|
237
230
|
AlignedTable<uint8_t> codes(n * code_size);
|
|
238
231
|
|
|
239
|
-
FAISS_THROW_IF_NOT
|
|
240
|
-
direct_map.check_can_add
|
|
232
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
233
|
+
direct_map.check_can_add(xids);
|
|
241
234
|
|
|
242
|
-
std::unique_ptr<idx_t
|
|
243
|
-
quantizer->assign
|
|
235
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n]);
|
|
236
|
+
quantizer->assign(n, x, idx.get());
|
|
244
237
|
size_t nadd = 0, nminus1 = 0;
|
|
245
238
|
|
|
246
239
|
for (size_t i = 0; i < n; i++) {
|
|
247
|
-
if (idx[i] < 0)
|
|
240
|
+
if (idx[i] < 0)
|
|
241
|
+
nminus1++;
|
|
248
242
|
}
|
|
249
243
|
|
|
250
244
|
AlignedTable<uint8_t> flat_codes(n * code_size);
|
|
251
|
-
encode_vectors
|
|
245
|
+
encode_vectors(n, x, idx.get(), flat_codes.get());
|
|
252
246
|
|
|
253
247
|
DirectMapAdd dm_adder(direct_map, n, xids);
|
|
254
248
|
|
|
255
249
|
// <---
|
|
256
250
|
|
|
257
|
-
BlockInvertedLists
|
|
258
|
-
FAISS_THROW_IF_NOT_MSG
|
|
251
|
+
BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
|
|
252
|
+
FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
|
|
259
253
|
|
|
260
254
|
// prepare batches
|
|
261
255
|
std::vector<idx_t> order(n);
|
|
262
|
-
for(idx_t i = 0; i < n
|
|
256
|
+
for (idx_t i = 0; i < n; i++) {
|
|
257
|
+
order[i] = i;
|
|
258
|
+
}
|
|
263
259
|
|
|
264
260
|
// TODO should not need stable
|
|
265
|
-
std::stable_sort(order.begin(), order.end(),
|
|
266
|
-
[
|
|
267
|
-
|
|
268
|
-
}
|
|
269
|
-
);
|
|
261
|
+
std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
|
|
262
|
+
return idx[a] < idx[b];
|
|
263
|
+
});
|
|
270
264
|
|
|
271
265
|
// TODO parallelize
|
|
272
266
|
idx_t i0 = 0;
|
|
@@ -274,7 +268,7 @@ void IndexIVFPQFastScan::add_with_ids (
|
|
|
274
268
|
idx_t list_no = idx[order[i0]];
|
|
275
269
|
idx_t i1 = i0 + 1;
|
|
276
270
|
while (i1 < n && idx[order[i1]] == list_no) {
|
|
277
|
-
i1
|
|
271
|
+
i1++;
|
|
278
272
|
}
|
|
279
273
|
|
|
280
274
|
if (list_no == -1) {
|
|
@@ -288,58 +282,57 @@ void IndexIVFPQFastScan::add_with_ids (
|
|
|
288
282
|
|
|
289
283
|
bil->resize(list_no, list_size + i1 - i0);
|
|
290
284
|
|
|
291
|
-
for(idx_t i = i0; i < i1; i++) {
|
|
285
|
+
for (idx_t i = i0; i < i1; i++) {
|
|
292
286
|
size_t ofs = list_size + i - i0;
|
|
293
287
|
idx_t id = xids ? xids[order[i]] : ntotal + order[i];
|
|
294
|
-
dm_adder.add
|
|
288
|
+
dm_adder.add(order[i], list_no, ofs);
|
|
295
289
|
bil->ids[list_no][ofs] = id;
|
|
296
|
-
memcpy(
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
code_size
|
|
300
|
-
);
|
|
290
|
+
memcpy(list_codes.data() + (i - i0) * code_size,
|
|
291
|
+
flat_codes.data() + order[i] * code_size,
|
|
292
|
+
code_size);
|
|
301
293
|
nadd++;
|
|
302
294
|
}
|
|
303
295
|
pq4_pack_codes_range(
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
296
|
+
list_codes.data(),
|
|
297
|
+
pq.M,
|
|
298
|
+
list_size,
|
|
299
|
+
list_size + i1 - i0,
|
|
300
|
+
bbs,
|
|
301
|
+
M2,
|
|
302
|
+
bil->codes[list_no].data());
|
|
308
303
|
|
|
309
304
|
i0 = i1;
|
|
310
305
|
}
|
|
311
306
|
|
|
312
307
|
ntotal += n;
|
|
313
|
-
|
|
314
308
|
}
|
|
315
309
|
|
|
316
|
-
|
|
317
|
-
|
|
318
310
|
/*********************************************************
|
|
319
311
|
* search
|
|
320
312
|
*********************************************************/
|
|
321
313
|
|
|
322
|
-
|
|
323
314
|
namespace {
|
|
324
315
|
|
|
325
316
|
// from impl/ProductQuantizer.cpp
|
|
326
317
|
template <class C, typename dis_t>
|
|
327
318
|
void pq_estimators_from_tables_generic(
|
|
328
|
-
const ProductQuantizer& pq,
|
|
329
|
-
|
|
330
|
-
const
|
|
319
|
+
const ProductQuantizer& pq,
|
|
320
|
+
size_t nbits,
|
|
321
|
+
const uint8_t* codes,
|
|
322
|
+
size_t ncodes,
|
|
323
|
+
const dis_t* dis_table,
|
|
324
|
+
const int64_t* ids,
|
|
331
325
|
float dis0,
|
|
332
|
-
size_t k,
|
|
333
|
-
|
|
326
|
+
size_t k,
|
|
327
|
+
typename C::T* heap_dis,
|
|
328
|
+
int64_t* heap_ids) {
|
|
334
329
|
using accu_t = typename C::T;
|
|
335
330
|
const size_t M = pq.M;
|
|
336
331
|
const size_t ksub = pq.ksub;
|
|
337
332
|
for (size_t j = 0; j < ncodes; ++j) {
|
|
338
|
-
PQDecoderGeneric decoder(
|
|
339
|
-
codes + j * pq.code_size, nbits
|
|
340
|
-
);
|
|
333
|
+
PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
|
|
341
334
|
accu_t dis = dis0;
|
|
342
|
-
const dis_t
|
|
335
|
+
const dis_t* dt = dis_table;
|
|
343
336
|
for (size_t m = 0; m < M; m++) {
|
|
344
337
|
uint64_t c = decoder.decode();
|
|
345
338
|
dis += dt[c];
|
|
@@ -356,17 +349,19 @@ void pq_estimators_from_tables_generic(
|
|
|
356
349
|
using idx_t = Index::idx_t;
|
|
357
350
|
using namespace quantize_lut;
|
|
358
351
|
|
|
359
|
-
void fvec_madd_avx
|
|
360
|
-
size_t n,
|
|
361
|
-
|
|
362
|
-
|
|
352
|
+
void fvec_madd_avx(
|
|
353
|
+
size_t n,
|
|
354
|
+
const float* a,
|
|
355
|
+
float bf,
|
|
356
|
+
const float* b,
|
|
357
|
+
float* c) {
|
|
363
358
|
assert(is_aligned_pointer(a));
|
|
364
359
|
assert(is_aligned_pointer(b));
|
|
365
360
|
assert(is_aligned_pointer(c));
|
|
366
361
|
assert(n % 8 == 0);
|
|
367
362
|
simd8float32 bf8(bf);
|
|
368
363
|
n /= 8;
|
|
369
|
-
for(size_t i = 0; i < n; i++) {
|
|
364
|
+
for (size_t i = 0; i < n; i++) {
|
|
370
365
|
simd8float32 ai(a);
|
|
371
366
|
simd8float32 bi(b);
|
|
372
367
|
|
|
@@ -376,7 +371,6 @@ void fvec_madd_avx (
|
|
|
376
371
|
a += 8;
|
|
377
372
|
b += 8;
|
|
378
373
|
}
|
|
379
|
-
|
|
380
374
|
}
|
|
381
375
|
|
|
382
376
|
} // anonymous namespace
|
|
@@ -385,23 +379,20 @@ void fvec_madd_avx (
|
|
|
385
379
|
* Look-Up Table functions
|
|
386
380
|
*********************************************************/
|
|
387
381
|
|
|
388
|
-
|
|
389
382
|
void IndexIVFPQFastScan::compute_LUT(
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
{
|
|
396
|
-
const IndexIVFPQFastScan
|
|
383
|
+
size_t n,
|
|
384
|
+
const float* x,
|
|
385
|
+
const idx_t* coarse_ids,
|
|
386
|
+
const float* coarse_dis,
|
|
387
|
+
AlignedTable<float>& dis_tables,
|
|
388
|
+
AlignedTable<float>& biases) const {
|
|
389
|
+
const IndexIVFPQFastScan& ivfpq = *this;
|
|
397
390
|
size_t dim12 = pq.ksub * pq.M;
|
|
398
391
|
size_t d = pq.d;
|
|
399
392
|
size_t nprobe = ivfpq.nprobe;
|
|
400
393
|
|
|
401
394
|
if (ivfpq.by_residual) {
|
|
402
|
-
|
|
403
395
|
if (ivfpq.metric_type == METRIC_L2) {
|
|
404
|
-
|
|
405
396
|
dis_tables.resize(n * nprobe * dim12);
|
|
406
397
|
|
|
407
398
|
if (ivfpq.use_precomputed_table == 1) {
|
|
@@ -409,57 +400,54 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
409
400
|
memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
|
|
410
401
|
|
|
411
402
|
AlignedTable<float> ip_table(n * dim12);
|
|
412
|
-
pq.compute_inner_prod_tables
|
|
403
|
+
pq.compute_inner_prod_tables(n, x, ip_table.get());
|
|
413
404
|
|
|
414
405
|
#pragma omp parallel for if (n * nprobe > 8000)
|
|
415
|
-
for(idx_t ij = 0; ij < n * nprobe; ij++) {
|
|
406
|
+
for (idx_t ij = 0; ij < n * nprobe; ij++) {
|
|
416
407
|
idx_t i = ij / nprobe;
|
|
417
|
-
float
|
|
408
|
+
float* tab = dis_tables.get() + ij * dim12;
|
|
418
409
|
idx_t cij = coarse_ids[ij];
|
|
419
410
|
|
|
420
411
|
if (cij >= 0) {
|
|
421
|
-
fvec_madd_avx
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
412
|
+
fvec_madd_avx(
|
|
413
|
+
dim12,
|
|
414
|
+
precomputed_table.get() + cij * dim12,
|
|
415
|
+
-2,
|
|
416
|
+
ip_table.get() + i * dim12,
|
|
417
|
+
tab);
|
|
427
418
|
} else {
|
|
428
419
|
// fill with NaNs so that they are ignored during
|
|
429
420
|
// LUT quantization
|
|
430
|
-
memset
|
|
421
|
+
memset(tab, -1, sizeof(float) * dim12);
|
|
431
422
|
}
|
|
432
423
|
}
|
|
433
424
|
|
|
434
425
|
} else {
|
|
435
|
-
|
|
436
426
|
std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
|
|
437
427
|
biases.resize(n * nprobe);
|
|
438
428
|
memset(biases.get(), 0, sizeof(float) * n * nprobe);
|
|
439
429
|
|
|
440
430
|
#pragma omp parallel for if (n * nprobe > 8000)
|
|
441
|
-
for(idx_t ij = 0; ij < n * nprobe; ij++) {
|
|
431
|
+
for (idx_t ij = 0; ij < n * nprobe; ij++) {
|
|
442
432
|
idx_t i = ij / nprobe;
|
|
443
|
-
float
|
|
433
|
+
float* xij = &xrel[ij * d];
|
|
444
434
|
idx_t cij = coarse_ids[ij];
|
|
445
435
|
|
|
446
436
|
if (cij >= 0) {
|
|
447
|
-
ivfpq.quantizer->compute_residual(
|
|
448
|
-
x + i * d, xij, cij);
|
|
437
|
+
ivfpq.quantizer->compute_residual(x + i * d, xij, cij);
|
|
449
438
|
} else {
|
|
450
439
|
// will fill with NaNs
|
|
451
440
|
memset(xij, -1, sizeof(float) * d);
|
|
452
441
|
}
|
|
453
442
|
}
|
|
454
443
|
|
|
455
|
-
pq.compute_distance_tables
|
|
444
|
+
pq.compute_distance_tables(
|
|
456
445
|
n * nprobe, xrel.get(), dis_tables.get());
|
|
457
|
-
|
|
458
446
|
}
|
|
459
447
|
|
|
460
448
|
} else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
|
|
461
449
|
dis_tables.resize(n * dim12);
|
|
462
|
-
pq.compute_inner_prod_tables
|
|
450
|
+
pq.compute_inner_prod_tables(n, x, dis_tables.get());
|
|
463
451
|
// compute_inner_prod_tables(pq, n, x, dis_tables.get());
|
|
464
452
|
|
|
465
453
|
biases.resize(n * nprobe);
|
|
@@ -471,33 +459,29 @@ void IndexIVFPQFastScan::compute_LUT(
|
|
|
471
459
|
} else {
|
|
472
460
|
dis_tables.resize(n * dim12);
|
|
473
461
|
if (ivfpq.metric_type == METRIC_L2) {
|
|
474
|
-
pq.compute_distance_tables
|
|
462
|
+
pq.compute_distance_tables(n, x, dis_tables.get());
|
|
475
463
|
} else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
|
|
476
|
-
pq.compute_inner_prod_tables
|
|
464
|
+
pq.compute_inner_prod_tables(n, x, dis_tables.get());
|
|
477
465
|
} else {
|
|
478
466
|
FAISS_THROW_FMT("metric %d not supported", ivfpq.metric_type);
|
|
479
467
|
}
|
|
480
468
|
}
|
|
481
|
-
|
|
482
469
|
}
|
|
483
470
|
|
|
484
471
|
void IndexIVFPQFastScan::compute_LUT_uint8(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
472
|
+
size_t n,
|
|
473
|
+
const float* x,
|
|
474
|
+
const idx_t* coarse_ids,
|
|
475
|
+
const float* coarse_dis,
|
|
476
|
+
AlignedTable<uint8_t>& dis_tables,
|
|
477
|
+
AlignedTable<uint16_t>& biases,
|
|
478
|
+
float* normalizers) const {
|
|
479
|
+
const IndexIVFPQFastScan& ivfpq = *this;
|
|
492
480
|
AlignedTable<float> dis_tables_float;
|
|
493
481
|
AlignedTable<float> biases_float;
|
|
494
482
|
|
|
495
483
|
uint64_t t0 = get_cy();
|
|
496
|
-
compute_LUT(
|
|
497
|
-
n, x,
|
|
498
|
-
coarse_ids, coarse_dis,
|
|
499
|
-
dis_tables_float, biases_float
|
|
500
|
-
);
|
|
484
|
+
compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
|
|
501
485
|
IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
|
|
502
486
|
|
|
503
487
|
bool lut_is_3d = ivfpq.by_residual && ivfpq.metric_type == METRIC_L2;
|
|
@@ -514,45 +498,52 @@ void IndexIVFPQFastScan::compute_LUT_uint8(
|
|
|
514
498
|
uint64_t t1 = get_cy();
|
|
515
499
|
|
|
516
500
|
#pragma omp parallel for if (n > 100)
|
|
517
|
-
for(int64_t i = 0; i < n; i++) {
|
|
518
|
-
const float
|
|
519
|
-
const float
|
|
520
|
-
uint8_t
|
|
521
|
-
uint16_t
|
|
501
|
+
for (int64_t i = 0; i < n; i++) {
|
|
502
|
+
const float* t_in = dis_tables_float.get() + i * dim123;
|
|
503
|
+
const float* b_in = nullptr;
|
|
504
|
+
uint8_t* t_out = dis_tables.get() + i * dim123_2;
|
|
505
|
+
uint16_t* b_out = nullptr;
|
|
522
506
|
if (biases_float.get()) {
|
|
523
507
|
b_in = biases_float.get() + i * nprobe;
|
|
524
508
|
b_out = biases.get() + i * nprobe;
|
|
525
509
|
}
|
|
526
510
|
|
|
527
511
|
quantize_LUT_and_bias(
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
512
|
+
nprobe,
|
|
513
|
+
pq.M,
|
|
514
|
+
pq.ksub,
|
|
515
|
+
lut_is_3d,
|
|
516
|
+
t_in,
|
|
517
|
+
b_in,
|
|
518
|
+
t_out,
|
|
519
|
+
M2,
|
|
520
|
+
b_out,
|
|
521
|
+
normalizers + 2 * i,
|
|
522
|
+
normalizers + 2 * i + 1);
|
|
533
523
|
}
|
|
534
524
|
IVFFastScan_stats.t_round += get_cy() - t1;
|
|
535
|
-
|
|
536
525
|
}
|
|
537
526
|
|
|
538
|
-
|
|
539
527
|
/*********************************************************
|
|
540
528
|
* Search functions
|
|
541
529
|
*********************************************************/
|
|
542
530
|
|
|
543
|
-
template<bool is_max>
|
|
531
|
+
template <bool is_max>
|
|
544
532
|
void IndexIVFPQFastScan::search_dispatch_implem(
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
533
|
+
idx_t n,
|
|
534
|
+
const float* x,
|
|
535
|
+
idx_t k,
|
|
536
|
+
float* distances,
|
|
537
|
+
idx_t* labels) const {
|
|
538
|
+
using Cfloat = typename std::conditional<
|
|
539
|
+
is_max,
|
|
540
|
+
CMax<float, int64_t>,
|
|
541
|
+
CMin<float, int64_t>>::type;
|
|
542
|
+
|
|
543
|
+
using C = typename std::conditional<
|
|
544
|
+
is_max,
|
|
545
|
+
CMax<uint16_t, int64_t>,
|
|
546
|
+
CMin<uint16_t, int64_t>>::type;
|
|
556
547
|
|
|
557
548
|
if (n == 0) {
|
|
558
549
|
return;
|
|
@@ -568,7 +559,7 @@ void IndexIVFPQFastScan::search_dispatch_implem(
|
|
|
568
559
|
impl = 10;
|
|
569
560
|
}
|
|
570
561
|
if (k > 20) {
|
|
571
|
-
impl
|
|
562
|
+
impl++;
|
|
572
563
|
}
|
|
573
564
|
}
|
|
574
565
|
|
|
@@ -582,11 +573,25 @@ void IndexIVFPQFastScan::search_dispatch_implem(
|
|
|
582
573
|
|
|
583
574
|
if (n < 2) {
|
|
584
575
|
if (impl == 12 || impl == 13) {
|
|
585
|
-
search_implem_12<C>
|
|
586
|
-
|
|
576
|
+
search_implem_12<C>(
|
|
577
|
+
n,
|
|
578
|
+
x,
|
|
579
|
+
k,
|
|
580
|
+
distances,
|
|
581
|
+
labels,
|
|
582
|
+
impl,
|
|
583
|
+
&ndis,
|
|
584
|
+
&nlist_visited);
|
|
587
585
|
} else {
|
|
588
|
-
search_implem_10<C>
|
|
589
|
-
|
|
586
|
+
search_implem_10<C>(
|
|
587
|
+
n,
|
|
588
|
+
x,
|
|
589
|
+
k,
|
|
590
|
+
distances,
|
|
591
|
+
labels,
|
|
592
|
+
impl,
|
|
593
|
+
&ndis,
|
|
594
|
+
&nlist_visited);
|
|
590
595
|
}
|
|
591
596
|
} else {
|
|
592
597
|
// explicitly slice over threads
|
|
@@ -595,34 +600,47 @@ void IndexIVFPQFastScan::search_dispatch_implem(
|
|
|
595
600
|
nslice = n;
|
|
596
601
|
} else if (by_residual && metric_type == METRIC_L2) {
|
|
597
602
|
// make sure we don't make too big LUT tables
|
|
598
|
-
size_t lut_size_per_query =
|
|
599
|
-
|
|
603
|
+
size_t lut_size_per_query = pq.M * pq.ksub * nprobe *
|
|
604
|
+
(sizeof(float) + sizeof(uint8_t));
|
|
600
605
|
|
|
601
606
|
size_t max_lut_size = precomputed_table_max_bytes;
|
|
602
607
|
// how many queries we can handle within mem budget
|
|
603
|
-
size_t nq_ok =
|
|
604
|
-
|
|
608
|
+
size_t nq_ok =
|
|
609
|
+
std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
610
|
+
nslice =
|
|
611
|
+
roundup(std::max(size_t(n / nq_ok), size_t(1)),
|
|
612
|
+
omp_get_max_threads());
|
|
605
613
|
} else {
|
|
606
614
|
// LUTs unlikely to be a limiting factor
|
|
607
615
|
nslice = omp_get_max_threads();
|
|
608
616
|
}
|
|
609
617
|
|
|
610
|
-
#pragma omp parallel for reduction(
|
|
618
|
+
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
611
619
|
for (int slice = 0; slice < nslice; slice++) {
|
|
612
620
|
idx_t i0 = n * slice / nslice;
|
|
613
621
|
idx_t i1 = n * (slice + 1) / nslice;
|
|
614
|
-
float
|
|
615
|
-
idx_t
|
|
622
|
+
float* dis_i = distances + i0 * k;
|
|
623
|
+
idx_t* lab_i = labels + i0 * k;
|
|
616
624
|
if (impl == 12 || impl == 13) {
|
|
617
625
|
search_implem_12<C>(
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
626
|
+
i1 - i0,
|
|
627
|
+
x + i0 * d,
|
|
628
|
+
k,
|
|
629
|
+
dis_i,
|
|
630
|
+
lab_i,
|
|
631
|
+
impl,
|
|
632
|
+
&ndis,
|
|
633
|
+
&nlist_visited);
|
|
621
634
|
} else {
|
|
622
635
|
search_implem_10<C>(
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
636
|
+
i1 - i0,
|
|
637
|
+
x + i0 * d,
|
|
638
|
+
k,
|
|
639
|
+
dis_i,
|
|
640
|
+
lab_i,
|
|
641
|
+
impl,
|
|
642
|
+
&ndis,
|
|
643
|
+
&nlist_visited);
|
|
626
644
|
}
|
|
627
645
|
}
|
|
628
646
|
}
|
|
@@ -632,14 +650,16 @@ void IndexIVFPQFastScan::search_dispatch_implem(
|
|
|
632
650
|
} else {
|
|
633
651
|
FAISS_THROW_FMT("implem %d does not exist", implem);
|
|
634
652
|
}
|
|
635
|
-
|
|
636
653
|
}
|
|
637
654
|
|
|
638
|
-
|
|
639
655
|
void IndexIVFPQFastScan::search(
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
656
|
+
idx_t n,
|
|
657
|
+
const float* x,
|
|
658
|
+
idx_t k,
|
|
659
|
+
float* distances,
|
|
660
|
+
idx_t* labels) const {
|
|
661
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
662
|
+
|
|
643
663
|
if (metric_type == METRIC_L2) {
|
|
644
664
|
search_dispatch_implem<true>(n, x, k, distances, labels);
|
|
645
665
|
} else {
|
|
@@ -647,133 +667,150 @@ void IndexIVFPQFastScan::search(
|
|
|
647
667
|
}
|
|
648
668
|
}
|
|
649
669
|
|
|
650
|
-
template<class C>
|
|
670
|
+
template <class C>
|
|
651
671
|
void IndexIVFPQFastScan::search_implem_1(
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
672
|
+
idx_t n,
|
|
673
|
+
const float* x,
|
|
674
|
+
idx_t k,
|
|
675
|
+
float* distances,
|
|
676
|
+
idx_t* labels) const {
|
|
655
677
|
FAISS_THROW_IF_NOT(orig_invlists);
|
|
656
678
|
|
|
657
679
|
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
658
680
|
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
659
681
|
|
|
660
|
-
quantizer->search
|
|
682
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
661
683
|
|
|
662
684
|
size_t dim12 = pq.ksub * pq.M;
|
|
663
685
|
AlignedTable<float> dis_tables;
|
|
664
686
|
AlignedTable<float> biases;
|
|
665
687
|
|
|
666
|
-
compute_LUT (
|
|
667
|
-
n, x,
|
|
668
|
-
coarse_ids.get(), coarse_dis.get(),
|
|
669
|
-
dis_tables, biases
|
|
670
|
-
);
|
|
688
|
+
compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
|
|
671
689
|
|
|
672
690
|
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
673
691
|
|
|
674
692
|
size_t ndis = 0, nlist_visited = 0;
|
|
675
693
|
|
|
676
|
-
#pragma omp parallel for reduction(
|
|
677
|
-
for(idx_t i = 0; i < n; i++) {
|
|
678
|
-
int64_t
|
|
679
|
-
float
|
|
680
|
-
heap_heapify<C>
|
|
681
|
-
float
|
|
694
|
+
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
695
|
+
for (idx_t i = 0; i < n; i++) {
|
|
696
|
+
int64_t* heap_ids = labels + i * k;
|
|
697
|
+
float* heap_dis = distances + i * k;
|
|
698
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
699
|
+
float* LUT = nullptr;
|
|
682
700
|
|
|
683
701
|
if (single_LUT) {
|
|
684
702
|
LUT = dis_tables.get() + i * dim12;
|
|
685
703
|
}
|
|
686
|
-
for(idx_t j = 0; j < nprobe; j++) {
|
|
704
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
687
705
|
if (!single_LUT) {
|
|
688
706
|
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
689
707
|
}
|
|
690
708
|
idx_t list_no = coarse_ids[i * nprobe + j];
|
|
691
|
-
if (list_no < 0)
|
|
709
|
+
if (list_no < 0)
|
|
710
|
+
continue;
|
|
692
711
|
size_t ls = orig_invlists->list_size(list_no);
|
|
693
|
-
if (ls == 0)
|
|
712
|
+
if (ls == 0)
|
|
713
|
+
continue;
|
|
694
714
|
InvertedLists::ScopedCodes codes(orig_invlists, list_no);
|
|
695
715
|
InvertedLists::ScopedIds ids(orig_invlists, list_no);
|
|
696
716
|
|
|
697
717
|
float bias = biases.get() ? biases[i * nprobe + j] : 0;
|
|
698
718
|
|
|
699
719
|
pq_estimators_from_tables_generic<C>(
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
720
|
+
pq,
|
|
721
|
+
pq.nbits,
|
|
722
|
+
codes.get(),
|
|
723
|
+
ls,
|
|
724
|
+
LUT,
|
|
725
|
+
ids.get(),
|
|
726
|
+
bias,
|
|
727
|
+
k,
|
|
728
|
+
heap_dis,
|
|
729
|
+
heap_ids);
|
|
730
|
+
nlist_visited++;
|
|
731
|
+
ndis++;
|
|
706
732
|
}
|
|
707
|
-
heap_reorder<C>
|
|
733
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
708
734
|
}
|
|
709
735
|
indexIVF_stats.nq += n;
|
|
710
736
|
indexIVF_stats.ndis += ndis;
|
|
711
737
|
indexIVF_stats.nlist += nlist_visited;
|
|
712
738
|
}
|
|
713
739
|
|
|
714
|
-
template<class C>
|
|
740
|
+
template <class C>
|
|
715
741
|
void IndexIVFPQFastScan::search_implem_2(
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
742
|
+
idx_t n,
|
|
743
|
+
const float* x,
|
|
744
|
+
idx_t k,
|
|
745
|
+
float* distances,
|
|
746
|
+
idx_t* labels) const {
|
|
719
747
|
FAISS_THROW_IF_NOT(orig_invlists);
|
|
720
748
|
|
|
721
749
|
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
722
750
|
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
723
751
|
|
|
724
|
-
quantizer->search
|
|
752
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
725
753
|
|
|
726
754
|
size_t dim12 = pq.ksub * M2;
|
|
727
755
|
AlignedTable<uint8_t> dis_tables;
|
|
728
756
|
AlignedTable<uint16_t> biases;
|
|
729
757
|
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
730
758
|
|
|
731
|
-
compute_LUT_uint8
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
759
|
+
compute_LUT_uint8(
|
|
760
|
+
n,
|
|
761
|
+
x,
|
|
762
|
+
coarse_ids.get(),
|
|
763
|
+
coarse_dis.get(),
|
|
764
|
+
dis_tables,
|
|
765
|
+
biases,
|
|
766
|
+
normalizers.get());
|
|
738
767
|
|
|
739
768
|
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
740
769
|
|
|
741
770
|
size_t ndis = 0, nlist_visited = 0;
|
|
742
771
|
|
|
743
|
-
#pragma omp parallel for reduction(
|
|
744
|
-
for(idx_t i = 0; i < n; i++) {
|
|
772
|
+
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
773
|
+
for (idx_t i = 0; i < n; i++) {
|
|
745
774
|
std::vector<uint16_t> tmp_dis(k);
|
|
746
|
-
int64_t
|
|
747
|
-
uint16_t
|
|
748
|
-
heap_heapify<C>
|
|
749
|
-
const uint8_t
|
|
775
|
+
int64_t* heap_ids = labels + i * k;
|
|
776
|
+
uint16_t* heap_dis = tmp_dis.data();
|
|
777
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
778
|
+
const uint8_t* LUT = nullptr;
|
|
750
779
|
|
|
751
780
|
if (single_LUT) {
|
|
752
781
|
LUT = dis_tables.get() + i * dim12;
|
|
753
782
|
}
|
|
754
|
-
for(idx_t j = 0; j < nprobe; j++) {
|
|
783
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
755
784
|
if (!single_LUT) {
|
|
756
785
|
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
757
786
|
}
|
|
758
787
|
idx_t list_no = coarse_ids[i * nprobe + j];
|
|
759
|
-
if (list_no < 0)
|
|
788
|
+
if (list_no < 0)
|
|
789
|
+
continue;
|
|
760
790
|
size_t ls = orig_invlists->list_size(list_no);
|
|
761
|
-
if (ls == 0)
|
|
791
|
+
if (ls == 0)
|
|
792
|
+
continue;
|
|
762
793
|
InvertedLists::ScopedCodes codes(orig_invlists, list_no);
|
|
763
794
|
InvertedLists::ScopedIds ids(orig_invlists, list_no);
|
|
764
795
|
|
|
765
796
|
uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
|
|
766
797
|
|
|
767
798
|
pq_estimators_from_tables_generic<C>(
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
799
|
+
pq,
|
|
800
|
+
pq.nbits,
|
|
801
|
+
codes.get(),
|
|
802
|
+
ls,
|
|
803
|
+
LUT,
|
|
804
|
+
ids.get(),
|
|
805
|
+
bias,
|
|
806
|
+
k,
|
|
807
|
+
heap_dis,
|
|
808
|
+
heap_ids);
|
|
772
809
|
|
|
773
810
|
nlist_visited++;
|
|
774
811
|
ndis += ls;
|
|
775
812
|
}
|
|
776
|
-
heap_reorder<C>
|
|
813
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
777
814
|
// convert distances to float
|
|
778
815
|
{
|
|
779
816
|
float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
|
|
@@ -781,7 +818,7 @@ void IndexIVFPQFastScan::search_implem_2(
|
|
|
781
818
|
one_a = 1;
|
|
782
819
|
b = 0;
|
|
783
820
|
}
|
|
784
|
-
float
|
|
821
|
+
float* heap_dis_float = distances + i * k;
|
|
785
822
|
for (int j = 0; j < k; j++) {
|
|
786
823
|
heap_dis_float[j] = b + heap_dis[j] * one_a;
|
|
787
824
|
}
|
|
@@ -792,14 +829,16 @@ void IndexIVFPQFastScan::search_implem_2(
|
|
|
792
829
|
indexIVF_stats.nlist += nlist_visited;
|
|
793
830
|
}
|
|
794
831
|
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
template<class C>
|
|
832
|
+
template <class C>
|
|
798
833
|
void IndexIVFPQFastScan::search_implem_10(
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
834
|
+
idx_t n,
|
|
835
|
+
const float* x,
|
|
836
|
+
idx_t k,
|
|
837
|
+
float* distances,
|
|
838
|
+
idx_t* labels,
|
|
839
|
+
int impl,
|
|
840
|
+
size_t* ndis_out,
|
|
841
|
+
size_t* nlist_out) const {
|
|
803
842
|
memset(distances, -1, sizeof(float) * k * n);
|
|
804
843
|
memset(labels, -1, sizeof(idx_t) * k * n);
|
|
805
844
|
|
|
@@ -807,7 +846,6 @@ void IndexIVFPQFastScan::search_implem_10(
|
|
|
807
846
|
using ReservoirHC = ReservoirHandler<C, true>;
|
|
808
847
|
using SingleResultHC = SingleResultHandler<C, true>;
|
|
809
848
|
|
|
810
|
-
|
|
811
849
|
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
812
850
|
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
813
851
|
|
|
@@ -817,20 +855,23 @@ void IndexIVFPQFastScan::search_implem_10(
|
|
|
817
855
|
#define TIC times[ti++] = get_cy()
|
|
818
856
|
TIC;
|
|
819
857
|
|
|
820
|
-
quantizer->search
|
|
858
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
821
859
|
|
|
822
860
|
TIC;
|
|
823
861
|
|
|
824
862
|
size_t dim12 = pq.ksub * M2;
|
|
825
863
|
AlignedTable<uint8_t> dis_tables;
|
|
826
864
|
AlignedTable<uint16_t> biases;
|
|
827
|
-
std::unique_ptr<float[]> normalizers
|
|
865
|
+
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
828
866
|
|
|
829
|
-
compute_LUT_uint8
|
|
830
|
-
n,
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
867
|
+
compute_LUT_uint8(
|
|
868
|
+
n,
|
|
869
|
+
x,
|
|
870
|
+
coarse_ids.get(),
|
|
871
|
+
coarse_dis.get(),
|
|
872
|
+
dis_tables,
|
|
873
|
+
biases,
|
|
874
|
+
normalizers.get());
|
|
834
875
|
|
|
835
876
|
TIC;
|
|
836
877
|
|
|
@@ -841,15 +882,16 @@ void IndexIVFPQFastScan::search_implem_10(
|
|
|
841
882
|
|
|
842
883
|
{
|
|
843
884
|
AlignedTable<uint16_t> tmp_distances(k);
|
|
844
|
-
for(idx_t i = 0; i < n; i++) {
|
|
845
|
-
const uint8_t
|
|
885
|
+
for (idx_t i = 0; i < n; i++) {
|
|
886
|
+
const uint8_t* LUT = nullptr;
|
|
846
887
|
int qmap1[1] = {0};
|
|
847
|
-
std::unique_ptr<SIMDResultHandler<C, true
|
|
888
|
+
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
848
889
|
|
|
849
890
|
if (k == 1) {
|
|
850
891
|
handler.reset(new SingleResultHC(1, 0));
|
|
851
892
|
} else if (impl == 10) {
|
|
852
|
-
handler.reset(new HeapHC(
|
|
893
|
+
handler.reset(new HeapHC(
|
|
894
|
+
1, tmp_distances.get(), labels + i * k, k, 0));
|
|
853
895
|
} else if (impl == 11) {
|
|
854
896
|
handler.reset(new ReservoirHC(1, 0, k, 2 * k));
|
|
855
897
|
} else {
|
|
@@ -861,7 +903,7 @@ void IndexIVFPQFastScan::search_implem_10(
|
|
|
861
903
|
if (single_LUT) {
|
|
862
904
|
LUT = dis_tables.get() + i * dim12;
|
|
863
905
|
}
|
|
864
|
-
for(idx_t j = 0; j < nprobe; j++) {
|
|
906
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
865
907
|
size_t ij = i * nprobe + j;
|
|
866
908
|
if (!single_LUT) {
|
|
867
909
|
LUT = dis_tables.get() + ij * dim12;
|
|
@@ -871,9 +913,11 @@ void IndexIVFPQFastScan::search_implem_10(
|
|
|
871
913
|
}
|
|
872
914
|
|
|
873
915
|
idx_t list_no = coarse_ids[ij];
|
|
874
|
-
if (list_no < 0)
|
|
916
|
+
if (list_no < 0)
|
|
917
|
+
continue;
|
|
875
918
|
size_t ls = invlists->list_size(list_no);
|
|
876
|
-
if (ls == 0)
|
|
919
|
+
if (ls == 0)
|
|
920
|
+
continue;
|
|
877
921
|
|
|
878
922
|
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
879
923
|
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
@@ -881,41 +925,40 @@ void IndexIVFPQFastScan::search_implem_10(
|
|
|
881
925
|
handler->ntotal = ls;
|
|
882
926
|
handler->id_map = ids.get();
|
|
883
927
|
|
|
884
|
-
#define DISPATCH(classHC)
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
); \
|
|
891
|
-
}
|
|
928
|
+
#define DISPATCH(classHC) \
|
|
929
|
+
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
930
|
+
auto* res = static_cast<classHC*>(handler.get()); \
|
|
931
|
+
pq4_accumulate_loop( \
|
|
932
|
+
1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res); \
|
|
933
|
+
}
|
|
892
934
|
DISPATCH(HeapHC)
|
|
893
|
-
else DISPATCH(ReservoirHC)
|
|
894
|
-
else DISPATCH(SingleResultHC)
|
|
935
|
+
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
895
936
|
#undef DISPATCH
|
|
896
937
|
|
|
897
|
-
|
|
898
|
-
ndis
|
|
938
|
+
nlist_visited++;
|
|
939
|
+
ndis++;
|
|
899
940
|
}
|
|
900
941
|
|
|
901
942
|
handler->to_flat_arrays(
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
943
|
+
distances + i * k,
|
|
944
|
+
labels + i * k,
|
|
945
|
+
skip & 16 ? nullptr : normalizers.get() + i * 2);
|
|
905
946
|
}
|
|
906
947
|
}
|
|
907
948
|
*ndis_out = ndis;
|
|
908
949
|
*nlist_out = nlist;
|
|
909
950
|
}
|
|
910
951
|
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
template<class C>
|
|
952
|
+
template <class C>
|
|
914
953
|
void IndexIVFPQFastScan::search_implem_12(
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
954
|
+
idx_t n,
|
|
955
|
+
const float* x,
|
|
956
|
+
idx_t k,
|
|
957
|
+
float* distances,
|
|
958
|
+
idx_t* labels,
|
|
959
|
+
int impl,
|
|
960
|
+
size_t* ndis_out,
|
|
961
|
+
size_t* nlist_out) const {
|
|
919
962
|
if (n == 0) { // does not work well with reservoir
|
|
920
963
|
return;
|
|
921
964
|
}
|
|
@@ -930,53 +973,53 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
930
973
|
#define TIC times[ti++] = get_cy()
|
|
931
974
|
TIC;
|
|
932
975
|
|
|
933
|
-
quantizer->search
|
|
976
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
934
977
|
|
|
935
978
|
TIC;
|
|
936
979
|
|
|
937
980
|
size_t dim12 = pq.ksub * M2;
|
|
938
981
|
AlignedTable<uint8_t> dis_tables;
|
|
939
982
|
AlignedTable<uint16_t> biases;
|
|
940
|
-
std::unique_ptr<float[]> normalizers
|
|
983
|
+
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
941
984
|
|
|
942
|
-
compute_LUT_uint8
|
|
943
|
-
n,
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
985
|
+
compute_LUT_uint8(
|
|
986
|
+
n,
|
|
987
|
+
x,
|
|
988
|
+
coarse_ids.get(),
|
|
989
|
+
coarse_dis.get(),
|
|
990
|
+
dis_tables,
|
|
991
|
+
biases,
|
|
992
|
+
normalizers.get());
|
|
947
993
|
|
|
948
994
|
TIC;
|
|
949
995
|
|
|
950
996
|
struct QC {
|
|
951
|
-
int qno;
|
|
952
|
-
int list_no;
|
|
953
|
-
int rank;
|
|
997
|
+
int qno; // sequence number of the query
|
|
998
|
+
int list_no; // list to visit
|
|
999
|
+
int rank; // this is the rank'th result of the coarse quantizer
|
|
954
1000
|
};
|
|
955
1001
|
bool single_LUT = !(by_residual && metric_type == METRIC_L2);
|
|
956
1002
|
|
|
957
1003
|
std::vector<QC> qcs;
|
|
958
1004
|
{
|
|
959
1005
|
int ij = 0;
|
|
960
|
-
for(int i = 0; i < n; i++) {
|
|
961
|
-
for(int j = 0; j < nprobe; j++) {
|
|
1006
|
+
for (int i = 0; i < n; i++) {
|
|
1007
|
+
for (int j = 0; j < nprobe; j++) {
|
|
962
1008
|
if (coarse_ids[ij] >= 0) {
|
|
963
1009
|
qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
|
|
964
1010
|
}
|
|
965
1011
|
ij++;
|
|
966
1012
|
}
|
|
967
1013
|
}
|
|
968
|
-
std::sort(
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
return a.list_no < b.list_no;
|
|
972
|
-
}
|
|
973
|
-
);
|
|
1014
|
+
std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
|
|
1015
|
+
return a.list_no < b.list_no;
|
|
1016
|
+
});
|
|
974
1017
|
}
|
|
975
1018
|
TIC;
|
|
976
1019
|
|
|
977
1020
|
// prepare the result handlers
|
|
978
1021
|
|
|
979
|
-
std::unique_ptr<SIMDResultHandler<C, true
|
|
1022
|
+
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
980
1023
|
AlignedTable<uint16_t> tmp_distances;
|
|
981
1024
|
|
|
982
1025
|
using HeapHC = HeapHandler<C, true>;
|
|
@@ -1012,7 +1055,7 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
1012
1055
|
int list_no = qcs[i0].list_no;
|
|
1013
1056
|
size_t i1 = i0 + 1;
|
|
1014
1057
|
|
|
1015
|
-
while(i1 < qcs.size() && i1 < i0 + qbs2) {
|
|
1058
|
+
while (i1 < qcs.size() && i1 < i0 + qbs2) {
|
|
1016
1059
|
if (qcs[i1].list_no != list_no) {
|
|
1017
1060
|
break;
|
|
1018
1061
|
}
|
|
@@ -1034,8 +1077,8 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
1034
1077
|
memset(LUT.get(), -1, nc * dim12);
|
|
1035
1078
|
int qbs = pq4_preferred_qbs(nc);
|
|
1036
1079
|
|
|
1037
|
-
for(size_t i = i0; i < i1; i++) {
|
|
1038
|
-
const QC
|
|
1080
|
+
for (size_t i = i0; i < i1; i++) {
|
|
1081
|
+
const QC& qc = qcs[i];
|
|
1039
1082
|
q_map[i - i0] = qc.qno;
|
|
1040
1083
|
int ij = qc.qno * nprobe + qc.rank;
|
|
1041
1084
|
lut_entries[i - i0] = single_LUT ? qc.qno : ij;
|
|
@@ -1044,9 +1087,7 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
1044
1087
|
}
|
|
1045
1088
|
}
|
|
1046
1089
|
pq4_pack_LUT_qbs_q_map(
|
|
1047
|
-
|
|
1048
|
-
LUT.get()
|
|
1049
|
-
);
|
|
1090
|
+
qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
|
|
1050
1091
|
|
|
1051
1092
|
// access the inverted list
|
|
1052
1093
|
|
|
@@ -1062,20 +1103,17 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
1062
1103
|
handler->id_map = ids.get();
|
|
1063
1104
|
uint64_t tt1 = get_cy();
|
|
1064
1105
|
|
|
1065
|
-
#define DISPATCH(classHC)
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
); \
|
|
1072
|
-
}
|
|
1106
|
+
#define DISPATCH(classHC) \
|
|
1107
|
+
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
1108
|
+
auto* res = static_cast<classHC*>(handler.get()); \
|
|
1109
|
+
pq4_accumulate_loop_qbs( \
|
|
1110
|
+
qbs, list_size, M2, codes.get(), LUT.get(), *res); \
|
|
1111
|
+
}
|
|
1073
1112
|
DISPATCH(HeapHC)
|
|
1074
|
-
else DISPATCH(ReservoirHC)
|
|
1075
|
-
else DISPATCH(SingleResultHC)
|
|
1113
|
+
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
1076
1114
|
|
|
1077
|
-
|
|
1078
|
-
|
|
1115
|
+
// prepare for next loop
|
|
1116
|
+
i0 = i1;
|
|
1079
1117
|
|
|
1080
1118
|
uint64_t tt2 = get_cy();
|
|
1081
1119
|
t_copy_pack += tt1 - tt0;
|
|
@@ -1085,21 +1123,19 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
1085
1123
|
|
|
1086
1124
|
// labels is in-place for HeapHC
|
|
1087
1125
|
handler->to_flat_arrays(
|
|
1088
|
-
distances, labels,
|
|
1089
|
-
skip & 16 ? nullptr : normalizers.get()
|
|
1090
|
-
);
|
|
1126
|
+
distances, labels, skip & 16 ? nullptr : normalizers.get());
|
|
1091
1127
|
|
|
1092
1128
|
TIC;
|
|
1093
1129
|
|
|
1094
1130
|
// these stats are not thread-safe
|
|
1095
1131
|
|
|
1096
|
-
for(int i = 1; i < ti; i++) {
|
|
1097
|
-
IVFFastScan_stats.times[i] += times[i] - times[i-1];
|
|
1132
|
+
for (int i = 1; i < ti; i++) {
|
|
1133
|
+
IVFFastScan_stats.times[i] += times[i] - times[i - 1];
|
|
1098
1134
|
}
|
|
1099
1135
|
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
1100
1136
|
IVFFastScan_stats.t_scan += t_scan;
|
|
1101
1137
|
|
|
1102
|
-
if (auto
|
|
1138
|
+
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
1103
1139
|
for (int i = 0; i < 4; i++) {
|
|
1104
1140
|
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
1105
1141
|
}
|
|
@@ -1107,10 +1143,8 @@ void IndexIVFPQFastScan::search_implem_12(
|
|
|
1107
1143
|
|
|
1108
1144
|
*ndis_out = ndis;
|
|
1109
1145
|
*nlist_out = nlist;
|
|
1110
|
-
|
|
1111
1146
|
}
|
|
1112
1147
|
|
|
1113
|
-
|
|
1114
1148
|
IVFFastScanStats IVFFastScan_stats;
|
|
1115
1149
|
|
|
1116
1150
|
} // namespace faiss
|