faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -19,8 +19,11 @@
19
19
  #include <immintrin.h>
20
20
  #endif
21
21
 
22
+ #include <faiss/IndexIVF.h>
22
23
  #include <faiss/impl/AuxIndexStructures.h>
23
24
  #include <faiss/impl/FaissAssert.h>
25
+ #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/fp16.h>
24
27
  #include <faiss/utils/utils.h>
25
28
 
26
29
  namespace faiss {
@@ -201,114 +204,6 @@ struct Codec6bit {
201
204
  #endif
202
205
  };
203
206
 
204
- #ifdef USE_F16C
205
-
206
- uint16_t encode_fp16(float x) {
207
- __m128 xf = _mm_set1_ps(x);
208
- __m128i xi =
209
- _mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
210
- return _mm_cvtsi128_si32(xi) & 0xffff;
211
- }
212
-
213
- float decode_fp16(uint16_t x) {
214
- __m128i xi = _mm_set1_epi16(x);
215
- __m128 xf = _mm_cvtph_ps(xi);
216
- return _mm_cvtss_f32(xf);
217
- }
218
-
219
- #else
220
-
221
- // non-intrinsic FP16 <-> FP32 code adapted from
222
- // https://github.com/ispc/ispc/blob/master/stdlib.ispc
223
-
224
- float floatbits(uint32_t x) {
225
- void* xptr = &x;
226
- return *(float*)xptr;
227
- }
228
-
229
- uint32_t intbits(float f) {
230
- void* fptr = &f;
231
- return *(uint32_t*)fptr;
232
- }
233
-
234
- uint16_t encode_fp16(float f) {
235
- // via Fabian "ryg" Giesen.
236
- // https://gist.github.com/2156668
237
- uint32_t sign_mask = 0x80000000u;
238
- int32_t o;
239
-
240
- uint32_t fint = intbits(f);
241
- uint32_t sign = fint & sign_mask;
242
- fint ^= sign;
243
-
244
- // NOTE all the integer compares in this function can be safely
245
- // compiled into signed compares since all operands are below
246
- // 0x80000000. Important if you want fast straight SSE2 code (since
247
- // there's no unsigned PCMPGTD).
248
-
249
- // Inf or NaN (all exponent bits set)
250
- // NaN->qNaN and Inf->Inf
251
- // unconditional assignment here, will override with right value for
252
- // the regular case below.
253
- uint32_t f32infty = 255u << 23;
254
- o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
255
-
256
- // (De)normalized number or zero
257
- // update fint unconditionally to save the blending; we don't need it
258
- // anymore for the Inf/NaN case anyway.
259
-
260
- const uint32_t round_mask = ~0xfffu;
261
- const uint32_t magic = 15u << 23;
262
-
263
- // Shift exponent down, denormalize if necessary.
264
- // NOTE This represents half-float denormals using single
265
- // precision denormals. The main reason to do this is that
266
- // there's no shift with per-lane variable shifts in SSE*, which
267
- // we'd otherwise need. It has some funky side effects though:
268
- // - This conversion will actually respect the FTZ (Flush To Zero)
269
- // flag in MXCSR - if it's set, no half-float denormals will be
270
- // generated. I'm honestly not sure whether this is good or
271
- // bad. It's definitely interesting.
272
- // - If the underlying HW doesn't support denormals (not an issue
273
- // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
274
- // you will always get flush-to-zero behavior. This is bad,
275
- // unless you're on a CPU where you don't care.
276
- // - Denormals tend to be slow. FP32 denormals are rare in
277
- // practice outside of things like recursive filters in DSP -
278
- // not a typical half-float application. Whether FP16 denormals
279
- // are rare in practice, I don't know. Whatever slow path your
280
- // HW may or may not have for denormals, this may well hit it.
281
- float fscale = floatbits(fint & round_mask) * floatbits(magic);
282
- fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
283
- int32_t fint2 = intbits(fscale) - round_mask;
284
-
285
- if (fint < f32infty)
286
- o = fint2 >> 13; // Take the bits!
287
-
288
- return (o | (sign >> 16));
289
- }
290
-
291
- float decode_fp16(uint16_t h) {
292
- // https://gist.github.com/2144712
293
- // Fabian "ryg" Giesen.
294
-
295
- const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
296
-
297
- int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
298
- int32_t exp = shifted_exp & o; // just the exponent
299
- o += (int32_t)(127 - 15) << 23; // exponent adjust
300
-
301
- int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
302
- int32_t zerodenorm_val =
303
- intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
304
- int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
305
-
306
- int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
307
- return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
308
- }
309
-
310
- #endif
311
-
312
207
  /*******************************************************************
313
208
  * Quantizer: normalizes scalar vector components, then passes them
314
209
  * through a codec
@@ -318,7 +213,7 @@ template <class Codec, bool uniform, int SIMD>
318
213
  struct QuantizerTemplate {};
319
214
 
320
215
  template <class Codec>
321
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::Quantizer {
216
+ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
322
217
  const size_t d;
323
218
  const float vmin, vdiff;
324
219
 
@@ -372,7 +267,7 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
372
267
  #endif
373
268
 
374
269
  template <class Codec>
375
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::Quantizer {
270
+ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
376
271
  const size_t d;
377
272
  const float *vmin, *vdiff;
378
273
 
@@ -433,7 +328,7 @@ template <int SIMDWIDTH>
433
328
  struct QuantizerFP16 {};
434
329
 
435
330
  template <>
436
- struct QuantizerFP16<1> : ScalarQuantizer::Quantizer {
331
+ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
437
332
  const size_t d;
438
333
 
439
334
  QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
@@ -478,7 +373,7 @@ template <int SIMDWIDTH>
478
373
  struct Quantizer8bitDirect {};
479
374
 
480
375
  template <>
481
- struct Quantizer8bitDirect<1> : ScalarQuantizer::Quantizer {
376
+ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
482
377
  const size_t d;
483
378
 
484
379
  Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
@@ -518,7 +413,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
518
413
  #endif
519
414
 
520
415
  template <int SIMDWIDTH>
521
- ScalarQuantizer::Quantizer* select_quantizer_1(
416
+ ScalarQuantizer::SQuantizer* select_quantizer_1(
522
417
  QuantizerType qtype,
523
418
  size_t d,
524
419
  const std::vector<float>& trained) {
@@ -911,11 +806,6 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
911
806
  q = x;
912
807
  }
913
808
 
914
- /// compute distance of vector i to current query
915
- float operator()(idx_t i) final {
916
- return query_to_code(codes + i * code_size);
917
- }
918
-
919
809
  float symmetric_dis(idx_t i, idx_t j) override {
920
810
  return compute_code_distance(
921
811
  codes + i * code_size, codes + j * code_size);
@@ -963,11 +853,6 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
963
853
  q = x;
964
854
  }
965
855
 
966
- /// compute distance of vector i to current query
967
- float operator()(idx_t i) final {
968
- return query_to_code(codes + i * code_size);
969
- }
970
-
971
856
  float symmetric_dis(idx_t i, idx_t j) override {
972
857
  return compute_code_distance(
973
858
  codes + i * code_size, codes + j * code_size);
@@ -1021,11 +906,6 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1021
906
  return compute_code_distance(tmp.data(), code);
1022
907
  }
1023
908
 
1024
- /// compute distance of vector i to current query
1025
- float operator()(idx_t i) final {
1026
- return query_to_code(codes + i * code_size);
1027
- }
1028
-
1029
909
  float symmetric_dis(idx_t i, idx_t j) override {
1030
910
  return compute_code_distance(
1031
911
  codes + i * code_size, codes + j * code_size);
@@ -1089,11 +969,6 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1089
969
  return compute_code_distance(tmp.data(), code);
1090
970
  }
1091
971
 
1092
- /// compute distance of vector i to current query
1093
- float operator()(idx_t i) final {
1094
- return query_to_code(codes + i * code_size);
1095
- }
1096
-
1097
972
  float symmetric_dis(idx_t i, idx_t j) override {
1098
973
  return compute_code_distance(
1099
974
  codes + i * code_size, codes + j * code_size);
@@ -1173,17 +1048,12 @@ SQDistanceComputer* select_distance_computer(
1173
1048
  ********************************************************************/
1174
1049
 
1175
1050
  ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
1176
- : qtype(qtype), rangestat(RS_minmax), rangestat_arg(0), d(d) {
1051
+ : Quantizer(d), qtype(qtype), rangestat(RS_minmax), rangestat_arg(0) {
1177
1052
  set_derived_sizes();
1178
1053
  }
1179
1054
 
1180
1055
  ScalarQuantizer::ScalarQuantizer()
1181
- : qtype(QT_8bit),
1182
- rangestat(RS_minmax),
1183
- rangestat_arg(0),
1184
- d(0),
1185
- bits(0),
1186
- code_size(0) {}
1056
+ : qtype(QT_8bit), rangestat(RS_minmax), rangestat_arg(0), bits(0) {}
1187
1057
 
1188
1058
  void ScalarQuantizer::set_derived_sizes() {
1189
1059
  switch (qtype) {
@@ -1273,7 +1143,7 @@ void ScalarQuantizer::train_residual(
1273
1143
  }
1274
1144
  }
1275
1145
 
1276
- ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1146
+ ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1277
1147
  #ifdef USE_F16C
1278
1148
  if (d % 8 == 0) {
1279
1149
  return select_quantizer_1<8>(qtype, d, trained);
@@ -1286,7 +1156,7 @@ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1286
1156
 
1287
1157
  void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1288
1158
  const {
1289
- std::unique_ptr<Quantizer> squant(select_quantizer());
1159
+ std::unique_ptr<SQuantizer> squant(select_quantizer());
1290
1160
 
1291
1161
  memset(codes, 0, code_size * n);
1292
1162
  #pragma omp parallel for
@@ -1295,7 +1165,7 @@ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1295
1165
  }
1296
1166
 
1297
1167
  void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1298
- std::unique_ptr<Quantizer> squant(select_quantizer());
1168
+ std::unique_ptr<SQuantizer> squant(select_quantizer());
1299
1169
 
1300
1170
  #pragma omp parallel for
1301
1171
  for (int64_t i = 0; i < n; i++)
@@ -1332,28 +1202,25 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1332
1202
 
1333
1203
  namespace {
1334
1204
 
1335
- template <class DCClass>
1205
+ template <class DCClass, int use_sel>
1336
1206
  struct IVFSQScannerIP : InvertedListScanner {
1337
1207
  DCClass dc;
1338
- bool store_pairs, by_residual;
1339
-
1340
- size_t code_size;
1208
+ bool by_residual;
1209
+ const IDSelector* sel;
1341
1210
 
1342
- idx_t list_no; /// current list (set to 0 for Flat index
1343
- float accu0; /// added to all distances
1211
+ float accu0; /// added to all distances
1344
1212
 
1345
1213
  IVFSQScannerIP(
1346
1214
  int d,
1347
1215
  const std::vector<float>& trained,
1348
1216
  size_t code_size,
1349
1217
  bool store_pairs,
1218
+ const IDSelector* sel,
1350
1219
  bool by_residual)
1351
- : dc(d, trained),
1352
- store_pairs(store_pairs),
1353
- by_residual(by_residual),
1354
- code_size(code_size),
1355
- list_no(0),
1356
- accu0(0) {}
1220
+ : dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
1221
+ this->store_pairs = store_pairs;
1222
+ this->code_size = code_size;
1223
+ }
1357
1224
 
1358
1225
  void set_query(const float* query) override {
1359
1226
  dc.set_query(query);
@@ -1377,7 +1244,11 @@ struct IVFSQScannerIP : InvertedListScanner {
1377
1244
  size_t k) const override {
1378
1245
  size_t nup = 0;
1379
1246
 
1380
- for (size_t j = 0; j < list_size; j++) {
1247
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1248
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1249
+ continue;
1250
+ }
1251
+
1381
1252
  float accu = accu0 + dc.query_to_code(codes);
1382
1253
 
1383
1254
  if (accu > simi[0]) {
@@ -1385,7 +1256,6 @@ struct IVFSQScannerIP : InvertedListScanner {
1385
1256
  minheap_replace_top(k, simi, idxi, accu, id);
1386
1257
  nup++;
1387
1258
  }
1388
- codes += code_size;
1389
1259
  }
1390
1260
  return nup;
1391
1261
  }
@@ -1396,25 +1266,31 @@ struct IVFSQScannerIP : InvertedListScanner {
1396
1266
  const idx_t* ids,
1397
1267
  float radius,
1398
1268
  RangeQueryResult& res) const override {
1399
- for (size_t j = 0; j < list_size; j++) {
1269
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1270
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1271
+ continue;
1272
+ }
1273
+
1400
1274
  float accu = accu0 + dc.query_to_code(codes);
1401
1275
  if (accu > radius) {
1402
1276
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1403
1277
  res.add(accu, id);
1404
1278
  }
1405
- codes += code_size;
1406
1279
  }
1407
1280
  }
1408
1281
  };
1409
1282
 
1410
- template <class DCClass>
1283
+ /* use_sel = 0: don't check selector
1284
+ * = 1: check on ids[j]
1285
+ * = 2: check in j directly (normally ids is nullptr and store_pairs)
1286
+ */
1287
+ template <class DCClass, int use_sel>
1411
1288
  struct IVFSQScannerL2 : InvertedListScanner {
1412
1289
  DCClass dc;
1413
1290
 
1414
- bool store_pairs, by_residual;
1415
- size_t code_size;
1291
+ bool by_residual;
1416
1292
  const Index* quantizer;
1417
- idx_t list_no; /// current inverted list
1293
+ const IDSelector* sel;
1418
1294
  const float* x; /// current query
1419
1295
 
1420
1296
  std::vector<float> tmp;
@@ -1425,15 +1301,17 @@ struct IVFSQScannerL2 : InvertedListScanner {
1425
1301
  size_t code_size,
1426
1302
  const Index* quantizer,
1427
1303
  bool store_pairs,
1304
+ const IDSelector* sel,
1428
1305
  bool by_residual)
1429
1306
  : dc(d, trained),
1430
- store_pairs(store_pairs),
1431
1307
  by_residual(by_residual),
1432
- code_size(code_size),
1433
1308
  quantizer(quantizer),
1434
- list_no(0),
1309
+ sel(sel),
1435
1310
  x(nullptr),
1436
- tmp(d) {}
1311
+ tmp(d) {
1312
+ this->store_pairs = store_pairs;
1313
+ this->code_size = code_size;
1314
+ }
1437
1315
 
1438
1316
  void set_query(const float* query) override {
1439
1317
  x = query;
@@ -1443,8 +1321,8 @@ struct IVFSQScannerL2 : InvertedListScanner {
1443
1321
  }
1444
1322
 
1445
1323
  void set_list(idx_t list_no, float /*coarse_dis*/) override {
1324
+ this->list_no = list_no;
1446
1325
  if (by_residual) {
1447
- this->list_no = list_no;
1448
1326
  // shift of x_in wrt centroid
1449
1327
  quantizer->compute_residual(x, tmp.data(), list_no);
1450
1328
  dc.set_query(tmp.data());
@@ -1465,7 +1343,11 @@ struct IVFSQScannerL2 : InvertedListScanner {
1465
1343
  idx_t* idxi,
1466
1344
  size_t k) const override {
1467
1345
  size_t nup = 0;
1468
- for (size_t j = 0; j < list_size; j++) {
1346
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1347
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1348
+ continue;
1349
+ }
1350
+
1469
1351
  float dis = dc.query_to_code(codes);
1470
1352
 
1471
1353
  if (dis < simi[0]) {
@@ -1473,7 +1355,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
1473
1355
  maxheap_replace_top(k, simi, idxi, dis, id);
1474
1356
  nup++;
1475
1357
  }
1476
- codes += code_size;
1477
1358
  }
1478
1359
  return nup;
1479
1360
  }
@@ -1484,44 +1365,77 @@ struct IVFSQScannerL2 : InvertedListScanner {
1484
1365
  const idx_t* ids,
1485
1366
  float radius,
1486
1367
  RangeQueryResult& res) const override {
1487
- for (size_t j = 0; j < list_size; j++) {
1368
+ for (size_t j = 0; j < list_size; j++, codes += code_size) {
1369
+ if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
1370
+ continue;
1371
+ }
1372
+
1488
1373
  float dis = dc.query_to_code(codes);
1489
1374
  if (dis < radius) {
1490
1375
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1491
1376
  res.add(dis, id);
1492
1377
  }
1493
- codes += code_size;
1494
1378
  }
1495
1379
  }
1496
1380
  };
1497
1381
 
1498
- template <class DCClass>
1499
- InvertedListScanner* sel2_InvertedListScanner(
1382
+ template <class DCClass, int use_sel>
1383
+ InvertedListScanner* sel3_InvertedListScanner(
1500
1384
  const ScalarQuantizer* sq,
1501
1385
  const Index* quantizer,
1502
1386
  bool store_pairs,
1387
+ const IDSelector* sel,
1503
1388
  bool r) {
1504
1389
  if (DCClass::Sim::metric_type == METRIC_L2) {
1505
- return new IVFSQScannerL2<DCClass>(
1506
- sq->d, sq->trained, sq->code_size, quantizer, store_pairs, r);
1390
+ return new IVFSQScannerL2<DCClass, use_sel>(
1391
+ sq->d,
1392
+ sq->trained,
1393
+ sq->code_size,
1394
+ quantizer,
1395
+ store_pairs,
1396
+ sel,
1397
+ r);
1507
1398
  } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1508
- return new IVFSQScannerIP<DCClass>(
1509
- sq->d, sq->trained, sq->code_size, store_pairs, r);
1399
+ return new IVFSQScannerIP<DCClass, use_sel>(
1400
+ sq->d, sq->trained, sq->code_size, store_pairs, sel, r);
1510
1401
  } else {
1511
1402
  FAISS_THROW_MSG("unsupported metric type");
1512
1403
  }
1513
1404
  }
1514
1405
 
1406
+ template <class DCClass>
1407
+ InvertedListScanner* sel2_InvertedListScanner(
1408
+ const ScalarQuantizer* sq,
1409
+ const Index* quantizer,
1410
+ bool store_pairs,
1411
+ const IDSelector* sel,
1412
+ bool r) {
1413
+ if (sel) {
1414
+ if (store_pairs) {
1415
+ return sel3_InvertedListScanner<DCClass, 2>(
1416
+ sq, quantizer, store_pairs, sel, r);
1417
+ } else {
1418
+ return sel3_InvertedListScanner<DCClass, 1>(
1419
+ sq, quantizer, store_pairs, sel, r);
1420
+ }
1421
+ } else {
1422
+ return sel3_InvertedListScanner<DCClass, 0>(
1423
+ sq, quantizer, store_pairs, sel, r);
1424
+ }
1425
+ }
1426
+
1515
1427
  template <class Similarity, class Codec, bool uniform>
1516
1428
  InvertedListScanner* sel12_InvertedListScanner(
1517
1429
  const ScalarQuantizer* sq,
1518
1430
  const Index* quantizer,
1519
1431
  bool store_pairs,
1432
+ const IDSelector* sel,
1520
1433
  bool r) {
1521
1434
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1522
1435
  using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1523
1436
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1524
- return sel2_InvertedListScanner<DCClass>(sq, quantizer, store_pairs, r);
1437
+ return sel2_InvertedListScanner<DCClass>(
1438
+ sq, quantizer, store_pairs, sel, r);
1525
1439
  }
1526
1440
 
1527
1441
  template <class Similarity>
@@ -1529,39 +1443,40 @@ InvertedListScanner* sel1_InvertedListScanner(
1529
1443
  const ScalarQuantizer* sq,
1530
1444
  const Index* quantizer,
1531
1445
  bool store_pairs,
1446
+ const IDSelector* sel,
1532
1447
  bool r) {
1533
1448
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1534
1449
  switch (sq->qtype) {
1535
1450
  case ScalarQuantizer::QT_8bit_uniform:
1536
1451
  return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
1537
- sq, quantizer, store_pairs, r);
1452
+ sq, quantizer, store_pairs, sel, r);
1538
1453
  case ScalarQuantizer::QT_4bit_uniform:
1539
1454
  return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
1540
- sq, quantizer, store_pairs, r);
1455
+ sq, quantizer, store_pairs, sel, r);
1541
1456
  case ScalarQuantizer::QT_8bit:
1542
1457
  return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
1543
- sq, quantizer, store_pairs, r);
1458
+ sq, quantizer, store_pairs, sel, r);
1544
1459
  case ScalarQuantizer::QT_4bit:
1545
1460
  return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
1546
- sq, quantizer, store_pairs, r);
1461
+ sq, quantizer, store_pairs, sel, r);
1547
1462
  case ScalarQuantizer::QT_6bit:
1548
1463
  return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
1549
- sq, quantizer, store_pairs, r);
1464
+ sq, quantizer, store_pairs, sel, r);
1550
1465
  case ScalarQuantizer::QT_fp16:
1551
1466
  return sel2_InvertedListScanner<DCTemplate<
1552
1467
  QuantizerFP16<SIMDWIDTH>,
1553
1468
  Similarity,
1554
- SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1469
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1555
1470
  case ScalarQuantizer::QT_8bit_direct:
1556
1471
  if (sq->d % 16 == 0) {
1557
1472
  return sel2_InvertedListScanner<
1558
1473
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1559
- sq, quantizer, store_pairs, r);
1474
+ sq, quantizer, store_pairs, sel, r);
1560
1475
  } else {
1561
1476
  return sel2_InvertedListScanner<DCTemplate<
1562
1477
  Quantizer8bitDirect<SIMDWIDTH>,
1563
1478
  Similarity,
1564
- SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1479
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1565
1480
  }
1566
1481
  }
1567
1482
 
@@ -1575,13 +1490,14 @@ InvertedListScanner* sel0_InvertedListScanner(
1575
1490
  const ScalarQuantizer* sq,
1576
1491
  const Index* quantizer,
1577
1492
  bool store_pairs,
1493
+ const IDSelector* sel,
1578
1494
  bool by_residual) {
1579
1495
  if (mt == METRIC_L2) {
1580
1496
  return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
1581
- sq, quantizer, store_pairs, by_residual);
1497
+ sq, quantizer, store_pairs, sel, by_residual);
1582
1498
  } else if (mt == METRIC_INNER_PRODUCT) {
1583
1499
  return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
1584
- sq, quantizer, store_pairs, by_residual);
1500
+ sq, quantizer, store_pairs, sel, by_residual);
1585
1501
  } else {
1586
1502
  FAISS_THROW_MSG("unsupported metric type");
1587
1503
  }
@@ -1593,16 +1509,17 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1593
1509
  MetricType mt,
1594
1510
  const Index* quantizer,
1595
1511
  bool store_pairs,
1512
+ const IDSelector* sel,
1596
1513
  bool by_residual) const {
1597
1514
  #ifdef USE_F16C
1598
1515
  if (d % 8 == 0) {
1599
1516
  return sel0_InvertedListScanner<8>(
1600
- mt, this, quantizer, store_pairs, by_residual);
1517
+ mt, this, quantizer, store_pairs, sel, by_residual);
1601
1518
  } else
1602
1519
  #endif
1603
1520
  {
1604
1521
  return sel0_InvertedListScanner<1>(
1605
- mt, this, quantizer, store_pairs, by_residual);
1522
+ mt, this, quantizer, store_pairs, sel, by_residual);
1606
1523
  }
1607
1524
  }
1608
1525
 
@@ -9,18 +9,21 @@
9
9
 
10
10
  #pragma once
11
11
 
12
- #include <faiss/IndexIVF.h>
13
12
  #include <faiss/impl/AuxIndexStructures.h>
13
+ #include <faiss/impl/DistanceComputer.h>
14
+ #include <faiss/impl/Quantizer.h>
14
15
 
15
16
  namespace faiss {
16
17
 
18
+ struct InvertedListScanner;
19
+
17
20
  /**
18
21
  * The uniform quantizer has a range [vmin, vmax]. The range can be
19
22
  * the same for all dimensions (uniform) or specific per dimension
20
23
  * (default).
21
24
  */
22
25
 
23
- struct ScalarQuantizer {
26
+ struct ScalarQuantizer : Quantizer {
24
27
  enum QuantizerType {
25
28
  QT_8bit, ///< 8 bits per component
26
29
  QT_4bit, ///< 4 bits per component
@@ -48,15 +51,9 @@ struct ScalarQuantizer {
48
51
  RangeStat rangestat;
49
52
  float rangestat_arg;
50
53
 
51
- /// dimension of input vectors
52
- size_t d;
53
-
54
54
  /// bits per scalar code
55
55
  size_t bits;
56
56
 
57
- /// bytes per vector
58
- size_t code_size;
59
-
60
57
  /// trained values (including the range)
61
58
  std::vector<float> trained;
62
59
 
@@ -66,7 +63,7 @@ struct ScalarQuantizer {
66
63
  /// updates internal values based on qtype and d
67
64
  void set_derived_sizes();
68
65
 
69
- void train(size_t n, const float* x);
66
+ void train(size_t n, const float* x) override;
70
67
 
71
68
  /// Used by an IVF index to train based on the residuals
72
69
  void train_residual(
@@ -81,38 +78,40 @@ struct ScalarQuantizer {
81
78
  * @param x vectors to encode, size n * d
82
79
  * @param codes output codes, size n * code_size
83
80
  */
84
- void compute_codes(const float* x, uint8_t* codes, size_t n) const;
81
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
85
82
 
86
83
  /** Decode a set of vectors
87
84
  *
88
85
  * @param codes codes to decode, size n * code_size
89
86
  * @param x output vectors, size n * d
90
87
  */
91
- void decode(const uint8_t* code, float* x, size_t n) const;
88
+ void decode(const uint8_t* code, float* x, size_t n) const override;
92
89
 
93
90
  /*****************************************************
94
91
  * Objects that provide methods for encoding/decoding, distance
95
92
  * computation and inverted list scanning
96
93
  *****************************************************/
97
94
 
98
- struct Quantizer {
95
+ struct SQuantizer {
99
96
  // encodes one vector. Assumes code is filled with 0s on input!
100
97
  virtual void encode_vector(const float* x, uint8_t* code) const = 0;
101
98
  virtual void decode_vector(const uint8_t* code, float* x) const = 0;
102
99
 
103
- virtual ~Quantizer() {}
100
+ virtual ~SQuantizer() {}
104
101
  };
105
102
 
106
- Quantizer* select_quantizer() const;
103
+ SQuantizer* select_quantizer() const;
107
104
 
108
- struct SQDistanceComputer : DistanceComputer {
105
+ struct SQDistanceComputer : FlatCodesDistanceComputer {
109
106
  const float* q;
110
- const uint8_t* codes;
111
- size_t code_size;
112
107
 
113
- SQDistanceComputer() : q(nullptr), codes(nullptr), code_size(0) {}
108
+ SQDistanceComputer() : q(nullptr) {}
114
109
 
115
110
  virtual float query_to_code(const uint8_t* code) const = 0;
111
+
112
+ float distance_to_code(const uint8_t* code) final {
113
+ return query_to_code(code);
114
+ }
116
115
  };
117
116
 
118
117
  SQDistanceComputer* get_distance_computer(
@@ -122,6 +121,7 @@ struct ScalarQuantizer {
122
121
  MetricType mt,
123
122
  const Index* quantizer,
124
123
  bool store_pairs,
124
+ const IDSelector* sel,
125
125
  bool by_residual = false) const;
126
126
  };
127
127