faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -19,8 +19,11 @@
19
19
  #include <immintrin.h>
20
20
  #endif
21
21
 
22
+ #include <faiss/IndexIVF.h>
22
23
  #include <faiss/impl/AuxIndexStructures.h>
23
24
  #include <faiss/impl/FaissAssert.h>
25
+ #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/fp16.h>
24
27
  #include <faiss/utils/utils.h>
25
28
 
26
29
  namespace faiss {
@@ -201,114 +204,6 @@ struct Codec6bit {
201
204
  #endif
202
205
  };
203
206
 
204
- #ifdef USE_F16C
205
-
206
- uint16_t encode_fp16(float x) {
207
- __m128 xf = _mm_set1_ps(x);
208
- __m128i xi =
209
- _mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
210
- return _mm_cvtsi128_si32(xi) & 0xffff;
211
- }
212
-
213
- float decode_fp16(uint16_t x) {
214
- __m128i xi = _mm_set1_epi16(x);
215
- __m128 xf = _mm_cvtph_ps(xi);
216
- return _mm_cvtss_f32(xf);
217
- }
218
-
219
- #else
220
-
221
- // non-intrinsic FP16 <-> FP32 code adapted from
222
- // https://github.com/ispc/ispc/blob/master/stdlib.ispc
223
-
224
- float floatbits(uint32_t x) {
225
- void* xptr = &x;
226
- return *(float*)xptr;
227
- }
228
-
229
- uint32_t intbits(float f) {
230
- void* fptr = &f;
231
- return *(uint32_t*)fptr;
232
- }
233
-
234
- uint16_t encode_fp16(float f) {
235
- // via Fabian "ryg" Giesen.
236
- // https://gist.github.com/2156668
237
- uint32_t sign_mask = 0x80000000u;
238
- int32_t o;
239
-
240
- uint32_t fint = intbits(f);
241
- uint32_t sign = fint & sign_mask;
242
- fint ^= sign;
243
-
244
- // NOTE all the integer compares in this function can be safely
245
- // compiled into signed compares since all operands are below
246
- // 0x80000000. Important if you want fast straight SSE2 code (since
247
- // there's no unsigned PCMPGTD).
248
-
249
- // Inf or NaN (all exponent bits set)
250
- // NaN->qNaN and Inf->Inf
251
- // unconditional assignment here, will override with right value for
252
- // the regular case below.
253
- uint32_t f32infty = 255u << 23;
254
- o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
255
-
256
- // (De)normalized number or zero
257
- // update fint unconditionally to save the blending; we don't need it
258
- // anymore for the Inf/NaN case anyway.
259
-
260
- const uint32_t round_mask = ~0xfffu;
261
- const uint32_t magic = 15u << 23;
262
-
263
- // Shift exponent down, denormalize if necessary.
264
- // NOTE This represents half-float denormals using single
265
- // precision denormals. The main reason to do this is that
266
- // there's no shift with per-lane variable shifts in SSE*, which
267
- // we'd otherwise need. It has some funky side effects though:
268
- // - This conversion will actually respect the FTZ (Flush To Zero)
269
- // flag in MXCSR - if it's set, no half-float denormals will be
270
- // generated. I'm honestly not sure whether this is good or
271
- // bad. It's definitely interesting.
272
- // - If the underlying HW doesn't support denormals (not an issue
273
- // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
274
- // you will always get flush-to-zero behavior. This is bad,
275
- // unless you're on a CPU where you don't care.
276
- // - Denormals tend to be slow. FP32 denormals are rare in
277
- // practice outside of things like recursive filters in DSP -
278
- // not a typical half-float application. Whether FP16 denormals
279
- // are rare in practice, I don't know. Whatever slow path your
280
- // HW may or may not have for denormals, this may well hit it.
281
- float fscale = floatbits(fint & round_mask) * floatbits(magic);
282
- fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
283
- int32_t fint2 = intbits(fscale) - round_mask;
284
-
285
- if (fint < f32infty)
286
- o = fint2 >> 13; // Take the bits!
287
-
288
- return (o | (sign >> 16));
289
- }
290
-
291
- float decode_fp16(uint16_t h) {
292
- // https://gist.github.com/2144712
293
- // Fabian "ryg" Giesen.
294
-
295
- const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
296
-
297
- int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
298
- int32_t exp = shifted_exp & o; // just the exponent
299
- o += (int32_t)(127 - 15) << 23; // exponent adjust
300
-
301
- int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
302
- int32_t zerodenorm_val =
303
- intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
304
- int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
305
-
306
- int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
307
- return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
308
- }
309
-
310
- #endif
311
-
312
207
  /*******************************************************************
313
208
  * Quantizer: normalizes scalar vector components, then passes them
314
209
  * through a codec
@@ -318,7 +213,7 @@ template <class Codec, bool uniform, int SIMD>
318
213
  struct QuantizerTemplate {};
319
214
 
320
215
  template <class Codec>
321
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::Quantizer {
216
+ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
322
217
  const size_t d;
323
218
  const float vmin, vdiff;
324
219
 
@@ -372,7 +267,7 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
372
267
  #endif
373
268
 
374
269
  template <class Codec>
375
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::Quantizer {
270
+ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
376
271
  const size_t d;
377
272
  const float *vmin, *vdiff;
378
273
 
@@ -433,7 +328,7 @@ template <int SIMDWIDTH>
433
328
  struct QuantizerFP16 {};
434
329
 
435
330
  template <>
436
- struct QuantizerFP16<1> : ScalarQuantizer::Quantizer {
331
+ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
437
332
  const size_t d;
438
333
 
439
334
  QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
@@ -478,7 +373,7 @@ template <int SIMDWIDTH>
478
373
  struct Quantizer8bitDirect {};
479
374
 
480
375
  template <>
481
- struct Quantizer8bitDirect<1> : ScalarQuantizer::Quantizer {
376
+ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
482
377
  const size_t d;
483
378
 
484
379
  Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
@@ -518,7 +413,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
518
413
  #endif
519
414
 
520
415
  template <int SIMDWIDTH>
521
- ScalarQuantizer::Quantizer* select_quantizer_1(
416
+ ScalarQuantizer::SQuantizer* select_quantizer_1(
522
417
  QuantizerType qtype,
523
418
  size_t d,
524
419
  const std::vector<float>& trained) {
@@ -911,11 +806,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
911
806
  q = x;
912
807
  }
913
808
 
914
- /// compute distance of vector i to current query
915
- float operator()(idx_t i) final {
916
- return query_to_code(codes + i * code_size);
917
- }
918
-
919
809
  float symmetric_dis(idx_t i, idx_t j) override {
920
810
  return compute_code_distance(
921
811
  codes + i * code_size, codes + j * code_size);
@@ -963,11 +853,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
963
853
  q = x;
964
854
  }
965
855
 
966
- /// compute distance of vector i to current query
967
- float operator()(idx_t i) final {
968
- return query_to_code(codes + i * code_size);
969
- }
970
-
971
856
  float symmetric_dis(idx_t i, idx_t j) override {
972
857
  return compute_code_distance(
973
858
  codes + i * code_size, codes + j * code_size);
@@ -1021,11 +906,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1021
906
  return compute_code_distance(tmp.data(), code);
1022
907
  }
1023
908
 
1024
- /// compute distance of vector i to current query
1025
- float operator()(idx_t i) final {
1026
- return query_to_code(codes + i * code_size);
1027
- }
1028
-
1029
909
  float symmetric_dis(idx_t i, idx_t j) override {
1030
910
  return compute_code_distance(
1031
911
  codes + i * code_size, codes + j * code_size);
@@ -1089,11 +969,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1089
969
  return compute_code_distance(tmp.data(), code);
1090
970
  }
1091
971
 
1092
- /// compute distance of vector i to current query
1093
- float operator()(idx_t i) final {
1094
- return query_to_code(codes + i * code_size);
1095
- }
1096
-
1097
972
  float symmetric_dis(idx_t i, idx_t j) override {
1098
973
  return compute_code_distance(
1099
974
  codes + i * code_size, codes + j * code_size);
@@ -1173,17 +1048,12 @@ SQDistanceComputer* select_distance_computer(
1173
1048
  ********************************************************************/
1174
1049
 
1175
1050
  ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
1176
- : qtype(qtype), rangestat(RS_minmax), rangestat_arg(0), d(d) {
1051
+ : Quantizer(d), qtype(qtype), rangestat(RS_minmax), rangestat_arg(0) {
1177
1052
  set_derived_sizes();
1178
1053
  }
1179
1054
 
1180
1055
  ScalarQuantizer::ScalarQuantizer()
1181
- : qtype(QT_8bit),
1182
- rangestat(RS_minmax),
1183
- rangestat_arg(0),
1184
- d(0),
1185
- bits(0),
1186
- code_size(0) {}
1056
+ : qtype(QT_8bit), rangestat(RS_minmax), rangestat_arg(0), bits(0) {}
1187
1057
 
1188
1058
  void ScalarQuantizer::set_derived_sizes() {
1189
1059
  switch (qtype) {
@@ -1273,7 +1143,7 @@ void ScalarQuantizer::train_residual(
1273
1143
  }
1274
1144
  }
1275
1145
 
1276
- ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1146
+ ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1277
1147
  #ifdef USE_F16C
1278
1148
  if (d % 8 == 0) {
1279
1149
  return select_quantizer_1<8>(qtype, d, trained);
@@ -1286,7 +1156,7 @@ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1286
1156
 
1287
1157
  void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1288
1158
  const {
1289
- std::unique_ptr<Quantizer> squant(select_quantizer());
1159
+ std::unique_ptr<SQuantizer> squant(select_quantizer());
1290
1160
 
1291
1161
  memset(codes, 0, code_size * n);
1292
1162
  #pragma omp parallel for
@@ -1295,7 +1165,7 @@ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1295
1165
  }
1296
1166
 
1297
1167
  void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1298
- std::unique_ptr<Quantizer> squant(select_quantizer());
1168
+ std::unique_ptr<SQuantizer> squant(select_quantizer());
1299
1169
 
1300
1170
  #pragma omp parallel for
1301
1171
  for (int64_t i = 0; i < n; i++)
@@ -1332,10 +1202,11 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1332
1202
 
1333
1203
  namespace {
1334
1204
 
1335
- template <class DCClass>
1205
+ template <class DCClass, int use_sel>
1336
1206
  struct IVFSQScannerIP : InvertedListScanner {
1337
1207
  DCClass dc;
1338
1208
  bool by_residual;
1209
+ const IDSelector* sel;
1339
1210
 
1340
1211
  float accu0; /// added to all distances
1341
1212
 
@@ -1344,8 +1215,9 @@ struct IVFSQScannerIP : InvertedListScanner {
1344
1215
  const std::vector<float>& trained,
1345
1216
  size_t code_size,
1346
1217
  bool store_pairs,
1218
+ const IDSelector* sel,
1347
1219
  bool by_residual)
1348
- : dc(d, trained), by_residual(by_residual), accu0(0) {
1220
+ : dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
1349
1221
  this->store_pairs = store_pairs;
1350
1222
  this->code_size = code_size;
1351
1223
  }
@@ -1372,7 +1244,11 @@ struct IVFSQScannerIP : InvertedListScanner {
1372
1244
  size_t k) const override {
1373
1245
  size_t nup = 0;
1374
1246
 
1375
- for (size_t j = 0; j < list_size; j++) {
1247
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1248
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1249
+ continue;
1250
+ }
1251
+
1376
1252
  float accu = accu0 + dc.query_to_code(codes);
1377
1253
 
1378
1254
  if (accu > simi[0]) {
@@ -1380,7 +1256,6 @@ struct IVFSQScannerIP : InvertedListScanner {
1380
1256
  minheap_replace_top(k, simi, idxi, accu, id);
1381
1257
  nup++;
1382
1258
  }
1383
- codes += code_size;
1384
1259
  }
1385
1260
  return nup;
1386
1261
  }
@@ -1391,23 +1266,31 @@ struct IVFSQScannerIP : InvertedListScanner {
1391
1266
  const idx_t* ids,
1392
1267
  float radius,
1393
1268
  RangeQueryResult& res) const override {
1394
- for (size_t j = 0; j < list_size; j++) {
1269
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1270
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1271
+ continue;
1272
+ }
1273
+
1395
1274
  float accu = accu0 + dc.query_to_code(codes);
1396
1275
  if (accu > radius) {
1397
1276
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1398
1277
  res.add(accu, id);
1399
1278
  }
1400
- codes += code_size;
1401
1279
  }
1402
1280
  }
1403
1281
  };
1404
1282
 
1405
- template <class DCClass>
1283
+ /* use_sel = 0: don't check selector
1284
+ * = 1: check on ids[j]
1285
+ * = 2: check in j directly (normally ids is nullptr and store_pairs)
1286
+ */
1287
+ template <class DCClass, int use_sel>
1406
1288
  struct IVFSQScannerL2 : InvertedListScanner {
1407
1289
  DCClass dc;
1408
1290
 
1409
1291
  bool by_residual;
1410
1292
  const Index* quantizer;
1293
+ const IDSelector* sel;
1411
1294
  const float* x; /// current query
1412
1295
 
1413
1296
  std::vector<float> tmp;
@@ -1418,10 +1301,12 @@ struct IVFSQScannerL2 : InvertedListScanner {
1418
1301
  size_t code_size,
1419
1302
  const Index* quantizer,
1420
1303
  bool store_pairs,
1304
+ const IDSelector* sel,
1421
1305
  bool by_residual)
1422
1306
  : dc(d, trained),
1423
1307
  by_residual(by_residual),
1424
1308
  quantizer(quantizer),
1309
+ sel(sel),
1425
1310
  x(nullptr),
1426
1311
  tmp(d) {
1427
1312
  this->store_pairs = store_pairs;
@@ -1458,7 +1343,11 @@ struct IVFSQScannerL2 : InvertedListScanner {
1458
1343
  idx_t* idxi,
1459
1344
  size_t k) const override {
1460
1345
  size_t nup = 0;
1461
- for (size_t j = 0; j < list_size; j++) {
1346
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1347
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1348
+ continue;
1349
+ }
1350
+
1462
1351
  float dis = dc.query_to_code(codes);
1463
1352
 
1464
1353
  if (dis < simi[0]) {
@@ -1466,7 +1355,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
1466
1355
  maxheap_replace_top(k, simi, idxi, dis, id);
1467
1356
  nup++;
1468
1357
  }
1469
- codes += code_size;
1470
1358
  }
1471
1359
  return nup;
1472
1360
  }
@@ -1477,44 +1365,77 @@ struct IVFSQScannerL2 : InvertedListScanner {
1477
1365
  const idx_t* ids,
1478
1366
  float radius,
1479
1367
  RangeQueryResult& res) const override {
1480
- for (size_t j = 0; j < list_size; j++) {
1368
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1369
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1370
+ continue;
1371
+ }
1372
+
1481
1373
  float dis = dc.query_to_code(codes);
1482
1374
  if (dis < radius) {
1483
1375
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1484
1376
  res.add(dis, id);
1485
1377
  }
1486
- codes += code_size;
1487
1378
  }
1488
1379
  }
1489
1380
  };
1490
1381
 
1491
- template <class DCClass>
1492
- InvertedListScanner* sel2_InvertedListScanner(
1382
+ template <class DCClass, int use_sel>
1383
+ InvertedListScanner* sel3_InvertedListScanner(
1493
1384
  const ScalarQuantizer* sq,
1494
1385
  const Index* quantizer,
1495
1386
  bool store_pairs,
1387
+ const IDSelector* sel,
1496
1388
  bool r) {
1497
1389
  if (DCClass::Sim::metric_type == METRIC_L2) {
1498
- return new IVFSQScannerL2<DCClass>(
1499
- sq->d, sq->trained, sq->code_size, quantizer, store_pairs, r);
1390
+ return new IVFSQScannerL2<DCClass, use_sel>(
1391
+ sq->d,
1392
+ sq->trained,
1393
+ sq->code_size,
1394
+ quantizer,
1395
+ store_pairs,
1396
+ sel,
1397
+ r);
1500
1398
  } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1501
- return new IVFSQScannerIP<DCClass>(
1502
- sq->d, sq->trained, sq->code_size, store_pairs, r);
1399
+ return new IVFSQScannerIP<DCClass, use_sel>(
1400
+ sq->d, sq->trained, sq->code_size, store_pairs, sel, r);
1503
1401
  } else {
1504
1402
  FAISS_THROW_MSG("unsupported metric type");
1505
1403
  }
1506
1404
  }
1507
1405
 
1406
+ template <class DCClass>
1407
+ InvertedListScanner* sel2_InvertedListScanner(
1408
+ const ScalarQuantizer* sq,
1409
+ const Index* quantizer,
1410
+ bool store_pairs,
1411
+ const IDSelector* sel,
1412
+ bool r) {
1413
+ if (sel) {
1414
+ if (store_pairs) {
1415
+ return sel3_InvertedListScanner<DCClass, 2>(
1416
+ sq, quantizer, store_pairs, sel, r);
1417
+ } else {
1418
+ return sel3_InvertedListScanner<DCClass, 1>(
1419
+ sq, quantizer, store_pairs, sel, r);
1420
+ }
1421
+ } else {
1422
+ return sel3_InvertedListScanner<DCClass, 0>(
1423
+ sq, quantizer, store_pairs, sel, r);
1424
+ }
1425
+ }
1426
+
1508
1427
  template <class Similarity, class Codec, bool uniform>
1509
1428
  InvertedListScanner* sel12_InvertedListScanner(
1510
1429
  const ScalarQuantizer* sq,
1511
1430
  const Index* quantizer,
1512
1431
  bool store_pairs,
1432
+ const IDSelector* sel,
1513
1433
  bool r) {
1514
1434
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1515
1435
  using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1516
1436
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1517
- return sel2_InvertedListScanner<DCClass>(sq, quantizer, store_pairs, r);
1437
+ return sel2_InvertedListScanner<DCClass>(
1438
+ sq, quantizer, store_pairs, sel, r);
1518
1439
  }
1519
1440
 
1520
1441
  template <class Similarity>
@@ -1522,39 +1443,40 @@ InvertedListScanner* sel1_InvertedListScanner(
1522
1443
  const ScalarQuantizer* sq,
1523
1444
  const Index* quantizer,
1524
1445
  bool store_pairs,
1446
+ const IDSelector* sel,
1525
1447
  bool r) {
1526
1448
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1527
1449
  switch (sq->qtype) {
1528
1450
  case ScalarQuantizer::QT_8bit_uniform:
1529
1451
  return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
1530
- sq, quantizer, store_pairs, r);
1452
+ sq, quantizer, store_pairs, sel, r);
1531
1453
  case ScalarQuantizer::QT_4bit_uniform:
1532
1454
  return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
1533
- sq, quantizer, store_pairs, r);
1455
+ sq, quantizer, store_pairs, sel, r);
1534
1456
  case ScalarQuantizer::QT_8bit:
1535
1457
  return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
1536
- sq, quantizer, store_pairs, r);
1458
+ sq, quantizer, store_pairs, sel, r);
1537
1459
  case ScalarQuantizer::QT_4bit:
1538
1460
  return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
1539
- sq, quantizer, store_pairs, r);
1461
+ sq, quantizer, store_pairs, sel, r);
1540
1462
  case ScalarQuantizer::QT_6bit:
1541
1463
  return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
1542
- sq, quantizer, store_pairs, r);
1464
+ sq, quantizer, store_pairs, sel, r);
1543
1465
  case ScalarQuantizer::QT_fp16:
1544
1466
  return sel2_InvertedListScanner<DCTemplate<
1545
1467
  QuantizerFP16<SIMDWIDTH>,
1546
1468
  Similarity,
1547
- SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1469
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1548
1470
  case ScalarQuantizer::QT_8bit_direct:
1549
1471
  if (sq->d % 16 == 0) {
1550
1472
  return sel2_InvertedListScanner<
1551
1473
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1552
- sq, quantizer, store_pairs, r);
1474
+ sq, quantizer, store_pairs, sel, r);
1553
1475
  } else {
1554
1476
  return sel2_InvertedListScanner<DCTemplate<
1555
1477
  Quantizer8bitDirect<SIMDWIDTH>,
1556
1478
  Similarity,
1557
- SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1479
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1558
1480
  }
1559
1481
  }
1560
1482
 
@@ -1568,13 +1490,14 @@ InvertedListScanner* sel0_InvertedListScanner(
1568
1490
  const ScalarQuantizer* sq,
1569
1491
  const Index* quantizer,
1570
1492
  bool store_pairs,
1493
+ const IDSelector* sel,
1571
1494
  bool by_residual) {
1572
1495
  if (mt == METRIC_L2) {
1573
1496
  return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
1574
- sq, quantizer, store_pairs, by_residual);
1497
+ sq, quantizer, store_pairs, sel, by_residual);
1575
1498
  } else if (mt == METRIC_INNER_PRODUCT) {
1576
1499
  return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
1577
- sq, quantizer, store_pairs, by_residual);
1500
+ sq, quantizer, store_pairs, sel, by_residual);
1578
1501
  } else {
1579
1502
  FAISS_THROW_MSG("unsupported metric type");
1580
1503
  }
@@ -1586,16 +1509,17 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1586
1509
  MetricType mt,
1587
1510
  const Index* quantizer,
1588
1511
  bool store_pairs,
1512
+ const IDSelector* sel,
1589
1513
  bool by_residual) const {
1590
1514
  #ifdef USE_F16C
1591
1515
  if (d % 8 == 0) {
1592
1516
  return sel0_InvertedListScanner<8>(
1593
- mt, this, quantizer, store_pairs, by_residual);
1517
+ mt, this, quantizer, store_pairs, sel, by_residual);
1594
1518
  } else
1595
1519
  #endif
1596
1520
  {
1597
1521
  return sel0_InvertedListScanner<1>(
1598
- mt, this, quantizer, store_pairs, by_residual);
1522
+ mt, this, quantizer, store_pairs, sel, by_residual);
1599
1523
  }
1600
1524
  }
1601
1525
 
@@ -9,18 +9,21 @@
9
9
 
10
10
  #pragma once
11
11
 
12
- #include <faiss/IndexIVF.h>
13
12
  #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/DistanceComputer.h>
14
+ #include <faiss/impl/Quantizer.h>
14
15
 
15
16
  namespace faiss {
16
17
 
18
+ struct InvertedListScanner;
19
+
17
20
  /**
18
21
  * The uniform quantizer has a range [vmin, vmax]. The range can be
19
22
  * the same for all dimensions (uniform) or specific per dimension
20
23
  * (default).
21
24
  */
22
25
 
23
- struct ScalarQuantizer {
26
+ struct ScalarQuantizer : Quantizer {
24
27
  enum QuantizerType {
25
28
  QT_8bit, ///< 8 bits per component
26
29
  QT_4bit, ///< 4 bits per component
@@ -48,15 +51,9 @@ struct ScalarQuantizer {
48
51
  RangeStat rangestat;
49
52
  float rangestat_arg;
50
53
 
51
- /// dimension of input vectors
52
- size_t d;
53
-
54
54
  /// bits per scalar code
55
55
  size_t bits;
56
56
 
57
- /// bytes per vector
58
- size_t code_size;
59
-
60
57
  /// trained values (including the range)
61
58
  std::vector<float> trained;
62
59
 
@@ -66,7 +63,7 @@ struct ScalarQuantizer {
66
63
  /// updates internal values based on qtype and d
67
64
  void set_derived_sizes();
68
65
 
69
- void train(size_t n, const float* x);
66
+ void train(size_t n, const float* x) override;
70
67
 
71
68
  /// Used by an IVF index to train based on the residuals
72
69
  void train_residual(
@@ -81,38 +78,40 @@ struct ScalarQuantizer {
81
78
  * @param x vectors to encode, size n * d
82
79
  * @param codes output codes, size n * code_size
83
80
  */
84
- void compute_codes(const float* x, uint8_t* codes, size_t n) const;
81
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
85
82
 
86
83
  /** Decode a set of vectors
87
84
  *
88
85
  * @param codes codes to decode, size n * code_size
89
86
  * @param x output vectors, size n * d
90
87
  */
91
- void decode(const uint8_t* code, float* x, size_t n) const;
88
+ void decode(const uint8_t* code, float* x, size_t n) const override;
92
89
 
93
90
  /*****************************************************
94
91
  * Objects that provide methods for encoding/decoding, distance
95
92
  * computation and inverted list scanning
96
93
  *****************************************************/
97
94
 
98
- struct Quantizer {
95
+ struct SQuantizer {
99
96
  // encodes one vector. Assumes code is filled with 0s on input!
100
97
  virtual void encode_vector(const float* x, uint8_t* code) const = 0;
101
98
  virtual void decode_vector(const uint8_t* code, float* x) const = 0;
102
99
 
103
- virtual ~Quantizer() {}
100
+ virtual ~SQuantizer() {}
104
101
  };
105
102
 
106
- Quantizer* select_quantizer() const;
103
+ SQuantizer* select_quantizer() const;
107
104
 
108
- struct SQDistanceComputer : DistanceComputer {
105
+ struct SQDistanceComputer : FlatCodesDistanceComputer {
109
106
  const float* q;
110
- const uint8_t* codes;
111
- size_t code_size;
112
107
 
113
- SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
108
+ SQDistanceComputer() : q(nullptr) {}
114
109
 
115
110
  virtual float query_to_code(const uint8_t* code) const = 0;
111
+
112
+ float distance_to_code(const uint8_t* code) final {
113
+ return query_to_code(code);
114
+ }
116
115
  };
117
116
 
118
117
  SQDistanceComputer* get_distance_computer(
@@ -122,6 +121,7 @@ struct ScalarQuantizer {
122
121
  MetricType mt,
123
122
  const Index* quantizer,
124
123
  bool store_pairs,
124
+ const IDSelector* sel,
125
125
  bool by_residual = false) const;
126
126
  };
127
127