faiss 0.2.0 → 0.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
@@ -9,17 +9,17 @@
|
|
9
9
|
|
10
10
|
#include <faiss/IndexIVFPQ.h>
|
11
11
|
|
12
|
+
#include <stdint.h>
|
13
|
+
#include <cassert>
|
12
14
|
#include <cinttypes>
|
13
15
|
#include <cmath>
|
14
16
|
#include <cstdio>
|
15
|
-
#include <cassert>
|
16
|
-
#include <stdint.h>
|
17
17
|
|
18
18
|
#include <algorithm>
|
19
19
|
|
20
20
|
#include <faiss/utils/Heap.h>
|
21
|
-
#include <faiss/utils/utils.h>
|
22
21
|
#include <faiss/utils/distances.h>
|
22
|
+
#include <faiss/utils/utils.h>
|
23
23
|
|
24
24
|
#include <faiss/Clustering.h>
|
25
25
|
#include <faiss/IndexFlat.h>
|
@@ -36,12 +36,15 @@ namespace faiss {
|
|
36
36
|
* IndexIVFPQ implementation
|
37
37
|
******************************************/
|
38
38
|
|
39
|
-
IndexIVFPQ::IndexIVFPQ
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
39
|
+
IndexIVFPQ::IndexIVFPQ(
|
40
|
+
Index* quantizer,
|
41
|
+
size_t d,
|
42
|
+
size_t nlist,
|
43
|
+
size_t M,
|
44
|
+
size_t nbits_per_idx,
|
45
|
+
MetricType metric)
|
46
|
+
: IndexIVF(quantizer, d, nlist, 0, metric), pq(d, M, nbits_per_idx) {
|
47
|
+
FAISS_THROW_IF_NOT(nbits_per_idx <= 8);
|
45
48
|
code_size = pq.code_size;
|
46
49
|
invlists->code_size = code_size;
|
47
50
|
is_trained = false;
|
@@ -52,202 +55,197 @@ IndexIVFPQ::IndexIVFPQ (Index * quantizer, size_t d, size_t nlist,
|
|
52
55
|
polysemous_training = nullptr;
|
53
56
|
do_polysemous_training = false;
|
54
57
|
polysemous_ht = 0;
|
55
|
-
|
56
58
|
}
|
57
59
|
|
58
|
-
|
59
60
|
/****************************************************************
|
60
61
|
* training */
|
61
62
|
|
62
|
-
void IndexIVFPQ::train_residual
|
63
|
-
|
64
|
-
train_residual_o (n, x, nullptr);
|
63
|
+
void IndexIVFPQ::train_residual(idx_t n, const float* x) {
|
64
|
+
train_residual_o(n, x, nullptr);
|
65
65
|
}
|
66
66
|
|
67
|
+
void IndexIVFPQ::train_residual_o(idx_t n, const float* x, float* residuals_2) {
|
68
|
+
const float* x_in = x;
|
67
69
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
70
|
+
x = fvecs_maybe_subsample(
|
71
|
+
d,
|
72
|
+
(size_t*)&n,
|
73
|
+
pq.cp.max_points_per_centroid * pq.ksub,
|
74
|
+
x,
|
75
|
+
verbose,
|
76
|
+
pq.cp.seed);
|
75
77
|
|
76
|
-
ScopeDeleter<float> del_x
|
78
|
+
ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
|
77
79
|
|
78
|
-
const float
|
80
|
+
const float* trainset;
|
79
81
|
ScopeDeleter<float> del_residuals;
|
80
82
|
if (by_residual) {
|
81
|
-
if(verbose)
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
83
|
+
if (verbose)
|
84
|
+
printf("computing residuals\n");
|
85
|
+
idx_t* assign = new idx_t[n]; // assignement to coarse centroids
|
86
|
+
ScopeDeleter<idx_t> del(assign);
|
87
|
+
quantizer->assign(n, x, assign);
|
88
|
+
float* residuals = new float[n * d];
|
89
|
+
del_residuals.set(residuals);
|
87
90
|
for (idx_t i = 0; i < n; i++)
|
88
|
-
|
91
|
+
quantizer->compute_residual(
|
92
|
+
x + i * d, residuals + i * d, assign[i]);
|
89
93
|
|
90
94
|
trainset = residuals;
|
91
95
|
} else {
|
92
96
|
trainset = x;
|
93
97
|
}
|
94
98
|
if (verbose)
|
95
|
-
printf
|
96
|
-
|
99
|
+
printf("training %zdx%zd product quantizer on %" PRId64
|
100
|
+
" vectors in %dD\n",
|
101
|
+
pq.M,
|
102
|
+
pq.ksub,
|
103
|
+
n,
|
104
|
+
d);
|
97
105
|
pq.verbose = verbose;
|
98
|
-
pq.train
|
106
|
+
pq.train(n, trainset);
|
99
107
|
|
100
108
|
if (do_polysemous_training) {
|
101
109
|
if (verbose)
|
102
110
|
printf("doing polysemous training for PQ\n");
|
103
111
|
PolysemousTraining default_pt;
|
104
|
-
PolysemousTraining
|
105
|
-
if (!pt)
|
106
|
-
|
112
|
+
PolysemousTraining* pt = polysemous_training;
|
113
|
+
if (!pt)
|
114
|
+
pt = &default_pt;
|
115
|
+
pt->optimize_pq_for_hamming(pq, n, trainset);
|
107
116
|
}
|
108
117
|
|
109
118
|
// prepare second-level residuals for refine PQ
|
110
119
|
if (residuals_2) {
|
111
|
-
uint8_t
|
112
|
-
ScopeDeleter<uint8_t> del
|
113
|
-
pq.compute_codes
|
120
|
+
uint8_t* train_codes = new uint8_t[pq.code_size * n];
|
121
|
+
ScopeDeleter<uint8_t> del(train_codes);
|
122
|
+
pq.compute_codes(trainset, train_codes, n);
|
114
123
|
|
115
124
|
for (idx_t i = 0; i < n; i++) {
|
116
|
-
const float
|
117
|
-
float
|
118
|
-
pq.decode
|
125
|
+
const float* xx = trainset + i * d;
|
126
|
+
float* res = residuals_2 + i * d;
|
127
|
+
pq.decode(train_codes + i * pq.code_size, res);
|
119
128
|
for (int j = 0; j < d; j++)
|
120
129
|
res[j] = xx[j] - res[j];
|
121
130
|
}
|
122
|
-
|
123
131
|
}
|
124
132
|
|
125
133
|
if (by_residual) {
|
126
|
-
precompute_table
|
134
|
+
precompute_table();
|
127
135
|
}
|
128
|
-
|
129
136
|
}
|
130
137
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
138
|
/****************************************************************
|
137
139
|
* IVFPQ as codec */
|
138
140
|
|
139
|
-
|
140
141
|
/* produce a binary signature based on the residual vector */
|
141
|
-
void IndexIVFPQ::encode
|
142
|
-
{
|
142
|
+
void IndexIVFPQ::encode(idx_t key, const float* x, uint8_t* code) const {
|
143
143
|
if (by_residual) {
|
144
144
|
std::vector<float> residual_vec(d);
|
145
|
-
quantizer->compute_residual
|
146
|
-
pq.compute_code
|
147
|
-
}
|
148
|
-
|
145
|
+
quantizer->compute_residual(x, residual_vec.data(), key);
|
146
|
+
pq.compute_code(residual_vec.data(), code);
|
147
|
+
} else
|
148
|
+
pq.compute_code(x, code);
|
149
149
|
}
|
150
150
|
|
151
|
-
void IndexIVFPQ::encode_multiple
|
152
|
-
|
153
|
-
|
154
|
-
|
151
|
+
void IndexIVFPQ::encode_multiple(
|
152
|
+
size_t n,
|
153
|
+
idx_t* keys,
|
154
|
+
const float* x,
|
155
|
+
uint8_t* xcodes,
|
156
|
+
bool compute_keys) const {
|
155
157
|
if (compute_keys)
|
156
|
-
quantizer->assign
|
158
|
+
quantizer->assign(n, x, keys);
|
157
159
|
|
158
|
-
encode_vectors
|
160
|
+
encode_vectors(n, x, keys, xcodes);
|
159
161
|
}
|
160
162
|
|
161
|
-
void IndexIVFPQ::decode_multiple
|
162
|
-
|
163
|
-
|
164
|
-
|
163
|
+
void IndexIVFPQ::decode_multiple(
|
164
|
+
size_t n,
|
165
|
+
const idx_t* keys,
|
166
|
+
const uint8_t* xcodes,
|
167
|
+
float* x) const {
|
168
|
+
pq.decode(xcodes, x, n);
|
165
169
|
if (by_residual) {
|
166
|
-
std::vector<float> centroid
|
170
|
+
std::vector<float> centroid(d);
|
167
171
|
for (size_t i = 0; i < n; i++) {
|
168
|
-
quantizer->reconstruct
|
169
|
-
float
|
172
|
+
quantizer->reconstruct(keys[i], centroid.data());
|
173
|
+
float* xi = x + i * d;
|
170
174
|
for (size_t j = 0; j < d; j++) {
|
171
|
-
xi
|
175
|
+
xi[j] += centroid[j];
|
172
176
|
}
|
173
177
|
}
|
174
178
|
}
|
175
179
|
}
|
176
180
|
|
177
|
-
|
178
|
-
|
179
|
-
|
180
181
|
/****************************************************************
|
181
182
|
* add */
|
182
183
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
184
|
+
void IndexIVFPQ::add_core(
|
185
|
+
idx_t n,
|
186
|
+
const float* x,
|
187
|
+
const idx_t* xids,
|
188
|
+
const idx_t* coarse_idx) {
|
189
|
+
add_core_o(n, x, xids, nullptr, coarse_idx);
|
187
190
|
}
|
188
191
|
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
const Index::idx_t
|
194
|
-
{
|
192
|
+
static float* compute_residuals(
|
193
|
+
const Index* quantizer,
|
194
|
+
Index::idx_t n,
|
195
|
+
const float* x,
|
196
|
+
const Index::idx_t* list_nos) {
|
195
197
|
size_t d = quantizer->d;
|
196
|
-
float
|
198
|
+
float* residuals = new float[n * d];
|
197
199
|
// TODO: parallelize?
|
198
200
|
for (size_t i = 0; i < n; i++) {
|
199
201
|
if (list_nos[i] < 0)
|
200
|
-
memset
|
202
|
+
memset(residuals + i * d, 0, sizeof(*residuals) * d);
|
201
203
|
else
|
202
|
-
quantizer->compute_residual
|
203
|
-
|
204
|
+
quantizer->compute_residual(
|
205
|
+
x + i * d, residuals + i * d, list_nos[i]);
|
204
206
|
}
|
205
207
|
return residuals;
|
206
208
|
}
|
207
209
|
|
208
|
-
void IndexIVFPQ::encode_vectors(
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
210
|
+
void IndexIVFPQ::encode_vectors(
|
211
|
+
idx_t n,
|
212
|
+
const float* x,
|
213
|
+
const idx_t* list_nos,
|
214
|
+
uint8_t* codes,
|
215
|
+
bool include_listnos) const {
|
213
216
|
if (by_residual) {
|
214
|
-
float
|
215
|
-
ScopeDeleter<float> del
|
216
|
-
pq.compute_codes
|
217
|
+
float* to_encode = compute_residuals(quantizer, n, x, list_nos);
|
218
|
+
ScopeDeleter<float> del(to_encode);
|
219
|
+
pq.compute_codes(to_encode, codes, n);
|
217
220
|
} else {
|
218
|
-
pq.compute_codes
|
221
|
+
pq.compute_codes(x, codes, n);
|
219
222
|
}
|
220
223
|
|
221
224
|
if (include_listnos) {
|
222
225
|
size_t coarse_size = coarse_code_size();
|
223
226
|
for (idx_t i = n - 1; i >= 0; i--) {
|
224
|
-
uint8_t
|
225
|
-
memmove
|
226
|
-
|
227
|
-
encode_listno (list_nos[i], code);
|
227
|
+
uint8_t* code = codes + i * (coarse_size + code_size);
|
228
|
+
memmove(code + coarse_size, codes + i * code_size, code_size);
|
229
|
+
encode_listno(list_nos[i], code);
|
228
230
|
}
|
229
231
|
}
|
230
232
|
}
|
231
233
|
|
232
|
-
|
233
|
-
|
234
|
-
void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
|
235
|
-
float *x) const
|
236
|
-
{
|
237
|
-
size_t coarse_size = coarse_code_size ();
|
234
|
+
void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
|
235
|
+
size_t coarse_size = coarse_code_size();
|
238
236
|
|
239
237
|
#pragma omp parallel
|
240
238
|
{
|
241
|
-
std::vector<float> residual
|
239
|
+
std::vector<float> residual(d);
|
242
240
|
|
243
241
|
#pragma omp for
|
244
242
|
for (idx_t i = 0; i < n; i++) {
|
245
|
-
const uint8_t
|
246
|
-
int64_t list_no = decode_listno
|
247
|
-
float
|
248
|
-
pq.decode
|
243
|
+
const uint8_t* code = codes + i * (code_size + coarse_size);
|
244
|
+
int64_t list_no = decode_listno(code);
|
245
|
+
float* xi = x + i * d;
|
246
|
+
pq.decode(code + coarse_size, xi);
|
249
247
|
if (by_residual) {
|
250
|
-
quantizer->reconstruct
|
248
|
+
quantizer->reconstruct(list_no, residual.data());
|
251
249
|
for (size_t j = 0; j < d; j++) {
|
252
250
|
xi[j] += residual[j];
|
253
251
|
}
|
@@ -256,120 +254,127 @@ void IndexIVFPQ::sa_decode (idx_t n, const uint8_t *codes,
|
|
256
254
|
}
|
257
255
|
}
|
258
256
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
257
|
+
void IndexIVFPQ::add_core_o(
|
258
|
+
idx_t n,
|
259
|
+
const float* x,
|
260
|
+
const idx_t* xids,
|
261
|
+
float* residuals_2,
|
262
|
+
const idx_t* precomputed_idx) {
|
264
263
|
idx_t bs = 32768;
|
265
264
|
if (n > bs) {
|
266
265
|
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
267
266
|
idx_t i1 = std::min(i0 + bs, n);
|
268
267
|
if (verbose) {
|
269
|
-
printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
|
270
|
-
|
268
|
+
printf("IndexIVFPQ::add_core_o: adding %" PRId64 ":%" PRId64
|
269
|
+
" / %" PRId64 "\n",
|
270
|
+
i0,
|
271
|
+
i1,
|
272
|
+
n);
|
271
273
|
}
|
272
|
-
add_core_o
|
273
|
-
|
274
|
-
|
275
|
-
|
274
|
+
add_core_o(
|
275
|
+
i1 - i0,
|
276
|
+
x + i0 * d,
|
277
|
+
xids ? xids + i0 : nullptr,
|
278
|
+
residuals_2 ? residuals_2 + i0 * d : nullptr,
|
279
|
+
precomputed_idx ? precomputed_idx + i0 : nullptr);
|
276
280
|
}
|
277
281
|
return;
|
278
282
|
}
|
279
283
|
|
280
284
|
InterruptCallback::check();
|
281
285
|
|
282
|
-
direct_map.check_can_add
|
286
|
+
direct_map.check_can_add(xids);
|
283
287
|
|
284
|
-
FAISS_THROW_IF_NOT
|
285
|
-
double t0 = getmillisecs
|
286
|
-
const idx_t
|
288
|
+
FAISS_THROW_IF_NOT(is_trained);
|
289
|
+
double t0 = getmillisecs();
|
290
|
+
const idx_t* idx;
|
287
291
|
ScopeDeleter<idx_t> del_idx;
|
288
292
|
|
289
293
|
if (precomputed_idx) {
|
290
294
|
idx = precomputed_idx;
|
291
295
|
} else {
|
292
|
-
idx_t
|
293
|
-
del_idx.set
|
294
|
-
quantizer->assign
|
296
|
+
idx_t* idx0 = new idx_t[n];
|
297
|
+
del_idx.set(idx0);
|
298
|
+
quantizer->assign(n, x, idx0);
|
295
299
|
idx = idx0;
|
296
300
|
}
|
297
301
|
|
298
|
-
double t1 = getmillisecs
|
299
|
-
uint8_t
|
300
|
-
ScopeDeleter<uint8_t> del_xcodes
|
302
|
+
double t1 = getmillisecs();
|
303
|
+
uint8_t* xcodes = new uint8_t[n * code_size];
|
304
|
+
ScopeDeleter<uint8_t> del_xcodes(xcodes);
|
301
305
|
|
302
|
-
const float
|
306
|
+
const float* to_encode = nullptr;
|
303
307
|
ScopeDeleter<float> del_to_encode;
|
304
308
|
|
305
309
|
if (by_residual) {
|
306
|
-
to_encode = compute_residuals
|
307
|
-
del_to_encode.set
|
310
|
+
to_encode = compute_residuals(quantizer, n, x, idx);
|
311
|
+
del_to_encode.set(to_encode);
|
308
312
|
} else {
|
309
313
|
to_encode = x;
|
310
314
|
}
|
311
|
-
pq.compute_codes
|
315
|
+
pq.compute_codes(to_encode, xcodes, n);
|
312
316
|
|
313
|
-
double t2 = getmillisecs
|
317
|
+
double t2 = getmillisecs();
|
314
318
|
// TODO: parallelize?
|
315
319
|
size_t n_ignore = 0;
|
316
320
|
for (size_t i = 0; i < n; i++) {
|
317
321
|
idx_t key = idx[i];
|
318
322
|
idx_t id = xids ? xids[i] : ntotal + i;
|
319
323
|
if (key < 0) {
|
320
|
-
direct_map.add_single_id
|
321
|
-
n_ignore
|
324
|
+
direct_map.add_single_id(id, -1, 0);
|
325
|
+
n_ignore++;
|
322
326
|
if (residuals_2)
|
323
|
-
memset
|
327
|
+
memset(residuals_2, 0, sizeof(*residuals_2) * d);
|
324
328
|
continue;
|
325
329
|
}
|
326
330
|
|
327
|
-
uint8_t
|
328
|
-
size_t offset = invlists->add_entry
|
331
|
+
uint8_t* code = xcodes + i * code_size;
|
332
|
+
size_t offset = invlists->add_entry(key, id, code);
|
329
333
|
|
330
334
|
if (residuals_2) {
|
331
|
-
float
|
332
|
-
const float
|
333
|
-
pq.decode
|
335
|
+
float* res2 = residuals_2 + i * d;
|
336
|
+
const float* xi = to_encode + i * d;
|
337
|
+
pq.decode(code, res2);
|
334
338
|
for (int j = 0; j < d; j++)
|
335
339
|
res2[j] = xi[j] - res2[j];
|
336
340
|
}
|
337
341
|
|
338
|
-
direct_map.add_single_id
|
342
|
+
direct_map.add_single_id(id, key, offset);
|
339
343
|
}
|
340
344
|
|
341
|
-
double t3 = getmillisecs
|
342
|
-
if(verbose) {
|
345
|
+
double t3 = getmillisecs();
|
346
|
+
if (verbose) {
|
343
347
|
char comment[100] = {0};
|
344
348
|
if (n_ignore > 0)
|
345
|
-
snprintf
|
349
|
+
snprintf(comment, 100, "(%zd vectors ignored)", n_ignore);
|
346
350
|
printf(" add_core times: %.3f %.3f %.3f %s\n",
|
347
|
-
t1 - t0,
|
351
|
+
t1 - t0,
|
352
|
+
t2 - t1,
|
353
|
+
t3 - t2,
|
354
|
+
comment);
|
348
355
|
}
|
349
356
|
ntotal += n;
|
350
357
|
}
|
351
358
|
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
{
|
356
|
-
const uint8_t* code = invlists->get_single_code
|
359
|
+
void IndexIVFPQ::reconstruct_from_offset(
|
360
|
+
int64_t list_no,
|
361
|
+
int64_t offset,
|
362
|
+
float* recons) const {
|
363
|
+
const uint8_t* code = invlists->get_single_code(list_no, offset);
|
357
364
|
|
358
365
|
if (by_residual) {
|
359
366
|
std::vector<float> centroid(d);
|
360
|
-
quantizer->reconstruct
|
367
|
+
quantizer->reconstruct(list_no, centroid.data());
|
361
368
|
|
362
|
-
pq.decode
|
369
|
+
pq.decode(code, recons);
|
363
370
|
for (int i = 0; i < d; ++i) {
|
364
371
|
recons[i] += centroid[i];
|
365
372
|
}
|
366
373
|
} else {
|
367
|
-
pq.decode
|
374
|
+
pq.decode(code, recons);
|
368
375
|
}
|
369
376
|
}
|
370
377
|
|
371
|
-
|
372
|
-
|
373
378
|
/// 2G by default, accommodates tables up to PQ32 w/ 65536 centroids
|
374
379
|
size_t precomputed_table_max_bytes = ((size_t)1) << 31;
|
375
380
|
|
@@ -403,20 +408,18 @@ size_t precomputed_table_max_bytes = ((size_t)1) << 31;
|
|
403
408
|
* is faster when the length of the lists is > ksub * M.
|
404
409
|
*/
|
405
410
|
|
406
|
-
void initialize_IVFPQ_precomputed_table
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
)
|
413
|
-
{
|
411
|
+
void initialize_IVFPQ_precomputed_table(
|
412
|
+
int& use_precomputed_table,
|
413
|
+
const Index* quantizer,
|
414
|
+
const ProductQuantizer& pq,
|
415
|
+
AlignedTable<float>& precomputed_table,
|
416
|
+
bool verbose) {
|
414
417
|
size_t nlist = quantizer->ntotal;
|
415
418
|
size_t d = quantizer->d;
|
416
419
|
FAISS_THROW_IF_NOT(d == pq.d);
|
417
420
|
|
418
421
|
if (use_precomputed_table == -1) {
|
419
|
-
precomputed_table.resize
|
422
|
+
precomputed_table.resize(0);
|
420
423
|
return;
|
421
424
|
}
|
422
425
|
|
@@ -424,23 +427,23 @@ void initialize_IVFPQ_precomputed_table (
|
|
424
427
|
if (quantizer->metric_type == METRIC_INNER_PRODUCT) {
|
425
428
|
if (verbose) {
|
426
429
|
printf("IndexIVFPQ::precompute_table: precomputed "
|
427
|
-
|
430
|
+
"tables not needed for inner product quantizers\n");
|
428
431
|
}
|
429
|
-
precomputed_table.resize
|
432
|
+
precomputed_table.resize(0);
|
430
433
|
return;
|
431
434
|
}
|
432
|
-
const MultiIndexQuantizer
|
433
|
-
|
435
|
+
const MultiIndexQuantizer* miq =
|
436
|
+
dynamic_cast<const MultiIndexQuantizer*>(quantizer);
|
434
437
|
if (miq && pq.M % miq->pq.M == 0)
|
435
438
|
use_precomputed_table = 2;
|
436
439
|
else {
|
437
440
|
size_t table_size = pq.M * pq.ksub * nlist * sizeof(float);
|
438
441
|
if (table_size > precomputed_table_max_bytes) {
|
439
442
|
if (verbose) {
|
440
|
-
printf(
|
441
|
-
|
442
|
-
|
443
|
-
|
443
|
+
printf("IndexIVFPQ::precompute_table: not precomputing table, "
|
444
|
+
"it would be too big: %zd bytes (max %zd)\n",
|
445
|
+
table_size,
|
446
|
+
precomputed_table_max_bytes);
|
444
447
|
use_precomputed_table = 0;
|
445
448
|
}
|
446
449
|
return;
|
@@ -450,80 +453,68 @@ void initialize_IVFPQ_precomputed_table (
|
|
450
453
|
} // otherwise assume user has set appropriate flag on input
|
451
454
|
|
452
455
|
if (verbose) {
|
453
|
-
printf
|
454
|
-
use_precomputed_table);
|
456
|
+
printf("precomputing IVFPQ tables type %d\n", use_precomputed_table);
|
455
457
|
}
|
456
458
|
|
457
459
|
// squared norms of the PQ centroids
|
458
|
-
std::vector<float> r_norms
|
460
|
+
std::vector<float> r_norms(pq.M * pq.ksub, NAN);
|
459
461
|
for (int m = 0; m < pq.M; m++)
|
460
462
|
for (int j = 0; j < pq.ksub; j++)
|
461
|
-
r_norms
|
462
|
-
|
463
|
+
r_norms[m * pq.ksub + j] =
|
464
|
+
fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
|
463
465
|
|
464
466
|
if (use_precomputed_table == 1) {
|
465
|
-
|
466
|
-
|
467
|
-
std::vector<float> centroid (d);
|
467
|
+
precomputed_table.resize(nlist * pq.M * pq.ksub);
|
468
|
+
std::vector<float> centroid(d);
|
468
469
|
|
469
470
|
for (size_t i = 0; i < nlist; i++) {
|
470
|
-
quantizer->reconstruct
|
471
|
+
quantizer->reconstruct(i, centroid.data());
|
471
472
|
|
472
|
-
float
|
473
|
-
pq.compute_inner_prod_table
|
474
|
-
fvec_madd
|
473
|
+
float* tab = &precomputed_table[i * pq.M * pq.ksub];
|
474
|
+
pq.compute_inner_prod_table(centroid.data(), tab);
|
475
|
+
fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
|
475
476
|
}
|
476
477
|
} else if (use_precomputed_table == 2) {
|
477
|
-
const MultiIndexQuantizer
|
478
|
-
|
479
|
-
FAISS_THROW_IF_NOT
|
480
|
-
const ProductQuantizer
|
481
|
-
FAISS_THROW_IF_NOT
|
478
|
+
const MultiIndexQuantizer* miq =
|
479
|
+
dynamic_cast<const MultiIndexQuantizer*>(quantizer);
|
480
|
+
FAISS_THROW_IF_NOT(miq);
|
481
|
+
const ProductQuantizer& cpq = miq->pq;
|
482
|
+
FAISS_THROW_IF_NOT(pq.M % cpq.M == 0);
|
482
483
|
|
483
484
|
precomputed_table.resize(cpq.ksub * pq.M * pq.ksub);
|
484
485
|
|
485
486
|
// reorder PQ centroid table
|
486
|
-
std::vector<float> centroids
|
487
|
+
std::vector<float> centroids(d * cpq.ksub, NAN);
|
487
488
|
|
488
489
|
for (int m = 0; m < cpq.M; m++) {
|
489
490
|
for (size_t i = 0; i < cpq.ksub; i++) {
|
490
|
-
memcpy
|
491
|
-
|
492
|
-
|
491
|
+
memcpy(centroids.data() + i * d + m * cpq.dsub,
|
492
|
+
cpq.get_centroids(m, i),
|
493
|
+
sizeof(*centroids.data()) * cpq.dsub);
|
493
494
|
}
|
494
495
|
}
|
495
496
|
|
496
|
-
pq.compute_inner_prod_tables
|
497
|
-
|
497
|
+
pq.compute_inner_prod_tables(
|
498
|
+
cpq.ksub, centroids.data(), precomputed_table.data());
|
498
499
|
|
499
500
|
for (size_t i = 0; i < cpq.ksub; i++) {
|
500
|
-
float
|
501
|
-
fvec_madd
|
501
|
+
float* tab = &precomputed_table[i * pq.M * pq.ksub];
|
502
|
+
fvec_madd(pq.M * pq.ksub, r_norms.data(), 2.0, tab, tab);
|
502
503
|
}
|
503
|
-
|
504
504
|
}
|
505
|
-
|
506
505
|
}
|
507
506
|
|
508
|
-
void IndexIVFPQ::precompute_table
|
509
|
-
|
510
|
-
|
511
|
-
use_precomputed_table, quantizer, pq, precomputed_table,
|
512
|
-
verbose
|
513
|
-
);
|
507
|
+
void IndexIVFPQ::precompute_table() {
|
508
|
+
initialize_IVFPQ_precomputed_table(
|
509
|
+
use_precomputed_table, quantizer, pq, precomputed_table, verbose);
|
514
510
|
}
|
515
511
|
|
516
|
-
|
517
|
-
|
518
512
|
namespace {
|
519
513
|
|
520
514
|
using idx_t = Index::idx_t;
|
521
515
|
|
522
|
-
|
523
516
|
#define TIC t0 = get_cycles()
|
524
|
-
#define TOC get_cycles
|
525
|
-
|
526
|
-
|
517
|
+
#define TOC get_cycles() - t0
|
527
518
|
|
528
519
|
/** QueryTables manages the various ways of searching an
|
529
520
|
* IndexIVFPQ. The code contains a lot of branches, depending on:
|
@@ -533,43 +524,42 @@ using idx_t = Index::idx_t;
|
|
533
524
|
* - polysemous_ht: are we filtering with polysemous codes?
|
534
525
|
*/
|
535
526
|
struct QueryTables {
|
536
|
-
|
537
527
|
/*****************************************************
|
538
528
|
* General data from the IVFPQ
|
539
529
|
*****************************************************/
|
540
530
|
|
541
|
-
const IndexIVFPQ
|
542
|
-
const IVFSearchParameters
|
531
|
+
const IndexIVFPQ& ivfpq;
|
532
|
+
const IVFSearchParameters* params;
|
543
533
|
|
544
534
|
// copied from IndexIVFPQ for easier access
|
545
535
|
int d;
|
546
|
-
const ProductQuantizer
|
536
|
+
const ProductQuantizer& pq;
|
547
537
|
MetricType metric_type;
|
548
538
|
bool by_residual;
|
549
539
|
int use_precomputed_table;
|
550
540
|
int polysemous_ht;
|
551
541
|
|
552
542
|
// pre-allocated data buffers
|
553
|
-
float *
|
554
|
-
float *
|
543
|
+
float *sim_table, *sim_table_2;
|
544
|
+
float *residual_vec, *decoded_vec;
|
555
545
|
|
556
546
|
// single data buffer
|
557
547
|
std::vector<float> mem;
|
558
548
|
|
559
549
|
// for table pointers
|
560
|
-
std::vector<const float
|
561
|
-
|
562
|
-
explicit QueryTables
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
mem.resize
|
572
|
-
sim_table = mem.data
|
550
|
+
std::vector<const float*> sim_table_ptrs;
|
551
|
+
|
552
|
+
explicit QueryTables(
|
553
|
+
const IndexIVFPQ& ivfpq,
|
554
|
+
const IVFSearchParameters* params)
|
555
|
+
: ivfpq(ivfpq),
|
556
|
+
d(ivfpq.d),
|
557
|
+
pq(ivfpq.pq),
|
558
|
+
metric_type(ivfpq.metric_type),
|
559
|
+
by_residual(ivfpq.by_residual),
|
560
|
+
use_precomputed_table(ivfpq.use_precomputed_table) {
|
561
|
+
mem.resize(pq.ksub * pq.M * 2 + d * 2);
|
562
|
+
sim_table = mem.data();
|
573
563
|
sim_table_2 = sim_table + pq.ksub * pq.M;
|
574
564
|
residual_vec = sim_table_2 + pq.ksub * pq.M;
|
575
565
|
decoded_vec = residual_vec + d;
|
@@ -577,14 +567,14 @@ struct QueryTables {
|
|
577
567
|
// for polysemous
|
578
568
|
polysemous_ht = ivfpq.polysemous_ht;
|
579
569
|
if (auto ivfpq_params =
|
580
|
-
|
570
|
+
dynamic_cast<const IVFPQSearchParameters*>(params)) {
|
581
571
|
polysemous_ht = ivfpq_params->polysemous_ht;
|
582
572
|
}
|
583
|
-
if (polysemous_ht != 0)
|
584
|
-
q_code.resize
|
573
|
+
if (polysemous_ht != 0) {
|
574
|
+
q_code.resize(pq.code_size);
|
585
575
|
}
|
586
576
|
init_list_cycles = 0;
|
587
|
-
sim_table_ptrs.resize
|
577
|
+
sim_table_ptrs.resize(pq.M);
|
588
578
|
}
|
589
579
|
|
590
580
|
/*****************************************************
|
@@ -592,29 +582,29 @@ struct QueryTables {
|
|
592
582
|
*****************************************************/
|
593
583
|
|
594
584
|
// field specific to query
|
595
|
-
const float
|
585
|
+
const float* qi;
|
596
586
|
|
597
|
-
// query-specific
|
598
|
-
void init_query
|
587
|
+
// query-specific initialization
|
588
|
+
void init_query(const float* qi) {
|
599
589
|
this->qi = qi;
|
600
590
|
if (metric_type == METRIC_INNER_PRODUCT)
|
601
|
-
init_query_IP
|
591
|
+
init_query_IP();
|
602
592
|
else
|
603
|
-
init_query_L2
|
593
|
+
init_query_L2();
|
604
594
|
if (!by_residual && polysemous_ht != 0)
|
605
|
-
pq.compute_code
|
595
|
+
pq.compute_code(qi, q_code.data());
|
606
596
|
}
|
607
597
|
|
608
|
-
void init_query_IP
|
598
|
+
void init_query_IP() {
|
609
599
|
// precompute some tables specific to the query qi
|
610
|
-
pq.compute_inner_prod_table
|
600
|
+
pq.compute_inner_prod_table(qi, sim_table);
|
611
601
|
}
|
612
602
|
|
613
|
-
void init_query_L2
|
603
|
+
void init_query_L2() {
|
614
604
|
if (!by_residual) {
|
615
|
-
pq.compute_distance_table
|
605
|
+
pq.compute_distance_table(qi, sim_table);
|
616
606
|
} else if (use_precomputed_table) {
|
617
|
-
pq.compute_inner_prod_table
|
607
|
+
pq.compute_inner_prod_table(qi, sim_table_2);
|
618
608
|
}
|
619
609
|
}
|
620
610
|
|
@@ -632,96 +622,95 @@ struct QueryTables {
|
|
632
622
|
/// once we know the query and the centroid, we can prepare the
|
633
623
|
/// sim_table that will be used for accumulation
|
634
624
|
/// and dis0, the initial value
|
635
|
-
float precompute_list_tables
|
625
|
+
float precompute_list_tables() {
|
636
626
|
float dis0 = 0;
|
637
|
-
uint64_t t0;
|
627
|
+
uint64_t t0;
|
628
|
+
TIC;
|
638
629
|
if (by_residual) {
|
639
630
|
if (metric_type == METRIC_INNER_PRODUCT)
|
640
|
-
dis0 = precompute_list_tables_IP
|
631
|
+
dis0 = precompute_list_tables_IP();
|
641
632
|
else
|
642
|
-
dis0 = precompute_list_tables_L2
|
633
|
+
dis0 = precompute_list_tables_L2();
|
643
634
|
}
|
644
635
|
init_list_cycles += TOC;
|
645
636
|
return dis0;
|
646
|
-
|
637
|
+
}
|
647
638
|
|
648
|
-
float precompute_list_table_pointers
|
639
|
+
float precompute_list_table_pointers() {
|
649
640
|
float dis0 = 0;
|
650
|
-
uint64_t t0;
|
641
|
+
uint64_t t0;
|
642
|
+
TIC;
|
651
643
|
if (by_residual) {
|
652
644
|
if (metric_type == METRIC_INNER_PRODUCT)
|
653
|
-
|
645
|
+
FAISS_THROW_MSG("not implemented");
|
654
646
|
else
|
655
|
-
|
647
|
+
dis0 = precompute_list_table_pointers_L2();
|
656
648
|
}
|
657
649
|
init_list_cycles += TOC;
|
658
650
|
return dis0;
|
659
|
-
|
651
|
+
}
|
660
652
|
|
661
653
|
/*****************************************************
|
662
654
|
* compute tables for inner prod
|
663
655
|
*****************************************************/
|
664
656
|
|
665
|
-
float precompute_list_tables_IP
|
666
|
-
{
|
657
|
+
float precompute_list_tables_IP() {
|
667
658
|
// prepare the sim_table that will be used for accumulation
|
668
659
|
// and dis0, the initial value
|
669
|
-
ivfpq.quantizer->reconstruct
|
660
|
+
ivfpq.quantizer->reconstruct(key, decoded_vec);
|
670
661
|
// decoded_vec = centroid
|
671
|
-
float dis0 = fvec_inner_product
|
662
|
+
float dis0 = fvec_inner_product(qi, decoded_vec, d);
|
672
663
|
|
673
664
|
if (polysemous_ht) {
|
674
665
|
for (int i = 0; i < d; i++) {
|
675
|
-
residual_vec
|
666
|
+
residual_vec[i] = qi[i] - decoded_vec[i];
|
676
667
|
}
|
677
|
-
pq.compute_code
|
668
|
+
pq.compute_code(residual_vec, q_code.data());
|
678
669
|
}
|
679
670
|
return dis0;
|
680
671
|
}
|
681
672
|
|
682
|
-
|
683
673
|
/*****************************************************
|
684
674
|
* compute tables for L2 distance
|
685
675
|
*****************************************************/
|
686
676
|
|
687
|
-
float precompute_list_tables_L2
|
688
|
-
{
|
677
|
+
float precompute_list_tables_L2() {
|
689
678
|
float dis0 = 0;
|
690
679
|
|
691
680
|
if (use_precomputed_table == 0 || use_precomputed_table == -1) {
|
692
|
-
ivfpq.quantizer->compute_residual
|
693
|
-
pq.compute_distance_table
|
681
|
+
ivfpq.quantizer->compute_residual(qi, residual_vec, key);
|
682
|
+
pq.compute_distance_table(residual_vec, sim_table);
|
694
683
|
|
695
684
|
if (polysemous_ht != 0) {
|
696
|
-
pq.compute_code
|
685
|
+
pq.compute_code(residual_vec, q_code.data());
|
697
686
|
}
|
698
687
|
|
699
688
|
} else if (use_precomputed_table == 1) {
|
700
689
|
dis0 = coarse_dis;
|
701
690
|
|
702
|
-
fvec_madd
|
691
|
+
fvec_madd(
|
703
692
|
pq.M * pq.ksub,
|
704
693
|
ivfpq.precomputed_table.data() + key * pq.ksub * pq.M,
|
705
|
-
-2.0,
|
706
|
-
|
707
|
-
|
694
|
+
-2.0,
|
695
|
+
sim_table_2,
|
696
|
+
sim_table);
|
708
697
|
|
709
698
|
if (polysemous_ht != 0) {
|
710
|
-
ivfpq.quantizer->compute_residual
|
711
|
-
pq.compute_code
|
699
|
+
ivfpq.quantizer->compute_residual(qi, residual_vec, key);
|
700
|
+
pq.compute_code(residual_vec, q_code.data());
|
712
701
|
}
|
713
702
|
|
714
703
|
} else if (use_precomputed_table == 2) {
|
715
704
|
dis0 = coarse_dis;
|
716
705
|
|
717
|
-
const MultiIndexQuantizer
|
718
|
-
|
719
|
-
FAISS_THROW_IF_NOT
|
720
|
-
const ProductQuantizer
|
706
|
+
const MultiIndexQuantizer* miq =
|
707
|
+
dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
|
708
|
+
FAISS_THROW_IF_NOT(miq);
|
709
|
+
const ProductQuantizer& cpq = miq->pq;
|
721
710
|
int Mf = pq.M / cpq.M;
|
722
711
|
|
723
|
-
const float
|
724
|
-
float
|
712
|
+
const float* qtab = sim_table_2; // query-specific table
|
713
|
+
float* ltab = sim_table; // (output) list-specific table
|
725
714
|
|
726
715
|
long k = key;
|
727
716
|
for (int cm = 0; cm < cpq.M; cm++) {
|
@@ -730,54 +719,48 @@ struct QueryTables {
|
|
730
719
|
k >>= cpq.nbits;
|
731
720
|
|
732
721
|
// get corresponding table
|
733
|
-
const float
|
734
|
-
|
722
|
+
const float* pc = ivfpq.precomputed_table.data() +
|
723
|
+
(ki * pq.M + cm * Mf) * pq.ksub;
|
735
724
|
|
736
725
|
if (polysemous_ht == 0) {
|
737
|
-
|
738
726
|
// sum up with query-specific table
|
739
|
-
fvec_madd
|
740
|
-
pc,
|
741
|
-
-2.0, qtab,
|
742
|
-
ltab);
|
727
|
+
fvec_madd(Mf * pq.ksub, pc, -2.0, qtab, ltab);
|
743
728
|
ltab += Mf * pq.ksub;
|
744
729
|
qtab += Mf * pq.ksub;
|
745
730
|
} else {
|
746
731
|
for (int m = cm * Mf; m < (cm + 1) * Mf; m++) {
|
747
|
-
q_code[m] = fvec_madd_and_argmin
|
748
|
-
|
732
|
+
q_code[m] = fvec_madd_and_argmin(
|
733
|
+
pq.ksub, pc, -2, qtab, ltab);
|
749
734
|
pc += pq.ksub;
|
750
735
|
ltab += pq.ksub;
|
751
736
|
qtab += pq.ksub;
|
752
737
|
}
|
753
738
|
}
|
754
|
-
|
755
739
|
}
|
756
740
|
}
|
757
741
|
|
758
742
|
return dis0;
|
759
743
|
}
|
760
744
|
|
761
|
-
float precompute_list_table_pointers_L2
|
762
|
-
{
|
745
|
+
float precompute_list_table_pointers_L2() {
|
763
746
|
float dis0 = 0;
|
764
747
|
|
765
748
|
if (use_precomputed_table == 1) {
|
766
749
|
dis0 = coarse_dis;
|
767
750
|
|
768
|
-
const float
|
769
|
-
key * pq.ksub * pq.M;
|
751
|
+
const float* s =
|
752
|
+
ivfpq.precomputed_table.data() + key * pq.ksub * pq.M;
|
770
753
|
for (int m = 0; m < pq.M; m++) {
|
771
|
-
sim_table_ptrs
|
754
|
+
sim_table_ptrs[m] = s;
|
772
755
|
s += pq.ksub;
|
773
756
|
}
|
774
757
|
} else if (use_precomputed_table == 2) {
|
775
758
|
dis0 = coarse_dis;
|
776
759
|
|
777
|
-
const MultiIndexQuantizer
|
778
|
-
|
779
|
-
FAISS_THROW_IF_NOT
|
780
|
-
const ProductQuantizer
|
760
|
+
const MultiIndexQuantizer* miq =
|
761
|
+
dynamic_cast<const MultiIndexQuantizer*>(ivfpq.quantizer);
|
762
|
+
FAISS_THROW_IF_NOT(miq);
|
763
|
+
const ProductQuantizer& cpq = miq->pq;
|
781
764
|
int Mf = pq.M / cpq.M;
|
782
765
|
|
783
766
|
long k = key;
|
@@ -786,21 +769,21 @@ struct QueryTables {
|
|
786
769
|
int ki = k & ((uint64_t(1) << cpq.nbits) - 1);
|
787
770
|
k >>= cpq.nbits;
|
788
771
|
|
789
|
-
const float
|
790
|
-
|
772
|
+
const float* pc = ivfpq.precomputed_table.data() +
|
773
|
+
(ki * pq.M + cm * Mf) * pq.ksub;
|
791
774
|
|
792
775
|
for (int m = m0; m < m0 + Mf; m++) {
|
793
|
-
sim_table_ptrs
|
776
|
+
sim_table_ptrs[m] = pc;
|
794
777
|
pc += pq.ksub;
|
795
778
|
}
|
796
779
|
m0 += Mf;
|
797
780
|
}
|
798
781
|
} else {
|
799
|
-
|
782
|
+
FAISS_THROW_MSG("need precomputed tables");
|
800
783
|
}
|
801
784
|
|
802
785
|
if (polysemous_ht) {
|
803
|
-
FAISS_THROW_MSG
|
786
|
+
FAISS_THROW_MSG("not implemented");
|
804
787
|
// Not clear that it makes sense to implemente this,
|
805
788
|
// because it costs M * ksub, which is what we wanted to
|
806
789
|
// avoid with the tables pointers.
|
@@ -808,82 +791,72 @@ struct QueryTables {
|
|
808
791
|
|
809
792
|
return dis0;
|
810
793
|
}
|
811
|
-
|
812
|
-
|
813
794
|
};
|
814
795
|
|
815
|
-
|
816
|
-
|
817
|
-
template<class C>
|
796
|
+
template <class C>
|
818
797
|
struct KnnSearchResults {
|
819
798
|
idx_t key;
|
820
|
-
const idx_t
|
799
|
+
const idx_t* ids;
|
821
800
|
|
822
801
|
// heap params
|
823
802
|
size_t k;
|
824
|
-
float
|
825
|
-
idx_t
|
803
|
+
float* heap_sim;
|
804
|
+
idx_t* heap_ids;
|
826
805
|
|
827
806
|
size_t nup;
|
828
807
|
|
829
|
-
inline void add
|
830
|
-
if (C::cmp
|
831
|
-
idx_t id = ids ? ids[j] : lo_build
|
832
|
-
heap_replace_top<C>
|
808
|
+
inline void add(idx_t j, float dis) {
|
809
|
+
if (C::cmp(heap_sim[0], dis)) {
|
810
|
+
idx_t id = ids ? ids[j] : lo_build(key, j);
|
811
|
+
heap_replace_top<C>(k, heap_sim, heap_ids, dis, id);
|
833
812
|
nup++;
|
834
813
|
}
|
835
814
|
}
|
836
|
-
|
837
815
|
};
|
838
816
|
|
839
|
-
template<class C>
|
817
|
+
template <class C>
|
840
818
|
struct RangeSearchResults {
|
841
819
|
idx_t key;
|
842
|
-
const idx_t
|
820
|
+
const idx_t* ids;
|
843
821
|
|
844
822
|
// wrapped result structure
|
845
823
|
float radius;
|
846
|
-
RangeQueryResult
|
824
|
+
RangeQueryResult& rres;
|
847
825
|
|
848
|
-
inline void add
|
849
|
-
if (C::cmp
|
850
|
-
idx_t id = ids ? ids[j] : lo_build
|
851
|
-
rres.add
|
826
|
+
inline void add(idx_t j, float dis) {
|
827
|
+
if (C::cmp(radius, dis)) {
|
828
|
+
idx_t id = ids ? ids[j] : lo_build(key, j);
|
829
|
+
rres.add(dis, id);
|
852
830
|
}
|
853
831
|
}
|
854
832
|
};
|
855
833
|
|
856
|
-
|
857
|
-
|
858
834
|
/*****************************************************
|
859
835
|
* Scaning the codes.
|
860
836
|
* The scanning functions call their favorite precompute_*
|
861
837
|
* function to precompute the tables they need.
|
862
838
|
*****************************************************/
|
863
839
|
template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
|
864
|
-
struct IVFPQScannerT: QueryTables {
|
865
|
-
|
866
|
-
const
|
867
|
-
const IDType * list_ids;
|
840
|
+
struct IVFPQScannerT : QueryTables {
|
841
|
+
const uint8_t* list_codes;
|
842
|
+
const IDType* list_ids;
|
868
843
|
size_t list_size;
|
869
844
|
|
870
|
-
IVFPQScannerT
|
871
|
-
|
872
|
-
{
|
845
|
+
IVFPQScannerT(const IndexIVFPQ& ivfpq, const IVFSearchParameters* params)
|
846
|
+
: QueryTables(ivfpq, params) {
|
873
847
|
assert(METRIC_TYPE == metric_type);
|
874
848
|
}
|
875
849
|
|
876
850
|
float dis0;
|
877
851
|
|
878
|
-
void init_list
|
879
|
-
int mode) {
|
852
|
+
void init_list(idx_t list_no, float coarse_dis, int mode) {
|
880
853
|
this->key = list_no;
|
881
854
|
this->coarse_dis = coarse_dis;
|
882
855
|
|
883
856
|
if (mode == 2) {
|
884
|
-
dis0 = precompute_list_tables
|
857
|
+
dis0 = precompute_list_tables();
|
885
858
|
} else if (mode == 1) {
|
886
|
-
dis0 = precompute_list_table_pointers
|
859
|
+
dis0 = precompute_list_table_pointers();
|
887
860
|
}
|
888
861
|
}
|
889
862
|
|
@@ -892,15 +865,16 @@ struct IVFPQScannerT: QueryTables {
|
|
892
865
|
*****************************************************/
|
893
866
|
|
894
867
|
/// version of the scan where we use precomputed tables
|
895
|
-
template<class SearchResultType>
|
896
|
-
void scan_list_with_table
|
897
|
-
|
898
|
-
|
868
|
+
template <class SearchResultType>
|
869
|
+
void scan_list_with_table(
|
870
|
+
size_t ncode,
|
871
|
+
const uint8_t* codes,
|
872
|
+
SearchResultType& res) const {
|
899
873
|
for (size_t j = 0; j < ncode; j++) {
|
900
874
|
PQDecoder decoder(codes, pq.nbits);
|
901
875
|
codes += pq.code_size;
|
902
876
|
float dis = dis0;
|
903
|
-
const float
|
877
|
+
const float* tab = sim_table;
|
904
878
|
|
905
879
|
for (size_t m = 0; m < pq.M; m++) {
|
906
880
|
dis += tab[decoder.decode()];
|
@@ -911,43 +885,43 @@ struct IVFPQScannerT: QueryTables {
|
|
911
885
|
}
|
912
886
|
}
|
913
887
|
|
914
|
-
|
915
888
|
/// tables are not precomputed, but pointers are provided to the
|
916
889
|
/// relevant X_c|x_r tables
|
917
|
-
template<class SearchResultType>
|
918
|
-
void scan_list_with_pointer
|
919
|
-
|
920
|
-
|
890
|
+
template <class SearchResultType>
|
891
|
+
void scan_list_with_pointer(
|
892
|
+
size_t ncode,
|
893
|
+
const uint8_t* codes,
|
894
|
+
SearchResultType& res) const {
|
921
895
|
for (size_t j = 0; j < ncode; j++) {
|
922
896
|
PQDecoder decoder(codes, pq.nbits);
|
923
897
|
codes += pq.code_size;
|
924
898
|
|
925
899
|
float dis = dis0;
|
926
|
-
const float
|
900
|
+
const float* tab = sim_table_2;
|
927
901
|
|
928
902
|
for (size_t m = 0; m < pq.M; m++) {
|
929
903
|
int ci = decoder.decode();
|
930
|
-
dis += sim_table_ptrs
|
904
|
+
dis += sim_table_ptrs[m][ci] - 2 * tab[ci];
|
931
905
|
tab += pq.ksub;
|
932
906
|
}
|
933
|
-
res.add
|
907
|
+
res.add(j, dis);
|
934
908
|
}
|
935
909
|
}
|
936
910
|
|
937
|
-
|
938
911
|
/// nothing is precomputed: access residuals on-the-fly
|
939
|
-
template<class SearchResultType>
|
940
|
-
void scan_on_the_fly_dist
|
941
|
-
|
942
|
-
|
943
|
-
|
912
|
+
template <class SearchResultType>
|
913
|
+
void scan_on_the_fly_dist(
|
914
|
+
size_t ncode,
|
915
|
+
const uint8_t* codes,
|
916
|
+
SearchResultType& res) const {
|
917
|
+
const float* dvec;
|
944
918
|
float dis0 = 0;
|
945
919
|
if (by_residual) {
|
946
920
|
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
|
947
|
-
ivfpq.quantizer->reconstruct
|
948
|
-
dis0 = fvec_inner_product
|
921
|
+
ivfpq.quantizer->reconstruct(key, residual_vec);
|
922
|
+
dis0 = fvec_inner_product(residual_vec, qi, d);
|
949
923
|
} else {
|
950
|
-
ivfpq.quantizer->compute_residual
|
924
|
+
ivfpq.quantizer->compute_residual(qi, residual_vec, key);
|
951
925
|
}
|
952
926
|
dvec = residual_vec;
|
953
927
|
} else {
|
@@ -956,17 +930,16 @@ struct IVFPQScannerT: QueryTables {
|
|
956
930
|
}
|
957
931
|
|
958
932
|
for (size_t j = 0; j < ncode; j++) {
|
959
|
-
|
960
|
-
pq.decode (codes, decoded_vec);
|
933
|
+
pq.decode(codes, decoded_vec);
|
961
934
|
codes += pq.code_size;
|
962
935
|
|
963
936
|
float dis;
|
964
937
|
if (METRIC_TYPE == METRIC_INNER_PRODUCT) {
|
965
|
-
dis = dis0 + fvec_inner_product
|
938
|
+
dis = dis0 + fvec_inner_product(decoded_vec, qi, d);
|
966
939
|
} else {
|
967
|
-
dis = fvec_L2sqr
|
940
|
+
dis = fvec_L2sqr(decoded_vec, dvec, d);
|
968
941
|
}
|
969
|
-
res.add
|
942
|
+
res.add(j, dis);
|
970
943
|
}
|
971
944
|
}
|
972
945
|
|
@@ -975,110 +948,99 @@ struct IVFPQScannerT: QueryTables {
|
|
975
948
|
*****************************************************/
|
976
949
|
|
977
950
|
template <class HammingComputer, class SearchResultType>
|
978
|
-
void scan_list_polysemous_hc
|
979
|
-
|
980
|
-
|
981
|
-
|
951
|
+
void scan_list_polysemous_hc(
|
952
|
+
size_t ncode,
|
953
|
+
const uint8_t* codes,
|
954
|
+
SearchResultType& res) const {
|
982
955
|
int ht = ivfpq.polysemous_ht;
|
983
956
|
size_t n_hamming_pass = 0, nup = 0;
|
984
957
|
|
985
958
|
int code_size = pq.code_size;
|
986
959
|
|
987
|
-
HammingComputer hc
|
960
|
+
HammingComputer hc(q_code.data(), code_size);
|
988
961
|
|
989
962
|
for (size_t j = 0; j < ncode; j++) {
|
990
|
-
const uint8_t
|
991
|
-
int hd = hc.hamming
|
963
|
+
const uint8_t* b_code = codes;
|
964
|
+
int hd = hc.hamming(b_code);
|
992
965
|
if (hd < ht) {
|
993
|
-
n_hamming_pass
|
966
|
+
n_hamming_pass++;
|
994
967
|
PQDecoder decoder(codes, pq.nbits);
|
995
968
|
|
996
969
|
float dis = dis0;
|
997
|
-
const float
|
970
|
+
const float* tab = sim_table;
|
998
971
|
|
999
972
|
for (size_t m = 0; m < pq.M; m++) {
|
1000
973
|
dis += tab[decoder.decode()];
|
1001
974
|
tab += pq.ksub;
|
1002
975
|
}
|
1003
976
|
|
1004
|
-
res.add
|
977
|
+
res.add(j, dis);
|
1005
978
|
}
|
1006
979
|
codes += code_size;
|
1007
980
|
}
|
1008
981
|
#pragma omp critical
|
1009
|
-
{
|
1010
|
-
indexIVFPQ_stats.n_hamming_pass += n_hamming_pass;
|
1011
|
-
}
|
982
|
+
{ indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
|
1012
983
|
}
|
1013
984
|
|
1014
|
-
template<class SearchResultType>
|
1015
|
-
void scan_list_polysemous
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
985
|
+
template <class SearchResultType>
|
986
|
+
void scan_list_polysemous(
|
987
|
+
size_t ncode,
|
988
|
+
const uint8_t* codes,
|
989
|
+
SearchResultType& res) const {
|
1019
990
|
switch (pq.code_size) {
|
1020
991
|
#define HANDLE_CODE_SIZE(cs) \
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
HANDLE_CODE_SIZE(64);
|
992
|
+
case cs: \
|
993
|
+
scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
|
994
|
+
ncode, codes, res); \
|
995
|
+
break
|
996
|
+
HANDLE_CODE_SIZE(4);
|
997
|
+
HANDLE_CODE_SIZE(8);
|
998
|
+
HANDLE_CODE_SIZE(16);
|
999
|
+
HANDLE_CODE_SIZE(20);
|
1000
|
+
HANDLE_CODE_SIZE(32);
|
1001
|
+
HANDLE_CODE_SIZE(64);
|
1032
1002
|
#undef HANDLE_CODE_SIZE
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
else
|
1039
|
-
scan_list_polysemous_hc
|
1040
|
-
<HammingComputerM4, SearchResultType>
|
1041
|
-
(ncode, codes, res);
|
1042
|
-
break;
|
1003
|
+
default:
|
1004
|
+
scan_list_polysemous_hc<
|
1005
|
+
HammingComputerDefault,
|
1006
|
+
SearchResultType>(ncode, codes, res);
|
1007
|
+
break;
|
1043
1008
|
}
|
1044
1009
|
}
|
1045
|
-
|
1046
1010
|
};
|
1047
1011
|
|
1048
|
-
|
1049
1012
|
/* We put as many parameters as possible in template. Hopefully the
|
1050
1013
|
* gain in runtime is worth the code bloat. C is the comparator < or
|
1051
1014
|
* >, it is directly related to METRIC_TYPE. precompute_mode is how
|
1052
1015
|
* much we precompute (2 = precompute distance tables, 1 = precompute
|
1053
1016
|
* pointers to distances, 0 = compute distances one by one).
|
1054
1017
|
* Currently only 2 is supported */
|
1055
|
-
template<MetricType METRIC_TYPE, class C, class PQDecoder>
|
1056
|
-
struct IVFPQScanner:
|
1057
|
-
|
1058
|
-
InvertedListScanner
|
1059
|
-
{
|
1060
|
-
bool store_pairs;
|
1018
|
+
template <MetricType METRIC_TYPE, class C, class PQDecoder>
|
1019
|
+
struct IVFPQScanner : IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>,
|
1020
|
+
InvertedListScanner {
|
1061
1021
|
int precompute_mode;
|
1062
1022
|
|
1063
|
-
IVFPQScanner(const IndexIVFPQ
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1023
|
+
IVFPQScanner(const IndexIVFPQ& ivfpq, bool store_pairs, int precompute_mode)
|
1024
|
+
: IVFPQScannerT<Index::idx_t, METRIC_TYPE, PQDecoder>(
|
1025
|
+
ivfpq,
|
1026
|
+
nullptr),
|
1027
|
+
precompute_mode(precompute_mode) {
|
1028
|
+
this->store_pairs = store_pairs;
|
1068
1029
|
}
|
1069
1030
|
|
1070
|
-
void set_query
|
1071
|
-
this->init_query
|
1031
|
+
void set_query(const float* query) override {
|
1032
|
+
this->init_query(query);
|
1072
1033
|
}
|
1073
1034
|
|
1074
|
-
void set_list
|
1075
|
-
this->
|
1035
|
+
void set_list(idx_t list_no, float coarse_dis) override {
|
1036
|
+
this->list_no = list_no;
|
1037
|
+
this->init_list(list_no, coarse_dis, precompute_mode);
|
1076
1038
|
}
|
1077
1039
|
|
1078
|
-
float distance_to_code
|
1040
|
+
float distance_to_code(const uint8_t* code) const override {
|
1079
1041
|
assert(precompute_mode == 2);
|
1080
1042
|
float dis = this->dis0;
|
1081
|
-
const float
|
1043
|
+
const float* tab = this->sim_table;
|
1082
1044
|
PQDecoder decoder(code, this->pq.nbits);
|
1083
1045
|
|
1084
1046
|
for (size_t m = 0; m < this->pq.M; m++) {
|
@@ -1088,112 +1050,100 @@ struct IVFPQScanner:
|
|
1088
1050
|
return dis;
|
1089
1051
|
}
|
1090
1052
|
|
1091
|
-
size_t scan_codes
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1053
|
+
size_t scan_codes(
|
1054
|
+
size_t ncode,
|
1055
|
+
const uint8_t* codes,
|
1056
|
+
const idx_t* ids,
|
1057
|
+
float* heap_sim,
|
1058
|
+
idx_t* heap_ids,
|
1059
|
+
size_t k) const override {
|
1097
1060
|
KnnSearchResults<C> res = {
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
};
|
1061
|
+
/* key */ this->key,
|
1062
|
+
/* ids */ this->store_pairs ? nullptr : ids,
|
1063
|
+
/* k */ k,
|
1064
|
+
/* heap_sim */ heap_sim,
|
1065
|
+
/* heap_ids */ heap_ids,
|
1066
|
+
/* nup */ 0};
|
1105
1067
|
|
1106
1068
|
if (this->polysemous_ht > 0) {
|
1107
1069
|
assert(precompute_mode == 2);
|
1108
|
-
this->scan_list_polysemous
|
1070
|
+
this->scan_list_polysemous(ncode, codes, res);
|
1109
1071
|
} else if (precompute_mode == 2) {
|
1110
|
-
this->scan_list_with_table
|
1072
|
+
this->scan_list_with_table(ncode, codes, res);
|
1111
1073
|
} else if (precompute_mode == 1) {
|
1112
|
-
this->scan_list_with_pointer
|
1074
|
+
this->scan_list_with_pointer(ncode, codes, res);
|
1113
1075
|
} else if (precompute_mode == 0) {
|
1114
|
-
this->scan_on_the_fly_dist
|
1076
|
+
this->scan_on_the_fly_dist(ncode, codes, res);
|
1115
1077
|
} else {
|
1116
1078
|
FAISS_THROW_MSG("bad precomp mode");
|
1117
1079
|
}
|
1118
1080
|
return res.nup;
|
1119
1081
|
}
|
1120
1082
|
|
1121
|
-
void scan_codes_range
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1083
|
+
void scan_codes_range(
|
1084
|
+
size_t ncode,
|
1085
|
+
const uint8_t* codes,
|
1086
|
+
const idx_t* ids,
|
1087
|
+
float radius,
|
1088
|
+
RangeQueryResult& rres) const override {
|
1127
1089
|
RangeSearchResults<C> res = {
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
};
|
1090
|
+
/* key */ this->key,
|
1091
|
+
/* ids */ this->store_pairs ? nullptr : ids,
|
1092
|
+
/* radius */ radius,
|
1093
|
+
/* rres */ rres};
|
1133
1094
|
|
1134
1095
|
if (this->polysemous_ht > 0) {
|
1135
1096
|
assert(precompute_mode == 2);
|
1136
|
-
this->scan_list_polysemous
|
1097
|
+
this->scan_list_polysemous(ncode, codes, res);
|
1137
1098
|
} else if (precompute_mode == 2) {
|
1138
|
-
this->scan_list_with_table
|
1099
|
+
this->scan_list_with_table(ncode, codes, res);
|
1139
1100
|
} else if (precompute_mode == 1) {
|
1140
|
-
this->scan_list_with_pointer
|
1101
|
+
this->scan_list_with_pointer(ncode, codes, res);
|
1141
1102
|
} else if (precompute_mode == 0) {
|
1142
|
-
this->scan_on_the_fly_dist
|
1103
|
+
this->scan_on_the_fly_dist(ncode, codes, res);
|
1143
1104
|
} else {
|
1144
1105
|
FAISS_THROW_MSG("bad precomp mode");
|
1145
1106
|
}
|
1146
|
-
|
1147
1107
|
}
|
1148
1108
|
};
|
1149
1109
|
|
1150
|
-
template<class PQDecoder>
|
1151
|
-
InvertedListScanner
|
1152
|
-
|
1153
|
-
{
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1110
|
+
template <class PQDecoder>
|
1111
|
+
InvertedListScanner* get_InvertedListScanner1(
|
1112
|
+
const IndexIVFPQ& index,
|
1113
|
+
bool store_pairs) {
|
1114
|
+
if (index.metric_type == METRIC_INNER_PRODUCT) {
|
1115
|
+
return new IVFPQScanner<
|
1116
|
+
METRIC_INNER_PRODUCT,
|
1117
|
+
CMin<float, idx_t>,
|
1118
|
+
PQDecoder>(index, store_pairs, 2);
|
1159
1119
|
} else if (index.metric_type == METRIC_L2) {
|
1160
|
-
return new IVFPQScanner
|
1161
|
-
|
1162
|
-
(index, store_pairs, 2);
|
1120
|
+
return new IVFPQScanner<METRIC_L2, CMax<float, idx_t>, PQDecoder>(
|
1121
|
+
index, store_pairs, 2);
|
1163
1122
|
}
|
1164
1123
|
return nullptr;
|
1165
1124
|
}
|
1166
1125
|
|
1167
|
-
|
1168
1126
|
} // anonymous namespace
|
1169
1127
|
|
1170
|
-
InvertedListScanner
|
1171
|
-
|
1172
|
-
{
|
1173
|
-
|
1128
|
+
InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
|
1129
|
+
bool store_pairs) const {
|
1174
1130
|
if (pq.nbits == 8) {
|
1175
|
-
return get_InvertedListScanner1<PQDecoder8>
|
1131
|
+
return get_InvertedListScanner1<PQDecoder8>(*this, store_pairs);
|
1176
1132
|
} else if (pq.nbits == 16) {
|
1177
|
-
return get_InvertedListScanner1<PQDecoder16>
|
1133
|
+
return get_InvertedListScanner1<PQDecoder16>(*this, store_pairs);
|
1178
1134
|
} else {
|
1179
|
-
return get_InvertedListScanner1<PQDecoderGeneric>
|
1135
|
+
return get_InvertedListScanner1<PQDecoderGeneric>(*this, store_pairs);
|
1180
1136
|
}
|
1181
1137
|
return nullptr;
|
1182
|
-
|
1183
1138
|
}
|
1184
1139
|
|
1185
|
-
|
1186
|
-
|
1187
1140
|
IndexIVFPQStats indexIVFPQ_stats;
|
1188
1141
|
|
1189
|
-
void IndexIVFPQStats::reset
|
1190
|
-
memset
|
1142
|
+
void IndexIVFPQStats::reset() {
|
1143
|
+
memset(this, 0, sizeof(*this));
|
1191
1144
|
}
|
1192
1145
|
|
1193
|
-
|
1194
|
-
|
1195
|
-
IndexIVFPQ::IndexIVFPQ ()
|
1196
|
-
{
|
1146
|
+
IndexIVFPQ::IndexIVFPQ() {
|
1197
1147
|
// initialize some runtime values
|
1198
1148
|
use_precomputed_table = 0;
|
1199
1149
|
scan_table_threshold = 0;
|
@@ -1202,43 +1152,40 @@ IndexIVFPQ::IndexIVFPQ ()
|
|
1202
1152
|
polysemous_training = nullptr;
|
1203
1153
|
}
|
1204
1154
|
|
1205
|
-
|
1206
1155
|
struct CodeCmp {
|
1207
|
-
const uint8_t
|
1156
|
+
const uint8_t* tab;
|
1208
1157
|
size_t code_size;
|
1209
|
-
bool operator
|
1210
|
-
return cmp
|
1158
|
+
bool operator()(int a, int b) const {
|
1159
|
+
return cmp(a, b) > 0;
|
1211
1160
|
}
|
1212
|
-
int cmp
|
1213
|
-
return memcmp
|
1214
|
-
code_size);
|
1161
|
+
int cmp(int a, int b) const {
|
1162
|
+
return memcmp(tab + a * code_size, tab + b * code_size, code_size);
|
1215
1163
|
}
|
1216
1164
|
};
|
1217
1165
|
|
1218
|
-
|
1219
|
-
size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
|
1220
|
-
{
|
1166
|
+
size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
|
1221
1167
|
size_t ngroup = 0;
|
1222
1168
|
lims[0] = 0;
|
1223
1169
|
for (size_t list_no = 0; list_no < nlist; list_no++) {
|
1224
|
-
size_t n = invlists->list_size
|
1225
|
-
std::vector<int> ord
|
1226
|
-
for (int i = 0; i < n; i++)
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1170
|
+
size_t n = invlists->list_size(list_no);
|
1171
|
+
std::vector<int> ord(n);
|
1172
|
+
for (int i = 0; i < n; i++)
|
1173
|
+
ord[i] = i;
|
1174
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
1175
|
+
CodeCmp cs = {codes.get(), code_size};
|
1176
|
+
std::sort(ord.begin(), ord.end(), cs);
|
1177
|
+
|
1178
|
+
InvertedLists::ScopedIds list_ids(invlists, list_no);
|
1179
|
+
int prev = -1; // all elements from prev to i-1 are equal
|
1233
1180
|
for (int i = 0; i < n; i++) {
|
1234
|
-
if (prev >= 0 && cs.cmp
|
1181
|
+
if (prev >= 0 && cs.cmp(ord[prev], ord[i]) == 0) {
|
1235
1182
|
// same as previous => remember
|
1236
1183
|
if (prev + 1 == i) { // start new group
|
1237
1184
|
ngroup++;
|
1238
1185
|
lims[ngroup] = lims[ngroup - 1];
|
1239
|
-
dup_ids
|
1186
|
+
dup_ids[lims[ngroup]++] = list_ids[ord[prev]];
|
1240
1187
|
}
|
1241
|
-
dup_ids
|
1188
|
+
dup_ids[lims[ngroup]++] = list_ids[ord[i]];
|
1242
1189
|
} else { // not same as previous.
|
1243
1190
|
prev = i;
|
1244
1191
|
}
|
@@ -1247,9 +1194,4 @@ size_t IndexIVFPQ::find_duplicates (idx_t *dup_ids, size_t *lims) const
|
|
1247
1194
|
return ngroup;
|
1248
1195
|
}
|
1249
1196
|
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
1197
|
} // namespace faiss
|