faiss 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
- data/vendor/faiss/faiss/Index2Layer.h +6 -1
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
- data/vendor/faiss/faiss/IndexFlat.h +7 -4
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
- data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
- data/vendor/faiss/faiss/IndexIVF.h +35 -16
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
- data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
- data/vendor/faiss/faiss/IndexLSH.h +2 -1
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
- data/vendor/faiss/faiss/IndexPQ.h +19 -5
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
- data/vendor/faiss/faiss/IndexRefine.h +4 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
- data/vendor/faiss/faiss/VectorTransform.h +22 -4
- data/vendor/faiss/faiss/clone_index.cpp +23 -1
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
- data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +196 -7
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +380 -56
- data/vendor/faiss/faiss/utils/distances.h +113 -15
- data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- metadata +37 -3
|
@@ -19,8 +19,11 @@
|
|
|
19
19
|
#include <immintrin.h>
|
|
20
20
|
#endif
|
|
21
21
|
|
|
22
|
+
#include <faiss/IndexIVF.h>
|
|
22
23
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
23
24
|
#include <faiss/impl/FaissAssert.h>
|
|
25
|
+
#include <faiss/impl/IDSelector.h>
|
|
26
|
+
#include <faiss/utils/fp16.h>
|
|
24
27
|
#include <faiss/utils/utils.h>
|
|
25
28
|
|
|
26
29
|
namespace faiss {
|
|
@@ -201,114 +204,6 @@ struct Codec6bit {
|
|
|
201
204
|
#endif
|
|
202
205
|
};
|
|
203
206
|
|
|
204
|
-
#ifdef USE_F16C
|
|
205
|
-
|
|
206
|
-
uint16_t encode_fp16(float x) {
|
|
207
|
-
__m128 xf = _mm_set1_ps(x);
|
|
208
|
-
__m128i xi =
|
|
209
|
-
_mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
|
210
|
-
return _mm_cvtsi128_si32(xi) & 0xffff;
|
|
211
|
-
}
|
|
212
|
-
|
|
213
|
-
float decode_fp16(uint16_t x) {
|
|
214
|
-
__m128i xi = _mm_set1_epi16(x);
|
|
215
|
-
__m128 xf = _mm_cvtph_ps(xi);
|
|
216
|
-
return _mm_cvtss_f32(xf);
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
#else
|
|
220
|
-
|
|
221
|
-
// non-intrinsic FP16 <-> FP32 code adapted from
|
|
222
|
-
// https://github.com/ispc/ispc/blob/master/stdlib.ispc
|
|
223
|
-
|
|
224
|
-
float floatbits(uint32_t x) {
|
|
225
|
-
void* xptr = &x;
|
|
226
|
-
return *(float*)xptr;
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
uint32_t intbits(float f) {
|
|
230
|
-
void* fptr = &f;
|
|
231
|
-
return *(uint32_t*)fptr;
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
uint16_t encode_fp16(float f) {
|
|
235
|
-
// via Fabian "ryg" Giesen.
|
|
236
|
-
// https://gist.github.com/2156668
|
|
237
|
-
uint32_t sign_mask = 0x80000000u;
|
|
238
|
-
int32_t o;
|
|
239
|
-
|
|
240
|
-
uint32_t fint = intbits(f);
|
|
241
|
-
uint32_t sign = fint & sign_mask;
|
|
242
|
-
fint ^= sign;
|
|
243
|
-
|
|
244
|
-
// NOTE all the integer compares in this function can be safely
|
|
245
|
-
// compiled into signed compares since all operands are below
|
|
246
|
-
// 0x80000000. Important if you want fast straight SSE2 code (since
|
|
247
|
-
// there's no unsigned PCMPGTD).
|
|
248
|
-
|
|
249
|
-
// Inf or NaN (all exponent bits set)
|
|
250
|
-
// NaN->qNaN and Inf->Inf
|
|
251
|
-
// unconditional assignment here, will override with right value for
|
|
252
|
-
// the regular case below.
|
|
253
|
-
uint32_t f32infty = 255u << 23;
|
|
254
|
-
o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
|
|
255
|
-
|
|
256
|
-
// (De)normalized number or zero
|
|
257
|
-
// update fint unconditionally to save the blending; we don't need it
|
|
258
|
-
// anymore for the Inf/NaN case anyway.
|
|
259
|
-
|
|
260
|
-
const uint32_t round_mask = ~0xfffu;
|
|
261
|
-
const uint32_t magic = 15u << 23;
|
|
262
|
-
|
|
263
|
-
// Shift exponent down, denormalize if necessary.
|
|
264
|
-
// NOTE This represents half-float denormals using single
|
|
265
|
-
// precision denormals. The main reason to do this is that
|
|
266
|
-
// there's no shift with per-lane variable shifts in SSE*, which
|
|
267
|
-
// we'd otherwise need. It has some funky side effects though:
|
|
268
|
-
// - This conversion will actually respect the FTZ (Flush To Zero)
|
|
269
|
-
// flag in MXCSR - if it's set, no half-float denormals will be
|
|
270
|
-
// generated. I'm honestly not sure whether this is good or
|
|
271
|
-
// bad. It's definitely interesting.
|
|
272
|
-
// - If the underlying HW doesn't support denormals (not an issue
|
|
273
|
-
// with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
|
|
274
|
-
// you will always get flush-to-zero behavior. This is bad,
|
|
275
|
-
// unless you're on a CPU where you don't care.
|
|
276
|
-
// - Denormals tend to be slow. FP32 denormals are rare in
|
|
277
|
-
// practice outside of things like recursive filters in DSP -
|
|
278
|
-
// not a typical half-float application. Whether FP16 denormals
|
|
279
|
-
// are rare in practice, I don't know. Whatever slow path your
|
|
280
|
-
// HW may or may not have for denormals, this may well hit it.
|
|
281
|
-
float fscale = floatbits(fint & round_mask) * floatbits(magic);
|
|
282
|
-
fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
|
|
283
|
-
int32_t fint2 = intbits(fscale) - round_mask;
|
|
284
|
-
|
|
285
|
-
if (fint < f32infty)
|
|
286
|
-
o = fint2 >> 13; // Take the bits!
|
|
287
|
-
|
|
288
|
-
return (o | (sign >> 16));
|
|
289
|
-
}
|
|
290
|
-
|
|
291
|
-
float decode_fp16(uint16_t h) {
|
|
292
|
-
// https://gist.github.com/2144712
|
|
293
|
-
// Fabian "ryg" Giesen.
|
|
294
|
-
|
|
295
|
-
const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
|
|
296
|
-
|
|
297
|
-
int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
|
|
298
|
-
int32_t exp = shifted_exp & o; // just the exponent
|
|
299
|
-
o += (int32_t)(127 - 15) << 23; // exponent adjust
|
|
300
|
-
|
|
301
|
-
int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
|
|
302
|
-
int32_t zerodenorm_val =
|
|
303
|
-
intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
|
|
304
|
-
int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
|
|
305
|
-
|
|
306
|
-
int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
|
|
307
|
-
return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
|
|
308
|
-
}
|
|
309
|
-
|
|
310
|
-
#endif
|
|
311
|
-
|
|
312
207
|
/*******************************************************************
|
|
313
208
|
* Quantizer: normalizes scalar vector components, then passes them
|
|
314
209
|
* through a codec
|
|
@@ -318,7 +213,7 @@ template <class Codec, bool uniform, int SIMD>
|
|
|
318
213
|
struct QuantizerTemplate {};
|
|
319
214
|
|
|
320
215
|
template <class Codec>
|
|
321
|
-
struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::
|
|
216
|
+
struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
|
|
322
217
|
const size_t d;
|
|
323
218
|
const float vmin, vdiff;
|
|
324
219
|
|
|
@@ -372,7 +267,7 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
|
|
|
372
267
|
#endif
|
|
373
268
|
|
|
374
269
|
template <class Codec>
|
|
375
|
-
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::
|
|
270
|
+
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
|
|
376
271
|
const size_t d;
|
|
377
272
|
const float *vmin, *vdiff;
|
|
378
273
|
|
|
@@ -433,7 +328,7 @@ template <int SIMDWIDTH>
|
|
|
433
328
|
struct QuantizerFP16 {};
|
|
434
329
|
|
|
435
330
|
template <>
|
|
436
|
-
struct QuantizerFP16<1> : ScalarQuantizer::
|
|
331
|
+
struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
|
|
437
332
|
const size_t d;
|
|
438
333
|
|
|
439
334
|
QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
|
|
@@ -478,7 +373,7 @@ template <int SIMDWIDTH>
|
|
|
478
373
|
struct Quantizer8bitDirect {};
|
|
479
374
|
|
|
480
375
|
template <>
|
|
481
|
-
struct Quantizer8bitDirect<1> : ScalarQuantizer::
|
|
376
|
+
struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
|
|
482
377
|
const size_t d;
|
|
483
378
|
|
|
484
379
|
Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
|
|
@@ -518,7 +413,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
|
|
|
518
413
|
#endif
|
|
519
414
|
|
|
520
415
|
template <int SIMDWIDTH>
|
|
521
|
-
ScalarQuantizer::
|
|
416
|
+
ScalarQuantizer::SQuantizer* select_quantizer_1(
|
|
522
417
|
QuantizerType qtype,
|
|
523
418
|
size_t d,
|
|
524
419
|
const std::vector<float>& trained) {
|
|
@@ -911,11 +806,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
|
|
|
911
806
|
q = x;
|
|
912
807
|
}
|
|
913
808
|
|
|
914
|
-
/// compute distance of vector i to current query
|
|
915
|
-
float operator()(idx_t i) final {
|
|
916
|
-
return query_to_code(codes + i * code_size);
|
|
917
|
-
}
|
|
918
|
-
|
|
919
809
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
920
810
|
return compute_code_distance(
|
|
921
811
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -963,11 +853,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
|
|
|
963
853
|
q = x;
|
|
964
854
|
}
|
|
965
855
|
|
|
966
|
-
/// compute distance of vector i to current query
|
|
967
|
-
float operator()(idx_t i) final {
|
|
968
|
-
return query_to_code(codes + i * code_size);
|
|
969
|
-
}
|
|
970
|
-
|
|
971
856
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
972
857
|
return compute_code_distance(
|
|
973
858
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -1021,11 +906,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
|
|
|
1021
906
|
return compute_code_distance(tmp.data(), code);
|
|
1022
907
|
}
|
|
1023
908
|
|
|
1024
|
-
/// compute distance of vector i to current query
|
|
1025
|
-
float operator()(idx_t i) final {
|
|
1026
|
-
return query_to_code(codes + i * code_size);
|
|
1027
|
-
}
|
|
1028
|
-
|
|
1029
909
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
1030
910
|
return compute_code_distance(
|
|
1031
911
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -1089,11 +969,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
|
|
|
1089
969
|
return compute_code_distance(tmp.data(), code);
|
|
1090
970
|
}
|
|
1091
971
|
|
|
1092
|
-
/// compute distance of vector i to current query
|
|
1093
|
-
float operator()(idx_t i) final {
|
|
1094
|
-
return query_to_code(codes + i * code_size);
|
|
1095
|
-
}
|
|
1096
|
-
|
|
1097
972
|
float symmetric_dis(idx_t i, idx_t j) override {
|
|
1098
973
|
return compute_code_distance(
|
|
1099
974
|
codes + i * code_size, codes + j * code_size);
|
|
@@ -1173,17 +1048,12 @@ SQDistanceComputer* select_distance_computer(
|
|
|
1173
1048
|
********************************************************************/
|
|
1174
1049
|
|
|
1175
1050
|
ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
|
|
1176
|
-
: qtype(qtype), rangestat(RS_minmax), rangestat_arg(0)
|
|
1051
|
+
: Quantizer(d), qtype(qtype), rangestat(RS_minmax), rangestat_arg(0) {
|
|
1177
1052
|
set_derived_sizes();
|
|
1178
1053
|
}
|
|
1179
1054
|
|
|
1180
1055
|
ScalarQuantizer::ScalarQuantizer()
|
|
1181
|
-
: qtype(QT_8bit),
|
|
1182
|
-
rangestat(RS_minmax),
|
|
1183
|
-
rangestat_arg(0),
|
|
1184
|
-
d(0),
|
|
1185
|
-
bits(0),
|
|
1186
|
-
code_size(0) {}
|
|
1056
|
+
: qtype(QT_8bit), rangestat(RS_minmax), rangestat_arg(0), bits(0) {}
|
|
1187
1057
|
|
|
1188
1058
|
void ScalarQuantizer::set_derived_sizes() {
|
|
1189
1059
|
switch (qtype) {
|
|
@@ -1273,7 +1143,7 @@ void ScalarQuantizer::train_residual(
|
|
|
1273
1143
|
}
|
|
1274
1144
|
}
|
|
1275
1145
|
|
|
1276
|
-
ScalarQuantizer::
|
|
1146
|
+
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
|
|
1277
1147
|
#ifdef USE_F16C
|
|
1278
1148
|
if (d % 8 == 0) {
|
|
1279
1149
|
return select_quantizer_1<8>(qtype, d, trained);
|
|
@@ -1286,7 +1156,7 @@ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
|
|
|
1286
1156
|
|
|
1287
1157
|
void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
1288
1158
|
const {
|
|
1289
|
-
std::unique_ptr<
|
|
1159
|
+
std::unique_ptr<SQuantizer> squant(select_quantizer());
|
|
1290
1160
|
|
|
1291
1161
|
memset(codes, 0, code_size * n);
|
|
1292
1162
|
#pragma omp parallel for
|
|
@@ -1295,7 +1165,7 @@ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
|
1295
1165
|
}
|
|
1296
1166
|
|
|
1297
1167
|
void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
|
|
1298
|
-
std::unique_ptr<
|
|
1168
|
+
std::unique_ptr<SQuantizer> squant(select_quantizer());
|
|
1299
1169
|
|
|
1300
1170
|
#pragma omp parallel for
|
|
1301
1171
|
for (int64_t i = 0; i < n; i++)
|
|
@@ -1332,10 +1202,11 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
|
|
|
1332
1202
|
|
|
1333
1203
|
namespace {
|
|
1334
1204
|
|
|
1335
|
-
template <class DCClass>
|
|
1205
|
+
template <class DCClass, int use_sel>
|
|
1336
1206
|
struct IVFSQScannerIP : InvertedListScanner {
|
|
1337
1207
|
DCClass dc;
|
|
1338
1208
|
bool by_residual;
|
|
1209
|
+
const IDSelector* sel;
|
|
1339
1210
|
|
|
1340
1211
|
float accu0; /// added to all distances
|
|
1341
1212
|
|
|
@@ -1344,8 +1215,9 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1344
1215
|
const std::vector<float>& trained,
|
|
1345
1216
|
size_t code_size,
|
|
1346
1217
|
bool store_pairs,
|
|
1218
|
+
const IDSelector* sel,
|
|
1347
1219
|
bool by_residual)
|
|
1348
|
-
: dc(d, trained), by_residual(by_residual), accu0(0) {
|
|
1220
|
+
: dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
|
|
1349
1221
|
this->store_pairs = store_pairs;
|
|
1350
1222
|
this->code_size = code_size;
|
|
1351
1223
|
}
|
|
@@ -1372,7 +1244,11 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1372
1244
|
size_t k) const override {
|
|
1373
1245
|
size_t nup = 0;
|
|
1374
1246
|
|
|
1375
|
-
for (size_t j = 0; j < list_size; j
|
|
1247
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1248
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1249
|
+
continue;
|
|
1250
|
+
}
|
|
1251
|
+
|
|
1376
1252
|
float accu = accu0 + dc.query_to_code(codes);
|
|
1377
1253
|
|
|
1378
1254
|
if (accu > simi[0]) {
|
|
@@ -1380,7 +1256,6 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1380
1256
|
minheap_replace_top(k, simi, idxi, accu, id);
|
|
1381
1257
|
nup++;
|
|
1382
1258
|
}
|
|
1383
|
-
codes += code_size;
|
|
1384
1259
|
}
|
|
1385
1260
|
return nup;
|
|
1386
1261
|
}
|
|
@@ -1391,23 +1266,31 @@ struct IVFSQScannerIP : InvertedListScanner {
|
|
|
1391
1266
|
const idx_t* ids,
|
|
1392
1267
|
float radius,
|
|
1393
1268
|
RangeQueryResult& res) const override {
|
|
1394
|
-
for (size_t j = 0; j < list_size; j
|
|
1269
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1270
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1271
|
+
continue;
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1395
1274
|
float accu = accu0 + dc.query_to_code(codes);
|
|
1396
1275
|
if (accu > radius) {
|
|
1397
1276
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1398
1277
|
res.add(accu, id);
|
|
1399
1278
|
}
|
|
1400
|
-
codes += code_size;
|
|
1401
1279
|
}
|
|
1402
1280
|
}
|
|
1403
1281
|
};
|
|
1404
1282
|
|
|
1405
|
-
|
|
1283
|
+
/* use_sel = 0: don't check selector
|
|
1284
|
+
* = 1: check on ids[j]
|
|
1285
|
+
* = 2: check in j directly (normally ids is nullptr and store_pairs)
|
|
1286
|
+
*/
|
|
1287
|
+
template <class DCClass, int use_sel>
|
|
1406
1288
|
struct IVFSQScannerL2 : InvertedListScanner {
|
|
1407
1289
|
DCClass dc;
|
|
1408
1290
|
|
|
1409
1291
|
bool by_residual;
|
|
1410
1292
|
const Index* quantizer;
|
|
1293
|
+
const IDSelector* sel;
|
|
1411
1294
|
const float* x; /// current query
|
|
1412
1295
|
|
|
1413
1296
|
std::vector<float> tmp;
|
|
@@ -1418,10 +1301,12 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1418
1301
|
size_t code_size,
|
|
1419
1302
|
const Index* quantizer,
|
|
1420
1303
|
bool store_pairs,
|
|
1304
|
+
const IDSelector* sel,
|
|
1421
1305
|
bool by_residual)
|
|
1422
1306
|
: dc(d, trained),
|
|
1423
1307
|
by_residual(by_residual),
|
|
1424
1308
|
quantizer(quantizer),
|
|
1309
|
+
sel(sel),
|
|
1425
1310
|
x(nullptr),
|
|
1426
1311
|
tmp(d) {
|
|
1427
1312
|
this->store_pairs = store_pairs;
|
|
@@ -1458,7 +1343,11 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1458
1343
|
idx_t* idxi,
|
|
1459
1344
|
size_t k) const override {
|
|
1460
1345
|
size_t nup = 0;
|
|
1461
|
-
for (size_t j = 0; j < list_size; j
|
|
1346
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1347
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1348
|
+
continue;
|
|
1349
|
+
}
|
|
1350
|
+
|
|
1462
1351
|
float dis = dc.query_to_code(codes);
|
|
1463
1352
|
|
|
1464
1353
|
if (dis < simi[0]) {
|
|
@@ -1466,7 +1355,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1466
1355
|
maxheap_replace_top(k, simi, idxi, dis, id);
|
|
1467
1356
|
nup++;
|
|
1468
1357
|
}
|
|
1469
|
-
codes += code_size;
|
|
1470
1358
|
}
|
|
1471
1359
|
return nup;
|
|
1472
1360
|
}
|
|
@@ -1477,44 +1365,77 @@ struct IVFSQScannerL2 : InvertedListScanner {
|
|
|
1477
1365
|
const idx_t* ids,
|
|
1478
1366
|
float radius,
|
|
1479
1367
|
RangeQueryResult& res) const override {
|
|
1480
|
-
for (size_t j = 0; j < list_size; j
|
|
1368
|
+
for (size_t j = 0; j < list_size; j++, codes += code_size) {
|
|
1369
|
+
if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
|
|
1370
|
+
continue;
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1481
1373
|
float dis = dc.query_to_code(codes);
|
|
1482
1374
|
if (dis < radius) {
|
|
1483
1375
|
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
|
|
1484
1376
|
res.add(dis, id);
|
|
1485
1377
|
}
|
|
1486
|
-
codes += code_size;
|
|
1487
1378
|
}
|
|
1488
1379
|
}
|
|
1489
1380
|
};
|
|
1490
1381
|
|
|
1491
|
-
template <class DCClass>
|
|
1492
|
-
InvertedListScanner*
|
|
1382
|
+
template <class DCClass, int use_sel>
|
|
1383
|
+
InvertedListScanner* sel3_InvertedListScanner(
|
|
1493
1384
|
const ScalarQuantizer* sq,
|
|
1494
1385
|
const Index* quantizer,
|
|
1495
1386
|
bool store_pairs,
|
|
1387
|
+
const IDSelector* sel,
|
|
1496
1388
|
bool r) {
|
|
1497
1389
|
if (DCClass::Sim::metric_type == METRIC_L2) {
|
|
1498
|
-
return new IVFSQScannerL2<DCClass>(
|
|
1499
|
-
sq->d,
|
|
1390
|
+
return new IVFSQScannerL2<DCClass, use_sel>(
|
|
1391
|
+
sq->d,
|
|
1392
|
+
sq->trained,
|
|
1393
|
+
sq->code_size,
|
|
1394
|
+
quantizer,
|
|
1395
|
+
store_pairs,
|
|
1396
|
+
sel,
|
|
1397
|
+
r);
|
|
1500
1398
|
} else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
|
|
1501
|
-
return new IVFSQScannerIP<DCClass>(
|
|
1502
|
-
sq->d, sq->trained, sq->code_size, store_pairs, r);
|
|
1399
|
+
return new IVFSQScannerIP<DCClass, use_sel>(
|
|
1400
|
+
sq->d, sq->trained, sq->code_size, store_pairs, sel, r);
|
|
1503
1401
|
} else {
|
|
1504
1402
|
FAISS_THROW_MSG("unsupported metric type");
|
|
1505
1403
|
}
|
|
1506
1404
|
}
|
|
1507
1405
|
|
|
1406
|
+
template <class DCClass>
|
|
1407
|
+
InvertedListScanner* sel2_InvertedListScanner(
|
|
1408
|
+
const ScalarQuantizer* sq,
|
|
1409
|
+
const Index* quantizer,
|
|
1410
|
+
bool store_pairs,
|
|
1411
|
+
const IDSelector* sel,
|
|
1412
|
+
bool r) {
|
|
1413
|
+
if (sel) {
|
|
1414
|
+
if (store_pairs) {
|
|
1415
|
+
return sel3_InvertedListScanner<DCClass, 2>(
|
|
1416
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1417
|
+
} else {
|
|
1418
|
+
return sel3_InvertedListScanner<DCClass, 1>(
|
|
1419
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1420
|
+
}
|
|
1421
|
+
} else {
|
|
1422
|
+
return sel3_InvertedListScanner<DCClass, 0>(
|
|
1423
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1424
|
+
}
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1508
1427
|
template <class Similarity, class Codec, bool uniform>
|
|
1509
1428
|
InvertedListScanner* sel12_InvertedListScanner(
|
|
1510
1429
|
const ScalarQuantizer* sq,
|
|
1511
1430
|
const Index* quantizer,
|
|
1512
1431
|
bool store_pairs,
|
|
1432
|
+
const IDSelector* sel,
|
|
1513
1433
|
bool r) {
|
|
1514
1434
|
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1515
1435
|
using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
|
|
1516
1436
|
using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
|
|
1517
|
-
return sel2_InvertedListScanner<DCClass>(
|
|
1437
|
+
return sel2_InvertedListScanner<DCClass>(
|
|
1438
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1518
1439
|
}
|
|
1519
1440
|
|
|
1520
1441
|
template <class Similarity>
|
|
@@ -1522,39 +1443,40 @@ InvertedListScanner* sel1_InvertedListScanner(
|
|
|
1522
1443
|
const ScalarQuantizer* sq,
|
|
1523
1444
|
const Index* quantizer,
|
|
1524
1445
|
bool store_pairs,
|
|
1446
|
+
const IDSelector* sel,
|
|
1525
1447
|
bool r) {
|
|
1526
1448
|
constexpr int SIMDWIDTH = Similarity::simdwidth;
|
|
1527
1449
|
switch (sq->qtype) {
|
|
1528
1450
|
case ScalarQuantizer::QT_8bit_uniform:
|
|
1529
1451
|
return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
|
|
1530
|
-
sq, quantizer, store_pairs, r);
|
|
1452
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1531
1453
|
case ScalarQuantizer::QT_4bit_uniform:
|
|
1532
1454
|
return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
|
|
1533
|
-
sq, quantizer, store_pairs, r);
|
|
1455
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1534
1456
|
case ScalarQuantizer::QT_8bit:
|
|
1535
1457
|
return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
|
|
1536
|
-
sq, quantizer, store_pairs, r);
|
|
1458
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1537
1459
|
case ScalarQuantizer::QT_4bit:
|
|
1538
1460
|
return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
|
|
1539
|
-
sq, quantizer, store_pairs, r);
|
|
1461
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1540
1462
|
case ScalarQuantizer::QT_6bit:
|
|
1541
1463
|
return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
|
|
1542
|
-
sq, quantizer, store_pairs, r);
|
|
1464
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1543
1465
|
case ScalarQuantizer::QT_fp16:
|
|
1544
1466
|
return sel2_InvertedListScanner<DCTemplate<
|
|
1545
1467
|
QuantizerFP16<SIMDWIDTH>,
|
|
1546
1468
|
Similarity,
|
|
1547
|
-
SIMDWIDTH>>(sq, quantizer, store_pairs, r);
|
|
1469
|
+
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
|
|
1548
1470
|
case ScalarQuantizer::QT_8bit_direct:
|
|
1549
1471
|
if (sq->d % 16 == 0) {
|
|
1550
1472
|
return sel2_InvertedListScanner<
|
|
1551
1473
|
DistanceComputerByte<Similarity, SIMDWIDTH>>(
|
|
1552
|
-
sq, quantizer, store_pairs, r);
|
|
1474
|
+
sq, quantizer, store_pairs, sel, r);
|
|
1553
1475
|
} else {
|
|
1554
1476
|
return sel2_InvertedListScanner<DCTemplate<
|
|
1555
1477
|
Quantizer8bitDirect<SIMDWIDTH>,
|
|
1556
1478
|
Similarity,
|
|
1557
|
-
SIMDWIDTH>>(sq, quantizer, store_pairs, r);
|
|
1479
|
+
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
|
|
1558
1480
|
}
|
|
1559
1481
|
}
|
|
1560
1482
|
|
|
@@ -1568,13 +1490,14 @@ InvertedListScanner* sel0_InvertedListScanner(
|
|
|
1568
1490
|
const ScalarQuantizer* sq,
|
|
1569
1491
|
const Index* quantizer,
|
|
1570
1492
|
bool store_pairs,
|
|
1493
|
+
const IDSelector* sel,
|
|
1571
1494
|
bool by_residual) {
|
|
1572
1495
|
if (mt == METRIC_L2) {
|
|
1573
1496
|
return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
|
|
1574
|
-
sq, quantizer, store_pairs, by_residual);
|
|
1497
|
+
sq, quantizer, store_pairs, sel, by_residual);
|
|
1575
1498
|
} else if (mt == METRIC_INNER_PRODUCT) {
|
|
1576
1499
|
return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
|
|
1577
|
-
sq, quantizer, store_pairs, by_residual);
|
|
1500
|
+
sq, quantizer, store_pairs, sel, by_residual);
|
|
1578
1501
|
} else {
|
|
1579
1502
|
FAISS_THROW_MSG("unsupported metric type");
|
|
1580
1503
|
}
|
|
@@ -1586,16 +1509,17 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
|
|
|
1586
1509
|
MetricType mt,
|
|
1587
1510
|
const Index* quantizer,
|
|
1588
1511
|
bool store_pairs,
|
|
1512
|
+
const IDSelector* sel,
|
|
1589
1513
|
bool by_residual) const {
|
|
1590
1514
|
#ifdef USE_F16C
|
|
1591
1515
|
if (d % 8 == 0) {
|
|
1592
1516
|
return sel0_InvertedListScanner<8>(
|
|
1593
|
-
mt, this, quantizer, store_pairs, by_residual);
|
|
1517
|
+
mt, this, quantizer, store_pairs, sel, by_residual);
|
|
1594
1518
|
} else
|
|
1595
1519
|
#endif
|
|
1596
1520
|
{
|
|
1597
1521
|
return sel0_InvertedListScanner<1>(
|
|
1598
|
-
mt, this, quantizer, store_pairs, by_residual);
|
|
1522
|
+
mt, this, quantizer, store_pairs, sel, by_residual);
|
|
1599
1523
|
}
|
|
1600
1524
|
}
|
|
1601
1525
|
|
|
@@ -9,18 +9,21 @@
|
|
|
9
9
|
|
|
10
10
|
#pragma once
|
|
11
11
|
|
|
12
|
-
#include <faiss/IndexIVF.h>
|
|
13
12
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
13
|
+
#include <faiss/impl/DistanceComputer.h>
|
|
14
|
+
#include <faiss/impl/Quantizer.h>
|
|
14
15
|
|
|
15
16
|
namespace faiss {
|
|
16
17
|
|
|
18
|
+
struct InvertedListScanner;
|
|
19
|
+
|
|
17
20
|
/**
|
|
18
21
|
* The uniform quantizer has a range [vmin, vmax]. The range can be
|
|
19
22
|
* the same for all dimensions (uniform) or specific per dimension
|
|
20
23
|
* (default).
|
|
21
24
|
*/
|
|
22
25
|
|
|
23
|
-
struct ScalarQuantizer {
|
|
26
|
+
struct ScalarQuantizer : Quantizer {
|
|
24
27
|
enum QuantizerType {
|
|
25
28
|
QT_8bit, ///< 8 bits per component
|
|
26
29
|
QT_4bit, ///< 4 bits per component
|
|
@@ -48,15 +51,9 @@ struct ScalarQuantizer {
|
|
|
48
51
|
RangeStat rangestat;
|
|
49
52
|
float rangestat_arg;
|
|
50
53
|
|
|
51
|
-
/// dimension of input vectors
|
|
52
|
-
size_t d;
|
|
53
|
-
|
|
54
54
|
/// bits per scalar code
|
|
55
55
|
size_t bits;
|
|
56
56
|
|
|
57
|
-
/// bytes per vector
|
|
58
|
-
size_t code_size;
|
|
59
|
-
|
|
60
57
|
/// trained values (including the range)
|
|
61
58
|
std::vector<float> trained;
|
|
62
59
|
|
|
@@ -66,7 +63,7 @@ struct ScalarQuantizer {
|
|
|
66
63
|
/// updates internal values based on qtype and d
|
|
67
64
|
void set_derived_sizes();
|
|
68
65
|
|
|
69
|
-
void train(size_t n, const float* x);
|
|
66
|
+
void train(size_t n, const float* x) override;
|
|
70
67
|
|
|
71
68
|
/// Used by an IVF index to train based on the residuals
|
|
72
69
|
void train_residual(
|
|
@@ -81,38 +78,40 @@ struct ScalarQuantizer {
|
|
|
81
78
|
* @param x vectors to encode, size n * d
|
|
82
79
|
* @param codes output codes, size n * code_size
|
|
83
80
|
*/
|
|
84
|
-
void compute_codes(const float* x, uint8_t* codes, size_t n) const;
|
|
81
|
+
void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
|
|
85
82
|
|
|
86
83
|
/** Decode a set of vectors
|
|
87
84
|
*
|
|
88
85
|
* @param codes codes to decode, size n * code_size
|
|
89
86
|
* @param x output vectors, size n * d
|
|
90
87
|
*/
|
|
91
|
-
void decode(const uint8_t* code, float* x, size_t n) const;
|
|
88
|
+
void decode(const uint8_t* code, float* x, size_t n) const override;
|
|
92
89
|
|
|
93
90
|
/*****************************************************
|
|
94
91
|
* Objects that provide methods for encoding/decoding, distance
|
|
95
92
|
* computation and inverted list scanning
|
|
96
93
|
*****************************************************/
|
|
97
94
|
|
|
98
|
-
struct
|
|
95
|
+
struct SQuantizer {
|
|
99
96
|
// encodes one vector. Assumes code is filled with 0s on input!
|
|
100
97
|
virtual void encode_vector(const float* x, uint8_t* code) const = 0;
|
|
101
98
|
virtual void decode_vector(const uint8_t* code, float* x) const = 0;
|
|
102
99
|
|
|
103
|
-
virtual ~
|
|
100
|
+
virtual ~SQuantizer() {}
|
|
104
101
|
};
|
|
105
102
|
|
|
106
|
-
|
|
103
|
+
SQuantizer* select_quantizer() const;
|
|
107
104
|
|
|
108
|
-
struct SQDistanceComputer :
|
|
105
|
+
struct SQDistanceComputer : FlatCodesDistanceComputer {
|
|
109
106
|
const float* q;
|
|
110
|
-
const uint8_t* codes;
|
|
111
|
-
size_t code_size;
|
|
112
107
|
|
|
113
|
-
SQDistanceComputer() : q(nullptr)
|
|
108
|
+
SQDistanceComputer() : q(nullptr) {}
|
|
114
109
|
|
|
115
110
|
virtual float query_to_code(const uint8_t* code) const = 0;
|
|
111
|
+
|
|
112
|
+
float distance_to_code(const uint8_t* code) final {
|
|
113
|
+
return query_to_code(code);
|
|
114
|
+
}
|
|
116
115
|
};
|
|
117
116
|
|
|
118
117
|
SQDistanceComputer* get_distance_computer(
|
|
@@ -122,6 +121,7 @@ struct ScalarQuantizer {
|
|
|
122
121
|
MetricType mt,
|
|
123
122
|
const Index* quantizer,
|
|
124
123
|
bool store_pairs,
|
|
124
|
+
const IDSelector* sel,
|
|
125
125
|
bool by_residual = false) const;
|
|
126
126
|
};
|
|
127
127
|
|