faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -17,7 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
#include <algorithm>
|
|
19
19
|
|
|
20
|
+
#include <faiss/Clustering.h>
|
|
20
21
|
#include <faiss/impl/FaissAssert.h>
|
|
22
|
+
#include <faiss/impl/LocalSearchQuantizer.h>
|
|
23
|
+
#include <faiss/impl/ResidualQuantizer.h>
|
|
21
24
|
#include <faiss/utils/Heap.h>
|
|
22
25
|
#include <faiss/utils/distances.h>
|
|
23
26
|
#include <faiss/utils/hamming.h>
|
|
@@ -48,14 +51,14 @@ AdditiveQuantizer::AdditiveQuantizer(
|
|
|
48
51
|
size_t d,
|
|
49
52
|
const std::vector<size_t>& nbits,
|
|
50
53
|
Search_type_t search_type)
|
|
51
|
-
:
|
|
54
|
+
: Quantizer(d),
|
|
52
55
|
M(nbits.size()),
|
|
53
56
|
nbits(nbits),
|
|
54
57
|
verbose(false),
|
|
55
58
|
is_trained(false),
|
|
59
|
+
max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
|
|
56
60
|
search_type(search_type) {
|
|
57
61
|
norm_max = norm_min = NAN;
|
|
58
|
-
code_size = 0;
|
|
59
62
|
tot_bits = 0;
|
|
60
63
|
total_codebook_size = 0;
|
|
61
64
|
only_8bit = false;
|
|
@@ -80,27 +83,82 @@ void AdditiveQuantizer::set_derived_values() {
|
|
|
80
83
|
}
|
|
81
84
|
total_codebook_size = codebook_offsets[M];
|
|
82
85
|
switch (search_type) {
|
|
83
|
-
case ST_decompress:
|
|
84
|
-
case ST_LUT_nonorm:
|
|
85
|
-
case ST_norm_from_LUT:
|
|
86
|
-
break; // nothing to add
|
|
87
86
|
case ST_norm_float:
|
|
88
|
-
|
|
87
|
+
norm_bits = 32;
|
|
89
88
|
break;
|
|
90
89
|
case ST_norm_qint8:
|
|
91
90
|
case ST_norm_cqint8:
|
|
92
|
-
|
|
91
|
+
case ST_norm_lsq2x4:
|
|
92
|
+
case ST_norm_rq2x4:
|
|
93
|
+
norm_bits = 8;
|
|
93
94
|
break;
|
|
94
95
|
case ST_norm_qint4:
|
|
95
96
|
case ST_norm_cqint4:
|
|
96
|
-
|
|
97
|
+
norm_bits = 4;
|
|
98
|
+
break;
|
|
99
|
+
case ST_decompress:
|
|
100
|
+
case ST_LUT_nonorm:
|
|
101
|
+
case ST_norm_from_LUT:
|
|
102
|
+
default:
|
|
103
|
+
norm_bits = 0;
|
|
97
104
|
break;
|
|
98
105
|
}
|
|
106
|
+
tot_bits += norm_bits;
|
|
99
107
|
|
|
100
108
|
// convert bits to bytes
|
|
101
109
|
code_size = (tot_bits + 7) / 8;
|
|
102
110
|
}
|
|
103
111
|
|
|
112
|
+
void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
|
|
113
|
+
norm_min = HUGE_VALF;
|
|
114
|
+
norm_max = -HUGE_VALF;
|
|
115
|
+
for (idx_t i = 0; i < n; i++) {
|
|
116
|
+
if (norms[i] < norm_min) {
|
|
117
|
+
norm_min = norms[i];
|
|
118
|
+
}
|
|
119
|
+
if (norms[i] > norm_max) {
|
|
120
|
+
norm_max = norms[i];
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
|
|
125
|
+
size_t k = (1 << 8);
|
|
126
|
+
if (search_type == ST_norm_cqint4) {
|
|
127
|
+
k = (1 << 4);
|
|
128
|
+
}
|
|
129
|
+
Clustering1D clus(k);
|
|
130
|
+
clus.train_exact(n, norms);
|
|
131
|
+
qnorm.add(clus.k, clus.centroids.data());
|
|
132
|
+
} else if (search_type == ST_norm_lsq2x4 || search_type == ST_norm_rq2x4) {
|
|
133
|
+
std::unique_ptr<AdditiveQuantizer> aq;
|
|
134
|
+
if (search_type == ST_norm_lsq2x4) {
|
|
135
|
+
aq.reset(new LocalSearchQuantizer(1, 2, 4));
|
|
136
|
+
} else {
|
|
137
|
+
aq.reset(new ResidualQuantizer(1, 2, 4));
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
aq->train(n, norms);
|
|
141
|
+
// flatten aq codebooks
|
|
142
|
+
std::vector<float> flat_codebooks(1 << 8);
|
|
143
|
+
FAISS_THROW_IF_NOT(aq->codebooks.size() == 32);
|
|
144
|
+
|
|
145
|
+
// save norm tables for 4-bit fastscan search
|
|
146
|
+
norm_tabs = aq->codebooks;
|
|
147
|
+
|
|
148
|
+
// assume big endian
|
|
149
|
+
const float* c = norm_tabs.data();
|
|
150
|
+
for (size_t i = 0; i < 16; i++) {
|
|
151
|
+
for (size_t j = 0; j < 16; j++) {
|
|
152
|
+
flat_codebooks[i * 16 + j] = c[j] + c[16 + i];
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
qnorm.reset();
|
|
157
|
+
qnorm.add(1 << 8, flat_codebooks.data());
|
|
158
|
+
FAISS_THROW_IF_NOT(qnorm.ntotal == (1 << 8));
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
104
162
|
namespace {
|
|
105
163
|
|
|
106
164
|
// TODO
|
|
@@ -132,7 +190,7 @@ float decode_qint4(uint8_t i, float amin, float amax) {
|
|
|
132
190
|
|
|
133
191
|
uint32_t AdditiveQuantizer::encode_qcint(float x) const {
|
|
134
192
|
idx_t id;
|
|
135
|
-
qnorm.assign(
|
|
193
|
+
qnorm.assign(1, &x, &id, 1);
|
|
136
194
|
return uint32_t(id);
|
|
137
195
|
}
|
|
138
196
|
|
|
@@ -140,23 +198,54 @@ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
|
|
|
140
198
|
return qnorm.get_xb()[c];
|
|
141
199
|
}
|
|
142
200
|
|
|
201
|
+
uint64_t AdditiveQuantizer::encode_norm(float norm) const {
|
|
202
|
+
switch (search_type) {
|
|
203
|
+
case ST_norm_float:
|
|
204
|
+
uint32_t inorm;
|
|
205
|
+
memcpy(&inorm, &norm, 4);
|
|
206
|
+
return inorm;
|
|
207
|
+
case ST_norm_qint8:
|
|
208
|
+
return encode_qint8(norm, norm_min, norm_max);
|
|
209
|
+
case ST_norm_qint4:
|
|
210
|
+
return encode_qint4(norm, norm_min, norm_max);
|
|
211
|
+
case ST_norm_lsq2x4:
|
|
212
|
+
case ST_norm_rq2x4:
|
|
213
|
+
case ST_norm_cqint8:
|
|
214
|
+
return encode_qcint(norm);
|
|
215
|
+
case ST_norm_cqint4:
|
|
216
|
+
return encode_qcint(norm);
|
|
217
|
+
case ST_decompress:
|
|
218
|
+
case ST_LUT_nonorm:
|
|
219
|
+
case ST_norm_from_LUT:
|
|
220
|
+
default:
|
|
221
|
+
return 0;
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
|
|
143
225
|
void AdditiveQuantizer::pack_codes(
|
|
144
226
|
size_t n,
|
|
145
227
|
const int32_t* codes,
|
|
146
228
|
uint8_t* packed_codes,
|
|
147
229
|
int64_t ld_codes,
|
|
148
|
-
const float* norms
|
|
230
|
+
const float* norms,
|
|
231
|
+
const float* centroids) const {
|
|
149
232
|
if (ld_codes == -1) {
|
|
150
233
|
ld_codes = M;
|
|
151
234
|
}
|
|
152
235
|
std::vector<float> norm_buf;
|
|
153
236
|
if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
|
|
154
237
|
search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
|
|
155
|
-
search_type == ST_norm_cqint4
|
|
156
|
-
|
|
238
|
+
search_type == ST_norm_cqint4 || search_type == ST_norm_lsq2x4 ||
|
|
239
|
+
search_type == ST_norm_rq2x4) {
|
|
240
|
+
if (centroids != nullptr || !norms) {
|
|
157
241
|
norm_buf.resize(n);
|
|
158
242
|
std::vector<float> x_recons(n * d);
|
|
159
243
|
decode_unpacked(codes, x_recons.data(), n, ld_codes);
|
|
244
|
+
|
|
245
|
+
if (centroids != nullptr) {
|
|
246
|
+
// x = x + c
|
|
247
|
+
fvec_add(n * d, x_recons.data(), centroids, x_recons.data());
|
|
248
|
+
}
|
|
160
249
|
fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
|
|
161
250
|
norms = norm_buf.data();
|
|
162
251
|
}
|
|
@@ -168,34 +257,8 @@ void AdditiveQuantizer::pack_codes(
|
|
|
168
257
|
for (int m = 0; m < M; m++) {
|
|
169
258
|
bsw.write(codes1[m], nbits[m]);
|
|
170
259
|
}
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
case ST_LUT_nonorm:
|
|
174
|
-
case ST_norm_from_LUT:
|
|
175
|
-
break;
|
|
176
|
-
case ST_norm_float:
|
|
177
|
-
bsw.write(*(uint32_t*)&norms[i], 32);
|
|
178
|
-
break;
|
|
179
|
-
case ST_norm_qint8: {
|
|
180
|
-
uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
|
|
181
|
-
bsw.write(b, 8);
|
|
182
|
-
break;
|
|
183
|
-
}
|
|
184
|
-
case ST_norm_qint4: {
|
|
185
|
-
uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
|
|
186
|
-
bsw.write(b, 4);
|
|
187
|
-
break;
|
|
188
|
-
}
|
|
189
|
-
case ST_norm_cqint8: {
|
|
190
|
-
uint32_t b = encode_qcint(norms[i]);
|
|
191
|
-
bsw.write(b, 8);
|
|
192
|
-
break;
|
|
193
|
-
}
|
|
194
|
-
case ST_norm_cqint4: {
|
|
195
|
-
uint32_t b = encode_qcint(norms[i]);
|
|
196
|
-
bsw.write(b, 4);
|
|
197
|
-
break;
|
|
198
|
-
}
|
|
260
|
+
if (norm_bits != 0) {
|
|
261
|
+
bsw.write(encode_norm(norms[i]), norm_bits);
|
|
199
262
|
}
|
|
200
263
|
}
|
|
201
264
|
}
|
|
@@ -283,28 +346,33 @@ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
|
|
|
283
346
|
}
|
|
284
347
|
}
|
|
285
348
|
|
|
286
|
-
void AdditiveQuantizer::compute_LUT(
|
|
287
|
-
|
|
349
|
+
void AdditiveQuantizer::compute_LUT(
|
|
350
|
+
size_t n,
|
|
351
|
+
const float* xq,
|
|
352
|
+
float* LUT,
|
|
353
|
+
float alpha,
|
|
354
|
+
long ld_lut) const {
|
|
288
355
|
// in all cases, it is large matrix multiplication
|
|
289
356
|
|
|
290
357
|
FINTEGER ncenti = total_codebook_size;
|
|
291
358
|
FINTEGER di = d;
|
|
292
359
|
FINTEGER nqi = n;
|
|
293
|
-
|
|
360
|
+
FINTEGER ldc = ld_lut > 0 ? ld_lut : ncenti;
|
|
361
|
+
float zero = 0;
|
|
294
362
|
|
|
295
363
|
sgemm_("Transposed",
|
|
296
364
|
"Not transposed",
|
|
297
365
|
&ncenti,
|
|
298
366
|
&nqi,
|
|
299
367
|
&di,
|
|
300
|
-
&
|
|
368
|
+
&alpha,
|
|
301
369
|
codebooks.data(),
|
|
302
370
|
&di,
|
|
303
371
|
xq,
|
|
304
372
|
&di,
|
|
305
373
|
&zero,
|
|
306
374
|
LUT,
|
|
307
|
-
&
|
|
375
|
+
&ldc);
|
|
308
376
|
}
|
|
309
377
|
|
|
310
378
|
namespace {
|
|
@@ -448,7 +516,8 @@ float AdditiveQuantizer::
|
|
|
448
516
|
BitstringReader bs(codes, code_size);
|
|
449
517
|
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
|
450
518
|
uint32_t norm_i = bs.read(32);
|
|
451
|
-
float norm2
|
|
519
|
+
float norm2;
|
|
520
|
+
memcpy(&norm2, &norm_i, 4);
|
|
452
521
|
return norm2 - 2 * accu;
|
|
453
522
|
}
|
|
454
523
|
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
|
|
13
13
|
#include <faiss/Index.h>
|
|
14
14
|
#include <faiss/IndexFlat.h>
|
|
15
|
+
#include <faiss/impl/Quantizer.h>
|
|
15
16
|
|
|
16
17
|
namespace faiss {
|
|
17
18
|
|
|
@@ -21,23 +22,31 @@ namespace faiss {
|
|
|
21
22
|
* concatenation of M sub-vectors, additive quantizers sum M sub-vectors
|
|
22
23
|
* to get the decoded vector.
|
|
23
24
|
*/
|
|
24
|
-
struct AdditiveQuantizer {
|
|
25
|
-
size_t d; ///< size of the input vectors
|
|
25
|
+
struct AdditiveQuantizer : Quantizer {
|
|
26
26
|
size_t M; ///< number of codebooks
|
|
27
27
|
std::vector<size_t> nbits; ///< bits for each step
|
|
28
28
|
std::vector<float> codebooks; ///< codebooks
|
|
29
29
|
|
|
30
30
|
// derived values
|
|
31
31
|
std::vector<uint64_t> codebook_offsets;
|
|
32
|
-
size_t
|
|
33
|
-
size_t
|
|
32
|
+
size_t tot_bits; ///< total number of bits (indexes + norms)
|
|
33
|
+
size_t norm_bits; ///< bits allocated for the norms
|
|
34
34
|
size_t total_codebook_size; ///< size of the codebook in vectors
|
|
35
35
|
bool only_8bit; ///< are all nbits = 8 (use faster decoder)
|
|
36
36
|
|
|
37
37
|
bool verbose; ///< verbose during training?
|
|
38
38
|
bool is_trained; ///< is trained or not
|
|
39
39
|
|
|
40
|
-
IndexFlat1D qnorm;
|
|
40
|
+
IndexFlat1D qnorm; ///< store and search norms
|
|
41
|
+
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
|
|
42
|
+
///< fastscan search
|
|
43
|
+
|
|
44
|
+
/// norms and distance matrixes with beam search can get large, so use this
|
|
45
|
+
/// to control for the amount of memory that can be allocated
|
|
46
|
+
size_t max_mem_distances;
|
|
47
|
+
|
|
48
|
+
/// encode a norm into norm_bits bits
|
|
49
|
+
uint64_t encode_norm(float norm) const;
|
|
41
50
|
|
|
42
51
|
uint32_t encode_qcint(
|
|
43
52
|
float x) const; ///< encode norm by non-uniform scalar quantization
|
|
@@ -57,6 +66,10 @@ struct AdditiveQuantizer {
|
|
|
57
66
|
ST_norm_qint4,
|
|
58
67
|
ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
|
|
59
68
|
ST_norm_cqint4,
|
|
69
|
+
|
|
70
|
+
ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
|
|
71
|
+
///< scan)
|
|
72
|
+
ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
|
|
60
73
|
};
|
|
61
74
|
|
|
62
75
|
AdditiveQuantizer(
|
|
@@ -69,16 +82,25 @@ struct AdditiveQuantizer {
|
|
|
69
82
|
///< compute derived values when d, M and nbits have been set
|
|
70
83
|
void set_derived_values();
|
|
71
84
|
|
|
72
|
-
///< Train the
|
|
73
|
-
|
|
85
|
+
///< Train the norm quantizer
|
|
86
|
+
void train_norm(size_t n, const float* norms);
|
|
87
|
+
|
|
88
|
+
void compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
89
|
+
const override {
|
|
90
|
+
compute_codes_add_centroids(x, codes, n);
|
|
91
|
+
}
|
|
74
92
|
|
|
75
93
|
/** Encode a set of vectors
|
|
76
94
|
*
|
|
77
95
|
* @param x vectors to encode, size n * d
|
|
78
96
|
* @param codes output codes, size n * code_size
|
|
97
|
+
* @param centroids centroids to be added to x, size n * d
|
|
79
98
|
*/
|
|
80
|
-
virtual void
|
|
81
|
-
const
|
|
99
|
+
virtual void compute_codes_add_centroids(
|
|
100
|
+
const float* x,
|
|
101
|
+
uint8_t* codes,
|
|
102
|
+
size_t n,
|
|
103
|
+
const float* centroids = nullptr) const = 0;
|
|
82
104
|
|
|
83
105
|
/** pack a series of code to bit-compact format
|
|
84
106
|
*
|
|
@@ -87,27 +109,29 @@ struct AdditiveQuantizer {
|
|
|
87
109
|
* @param ld_codes leading dimension of codes
|
|
88
110
|
* @param norms norms of the vectors (size n). Will be computed if
|
|
89
111
|
* needed but not provided
|
|
112
|
+
* @param centroids centroids to be added to x, size n * d
|
|
90
113
|
*/
|
|
91
114
|
void pack_codes(
|
|
92
115
|
size_t n,
|
|
93
116
|
const int32_t* codes,
|
|
94
117
|
uint8_t* packed_codes,
|
|
95
118
|
int64_t ld_codes = -1,
|
|
96
|
-
const float* norms = nullptr
|
|
119
|
+
const float* norms = nullptr,
|
|
120
|
+
const float* centroids = nullptr) const;
|
|
97
121
|
|
|
98
122
|
/** Decode a set of vectors
|
|
99
123
|
*
|
|
100
124
|
* @param codes codes to decode, size n * code_size
|
|
101
125
|
* @param x output vectors, size n * d
|
|
102
126
|
*/
|
|
103
|
-
void decode(const uint8_t* codes, float* x, size_t n) const;
|
|
127
|
+
void decode(const uint8_t* codes, float* x, size_t n) const override;
|
|
104
128
|
|
|
105
129
|
/** Decode a set of vectors in non-packed format
|
|
106
130
|
*
|
|
107
131
|
* @param codes codes to decode, size n * ld_codes
|
|
108
132
|
* @param x output vectors, size n * d
|
|
109
133
|
*/
|
|
110
|
-
void decode_unpacked(
|
|
134
|
+
virtual void decode_unpacked(
|
|
111
135
|
const int32_t* codes,
|
|
112
136
|
float* x,
|
|
113
137
|
size_t n,
|
|
@@ -143,8 +167,15 @@ struct AdditiveQuantizer {
|
|
|
143
167
|
*
|
|
144
168
|
* @param xq query vector, size (n, d)
|
|
145
169
|
* @param LUT look-up table, size (n, total_codebook_size)
|
|
170
|
+
* @param alpha compute alpha * inner-product
|
|
171
|
+
* @param ld_lut leading dimension of LUT
|
|
146
172
|
*/
|
|
147
|
-
void compute_LUT(
|
|
173
|
+
virtual void compute_LUT(
|
|
174
|
+
size_t n,
|
|
175
|
+
const float* xq,
|
|
176
|
+
float* LUT,
|
|
177
|
+
float alpha = 1.0f,
|
|
178
|
+
long ld_lut = -1) const;
|
|
148
179
|
|
|
149
180
|
/// exact IP search
|
|
150
181
|
void knn_centroids_inner_product(
|
|
@@ -199,60 +199,6 @@ void RangeSearchPartialResult::merge(
|
|
|
199
199
|
result->lims[0] = 0;
|
|
200
200
|
}
|
|
201
201
|
|
|
202
|
-
/***********************************************************************
|
|
203
|
-
* IDSelectorRange
|
|
204
|
-
***********************************************************************/
|
|
205
|
-
|
|
206
|
-
IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
|
|
207
|
-
: imin(imin), imax(imax) {}
|
|
208
|
-
|
|
209
|
-
bool IDSelectorRange::is_member(idx_t id) const {
|
|
210
|
-
return id >= imin && id < imax;
|
|
211
|
-
}
|
|
212
|
-
|
|
213
|
-
/***********************************************************************
|
|
214
|
-
* IDSelectorArray
|
|
215
|
-
***********************************************************************/
|
|
216
|
-
|
|
217
|
-
IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
|
|
218
|
-
|
|
219
|
-
bool IDSelectorArray::is_member(idx_t id) const {
|
|
220
|
-
for (idx_t i = 0; i < n; i++) {
|
|
221
|
-
if (ids[i] == id)
|
|
222
|
-
return true;
|
|
223
|
-
}
|
|
224
|
-
return false;
|
|
225
|
-
}
|
|
226
|
-
|
|
227
|
-
/***********************************************************************
|
|
228
|
-
* IDSelectorBatch
|
|
229
|
-
***********************************************************************/
|
|
230
|
-
|
|
231
|
-
IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
|
|
232
|
-
nbits = 0;
|
|
233
|
-
while (n > (1L << nbits))
|
|
234
|
-
nbits++;
|
|
235
|
-
nbits += 5;
|
|
236
|
-
// for n = 1M, nbits = 25 is optimal, see P56659518
|
|
237
|
-
|
|
238
|
-
mask = (1L << nbits) - 1;
|
|
239
|
-
bloom.resize(1UL << (nbits - 3), 0);
|
|
240
|
-
for (long i = 0; i < n; i++) {
|
|
241
|
-
Index::idx_t id = indices[i];
|
|
242
|
-
set.insert(id);
|
|
243
|
-
id &= mask;
|
|
244
|
-
bloom[id >> 3] |= 1 << (id & 7);
|
|
245
|
-
}
|
|
246
|
-
}
|
|
247
|
-
|
|
248
|
-
bool IDSelectorBatch::is_member(idx_t i) const {
|
|
249
|
-
long im = i & mask;
|
|
250
|
-
if (!(bloom[im >> 3] & (1 << (im & 7)))) {
|
|
251
|
-
return 0;
|
|
252
|
-
}
|
|
253
|
-
return set.count(i);
|
|
254
|
-
}
|
|
255
|
-
|
|
256
202
|
/***********************************************************
|
|
257
203
|
* Interrupt callback
|
|
258
204
|
***********************************************************/
|
|
@@ -5,8 +5,6 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
// Auxiliary index structures, that are used in indexes but that can
|
|
11
9
|
// be forward-declared
|
|
12
10
|
|
|
@@ -18,7 +16,6 @@
|
|
|
18
16
|
#include <cstring>
|
|
19
17
|
#include <memory>
|
|
20
18
|
#include <mutex>
|
|
21
|
-
#include <unordered_set>
|
|
22
19
|
#include <vector>
|
|
23
20
|
|
|
24
21
|
#include <faiss/Index.h>
|
|
@@ -52,55 +49,6 @@ struct RangeSearchResult {
|
|
|
52
49
|
virtual ~RangeSearchResult();
|
|
53
50
|
};
|
|
54
51
|
|
|
55
|
-
/** Encapsulates a set of ids to remove. */
|
|
56
|
-
struct IDSelector {
|
|
57
|
-
typedef Index::idx_t idx_t;
|
|
58
|
-
virtual bool is_member(idx_t id) const = 0;
|
|
59
|
-
virtual ~IDSelector() {}
|
|
60
|
-
};
|
|
61
|
-
|
|
62
|
-
/** remove ids between [imni, imax) */
|
|
63
|
-
struct IDSelectorRange : IDSelector {
|
|
64
|
-
idx_t imin, imax;
|
|
65
|
-
|
|
66
|
-
IDSelectorRange(idx_t imin, idx_t imax);
|
|
67
|
-
bool is_member(idx_t id) const override;
|
|
68
|
-
~IDSelectorRange() override {}
|
|
69
|
-
};
|
|
70
|
-
|
|
71
|
-
/** simple list of elements to remove
|
|
72
|
-
*
|
|
73
|
-
* this is inefficient in most cases, except for IndexIVF with
|
|
74
|
-
* maintain_direct_map
|
|
75
|
-
*/
|
|
76
|
-
struct IDSelectorArray : IDSelector {
|
|
77
|
-
size_t n;
|
|
78
|
-
const idx_t* ids;
|
|
79
|
-
|
|
80
|
-
IDSelectorArray(size_t n, const idx_t* ids);
|
|
81
|
-
bool is_member(idx_t id) const override;
|
|
82
|
-
~IDSelectorArray() override {}
|
|
83
|
-
};
|
|
84
|
-
|
|
85
|
-
/** Remove ids from a set. Repetitions of ids in the indices set
|
|
86
|
-
* passed to the constructor does not hurt performance. The hash
|
|
87
|
-
* function used for the bloom filter and GCC's implementation of
|
|
88
|
-
* unordered_set are just the least significant bits of the id. This
|
|
89
|
-
* works fine for random ids or ids in sequences but will produce many
|
|
90
|
-
* hash collisions if lsb's are always the same */
|
|
91
|
-
struct IDSelectorBatch : IDSelector {
|
|
92
|
-
std::unordered_set<idx_t> set;
|
|
93
|
-
|
|
94
|
-
typedef unsigned char uint8_t;
|
|
95
|
-
std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
|
|
96
|
-
int nbits;
|
|
97
|
-
idx_t mask;
|
|
98
|
-
|
|
99
|
-
IDSelectorBatch(size_t n, const idx_t* indices);
|
|
100
|
-
bool is_member(idx_t id) const override;
|
|
101
|
-
~IDSelectorBatch() override {}
|
|
102
|
-
};
|
|
103
|
-
|
|
104
52
|
/****************************************************************
|
|
105
53
|
* Result structures for range search.
|
|
106
54
|
*
|
|
@@ -186,30 +134,6 @@ struct RangeSearchPartialResult : BufferList {
|
|
|
186
134
|
bool do_delete = true);
|
|
187
135
|
};
|
|
188
136
|
|
|
189
|
-
/***********************************************************
|
|
190
|
-
* The distance computer maintains a current query and computes
|
|
191
|
-
* distances to elements in an index that supports random access.
|
|
192
|
-
*
|
|
193
|
-
* The DistanceComputer is not intended to be thread-safe (eg. because
|
|
194
|
-
* it maintains counters) so the distance functions are not const,
|
|
195
|
-
* instantiate one from each thread if needed.
|
|
196
|
-
***********************************************************/
|
|
197
|
-
struct DistanceComputer {
|
|
198
|
-
using idx_t = Index::idx_t;
|
|
199
|
-
|
|
200
|
-
/// called before computing distances. Pointer x should remain valid
|
|
201
|
-
/// while operator () is called
|
|
202
|
-
virtual void set_query(const float* x) = 0;
|
|
203
|
-
|
|
204
|
-
/// compute distance of vector i to current query
|
|
205
|
-
virtual float operator()(idx_t i) = 0;
|
|
206
|
-
|
|
207
|
-
/// compute distance between two stored vectors
|
|
208
|
-
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
|
|
209
|
-
|
|
210
|
-
virtual ~DistanceComputer() {}
|
|
211
|
-
};
|
|
212
|
-
|
|
213
137
|
/***********************************************************
|
|
214
138
|
* Interrupt callback
|
|
215
139
|
***********************************************************/
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <faiss/Index.h>
|
|
11
|
+
|
|
12
|
+
namespace faiss {
|
|
13
|
+
|
|
14
|
+
/***********************************************************
|
|
15
|
+
* The distance computer maintains a current query and computes
|
|
16
|
+
* distances to elements in an index that supports random access.
|
|
17
|
+
*
|
|
18
|
+
* The DistanceComputer is not intended to be thread-safe (eg. because
|
|
19
|
+
* it maintains counters) so the distance functions are not const,
|
|
20
|
+
* instantiate one from each thread if needed.
|
|
21
|
+
*
|
|
22
|
+
* Note that the equivalent for IVF indexes is the InvertedListScanner,
|
|
23
|
+
* that has additional methods to handle the inverted list context.
|
|
24
|
+
***********************************************************/
|
|
25
|
+
struct DistanceComputer {
|
|
26
|
+
using idx_t = Index::idx_t;
|
|
27
|
+
|
|
28
|
+
/// called before computing distances. Pointer x should remain valid
|
|
29
|
+
/// while operator () is called
|
|
30
|
+
virtual void set_query(const float* x) = 0;
|
|
31
|
+
|
|
32
|
+
/// compute distance of vector i to current query
|
|
33
|
+
virtual float operator()(idx_t i) = 0;
|
|
34
|
+
|
|
35
|
+
/// compute distance between two stored vectors
|
|
36
|
+
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
|
|
37
|
+
|
|
38
|
+
virtual ~DistanceComputer() {}
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
/*************************************************************
|
|
42
|
+
* Specialized version of the DistanceComputer when we know that codes are
|
|
43
|
+
* laid out in a flat index.
|
|
44
|
+
*/
|
|
45
|
+
struct FlatCodesDistanceComputer : DistanceComputer {
|
|
46
|
+
const uint8_t* codes;
|
|
47
|
+
size_t code_size;
|
|
48
|
+
|
|
49
|
+
FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
|
|
50
|
+
: codes(codes), code_size(code_size) {}
|
|
51
|
+
|
|
52
|
+
FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
|
|
53
|
+
|
|
54
|
+
float operator()(idx_t i) final {
|
|
55
|
+
return distance_to_code(codes + i * code_size);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/// compute distance of current query to an encoded vector
|
|
59
|
+
virtual float distance_to_code(const uint8_t* code) = 0;
|
|
60
|
+
|
|
61
|
+
virtual ~FlatCodesDistanceComputer() {}
|
|
62
|
+
};
|
|
63
|
+
|
|
64
|
+
} // namespace faiss
|