faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -14,17 +14,11 @@
|
|
|
14
14
|
#include <omp.h>
|
|
15
15
|
|
|
16
16
|
#include <faiss/impl/FaissAssert.h>
|
|
17
|
-
#include <faiss/utils/random.h>
|
|
18
|
-
#include <faiss/utils/utils.h>
|
|
19
|
-
|
|
20
17
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
21
|
-
#include <faiss/
|
|
22
|
-
#include <faiss/utils/quantize_lut.h>
|
|
18
|
+
#include <faiss/utils/utils.h>
|
|
23
19
|
|
|
24
20
|
namespace faiss {
|
|
25
21
|
|
|
26
|
-
using namespace simd_result_handlers;
|
|
27
|
-
|
|
28
22
|
inline size_t roundup(size_t a, size_t b) {
|
|
29
23
|
return (a + b - 1) / b * b;
|
|
30
24
|
}
|
|
@@ -35,37 +29,19 @@ IndexPQFastScan::IndexPQFastScan(
|
|
|
35
29
|
size_t nbits,
|
|
36
30
|
MetricType metric,
|
|
37
31
|
int bbs)
|
|
38
|
-
:
|
|
39
|
-
|
|
40
|
-
bbs(bbs),
|
|
41
|
-
ntotal2(0),
|
|
42
|
-
M2(roundup(M, 2)) {
|
|
43
|
-
FAISS_THROW_IF_NOT(nbits == 4);
|
|
44
|
-
is_trained = false;
|
|
32
|
+
: pq(d, M, nbits) {
|
|
33
|
+
init_fastscan(d, M, nbits, metric, bbs);
|
|
45
34
|
}
|
|
46
35
|
|
|
47
|
-
IndexPQFastScan::IndexPQFastScan(
|
|
48
|
-
|
|
49
|
-
IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs)
|
|
50
|
-
: Index(orig.d, orig.metric_type), pq(orig.pq), bbs(bbs) {
|
|
51
|
-
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
|
|
36
|
+
IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs) : pq(orig.pq) {
|
|
37
|
+
init_fastscan(orig.d, pq.M, pq.nbits, orig.metric_type, bbs);
|
|
52
38
|
ntotal = orig.ntotal;
|
|
39
|
+
ntotal2 = roundup(ntotal, bbs);
|
|
53
40
|
is_trained = orig.is_trained;
|
|
54
41
|
orig_codes = orig.codes.data();
|
|
55
42
|
|
|
56
|
-
qbs = 0; // means use default
|
|
57
|
-
|
|
58
43
|
// pack the codes
|
|
59
|
-
|
|
60
|
-
size_t M = pq.M;
|
|
61
|
-
|
|
62
|
-
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
63
|
-
M2 = roundup(M, 2);
|
|
64
|
-
ntotal2 = roundup(ntotal, bbs);
|
|
65
|
-
|
|
66
44
|
codes.resize(ntotal2 * M2 / 2);
|
|
67
|
-
|
|
68
|
-
// printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
|
|
69
45
|
pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get());
|
|
70
46
|
}
|
|
71
47
|
|
|
@@ -77,433 +53,22 @@ void IndexPQFastScan::train(idx_t n, const float* x) {
|
|
|
77
53
|
is_trained = true;
|
|
78
54
|
}
|
|
79
55
|
|
|
80
|
-
void IndexPQFastScan::
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
pq.compute_codes(x, tmp_codes.get(), n);
|
|
84
|
-
ntotal2 = roundup(ntotal + n, bbs);
|
|
85
|
-
size_t new_size = ntotal2 * M2 / 2;
|
|
86
|
-
size_t old_size = codes.size();
|
|
87
|
-
if (new_size > old_size) {
|
|
88
|
-
codes.resize(new_size);
|
|
89
|
-
memset(codes.get() + old_size, 0, new_size - old_size);
|
|
90
|
-
}
|
|
91
|
-
pq4_pack_codes_range(
|
|
92
|
-
tmp_codes.get(), pq.M, ntotal, ntotal + n, bbs, M2, codes.get());
|
|
93
|
-
ntotal += n;
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
void IndexPQFastScan::reset() {
|
|
97
|
-
codes.resize(0);
|
|
98
|
-
ntotal = 0;
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
namespace {
|
|
102
|
-
|
|
103
|
-
// from impl/ProductQuantizer.cpp
|
|
104
|
-
template <class C, typename dis_t>
|
|
105
|
-
void pq_estimators_from_tables_generic(
|
|
106
|
-
const ProductQuantizer& pq,
|
|
107
|
-
size_t nbits,
|
|
108
|
-
const uint8_t* codes,
|
|
109
|
-
size_t ncodes,
|
|
110
|
-
const dis_t* dis_table,
|
|
111
|
-
size_t k,
|
|
112
|
-
typename C::T* heap_dis,
|
|
113
|
-
int64_t* heap_ids) {
|
|
114
|
-
using accu_t = typename C::T;
|
|
115
|
-
const size_t M = pq.M;
|
|
116
|
-
const size_t ksub = pq.ksub;
|
|
117
|
-
for (size_t j = 0; j < ncodes; ++j) {
|
|
118
|
-
PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
|
|
119
|
-
accu_t dis = 0;
|
|
120
|
-
const dis_t* __restrict dt = dis_table;
|
|
121
|
-
for (size_t m = 0; m < M; m++) {
|
|
122
|
-
uint64_t c = decoder.decode();
|
|
123
|
-
dis += dt[c];
|
|
124
|
-
dt += ksub;
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
if (C::cmp(heap_dis[0], dis)) {
|
|
128
|
-
heap_pop<C>(k, heap_dis, heap_ids);
|
|
129
|
-
heap_push<C>(k, heap_dis, heap_ids, dis, j);
|
|
130
|
-
}
|
|
131
|
-
}
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
} // anonymous namespace
|
|
135
|
-
|
|
136
|
-
using namespace quantize_lut;
|
|
137
|
-
|
|
138
|
-
void IndexPQFastScan::compute_quantized_LUT(
|
|
139
|
-
idx_t n,
|
|
140
|
-
const float* x,
|
|
141
|
-
uint8_t* lut,
|
|
142
|
-
float* normalizers) const {
|
|
143
|
-
size_t dim12 = pq.ksub * pq.M;
|
|
144
|
-
std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
|
|
145
|
-
if (metric_type == METRIC_L2) {
|
|
146
|
-
pq.compute_distance_tables(n, x, dis_tables.get());
|
|
147
|
-
} else {
|
|
148
|
-
pq.compute_inner_prod_tables(n, x, dis_tables.get());
|
|
149
|
-
}
|
|
150
|
-
|
|
151
|
-
for (uint64_t i = 0; i < n; i++) {
|
|
152
|
-
round_uint8_per_column(
|
|
153
|
-
dis_tables.get() + i * dim12,
|
|
154
|
-
pq.M,
|
|
155
|
-
pq.ksub,
|
|
156
|
-
&normalizers[2 * i],
|
|
157
|
-
&normalizers[2 * i + 1]);
|
|
158
|
-
}
|
|
159
|
-
|
|
160
|
-
for (uint64_t i = 0; i < n; i++) {
|
|
161
|
-
const float* t_in = dis_tables.get() + i * dim12;
|
|
162
|
-
uint8_t* t_out = lut + i * M2 * pq.ksub;
|
|
163
|
-
|
|
164
|
-
for (int j = 0; j < dim12; j++) {
|
|
165
|
-
t_out[j] = int(t_in[j]);
|
|
166
|
-
}
|
|
167
|
-
memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
|
|
168
|
-
}
|
|
56
|
+
void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
|
|
57
|
+
const {
|
|
58
|
+
pq.compute_codes(x, codes, n);
|
|
169
59
|
}
|
|
170
60
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
******************************************************************************/
|
|
174
|
-
|
|
175
|
-
void IndexPQFastScan::search(
|
|
176
|
-
idx_t n,
|
|
177
|
-
const float* x,
|
|
178
|
-
idx_t k,
|
|
179
|
-
float* distances,
|
|
180
|
-
idx_t* labels) const {
|
|
181
|
-
FAISS_THROW_IF_NOT(k > 0);
|
|
182
|
-
|
|
61
|
+
void IndexPQFastScan::compute_float_LUT(float* lut, idx_t n, const float* x)
|
|
62
|
+
const {
|
|
183
63
|
if (metric_type == METRIC_L2) {
|
|
184
|
-
|
|
64
|
+
pq.compute_distance_tables(n, x, lut);
|
|
185
65
|
} else {
|
|
186
|
-
|
|
66
|
+
pq.compute_inner_prod_tables(n, x, lut);
|
|
187
67
|
}
|
|
188
68
|
}
|
|
189
69
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
idx_t n,
|
|
193
|
-
const float* x,
|
|
194
|
-
idx_t k,
|
|
195
|
-
float* distances,
|
|
196
|
-
idx_t* labels) const {
|
|
197
|
-
using Cfloat = typename std::conditional<
|
|
198
|
-
is_max,
|
|
199
|
-
CMax<float, int64_t>,
|
|
200
|
-
CMin<float, int64_t>>::type;
|
|
201
|
-
|
|
202
|
-
using C = typename std::
|
|
203
|
-
conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
|
|
204
|
-
|
|
205
|
-
if (n == 0) {
|
|
206
|
-
return;
|
|
207
|
-
}
|
|
208
|
-
|
|
209
|
-
// actual implementation used
|
|
210
|
-
int impl = implem;
|
|
211
|
-
|
|
212
|
-
if (impl == 0) {
|
|
213
|
-
if (bbs == 32) {
|
|
214
|
-
impl = 12;
|
|
215
|
-
} else {
|
|
216
|
-
impl = 14;
|
|
217
|
-
}
|
|
218
|
-
if (k > 20) {
|
|
219
|
-
impl++;
|
|
220
|
-
}
|
|
221
|
-
}
|
|
222
|
-
|
|
223
|
-
if (implem == 1) {
|
|
224
|
-
FAISS_THROW_IF_NOT(orig_codes);
|
|
225
|
-
FAISS_THROW_IF_NOT(is_max);
|
|
226
|
-
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
227
|
-
pq.search(x, n, orig_codes, ntotal, &res, true);
|
|
228
|
-
} else if (implem == 2 || implem == 3 || implem == 4) {
|
|
229
|
-
FAISS_THROW_IF_NOT(orig_codes);
|
|
230
|
-
|
|
231
|
-
size_t dim12 = pq.ksub * pq.M;
|
|
232
|
-
std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
|
|
233
|
-
if (is_max) {
|
|
234
|
-
pq.compute_distance_tables(n, x, dis_tables.get());
|
|
235
|
-
} else {
|
|
236
|
-
pq.compute_inner_prod_tables(n, x, dis_tables.get());
|
|
237
|
-
}
|
|
238
|
-
|
|
239
|
-
std::vector<float> normalizers(n * 2);
|
|
240
|
-
|
|
241
|
-
if (implem == 2) {
|
|
242
|
-
// default float
|
|
243
|
-
} else if (implem == 3 || implem == 4) {
|
|
244
|
-
for (uint64_t i = 0; i < n; i++) {
|
|
245
|
-
round_uint8_per_column(
|
|
246
|
-
dis_tables.get() + i * dim12,
|
|
247
|
-
pq.M,
|
|
248
|
-
pq.ksub,
|
|
249
|
-
&normalizers[2 * i],
|
|
250
|
-
&normalizers[2 * i + 1]);
|
|
251
|
-
}
|
|
252
|
-
}
|
|
253
|
-
|
|
254
|
-
for (int64_t i = 0; i < n; i++) {
|
|
255
|
-
int64_t* heap_ids = labels + i * k;
|
|
256
|
-
float* heap_dis = distances + i * k;
|
|
257
|
-
|
|
258
|
-
heap_heapify<Cfloat>(k, heap_dis, heap_ids);
|
|
259
|
-
|
|
260
|
-
pq_estimators_from_tables_generic<Cfloat>(
|
|
261
|
-
pq,
|
|
262
|
-
pq.nbits,
|
|
263
|
-
orig_codes,
|
|
264
|
-
ntotal,
|
|
265
|
-
dis_tables.get() + i * dim12,
|
|
266
|
-
k,
|
|
267
|
-
heap_dis,
|
|
268
|
-
heap_ids);
|
|
269
|
-
|
|
270
|
-
heap_reorder<Cfloat>(k, heap_dis, heap_ids);
|
|
271
|
-
|
|
272
|
-
if (implem == 4) {
|
|
273
|
-
float a = normalizers[2 * i];
|
|
274
|
-
float b = normalizers[2 * i + 1];
|
|
275
|
-
|
|
276
|
-
for (int j = 0; j < k; j++) {
|
|
277
|
-
heap_dis[j] = heap_dis[j] / a + b;
|
|
278
|
-
}
|
|
279
|
-
}
|
|
280
|
-
}
|
|
281
|
-
} else if (impl >= 12 && impl <= 15) {
|
|
282
|
-
FAISS_THROW_IF_NOT(ntotal < INT_MAX);
|
|
283
|
-
int nt = std::min(omp_get_max_threads(), int(n));
|
|
284
|
-
if (nt < 2) {
|
|
285
|
-
if (impl == 12 || impl == 13) {
|
|
286
|
-
search_implem_12<C>(n, x, k, distances, labels, impl);
|
|
287
|
-
} else {
|
|
288
|
-
search_implem_14<C>(n, x, k, distances, labels, impl);
|
|
289
|
-
}
|
|
290
|
-
} else {
|
|
291
|
-
// explicitly slice over threads
|
|
292
|
-
#pragma omp parallel for num_threads(nt)
|
|
293
|
-
for (int slice = 0; slice < nt; slice++) {
|
|
294
|
-
idx_t i0 = n * slice / nt;
|
|
295
|
-
idx_t i1 = n * (slice + 1) / nt;
|
|
296
|
-
float* dis_i = distances + i0 * k;
|
|
297
|
-
idx_t* lab_i = labels + i0 * k;
|
|
298
|
-
if (impl == 12 || impl == 13) {
|
|
299
|
-
search_implem_12<C>(
|
|
300
|
-
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
|
|
301
|
-
} else {
|
|
302
|
-
search_implem_14<C>(
|
|
303
|
-
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
|
|
304
|
-
}
|
|
305
|
-
}
|
|
306
|
-
}
|
|
307
|
-
} else {
|
|
308
|
-
FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
|
|
309
|
-
}
|
|
310
|
-
}
|
|
311
|
-
|
|
312
|
-
template <class C>
|
|
313
|
-
void IndexPQFastScan::search_implem_12(
|
|
314
|
-
idx_t n,
|
|
315
|
-
const float* x,
|
|
316
|
-
idx_t k,
|
|
317
|
-
float* distances,
|
|
318
|
-
idx_t* labels,
|
|
319
|
-
int impl) const {
|
|
320
|
-
FAISS_THROW_IF_NOT(bbs == 32);
|
|
321
|
-
|
|
322
|
-
// handle qbs2 blocking by recursive call
|
|
323
|
-
int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
|
|
324
|
-
if (n > qbs2) {
|
|
325
|
-
for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
|
|
326
|
-
int64_t i1 = std::min(i0 + qbs2, n);
|
|
327
|
-
search_implem_12<C>(
|
|
328
|
-
i1 - i0,
|
|
329
|
-
x + d * i0,
|
|
330
|
-
k,
|
|
331
|
-
distances + i0 * k,
|
|
332
|
-
labels + i0 * k,
|
|
333
|
-
impl);
|
|
334
|
-
}
|
|
335
|
-
return;
|
|
336
|
-
}
|
|
337
|
-
|
|
338
|
-
size_t dim12 = pq.ksub * M2;
|
|
339
|
-
AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
|
|
340
|
-
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
341
|
-
|
|
342
|
-
if (skip & 1) {
|
|
343
|
-
quantized_dis_tables.clear();
|
|
344
|
-
} else {
|
|
345
|
-
compute_quantized_LUT(
|
|
346
|
-
n, x, quantized_dis_tables.get(), normalizers.get());
|
|
347
|
-
}
|
|
348
|
-
|
|
349
|
-
AlignedTable<uint8_t> LUT(n * dim12);
|
|
350
|
-
|
|
351
|
-
// block sizes are encoded in qbs, 4 bits at a time
|
|
352
|
-
|
|
353
|
-
// caution: we override an object field
|
|
354
|
-
int qbs = this->qbs;
|
|
355
|
-
|
|
356
|
-
if (n != pq4_qbs_to_nq(qbs)) {
|
|
357
|
-
qbs = pq4_preferred_qbs(n);
|
|
358
|
-
}
|
|
359
|
-
|
|
360
|
-
int LUT_nq =
|
|
361
|
-
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
|
|
362
|
-
FAISS_THROW_IF_NOT(LUT_nq == n);
|
|
363
|
-
|
|
364
|
-
if (k == 1) {
|
|
365
|
-
SingleResultHandler<C> handler(n, ntotal);
|
|
366
|
-
if (skip & 4) {
|
|
367
|
-
// pass
|
|
368
|
-
} else {
|
|
369
|
-
handler.disable = bool(skip & 2);
|
|
370
|
-
pq4_accumulate_loop_qbs(
|
|
371
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
|
|
372
|
-
}
|
|
373
|
-
|
|
374
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
375
|
-
|
|
376
|
-
} else if (impl == 12) {
|
|
377
|
-
std::vector<uint16_t> tmp_dis(n * k);
|
|
378
|
-
std::vector<int32_t> tmp_ids(n * k);
|
|
379
|
-
|
|
380
|
-
if (skip & 4) {
|
|
381
|
-
// skip
|
|
382
|
-
} else {
|
|
383
|
-
HeapHandler<C> handler(
|
|
384
|
-
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
|
385
|
-
handler.disable = bool(skip & 2);
|
|
386
|
-
|
|
387
|
-
pq4_accumulate_loop_qbs(
|
|
388
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
|
|
389
|
-
|
|
390
|
-
if (!(skip & 8)) {
|
|
391
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
392
|
-
}
|
|
393
|
-
}
|
|
394
|
-
|
|
395
|
-
} else { // impl == 13
|
|
396
|
-
|
|
397
|
-
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
|
398
|
-
handler.disable = bool(skip & 2);
|
|
399
|
-
|
|
400
|
-
if (skip & 4) {
|
|
401
|
-
// skip
|
|
402
|
-
} else {
|
|
403
|
-
pq4_accumulate_loop_qbs(
|
|
404
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
|
|
405
|
-
}
|
|
406
|
-
|
|
407
|
-
if (!(skip & 8)) {
|
|
408
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
409
|
-
}
|
|
410
|
-
|
|
411
|
-
FastScan_stats.t0 += handler.times[0];
|
|
412
|
-
FastScan_stats.t1 += handler.times[1];
|
|
413
|
-
FastScan_stats.t2 += handler.times[2];
|
|
414
|
-
FastScan_stats.t3 += handler.times[3];
|
|
415
|
-
}
|
|
416
|
-
}
|
|
417
|
-
|
|
418
|
-
FastScanStats FastScan_stats;
|
|
419
|
-
|
|
420
|
-
template <class C>
|
|
421
|
-
void IndexPQFastScan::search_implem_14(
|
|
422
|
-
idx_t n,
|
|
423
|
-
const float* x,
|
|
424
|
-
idx_t k,
|
|
425
|
-
float* distances,
|
|
426
|
-
idx_t* labels,
|
|
427
|
-
int impl) const {
|
|
428
|
-
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
429
|
-
|
|
430
|
-
int qbs2 = qbs == 0 ? 4 : qbs;
|
|
431
|
-
|
|
432
|
-
// handle qbs2 blocking by recursive call
|
|
433
|
-
if (n > qbs2) {
|
|
434
|
-
for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
|
|
435
|
-
int64_t i1 = std::min(i0 + qbs2, n);
|
|
436
|
-
search_implem_14<C>(
|
|
437
|
-
i1 - i0,
|
|
438
|
-
x + d * i0,
|
|
439
|
-
k,
|
|
440
|
-
distances + i0 * k,
|
|
441
|
-
labels + i0 * k,
|
|
442
|
-
impl);
|
|
443
|
-
}
|
|
444
|
-
return;
|
|
445
|
-
}
|
|
446
|
-
|
|
447
|
-
size_t dim12 = pq.ksub * M2;
|
|
448
|
-
AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
|
|
449
|
-
std::unique_ptr<float[]> normalizers(new float[2 * n]);
|
|
450
|
-
|
|
451
|
-
if (skip & 1) {
|
|
452
|
-
quantized_dis_tables.clear();
|
|
453
|
-
} else {
|
|
454
|
-
compute_quantized_LUT(
|
|
455
|
-
n, x, quantized_dis_tables.get(), normalizers.get());
|
|
456
|
-
}
|
|
457
|
-
|
|
458
|
-
AlignedTable<uint8_t> LUT(n * dim12);
|
|
459
|
-
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
|
|
460
|
-
|
|
461
|
-
if (k == 1) {
|
|
462
|
-
SingleResultHandler<C> handler(n, ntotal);
|
|
463
|
-
if (skip & 4) {
|
|
464
|
-
// pass
|
|
465
|
-
} else {
|
|
466
|
-
handler.disable = bool(skip & 2);
|
|
467
|
-
pq4_accumulate_loop(
|
|
468
|
-
n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
|
|
469
|
-
}
|
|
470
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
471
|
-
|
|
472
|
-
} else if (impl == 14) {
|
|
473
|
-
std::vector<uint16_t> tmp_dis(n * k);
|
|
474
|
-
std::vector<int32_t> tmp_ids(n * k);
|
|
475
|
-
|
|
476
|
-
if (skip & 4) {
|
|
477
|
-
// skip
|
|
478
|
-
} else if (k > 1) {
|
|
479
|
-
HeapHandler<C> handler(
|
|
480
|
-
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
|
481
|
-
handler.disable = bool(skip & 2);
|
|
482
|
-
|
|
483
|
-
pq4_accumulate_loop(
|
|
484
|
-
n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
|
|
485
|
-
|
|
486
|
-
if (!(skip & 8)) {
|
|
487
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
488
|
-
}
|
|
489
|
-
}
|
|
490
|
-
|
|
491
|
-
} else { // impl == 15
|
|
492
|
-
|
|
493
|
-
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
|
494
|
-
handler.disable = bool(skip & 2);
|
|
495
|
-
|
|
496
|
-
if (skip & 4) {
|
|
497
|
-
// skip
|
|
498
|
-
} else {
|
|
499
|
-
pq4_accumulate_loop(
|
|
500
|
-
n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
|
|
501
|
-
}
|
|
502
|
-
|
|
503
|
-
if (!(skip & 8)) {
|
|
504
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
505
|
-
}
|
|
506
|
-
}
|
|
70
|
+
void IndexPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
71
|
+
pq.decode(bytes, x, n);
|
|
507
72
|
}
|
|
508
73
|
|
|
509
74
|
} // namespace faiss
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#pragma once
|
|
9
9
|
|
|
10
|
+
#include <faiss/IndexFastScan.h>
|
|
10
11
|
#include <faiss/IndexPQ.h>
|
|
11
12
|
#include <faiss/impl/ProductQuantizer.h>
|
|
12
13
|
#include <faiss/utils/AlignedTable.h>
|
|
@@ -25,27 +26,9 @@ namespace faiss {
|
|
|
25
26
|
* 15: no qbs with reservoir accumulator
|
|
26
27
|
*/
|
|
27
28
|
|
|
28
|
-
struct IndexPQFastScan :
|
|
29
|
+
struct IndexPQFastScan : IndexFastScan {
|
|
29
30
|
ProductQuantizer pq;
|
|
30
31
|
|
|
31
|
-
// implementation to select
|
|
32
|
-
int implem = 0;
|
|
33
|
-
// skip some parts of the computation (for timing)
|
|
34
|
-
int skip = 0;
|
|
35
|
-
|
|
36
|
-
// size of the kernel
|
|
37
|
-
int bbs; // set at build time
|
|
38
|
-
int qbs = 0; // query block size 0 = use default
|
|
39
|
-
|
|
40
|
-
// packed version of the codes
|
|
41
|
-
size_t ntotal2;
|
|
42
|
-
size_t M2;
|
|
43
|
-
|
|
44
|
-
AlignedTable<uint8_t> codes;
|
|
45
|
-
|
|
46
|
-
// this is for testing purposes only (set when initialized by IndexPQ)
|
|
47
|
-
const uint8_t* orig_codes = nullptr;
|
|
48
|
-
|
|
49
32
|
IndexPQFastScan(
|
|
50
33
|
int d,
|
|
51
34
|
size_t M,
|
|
@@ -53,73 +36,27 @@ struct IndexPQFastScan : Index {
|
|
|
53
36
|
MetricType metric = METRIC_L2,
|
|
54
37
|
int bbs = 32);
|
|
55
38
|
|
|
56
|
-
IndexPQFastScan();
|
|
39
|
+
IndexPQFastScan() = default;
|
|
57
40
|
|
|
58
41
|
/// build from an existing IndexPQ
|
|
59
42
|
explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
|
|
60
43
|
|
|
61
44
|
void train(idx_t n, const float* x) override;
|
|
62
|
-
void add(idx_t n, const float* x) override;
|
|
63
|
-
void reset() override;
|
|
64
|
-
void search(
|
|
65
|
-
idx_t n,
|
|
66
|
-
const float* x,
|
|
67
|
-
idx_t k,
|
|
68
|
-
float* distances,
|
|
69
|
-
idx_t* labels) const override;
|
|
70
45
|
|
|
71
|
-
|
|
72
|
-
void compute_quantized_LUT(
|
|
73
|
-
idx_t n,
|
|
74
|
-
const float* x,
|
|
75
|
-
uint8_t* lut,
|
|
76
|
-
float* normalizers) const;
|
|
46
|
+
void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
|
|
77
47
|
|
|
78
|
-
|
|
79
|
-
void search_dispatch_implem(
|
|
80
|
-
idx_t n,
|
|
81
|
-
const float* x,
|
|
82
|
-
idx_t k,
|
|
83
|
-
float* distances,
|
|
84
|
-
idx_t* labels) const;
|
|
48
|
+
void compute_float_LUT(float* lut, idx_t n, const float* x) const override;
|
|
85
49
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
void
|
|
96
|
-
idx_t n,
|
|
97
|
-
const float* x,
|
|
98
|
-
idx_t k,
|
|
99
|
-
float* distances,
|
|
100
|
-
idx_t* labels,
|
|
101
|
-
int impl) const;
|
|
102
|
-
|
|
103
|
-
template <class C>
|
|
104
|
-
void search_implem_14(
|
|
105
|
-
idx_t n,
|
|
106
|
-
const float* x,
|
|
107
|
-
idx_t k,
|
|
108
|
-
float* distances,
|
|
109
|
-
idx_t* labels,
|
|
110
|
-
int impl) const;
|
|
50
|
+
/** Decode a set of vectors.
|
|
51
|
+
*
|
|
52
|
+
* NOTE: The codes in the IndexPQFastScan object are non-contiguous.
|
|
53
|
+
* But this method requires a contiguous representation.
|
|
54
|
+
*
|
|
55
|
+
* @param n number of vectors
|
|
56
|
+
* @param bytes input encoded vectors, size n * code_size
|
|
57
|
+
* @param x output vectors, size n * d
|
|
58
|
+
*/
|
|
59
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
111
60
|
};
|
|
112
61
|
|
|
113
|
-
struct FastScanStats {
|
|
114
|
-
uint64_t t0, t1, t2, t3;
|
|
115
|
-
FastScanStats() {
|
|
116
|
-
reset();
|
|
117
|
-
}
|
|
118
|
-
void reset() {
|
|
119
|
-
memset(this, 0, sizeof(*this));
|
|
120
|
-
}
|
|
121
|
-
};
|
|
122
|
-
|
|
123
|
-
FAISS_API extern FastScanStats FastScan_stats;
|
|
124
|
-
|
|
125
62
|
} // namespace faiss
|