faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -19,8 +19,11 @@
19
19
  #include <immintrin.h>
20
20
  #endif
21
21
 
22
+ #include <faiss/IndexIVF.h>
22
23
  #include <faiss/impl/AuxIndexStructures.h>
23
24
  #include <faiss/impl/FaissAssert.h>
25
+ #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/fp16.h>
24
27
  #include <faiss/utils/utils.h>
25
28
 
26
29
  namespace faiss {
@@ -201,114 +204,6 @@ struct Codec6bit {
201
204
  #endif
202
205
  };
203
206
 
204
- #ifdef USE_F16C
205
-
206
- uint16_t encode_fp16(float x) {
207
- __m128 xf = _mm_set1_ps(x);
208
- __m128i xi =
209
- _mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
210
- return _mm_cvtsi128_si32(xi) & 0xffff;
211
- }
212
-
213
- float decode_fp16(uint16_t x) {
214
- __m128i xi = _mm_set1_epi16(x);
215
- __m128 xf = _mm_cvtph_ps(xi);
216
- return _mm_cvtss_f32(xf);
217
- }
218
-
219
- #else
220
-
221
- // non-intrinsic FP16 <-> FP32 code adapted from
222
- // https://github.com/ispc/ispc/blob/master/stdlib.ispc
223
-
224
- float floatbits(uint32_t x) {
225
- void* xptr = &x;
226
- return *(float*)xptr;
227
- }
228
-
229
- uint32_t intbits(float f) {
230
- void* fptr = &f;
231
- return *(uint32_t*)fptr;
232
- }
233
-
234
- uint16_t encode_fp16(float f) {
235
- // via Fabian "ryg" Giesen.
236
- // https://gist.github.com/2156668
237
- uint32_t sign_mask = 0x80000000u;
238
- int32_t o;
239
-
240
- uint32_t fint = intbits(f);
241
- uint32_t sign = fint & sign_mask;
242
- fint ^= sign;
243
-
244
- // NOTE all the integer compares in this function can be safely
245
- // compiled into signed compares since all operands are below
246
- // 0x80000000. Important if you want fast straight SSE2 code (since
247
- // there's no unsigned PCMPGTD).
248
-
249
- // Inf or NaN (all exponent bits set)
250
- // NaN->qNaN and Inf->Inf
251
- // unconditional assignment here, will override with right value for
252
- // the regular case below.
253
- uint32_t f32infty = 255u << 23;
254
- o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
255
-
256
- // (De)normalized number or zero
257
- // update fint unconditionally to save the blending; we don't need it
258
- // anymore for the Inf/NaN case anyway.
259
-
260
- const uint32_t round_mask = ~0xfffu;
261
- const uint32_t magic = 15u << 23;
262
-
263
- // Shift exponent down, denormalize if necessary.
264
- // NOTE This represents half-float denormals using single
265
- // precision denormals. The main reason to do this is that
266
- // there's no shift with per-lane variable shifts in SSE*, which
267
- // we'd otherwise need. It has some funky side effects though:
268
- // - This conversion will actually respect the FTZ (Flush To Zero)
269
- // flag in MXCSR - if it's set, no half-float denormals will be
270
- // generated. I'm honestly not sure whether this is good or
271
- // bad. It's definitely interesting.
272
- // - If the underlying HW doesn't support denormals (not an issue
273
- // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
274
- // you will always get flush-to-zero behavior. This is bad,
275
- // unless you're on a CPU where you don't care.
276
- // - Denormals tend to be slow. FP32 denormals are rare in
277
- // practice outside of things like recursive filters in DSP -
278
- // not a typical half-float application. Whether FP16 denormals
279
- // are rare in practice, I don't know. Whatever slow path your
280
- // HW may or may not have for denormals, this may well hit it.
281
- float fscale = floatbits(fint & round_mask) * floatbits(magic);
282
- fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
283
- int32_t fint2 = intbits(fscale) - round_mask;
284
-
285
- if (fint < f32infty)
286
- o = fint2 >> 13; // Take the bits!
287
-
288
- return (o | (sign >> 16));
289
- }
290
-
291
- float decode_fp16(uint16_t h) {
292
- // https://gist.github.com/2144712
293
- // Fabian "ryg" Giesen.
294
-
295
- const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
296
-
297
- int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
298
- int32_t exp = shifted_exp & o; // just the exponent
299
- o += (int32_t)(127 - 15) << 23; // exponent adjust
300
-
301
- int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
302
- int32_t zerodenorm_val =
303
- intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
304
- int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
305
-
306
- int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
307
- return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
308
- }
309
-
310
- #endif
311
-
312
207
  /*******************************************************************
313
208
  * Quantizer: normalizes scalar vector components, then passes them
314
209
  * through a codec
@@ -318,7 +213,7 @@ template <class Codec, bool uniform, int SIMD>
318
213
  struct QuantizerTemplate {};
319
214
 
320
215
  template <class Codec>
321
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::Quantizer {
216
+ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
322
217
  const size_t d;
323
218
  const float vmin, vdiff;
324
219
 
@@ -372,7 +267,7 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
372
267
  #endif
373
268
 
374
269
  template <class Codec>
375
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::Quantizer {
270
+ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
376
271
  const size_t d;
377
272
  const float *vmin, *vdiff;
378
273
 
@@ -433,7 +328,7 @@ template <int SIMDWIDTH>
433
328
  struct QuantizerFP16 {};
434
329
 
435
330
  template <>
436
- struct QuantizerFP16<1> : ScalarQuantizer::Quantizer {
331
+ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
437
332
  const size_t d;
438
333
 
439
334
  QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
@@ -478,7 +373,7 @@ template <int SIMDWIDTH>
478
373
  struct Quantizer8bitDirect {};
479
374
 
480
375
  template <>
481
- struct Quantizer8bitDirect<1> : ScalarQuantizer::Quantizer {
376
+ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
482
377
  const size_t d;
483
378
 
484
379
  Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
@@ -518,7 +413,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
518
413
  #endif
519
414
 
520
415
  template <int SIMDWIDTH>
521
- ScalarQuantizer::Quantizer* select_quantizer_1(
416
+ ScalarQuantizer::SQuantizer* select_quantizer_1(
522
417
  QuantizerType qtype,
523
418
  size_t d,
524
419
  const std::vector<float>& trained) {
@@ -911,11 +806,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
911
806
  q = x;
912
807
  }
913
808
 
914
- /// compute distance of vector i to current query
915
- float operator()(idx_t i) final {
916
- return query_to_code(codes + i * code_size);
917
- }
918
-
919
809
  float symmetric_dis(idx_t i, idx_t j) override {
920
810
  return compute_code_distance(
921
811
  codes + i * code_size, codes + j * code_size);
@@ -963,11 +853,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
963
853
  q = x;
964
854
  }
965
855
 
966
- /// compute distance of vector i to current query
967
- float operator()(idx_t i) final {
968
- return query_to_code(codes + i * code_size);
969
- }
970
-
971
856
  float symmetric_dis(idx_t i, idx_t j) override {
972
857
  return compute_code_distance(
973
858
  codes + i * code_size, codes + j * code_size);
@@ -1021,11 +906,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1021
906
  return compute_code_distance(tmp.data(), code);
1022
907
  }
1023
908
 
1024
- /// compute distance of vector i to current query
1025
- float operator()(idx_t i) final {
1026
- return query_to_code(codes + i * code_size);
1027
- }
1028
-
1029
909
  float symmetric_dis(idx_t i, idx_t j) override {
1030
910
  return compute_code_distance(
1031
911
  codes + i * code_size, codes + j * code_size);
@@ -1089,11 +969,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1089
969
  return compute_code_distance(tmp.data(), code);
1090
970
  }
1091
971
 
1092
- /// compute distance of vector i to current query
1093
- float operator()(idx_t i) final {
1094
- return query_to_code(codes + i * code_size);
1095
- }
1096
-
1097
972
  float symmetric_dis(idx_t i, idx_t j) override {
1098
973
  return compute_code_distance(
1099
974
  codes + i * code_size, codes + j * code_size);
@@ -1173,17 +1048,12 @@ SQDistanceComputer* select_distance_computer(
1173
1048
  ********************************************************************/
1174
1049
 
1175
1050
  ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
1176
- : qtype(qtype), rangestat(RS_minmax), rangestat_arg(0), d(d) {
1051
+ : Quantizer(d), qtype(qtype), rangestat(RS_minmax), rangestat_arg(0) {
1177
1052
  set_derived_sizes();
1178
1053
  }
1179
1054
 
1180
1055
  ScalarQuantizer::ScalarQuantizer()
1181
- : qtype(QT_8bit),
1182
- rangestat(RS_minmax),
1183
- rangestat_arg(0),
1184
- d(0),
1185
- bits(0),
1186
- code_size(0) {}
1056
+ : qtype(QT_8bit), rangestat(RS_minmax), rangestat_arg(0), bits(0) {}
1187
1057
 
1188
1058
  void ScalarQuantizer::set_derived_sizes() {
1189
1059
  switch (qtype) {
@@ -1273,7 +1143,7 @@ void ScalarQuantizer::train_residual(
1273
1143
  }
1274
1144
  }
1275
1145
 
1276
- ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1146
+ ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1277
1147
  #ifdef USE_F16C
1278
1148
  if (d % 8 == 0) {
1279
1149
  return select_quantizer_1<8>(qtype, d, trained);
@@ -1286,7 +1156,7 @@ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1286
1156
 
1287
1157
  void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1288
1158
  const {
1289
- std::unique_ptr<Quantizer> squant(select_quantizer());
1159
+ std::unique_ptr<SQuantizer> squant(select_quantizer());
1290
1160
 
1291
1161
  memset(codes, 0, code_size * n);
1292
1162
  #pragma omp parallel for
@@ -1295,7 +1165,7 @@ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1295
1165
  }
1296
1166
 
1297
1167
  void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1298
- std::unique_ptr<Quantizer> squant(select_quantizer());
1168
+ std::unique_ptr<SQuantizer> squant(select_quantizer());
1299
1169
 
1300
1170
  #pragma omp parallel for
1301
1171
  for (int64_t i = 0; i < n; i++)
@@ -1332,10 +1202,11 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1332
1202
 
1333
1203
  namespace {
1334
1204
 
1335
- template <class DCClass>
1205
+ template <class DCClass, int use_sel>
1336
1206
  struct IVFSQScannerIP : InvertedListScanner {
1337
1207
  DCClass dc;
1338
1208
  bool by_residual;
1209
+ const IDSelector* sel;
1339
1210
 
1340
1211
  float accu0; /// added to all distances
1341
1212
 
@@ -1344,8 +1215,9 @@ struct IVFSQScannerIP : InvertedListScanner {
1344
1215
  const std::vector<float>& trained,
1345
1216
  size_t code_size,
1346
1217
  bool store_pairs,
1218
+ const IDSelector* sel,
1347
1219
  bool by_residual)
1348
- : dc(d, trained), by_residual(by_residual), accu0(0) {
1220
+ : dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
1349
1221
  this->store_pairs = store_pairs;
1350
1222
  this->code_size = code_size;
1351
1223
  }
@@ -1372,7 +1244,11 @@ struct IVFSQScannerIP : InvertedListScanner {
1372
1244
  size_t k) const override {
1373
1245
  size_t nup = 0;
1374
1246
 
1375
- for (size_t j = 0; j < list_size; j++) {
1247
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1248
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1249
+ continue;
1250
+ }
1251
+
1376
1252
  float accu = accu0 + dc.query_to_code(codes);
1377
1253
 
1378
1254
  if (accu > simi[0]) {
@@ -1380,7 +1256,6 @@ struct IVFSQScannerIP : InvertedListScanner {
1380
1256
  minheap_replace_top(k, simi, idxi, accu, id);
1381
1257
  nup++;
1382
1258
  }
1383
- codes += code_size;
1384
1259
  }
1385
1260
  return nup;
1386
1261
  }
@@ -1391,23 +1266,31 @@ struct IVFSQScannerIP : InvertedListScanner {
1391
1266
  const idx_t* ids,
1392
1267
  float radius,
1393
1268
  RangeQueryResult& res) const override {
1394
- for (size_t j = 0; j < list_size; j++) {
1269
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1270
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1271
+ continue;
1272
+ }
1273
+
1395
1274
  float accu = accu0 + dc.query_to_code(codes);
1396
1275
  if (accu > radius) {
1397
1276
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1398
1277
  res.add(accu, id);
1399
1278
  }
1400
- codes += code_size;
1401
1279
  }
1402
1280
  }
1403
1281
  };
1404
1282
 
1405
- template <class DCClass>
1283
+ /* use_sel = 0: don't check selector
1284
+ * = 1: check on ids[j]
1285
+ * = 2: check in j directly (normally ids is nullptr and store_pairs)
1286
+ */
1287
+ template <class DCClass, int use_sel>
1406
1288
  struct IVFSQScannerL2 : InvertedListScanner {
1407
1289
  DCClass dc;
1408
1290
 
1409
1291
  bool by_residual;
1410
1292
  const Index* quantizer;
1293
+ const IDSelector* sel;
1411
1294
  const float* x; /// current query
1412
1295
 
1413
1296
  std::vector<float> tmp;
@@ -1418,10 +1301,12 @@ struct IVFSQScannerL2 : InvertedListScanner {
1418
1301
  size_t code_size,
1419
1302
  const Index* quantizer,
1420
1303
  bool store_pairs,
1304
+ const IDSelector* sel,
1421
1305
  bool by_residual)
1422
1306
  : dc(d, trained),
1423
1307
  by_residual(by_residual),
1424
1308
  quantizer(quantizer),
1309
+ sel(sel),
1425
1310
  x(nullptr),
1426
1311
  tmp(d) {
1427
1312
  this->store_pairs = store_pairs;
@@ -1458,7 +1343,11 @@ struct IVFSQScannerL2 : InvertedListScanner {
1458
1343
  idx_t* idxi,
1459
1344
  size_t k) const override {
1460
1345
  size_t nup = 0;
1461
- for (size_t j = 0; j < list_size; j++) {
1346
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1347
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1348
+ continue;
1349
+ }
1350
+
1462
1351
  float dis = dc.query_to_code(codes);
1463
1352
 
1464
1353
  if (dis < simi[0]) {
@@ -1466,7 +1355,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
1466
1355
  maxheap_replace_top(k, simi, idxi, dis, id);
1467
1356
  nup++;
1468
1357
  }
1469
- codes += code_size;
1470
1358
  }
1471
1359
  return nup;
1472
1360
  }
@@ -1477,44 +1365,77 @@ struct IVFSQScannerL2 : InvertedListScanner {
1477
1365
  const idx_t* ids,
1478
1366
  float radius,
1479
1367
  RangeQueryResult& res) const override {
1480
- for (size_t j = 0; j < list_size; j++) {
1368
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1369
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1370
+ continue;
1371
+ }
1372
+
1481
1373
  float dis = dc.query_to_code(codes);
1482
1374
  if (dis < radius) {
1483
1375
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1484
1376
  res.add(dis, id);
1485
1377
  }
1486
- codes += code_size;
1487
1378
  }
1488
1379
  }
1489
1380
  };
1490
1381
 
1491
- template <class DCClass>
1492
- InvertedListScanner* sel2_InvertedListScanner(
1382
+ template <class DCClass, int use_sel>
1383
+ InvertedListScanner* sel3_InvertedListScanner(
1493
1384
  const ScalarQuantizer* sq,
1494
1385
  const Index* quantizer,
1495
1386
  bool store_pairs,
1387
+ const IDSelector* sel,
1496
1388
  bool r) {
1497
1389
  if (DCClass::Sim::metric_type == METRIC_L2) {
1498
- return new IVFSQScannerL2<DCClass>(
1499
- sq->d, sq->trained, sq->code_size, quantizer, store_pairs, r);
1390
+ return new IVFSQScannerL2<DCClass, use_sel>(
1391
+ sq->d,
1392
+ sq->trained,
1393
+ sq->code_size,
1394
+ quantizer,
1395
+ store_pairs,
1396
+ sel,
1397
+ r);
1500
1398
  } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1501
- return new IVFSQScannerIP<DCClass>(
1502
- sq->d, sq->trained, sq->code_size, store_pairs, r);
1399
+ return new IVFSQScannerIP<DCClass, use_sel>(
1400
+ sq->d, sq->trained, sq->code_size, store_pairs, sel, r);
1503
1401
  } else {
1504
1402
  FAISS_THROW_MSG("unsupported metric type");
1505
1403
  }
1506
1404
  }
1507
1405
 
1406
+ template <class DCClass>
1407
+ InvertedListScanner* sel2_InvertedListScanner(
1408
+ const ScalarQuantizer* sq,
1409
+ const Index* quantizer,
1410
+ bool store_pairs,
1411
+ const IDSelector* sel,
1412
+ bool r) {
1413
+ if (sel) {
1414
+ if (store_pairs) {
1415
+ return sel3_InvertedListScanner<DCClass, 2>(
1416
+ sq, quantizer, store_pairs, sel, r);
1417
+ } else {
1418
+ return sel3_InvertedListScanner<DCClass, 1>(
1419
+ sq, quantizer, store_pairs, sel, r);
1420
+ }
1421
+ } else {
1422
+ return sel3_InvertedListScanner<DCClass, 0>(
1423
+ sq, quantizer, store_pairs, sel, r);
1424
+ }
1425
+ }
1426
+
1508
1427
  template <class Similarity, class Codec, bool uniform>
1509
1428
  InvertedListScanner* sel12_InvertedListScanner(
1510
1429
  const ScalarQuantizer* sq,
1511
1430
  const Index* quantizer,
1512
1431
  bool store_pairs,
1432
+ const IDSelector* sel,
1513
1433
  bool r) {
1514
1434
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1515
1435
  using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1516
1436
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1517
- return sel2_InvertedListScanner<DCClass>(sq, quantizer, store_pairs, r);
1437
+ return sel2_InvertedListScanner<DCClass>(
1438
+ sq, quantizer, store_pairs, sel, r);
1518
1439
  }
1519
1440
 
1520
1441
  template <class Similarity>
@@ -1522,39 +1443,40 @@ InvertedListScanner* sel1_InvertedListScanner(
1522
1443
  const ScalarQuantizer* sq,
1523
1444
  const Index* quantizer,
1524
1445
  bool store_pairs,
1446
+ const IDSelector* sel,
1525
1447
  bool r) {
1526
1448
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1527
1449
  switch (sq->qtype) {
1528
1450
  case ScalarQuantizer::QT_8bit_uniform:
1529
1451
  return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
1530
- sq, quantizer, store_pairs, r);
1452
+ sq, quantizer, store_pairs, sel, r);
1531
1453
  case ScalarQuantizer::QT_4bit_uniform:
1532
1454
  return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
1533
- sq, quantizer, store_pairs, r);
1455
+ sq, quantizer, store_pairs, sel, r);
1534
1456
  case ScalarQuantizer::QT_8bit:
1535
1457
  return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
1536
- sq, quantizer, store_pairs, r);
1458
+ sq, quantizer, store_pairs, sel, r);
1537
1459
  case ScalarQuantizer::QT_4bit:
1538
1460
  return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
1539
- sq, quantizer, store_pairs, r);
1461
+ sq, quantizer, store_pairs, sel, r);
1540
1462
  case ScalarQuantizer::QT_6bit:
1541
1463
  return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
1542
- sq, quantizer, store_pairs, r);
1464
+ sq, quantizer, store_pairs, sel, r);
1543
1465
  case ScalarQuantizer::QT_fp16:
1544
1466
  return sel2_InvertedListScanner<DCTemplate<
1545
1467
  QuantizerFP16<SIMDWIDTH>,
1546
1468
  Similarity,
1547
- SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1469
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1548
1470
  case ScalarQuantizer::QT_8bit_direct:
1549
1471
  if (sq->d % 16 == 0) {
1550
1472
  return sel2_InvertedListScanner<
1551
1473
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1552
- sq, quantizer, store_pairs, r);
1474
+ sq, quantizer, store_pairs, sel, r);
1553
1475
  } else {
1554
1476
  return sel2_InvertedListScanner<DCTemplate<
1555
1477
  Quantizer8bitDirect<SIMDWIDTH>,
1556
1478
  Similarity,
1557
- SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1479
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1558
1480
  }
1559
1481
  }
1560
1482
 
@@ -1568,13 +1490,14 @@ InvertedListScanner* sel0_InvertedListScanner(
1568
1490
  const ScalarQuantizer* sq,
1569
1491
  const Index* quantizer,
1570
1492
  bool store_pairs,
1493
+ const IDSelector* sel,
1571
1494
  bool by_residual) {
1572
1495
  if (mt == METRIC_L2) {
1573
1496
  return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
1574
- sq, quantizer, store_pairs, by_residual);
1497
+ sq, quantizer, store_pairs, sel, by_residual);
1575
1498
  } else if (mt == METRIC_INNER_PRODUCT) {
1576
1499
  return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
1577
- sq, quantizer, store_pairs, by_residual);
1500
+ sq, quantizer, store_pairs, sel, by_residual);
1578
1501
  } else {
1579
1502
  FAISS_THROW_MSG("unsupported metric type");
1580
1503
  }
@@ -1586,16 +1509,17 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1586
1509
  MetricType mt,
1587
1510
  const Index* quantizer,
1588
1511
  bool store_pairs,
1512
+ const IDSelector* sel,
1589
1513
  bool by_residual) const {
1590
1514
  #ifdef USE_F16C
1591
1515
  if (d % 8 == 0) {
1592
1516
  return sel0_InvertedListScanner<8>(
1593
- mt, this, quantizer, store_pairs, by_residual);
1517
+ mt, this, quantizer, store_pairs, sel, by_residual);
1594
1518
  } else
1595
1519
  #endif
1596
1520
  {
1597
1521
  return sel0_InvertedListScanner<1>(
1598
- mt, this, quantizer, store_pairs, by_residual);
1522
+ mt, this, quantizer, store_pairs, sel, by_residual);
1599
1523
  }
1600
1524
  }
1601
1525
 
@@ -9,18 +9,21 @@
9
9
 
10
10
  #pragma once
11
11
 
12
- #include <faiss/IndexIVF.h>
13
12
  #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/DistanceComputer.h>
14
+ #include <faiss/impl/Quantizer.h>
14
15
 
15
16
  namespace faiss {
16
17
 
18
+ struct InvertedListScanner;
19
+
17
20
  /**
18
21
  * The uniform quantizer has a range [vmin, vmax]. The range can be
19
22
  * the same for all dimensions (uniform) or specific per dimension
20
23
  * (default).
21
24
  */
22
25
 
23
- struct ScalarQuantizer {
26
+ struct ScalarQuantizer : Quantizer {
24
27
  enum QuantizerType {
25
28
  QT_8bit, ///< 8 bits per component
26
29
  QT_4bit, ///< 4 bits per component
@@ -48,15 +51,9 @@ struct ScalarQuantizer {
48
51
  RangeStat rangestat;
49
52
  float rangestat_arg;
50
53
 
51
- /// dimension of input vectors
52
- size_t d;
53
-
54
54
  /// bits per scalar code
55
55
  size_t bits;
56
56
 
57
- /// bytes per vector
58
- size_t code_size;
59
-
60
57
  /// trained values (including the range)
61
58
  std::vector<float> trained;
62
59
 
@@ -66,7 +63,7 @@ struct ScalarQuantizer {
66
63
  /// updates internal values based on qtype and d
67
64
  void set_derived_sizes();
68
65
 
69
- void train(size_t n, const float* x);
66
+ void train(size_t n, const float* x) override;
70
67
 
71
68
  /// Used by an IVF index to train based on the residuals
72
69
  void train_residual(
@@ -81,38 +78,40 @@ struct ScalarQuantizer {
81
78
  * @param x vectors to encode, size n * d
82
79
  * @param codes output codes, size n * code_size
83
80
  */
84
- void compute_codes(const float* x, uint8_t* codes, size_t n) const;
81
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
85
82
 
86
83
  /** Decode a set of vectors
87
84
  *
88
85
  * @param codes codes to decode, size n * code_size
89
86
  * @param x output vectors, size n * d
90
87
  */
91
- void decode(const uint8_t* code, float* x, size_t n) const;
88
+ void decode(const uint8_t* code, float* x, size_t n) const override;
92
89
 
93
90
  /*****************************************************
94
91
  * Objects that provide methods for encoding/decoding, distance
95
92
  * computation and inverted list scanning
96
93
  *****************************************************/
97
94
 
98
- struct Quantizer {
95
+ struct SQuantizer {
99
96
  // encodes one vector. Assumes code is filled with 0s on input!
100
97
  virtual void encode_vector(const float* x, uint8_t* code) const = 0;
101
98
  virtual void decode_vector(const uint8_t* code, float* x) const = 0;
102
99
 
103
- virtual ~Quantizer() {}
100
+ virtual ~SQuantizer() {}
104
101
  };
105
102
 
106
- Quantizer* select_quantizer() const;
103
+ SQuantizer* select_quantizer() const;
107
104
 
108
- struct SQDistanceComputer : DistanceComputer {
105
+ struct SQDistanceComputer : FlatCodesDistanceComputer {
109
106
  const float* q;
110
- const uint8_t* codes;
111
- size_t code_size;
112
107
 
113
- SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
108
+ SQDistanceComputer() : q(nullptr) {}
114
109
 
115
110
  virtual float query_to_code(const uint8_t* code) const = 0;
111
+
112
+ float distance_to_code(const uint8_t* code) final {
113
+ return query_to_code(code);
114
+ }
116
115
  };
117
116
 
118
117
  SQDistanceComputer* get_distance_computer(
@@ -122,6 +121,7 @@ struct ScalarQuantizer {
122
121
  MetricType mt,
123
122
  const Index* quantizer,
124
123
  bool store_pairs,
124
+ const IDSelector* sel,
125
125
  bool by_residual = false) const;
126
126
  };
127
127