faiss 0.2.4 → 0.2.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (178) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +17 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  12. data/vendor/faiss/faiss/IVFlib.h +26 -2
  13. data/vendor/faiss/faiss/Index.cpp +36 -3
  14. data/vendor/faiss/faiss/Index.h +43 -6
  15. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  16. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  21. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  30. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  31. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  32. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  33. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  34. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  35. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  36. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  37. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  38. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  39. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  40. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  41. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  42. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  43. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  44. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  50. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  51. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  52. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  61. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  63. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  64. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  66. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  67. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  68. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  69. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  70. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  71. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  72. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  73. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  74. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  75. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  76. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  78. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  80. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  82. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  83. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  84. data/vendor/faiss/faiss/IndexShards.h +2 -1
  85. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  86. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  87. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  88. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  89. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  90. data/vendor/faiss/faiss/clone_index.h +3 -0
  91. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  93. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  102. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  103. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  105. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  106. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  111. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  112. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  113. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  114. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  118. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  119. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  120. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  122. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  124. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  125. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  127. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  128. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  129. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  131. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  132. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  133. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  134. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  136. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  139. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  145. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  146. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  147. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  151. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  152. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  153. data/vendor/faiss/faiss/index_io.h +5 -0
  154. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  155. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  156. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  157. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  158. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  159. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  160. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  161. data/vendor/faiss/faiss/utils/distances.h +113 -15
  162. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  163. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  164. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  165. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  166. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  167. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  168. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  169. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  170. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  172. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  173. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  174. data/vendor/faiss/faiss/utils/random.h +5 -0
  175. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  176. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  177. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  178. metadata +37 -3
@@ -111,7 +111,8 @@ void Index2Layer::search(
111
111
  const float* /*x*/,
112
112
  idx_t /*k*/,
113
113
  float* /*distances*/,
114
- idx_t* /*labels*/) const {
114
+ idx_t* /*labels*/,
115
+ const SearchParameters* /* params */) const {
115
116
  FAISS_THROW_MSG("not implemented");
116
117
  }
117
118
 
@@ -282,10 +283,13 @@ DistanceComputer* Index2Layer::get_distance_computer() const {
282
283
 
283
284
  /* The standalone codec interface */
284
285
 
286
+ // block size used in Index2Layer::sa_encode
287
+ int index2layer_sa_encode_bs = 32768;
288
+
285
289
  void Index2Layer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
286
290
  FAISS_THROW_IF_NOT(is_trained);
287
291
 
288
- idx_t bs = 32768;
292
+ idx_t bs = index2layer_sa_encode_bs;
289
293
  if (n > bs) {
290
294
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
291
295
  idx_t i1 = std::min(i0 + bs, n);
@@ -14,6 +14,7 @@
14
14
  #include <faiss/IndexFlatCodes.h>
15
15
  #include <faiss/IndexIVF.h>
16
16
  #include <faiss/IndexPQ.h>
17
+ #include <faiss/impl/platform_macros.h>
17
18
 
18
19
  namespace faiss {
19
20
 
@@ -56,7 +57,8 @@ struct Index2Layer : IndexFlatCodes {
56
57
  const float* x,
57
58
  idx_t k,
58
59
  float* distances,
59
- idx_t* labels) const override;
60
+ idx_t* labels,
61
+ const SearchParameters* params = nullptr) const override;
60
62
 
61
63
  DistanceComputer* get_distance_computer() const override;
62
64
 
@@ -68,4 +70,7 @@ struct Index2Layer : IndexFlatCodes {
68
70
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
69
71
  };
70
72
 
73
+ // block size used in Index2Layer::sa_encode
74
+ FAISS_API extern int index2layer_sa_encode_bs;
75
+
71
76
  } // namespace faiss
@@ -40,6 +40,90 @@ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
40
40
 
41
41
  namespace {
42
42
 
43
+ /************************************************************
44
+ * DistanceComputer implementation
45
+ ************************************************************/
46
+
47
+ template <class VectorDistance>
48
+ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
49
+ std::vector<float> tmp;
50
+ const AdditiveQuantizer & aq;
51
+ VectorDistance vd;
52
+ size_t d;
53
+
54
+ AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
55
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
56
+ tmp(iaq.d * 2),
57
+ aq(*iaq.aq),
58
+ vd(vd),
59
+ d(iaq.d)
60
+ {}
61
+
62
+ const float *q;
63
+ void set_query(const float* x) final {
64
+ q = x;
65
+ }
66
+
67
+ float symmetric_dis(idx_t i, idx_t j) final {
68
+ aq.decode(codes + i * d, tmp.data(), 1);
69
+ aq.decode(codes + j * d, tmp.data() + d, 1);
70
+ return vd(tmp.data(), tmp.data() + d);
71
+ }
72
+
73
+ float distance_to_code(const uint8_t *code) final {
74
+ aq.decode(code, tmp.data(), 1);
75
+ return vd(q, tmp.data());
76
+ }
77
+
78
+ virtual ~AQDistanceComputerDecompress() {}
79
+ };
80
+
81
+
82
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st>
83
+ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
84
+ std::vector<float> LUT;
85
+ const AdditiveQuantizer & aq;
86
+ size_t d;
87
+
88
+ explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
89
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
90
+ LUT(iaq.aq->total_codebook_size + iaq.d * 2),
91
+ aq(*iaq.aq),
92
+ d(iaq.d)
93
+ {}
94
+
95
+ float bias;
96
+ void set_query(const float* x) final {
97
+ // this is quite sub-optimal for multiple queries
98
+ aq.compute_LUT(1, x, LUT.data());
99
+ if (is_IP) {
100
+ bias = 0;
101
+ } else {
102
+ bias = fvec_norm_L2sqr(x, d);
103
+ }
104
+ }
105
+
106
+ float symmetric_dis(idx_t i, idx_t j) final {
107
+ float *tmp = LUT.data();
108
+ aq.decode(codes + i * d, tmp, 1);
109
+ aq.decode(codes + j * d, tmp + d, 1);
110
+ return fvec_L2sqr(tmp, tmp + d, d);
111
+ }
112
+
113
+ float distance_to_code(const uint8_t *code) final {
114
+ return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
115
+ }
116
+
117
+ virtual ~AQDistanceComputerLUT() {}
118
+ };
119
+
120
+
121
+
122
+ /************************************************************
123
+ * scanning implementation for search
124
+ ************************************************************/
125
+
126
+
43
127
  template <class VectorDistance, class ResultHandler>
44
128
  void search_with_decompress(
45
129
  const IndexAdditiveQuantizer& ir,
@@ -111,12 +195,61 @@ void search_with_LUT(
111
195
 
112
196
  } // anonymous namespace
113
197
 
198
+
199
+ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
200
+
201
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
202
+ if (metric_type == METRIC_L2) {
203
+ using VD = VectorDistance<METRIC_L2>;
204
+ VD vd = {size_t(d), metric_arg};
205
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
206
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
207
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
208
+ VD vd = {size_t(d), metric_arg};
209
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
210
+ } else {
211
+ FAISS_THROW_MSG("unsupported metric");
212
+ }
213
+ } else {
214
+ if (metric_type == METRIC_INNER_PRODUCT) {
215
+ return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
216
+ } else {
217
+ switch(aq->search_type) {
218
+ #define DISPATCH(st) \
219
+ case AdditiveQuantizer::st: \
220
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
221
+ break;
222
+ DISPATCH(ST_norm_float)
223
+ DISPATCH(ST_LUT_nonorm)
224
+ DISPATCH(ST_norm_qint8)
225
+ DISPATCH(ST_norm_qint4)
226
+ DISPATCH(ST_norm_cqint4)
227
+ case AdditiveQuantizer::ST_norm_cqint8:
228
+ case AdditiveQuantizer::ST_norm_lsq2x4:
229
+ case AdditiveQuantizer::ST_norm_rq2x4:
230
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
231
+ break;
232
+ #undef DISPATCH
233
+ default:
234
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+
241
+
242
+
114
243
  void IndexAdditiveQuantizer::search(
115
244
  idx_t n,
116
245
  const float* x,
117
246
  idx_t k,
118
247
  float* distances,
119
- idx_t* labels) const {
248
+ idx_t* labels,
249
+ const SearchParameters* params) const {
250
+
251
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
252
+
120
253
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
121
254
  if (metric_type == METRIC_L2) {
122
255
  using VD = VectorDistance<METRIC_L2>;
@@ -135,20 +268,23 @@ void IndexAdditiveQuantizer::search(
135
268
  search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
136
269
  } else {
137
270
  HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
138
-
139
- if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
140
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
141
- } else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
142
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
143
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
144
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
145
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
146
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
147
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8) {
271
+ switch(aq->search_type) {
272
+ #define DISPATCH(st) \
273
+ case AdditiveQuantizer::st: \
274
+ search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
275
+ break;
276
+ DISPATCH(ST_norm_float)
277
+ DISPATCH(ST_LUT_nonorm)
278
+ DISPATCH(ST_norm_qint8)
279
+ DISPATCH(ST_norm_qint4)
280
+ DISPATCH(ST_norm_cqint4)
281
+ case AdditiveQuantizer::ST_norm_cqint8:
282
+ case AdditiveQuantizer::ST_norm_lsq2x4:
283
+ case AdditiveQuantizer::ST_norm_rq2x4:
148
284
  search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
149
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
150
- search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
151
- } else {
285
+ break;
286
+ #undef DISPATCH
287
+ default:
152
288
  FAISS_THROW_FMT("search type %d not supported", aq->search_type);
153
289
  }
154
290
  }
@@ -220,6 +356,57 @@ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
220
356
  is_trained = true;
221
357
  }
222
358
 
359
+
360
+ /**************************************************************************************
361
+ * IndexProductResidualQuantizer
362
+ **************************************************************************************/
363
+
364
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer(
365
+ int d, ///< dimensionality of the input vectors
366
+ size_t nsplits, ///< number of residual quantizers
367
+ size_t Msub, ///< number of subquantizers per RQ
368
+ size_t nbits, ///< number of bit per subvector index
369
+ MetricType metric,
370
+ Search_type_t search_type)
371
+ : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
372
+ code_size = prq.code_size;
373
+ is_trained = false;
374
+ }
375
+
376
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer()
377
+ : IndexProductResidualQuantizer(0, 0, 0, 0) {}
378
+
379
+ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
380
+ prq.train(n, x);
381
+ is_trained = true;
382
+ }
383
+
384
+
385
+ /**************************************************************************************
386
+ * IndexProductLocalSearchQuantizer
387
+ **************************************************************************************/
388
+
389
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t nsplits, ///< number of local search quantizers
392
+ size_t Msub, ///< number of subquantizers per LSQ
393
+ size_t nbits, ///< number of bit per subvector index
394
+ MetricType metric,
395
+ Search_type_t search_type)
396
+ : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
397
+ code_size = plsq.code_size;
398
+ is_trained = false;
399
+ }
400
+
401
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
402
+ : IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
403
+
404
+ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
405
+ plsq.train(n, x);
406
+ is_trained = true;
407
+ }
408
+
409
+
223
410
  /**************************************************************************************
224
411
  * AdditiveCoarseQuantizer
225
412
  **************************************************************************************/
@@ -248,6 +435,13 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
248
435
  if (verbose) {
249
436
  printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
250
437
  }
438
+ size_t norms_size = sizeof(float) << aq->tot_bits;
439
+
440
+ FAISS_THROW_IF_NOT_MSG (
441
+ norms_size <= aq->max_mem_distances,
442
+ "the RCQ norms matrix will become too large, please reduce the number of quantization steps"
443
+ );
444
+
251
445
  aq->train(n, x);
252
446
  is_trained = true;
253
447
  ntotal = (idx_t)1 << aq->tot_bits;
@@ -268,7 +462,11 @@ void AdditiveCoarseQuantizer::search(
268
462
  const float* x,
269
463
  idx_t k,
270
464
  float* distances,
271
- idx_t* labels) const {
465
+ idx_t* labels,
466
+ const SearchParameters * params) const {
467
+
468
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
469
+
272
470
  if (metric_type == METRIC_INNER_PRODUCT) {
273
471
  aq->knn_centroids_inner_product(n, x, k, distances, labels);
274
472
  } else if (metric_type == METRIC_L2) {
@@ -321,7 +519,12 @@ void ResidualCoarseQuantizer::search(
321
519
  const float* x,
322
520
  idx_t k,
323
521
  float* distances,
324
- idx_t* labels) const {
522
+ idx_t* labels,
523
+ const SearchParameters * params
524
+ ) const {
525
+
526
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
527
+
325
528
  if (beam_factor < 0) {
326
529
  AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
327
530
  return;
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <faiss/IndexFlatCodes.h>
17
17
  #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ProductAdditiveQuantizer.h>
18
19
  #include <faiss/impl/ResidualQuantizer.h>
19
20
  #include <faiss/impl/platform_macros.h>
20
21
 
@@ -28,8 +29,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
28
29
  using Search_type_t = AdditiveQuantizer::Search_type_t;
29
30
 
30
31
  explicit IndexAdditiveQuantizer(
31
- idx_t d = 0,
32
- AdditiveQuantizer* aq = nullptr,
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
33
34
  MetricType metric = METRIC_L2);
34
35
 
35
36
  void search(
@@ -37,12 +38,15 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
37
38
  const float* x,
38
39
  idx_t k,
39
40
  float* distances,
40
- idx_t* labels) const override;
41
+ idx_t* labels,
42
+ const SearchParameters* params = nullptr) const override;
41
43
 
42
44
  /* The standalone codec interface */
43
45
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
44
46
 
45
47
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
48
+
49
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
46
50
  };
47
51
 
48
52
  /** Index based on a residual quantizer. Stored vectors are
@@ -98,6 +102,58 @@ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
98
102
  void train(idx_t n, const float* x) override;
99
103
  };
100
104
 
105
+ /** Index based on a product residual quantizer.
106
+ */
107
+ struct IndexProductResidualQuantizer : IndexAdditiveQuantizer {
108
+ /// The product residual quantizer used to encode the vectors
109
+ ProductResidualQuantizer prq;
110
+
111
+ /** Constructor.
112
+ *
113
+ * @param d dimensionality of the input vectors
114
+ * @param nsplits number of residual quantizers
115
+ * @param Msub number of subquantizers per RQ
116
+ * @param nbits number of bit per subvector index
117
+ */
118
+ IndexProductResidualQuantizer(
119
+ int d, ///< dimensionality of the input vectors
120
+ size_t nsplits, ///< number of residual quantizers
121
+ size_t Msub, ///< number of subquantizers per RQ
122
+ size_t nbits, ///< number of bit per subvector index
123
+ MetricType metric = METRIC_L2,
124
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
125
+
126
+ IndexProductResidualQuantizer();
127
+
128
+ void train(idx_t n, const float* x) override;
129
+ };
130
+
131
+ /** Index based on a product local search quantizer.
132
+ */
133
+ struct IndexProductLocalSearchQuantizer : IndexAdditiveQuantizer {
134
+ /// The product local search quantizer used to encode the vectors
135
+ ProductLocalSearchQuantizer plsq;
136
+
137
+ /** Constructor.
138
+ *
139
+ * @param d dimensionality of the input vectors
140
+ * @param nsplits number of local search quantizers
141
+ * @param Msub number of subquantizers per LSQ
142
+ * @param nbits number of bit per subvector index
143
+ */
144
+ IndexProductLocalSearchQuantizer(
145
+ int d, ///< dimensionality of the input vectors
146
+ size_t nsplits, ///< number of local search quantizers
147
+ size_t Msub, ///< number of subquantizers per LSQ
148
+ size_t nbits, ///< number of bit per subvector index
149
+ MetricType metric = METRIC_L2,
150
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
151
+
152
+ IndexProductLocalSearchQuantizer();
153
+
154
+ void train(idx_t n, const float* x) override;
155
+ };
156
+
101
157
  /** A "virtual" index where the elements are the residual quantizer centroids.
102
158
  *
103
159
  * Intended for use as a coarse quantizer in an IndexIVF.
@@ -121,7 +177,8 @@ struct AdditiveCoarseQuantizer : Index {
121
177
  const float* x,
122
178
  idx_t k,
123
179
  float* distances,
124
- idx_t* labels) const override;
180
+ idx_t* labels,
181
+ const SearchParameters* params = nullptr) const override;
125
182
 
126
183
  void reconstruct(idx_t key, float* recons) const override;
127
184
  void train(idx_t n, const float* x) override;
@@ -166,7 +223,8 @@ struct ResidualCoarseQuantizer : AdditiveCoarseQuantizer {
166
223
  const float* x,
167
224
  idx_t k,
168
225
  float* distances,
169
- idx_t* labels) const override;
226
+ idx_t* labels,
227
+ const SearchParameters* params = nullptr) const override;
170
228
 
171
229
  ResidualCoarseQuantizer();
172
230
  };