faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -0,0 +1,1290 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/IndexIVFFastScan.h>
|
|
9
|
+
|
|
10
|
+
#include <cassert>
|
|
11
|
+
#include <cinttypes>
|
|
12
|
+
#include <cstdio>
|
|
13
|
+
#include <set>
|
|
14
|
+
|
|
15
|
+
#include <omp.h>
|
|
16
|
+
|
|
17
|
+
#include <memory>
|
|
18
|
+
|
|
19
|
+
#include <faiss/IndexIVFPQ.h>
|
|
20
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
22
|
+
#include <faiss/impl/LookupTableScaler.h>
|
|
23
|
+
#include <faiss/impl/pq4_fast_scan.h>
|
|
24
|
+
#include <faiss/impl/simd_result_handlers.h>
|
|
25
|
+
#include <faiss/invlists/BlockInvertedLists.h>
|
|
26
|
+
#include <faiss/utils/distances.h>
|
|
27
|
+
#include <faiss/utils/hamming.h>
|
|
28
|
+
#include <faiss/utils/quantize_lut.h>
|
|
29
|
+
#include <faiss/utils/utils.h>
|
|
30
|
+
|
|
31
|
+
namespace faiss {
|
|
32
|
+
|
|
33
|
+
using namespace simd_result_handlers;
|
|
34
|
+
|
|
35
|
+
inline size_t roundup(size_t a, size_t b) {
|
|
36
|
+
return (a + b - 1) / b * b;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
IndexIVFFastScan::IndexIVFFastScan(
|
|
40
|
+
Index* quantizer,
|
|
41
|
+
size_t d,
|
|
42
|
+
size_t nlist,
|
|
43
|
+
size_t code_size,
|
|
44
|
+
MetricType metric)
|
|
45
|
+
: IndexIVF(quantizer, d, nlist, code_size, metric) {
|
|
46
|
+
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
IndexIVFFastScan::IndexIVFFastScan() {
|
|
50
|
+
bbs = 0;
|
|
51
|
+
M2 = 0;
|
|
52
|
+
is_trained = false;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
void IndexIVFFastScan::init_fastscan(
|
|
56
|
+
size_t M,
|
|
57
|
+
size_t nbits,
|
|
58
|
+
size_t nlist,
|
|
59
|
+
MetricType /* metric */,
|
|
60
|
+
int bbs) {
|
|
61
|
+
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
62
|
+
FAISS_THROW_IF_NOT(nbits == 4);
|
|
63
|
+
|
|
64
|
+
this->M = M;
|
|
65
|
+
this->nbits = nbits;
|
|
66
|
+
this->bbs = bbs;
|
|
67
|
+
ksub = (1 << nbits);
|
|
68
|
+
M2 = roundup(M, 2);
|
|
69
|
+
code_size = M2 / 2;
|
|
70
|
+
|
|
71
|
+
is_trained = false;
|
|
72
|
+
replace_invlists(new BlockInvertedLists(nlist, bbs, bbs * M2 / 2), true);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
IndexIVFFastScan::~IndexIVFFastScan() {}
|
|
76
|
+
|
|
77
|
+
/*********************************************************
|
|
78
|
+
* Code management functions
|
|
79
|
+
*********************************************************/
|
|
80
|
+
|
|
81
|
+
void IndexIVFFastScan::add_with_ids(
|
|
82
|
+
idx_t n,
|
|
83
|
+
const float* x,
|
|
84
|
+
const idx_t* xids) {
|
|
85
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
86
|
+
|
|
87
|
+
// do some blocking to avoid excessive allocs
|
|
88
|
+
constexpr idx_t bs = 65536;
|
|
89
|
+
if (n > bs) {
|
|
90
|
+
double t0 = getmillisecs();
|
|
91
|
+
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
92
|
+
idx_t i1 = std::min(n, i0 + bs);
|
|
93
|
+
if (verbose) {
|
|
94
|
+
double t1 = getmillisecs();
|
|
95
|
+
double elapsed_time = (t1 - t0) / 1000;
|
|
96
|
+
double total_time = 0;
|
|
97
|
+
if (i0 != 0) {
|
|
98
|
+
total_time = elapsed_time / i0 * n;
|
|
99
|
+
}
|
|
100
|
+
size_t mem = get_mem_usage_kb() / (1 << 10);
|
|
101
|
+
|
|
102
|
+
printf("IndexIVFFastScan::add_with_ids %zd/%zd, time %.2f/%.2f, RSS %zdMB\n",
|
|
103
|
+
size_t(i1),
|
|
104
|
+
size_t(n),
|
|
105
|
+
elapsed_time,
|
|
106
|
+
total_time,
|
|
107
|
+
mem);
|
|
108
|
+
}
|
|
109
|
+
add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
|
|
110
|
+
}
|
|
111
|
+
return;
|
|
112
|
+
}
|
|
113
|
+
InterruptCallback::check();
|
|
114
|
+
|
|
115
|
+
AlignedTable<uint8_t> codes(n * code_size);
|
|
116
|
+
direct_map.check_can_add(xids);
|
|
117
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n]);
|
|
118
|
+
quantizer->assign(n, x, idx.get());
|
|
119
|
+
size_t nadd = 0, nminus1 = 0;
|
|
120
|
+
|
|
121
|
+
for (size_t i = 0; i < n; i++) {
|
|
122
|
+
if (idx[i] < 0) {
|
|
123
|
+
nminus1++;
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
AlignedTable<uint8_t> flat_codes(n * code_size);
|
|
128
|
+
encode_vectors(n, x, idx.get(), flat_codes.get());
|
|
129
|
+
|
|
130
|
+
DirectMapAdd dm_adder(direct_map, n, xids);
|
|
131
|
+
BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
|
|
132
|
+
FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
|
|
133
|
+
|
|
134
|
+
// prepare batches
|
|
135
|
+
std::vector<idx_t> order(n);
|
|
136
|
+
for (idx_t i = 0; i < n; i++) {
|
|
137
|
+
order[i] = i;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
// TODO should not need stable
|
|
141
|
+
std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
|
|
142
|
+
return idx[a] < idx[b];
|
|
143
|
+
});
|
|
144
|
+
|
|
145
|
+
// TODO parallelize
|
|
146
|
+
idx_t i0 = 0;
|
|
147
|
+
while (i0 < n) {
|
|
148
|
+
idx_t list_no = idx[order[i0]];
|
|
149
|
+
idx_t i1 = i0 + 1;
|
|
150
|
+
while (i1 < n && idx[order[i1]] == list_no) {
|
|
151
|
+
i1++;
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
if (list_no == -1) {
|
|
155
|
+
i0 = i1;
|
|
156
|
+
continue;
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
// make linear array
|
|
160
|
+
AlignedTable<uint8_t> list_codes((i1 - i0) * code_size);
|
|
161
|
+
size_t list_size = bil->list_size(list_no);
|
|
162
|
+
|
|
163
|
+
bil->resize(list_no, list_size + i1 - i0);
|
|
164
|
+
|
|
165
|
+
for (idx_t i = i0; i < i1; i++) {
|
|
166
|
+
size_t ofs = list_size + i - i0;
|
|
167
|
+
idx_t id = xids ? xids[order[i]] : ntotal + order[i];
|
|
168
|
+
dm_adder.add(order[i], list_no, ofs);
|
|
169
|
+
bil->ids[list_no][ofs] = id;
|
|
170
|
+
memcpy(list_codes.data() + (i - i0) * code_size,
|
|
171
|
+
flat_codes.data() + order[i] * code_size,
|
|
172
|
+
code_size);
|
|
173
|
+
nadd++;
|
|
174
|
+
}
|
|
175
|
+
pq4_pack_codes_range(
|
|
176
|
+
list_codes.data(),
|
|
177
|
+
M,
|
|
178
|
+
list_size,
|
|
179
|
+
list_size + i1 - i0,
|
|
180
|
+
bbs,
|
|
181
|
+
M2,
|
|
182
|
+
bil->codes[list_no].data());
|
|
183
|
+
|
|
184
|
+
i0 = i1;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
ntotal += n;
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
/*********************************************************
|
|
191
|
+
* search
|
|
192
|
+
*********************************************************/
|
|
193
|
+
|
|
194
|
+
namespace {
|
|
195
|
+
|
|
196
|
+
template <class C, typename dis_t, class Scaler>
|
|
197
|
+
void estimators_from_tables_generic(
|
|
198
|
+
const IndexIVFFastScan& index,
|
|
199
|
+
const uint8_t* codes,
|
|
200
|
+
size_t ncodes,
|
|
201
|
+
const dis_t* dis_table,
|
|
202
|
+
const int64_t* ids,
|
|
203
|
+
float bias,
|
|
204
|
+
size_t k,
|
|
205
|
+
typename C::T* heap_dis,
|
|
206
|
+
int64_t* heap_ids,
|
|
207
|
+
const Scaler& scaler) {
|
|
208
|
+
using accu_t = typename C::T;
|
|
209
|
+
for (size_t j = 0; j < ncodes; ++j) {
|
|
210
|
+
BitstringReader bsr(codes + j * index.code_size, index.code_size);
|
|
211
|
+
accu_t dis = bias;
|
|
212
|
+
const dis_t* __restrict dt = dis_table;
|
|
213
|
+
for (size_t m = 0; m < index.M - scaler.nscale; m++) {
|
|
214
|
+
uint64_t c = bsr.read(index.nbits);
|
|
215
|
+
dis += dt[c];
|
|
216
|
+
dt += index.ksub;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
for (size_t m = 0; m < scaler.nscale; m++) {
|
|
220
|
+
uint64_t c = bsr.read(index.nbits);
|
|
221
|
+
dis += scaler.scale_one(dt[c]);
|
|
222
|
+
dt += index.ksub;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
if (C::cmp(heap_dis[0], dis)) {
|
|
226
|
+
heap_pop<C>(k, heap_dis, heap_ids);
|
|
227
|
+
heap_push<C>(k, heap_dis, heap_ids, dis, ids[j]);
|
|
228
|
+
}
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
using idx_t = Index::idx_t;
|
|
233
|
+
using namespace quantize_lut;
|
|
234
|
+
|
|
235
|
+
} // anonymous namespace
|
|
236
|
+
|
|
237
|
+
/*********************************************************
|
|
238
|
+
* Look-Up Table functions
|
|
239
|
+
*********************************************************/
|
|
240
|
+
|
|
241
|
+
void IndexIVFFastScan::compute_LUT_uint8(
|
|
242
|
+
size_t n,
|
|
243
|
+
const float* x,
|
|
244
|
+
const idx_t* coarse_ids,
|
|
245
|
+
const float* coarse_dis,
|
|
246
|
+
AlignedTable<uint8_t>& dis_tables,
|
|
247
|
+
AlignedTable<uint16_t>& biases,
|
|
248
|
+
float* normalizers) const {
|
|
249
|
+
AlignedTable<float> dis_tables_float;
|
|
250
|
+
AlignedTable<float> biases_float;
|
|
251
|
+
|
|
252
|
+
uint64_t t0 = get_cy();
|
|
253
|
+
compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
|
|
254
|
+
IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
|
|
255
|
+
|
|
256
|
+
bool lut_is_3d = lookup_table_is_3d();
|
|
257
|
+
size_t dim123 = ksub * M;
|
|
258
|
+
size_t dim123_2 = ksub * M2;
|
|
259
|
+
if (lut_is_3d) {
|
|
260
|
+
dim123 *= nprobe;
|
|
261
|
+
dim123_2 *= nprobe;
|
|
262
|
+
}
|
|
263
|
+
dis_tables.resize(n * dim123_2);
|
|
264
|
+
if (biases_float.get()) {
|
|
265
|
+
biases.resize(n * nprobe);
|
|
266
|
+
}
|
|
267
|
+
uint64_t t1 = get_cy();
|
|
268
|
+
|
|
269
|
+
#pragma omp parallel for if (n > 100)
|
|
270
|
+
for (int64_t i = 0; i < n; i++) {
|
|
271
|
+
const float* t_in = dis_tables_float.get() + i * dim123;
|
|
272
|
+
const float* b_in = nullptr;
|
|
273
|
+
uint8_t* t_out = dis_tables.get() + i * dim123_2;
|
|
274
|
+
uint16_t* b_out = nullptr;
|
|
275
|
+
if (biases_float.get()) {
|
|
276
|
+
b_in = biases_float.get() + i * nprobe;
|
|
277
|
+
b_out = biases.get() + i * nprobe;
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
quantize_LUT_and_bias(
|
|
281
|
+
nprobe,
|
|
282
|
+
M,
|
|
283
|
+
ksub,
|
|
284
|
+
lut_is_3d,
|
|
285
|
+
t_in,
|
|
286
|
+
b_in,
|
|
287
|
+
t_out,
|
|
288
|
+
M2,
|
|
289
|
+
b_out,
|
|
290
|
+
normalizers + 2 * i,
|
|
291
|
+
normalizers + 2 * i + 1);
|
|
292
|
+
}
|
|
293
|
+
IVFFastScan_stats.t_round += get_cy() - t1;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
/*********************************************************
|
|
297
|
+
* Search functions
|
|
298
|
+
*********************************************************/
|
|
299
|
+
|
|
300
|
+
void IndexIVFFastScan::search(
|
|
301
|
+
idx_t n,
|
|
302
|
+
const float* x,
|
|
303
|
+
idx_t k,
|
|
304
|
+
float* distances,
|
|
305
|
+
idx_t* labels,
|
|
306
|
+
const SearchParameters* params) const {
|
|
307
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
308
|
+
!params, "search params not supported for this index");
|
|
309
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
310
|
+
|
|
311
|
+
DummyScaler scaler;
|
|
312
|
+
if (metric_type == METRIC_L2) {
|
|
313
|
+
search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
|
|
314
|
+
} else {
|
|
315
|
+
search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
void IndexIVFFastScan::range_search(
|
|
320
|
+
idx_t,
|
|
321
|
+
const float*,
|
|
322
|
+
float,
|
|
323
|
+
RangeSearchResult*,
|
|
324
|
+
const SearchParameters*) const {
|
|
325
|
+
FAISS_THROW_MSG("not implemented");
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
template <bool is_max, class Scaler>
|
|
329
|
+
void IndexIVFFastScan::search_dispatch_implem(
|
|
330
|
+
idx_t n,
|
|
331
|
+
const float* x,
|
|
332
|
+
idx_t k,
|
|
333
|
+
float* distances,
|
|
334
|
+
idx_t* labels,
|
|
335
|
+
const Scaler& scaler) const {
|
|
336
|
+
using Cfloat = typename std::conditional<
|
|
337
|
+
is_max,
|
|
338
|
+
CMax<float, int64_t>,
|
|
339
|
+
CMin<float, int64_t>>::type;
|
|
340
|
+
|
|
341
|
+
using C = typename std::conditional<
|
|
342
|
+
is_max,
|
|
343
|
+
CMax<uint16_t, int64_t>,
|
|
344
|
+
CMin<uint16_t, int64_t>>::type;
|
|
345
|
+
|
|
346
|
+
if (n == 0) {
|
|
347
|
+
return;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
// actual implementation used
|
|
351
|
+
int impl = implem;
|
|
352
|
+
|
|
353
|
+
if (impl == 0) {
|
|
354
|
+
if (bbs == 32) {
|
|
355
|
+
impl = 12;
|
|
356
|
+
} else {
|
|
357
|
+
impl = 10;
|
|
358
|
+
}
|
|
359
|
+
if (k > 20) {
|
|
360
|
+
impl++;
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
if (impl == 1) {
|
|
365
|
+
search_implem_1<Cfloat>(n, x, k, distances, labels, scaler);
|
|
366
|
+
} else if (impl == 2) {
|
|
367
|
+
search_implem_2<C>(n, x, k, distances, labels, scaler);
|
|
368
|
+
|
|
369
|
+
} else if (impl >= 10 && impl <= 15) {
|
|
370
|
+
size_t ndis = 0, nlist_visited = 0;
|
|
371
|
+
|
|
372
|
+
if (n < 2) {
|
|
373
|
+
if (impl == 12 || impl == 13) {
|
|
374
|
+
search_implem_12<C>(
|
|
375
|
+
n,
|
|
376
|
+
x,
|
|
377
|
+
k,
|
|
378
|
+
distances,
|
|
379
|
+
labels,
|
|
380
|
+
impl,
|
|
381
|
+
&ndis,
|
|
382
|
+
&nlist_visited,
|
|
383
|
+
scaler);
|
|
384
|
+
} else if (impl == 14 || impl == 15) {
|
|
385
|
+
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
|
|
386
|
+
} else {
|
|
387
|
+
search_implem_10<C>(
|
|
388
|
+
n,
|
|
389
|
+
x,
|
|
390
|
+
k,
|
|
391
|
+
distances,
|
|
392
|
+
labels,
|
|
393
|
+
impl,
|
|
394
|
+
&ndis,
|
|
395
|
+
&nlist_visited,
|
|
396
|
+
scaler);
|
|
397
|
+
}
|
|
398
|
+
} else {
|
|
399
|
+
// explicitly slice over threads
|
|
400
|
+
int nslice;
|
|
401
|
+
if (n <= omp_get_max_threads()) {
|
|
402
|
+
nslice = n;
|
|
403
|
+
} else if (lookup_table_is_3d()) {
|
|
404
|
+
// make sure we don't make too big LUT tables
|
|
405
|
+
size_t lut_size_per_query =
|
|
406
|
+
M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
|
|
407
|
+
|
|
408
|
+
size_t max_lut_size = precomputed_table_max_bytes;
|
|
409
|
+
// how many queries we can handle within mem budget
|
|
410
|
+
size_t nq_ok =
|
|
411
|
+
std::max(max_lut_size / lut_size_per_query, size_t(1));
|
|
412
|
+
nslice =
|
|
413
|
+
roundup(std::max(size_t(n / nq_ok), size_t(1)),
|
|
414
|
+
omp_get_max_threads());
|
|
415
|
+
} else {
|
|
416
|
+
// LUTs unlikely to be a limiting factor
|
|
417
|
+
nslice = omp_get_max_threads();
|
|
418
|
+
}
|
|
419
|
+
if (impl == 14 ||
|
|
420
|
+
impl == 15) { // this might require slicing if there are too
|
|
421
|
+
// many queries (for now we keep this simple)
|
|
422
|
+
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
|
|
423
|
+
} else {
|
|
424
|
+
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
425
|
+
for (int slice = 0; slice < nslice; slice++) {
|
|
426
|
+
idx_t i0 = n * slice / nslice;
|
|
427
|
+
idx_t i1 = n * (slice + 1) / nslice;
|
|
428
|
+
float* dis_i = distances + i0 * k;
|
|
429
|
+
idx_t* lab_i = labels + i0 * k;
|
|
430
|
+
if (impl == 12 || impl == 13) {
|
|
431
|
+
search_implem_12<C>(
|
|
432
|
+
i1 - i0,
|
|
433
|
+
x + i0 * d,
|
|
434
|
+
k,
|
|
435
|
+
dis_i,
|
|
436
|
+
lab_i,
|
|
437
|
+
impl,
|
|
438
|
+
&ndis,
|
|
439
|
+
&nlist_visited,
|
|
440
|
+
scaler);
|
|
441
|
+
} else {
|
|
442
|
+
search_implem_10<C>(
|
|
443
|
+
i1 - i0,
|
|
444
|
+
x + i0 * d,
|
|
445
|
+
k,
|
|
446
|
+
dis_i,
|
|
447
|
+
lab_i,
|
|
448
|
+
impl,
|
|
449
|
+
&ndis,
|
|
450
|
+
&nlist_visited,
|
|
451
|
+
scaler);
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
indexIVF_stats.nq += n;
|
|
457
|
+
indexIVF_stats.ndis += ndis;
|
|
458
|
+
indexIVF_stats.nlist += nlist_visited;
|
|
459
|
+
} else {
|
|
460
|
+
FAISS_THROW_FMT("implem %d does not exist", implem);
|
|
461
|
+
}
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
template <class C, class Scaler>
|
|
465
|
+
void IndexIVFFastScan::search_implem_1(
|
|
466
|
+
idx_t n,
|
|
467
|
+
const float* x,
|
|
468
|
+
idx_t k,
|
|
469
|
+
float* distances,
|
|
470
|
+
idx_t* labels,
|
|
471
|
+
const Scaler& scaler) const {
|
|
472
|
+
FAISS_THROW_IF_NOT(orig_invlists);
|
|
473
|
+
|
|
474
|
+
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
475
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
476
|
+
|
|
477
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
478
|
+
|
|
479
|
+
size_t dim12 = ksub * M;
|
|
480
|
+
AlignedTable<float> dis_tables;
|
|
481
|
+
AlignedTable<float> biases;
|
|
482
|
+
|
|
483
|
+
compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
|
|
484
|
+
|
|
485
|
+
bool single_LUT = !lookup_table_is_3d();
|
|
486
|
+
|
|
487
|
+
size_t ndis = 0, nlist_visited = 0;
|
|
488
|
+
|
|
489
|
+
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
490
|
+
for (idx_t i = 0; i < n; i++) {
|
|
491
|
+
int64_t* heap_ids = labels + i * k;
|
|
492
|
+
float* heap_dis = distances + i * k;
|
|
493
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
494
|
+
float* LUT = nullptr;
|
|
495
|
+
|
|
496
|
+
if (single_LUT) {
|
|
497
|
+
LUT = dis_tables.get() + i * dim12;
|
|
498
|
+
}
|
|
499
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
500
|
+
if (!single_LUT) {
|
|
501
|
+
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
502
|
+
}
|
|
503
|
+
idx_t list_no = coarse_ids[i * nprobe + j];
|
|
504
|
+
if (list_no < 0)
|
|
505
|
+
continue;
|
|
506
|
+
size_t ls = orig_invlists->list_size(list_no);
|
|
507
|
+
if (ls == 0)
|
|
508
|
+
continue;
|
|
509
|
+
InvertedLists::ScopedCodes codes(orig_invlists, list_no);
|
|
510
|
+
InvertedLists::ScopedIds ids(orig_invlists, list_no);
|
|
511
|
+
|
|
512
|
+
float bias = biases.get() ? biases[i * nprobe + j] : 0;
|
|
513
|
+
|
|
514
|
+
estimators_from_tables_generic<C>(
|
|
515
|
+
*this,
|
|
516
|
+
codes.get(),
|
|
517
|
+
ls,
|
|
518
|
+
LUT,
|
|
519
|
+
ids.get(),
|
|
520
|
+
bias,
|
|
521
|
+
k,
|
|
522
|
+
heap_dis,
|
|
523
|
+
heap_ids,
|
|
524
|
+
scaler);
|
|
525
|
+
nlist_visited++;
|
|
526
|
+
ndis++;
|
|
527
|
+
}
|
|
528
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
529
|
+
}
|
|
530
|
+
indexIVF_stats.nq += n;
|
|
531
|
+
indexIVF_stats.ndis += ndis;
|
|
532
|
+
indexIVF_stats.nlist += nlist_visited;
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
template <class C, class Scaler>
|
|
536
|
+
void IndexIVFFastScan::search_implem_2(
|
|
537
|
+
idx_t n,
|
|
538
|
+
const float* x,
|
|
539
|
+
idx_t k,
|
|
540
|
+
float* distances,
|
|
541
|
+
idx_t* labels,
|
|
542
|
+
const Scaler& scaler) const {
|
|
543
|
+
FAISS_THROW_IF_NOT(orig_invlists);
|
|
544
|
+
|
|
545
|
+
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
546
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
547
|
+
|
|
548
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
549
|
+
|
|
550
|
+
size_t dim12 = ksub * M2;
|
|
551
|
+
AlignedTable<uint8_t> dis_tables;
|
|
552
|
+
AlignedTable<uint16_t> biases;
|
|
553
|
+
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
554
|
+
|
|
555
|
+
compute_LUT_uint8(
|
|
556
|
+
n,
|
|
557
|
+
x,
|
|
558
|
+
coarse_ids.get(),
|
|
559
|
+
coarse_dis.get(),
|
|
560
|
+
dis_tables,
|
|
561
|
+
biases,
|
|
562
|
+
normalizers.get());
|
|
563
|
+
|
|
564
|
+
bool single_LUT = !lookup_table_is_3d();
|
|
565
|
+
|
|
566
|
+
size_t ndis = 0, nlist_visited = 0;
|
|
567
|
+
|
|
568
|
+
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
|
|
569
|
+
for (idx_t i = 0; i < n; i++) {
|
|
570
|
+
std::vector<uint16_t> tmp_dis(k);
|
|
571
|
+
int64_t* heap_ids = labels + i * k;
|
|
572
|
+
uint16_t* heap_dis = tmp_dis.data();
|
|
573
|
+
heap_heapify<C>(k, heap_dis, heap_ids);
|
|
574
|
+
const uint8_t* LUT = nullptr;
|
|
575
|
+
|
|
576
|
+
if (single_LUT) {
|
|
577
|
+
LUT = dis_tables.get() + i * dim12;
|
|
578
|
+
}
|
|
579
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
580
|
+
if (!single_LUT) {
|
|
581
|
+
LUT = dis_tables.get() + (i * nprobe + j) * dim12;
|
|
582
|
+
}
|
|
583
|
+
idx_t list_no = coarse_ids[i * nprobe + j];
|
|
584
|
+
if (list_no < 0)
|
|
585
|
+
continue;
|
|
586
|
+
size_t ls = orig_invlists->list_size(list_no);
|
|
587
|
+
if (ls == 0)
|
|
588
|
+
continue;
|
|
589
|
+
InvertedLists::ScopedCodes codes(orig_invlists, list_no);
|
|
590
|
+
InvertedLists::ScopedIds ids(orig_invlists, list_no);
|
|
591
|
+
|
|
592
|
+
uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
|
|
593
|
+
|
|
594
|
+
estimators_from_tables_generic<C>(
|
|
595
|
+
*this,
|
|
596
|
+
codes.get(),
|
|
597
|
+
ls,
|
|
598
|
+
LUT,
|
|
599
|
+
ids.get(),
|
|
600
|
+
bias,
|
|
601
|
+
k,
|
|
602
|
+
heap_dis,
|
|
603
|
+
heap_ids,
|
|
604
|
+
scaler);
|
|
605
|
+
|
|
606
|
+
nlist_visited++;
|
|
607
|
+
ndis += ls;
|
|
608
|
+
}
|
|
609
|
+
heap_reorder<C>(k, heap_dis, heap_ids);
|
|
610
|
+
// convert distances to float
|
|
611
|
+
{
|
|
612
|
+
float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
|
|
613
|
+
if (skip & 16) {
|
|
614
|
+
one_a = 1;
|
|
615
|
+
b = 0;
|
|
616
|
+
}
|
|
617
|
+
float* heap_dis_float = distances + i * k;
|
|
618
|
+
for (int j = 0; j < k; j++) {
|
|
619
|
+
heap_dis_float[j] = b + heap_dis[j] * one_a;
|
|
620
|
+
}
|
|
621
|
+
}
|
|
622
|
+
}
|
|
623
|
+
indexIVF_stats.nq += n;
|
|
624
|
+
indexIVF_stats.ndis += ndis;
|
|
625
|
+
indexIVF_stats.nlist += nlist_visited;
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
template <class C, class Scaler>
|
|
629
|
+
void IndexIVFFastScan::search_implem_10(
|
|
630
|
+
idx_t n,
|
|
631
|
+
const float* x,
|
|
632
|
+
idx_t k,
|
|
633
|
+
float* distances,
|
|
634
|
+
idx_t* labels,
|
|
635
|
+
int impl,
|
|
636
|
+
size_t* ndis_out,
|
|
637
|
+
size_t* nlist_out,
|
|
638
|
+
const Scaler& scaler) const {
|
|
639
|
+
memset(distances, -1, sizeof(float) * k * n);
|
|
640
|
+
memset(labels, -1, sizeof(idx_t) * k * n);
|
|
641
|
+
|
|
642
|
+
using HeapHC = HeapHandler<C, true>;
|
|
643
|
+
using ReservoirHC = ReservoirHandler<C, true>;
|
|
644
|
+
using SingleResultHC = SingleResultHandler<C, true>;
|
|
645
|
+
|
|
646
|
+
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
647
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
648
|
+
|
|
649
|
+
uint64_t times[10];
|
|
650
|
+
memset(times, 0, sizeof(times));
|
|
651
|
+
int ti = 0;
|
|
652
|
+
#define TIC times[ti++] = get_cy()
|
|
653
|
+
TIC;
|
|
654
|
+
|
|
655
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
656
|
+
|
|
657
|
+
TIC;
|
|
658
|
+
|
|
659
|
+
size_t dim12 = ksub * M2;
|
|
660
|
+
AlignedTable<uint8_t> dis_tables;
|
|
661
|
+
AlignedTable<uint16_t> biases;
|
|
662
|
+
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
663
|
+
|
|
664
|
+
compute_LUT_uint8(
|
|
665
|
+
n,
|
|
666
|
+
x,
|
|
667
|
+
coarse_ids.get(),
|
|
668
|
+
coarse_dis.get(),
|
|
669
|
+
dis_tables,
|
|
670
|
+
biases,
|
|
671
|
+
normalizers.get());
|
|
672
|
+
|
|
673
|
+
TIC;
|
|
674
|
+
|
|
675
|
+
bool single_LUT = !lookup_table_is_3d();
|
|
676
|
+
|
|
677
|
+
TIC;
|
|
678
|
+
size_t ndis = 0, nlist_visited = 0;
|
|
679
|
+
|
|
680
|
+
{
|
|
681
|
+
AlignedTable<uint16_t> tmp_distances(k);
|
|
682
|
+
for (idx_t i = 0; i < n; i++) {
|
|
683
|
+
const uint8_t* LUT = nullptr;
|
|
684
|
+
int qmap1[1] = {0};
|
|
685
|
+
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
686
|
+
|
|
687
|
+
if (k == 1) {
|
|
688
|
+
handler.reset(new SingleResultHC(1, 0));
|
|
689
|
+
} else if (impl == 10) {
|
|
690
|
+
handler.reset(new HeapHC(
|
|
691
|
+
1, tmp_distances.get(), labels + i * k, k, 0));
|
|
692
|
+
} else if (impl == 11) {
|
|
693
|
+
handler.reset(new ReservoirHC(1, 0, k, 2 * k));
|
|
694
|
+
} else {
|
|
695
|
+
FAISS_THROW_MSG("invalid");
|
|
696
|
+
}
|
|
697
|
+
|
|
698
|
+
handler->q_map = qmap1;
|
|
699
|
+
|
|
700
|
+
if (single_LUT) {
|
|
701
|
+
LUT = dis_tables.get() + i * dim12;
|
|
702
|
+
}
|
|
703
|
+
for (idx_t j = 0; j < nprobe; j++) {
|
|
704
|
+
size_t ij = i * nprobe + j;
|
|
705
|
+
if (!single_LUT) {
|
|
706
|
+
LUT = dis_tables.get() + ij * dim12;
|
|
707
|
+
}
|
|
708
|
+
if (biases.get()) {
|
|
709
|
+
handler->dbias = biases.get() + ij;
|
|
710
|
+
}
|
|
711
|
+
|
|
712
|
+
idx_t list_no = coarse_ids[ij];
|
|
713
|
+
if (list_no < 0)
|
|
714
|
+
continue;
|
|
715
|
+
size_t ls = invlists->list_size(list_no);
|
|
716
|
+
if (ls == 0)
|
|
717
|
+
continue;
|
|
718
|
+
|
|
719
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
720
|
+
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
721
|
+
|
|
722
|
+
handler->ntotal = ls;
|
|
723
|
+
handler->id_map = ids.get();
|
|
724
|
+
|
|
725
|
+
#define DISPATCH(classHC) \
|
|
726
|
+
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
727
|
+
auto* res = static_cast<classHC*>(handler.get()); \
|
|
728
|
+
pq4_accumulate_loop( \
|
|
729
|
+
1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \
|
|
730
|
+
}
|
|
731
|
+
DISPATCH(HeapHC)
|
|
732
|
+
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
733
|
+
#undef DISPATCH
|
|
734
|
+
|
|
735
|
+
nlist_visited++;
|
|
736
|
+
ndis++;
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
handler->to_flat_arrays(
|
|
740
|
+
distances + i * k,
|
|
741
|
+
labels + i * k,
|
|
742
|
+
skip & 16 ? nullptr : normalizers.get() + i * 2);
|
|
743
|
+
}
|
|
744
|
+
}
|
|
745
|
+
*ndis_out = ndis;
|
|
746
|
+
*nlist_out = nlist;
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
template <class C, class Scaler>
|
|
750
|
+
void IndexIVFFastScan::search_implem_12(
|
|
751
|
+
idx_t n,
|
|
752
|
+
const float* x,
|
|
753
|
+
idx_t k,
|
|
754
|
+
float* distances,
|
|
755
|
+
idx_t* labels,
|
|
756
|
+
int impl,
|
|
757
|
+
size_t* ndis_out,
|
|
758
|
+
size_t* nlist_out,
|
|
759
|
+
const Scaler& scaler) const {
|
|
760
|
+
if (n == 0) { // does not work well with reservoir
|
|
761
|
+
return;
|
|
762
|
+
}
|
|
763
|
+
FAISS_THROW_IF_NOT(bbs == 32);
|
|
764
|
+
|
|
765
|
+
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
766
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
767
|
+
|
|
768
|
+
uint64_t times[10];
|
|
769
|
+
memset(times, 0, sizeof(times));
|
|
770
|
+
int ti = 0;
|
|
771
|
+
#define TIC times[ti++] = get_cy()
|
|
772
|
+
TIC;
|
|
773
|
+
|
|
774
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
775
|
+
|
|
776
|
+
TIC;
|
|
777
|
+
|
|
778
|
+
size_t dim12 = ksub * M2;
|
|
779
|
+
AlignedTable<uint8_t> dis_tables;
|
|
780
|
+
AlignedTable<uint16_t> biases;
|
|
781
|
+
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
782
|
+
|
|
783
|
+
compute_LUT_uint8(
|
|
784
|
+
n,
|
|
785
|
+
x,
|
|
786
|
+
coarse_ids.get(),
|
|
787
|
+
coarse_dis.get(),
|
|
788
|
+
dis_tables,
|
|
789
|
+
biases,
|
|
790
|
+
normalizers.get());
|
|
791
|
+
|
|
792
|
+
TIC;
|
|
793
|
+
|
|
794
|
+
struct QC {
|
|
795
|
+
int qno; // sequence number of the query
|
|
796
|
+
int list_no; // list to visit
|
|
797
|
+
int rank; // this is the rank'th result of the coarse quantizer
|
|
798
|
+
};
|
|
799
|
+
bool single_LUT = !lookup_table_is_3d();
|
|
800
|
+
|
|
801
|
+
std::vector<QC> qcs;
|
|
802
|
+
{
|
|
803
|
+
int ij = 0;
|
|
804
|
+
for (int i = 0; i < n; i++) {
|
|
805
|
+
for (int j = 0; j < nprobe; j++) {
|
|
806
|
+
if (coarse_ids[ij] >= 0) {
|
|
807
|
+
qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
|
|
808
|
+
}
|
|
809
|
+
ij++;
|
|
810
|
+
}
|
|
811
|
+
}
|
|
812
|
+
std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
|
|
813
|
+
return a.list_no < b.list_no;
|
|
814
|
+
});
|
|
815
|
+
}
|
|
816
|
+
TIC;
|
|
817
|
+
|
|
818
|
+
// prepare the result handlers
|
|
819
|
+
|
|
820
|
+
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
821
|
+
AlignedTable<uint16_t> tmp_distances;
|
|
822
|
+
|
|
823
|
+
using HeapHC = HeapHandler<C, true>;
|
|
824
|
+
using ReservoirHC = ReservoirHandler<C, true>;
|
|
825
|
+
using SingleResultHC = SingleResultHandler<C, true>;
|
|
826
|
+
|
|
827
|
+
if (k == 1) {
|
|
828
|
+
handler.reset(new SingleResultHC(n, 0));
|
|
829
|
+
} else if (impl == 12) {
|
|
830
|
+
tmp_distances.resize(n * k);
|
|
831
|
+
handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
|
|
832
|
+
} else if (impl == 13) {
|
|
833
|
+
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
837
|
+
|
|
838
|
+
std::vector<uint16_t> tmp_bias;
|
|
839
|
+
if (biases.get()) {
|
|
840
|
+
tmp_bias.resize(qbs2);
|
|
841
|
+
handler->dbias = tmp_bias.data();
|
|
842
|
+
}
|
|
843
|
+
TIC;
|
|
844
|
+
|
|
845
|
+
size_t ndis = 0;
|
|
846
|
+
|
|
847
|
+
size_t i0 = 0;
|
|
848
|
+
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
849
|
+
while (i0 < qcs.size()) {
|
|
850
|
+
uint64_t tt0 = get_cy();
|
|
851
|
+
|
|
852
|
+
// find all queries that access this inverted list
|
|
853
|
+
int list_no = qcs[i0].list_no;
|
|
854
|
+
size_t i1 = i0 + 1;
|
|
855
|
+
|
|
856
|
+
while (i1 < qcs.size() && i1 < i0 + qbs2) {
|
|
857
|
+
if (qcs[i1].list_no != list_no) {
|
|
858
|
+
break;
|
|
859
|
+
}
|
|
860
|
+
i1++;
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
size_t list_size = invlists->list_size(list_no);
|
|
864
|
+
|
|
865
|
+
if (list_size == 0) {
|
|
866
|
+
i0 = i1;
|
|
867
|
+
continue;
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
// re-organize LUTs and biases into the right order
|
|
871
|
+
int nc = i1 - i0;
|
|
872
|
+
|
|
873
|
+
std::vector<int> q_map(nc), lut_entries(nc);
|
|
874
|
+
AlignedTable<uint8_t> LUT(nc * dim12);
|
|
875
|
+
memset(LUT.get(), -1, nc * dim12);
|
|
876
|
+
int qbs = pq4_preferred_qbs(nc);
|
|
877
|
+
|
|
878
|
+
for (size_t i = i0; i < i1; i++) {
|
|
879
|
+
const QC& qc = qcs[i];
|
|
880
|
+
q_map[i - i0] = qc.qno;
|
|
881
|
+
int ij = qc.qno * nprobe + qc.rank;
|
|
882
|
+
lut_entries[i - i0] = single_LUT ? qc.qno : ij;
|
|
883
|
+
if (biases.get()) {
|
|
884
|
+
tmp_bias[i - i0] = biases[ij];
|
|
885
|
+
}
|
|
886
|
+
}
|
|
887
|
+
pq4_pack_LUT_qbs_q_map(
|
|
888
|
+
qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
|
|
889
|
+
|
|
890
|
+
// access the inverted list
|
|
891
|
+
|
|
892
|
+
ndis += (i1 - i0) * list_size;
|
|
893
|
+
|
|
894
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
895
|
+
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
896
|
+
|
|
897
|
+
// prepare the handler
|
|
898
|
+
|
|
899
|
+
handler->ntotal = list_size;
|
|
900
|
+
handler->q_map = q_map.data();
|
|
901
|
+
handler->id_map = ids.get();
|
|
902
|
+
uint64_t tt1 = get_cy();
|
|
903
|
+
|
|
904
|
+
#define DISPATCH(classHC) \
|
|
905
|
+
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
906
|
+
auto* res = static_cast<classHC*>(handler.get()); \
|
|
907
|
+
pq4_accumulate_loop_qbs( \
|
|
908
|
+
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
|
|
909
|
+
}
|
|
910
|
+
DISPATCH(HeapHC)
|
|
911
|
+
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
912
|
+
|
|
913
|
+
// prepare for next loop
|
|
914
|
+
i0 = i1;
|
|
915
|
+
|
|
916
|
+
uint64_t tt2 = get_cy();
|
|
917
|
+
t_copy_pack += tt1 - tt0;
|
|
918
|
+
t_scan += tt2 - tt1;
|
|
919
|
+
}
|
|
920
|
+
TIC;
|
|
921
|
+
|
|
922
|
+
// labels is in-place for HeapHC
|
|
923
|
+
handler->to_flat_arrays(
|
|
924
|
+
distances, labels, skip & 16 ? nullptr : normalizers.get());
|
|
925
|
+
|
|
926
|
+
TIC;
|
|
927
|
+
|
|
928
|
+
// these stats are not thread-safe
|
|
929
|
+
|
|
930
|
+
for (int i = 1; i < ti; i++) {
|
|
931
|
+
IVFFastScan_stats.times[i] += times[i] - times[i - 1];
|
|
932
|
+
}
|
|
933
|
+
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
934
|
+
IVFFastScan_stats.t_scan += t_scan;
|
|
935
|
+
|
|
936
|
+
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
937
|
+
for (int i = 0; i < 4; i++) {
|
|
938
|
+
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
939
|
+
}
|
|
940
|
+
}
|
|
941
|
+
|
|
942
|
+
*ndis_out = ndis;
|
|
943
|
+
*nlist_out = nlist;
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
template <class C, class Scaler>
|
|
947
|
+
void IndexIVFFastScan::search_implem_14(
|
|
948
|
+
idx_t n,
|
|
949
|
+
const float* x,
|
|
950
|
+
idx_t k,
|
|
951
|
+
float* distances,
|
|
952
|
+
idx_t* labels,
|
|
953
|
+
int impl,
|
|
954
|
+
const Scaler& scaler) const {
|
|
955
|
+
if (n == 0) { // does not work well with reservoir
|
|
956
|
+
return;
|
|
957
|
+
}
|
|
958
|
+
FAISS_THROW_IF_NOT(bbs == 32);
|
|
959
|
+
|
|
960
|
+
std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
|
|
961
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
962
|
+
|
|
963
|
+
uint64_t ttg0 = get_cy();
|
|
964
|
+
|
|
965
|
+
quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
|
|
966
|
+
|
|
967
|
+
uint64_t ttg1 = get_cy();
|
|
968
|
+
uint64_t coarse_search_tt = ttg1 - ttg0;
|
|
969
|
+
|
|
970
|
+
size_t dim12 = ksub * M2;
|
|
971
|
+
AlignedTable<uint8_t> dis_tables;
|
|
972
|
+
AlignedTable<uint16_t> biases;
|
|
973
|
+
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
974
|
+
|
|
975
|
+
compute_LUT_uint8(
|
|
976
|
+
n,
|
|
977
|
+
x,
|
|
978
|
+
coarse_ids.get(),
|
|
979
|
+
coarse_dis.get(),
|
|
980
|
+
dis_tables,
|
|
981
|
+
biases,
|
|
982
|
+
normalizers.get());
|
|
983
|
+
|
|
984
|
+
uint64_t ttg2 = get_cy();
|
|
985
|
+
uint64_t lut_compute_tt = ttg2 - ttg1;
|
|
986
|
+
|
|
987
|
+
struct QC {
|
|
988
|
+
int qno; // sequence number of the query
|
|
989
|
+
int list_no; // list to visit
|
|
990
|
+
int rank; // this is the rank'th result of the coarse quantizer
|
|
991
|
+
};
|
|
992
|
+
bool single_LUT = !lookup_table_is_3d();
|
|
993
|
+
|
|
994
|
+
std::vector<QC> qcs;
|
|
995
|
+
{
|
|
996
|
+
int ij = 0;
|
|
997
|
+
for (int i = 0; i < n; i++) {
|
|
998
|
+
for (int j = 0; j < nprobe; j++) {
|
|
999
|
+
if (coarse_ids[ij] >= 0) {
|
|
1000
|
+
qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
|
|
1001
|
+
}
|
|
1002
|
+
ij++;
|
|
1003
|
+
}
|
|
1004
|
+
}
|
|
1005
|
+
std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
|
|
1006
|
+
return a.list_no < b.list_no;
|
|
1007
|
+
});
|
|
1008
|
+
}
|
|
1009
|
+
|
|
1010
|
+
struct SE {
|
|
1011
|
+
size_t start; // start in the QC vector
|
|
1012
|
+
size_t end; // end in the QC vector
|
|
1013
|
+
size_t list_size;
|
|
1014
|
+
};
|
|
1015
|
+
std::vector<SE> ses;
|
|
1016
|
+
size_t i0_l = 0;
|
|
1017
|
+
while (i0_l < qcs.size()) {
|
|
1018
|
+
// find all queries that access this inverted list
|
|
1019
|
+
int list_no = qcs[i0_l].list_no;
|
|
1020
|
+
size_t i1 = i0_l + 1;
|
|
1021
|
+
|
|
1022
|
+
while (i1 < qcs.size() && i1 < i0_l + qbs2) {
|
|
1023
|
+
if (qcs[i1].list_no != list_no) {
|
|
1024
|
+
break;
|
|
1025
|
+
}
|
|
1026
|
+
i1++;
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
size_t list_size = invlists->list_size(list_no);
|
|
1030
|
+
|
|
1031
|
+
if (list_size == 0) {
|
|
1032
|
+
i0_l = i1;
|
|
1033
|
+
continue;
|
|
1034
|
+
}
|
|
1035
|
+
ses.push_back(SE{i0_l, i1, list_size});
|
|
1036
|
+
i0_l = i1;
|
|
1037
|
+
}
|
|
1038
|
+
uint64_t ttg3 = get_cy();
|
|
1039
|
+
uint64_t compute_clusters_tt = ttg3 - ttg2;
|
|
1040
|
+
|
|
1041
|
+
// function to handle the global heap
|
|
1042
|
+
using HeapForIP = CMin<float, idx_t>;
|
|
1043
|
+
using HeapForL2 = CMax<float, idx_t>;
|
|
1044
|
+
auto init_result = [&](float* simi, idx_t* idxi) {
|
|
1045
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
1046
|
+
heap_heapify<HeapForIP>(k, simi, idxi);
|
|
1047
|
+
} else {
|
|
1048
|
+
heap_heapify<HeapForL2>(k, simi, idxi);
|
|
1049
|
+
}
|
|
1050
|
+
};
|
|
1051
|
+
|
|
1052
|
+
auto add_local_results = [&](const float* local_dis,
|
|
1053
|
+
const idx_t* local_idx,
|
|
1054
|
+
float* simi,
|
|
1055
|
+
idx_t* idxi) {
|
|
1056
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
1057
|
+
heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
|
|
1058
|
+
} else {
|
|
1059
|
+
heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
|
|
1060
|
+
}
|
|
1061
|
+
};
|
|
1062
|
+
|
|
1063
|
+
auto reorder_result = [&](float* simi, idx_t* idxi) {
|
|
1064
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
1065
|
+
heap_reorder<HeapForIP>(k, simi, idxi);
|
|
1066
|
+
} else {
|
|
1067
|
+
heap_reorder<HeapForL2>(k, simi, idxi);
|
|
1068
|
+
}
|
|
1069
|
+
};
|
|
1070
|
+
uint64_t ttg4 = get_cy();
|
|
1071
|
+
uint64_t fn_tt = ttg4 - ttg3;
|
|
1072
|
+
|
|
1073
|
+
size_t ndis = 0;
|
|
1074
|
+
size_t nlist_visited = 0;
|
|
1075
|
+
|
|
1076
|
+
#pragma omp parallel reduction(+ : ndis, nlist_visited)
|
|
1077
|
+
{
|
|
1078
|
+
// storage for each thread
|
|
1079
|
+
std::vector<idx_t> local_idx(k * n);
|
|
1080
|
+
std::vector<float> local_dis(k * n);
|
|
1081
|
+
|
|
1082
|
+
// prepare the result handlers
|
|
1083
|
+
std::unique_ptr<SIMDResultHandler<C, true>> handler;
|
|
1084
|
+
AlignedTable<uint16_t> tmp_distances;
|
|
1085
|
+
|
|
1086
|
+
using HeapHC = HeapHandler<C, true>;
|
|
1087
|
+
using ReservoirHC = ReservoirHandler<C, true>;
|
|
1088
|
+
using SingleResultHC = SingleResultHandler<C, true>;
|
|
1089
|
+
|
|
1090
|
+
if (k == 1) {
|
|
1091
|
+
handler.reset(new SingleResultHC(n, 0));
|
|
1092
|
+
} else if (impl == 14) {
|
|
1093
|
+
tmp_distances.resize(n * k);
|
|
1094
|
+
handler.reset(
|
|
1095
|
+
new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
|
|
1096
|
+
} else if (impl == 15) {
|
|
1097
|
+
handler.reset(new ReservoirHC(n, 0, k, 2 * k));
|
|
1098
|
+
}
|
|
1099
|
+
|
|
1100
|
+
int qbs2 = this->qbs2 ? this->qbs2 : 11;
|
|
1101
|
+
|
|
1102
|
+
std::vector<uint16_t> tmp_bias;
|
|
1103
|
+
if (biases.get()) {
|
|
1104
|
+
tmp_bias.resize(qbs2);
|
|
1105
|
+
handler->dbias = tmp_bias.data();
|
|
1106
|
+
}
|
|
1107
|
+
|
|
1108
|
+
uint64_t ttg5 = get_cy();
|
|
1109
|
+
uint64_t handler_tt = ttg5 - ttg4;
|
|
1110
|
+
|
|
1111
|
+
std::set<int> q_set;
|
|
1112
|
+
uint64_t t_copy_pack = 0, t_scan = 0;
|
|
1113
|
+
#pragma omp for schedule(dynamic)
|
|
1114
|
+
for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
|
|
1115
|
+
uint64_t tt0 = get_cy();
|
|
1116
|
+
size_t i0 = ses[cluster].start;
|
|
1117
|
+
size_t i1 = ses[cluster].end;
|
|
1118
|
+
size_t list_size = ses[cluster].list_size;
|
|
1119
|
+
nlist_visited++;
|
|
1120
|
+
int list_no = qcs[i0].list_no;
|
|
1121
|
+
|
|
1122
|
+
// re-organize LUTs and biases into the right order
|
|
1123
|
+
int nc = i1 - i0;
|
|
1124
|
+
|
|
1125
|
+
std::vector<int> q_map(nc), lut_entries(nc);
|
|
1126
|
+
AlignedTable<uint8_t> LUT(nc * dim12);
|
|
1127
|
+
memset(LUT.get(), -1, nc * dim12);
|
|
1128
|
+
int qbs = pq4_preferred_qbs(nc);
|
|
1129
|
+
|
|
1130
|
+
for (size_t i = i0; i < i1; i++) {
|
|
1131
|
+
const QC& qc = qcs[i];
|
|
1132
|
+
q_map[i - i0] = qc.qno;
|
|
1133
|
+
q_set.insert(qc.qno);
|
|
1134
|
+
int ij = qc.qno * nprobe + qc.rank;
|
|
1135
|
+
lut_entries[i - i0] = single_LUT ? qc.qno : ij;
|
|
1136
|
+
if (biases.get()) {
|
|
1137
|
+
tmp_bias[i - i0] = biases[ij];
|
|
1138
|
+
}
|
|
1139
|
+
}
|
|
1140
|
+
pq4_pack_LUT_qbs_q_map(
|
|
1141
|
+
qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
|
|
1142
|
+
|
|
1143
|
+
// access the inverted list
|
|
1144
|
+
|
|
1145
|
+
ndis += (i1 - i0) * list_size;
|
|
1146
|
+
|
|
1147
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
1148
|
+
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
1149
|
+
|
|
1150
|
+
// prepare the handler
|
|
1151
|
+
|
|
1152
|
+
handler->ntotal = list_size;
|
|
1153
|
+
handler->q_map = q_map.data();
|
|
1154
|
+
handler->id_map = ids.get();
|
|
1155
|
+
uint64_t tt1 = get_cy();
|
|
1156
|
+
|
|
1157
|
+
#define DISPATCH(classHC) \
|
|
1158
|
+
if (dynamic_cast<classHC*>(handler.get())) { \
|
|
1159
|
+
auto* res = static_cast<classHC*>(handler.get()); \
|
|
1160
|
+
pq4_accumulate_loop_qbs( \
|
|
1161
|
+
qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
|
|
1162
|
+
}
|
|
1163
|
+
DISPATCH(HeapHC)
|
|
1164
|
+
else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
|
|
1165
|
+
|
|
1166
|
+
uint64_t tt2 = get_cy();
|
|
1167
|
+
t_copy_pack += tt1 - tt0;
|
|
1168
|
+
t_scan += tt2 - tt1;
|
|
1169
|
+
}
|
|
1170
|
+
|
|
1171
|
+
// labels is in-place for HeapHC
|
|
1172
|
+
handler->to_flat_arrays(
|
|
1173
|
+
local_dis.data(),
|
|
1174
|
+
local_idx.data(),
|
|
1175
|
+
skip & 16 ? nullptr : normalizers.get());
|
|
1176
|
+
|
|
1177
|
+
#pragma omp single
|
|
1178
|
+
{
|
|
1179
|
+
// we init the results as a heap
|
|
1180
|
+
for (idx_t i = 0; i < n; i++) {
|
|
1181
|
+
init_result(distances + i * k, labels + i * k);
|
|
1182
|
+
}
|
|
1183
|
+
}
|
|
1184
|
+
#pragma omp barrier
|
|
1185
|
+
#pragma omp critical
|
|
1186
|
+
{
|
|
1187
|
+
// write to global heap #go over only the queries
|
|
1188
|
+
for (std::set<int>::iterator it = q_set.begin(); it != q_set.end();
|
|
1189
|
+
++it) {
|
|
1190
|
+
add_local_results(
|
|
1191
|
+
local_dis.data() + *it * k,
|
|
1192
|
+
local_idx.data() + *it * k,
|
|
1193
|
+
distances + *it * k,
|
|
1194
|
+
labels + *it * k);
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
IVFFastScan_stats.t_copy_pack += t_copy_pack;
|
|
1198
|
+
IVFFastScan_stats.t_scan += t_scan;
|
|
1199
|
+
|
|
1200
|
+
if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
|
|
1201
|
+
for (int i = 0; i < 4; i++) {
|
|
1202
|
+
IVFFastScan_stats.reservoir_times[i] += rh->times[i];
|
|
1203
|
+
}
|
|
1204
|
+
}
|
|
1205
|
+
}
|
|
1206
|
+
#pragma omp barrier
|
|
1207
|
+
#pragma omp single
|
|
1208
|
+
{
|
|
1209
|
+
for (idx_t i = 0; i < n; i++) {
|
|
1210
|
+
reorder_result(distances + i * k, labels + i * k);
|
|
1211
|
+
}
|
|
1212
|
+
}
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
indexIVF_stats.nq += n;
|
|
1216
|
+
indexIVF_stats.ndis += ndis;
|
|
1217
|
+
indexIVF_stats.nlist += nlist_visited;
|
|
1218
|
+
}
|
|
1219
|
+
|
|
1220
|
+
void IndexIVFFastScan::reconstruct_from_offset(
|
|
1221
|
+
int64_t list_no,
|
|
1222
|
+
int64_t offset,
|
|
1223
|
+
float* recons) const {
|
|
1224
|
+
// unpack codes
|
|
1225
|
+
InvertedLists::ScopedCodes list_codes(invlists, list_no);
|
|
1226
|
+
std::vector<uint8_t> code(code_size, 0);
|
|
1227
|
+
BitstringWriter bsw(code.data(), code_size);
|
|
1228
|
+
for (size_t m = 0; m < M; m++) {
|
|
1229
|
+
uint8_t c =
|
|
1230
|
+
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
|
|
1231
|
+
bsw.write(c, nbits);
|
|
1232
|
+
}
|
|
1233
|
+
sa_decode(1, code.data(), recons);
|
|
1234
|
+
|
|
1235
|
+
// add centroid to it
|
|
1236
|
+
if (by_residual) {
|
|
1237
|
+
std::vector<float> centroid(d);
|
|
1238
|
+
quantizer->reconstruct(list_no, centroid.data());
|
|
1239
|
+
for (int i = 0; i < d; ++i) {
|
|
1240
|
+
recons[i] += centroid[i];
|
|
1241
|
+
}
|
|
1242
|
+
}
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
void IndexIVFFastScan::reconstruct_orig_invlists() {
|
|
1246
|
+
FAISS_THROW_IF_NOT(orig_invlists != nullptr);
|
|
1247
|
+
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);
|
|
1248
|
+
|
|
1249
|
+
for (size_t list_no = 0; list_no < nlist; list_no++) {
|
|
1250
|
+
InvertedLists::ScopedCodes codes(invlists, list_no);
|
|
1251
|
+
InvertedLists::ScopedIds ids(invlists, list_no);
|
|
1252
|
+
size_t list_size = orig_invlists->list_size(list_no);
|
|
1253
|
+
std::vector<uint8_t> code(code_size, 0);
|
|
1254
|
+
|
|
1255
|
+
for (size_t offset = 0; offset < list_size; offset++) {
|
|
1256
|
+
// unpack codes
|
|
1257
|
+
BitstringWriter bsw(code.data(), code_size);
|
|
1258
|
+
for (size_t m = 0; m < M; m++) {
|
|
1259
|
+
uint8_t c =
|
|
1260
|
+
pq4_get_packed_element(codes.get(), bbs, M2, offset, m);
|
|
1261
|
+
bsw.write(c, nbits);
|
|
1262
|
+
}
|
|
1263
|
+
|
|
1264
|
+
// get id
|
|
1265
|
+
idx_t id = ids.get()[offset];
|
|
1266
|
+
|
|
1267
|
+
orig_invlists->add_entry(list_no, id, code.data());
|
|
1268
|
+
}
|
|
1269
|
+
}
|
|
1270
|
+
}
|
|
1271
|
+
|
|
1272
|
+
IVFFastScanStats IVFFastScan_stats;
|
|
1273
|
+
|
|
1274
|
+
template void IndexIVFFastScan::search_dispatch_implem<true, NormTableScaler>(
|
|
1275
|
+
idx_t n,
|
|
1276
|
+
const float* x,
|
|
1277
|
+
idx_t k,
|
|
1278
|
+
float* distances,
|
|
1279
|
+
idx_t* labels,
|
|
1280
|
+
const NormTableScaler& scaler) const;
|
|
1281
|
+
|
|
1282
|
+
template void IndexIVFFastScan::search_dispatch_implem<false, NormTableScaler>(
|
|
1283
|
+
idx_t n,
|
|
1284
|
+
const float* x,
|
|
1285
|
+
idx_t k,
|
|
1286
|
+
float* distances,
|
|
1287
|
+
idx_t* labels,
|
|
1288
|
+
const NormTableScaler& scaler) const;
|
|
1289
|
+
|
|
1290
|
+
} // namespace faiss
|