faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -19,8 +19,11 @@
|
|
|
19
19
|
#include <immintrin.h>
|
|
20
20
|
#endif
|
|
21
21
|
|
|
22
|
+
#include <faiss/IndexIVF.h>
|
|
22
23
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
23
24
|
#include <faiss/impl/FaissAssert.h>
|
|
25
|
+
#include <faiss/impl/IDSelector.h>
|
|
26
|
+
#include <faiss/utils/fp16.h>
|
|
24
27
|
#include <faiss/utils/utils.h>
|
|
25
28
|
|
|
26
29
|
namespace faiss {
|
|
@@ -201,114 +204,6 @@ struct Codec6bit {
|
|
|
201
204
|
#endif
|
|
202
205
|
};
|
|
203
206
|
|
|
204
|
-
#ifdef USE_F16C
|
|
205
|
-
|
|
206
|
-
uint16_t encode_fp16(float x) {
|
|
207
|
-
__m128 xf = _mm_set1_ps(x);
|
|
208
|
-
__m128i xi =
|
|
209
|
-
_mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
210
|
-
return _mm_cvtsi128_si32(xi) & 0xffff;
|
|
211
|
-
}
|
|
212
|
-
|
|
213
|
-
float decode_fp16(uint16_t x) {
|
|
214
|
-
__m128i xi = _mm_set1_epi16(x);
|
|
215
|
-
__m128 xf = _mm_cvtph_ps(xi);
|
|
216
|
-
return _mm_cvtss_f32(xf);
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
#else
|
|
220
|
-
|
|
221
|
-
// non-intrinsic FP16 <-> FP32 code adapted from
|
|
222
|
-
// https://github.com/ispc/ispc/blob/master/stdlib.ispc
|
|
223
|
-
|
|
224
|
-
float floatbits(uint32_t x) {
|
|
225
|
-
void* xptr = &x;
|
|
226
|
-
return *(float*)xptr;
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
uint32_t intbits(float f) {
|
|
230
|
-
void* fptr = &f;
|
|
231
|
-
return *(uint32_t*)fptr;
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
uint16_t encode_fp16(float f) {
|
|
235
|
-
// via Fabian "ryg" Giesen.
|
|
236
|
-
// https://gist.github.com/2156668
|
|
237
|
-
uint32_t sign_mask = 0x80000000u;
|
|
238
|
-
int32_t o;
|
|
239
|
-
|
|
240
|
-
uint32_t fint = intbits(f);
|
|
241
|
-
uint32_t sign = fint & sign_mask;
|
|
242
|
-
fint ^= sign;
|
|
243
|
-
|
|
244
|
-
// NOTE all the integer compares in this function can be safely
|
|
245
|
-
// compiled into signed compares since all operands are below
|
|
246
|
-
// 0x80000000. Important if you want fast straight SSE2 code (since
|
|
247
|
-
// there's no unsigned PCMPGTD).
|
|
248
|
-
|
|
249
|
-
// Inf or NaN (all exponent bits set)
|
|
250
|
-
// NaN->qNaN and Inf->Inf
|
|
251
|
-
// unconditional assignment here, will override with right value for
|
|
252
|
-
// the regular case below.
|
|
253
|
-
uint32_t f32infty = 255u << 23;
|
|
254
|
-
o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
|
|
255
|
-
|
|
256
|
-
// (De)normalized number or zero
|
|
257
|
-
// update fint unconditionally to save the blending; we don't need it
|
|
258
|
-
// anymore for the Inf/NaN case anyway.
|
|
259
|
-
|
|
260
|
-
const uint32_t round_mask = ~0xfffu;
|
|
261
|
-
const uint32_t magic = 15u << 23;
|
|
262
|
-
|
|
263
|
-
// Shift exponent down, denormalize if necessary.
|
|
264
|
-
// NOTE This represents half-float denormals using single
|
|
265
|
-
// precision denormals. The main reason to do this is that
|
|
266
|
-
// there's no shift with per-lane variable shifts in SSE*, which
|
|
267
|
-
// we'd otherwise need. It has some funky side effects though:
|
|
268
|
-
// - This conversion will actually respect the FTZ (Flush To Zero)
|
|
269
|
-
// flag in MXCSR - if it's set, no half-float denormals will be
|
|
270
|
-
// generated. I'm honestly not sure whether this is good or
|
|
271
|
-
// bad. It's definitely interesting.
|
|
272
|
-
// - If the underlying HW doesn't support denormals (not an issue
|
|
273
|
-
// with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
|
|
274
|
-
// you will always get flush-to-zero behavior. This is bad,
|
|
275
|
-
// unless you're on a CPU where you don't care.
|
|
276
|
-
// - Denormals tend to be slow. FP32 denormals are rare in
|
|
277
|
-
// practice outside of things like recursive filters in DSP -
|
|
278
|
-
// not a typical half-float application. Whether FP16 denormals
|
|
279
|
-
// are rare in practice, I don't know. Whatever slow path your
|
|
280
|
-
// HW may or may not have for denormals, this may well hit it.
|
|
281
|
-
float fscale = floatbits(fint & round_mask) * floatbits(magic);
|
|
282
|
-
fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
|
|
283
|
-
int32_t fint2 = intbits(fscale) - round_mask;
|
|
284
|
-
|
|
285
|
-
if (fint < f32infty)
|
|
286
|
-
o = fint2 >> 13; // Take the bits!
|
|
287
|
-
|
|
288
|
-
return (o | (sign >> 16));
|
|
289
|
-
}
|
|
290
|
-
|
|
291
|
-
float decode_fp16(uint16_t h) {
|
|
292
|
-
// https://gist.github.com/2144712
|
|
293
|
-
// Fabian "ryg" Giesen.
|
|
294
|
-
|
|
295
|
-
const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
|
|
296
|
-
|
|
297
|
-
int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
|
|
298
|
-
int32_t exp = shifted_exp & o; // just the exponent
|
|
299
|
-
o += (int32_t)(127 - 15) << 23; // exponent adjust
|
|
300
|
-
|
|
301
|
-
int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
|
|
302
|
-
int32_t zerodenorm_val =
|
|
303
|
-
intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
|
|
304
|
-
int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
|
|
305
|
-
|
|
306
|
-
int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
|
|
307
|
-
return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
|
|
308
|
-
}
|
|
309
|
-
|
|
310
|
-
#endif
|
|
311
|
-
|
|
312
207
|
/*******************************************************************
|
|
313
208
|
* Quantizer: normalizes scalar vector components, then passes them
|
|
314
209
|
* through a codec
|
|
@@ -318,7 +213,7 @@ template <class Codec, bool uniform, int SIMD>
|
|
|
318
213
|
struct QuantizerTemplate {};
|
|
319
214
|
|
|
320
215
|
template <class Codec>
|
|
321
|
-
struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::
|
|
216
|
+
struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
|
|
322
217
|
const size_t d;
|
|
323
218
|
const float vmin, vdiff;
|
|
324
219
|
|
|
@@ -372,7 +267,7 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
|
|
|
372
267
|
#endif
|
|
373
268
|
|
|
374
269
|
template <class Codec>
|
|
375
|
-
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::
|
|
270
|
+
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
|
|
376
271
|
const size_t d;
|
|
377
272
|
const float *vmin, *vdiff;
|
|
378
273
|
|
|
@@ -433,7 +328,7 @@ template <int SIMDWIDTH>
|
|
|
433
328
|
struct QuantizerFP16 {};
|
|
434
329
|
|
|
435
330
|
template <>
|
|
436
|
-
struct QuantizerFP16<1> : ScalarQuantizer::
|
|
331
|
+
struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
|
|
437
332
|
const size_t d;
|
|
438
333
|
|
|
439
334
|
QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
|
|
@@ -478,7 +373,7 @@ template <int SIMDWIDTH>
|
|
|
478
373
|
struct Quantizer8bitDirect {};
|
|
479
374
|
|
|
480
375
|
template <>
|
|
481
|
-
struct Quantizer8bitDirect<1> : ScalarQuantizer::
|
|
376
|
+
struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
|
|
482
377
|
const size_t d;
|
|
483
378
|
|
|
484
379
|
Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
|
|
@@ -518,7 +413,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
|
|
|
518
413
|
#endif
|
|
519
414
|
|
|
520
415
|
template <int SIMDWIDTH>
|
|
521
|
-
ScalarQuantizer::
|
|
416
|
+
ScalarQuantizer::SQuantizer* select_quantizer_1(
|
|
522
417
|
QuantizerType qtype,
|
|
523
418
|
size_t d,
|
|
524
419
|
const std::vector<float>& trained) {
|
|
@@ -911,11 +806,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
|
|
|
911
806
|
q = x;
|
|
912
807
|
}
|
|
913
808
|
|
|
914
|
-
/// compute distance of vector i to current query
|
|
915
|
-
float operator()(idx_t i) final {
|
|
916
|
-
return query_to_code(codes + i * code_size);
|
|
917
|
-
}
|
|
918
|
-
|
|
919
809
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
920
810
|
return compute_code_distance(
|
|
921
811
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -963,11 +853,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
|
|
|
963
853
|
q = x;
|
|
964
854
|
}
|
|
965
855
|
|
|
966
|
-
/// compute distance of vector i to current query
|
|
967
|
-
float operator()(idx_t i) final {
|
|
968
|
-
return query_to_code(codes + i * code_size);
|
|
969
|
-
}
|
|
970
|
-
|
|
971
856
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
972
857
|
return compute_code_distance(
|
|
973
858
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -1021,11 +906,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|
|
1021
906
|
return compute_code_distance(tmp.data(), code);
|
|
1022
907
|
}
|
|
1023
908
|
|
|
1024
|
-
/// compute distance of vector i to current query
|
|
1025
|
-
float operator()(idx_t i) final {
|
|
1026
|
-
return query_to_code(codes + i * code_size);
|
|
1027
|
-
}
|
|
1028
|
-
|
|
1029
909
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
1030
910
|
return compute_code_distance(
|
|
1031
911
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -1089,11 +969,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
|
1089
969
|
return compute_code_distance(tmp.data(), code);
|
|
1090
970
|
}
|
|
1091
971
|
|
|
1092
|
-
/// compute distance of vector i to current query
|
|
1093
|
-
float operator()(idx_t i) final {
|
|
1094
|
-
return query_to_code(codes + i * code_size);
|
|
1095
|
-
}
|
|
1096
|
-
|
|
1097
972
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
1098
973
|
return compute_code_distance(
|
|
1099
974
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -1173,17 +1048,12 @@ SQDistanceComputer* select_distance_computer(
|
|
|
1173
1048
|
********************************************************************/
|
|
1174
1049
|
|
|
1175
1050
|
ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
|
|
1176
|
-
: qtype(qtype), rangestat(RS_minmax), rangestat_arg(0)
|
|
1051
|
+
: Quantizer(d), qtype(qtype), rangestat(RS_minmax), rangestat_arg(0) {
|
|
1177
1052
|
set_derived_sizes();
|
|
1178
1053
|
}
|
|
1179
1054
|
|
|
1180
1055
|
ScalarQuantizer::ScalarQuantizer()
|
|
1181
|
-
: qtype(QT_8bit),
|
|
1182
|
-
rangestat(RS_minmax),
|
|
1183
|
-
rangestat_arg(0),
|
|
1184
|
-
d(0),
|
|
1185
|
-
bits(0),
|
|
1186
|
-
code_size(0) {}
|
|
1056
|
+
: qtype(QT_8bit), rangestat(RS_minmax), rangestat_arg(0), bits(0) {}
|
|
1187
1057
|
|
|
1188
1058
|
void ScalarQuantizer::set_derived_sizes() {
|
|
1189
1059
|
switch (qtype) {
|
|
@@ -1273,7 +1143,7 @@ void ScalarQuantizer::train_residual(
|
|
|
1273
1143
|
}
|
|
1274
1144
|
}
|
|
1275
1145
|
|
|
1276
|
-
ScalarQuantizer::
|
|
1146
|
+
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
|
|
1277
1147
|
#ifdef USE_F16C
|
|
1278
1148
|
if (d % 8 == 0) {
|
|
1279
1149
|
return select_quantizer_1<8>(qtype, d, trained);
|
|
@@ -1286,7 +1156,7 @@ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
|
|
|
1286
1156
|
|
|
1287
1157
|
void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
1288
1158
|
const {
|
|
1289
|
-
std::unique_ptr<
|
|
1159
|
+
std::unique_ptr<SQuantizer> squant(select_quantizer());
|
|
1290
1160
|
|
|
1291
1161
|
memset(codes, 0, code_size * n);
|
|
1292
1162
|
#pragma omp parallel for
|
|
@@ -1295,7 +1165,7 @@ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
|
1295
1165
|
}
|
|
1296
1166
|
|
|
1297
1167
|
void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
|
|
1298
|
-
std::unique_ptr<
|
|
1168
|
+
std::unique_ptr<SQuantizer> squant(select_quantizer());
|
|
1299
1169
|
|
|
1300
1170
|
#pragma omp parallel for
|
|
1301
1171
|
for (int64_t i = 0; i < n; i++)
|
|
@@ -1332,28 +1202,25 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
|
|
|
1332
1202
|
|
|
1333
1203
|
namespace {
|
|
1334
1204
|
|
|
1335
|
-
template <class DCClass>
|
|
1205
|
+
template <class DCClass, int use_sel>
|
|
1336
1206
|
struct IVFSQScannerIP : InvertedListScanner {
|
|
1337
1207
|
DCClass dc;
|
|
1338
|
-
bool
|
|
1339
|
-
|
|
1340
|
-
size_t code_size;
|
|
1208
|
+
bool by_residual;
|
|
1209
|
+
const IDSelector* sel;
|
|
1341
1210
|
|
|
1342
|
-
|
|
1343
|
-
float accu0; /// added to all distances
|
|
1211
|
+
float accu0; /// added to all distances
|
|
1344
1212
|
|
|
1345
1213
|
IVFSQScannerIP(
|
|
1346
1214
|
int d,
|
|
1347
1215
|
const std::vector<float>& trained,
|
|
1348
1216
|
size_t code_size,
|
|
1349
1217
|
bool store_pairs,
|
|
1218
|
+
const IDSelector* sel,
|
|
1350
1219
|
bool by_residual)
|
|
1351
|
-
: dc(d, trained),
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
list_no(0),
|
|
1356
|
-
accu0(0) {}
|
|
1220
|
+
: dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
|
|
1221
|
+
this->store_pairs = store_pairs;
|
|
1222
|
+
this->code_size = code_size;
|
|
1223
|
+
}
|
|
1357
1224
|
|
|
1358
1225
|
void set_query(const float* query) override {
|
|
1359
1226
|
dc.set_query(query);
|
|
@@ -1377,7 +1244,11 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1377
1244
|
size_t k) const override {
|
|
1378
1245
|
size_t nup = 0;
|
|
1379
1246
|
|
|
1380
|
-
for (size_t j = 0; j < list_size; j
|
|
1247
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1248
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1249
|
+
continue;
|
|
1250
|
+
}
|
|
1251
|
+
|
|
1381
1252
|
float accu = accu0 + dc.query_to_code(codes);
|
|
1382
1253
|
|
|
1383
1254
|
if (accu > simi[0]) {
|
|
@@ -1385,7 +1256,6 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1385
1256
|
minheap_replace_top(k, simi, idxi, accu, id);
|
|
1386
1257
|
nup++;
|
|
1387
1258
|
}
|
|
1388
|
-
codes += code_size;
|
|
1389
1259
|
}
|
|
1390
1260
|
return nup;
|
|
1391
1261
|
}
|
|
@@ -1396,25 +1266,31 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1396
1266
|
const idx_t* ids,
|
|
1397
1267
|
float radius,
|
|
1398
1268
|
RangeQueryResult& res) const override {
|
|
1399
|
-
for (size_t j = 0; j < list_size; j
|
|
1269
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1270
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1271
|
+
continue;
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1400
1274
|
float accu = accu0 + dc.query_to_code(codes);
|
|
1401
1275
|
if (accu > radius) {
|
|
1402
1276
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1403
1277
|
res.add(accu, id);
|
|
1404
1278
|
}
|
|
1405
|
-
codes += code_size;
|
|
1406
1279
|
}
|
|
1407
1280
|
}
|
|
1408
1281
|
};
|
|
1409
1282
|
|
|
1410
|
-
|
|
1283
|
+
/* use_sel = 0: don't check selector
|
|
1284
|
+
* = 1: check on ids[j]
|
|
1285
|
+
* = 2: check in j directly (normally ids is nullptr and store_pairs)
|
|
1286
|
+
*/
|
|
1287
|
+
template <class DCClass, int use_sel>
|
|
1411
1288
|
struct IVFSQScannerL2 : InvertedListScanner {
|
|
1412
1289
|
DCClass dc;
|
|
1413
1290
|
|
|
1414
|
-
bool
|
|
1415
|
-
size_t code_size;
|
|
1291
|
+
bool by_residual;
|
|
1416
1292
|
const Index* quantizer;
|
|
1417
|
-
|
|
1293
|
+
const IDSelector* sel;
|
|
1418
1294
|
const float* x; /// current query
|
|
1419
1295
|
|
|
1420
1296
|
std::vector<float> tmp;
|
|
@@ -1425,15 +1301,17 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1425
1301
|
size_t code_size,
|
|
1426
1302
|
const Index* quantizer,
|
|
1427
1303
|
bool store_pairs,
|
|
1304
|
+
const IDSelector* sel,
|
|
1428
1305
|
bool by_residual)
|
|
1429
1306
|
: dc(d, trained),
|
|
1430
|
-
store_pairs(store_pairs),
|
|
1431
1307
|
by_residual(by_residual),
|
|
1432
|
-
code_size(code_size),
|
|
1433
1308
|
quantizer(quantizer),
|
|
1434
|
-
|
|
1309
|
+
sel(sel),
|
|
1435
1310
|
x(nullptr),
|
|
1436
|
-
tmp(d) {
|
|
1311
|
+
tmp(d) {
|
|
1312
|
+
this->store_pairs = store_pairs;
|
|
1313
|
+
this->code_size = code_size;
|
|
1314
|
+
}
|
|
1437
1315
|
|
|
1438
1316
|
void set_query(const float* query) override {
|
|
1439
1317
|
x = query;
|
|
@@ -1443,8 +1321,8 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1443
1321
|
}
|
|
1444
1322
|
|
|
1445
1323
|
void set_list(idx_t list_no, float /*coarse_dis*/) override {
|
|
1324
|
+
this->list_no = list_no;
|
|
1446
1325
|
if (by_residual) {
|
|
1447
|
-
this->list_no = list_no;
|
|
1448
1326
|
// shift of x_in wrt centroid
|
|
1449
1327
|
quantizer->compute_residual(x, tmp.data(), list_no);
|
|
1450
1328
|
dc.set_query(tmp.data());
|
|
@@ -1465,7 +1343,11 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1465
1343
|
idx_t* idxi,
|
|
1466
1344
|
size_t k) const override {
|
|
1467
1345
|
size_t nup = 0;
|
|
1468
|
-
for (size_t j = 0; j < list_size; j
|
|
1346
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1347
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1348
|
+
continue;
|
|
1349
|
+
}
|
|
1350
|
+
|
|
1469
1351
|
float dis = dc.query_to_code(codes);
|
|
1470
1352
|
|
|
1471
1353
|
if (dis < simi[0]) {
|
|
@@ -1473,7 +1355,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1473
1355
|
maxheap_replace_top(k, simi, idxi, dis, id);
|
|
1474
1356
|
nup++;
|
|
1475
1357
|
}
|
|
1476
|
-
codes += code_size;
|
|
1477
1358
|
}
|
|
1478
1359
|
return nup;
|
|
1479
1360
|
}
|
|
@@ -1484,44 +1365,77 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1484
1365
|
const idx_t* ids,
|
|
1485
1366
|
float radius,
|
|
1486
1367
|
RangeQueryResult& res) const override {
|
|
1487
|
-
for (size_t j = 0; j < list_size; j
|
|
1368
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1369
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1370
|
+
continue;
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1488
1373
|
float dis = dc.query_to_code(codes);
|
|
1489
1374
|
if (dis < radius) {
|
|
1490
1375
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1491
1376
|
res.add(dis, id);
|
|
1492
1377
|
}
|
|
1493
|
-
codes += code_size;
|
|
1494
1378
|
}
|
|
1495
1379
|
}
|
|
1496
1380
|
};
|
|
1497
1381
|
|
|
1498
|
-
template <class DCClass>
|
|
1499
|
-
InvertedListScanner*
|
|
1382
|
+
template <class DCClass, int use_sel>
|
|
1383
|
+
InvertedListScanner* sel3_InvertedListScanner(
|
|
1500
1384
|
const ScalarQuantizer* sq,
|
|
1501
1385
|
const Index* quantizer,
|
|
1502
1386
|
bool store_pairs,
|
|
1387
|
+
const IDSelector* sel,
|
|
1503
1388
|
bool r) {
|
|
1504
1389
|
if (DCClass::Sim::metric_type == METRIC_L2) {
|
|
1505
|
-
return new IVFSQScannerL2<DCClass>(
|
|
1506
|
-
sq->d,
|
|
1390
|
+
return new IVFSQScannerL2<DCClass, use_sel>(
|
|
1391
|
+
sq->d,
|
|
1392
|
+
sq->trained,
|
|
1393
|
+
sq->code_size,
|
|
1394
|
+
quantizer,
|
|
1395
|
+
store_pairs,
|
|
1396
|
+
sel,
|
|
1397
|
+
r);
|
|
1507
1398
|
} else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1508
|
-
return new IVFSQScannerIP<DCClass>(
|
|
1509
|
-
sq->d, sq->trained, sq->code_size, store_pairs, r);
|
|
1399
|
+
return new IVFSQScannerIP<DCClass, use_sel>(
|
|
1400
|
+
sq->d, sq->trained, sq->code_size, store_pairs, sel, r);
|
|
1510
1401
|
} else {
|
|
1511
1402
|
FAISS_THROW_MSG("unsupported metric type");
|
|
1512
1403
|
}
|
|
1513
1404
|
}
|
|
1514
1405
|
|
|
1406
|
+
template <class DCClass>
|
|
1407
|
+
InvertedListScanner* sel2_InvertedListScanner(
|
|
1408
|
+
const ScalarQuantizer* sq,
|
|
1409
|
+
const Index* quantizer,
|
|
1410
|
+
bool store_pairs,
|
|
1411
|
+
const IDSelector* sel,
|
|
1412
|
+
bool r) {
|
|
1413
|
+
if (sel) {
|
|
1414
|
+
if (store_pairs) {
|
|
1415
|
+
return sel3_InvertedListScanner<DCClass, 2>(
|
|
1416
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1417
|
+
} else {
|
|
1418
|
+
return sel3_InvertedListScanner<DCClass, 1>(
|
|
1419
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1420
|
+
}
|
|
1421
|
+
} else {
|
|
1422
|
+
return sel3_InvertedListScanner<DCClass, 0>(
|
|
1423
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1424
|
+
}
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1515
1427
|
template <class Similarity, class Codec, bool uniform>
|
|
1516
1428
|
InvertedListScanner* sel12_InvertedListScanner(
|
|
1517
1429
|
const ScalarQuantizer* sq,
|
|
1518
1430
|
const Index* quantizer,
|
|
1519
1431
|
bool store_pairs,
|
|
1432
|
+
const IDSelector* sel,
|
|
1520
1433
|
bool r) {
|
|
1521
1434
|
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1522
1435
|
using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
|
|
1523
1436
|
using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
|
|
1524
|
-
return sel2_InvertedListScanner<DCClass>(
|
|
1437
|
+
return sel2_InvertedListScanner<DCClass>(
|
|
1438
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1525
1439
|
}
|
|
1526
1440
|
|
|
1527
1441
|
template <class Similarity>
|
|
@@ -1529,39 +1443,40 @@ InvertedListScanner* sel1_InvertedListScanner(
|
|
|
1529
1443
|
const ScalarQuantizer* sq,
|
|
1530
1444
|
const Index* quantizer,
|
|
1531
1445
|
bool store_pairs,
|
|
1446
|
+
const IDSelector* sel,
|
|
1532
1447
|
bool r) {
|
|
1533
1448
|
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1534
1449
|
switch (sq->qtype) {
|
|
1535
1450
|
case ScalarQuantizer::QT_8bit_uniform:
|
|
1536
1451
|
return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
|
|
1537
|
-
sq, quantizer, store_pairs, r);
|
|
1452
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1538
1453
|
case ScalarQuantizer::QT_4bit_uniform:
|
|
1539
1454
|
return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
|
|
1540
|
-
sq, quantizer, store_pairs, r);
|
|
1455
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1541
1456
|
case ScalarQuantizer::QT_8bit:
|
|
1542
1457
|
return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
|
|
1543
|
-
sq, quantizer, store_pairs, r);
|
|
1458
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1544
1459
|
case ScalarQuantizer::QT_4bit:
|
|
1545
1460
|
return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
|
|
1546
|
-
sq, quantizer, store_pairs, r);
|
|
1461
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1547
1462
|
case ScalarQuantizer::QT_6bit:
|
|
1548
1463
|
return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
|
|
1549
|
-
sq, quantizer, store_pairs, r);
|
|
1464
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1550
1465
|
case ScalarQuantizer::QT_fp16:
|
|
1551
1466
|
return sel2_InvertedListScanner<DCTemplate<
|
|
1552
1467
|
QuantizerFP16<SIMDWIDTH>,
|
|
1553
1468
|
Similarity,
|
|
1554
|
-
SIMDWIDTH>>(sq, quantizer, store_pairs, r);
|
|
1469
|
+
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
|
|
1555
1470
|
case ScalarQuantizer::QT_8bit_direct:
|
|
1556
1471
|
if (sq->d % 16 == 0) {
|
|
1557
1472
|
return sel2_InvertedListScanner<
|
|
1558
1473
|
DistanceComputerByte<Similarity, SIMDWIDTH>>(
|
|
1559
|
-
sq, quantizer, store_pairs, r);
|
|
1474
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1560
1475
|
} else {
|
|
1561
1476
|
return sel2_InvertedListScanner<DCTemplate<
|
|
1562
1477
|
Quantizer8bitDirect<SIMDWIDTH>,
|
|
1563
1478
|
Similarity,
|
|
1564
|
-
SIMDWIDTH>>(sq, quantizer, store_pairs, r);
|
|
1479
|
+
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
|
|
1565
1480
|
}
|
|
1566
1481
|
}
|
|
1567
1482
|
|
|
@@ -1575,13 +1490,14 @@ InvertedListScanner* sel0_InvertedListScanner(
|
|
|
1575
1490
|
const ScalarQuantizer* sq,
|
|
1576
1491
|
const Index* quantizer,
|
|
1577
1492
|
bool store_pairs,
|
|
1493
|
+
const IDSelector* sel,
|
|
1578
1494
|
bool by_residual) {
|
|
1579
1495
|
if (mt == METRIC_L2) {
|
|
1580
1496
|
return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
|
|
1581
|
-
sq, quantizer, store_pairs, by_residual);
|
|
1497
|
+
sq, quantizer, store_pairs, sel, by_residual);
|
|
1582
1498
|
} else if (mt == METRIC_INNER_PRODUCT) {
|
|
1583
1499
|
return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
|
|
1584
|
-
sq, quantizer, store_pairs, by_residual);
|
|
1500
|
+
sq, quantizer, store_pairs, sel, by_residual);
|
|
1585
1501
|
} else {
|
|
1586
1502
|
FAISS_THROW_MSG("unsupported metric type");
|
|
1587
1503
|
}
|
|
@@ -1593,16 +1509,17 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
|
|
|
1593
1509
|
MetricType mt,
|
|
1594
1510
|
const Index* quantizer,
|
|
1595
1511
|
bool store_pairs,
|
|
1512
|
+
const IDSelector* sel,
|
|
1596
1513
|
bool by_residual) const {
|
|
1597
1514
|
#ifdef USE_F16C
|
|
1598
1515
|
if (d % 8 == 0) {
|
|
1599
1516
|
return sel0_InvertedListScanner<8>(
|
|
1600
|
-
mt, this, quantizer, store_pairs, by_residual);
|
|
1517
|
+
mt, this, quantizer, store_pairs, sel, by_residual);
|
|
1601
1518
|
} else
|
|
1602
1519
|
#endif
|
|
1603
1520
|
{
|
|
1604
1521
|
return sel0_InvertedListScanner<1>(
|
|
1605
|
-
mt, this, quantizer, store_pairs, by_residual);
|
|
1522
|
+
mt, this, quantizer, store_pairs, sel, by_residual);
|
|
1606
1523
|
}
|
|
1607
1524
|
}
|
|
1608
1525
|
|
|
@@ -9,18 +9,21 @@
|
|
|
9
9
|
|
|
10
10
|
#pragma once
|
|
11
11
|
|
|
12
|
-
#include <faiss/IndexIVF.h>
|
|
13
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
13
|
+
#include <faiss/impl/DistanceComputer.h>
|
|
14
|
+
#include <faiss/impl/Quantizer.h>
|
|
14
15
|
|
|
15
16
|
namespace faiss {
|
|
16
17
|
|
|
18
|
+
struct InvertedListScanner;
|
|
19
|
+
|
|
17
20
|
/**
|
|
18
21
|
* The uniform quantizer has a range [vmin, vmax]. The range can be
|
|
19
22
|
* the same for all dimensions (uniform) or specific per dimension
|
|
20
23
|
* (default).
|
|
21
24
|
*/
|
|
22
25
|
|
|
23
|
-
struct ScalarQuantizer {
|
|
26
|
+
struct ScalarQuantizer : Quantizer {
|
|
24
27
|
enum QuantizerType {
|
|
25
28
|
QT_8bit, ///< 8 bits per component
|
|
26
29
|
QT_4bit, ///< 4 bits per component
|
|
@@ -48,15 +51,9 @@ struct ScalarQuantizer {
|
|
|
48
51
|
RangeStat rangestat;
|
|
49
52
|
float rangestat_arg;
|
|
50
53
|
|
|
51
|
-
/// dimension of input vectors
|
|
52
|
-
size_t d;
|
|
53
|
-
|
|
54
54
|
/// bits per scalar code
|
|
55
55
|
size_t bits;
|
|
56
56
|
|
|
57
|
-
/// bytes per vector
|
|
58
|
-
size_t code_size;
|
|
59
|
-
|
|
60
57
|
/// trained values (including the range)
|
|
61
58
|
std::vector<float> trained;
|
|
62
59
|
|
|
@@ -66,7 +63,7 @@ struct ScalarQuantizer {
|
|
|
66
63
|
/// updates internal values based on qtype and d
|
|
67
64
|
void set_derived_sizes();
|
|
68
65
|
|
|
69
|
-
void train(size_t n, const float* x);
|
|
66
|
+
void train(size_t n, const float* x) override;
|
|
70
67
|
|
|
71
68
|
/// Used by an IVF index to train based on the residuals
|
|
72
69
|
void train_residual(
|
|
@@ -81,38 +78,40 @@ struct ScalarQuantizer {
|
|
|
81
78
|
* @param x vectors to encode, size n * d
|
|
82
79
|
* @param codes output codes, size n * code_size
|
|
83
80
|
*/
|
|
84
|
-
void compute_codes(const float* x, uint8_t* codes, size_t n) const;
|
|
81
|
+
void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
|
|
85
82
|
|
|
86
83
|
/** Decode a set of vectors
|
|
87
84
|
*
|
|
88
85
|
* @param codes codes to decode, size n * code_size
|
|
89
86
|
* @param x output vectors, size n * d
|
|
90
87
|
*/
|
|
91
|
-
void decode(const uint8_t* code, float* x, size_t n) const;
|
|
88
|
+
void decode(const uint8_t* code, float* x, size_t n) const override;
|
|
92
89
|
|
|
93
90
|
/*****************************************************
|
|
94
91
|
* Objects that provide methods for encoding/decoding, distance
|
|
95
92
|
* computation and inverted list scanning
|
|
96
93
|
*****************************************************/
|
|
97
94
|
|
|
98
|
-
struct
|
|
95
|
+
struct SQuantizer {
|
|
99
96
|
// encodes one vector. Assumes code is filled with 0s on input!
|
|
100
97
|
virtual void encode_vector(const float* x, uint8_t* code) const = 0;
|
|
101
98
|
virtual void decode_vector(const uint8_t* code, float* x) const = 0;
|
|
102
99
|
|
|
103
|
-
virtual ~
|
|
100
|
+
virtual ~SQuantizer() {}
|
|
104
101
|
};
|
|
105
102
|
|
|
106
|
-
|
|
103
|
+
SQuantizer* select_quantizer() const;
|
|
107
104
|
|
|
108
|
-
struct SQDistanceComputer :
|
|
105
|
+
struct SQDistanceComputer : FlatCodesDistanceComputer {
|
|
109
106
|
const float* q;
|
|
110
|
-
const uint8_t* codes;
|
|
111
|
-
size_t code_size;
|
|
112
107
|
|
|
113
|
-
SQDistanceComputer() : q(nullptr)
|
|
108
|
+
SQDistanceComputer() : q(nullptr) {}
|
|
114
109
|
|
|
115
110
|
virtual float query_to_code(const uint8_t* code) const = 0;
|
|
111
|
+
|
|
112
|
+
float distance_to_code(const uint8_t* code) final {
|
|
113
|
+
return query_to_code(code);
|
|
114
|
+
}
|
|
116
115
|
};
|
|
117
116
|
|
|
118
117
|
SQDistanceComputer* get_distance_computer(
|
|
@@ -122,6 +121,7 @@ struct ScalarQuantizer {
|
|
|
122
121
|
MetricType mt,
|
|
123
122
|
const Index* quantizer,
|
|
124
123
|
bool store_pairs,
|
|
124
|
+
const IDSelector* sel,
|
|
125
125
|
bool by_residual = false) const;
|
|
126
126
|
};
|
|
127
127
|
|