faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -9,17 +9,17 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/IndexIVFPQ.h>
|
|
11
11
|
|
|
12
|
+
#include <stdint.h>
|
|
13
|
+
#include <cassert>
|
|
12
14
|
#include <cinttypes>
|
|
13
15
|
#include <cmath>
|
|
14
16
|
#include <cstdio>
|
|
15
|
-
#include <cassert>
|
|
16
|
-
#include <stdint.h>
|
|
17
17
|
|
|
18
18
|
#include <algorithm>
|
|
19
19
|
|
|
20
20
|
#include <faiss/utils/Heap.h>
|
|
21
|
-
#include <faiss/utils/utils.h>
|
|
22
21
|
#include <faiss/utils/distances.h>
|
|
22
|
+
#include <faiss/utils/utils.h>
|
|
23
23
|
|
|
24
24
|
#include <faiss/Clustering.h>
|
|
25
25
|
#include <faiss/IndexFlat.h>
|
|
@@ -36,12 +36,15 @@ namespace faiss {
|
|
|
36
36
|
* IndexIVFPQ implementation
|
|
37
37
|
******************************************/
|
|
38
38
|
|
|
39
|
-
IndexIVFPQ::IndexIVFPQ
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
39
|
+
IndexIVFPQ::IndexIVFPQ(
|
|
40
|
+
Index* quantizer,
|
|
41
|
+
size_t d,
|
|
42
|
+
size_t nlist,
|
|
43
|
+
size_t M,
|
|
44
|
+
size_t nbits_per_idx,
|
|
45
|
+
MetricType metric)
|
|
46
|
+
: IndexIVF(quantizer, d, nlist, 0, metric), pq(d, M, nbits_per_idx) {
|
|
47
|
+
FAISS_THROW_IF_NOT(nbits_per_idx <= 8);
|
|
45
48
|
code_size = pq.code_size;
|
|
46
49
|
invlists->code_size = code_size;
|
|
47
50
|
is_trained = false;
|
|
@@ -52,202 +55,197 @@ IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist,
|
|
|
52
55
|
polysemous_training = nullptr;
|
|
53
56
|
do_polysemous_training = false;
|
|
54
57
|
polysemous_ht = 0;
|
|
55
|
-
|
|
56
58
|
}
|
|
57
59
|
|
|
58
|
-
|
|
59
60
|
/****************************************************************
|
|
60
61
|
* training */
|
|
61
62
|
|
|
62
|
-
void IndexIVFPQ::train_residual
|
|
63
|
-
|
|
64
|
-
train_residual_o (n, x, nullptr);
|
|
63
|
+
void IndexIVFPQ::train_residual(idx_t n, const float* x) {
|
|
64
|
+
train_residual_o(n, x, nullptr);
|
|
65
65
|
}
|
|
66
66
|
|
|
67
|
+
void IndexIVFPQ::train_residual_o(idx_t n, const float* x, float* residuals_2) {
|
|
68
|
+
const float* x_in = x;
|
|
67
69
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
70
|
+
x = fvecs_maybe_subsample(
|
|
71
|
+
d,
|
|
72
|
+
(size_t*)&n,
|
|
73
|
+
pq.cp.max_points_per_centroid * pq.ksub,
|
|
74
|
+
x,
|
|
75
|
+
verbose,
|
|
76
|
+
pq.cp.seed);
|
|
75
77
|
|
|
76
|
-
ScopeDeleter<float> del_x
|
|
78
|
+
ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
|
|
77
79
|
|
|
78
|
-
const float
|
|
80
|
+
const float* trainset;
|
|
79
81
|
ScopeDeleter<float> del_residuals;
|
|
80
82
|
if (by_residual) {
|
|
81
|
-
if(verbose)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
83
|
+
if (verbose)
|
|
84
|
+
printf("computing residuals\n");
|
|
85
|
+
idx_t* assign = new idx_t[n]; // assignement to coarse centroids
|
|
86
|
+
ScopeDeleter<idx_t> del(assign);
|
|
87
|
+
quantizer->assign(n, x, assign);
|
|
88
|
+
float* residuals = new float[n * d];
|
|
89
|
+
del_residuals.set(residuals);
|
|
87
90
|
for (idx_t i = 0; i < n; i++)
|
|
88
|
-
|
|
91
|
+
quantizer->compute_residual(
|
|
92
|
+
x + i * d, residuals + i * d, assign[i]);
|
|
89
93
|
|
|
90
94
|
trainset = residuals;
|
|
91
95
|
} else {
|
|
92
96
|
trainset = x;
|
|
93
97
|
}
|
|
94
98
|
if (verbose)
|
|
95
|
-
printf
|
|
96
|
-
|
|
99
|
+
printf("training %zdx%zd product quantizer on %" PRId64
|
|
100
|
+
" vectors in %dD\n",
|
|
101
|
+
pq.M,
|
|
102
|
+
pq.ksub,
|
|
103
|
+
n,
|
|
104
|
+
d);
|
|
97
105
|
pq.verbose = verbose;
|
|
98
|
-
pq.train
|
|
106
|
+
pq.train(n, trainset);
|
|
99
107
|
|
|
100
108
|
if (do_polysemous_training) {
|
|
101
109
|
if (verbose)
|
|
102
110
|
printf("doing polysemous training for PQ\n");
|
|
103
111
|
PolysemousTraining default_pt;
|
|
104
|
-
PolysemousTraining
|
|
105
|
-
if (!pt)
|
|
106
|
-
|
|
112
|
+
PolysemousTraining* pt = polysemous_training;
|
|
113
|
+
if (!pt)
|
|
114
|
+
pt = &default_pt;
|
|
115
|
+
pt->optimize_pq_for_hamming(pq, n, trainset);
|
|
107
116
|
}
|
|
108
117
|
|
|
109
118
|
// prepare second-level residuals for refine PQ
|
|
110
119
|
if (residuals_2) {
|
|
111
|
-
uint8_t
|
|
112
|
-
ScopeDeleter<uint8_t> del
|
|
113
|
-
pq.compute_codes
|
|
120
|
+
uint8_t* train_codes = new uint8_t[pq.code_size * n];
|
|
121
|
+
ScopeDeleter<uint8_t> del(train_codes);
|
|
122
|
+
pq.compute_codes(trainset, train_codes, n);
|
|
114
123
|
|
|
115
124
|
for (idx_t i = 0; i < n; i++) {
|
|
116
|
-
const float
|
|
117
|
-
float
|
|
118
|
-
pq.decode
|
|
125
|
+
const float* xx = trainset + i * d;
|
|
126
|
+
float* res = residuals_2 + i * d;
|
|
127
|
+
pq.decode(train_codes + i * pq.code_size, res);
|
|
119
128
|
for (int j = 0; j < d; j++)
|
|
120
129
|
res[j] = xx[j] - res[j];
|
|
121
130
|
}
|
|
122
|
-
|
|
123
131
|
}
|
|
124
132
|
|
|
125
133
|
if (by_residual) {
|
|
126
|
-
precompute_table
|
|
134
|
+
precompute_table();
|
|
127
135
|
}
|
|
128
|
-
|
|
129
136
|
}
|
|
130
137
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
138
|
/****************************************************************
|
|
137
139
|
* IVFPQ as codec */
|
|
138
140
|
|
|
139
|
-
|
|
140
141
|
/* produce a binary signature based on the residual vector */
|
|
141
|
-
void IndexIVFPQ::encode
|
|
142
|
-
{
|
|
142
|
+
void IndexIVFPQ::encode(idx_t key, const float* x, uint8_t* code) const {
|
|
143
143
|
if (by_residual) {
|
|
144
144
|
std::vector<float> residual_vec(d);
|
|
145
|
-
quantizer->compute_residual
|
|
146
|
-
pq.compute_code
|
|
147
|
-
}
|
|
148
|
-
|
|
145
|
+
quantizer->compute_residual(x, residual_vec.data(), key);
|
|
146
|
+
pq.compute_code(residual_vec.data(), code);
|
|
147
|
+
} else
|
|
148
|
+
pq.compute_code(x, code);
|
|
149
149
|
}
|
|
150
150
|
|
|
151
|
-
void IndexIVFPQ::encode_multiple
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
151
|
+
void IndexIVFPQ::encode_multiple(
|
|
152
|
+
size_t n,
|
|
153
|
+
idx_t* keys,
|
|
154
|
+
const float* x,
|
|
155
|
+
uint8_t* xcodes,
|
|
156
|
+
bool compute_keys) const {
|
|
155
157
|
if (compute_keys)
|
|
156
|
-
quantizer->assign
|
|
158
|
+
quantizer->assign(n, x, keys);
|
|
157
159
|
|
|
158
|
-
encode_vectors
|
|
160
|
+
encode_vectors(n, x, keys, xcodes);
|
|
159
161
|
}
|
|
160
162
|
|
|
161
|
-
void IndexIVFPQ::decode_multiple
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
163
|
+
void IndexIVFPQ::decode_multiple(
|
|
164
|
+
size_t n,
|
|
165
|
+
const idx_t* keys,
|
|
166
|
+
const uint8_t* xcodes,
|
|
167
|
+
float* x) const {
|
|
168
|
+
pq.decode(xcodes, x, n);
|
|
165
169
|
if (by_residual) {
|
|
166
|
-
std::vector<float> centroid
|
|
170
|
+
std::vector<float> centroid(d);
|
|
167
171
|
for (size_t i = 0; i < n; i++) {
|
|
168
|
-
quantizer->reconstruct
|
|
169
|
-
float
|
|
172
|
+
quantizer->reconstruct(keys[i], centroid.data());
|
|
173
|
+
float* xi = x + i * d;
|
|
170
174
|
for (size_t j = 0; j < d; j++) {
|
|
171
|
-
xi
|
|
175
|
+
xi[j] += centroid[j];
|
|
172
176
|
}
|
|
173
177
|
}
|
|
174
178
|
}
|
|
175
179
|
}
|
|
176
180
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
181
|
/****************************************************************
|
|
181
182
|
* add */
|
|
182
183
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
184
|
+
void IndexIVFPQ::add_core(
|
|
185
|
+
idx_t n,
|
|
186
|
+
const float* x,
|
|
187
|
+
const idx_t* xids,
|
|
188
|
+
const idx_t* coarse_idx) {
|
|
189
|
+
add_core_o(n, x, xids, nullptr, coarse_idx);
|
|
187
190
|
}
|
|
188
191
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
const Index::idx_t
|
|
194
|
-
{
|
|
192
|
+
static float* compute_residuals(
|
|
193
|
+
const Index* quantizer,
|
|
194
|
+
Index::idx_t n,
|
|
195
|
+
const float* x,
|
|
196
|
+
const Index::idx_t* list_nos) {
|
|
195
197
|
size_t d = quantizer->d;
|
|
196
|
-
float
|
|
198
|
+
float* residuals = new float[n * d];
|
|
197
199
|
// TODO: parallelize?
|
|
198
200
|
for (size_t i = 0; i < n; i++) {
|
|
199
201
|
if (list_nos[i] < 0)
|
|
200
|
-
memset
|
|
202
|
+
memset(residuals + i * d, 0, sizeof(*residuals) * d);
|
|
201
203
|
else
|
|
202
|
-
quantizer->compute_residual
|
|
203
|
-
|
|
204
|
+
quantizer->compute_residual(
|
|
205
|
+
x + i * d, residuals + i * d, list_nos[i]);
|
|
204
206
|
}
|
|
205
207
|
return residuals;
|
|
206
208
|
}
|
|
207
209
|
|
|
208
|
-
void IndexIVFPQ::encode_vectors(
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
210
|
+
void IndexIVFPQ::encode_vectors(
|
|
211
|
+
idx_t n,
|
|
212
|
+
const float* x,
|
|
213
|
+
const idx_t* list_nos,
|
|
214
|
+
uint8_t* codes,
|
|
215
|
+
bool include_listnos) const {
|
|
213
216
|
if (by_residual) {
|
|
214
|
-
float
|
|
215
|
-
ScopeDeleter<float> del
|
|
216
|
-
pq.compute_codes
|
|
217
|
+
float* to_encode = compute_residuals(quantizer, n, x, list_nos);
|
|
218
|
+
ScopeDeleter<float> del(to_encode);
|
|
219
|
+
pq.compute_codes(to_encode, codes, n);
|
|
217
220
|
} else {
|
|
218
|
-
pq.compute_codes
|
|
221
|
+
pq.compute_codes(x, codes, n);
|
|
219
222
|
}
|
|
220
223
|
|
|
221
224
|
if (include_listnos) {
|
|
222
225
|
size_t coarse_size = coarse_code_size();
|
|
223
226
|
for (idx_t i = n - 1; i >= 0; i--) {
|
|
224
|
-
uint8_t
|
|
225
|
-
memmove
|
|
226
|
-
|
|
227
|
-
encode_listno (list_nos[i], code);
|
|
227
|
+
uint8_t* code = codes + i * (coarse_size + code_size);
|
|
228
|
+
memmove(code + coarse_size, codes + i * code_size, code_size);
|
|
229
|
+
encode_listno(list_nos[i], code);
|
|
228
230
|
}
|
|
229
231
|
}
|
|
230
232
|
}
|
|
231
233
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
|
|
235
|
-
float *x) const
|
|
236
|
-
{
|
|
237
|
-
size_t coarse_size = coarse_code_size ();
|
|
234
|
+
void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
|
|
235
|
+
size_t coarse_size = coarse_code_size();
|
|
238
236
|
|
|
239
237
|
#pragma omp parallel
|
|
240
238
|
{
|
|
241
|
-
std::vector<float> residual
|
|
239
|
+
std::vector<float> residual(d);
|
|
242
240
|
|
|
243
241
|
#pragma omp for
|
|
244
242
|
for (idx_t i = 0; i < n; i++) {
|
|
245
|
-
const uint8_t
|
|
246
|
-
int64_t list_no = decode_listno
|
|
247
|
-
float
|
|
248
|
-
pq.decode
|
|
243
|
+
const uint8_t* code = codes + i * (code_size + coarse_size);
|
|
244
|
+
int64_t list_no = decode_listno(code);
|
|
245
|
+
float* xi = x + i * d;
|
|
246
|
+
pq.decode(code + coarse_size, xi);
|
|
249
247
|
if (by_residual) {
|
|
250
|
-
quantizer->reconstruct
|
|
248
|
+
quantizer->reconstruct(list_no, residual.data());
|
|
251
249
|
for (size_t j = 0; j < d; j++) {
|
|
252
250
|
xi[j] += residual[j];
|
|
253
251
|
}
|
|
@@ -256,120 +254,127 @@ void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
|
|
|
256
254
|
}
|
|
257
255
|
}
|
|
258
256
|
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
257
|
+
void IndexIVFPQ::add_core_o(
|
|
258
|
+
idx_t n,
|
|
259
|
+
const float* x,
|
|
260
|
+
const idx_t* xids,
|
|
261
|
+
float* residuals_2,
|
|
262
|
+
const idx_t* precomputed_idx) {
|
|
264
263
|
idx_t bs = 32768;
|
|
265
264
|
if (n > bs) {
|
|
266
265
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
267
266
|
idx_t i1 = std::min(i0 + bs, n);
|
|
268
267
|
if (verbose) {
|
|
269
|
-
printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
|
|
270
|
-
|
|
268
|
+
printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
|
|
269
|
+
" / %" PRId64 "\n",
|
|
270
|
+
i0,
|
|
271
|
+
i1,
|
|
272
|
+
n);
|
|
271
273
|
}
|
|
272
|
-
add_core_o
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
274
|
+
add_core_o(
|
|
275
|
+
i1 - i0,
|
|
276
|
+
x + i0 * d,
|
|
277
|
+
xids ? xids + i0 : nullptr,
|
|
278
|
+
residuals_2 ? residuals_2 + i0 * d : nullptr,
|
|
279
|
+
precomputed_idx ? precomputed_idx + i0 : nullptr);
|
|
276
280
|
}
|
|
277
281
|
return;
|
|
278
282
|
}
|
|
279
283
|
|
|
280
284
|
InterruptCallback::check();
|
|
281
285
|
|
|
282
|
-
direct_map.check_can_add
|
|
286
|
+
direct_map.check_can_add(xids);
|
|
283
287
|
|
|
284
|
-
FAISS_THROW_IF_NOT
|
|
285
|
-
double t0 = getmillisecs
|
|
286
|
-
const idx_t
|
|
288
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
289
|
+
double t0 = getmillisecs();
|
|
290
|
+
const idx_t* idx;
|
|
287
291
|
ScopeDeleter<idx_t> del_idx;
|
|
288
292
|
|
|
289
293
|
if (precomputed_idx) {
|
|
290
294
|
idx = precomputed_idx;
|
|
291
295
|
} else {
|
|
292
|
-
idx_t
|
|
293
|
-
del_idx.set
|
|
294
|
-
quantizer->assign
|
|
296
|
+
idx_t* idx0 = new idx_t[n];
|
|
297
|
+
del_idx.set(idx0);
|
|
298
|
+
quantizer->assign(n, x, idx0);
|
|
295
299
|
idx = idx0;
|
|
296
300
|
}
|
|
297
301
|
|
|
298
|
-
double t1 = getmillisecs
|
|
299
|
-
uint8_t
|
|
300
|
-
ScopeDeleter<uint8_t> del_xcodes
|
|
302
|
+
double t1 = getmillisecs();
|
|
303
|
+
uint8_t* xcodes = new uint8_t[n * code_size];
|
|
304
|
+
ScopeDeleter<uint8_t> del_xcodes(xcodes);
|
|
301
305
|
|
|
302
|
-
const float
|
|
306
|
+
const float* to_encode = nullptr;
|
|
303
307
|
ScopeDeleter<float> del_to_encode;
|
|
304
308
|
|
|
305
309
|
if (by_residual) {
|
|
306
|
-
to_encode = compute_residuals
|
|
307
|
-
del_to_encode.set
|
|
310
|
+
to_encode = compute_residuals(quantizer, n, x, idx);
|
|
311
|
+
del_to_encode.set(to_encode);
|
|
308
312
|
} else {
|
|
309
313
|
to_encode = x;
|
|
310
314
|
}
|
|
311
|
-
pq.compute_codes
|
|
315
|
+
pq.compute_codes(to_encode, xcodes, n);
|
|
312
316
|
|
|
313
|
-
double t2 = getmillisecs
|
|
317
|
+
double t2 = getmillisecs();
|
|
314
318
|
// TODO: parallelize?
|
|
315
319
|
size_t n_ignore = 0;
|
|
316
320
|
for (size_t i = 0; i < n; i++) {
|
|
317
321
|
idx_t key = idx[i];
|
|
318
322
|
idx_t id = xids ? xids[i] : ntotal + i;
|
|
319
323
|
if (key < 0) {
|
|
320
|
-
direct_map.add_single_id
|
|
321
|
-
n_ignore
|
|
324
|
+
direct_map.add_single_id(id, -1, 0);
|
|
325
|
+
n_ignore++;
|
|
322
326
|
if (residuals_2)
|
|
323
|
-
memset
|
|
327
|
+
memset(residuals_2, 0, sizeof(*residuals_2) * d);
|
|
324
328
|
continue;
|
|
325
329
|
}
|
|
326
330
|
|
|
327
|
-
uint8_t
|
|
328
|
-
size_t offset = invlists->add_entry
|
|
331
|
+
uint8_t* code = xcodes + i * code_size;
|
|
332
|
+
size_t offset = invlists->add_entry(key, id, code);
|
|
329
333
|
|
|
330
334
|
if (residuals_2) {
|
|
331
|
-
float
|
|
332
|
-
const float
|
|
333
|
-
pq.decode
|
|
335
|
+
float* res2 = residuals_2 + i * d;
|
|
336
|
+
const float* xi = to_encode + i * d;
|
|
337
|
+
pq.decode(code, res2);
|
|
334
338
|
for (int j = 0; j < d; j++)
|
|
335
339
|
res2[j] = xi[j] - res2[j];
|
|
336
340
|
}
|
|
337
341
|
|
|
338
|
-
direct_map.add_single_id
|
|
342
|
+
direct_map.add_single_id(id, key, offset);
|
|
339
343
|
}
|
|
340
344
|
|
|
341
|
-
double t3 = getmillisecs
|
|
342
|
-
if(verbose) {
|
|
345
|
+
double t3 = getmillisecs();
|
|
346
|
+
if (verbose) {
|
|
343
347
|
char comment[100] = {0};
|
|
344
348
|
if (n_ignore > 0)
|
|
345
|
-
snprintf
|
|
349
|
+
snprintf(comment, 100, "(%zd vectors ignored)", n_ignore);
|
|
346
350
|
printf(" add_core times: %.3f %.3f %.3f %s\n",
|
|
347
|
-
t1 - t0,
|
|
351
|
+
t1 - t0,
|
|
352
|
+
t2 - t1,
|
|
353
|
+
t3 - t2,
|
|
354
|
+
comment);
|
|
348
355
|
}
|
|
349
356
|
ntotal += n;
|
|
350
357
|
}
|
|
351
358
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
{
|
|
356
|
-
const uint8_t* code = invlists->get_single_code
|
|
359
|
+
void IndexIVFPQ::reconstruct_from_offset(
|
|
360
|
+
int64_t list_no,
|
|
361
|
+
int64_t offset,
|
|
362
|
+
float* recons) const {
|
|
363
|
+
const uint8_t* code = invlists->get_single_code(list_no, offset);
|
|
357
364
|
|
|
358
365
|
if (by_residual) {
|
|
359
366
|
std::vector<float> centroid(d);
|
|
360
|
-
quantizer->reconstruct
|
|
367
|
+
quantizer->reconstruct(list_no, centroid.data());
|
|
361
368
|
|
|
362
|
-
pq.decode
|
|
369
|
+
pq.decode(code, recons);
|
|
363
370
|
for (int i = 0; i < d; ++i) {
|
|
364
371
|
recons[i] += centroid[i];
|
|
365
372
|
}
|
|
366
373
|
} else {
|
|
367
|
-
pq.decode
|
|
374
|
+
pq.decode(code, recons);
|
|
368
375
|
}
|
|
369
376
|
}
|
|
370
377
|
|
|
371
|
-
|
|
372
|
-
|
|
373
378
|
/// 2G by default, accommodates tables up to PQ32 w/ 65536 centroids
|
|
374
379
|
size_t precomputed_table_max_bytes = ((size_t)1) << 31;
|
|
375
380
|
|
|
@@ -403,20 +408,18 @@ size_t precomputed_table_max_bytes = ((size_t)1) << 31;
|
|
|
403
408
|
* is faster when the length of the lists is > ksub * M.
|
|
404
409
|
*/
|
|
405
410
|
|
|
406
|
-
void initialize_IVFPQ_precomputed_table
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
)
|
|
413
|
-
{
|
|
411
|
+
void initialize_IVFPQ_precomputed_table(
|
|
412
|
+
int& use_precomputed_table,
|
|
413
|
+
const Index* quantizer,
|
|
414
|
+
const ProductQuantizer& pq,
|
|
415
|
+
AlignedTable<float>& precomputed_table,
|
|
416
|
+
bool verbose) {
|
|
414
417
|
size_t nlist = quantizer->ntotal;
|
|
415
418
|
size_t d = quantizer->d;
|
|
416
419
|
FAISS_THROW_IF_NOT(d == pq.d);
|
|
417
420
|
|
|
418
421
|
if (use_precomputed_table == -1) {
|
|
419
|
-
precomputed_table.resize
|
|
422
|
+
precomputed_table.resize(0);
|
|
420
423
|
return;
|
|
421
424
|
}
|
|
422
425
|
|
|
@@ -424,23 +427,23 @@ void initialize_IVFPQ_precomputed_table (
|
|
|
424
427
|
if (quantizer->metric_type == METRIC_INNER_PRODUCT) {
|
|
425
428
|
if (verbose) {
|
|
426
429
|
printf("IndexIVFPQ::precompute_table: precomputed "
|
|
427
|
-
|
|
430
|
+
"tables not needed for inner product quantizers\n");
|
|
428
431
|
}
|
|
429
|
-
precomputed_table.resize
|
|
432
|
+
precomputed_table.resize(0);
|
|
430
433
|
return;
|
|
431
434
|
}
|
|
432
|
-
const MultiIndexQuantizer
|
|
433
|
-
|
|
435
|
+
const MultiIndexQuantizer* miq =
|
|
436
|
+
dynamic_cast<const MultiIndexQuantizer*>(quantizer);
|
|
434
437
|
if (miq && pq.M % miq->pq.M == 0)
|
|
435
438
|
use_precomputed_table = 2;
|
|
436
439
|
else {
|
|
437
440
|
size_t table_size = pq.M * pq.ksub * nlist * sizeof(float);
|
|
438
441
|
if (table_size > precomputed_table_max_bytes) {
|
|
439
442
|
if (verbose) {
|
|
440
|
-
printf(
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
443
|
+
printf("IndexIVFPQ::precompute_table: not precomputing table, "
|
|
444
|
+
"it would be too big: %zd bytes (max %zd)\n",
|
|
445
|
+
table_size,
|
|
446
|
+
precomputed_table_max_bytes);
|
|
444
447
|
use_precomputed_table = 0;
|
|
445
448
|
}
|
|
446
449
|
return;
|
|
@@ -450,80 +453,68 @@ void initialize_IVFPQ_precomputed_table (
|
|
|
450
453
|
} // otherwise assume user has set appropriate flag on input
|
|
451
454
|
|
|
452
455
|
if (verbose) {
|
|
453
|
-
printf
|
|
454
|
-
use_precomputed_table);
|
|
456
|
+
printf("precomputing IVFPQ tables type %d\n", use_precomputed_table);
|
|
455
457
|
}
|
|
456
458
|
|
|
457
459
|
// squared norms of the PQ centroids
|
|
458
|
-
std::vector<float> r_norms
|
|
460
|
+
std::vector<float> r_norms(pq.M * pq.ksub, NAN);
|
|
459
461
|
for (int m = 0; m < pq.M; m++)
|
|
460
462
|
for (int j = 0; j < pq.ksub; j++)
|
|
461
|
-
r_norms
|
|
462
|
-
|
|
463
|
+
r_norms[m * pq.ksub + j] =
|
|
464
|
+
fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
|
|
463
465
|
|
|
464
466
|
if (use_precomputed_table == 1) {
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
std::vector<float> centroid (d);
|
|
467
|
+
precomputed_table.resize(nlist * pq.M * pq.ksub);
|
|
468
|
+
std::vector<float> centroid(d);
|
|
468
469
|
|
|
469
470
|
for (size_t i = 0; i < nlist; i++) {
|
|
470
|
-
quantizer->reconstruct
|
|
471
|
+
quantizer->reconstruct(i, centroid.data());
|
|
471
472
|
|
|
472
|
-
float
|
|
473
|
-
pq.compute_inner_prod_table
|
|
474
|
-
fvec_madd
|
|
473
|
+
float* tab = &precomputed_table[i * pq.M * pq.ksub];
|
|
474
|
+
pq.compute_inner_prod_table(centroid.data(), tab);
|
|
475
|
+
fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
|
|
475
476
|
}
|
|
476
477
|
} else if (use_precomputed_table == 2) {
|
|
477
|
-
const MultiIndexQuantizer
|
|
478
|
-
|
|
479
|
-
FAISS_THROW_IF_NOT
|
|
480
|
-
const ProductQuantizer
|
|
481
|
-
FAISS_THROW_IF_NOT
|
|
478
|
+
const MultiIndexQuantizer* miq =
|
|
479
|
+
dynamic_cast<const MultiIndexQuantizer*>(quantizer);
|
|
480
|
+
FAISS_THROW_IF_NOT(miq);
|
|
481
|
+
const ProductQuantizer& cpq = miq->pq;
|
|
482
|
+
FAISS_THROW_IF_NOT(pq.M % cpq.M == 0);
|
|
482
483
|
|
|
483
484
|
precomputed_table.resize(cpq.ksub * pq.M * pq.ksub);
|
|
484
485
|
|
|
485
486
|
// reorder PQ centroid table
|
|
486
|
-
std::vector<float> centroids
|
|
487
|
+
std::vector<float> centroids(d * cpq.ksub, NAN);
|
|
487
488
|
|
|
488
489
|
for (int m = 0; m < cpq.M; m++) {
|
|
489
490
|
for (size_t i = 0; i < cpq.ksub; i++) {
|
|
490
|
-
memcpy
|
|
491
|
-
|
|
492
|
-
|
|
491
|
+
memcpy(centroids.data() + i * d + m * cpq.dsub,
|
|
492
|
+
cpq.get_centroids(m, i),
|
|
493
|
+
sizeof(*centroids.data()) * cpq.dsub);
|
|
493
494
|
}
|
|
494
495
|
}
|
|
495
496
|
|
|
496
|
-
pq.compute_inner_prod_tables
|
|
497
|
-
|
|
497
|
+
pq.compute_inner_prod_tables(
|
|
498
|
+
cpq.ksub, centroids.data(), precomputed_table.data());
|
|
498
499
|
|
|
499
500
|
for (size_t i = 0; i < cpq.ksub; i++) {
|
|
500
|
-
float
|
|
501
|
-
fvec_madd
|
|
501
|
+
float* tab = &precomputed_table[i * pq.M * pq.ksub];
|
|
502
|
+
fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
|
|
502
503
|
}
|
|
503
|
-
|
|
504
504
|
}
|
|
505
|
-
|
|
506
505
|
}
|
|
507
506
|
|
|
508
|
-
void IndexIVFPQ::precompute_table
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
use_precomputed_table, quantizer, pq, precomputed_table,
|
|
512
|
-
verbose
|
|
513
|
-
);
|
|
507
|
+
void IndexIVFPQ::precompute_table() {
|
|
508
|
+
initialize_IVFPQ_precomputed_table(
|
|
509
|
+
use_precomputed_table, quantizer, pq, precomputed_table, verbose);
|
|
514
510
|
}
|
|
515
511
|
|
|
516
|
-
|
|
517
|
-
|
|
518
512
|
namespace {
|
|
519
513
|
|
|
520
514
|
using idx_t = Index::idx_t;
|
|
521
515
|
|
|
522
|
-
|
|
523
516
|
#define TIC t0 = get_cycles()
|
|
524
|
-
#define TOC get_cycles
|
|
525
|
-
|
|
526
|
-
|
|
517
|
+
#define TOC get_cycles() - t0
|
|
527
518
|
|
|
528
519
|
/** QueryTables manages the various ways of searching an
|
|
529
520
|
* IndexIVFPQ. The code contains a lot of branches, depending on:
|
|
@@ -533,43 +524,42 @@ using idx_t = Index::idx_t;
|
|
|
533
524
|
* - polysemous_ht: are we filtering with polysemous codes?
|
|
534
525
|
*/
|
|
535
526
|
struct QueryTables {
|
|
536
|
-
|
|
537
527
|
/*****************************************************
|
|
538
528
|
* General data from the IVFPQ
|
|
539
529
|
*****************************************************/
|
|
540
530
|
|
|
541
|
-
const IndexIVFPQ
|
|
542
|
-
const IVFSearchParameters
|
|
531
|
+
const IndexIVFPQ& ivfpq;
|
|
532
|
+
const IVFSearchParameters* params;
|
|
543
533
|
|
|
544
534
|
// copied from IndexIVFPQ for easier access
|
|
545
535
|
int d;
|
|
546
|
-
const ProductQuantizer
|
|
536
|
+
const ProductQuantizer& pq;
|
|
547
537
|
MetricType metric_type;
|
|
548
538
|
bool by_residual;
|
|
549
539
|
int use_precomputed_table;
|
|
550
540
|
int polysemous_ht;
|
|
551
541
|
|
|
552
542
|
// pre-allocated data buffers
|
|
553
|
-
float *
|
|
554
|
-
float *
|
|
543
|
+
float *sim_table, *sim_table_2;
|
|
544
|
+
float *residual_vec, *decoded_vec;
|
|
555
545
|
|
|
556
546
|
// single data buffer
|
|
557
547
|
std::vector<float> mem;
|
|
558
548
|
|
|
559
549
|
// for table pointers
|
|
560
|
-
std::vector<const float
|
|
561
|
-
|
|
562
|
-
explicit QueryTables
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
mem.resize
|
|
572
|
-
sim_table = mem.data
|
|
550
|
+
std::vector<const float*> sim_table_ptrs;
|
|
551
|
+
|
|
552
|
+
explicit QueryTables(
|
|
553
|
+
const IndexIVFPQ& ivfpq,
|
|
554
|
+
const IVFSearchParameters* params)
|
|
555
|
+
: ivfpq(ivfpq),
|
|
556
|
+
d(ivfpq.d),
|
|
557
|
+
pq(ivfpq.pq),
|
|
558
|
+
metric_type(ivfpq.metric_type),
|
|
559
|
+
by_residual(ivfpq.by_residual),
|
|
560
|
+
use_precomputed_table(ivfpq.use_precomputed_table) {
|
|
561
|
+
mem.resize(pq.ksub * pq.M * 2 + d * 2);
|
|
562
|
+
sim_table = mem.data();
|
|
573
563
|
sim_table_2 = sim_table + pq.ksub * pq.M;
|
|
574
564
|
residual_vec = sim_table_2 + pq.ksub * pq.M;
|
|
575
565
|
decoded_vec = residual_vec + d;
|
|
@@ -577,14 +567,14 @@ struct QueryTables {
|
|
|
577
567
|
// for polysemous
|
|
578
568
|
polysemous_ht = ivfpq.polysemous_ht;
|
|
579
569
|
if (auto ivfpq_params =
|
|
580
|
-
|
|
570
|
+
dynamic_cast<const IVFPQSearchParameters*>(params)) {
|
|
581
571
|
polysemous_ht = ivfpq_params->polysemous_ht;
|
|
582
572
|
}
|
|
583
|
-
if (polysemous_ht != 0)
|
|
584
|
-
q_code.resize
|
|
573
|
+
if (polysemous_ht != 0) {
|
|
574
|
+
q_code.resize(pq.code_size);
|
|
585
575
|
}
|
|
586
576
|
init_list_cycles = 0;
|
|
587
|
-
sim_table_ptrs.resize
|
|
577
|
+
sim_table_ptrs.resize(pq.M);
|
|
588
578
|
}
|
|
589
579
|
|
|
590
580
|
/*****************************************************
|
|
@@ -592,29 +582,29 @@ struct QueryTables {
|
|
|
592
582
|
*****************************************************/
|
|
593
583
|
|
|
594
584
|
// field specific to query
|
|
595
|
-
const float
|
|
585
|
+
const float* qi;
|
|
596
586
|
|
|
597
|
-
// query-specific
|
|
598
|
-
void init_query
|
|
587
|
+
// query-specific initialization
|
|
588
|
+
void init_query(const float* qi) {
|
|
599
589
|
this->qi = qi;
|
|
600
590
|
if (metric_type == METRIC_INNER_PRODUCT)
|
|
601
|
-
init_query_IP
|
|
591
|
+
init_query_IP();
|
|
602
592
|
else
|
|
603
|
-
init_query_L2
|
|
593
|
+
init_query_L2();
|
|
604
594
|
if (!by_residual && polysemous_ht != 0)
|
|
605
|
-
pq.compute_code
|
|
595
|
+
pq.compute_code(qi, q_code.data());
|
|
606
596
|
}
|
|
607
597
|
|
|
608
|
-
void init_query_IP
|
|
598
|
+
void init_query_IP() {
|
|
609
599
|
// precompute some tables specific to the query qi
|
|
610
|
-
pq.compute_inner_prod_table
|
|
600
|
+
pq.compute_inner_prod_table(qi, sim_table);
|
|
611
601
|
}
|
|
612
602
|
|
|
613
|
-
void init_query_L2
|
|
603
|
+
void init_query_L2() {
|
|
614
604
|
if (!by_residual) {
|
|
615
|
-
pq.compute_distance_table
|
|
605
|
+
pq.compute_distance_table(qi, sim_table);
|
|
616
606
|
} else if (use_precomputed_table) {
|
|
617
|
-
pq.compute_inner_prod_table
|
|
607
|
+
pq.compute_inner_prod_table(qi, sim_table_2);
|
|
618
608
|
}
|
|
619
609
|
}
|
|
620
610
|
|
|
@@ -632,96 +622,95 @@ struct QueryTables {
|
|
|
632
622
|
/// once we know the query and the centroid, we can prepare the
|
|
633
623
|
/// sim_table that will be used for accumulation
|
|
634
624
|
/// and dis0, the initial value
|
|
635
|
-
float precompute_list_tables
|
|
625
|
+
float precompute_list_tables() {
|
|
636
626
|
float dis0 = 0;
|
|
637
|
-
uint64_t t0;
|
|
627
|
+
uint64_t t0;
|
|
628
|
+
TIC;
|
|
638
629
|
if (by_residual) {
|
|
639
630
|
if (metric_type == METRIC_INNER_PRODUCT)
|
|
640
|
-
dis0 = precompute_list_tables_IP
|
|
631
|
+
dis0 = precompute_list_tables_IP();
|
|
641
632
|
else
|
|
642
|
-
dis0 = precompute_list_tables_L2
|
|
633
|
+
dis0 = precompute_list_tables_L2();
|
|
643
634
|
}
|
|
644
635
|
init_list_cycles += TOC;
|
|
645
636
|
return dis0;
|
|
646
|
-
|
|
637
|
+
}
|
|
647
638
|
|
|
648
|
-
float precompute_list_table_pointers
|
|
639
|
+
float precompute_list_table_pointers() {
|
|
649
640
|
float dis0 = 0;
|
|
650
|
-
uint64_t t0;
|
|
641
|
+
uint64_t t0;
|
|
642
|
+
TIC;
|
|
651
643
|
if (by_residual) {
|
|
652
644
|
if (metric_type == METRIC_INNER_PRODUCT)
|
|
653
|
-
|
|
645
|
+
FAISS_THROW_MSG("not implemented");
|
|
654
646
|
else
|
|
655
|
-
|
|
647
|
+
dis0 = precompute_list_table_pointers_L2();
|
|
656
648
|
}
|
|
657
649
|
init_list_cycles += TOC;
|
|
658
650
|
return dis0;
|
|
659
|
-
|
|
651
|
+
}
|
|
660
652
|
|
|
661
653
|
/*****************************************************
|
|
662
654
|
* compute tables for inner prod
|
|
663
655
|
*****************************************************/
|
|
664
656
|
|
|
665
|
-
float precompute_list_tables_IP
|
|
666
|
-
{
|
|
657
|
+
float precompute_list_tables_IP() {
|
|
667
658
|
// prepare the sim_table that will be used for accumulation
|
|
668
659
|
// and dis0, the initial value
|
|
669
|
-
ivfpq.quantizer->reconstruct
|
|
660
|
+
ivfpq.quantizer->reconstruct(key, decoded_vec);
|
|
670
661
|
// decoded_vec = centroid
|
|
671
|
-
float dis0 = fvec_inner_product
|
|
662
|
+
float dis0 = fvec_inner_product(qi, decoded_vec, d);
|
|
672
663
|
|
|
673
664
|
if (polysemous_ht) {
|
|
674
665
|
for (int i = 0; i < d; i++) {
|
|
675
|
-
residual_vec
|
|
666
|
+
residual_vec[i] = qi[i] - decoded_vec[i];
|
|
676
667
|
}
|
|
677
|
-
pq.compute_code
|
|
668
|
+
pq.compute_code(residual_vec, q_code.data());
|
|
678
669
|
}
|
|
679
670
|
return dis0;
|
|
680
671
|
}
|
|
681
672
|
|
|
682
|
-
|
|
683
673
|
/*****************************************************
|
|
684
674
|
* compute tables for L2 distance
|
|
685
675
|
*****************************************************/
|
|
686
676
|
|
|
687
|
-
float precompute_list_tables_L2
|
|
688
|
-
{
|
|
677
|
+
float precompute_list_tables_L2() {
|
|
689
678
|
float dis0 = 0;
|
|
690
679
|
|
|
691
680
|
if (use_precomputed_table == 0 || use_precomputed_table == -1) {
|
|
692
|
-
ivfpq.quantizer->compute_residual
|
|
693
|
-
pq.compute_distance_table
|
|
681
|
+
ivfpq.quantizer->compute_residual(qi, residual_vec, key);
|
|
682
|
+
pq.compute_distance_table(residual_vec, sim_table);
|
|
694
683
|
|
|
695
684
|
if (polysemous_ht != 0) {
|
|
696
|
-
pq.compute_code
|
|
685
|
+
pq.compute_code(residual_vec, q_code.data());
|
|
697
686
|
}
|
|
698
687
|
|
|
699
688
|
} else if (use_precomputed_table == 1) {
|
|
700
689
|
dis0 = coarse_dis;
|
|
701
690
|
|
|
702
|
-
fvec_madd
|
|
691
|
+
fvec_madd(
|
|
703
692
|
pq.M * pq.ksub,
|
|
704
693
|
ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
|
|
705
|
-
-2.0,
|
|
706
|
-
|
|
707
|
-
|
|
694
|
+
-2.0,
|
|
695
|
+
sim_table_2,
|
|
696
|
+
sim_table);
|
|
708
697
|
|
|
709
698
|
if (polysemous_ht != 0) {
|
|
710
|
-
ivfpq.quantizer->compute_residual
|
|
711
|
-
pq.compute_code
|
|
699
|
+
ivfpq.quantizer->compute_residual(qi, residual_vec, key);
|
|
700
|
+
pq.compute_code(residual_vec, q_code.data());
|
|
712
701
|
}
|
|
713
702
|
|
|
714
703
|
} else if (use_precomputed_table == 2) {
|
|
715
704
|
dis0 = coarse_dis;
|
|
716
705
|
|
|
717
|
-
const MultiIndexQuantizer
|
|
718
|
-
|
|
719
|
-
FAISS_THROW_IF_NOT
|
|
720
|
-
const ProductQuantizer
|
|
706
|
+
const MultiIndexQuantizer* miq =
|
|
707
|
+
dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
|
|
708
|
+
FAISS_THROW_IF_NOT(miq);
|
|
709
|
+
const ProductQuantizer& cpq = miq->pq;
|
|
721
710
|
int Mf = pq.M / cpq.M;
|
|
722
711
|
|
|
723
|
-
const float
|
|
724
|
-
float
|
|
712
|
+
const float* qtab = sim_table_2; // query-specific table
|
|
713
|
+
float* ltab = sim_table; // (output) list-specific table
|
|
725
714
|
|
|
726
715
|
long k = key;
|
|
727
716
|
for (int cm = 0; cm < cpq.M; cm++) {
|
|
@@ -730,54 +719,48 @@ struct QueryTables {
|
|
|
730
719
|
k >>= cpq.nbits;
|
|
731
720
|
|
|
732
721
|
// get corresponding table
|
|
733
|
-
const float
|
|
734
|
-
|
|
722
|
+
const float* pc = ivfpq.precomputed_table.data() +
|
|
723
|
+
(ki * pq.M + cm * Mf) * pq.ksub;
|
|
735
724
|
|
|
736
725
|
if (polysemous_ht == 0) {
|
|
737
|
-
|
|
738
726
|
// sum up with query-specific table
|
|
739
|
-
fvec_madd
|
|
740
|
-
pc,
|
|
741
|
-
-2.0, qtab,
|
|
742
|
-
ltab);
|
|
727
|
+
fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
|
|
743
728
|
ltab += Mf * pq.ksub;
|
|
744
729
|
qtab += Mf * pq.ksub;
|
|
745
730
|
} else {
|
|
746
731
|
for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
|
|
747
|
-
q_code[m] = fvec_madd_and_argmin
|
|
748
|
-
|
|
732
|
+
q_code[m] = fvec_madd_and_argmin(
|
|
733
|
+
pq.ksub, pc, -2, qtab, ltab);
|
|
749
734
|
pc += pq.ksub;
|
|
750
735
|
ltab += pq.ksub;
|
|
751
736
|
qtab += pq.ksub;
|
|
752
737
|
}
|
|
753
738
|
}
|
|
754
|
-
|
|
755
739
|
}
|
|
756
740
|
}
|
|
757
741
|
|
|
758
742
|
return dis0;
|
|
759
743
|
}
|
|
760
744
|
|
|
761
|
-
float precompute_list_table_pointers_L2
|
|
762
|
-
{
|
|
745
|
+
float precompute_list_table_pointers_L2() {
|
|
763
746
|
float dis0 = 0;
|
|
764
747
|
|
|
765
748
|
if (use_precomputed_table == 1) {
|
|
766
749
|
dis0 = coarse_dis;
|
|
767
750
|
|
|
768
|
-
const float
|
|
769
|
-
key * pq.ksub * pq.M;
|
|
751
|
+
const float* s =
|
|
752
|
+
ivfpq.precomputed_table.data() + key * pq.ksub * pq.M;
|
|
770
753
|
for (int m = 0; m < pq.M; m++) {
|
|
771
|
-
sim_table_ptrs
|
|
754
|
+
sim_table_ptrs[m] = s;
|
|
772
755
|
s += pq.ksub;
|
|
773
756
|
}
|
|
774
757
|
} else if (use_precomputed_table == 2) {
|
|
775
758
|
dis0 = coarse_dis;
|
|
776
759
|
|
|
777
|
-
const MultiIndexQuantizer
|
|
778
|
-
|
|
779
|
-
FAISS_THROW_IF_NOT
|
|
780
|
-
const ProductQuantizer
|
|
760
|
+
const MultiIndexQuantizer* miq =
|
|
761
|
+
dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
|
|
762
|
+
FAISS_THROW_IF_NOT(miq);
|
|
763
|
+
const ProductQuantizer& cpq = miq->pq;
|
|
781
764
|
int Mf = pq.M / cpq.M;
|
|
782
765
|
|
|
783
766
|
long k = key;
|
|
@@ -786,21 +769,21 @@ struct QueryTables {
|
|
|
786
769
|
int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
|
|
787
770
|
k >>= cpq.nbits;
|
|
788
771
|
|
|
789
|
-
const float
|
|
790
|
-
|
|
772
|
+
const float* pc = ivfpq.precomputed_table.data() +
|
|
773
|
+
(ki * pq.M + cm * Mf) * pq.ksub;
|
|
791
774
|
|
|
792
775
|
for (int m = m0; m < m0 + Mf; m++) {
|
|
793
|
-
sim_table_ptrs
|
|
776
|
+
sim_table_ptrs[m] = pc;
|
|
794
777
|
pc += pq.ksub;
|
|
795
778
|
}
|
|
796
779
|
m0 += Mf;
|
|
797
780
|
}
|
|
798
781
|
} else {
|
|
799
|
-
|
|
782
|
+
FAISS_THROW_MSG("need precomputed tables");
|
|
800
783
|
}
|
|
801
784
|
|
|
802
785
|
if (polysemous_ht) {
|
|
803
|
-
FAISS_THROW_MSG
|
|
786
|
+
FAISS_THROW_MSG("not implemented");
|
|
804
787
|
// Not clear that it makes sense to implemente this,
|
|
805
788
|
// because it costs M * ksub, which is what we wanted to
|
|
806
789
|
// avoid with the tables pointers.
|
|
@@ -808,82 +791,72 @@ struct QueryTables {
|
|
|
808
791
|
|
|
809
792
|
return dis0;
|
|
810
793
|
}
|
|
811
|
-
|
|
812
|
-
|
|
813
794
|
};
|
|
814
795
|
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
template<class C>
|
|
796
|
+
template <class C>
|
|
818
797
|
struct KnnSearchResults {
|
|
819
798
|
idx_t key;
|
|
820
|
-
const idx_t
|
|
799
|
+
const idx_t* ids;
|
|
821
800
|
|
|
822
801
|
// heap params
|
|
823
802
|
size_t k;
|
|
824
|
-
float
|
|
825
|
-
idx_t
|
|
803
|
+
float* heap_sim;
|
|
804
|
+
idx_t* heap_ids;
|
|
826
805
|
|
|
827
806
|
size_t nup;
|
|
828
807
|
|
|
829
|
-
inline void add
|
|
830
|
-
if (C::cmp
|
|
831
|
-
idx_t id = ids ? ids[j] : lo_build
|
|
832
|
-
heap_replace_top<C>
|
|
808
|
+
inline void add(idx_t j, float dis) {
|
|
809
|
+
if (C::cmp(heap_sim[0], dis)) {
|
|
810
|
+
idx_t id = ids ? ids[j] : lo_build(key, j);
|
|
811
|
+
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
|
|
833
812
|
nup++;
|
|
834
813
|
}
|
|
835
814
|
}
|
|
836
|
-
|
|
837
815
|
};
|
|
838
816
|
|
|
839
|
-
template<class C>
|
|
817
|
+
template <class C>
|
|
840
818
|
struct RangeSearchResults {
|
|
841
819
|
idx_t key;
|
|
842
|
-
const idx_t
|
|
820
|
+
const idx_t* ids;
|
|
843
821
|
|
|
844
822
|
// wrapped result structure
|
|
845
823
|
float radius;
|
|
846
|
-
RangeQueryResult
|
|
824
|
+
RangeQueryResult& rres;
|
|
847
825
|
|
|
848
|
-
inline void add
|
|
849
|
-
if (C::cmp
|
|
850
|
-
idx_t id = ids ? ids[j] : lo_build
|
|
851
|
-
rres.add
|
|
826
|
+
inline void add(idx_t j, float dis) {
|
|
827
|
+
if (C::cmp(radius, dis)) {
|
|
828
|
+
idx_t id = ids ? ids[j] : lo_build(key, j);
|
|
829
|
+
rres.add(dis, id);
|
|
852
830
|
}
|
|
853
831
|
}
|
|
854
832
|
};
|
|
855
833
|
|
|
856
|
-
|
|
857
|
-
|
|
858
834
|
/*****************************************************
|
|
859
835
|
* Scaning the codes.
|
|
860
836
|
* The scanning functions call their favorite precompute_*
|
|
861
837
|
* function to precompute the tables they need.
|
|
862
838
|
*****************************************************/
|
|
863
839
|
template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
|
|
864
|
-
struct IVFPQScannerT: QueryTables {
|
|
865
|
-
|
|
866
|
-
const
|
|
867
|
-
const IDType * list_ids;
|
|
840
|
+
struct IVFPQScannerT : QueryTables {
|
|
841
|
+
const uint8_t* list_codes;
|
|
842
|
+
const IDType* list_ids;
|
|
868
843
|
size_t list_size;
|
|
869
844
|
|
|
870
|
-
IVFPQScannerT
|
|
871
|
-
|
|
872
|
-
{
|
|
845
|
+
IVFPQScannerT(const IndexIVFPQ& ivfpq, const IVFSearchParameters* params)
|
|
846
|
+
: QueryTables(ivfpq, params) {
|
|
873
847
|
assert(METRIC_TYPE == metric_type);
|
|
874
848
|
}
|
|
875
849
|
|
|
876
850
|
float dis0;
|
|
877
851
|
|
|
878
|
-
void init_list
|
|
879
|
-
int mode) {
|
|
852
|
+
void init_list(idx_t list_no, float coarse_dis, int mode) {
|
|
880
853
|
this->key = list_no;
|
|
881
854
|
this->coarse_dis = coarse_dis;
|
|
882
855
|
|
|
883
856
|
if (mode == 2) {
|
|
884
|
-
dis0 = precompute_list_tables
|
|
857
|
+
dis0 = precompute_list_tables();
|
|
885
858
|
} else if (mode == 1) {
|
|
886
|
-
dis0 = precompute_list_table_pointers
|
|
859
|
+
dis0 = precompute_list_table_pointers();
|
|
887
860
|
}
|
|
888
861
|
}
|
|
889
862
|
|
|
@@ -892,15 +865,16 @@ struct IVFPQScannerT: QueryTables {
|
|
|
892
865
|
*****************************************************/
|
|
893
866
|
|
|
894
867
|
/// version of the scan where we use precomputed tables
|
|
895
|
-
template<class SearchResultType>
|
|
896
|
-
void scan_list_with_table
|
|
897
|
-
|
|
898
|
-
|
|
868
|
+
template <class SearchResultType>
|
|
869
|
+
void scan_list_with_table(
|
|
870
|
+
size_t ncode,
|
|
871
|
+
const uint8_t* codes,
|
|
872
|
+
SearchResultType& res) const {
|
|
899
873
|
for (size_t j = 0; j < ncode; j++) {
|
|
900
874
|
PQDecoder decoder(codes, pq.nbits);
|
|
901
875
|
codes += pq.code_size;
|
|
902
876
|
float dis = dis0;
|
|
903
|
-
const float
|
|
877
|
+
const float* tab = sim_table;
|
|
904
878
|
|
|
905
879
|
for (size_t m = 0; m < pq.M; m++) {
|
|
906
880
|
dis += tab[decoder.decode()];
|
|
@@ -911,43 +885,43 @@ struct IVFPQScannerT: QueryTables {
|
|
|
911
885
|
}
|
|
912
886
|
}
|
|
913
887
|
|
|
914
|
-
|
|
915
888
|
/// tables are not precomputed, but pointers are provided to the
|
|
916
889
|
/// relevant X_c|x_r tables
|
|
917
|
-
template<class SearchResultType>
|
|
918
|
-
void scan_list_with_pointer
|
|
919
|
-
|
|
920
|
-
|
|
890
|
+
template <class SearchResultType>
|
|
891
|
+
void scan_list_with_pointer(
|
|
892
|
+
size_t ncode,
|
|
893
|
+
const uint8_t* codes,
|
|
894
|
+
SearchResultType& res) const {
|
|
921
895
|
for (size_t j = 0; j < ncode; j++) {
|
|
922
896
|
PQDecoder decoder(codes, pq.nbits);
|
|
923
897
|
codes += pq.code_size;
|
|
924
898
|
|
|
925
899
|
float dis = dis0;
|
|
926
|
-
const float
|
|
900
|
+
const float* tab = sim_table_2;
|
|
927
901
|
|
|
928
902
|
for (size_t m = 0; m < pq.M; m++) {
|
|
929
903
|
int ci = decoder.decode();
|
|
930
|
-
dis += sim_table_ptrs
|
|
904
|
+
dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
|
|
931
905
|
tab += pq.ksub;
|
|
932
906
|
}
|
|
933
|
-
res.add
|
|
907
|
+
res.add(j, dis);
|
|
934
908
|
}
|
|
935
909
|
}
|
|
936
910
|
|
|
937
|
-
|
|
938
911
|
/// nothing is precomputed: access residuals on-the-fly
|
|
939
|
-
template<class SearchResultType>
|
|
940
|
-
void scan_on_the_fly_dist
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
912
|
+
template <class SearchResultType>
|
|
913
|
+
void scan_on_the_fly_dist(
|
|
914
|
+
size_t ncode,
|
|
915
|
+
const uint8_t* codes,
|
|
916
|
+
SearchResultType& res) const {
|
|
917
|
+
const float* dvec;
|
|
944
918
|
float dis0 = 0;
|
|
945
919
|
if (by_residual) {
|
|
946
920
|
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
|
|
947
|
-
ivfpq.quantizer->reconstruct
|
|
948
|
-
dis0 = fvec_inner_product
|
|
921
|
+
ivfpq.quantizer->reconstruct(key, residual_vec);
|
|
922
|
+
dis0 = fvec_inner_product(residual_vec, qi, d);
|
|
949
923
|
} else {
|
|
950
|
-
ivfpq.quantizer->compute_residual
|
|
924
|
+
ivfpq.quantizer->compute_residual(qi, residual_vec, key);
|
|
951
925
|
}
|
|
952
926
|
dvec = residual_vec;
|
|
953
927
|
} else {
|
|
@@ -956,17 +930,16 @@ struct IVFPQScannerT: QueryTables {
|
|
|
956
930
|
}
|
|
957
931
|
|
|
958
932
|
for (size_t j = 0; j < ncode; j++) {
|
|
959
|
-
|
|
960
|
-
pq.decode (codes, decoded_vec);
|
|
933
|
+
pq.decode(codes, decoded_vec);
|
|
961
934
|
codes += pq.code_size;
|
|
962
935
|
|
|
963
936
|
float dis;
|
|
964
937
|
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
|
|
965
|
-
dis = dis0 + fvec_inner_product
|
|
938
|
+
dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
|
|
966
939
|
} else {
|
|
967
|
-
dis = fvec_L2sqr
|
|
940
|
+
dis = fvec_L2sqr(decoded_vec, dvec, d);
|
|
968
941
|
}
|
|
969
|
-
res.add
|
|
942
|
+
res.add(j, dis);
|
|
970
943
|
}
|
|
971
944
|
}
|
|
972
945
|
|
|
@@ -975,110 +948,99 @@ struct IVFPQScannerT: QueryTables {
|
|
|
975
948
|
*****************************************************/
|
|
976
949
|
|
|
977
950
|
template <class HammingComputer, class SearchResultType>
|
|
978
|
-
void scan_list_polysemous_hc
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
951
|
+
void scan_list_polysemous_hc(
|
|
952
|
+
size_t ncode,
|
|
953
|
+
const uint8_t* codes,
|
|
954
|
+
SearchResultType& res) const {
|
|
982
955
|
int ht = ivfpq.polysemous_ht;
|
|
983
956
|
size_t n_hamming_pass = 0, nup = 0;
|
|
984
957
|
|
|
985
958
|
int code_size = pq.code_size;
|
|
986
959
|
|
|
987
|
-
HammingComputer hc
|
|
960
|
+
HammingComputer hc(q_code.data(), code_size);
|
|
988
961
|
|
|
989
962
|
for (size_t j = 0; j < ncode; j++) {
|
|
990
|
-
const uint8_t
|
|
991
|
-
int hd = hc.hamming
|
|
963
|
+
const uint8_t* b_code = codes;
|
|
964
|
+
int hd = hc.hamming(b_code);
|
|
992
965
|
if (hd < ht) {
|
|
993
|
-
n_hamming_pass
|
|
966
|
+
n_hamming_pass++;
|
|
994
967
|
PQDecoder decoder(codes, pq.nbits);
|
|
995
968
|
|
|
996
969
|
float dis = dis0;
|
|
997
|
-
const float
|
|
970
|
+
const float* tab = sim_table;
|
|
998
971
|
|
|
999
972
|
for (size_t m = 0; m < pq.M; m++) {
|
|
1000
973
|
dis += tab[decoder.decode()];
|
|
1001
974
|
tab += pq.ksub;
|
|
1002
975
|
}
|
|
1003
976
|
|
|
1004
|
-
res.add
|
|
977
|
+
res.add(j, dis);
|
|
1005
978
|
}
|
|
1006
979
|
codes += code_size;
|
|
1007
980
|
}
|
|
1008
981
|
#pragma omp critical
|
|
1009
|
-
{
|
|
1010
|
-
indexIVFPQ_stats.n_hamming_pass += n_hamming_pass;
|
|
1011
|
-
}
|
|
982
|
+
{ indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
|
|
1012
983
|
}
|
|
1013
984
|
|
|
1014
|
-
template<class SearchResultType>
|
|
1015
|
-
void scan_list_polysemous
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
985
|
+
template <class SearchResultType>
|
|
986
|
+
void scan_list_polysemous(
|
|
987
|
+
size_t ncode,
|
|
988
|
+
const uint8_t* codes,
|
|
989
|
+
SearchResultType& res) const {
|
|
1019
990
|
switch (pq.code_size) {
|
|
1020
991
|
#define HANDLE_CODE_SIZE(cs) \
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
HANDLE_CODE_SIZE(64);
|
|
992
|
+
case cs: \
|
|
993
|
+
scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
|
|
994
|
+
ncode, codes, res); \
|
|
995
|
+
break
|
|
996
|
+
HANDLE_CODE_SIZE(4);
|
|
997
|
+
HANDLE_CODE_SIZE(8);
|
|
998
|
+
HANDLE_CODE_SIZE(16);
|
|
999
|
+
HANDLE_CODE_SIZE(20);
|
|
1000
|
+
HANDLE_CODE_SIZE(32);
|
|
1001
|
+
HANDLE_CODE_SIZE(64);
|
|
1032
1002
|
#undef HANDLE_CODE_SIZE
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
else
|
|
1039
|
-
scan_list_polysemous_hc
|
|
1040
|
-
<HammingComputerM4, SearchResultType>
|
|
1041
|
-
(ncode, codes, res);
|
|
1042
|
-
break;
|
|
1003
|
+
default:
|
|
1004
|
+
scan_list_polysemous_hc<
|
|
1005
|
+
HammingComputerDefault,
|
|
1006
|
+
SearchResultType>(ncode, codes, res);
|
|
1007
|
+
break;
|
|
1043
1008
|
}
|
|
1044
1009
|
}
|
|
1045
|
-
|
|
1046
1010
|
};
|
|
1047
1011
|
|
|
1048
|
-
|
|
1049
1012
|
/* We put as many parameters as possible in template. Hopefully the
|
|
1050
1013
|
* gain in runtime is worth the code bloat. C is the comparator < or
|
|
1051
1014
|
* >, it is directly related to METRIC_TYPE. precompute_mode is how
|
|
1052
1015
|
* much we precompute (2 = precompute distance tables, 1 = precompute
|
|
1053
1016
|
* pointers to distances, 0 = compute distances one by one).
|
|
1054
1017
|
* Currently only 2 is supported */
|
|
1055
|
-
template<MetricType METRIC_TYPE, class C, class PQDecoder>
|
|
1056
|
-
struct IVFPQScanner:
|
|
1057
|
-
|
|
1058
|
-
InvertedListScanner
|
|
1059
|
-
{
|
|
1060
|
-
bool store_pairs;
|
|
1018
|
+
template <MetricType METRIC_TYPE, class C, class PQDecoder>
|
|
1019
|
+
struct IVFPQScanner : IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>,
|
|
1020
|
+
InvertedListScanner {
|
|
1061
1021
|
int precompute_mode;
|
|
1062
1022
|
|
|
1063
|
-
IVFPQScanner(const IndexIVFPQ
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1023
|
+
IVFPQScanner(const IndexIVFPQ& ivfpq, bool store_pairs, int precompute_mode)
|
|
1024
|
+
: IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>(
|
|
1025
|
+
ivfpq,
|
|
1026
|
+
nullptr),
|
|
1027
|
+
precompute_mode(precompute_mode) {
|
|
1028
|
+
this->store_pairs = store_pairs;
|
|
1068
1029
|
}
|
|
1069
1030
|
|
|
1070
|
-
void set_query
|
|
1071
|
-
this->init_query
|
|
1031
|
+
void set_query(const float* query) override {
|
|
1032
|
+
this->init_query(query);
|
|
1072
1033
|
}
|
|
1073
1034
|
|
|
1074
|
-
void set_list
|
|
1075
|
-
this->
|
|
1035
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
|
1036
|
+
this->list_no = list_no;
|
|
1037
|
+
this->init_list(list_no, coarse_dis, precompute_mode);
|
|
1076
1038
|
}
|
|
1077
1039
|
|
|
1078
|
-
float distance_to_code
|
|
1040
|
+
float distance_to_code(const uint8_t* code) const override {
|
|
1079
1041
|
assert(precompute_mode == 2);
|
|
1080
1042
|
float dis = this->dis0;
|
|
1081
|
-
const float
|
|
1043
|
+
const float* tab = this->sim_table;
|
|
1082
1044
|
PQDecoder decoder(code, this->pq.nbits);
|
|
1083
1045
|
|
|
1084
1046
|
for (size_t m = 0; m < this->pq.M; m++) {
|
|
@@ -1088,112 +1050,100 @@ struct IVFPQScanner:
|
|
|
1088
1050
|
return dis;
|
|
1089
1051
|
}
|
|
1090
1052
|
|
|
1091
|
-
size_t scan_codes
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1053
|
+
size_t scan_codes(
|
|
1054
|
+
size_t ncode,
|
|
1055
|
+
const uint8_t* codes,
|
|
1056
|
+
const idx_t* ids,
|
|
1057
|
+
float* heap_sim,
|
|
1058
|
+
idx_t* heap_ids,
|
|
1059
|
+
size_t k) const override {
|
|
1097
1060
|
KnnSearchResults<C> res = {
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
};
|
|
1061
|
+
/* key */ this->key,
|
|
1062
|
+
/* ids */ this->store_pairs ? nullptr : ids,
|
|
1063
|
+
/* k */ k,
|
|
1064
|
+
/* heap_sim */ heap_sim,
|
|
1065
|
+
/* heap_ids */ heap_ids,
|
|
1066
|
+
/* nup */ 0};
|
|
1105
1067
|
|
|
1106
1068
|
if (this->polysemous_ht > 0) {
|
|
1107
1069
|
assert(precompute_mode == 2);
|
|
1108
|
-
this->scan_list_polysemous
|
|
1070
|
+
this->scan_list_polysemous(ncode, codes, res);
|
|
1109
1071
|
} else if (precompute_mode == 2) {
|
|
1110
|
-
this->scan_list_with_table
|
|
1072
|
+
this->scan_list_with_table(ncode, codes, res);
|
|
1111
1073
|
} else if (precompute_mode == 1) {
|
|
1112
|
-
this->scan_list_with_pointer
|
|
1074
|
+
this->scan_list_with_pointer(ncode, codes, res);
|
|
1113
1075
|
} else if (precompute_mode == 0) {
|
|
1114
|
-
this->scan_on_the_fly_dist
|
|
1076
|
+
this->scan_on_the_fly_dist(ncode, codes, res);
|
|
1115
1077
|
} else {
|
|
1116
1078
|
FAISS_THROW_MSG("bad precomp mode");
|
|
1117
1079
|
}
|
|
1118
1080
|
return res.nup;
|
|
1119
1081
|
}
|
|
1120
1082
|
|
|
1121
|
-
void scan_codes_range
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1083
|
+
void scan_codes_range(
|
|
1084
|
+
size_t ncode,
|
|
1085
|
+
const uint8_t* codes,
|
|
1086
|
+
const idx_t* ids,
|
|
1087
|
+
float radius,
|
|
1088
|
+
RangeQueryResult& rres) const override {
|
|
1127
1089
|
RangeSearchResults<C> res = {
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
};
|
|
1090
|
+
/* key */ this->key,
|
|
1091
|
+
/* ids */ this->store_pairs ? nullptr : ids,
|
|
1092
|
+
/* radius */ radius,
|
|
1093
|
+
/* rres */ rres};
|
|
1133
1094
|
|
|
1134
1095
|
if (this->polysemous_ht > 0) {
|
|
1135
1096
|
assert(precompute_mode == 2);
|
|
1136
|
-
this->scan_list_polysemous
|
|
1097
|
+
this->scan_list_polysemous(ncode, codes, res);
|
|
1137
1098
|
} else if (precompute_mode == 2) {
|
|
1138
|
-
this->scan_list_with_table
|
|
1099
|
+
this->scan_list_with_table(ncode, codes, res);
|
|
1139
1100
|
} else if (precompute_mode == 1) {
|
|
1140
|
-
this->scan_list_with_pointer
|
|
1101
|
+
this->scan_list_with_pointer(ncode, codes, res);
|
|
1141
1102
|
} else if (precompute_mode == 0) {
|
|
1142
|
-
this->scan_on_the_fly_dist
|
|
1103
|
+
this->scan_on_the_fly_dist(ncode, codes, res);
|
|
1143
1104
|
} else {
|
|
1144
1105
|
FAISS_THROW_MSG("bad precomp mode");
|
|
1145
1106
|
}
|
|
1146
|
-
|
|
1147
1107
|
}
|
|
1148
1108
|
};
|
|
1149
1109
|
|
|
1150
|
-
template<class PQDecoder>
|
|
1151
|
-
InvertedListScanner
|
|
1152
|
-
|
|
1153
|
-
{
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1110
|
+
template <class PQDecoder>
|
|
1111
|
+
InvertedListScanner* get_InvertedListScanner1(
|
|
1112
|
+
const IndexIVFPQ& index,
|
|
1113
|
+
bool store_pairs) {
|
|
1114
|
+
if (index.metric_type == METRIC_INNER_PRODUCT) {
|
|
1115
|
+
return new IVFPQScanner<
|
|
1116
|
+
METRIC_INNER_PRODUCT,
|
|
1117
|
+
CMin<float, idx_t>,
|
|
1118
|
+
PQDecoder>(index, store_pairs, 2);
|
|
1159
1119
|
} else if (index.metric_type == METRIC_L2) {
|
|
1160
|
-
return new IVFPQScanner
|
|
1161
|
-
|
|
1162
|
-
(index, store_pairs, 2);
|
|
1120
|
+
return new IVFPQScanner<METRIC_L2, CMax<float, idx_t>, PQDecoder>(
|
|
1121
|
+
index, store_pairs, 2);
|
|
1163
1122
|
}
|
|
1164
1123
|
return nullptr;
|
|
1165
1124
|
}
|
|
1166
1125
|
|
|
1167
|
-
|
|
1168
1126
|
} // anonymous namespace
|
|
1169
1127
|
|
|
1170
|
-
InvertedListScanner
|
|
1171
|
-
|
|
1172
|
-
{
|
|
1173
|
-
|
|
1128
|
+
InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
|
|
1129
|
+
bool store_pairs) const {
|
|
1174
1130
|
if (pq.nbits == 8) {
|
|
1175
|
-
return get_InvertedListScanner1<PQDecoder8>
|
|
1131
|
+
return get_InvertedListScanner1<PQDecoder8>(*this, store_pairs);
|
|
1176
1132
|
} else if (pq.nbits == 16) {
|
|
1177
|
-
return get_InvertedListScanner1<PQDecoder16>
|
|
1133
|
+
return get_InvertedListScanner1<PQDecoder16>(*this, store_pairs);
|
|
1178
1134
|
} else {
|
|
1179
|
-
return get_InvertedListScanner1<PQDecoderGeneric>
|
|
1135
|
+
return get_InvertedListScanner1<PQDecoderGeneric>(*this, store_pairs);
|
|
1180
1136
|
}
|
|
1181
1137
|
return nullptr;
|
|
1182
|
-
|
|
1183
1138
|
}
|
|
1184
1139
|
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
1140
|
IndexIVFPQStats indexIVFPQ_stats;
|
|
1188
1141
|
|
|
1189
|
-
void IndexIVFPQStats::reset
|
|
1190
|
-
memset
|
|
1142
|
+
void IndexIVFPQStats::reset() {
|
|
1143
|
+
memset(this, 0, sizeof(*this));
|
|
1191
1144
|
}
|
|
1192
1145
|
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
IndexIVFPQ::IndexIVFPQ ()
|
|
1196
|
-
{
|
|
1146
|
+
IndexIVFPQ::IndexIVFPQ() {
|
|
1197
1147
|
// initialize some runtime values
|
|
1198
1148
|
use_precomputed_table = 0;
|
|
1199
1149
|
scan_table_threshold = 0;
|
|
@@ -1202,43 +1152,40 @@ IndexIVFPQ::IndexIVFPQ ()
|
|
|
1202
1152
|
polysemous_training = nullptr;
|
|
1203
1153
|
}
|
|
1204
1154
|
|
|
1205
|
-
|
|
1206
1155
|
struct CodeCmp {
|
|
1207
|
-
const uint8_t
|
|
1156
|
+
const uint8_t* tab;
|
|
1208
1157
|
size_t code_size;
|
|
1209
|
-
bool operator
|
|
1210
|
-
return cmp
|
|
1158
|
+
bool operator()(int a, int b) const {
|
|
1159
|
+
return cmp(a, b) > 0;
|
|
1211
1160
|
}
|
|
1212
|
-
int cmp
|
|
1213
|
-
return memcmp
|
|
1214
|
-
code_size);
|
|
1161
|
+
int cmp(int a, int b) const {
|
|
1162
|
+
return memcmp(tab + a * code_size, tab + b * code_size, code_size);
|
|
1215
1163
|
}
|
|
1216
1164
|
};
|
|
1217
1165
|
|
|
1218
|
-
|
|
1219
|
-
size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
|
|
1220
|
-
{
|
|
1166
|
+
size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
|
|
1221
1167
|
size_t ngroup = 0;
|
|
1222
1168
|
lims[0] = 0;
|
|
1223
1169
|
for (size_t list_no = 0; list_no < nlist; list_no++) {
|
|
1224
|
-
size_t n = invlists->list_size
|
|
1225
|
-
std::vector<int> ord
|
|
1226
|
-
for (int i = 0; i < n; i++)
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1170
|
+
size_t n = invlists->list_size(list_no);
|
|
1171
|
+
std::vector<int> ord(n);
|
|
1172
|
+
for (int i = 0; i < n; i++)
|
|
1173
|
+
ord[i] = i;
|
|
1174
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
1175
|
+
CodeCmp cs = {codes.get(), code_size};
|
|
1176
|
+
std::sort(ord.begin(), ord.end(), cs);
|
|
1177
|
+
|
|
1178
|
+
InvertedLists::ScopedIds list_ids(invlists, list_no);
|
|
1179
|
+
int prev = -1; // all elements from prev to i-1 are equal
|
|
1233
1180
|
for (int i = 0; i < n; i++) {
|
|
1234
|
-
if (prev >= 0 && cs.cmp
|
|
1181
|
+
if (prev >= 0 && cs.cmp(ord[prev], ord[i]) == 0) {
|
|
1235
1182
|
// same as previous => remember
|
|
1236
1183
|
if (prev + 1 == i) { // start new group
|
|
1237
1184
|
ngroup++;
|
|
1238
1185
|
lims[ngroup] = lims[ngroup - 1];
|
|
1239
|
-
dup_ids
|
|
1186
|
+
dup_ids[lims[ngroup]++] = list_ids[ord[prev]];
|
|
1240
1187
|
}
|
|
1241
|
-
dup_ids
|
|
1188
|
+
dup_ids[lims[ngroup]++] = list_ids[ord[i]];
|
|
1242
1189
|
} else { // not same as previous.
|
|
1243
1190
|
prev = i;
|
|
1244
1191
|
}
|
|
@@ -1247,9 +1194,4 @@ size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
|
|
|
1247
1194
|
return ngroup;
|
|
1248
1195
|
}
|
|
1249
1196
|
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
1197
|
} // namespace faiss
|