faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexPQ.h>
11
9
 
12
10
  #include <cinttypes>
@@ -17,7 +15,7 @@
17
15
 
18
16
  #include <algorithm>
19
17
 
20
- #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/DistanceComputer.h>
21
19
  #include <faiss/impl/FaissAssert.h>
22
20
  #include <faiss/utils/hamming.h>
23
21
 
@@ -73,19 +71,16 @@ void IndexPQ::train(idx_t n, const float* x) {
73
71
  namespace {
74
72
 
75
73
  template <class PQDecoder>
76
- struct PQDistanceComputer : DistanceComputer {
74
+ struct PQDistanceComputer : FlatCodesDistanceComputer {
77
75
  size_t d;
78
76
  MetricType metric;
79
77
  Index::idx_t nb;
80
- const uint8_t* codes;
81
- size_t code_size;
82
78
  const ProductQuantizer& pq;
83
79
  const float* sdc;
84
80
  std::vector<float> precomputed_table;
85
81
  size_t ndis;
86
82
 
87
- float operator()(idx_t i) override {
88
- const uint8_t* code = codes + i * code_size;
83
+ float distance_to_code(const uint8_t* code) final {
89
84
  const float* dt = precomputed_table.data();
90
85
  PQDecoder decoder(code, pq.nbits);
91
86
  float accu = 0;
@@ -112,13 +107,15 @@ struct PQDistanceComputer : DistanceComputer {
112
107
  return accu;
113
108
  }
114
109
 
115
- explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
110
+ explicit PQDistanceComputer(const IndexPQ& storage)
111
+ : FlatCodesDistanceComputer(
112
+ storage.codes.data(),
113
+ storage.code_size),
114
+ pq(storage.pq) {
116
115
  precomputed_table.resize(pq.M * pq.ksub);
117
116
  nb = storage.ntotal;
118
117
  d = storage.d;
119
118
  metric = storage.metric_type;
120
- codes = storage.codes.data();
121
- code_size = pq.code_size;
122
119
  if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
123
120
  sdc = pq.sdc_table.data();
124
121
  } else {
@@ -138,7 +135,7 @@ struct PQDistanceComputer : DistanceComputer {
138
135
 
139
136
  } // namespace
140
137
 
141
- DistanceComputer* IndexPQ::get_distance_computer() const {
138
+ FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
142
139
  if (pq.nbits == 8) {
143
140
  return new PQDistanceComputer<PQDecoder8>(*this);
144
141
  } else if (pq.nbits == 16) {
@@ -157,10 +154,21 @@ void IndexPQ::search(
157
154
  const float* x,
158
155
  idx_t k,
159
156
  float* distances,
160
- idx_t* labels) const {
157
+ idx_t* labels,
158
+ const SearchParameters* iparams) const {
161
159
  FAISS_THROW_IF_NOT(k > 0);
162
-
163
160
  FAISS_THROW_IF_NOT(is_trained);
161
+
162
+ const SearchParametersPQ* params = nullptr;
163
+ Search_type_t search_type = this->search_type;
164
+
165
+ if (iparams) {
166
+ params = dynamic_cast<const SearchParametersPQ*>(iparams);
167
+ FAISS_THROW_IF_NOT_MSG(params, "invalid search params");
168
+ FAISS_THROW_IF_NOT_MSG(!params->sel, "selector not supported");
169
+ search_type = params->search_type;
170
+ }
171
+
164
172
  if (search_type == ST_PQ) { // Simple PQ search
165
173
 
166
174
  if (metric_type == METRIC_L2) {
@@ -179,8 +187,16 @@ void IndexPQ::search(
179
187
  search_type == ST_polysemous ||
180
188
  search_type == ST_polysemous_generalize) {
181
189
  FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
182
-
183
- search_core_polysemous(n, x, k, distances, labels);
190
+ int polysemous_ht =
191
+ params ? params->polysemous_ht : this->polysemous_ht;
192
+ search_core_polysemous(
193
+ n,
194
+ x,
195
+ k,
196
+ distances,
197
+ labels,
198
+ polysemous_ht,
199
+ search_type == ST_polysemous_generalize);
184
200
 
185
201
  } else { // code-to-code distances
186
202
 
@@ -256,12 +272,12 @@ static size_t polysemous_inner_loop(
256
272
  const uint8_t* q_code,
257
273
  size_t k,
258
274
  float* heap_dis,
259
- int64_t* heap_ids) {
275
+ int64_t* heap_ids,
276
+ int ht) {
260
277
  int M = index.pq.M;
261
278
  int code_size = index.pq.code_size;
262
279
  int ksub = index.pq.ksub;
263
280
  size_t ntotal = index.ntotal;
264
- int ht = index.polysemous_ht;
265
281
 
266
282
  const uint8_t* b_code = index.codes.data();
267
283
 
@@ -296,11 +312,16 @@ void IndexPQ::search_core_polysemous(
296
312
  const float* x,
297
313
  idx_t k,
298
314
  float* distances,
299
- idx_t* labels) const {
315
+ idx_t* labels,
316
+ int polysemous_ht,
317
+ bool generalized_hamming) const {
300
318
  FAISS_THROW_IF_NOT(k > 0);
301
-
302
319
  FAISS_THROW_IF_NOT(pq.nbits == 8);
303
320
 
321
+ if (polysemous_ht == 0) {
322
+ polysemous_ht = pq.nbits * pq.M + 1;
323
+ }
324
+
304
325
  // PQ distance tables
305
326
  float* dis_tables = new float[n * pq.ksub * pq.M];
306
327
  ScopeDeleter<float> del(dis_tables);
@@ -323,7 +344,9 @@ void IndexPQ::search_core_polysemous(
323
344
 
324
345
  size_t n_pass = 0;
325
346
 
326
- #pragma omp parallel for reduction(+ : n_pass)
347
+ int bad_code_size = 0;
348
+
349
+ #pragma omp parallel for reduction(+ : n_pass, bad_code_size)
327
350
  for (idx_t qi = 0; qi < n; qi++) {
328
351
  const uint8_t* q_code = q_codes + qi * pq.code_size;
329
352
 
@@ -333,28 +356,24 @@ void IndexPQ::search_core_polysemous(
333
356
  float* heap_dis = distances + qi * k;
334
357
  maxheap_heapify(k, heap_dis, heap_ids);
335
358
 
336
- if (search_type == ST_polysemous) {
359
+ if (!generalized_hamming) {
337
360
  switch (pq.code_size) {
338
- case 4:
339
- n_pass += polysemous_inner_loop<HammingComputer4>(
340
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
341
- break;
342
- case 8:
343
- n_pass += polysemous_inner_loop<HammingComputer8>(
344
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
345
- break;
346
- case 16:
347
- n_pass += polysemous_inner_loop<HammingComputer16>(
348
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
349
- break;
350
- case 32:
351
- n_pass += polysemous_inner_loop<HammingComputer32>(
352
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
353
- break;
354
- case 20:
355
- n_pass += polysemous_inner_loop<HammingComputer20>(
356
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
357
- break;
361
+ #define DISPATCH(cs) \
362
+ case cs: \
363
+ n_pass += polysemous_inner_loop<HammingComputer##cs>( \
364
+ *this, \
365
+ dis_table_qi, \
366
+ q_code, \
367
+ k, \
368
+ heap_dis, \
369
+ heap_ids, \
370
+ polysemous_ht); \
371
+ break;
372
+ DISPATCH(4)
373
+ DISPATCH(8)
374
+ DISPATCH(16)
375
+ DISPATCH(32)
376
+ DISPATCH(20)
358
377
  default:
359
378
  if (pq.code_size % 4 == 0) {
360
379
  n_pass += polysemous_inner_loop<HammingComputerDefault>(
@@ -363,28 +382,30 @@ void IndexPQ::search_core_polysemous(
363
382
  q_code,
364
383
  k,
365
384
  heap_dis,
366
- heap_ids);
385
+ heap_ids,
386
+ polysemous_ht);
367
387
  } else {
368
- FAISS_THROW_FMT(
369
- "code size %zd not supported for polysemous",
370
- pq.code_size);
388
+ bad_code_size++;
371
389
  }
372
390
  break;
373
391
  }
374
- } else {
392
+ #undef DISPATCH
393
+ } else { // generalized hamming
375
394
  switch (pq.code_size) {
376
- case 8:
377
- n_pass += polysemous_inner_loop<GenHammingComputer8>(
378
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
379
- break;
380
- case 16:
381
- n_pass += polysemous_inner_loop<GenHammingComputer16>(
382
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
383
- break;
384
- case 32:
385
- n_pass += polysemous_inner_loop<GenHammingComputer32>(
386
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
387
- break;
395
+ #define DISPATCH(cs) \
396
+ case cs: \
397
+ n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
398
+ *this, \
399
+ dis_table_qi, \
400
+ q_code, \
401
+ k, \
402
+ heap_dis, \
403
+ heap_ids, \
404
+ polysemous_ht); \
405
+ break;
406
+ DISPATCH(8)
407
+ DISPATCH(16)
408
+ DISPATCH(32)
388
409
  default:
389
410
  if (pq.code_size % 8 == 0) {
390
411
  n_pass += polysemous_inner_loop<GenHammingComputerM8>(
@@ -393,18 +414,23 @@ void IndexPQ::search_core_polysemous(
393
414
  q_code,
394
415
  k,
395
416
  heap_dis,
396
- heap_ids);
417
+ heap_ids,
418
+ polysemous_ht);
397
419
  } else {
398
- FAISS_THROW_FMT(
399
- "code size %zd not supported for polysemous",
400
- pq.code_size);
420
+ bad_code_size++;
401
421
  }
402
422
  break;
423
+ #undef DISPATCH
403
424
  }
404
425
  }
405
426
  maxheap_reorder(k, heap_dis, heap_ids);
406
427
  }
407
428
 
429
+ if (bad_code_size) {
430
+ FAISS_THROW_FMT(
431
+ "code size %zd not supported for polysemous", pq.code_size);
432
+ }
433
+
408
434
  indexPQ_stats.nq += n;
409
435
  indexPQ_stats.ncode += n * ntotal;
410
436
  indexPQ_stats.n_hamming_pass += n_pass;
@@ -865,19 +891,25 @@ void MultiIndexQuantizer::train(idx_t n, const float* x) {
865
891
  ntotal *= pq.ksub;
866
892
  }
867
893
 
894
+ // block size used in MultiIndexQuantizer::search
895
+ int multi_index_quantizer_search_bs = 32768;
896
+
868
897
  void MultiIndexQuantizer::search(
869
898
  idx_t n,
870
899
  const float* x,
871
900
  idx_t k,
872
901
  float* distances,
873
- idx_t* labels) const {
874
- if (n == 0)
902
+ idx_t* labels,
903
+ const SearchParameters* params) const {
904
+ FAISS_THROW_IF_NOT_MSG(
905
+ !params, "search params not supported for this index");
906
+ if (n == 0) {
875
907
  return;
876
-
908
+ }
877
909
  FAISS_THROW_IF_NOT(k > 0);
878
910
 
879
911
  // the allocation just below can be severe...
880
- idx_t bs = 32768;
912
+ idx_t bs = multi_index_quantizer_search_bs;
881
913
  if (n > bs) {
882
914
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
883
915
  idx_t i1 = std::min(i0 + bs, n);
@@ -1012,9 +1044,14 @@ void MultiIndexQuantizer2::search(
1012
1044
  const float* x,
1013
1045
  idx_t K,
1014
1046
  float* distances,
1015
- idx_t* labels) const {
1016
- if (n == 0)
1047
+ idx_t* labels,
1048
+ const SearchParameters* params) const {
1049
+ FAISS_THROW_IF_NOT_MSG(
1050
+ !params, "search params not supported for this index");
1051
+
1052
+ if (n == 0) {
1017
1053
  return;
1054
+ }
1018
1055
 
1019
1056
  int k2 = std::min(K, int64_t(pq.ksub));
1020
1057
  FAISS_THROW_IF_NOT(k2);
@@ -45,14 +45,15 @@ struct IndexPQ : IndexFlatCodes {
45
45
  const float* x,
46
46
  idx_t k,
47
47
  float* distances,
48
- idx_t* labels) const override;
48
+ idx_t* labels,
49
+ const SearchParameters* params = nullptr) const override;
49
50
 
50
51
  /* The standalone codec interface */
51
52
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
52
53
 
53
54
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
54
55
 
55
- DistanceComputer* get_distance_computer() const override;
56
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
56
57
 
57
58
  /******************************************************
58
59
  * Polysemous codes implementation
@@ -87,7 +88,9 @@ struct IndexPQ : IndexFlatCodes {
87
88
  const float* x,
88
89
  idx_t k,
89
90
  float* distances,
90
- idx_t* labels) const;
91
+ idx_t* labels,
92
+ int polysemous_ht,
93
+ bool generalized_hamming) const;
91
94
 
92
95
  /// prepare query for a polysemous search, but instead of
93
96
  /// computing the result, just get the histogram of Hamming
@@ -109,6 +112,12 @@ struct IndexPQ : IndexFlatCodes {
109
112
  void hamming_distance_table(idx_t n, const float* x, int32_t* dis) const;
110
113
  };
111
114
 
115
+ /// override search parameters from the class
116
+ struct SearchParametersPQ : SearchParameters {
117
+ IndexPQ::Search_type_t search_type;
118
+ int polysemous_ht;
119
+ };
120
+
112
121
  /// statistics are robust to internal threading, but not if
113
122
  /// IndexPQ::search is called by multiple threads
114
123
  struct IndexPQStats {
@@ -142,7 +151,8 @@ struct MultiIndexQuantizer : Index {
142
151
  const float* x,
143
152
  idx_t k,
144
153
  float* distances,
145
- idx_t* labels) const override;
154
+ idx_t* labels,
155
+ const SearchParameters* params = nullptr) const override;
146
156
 
147
157
  /// add and reset will crash at runtime
148
158
  void add(idx_t n, const float* x) override;
@@ -153,6 +163,9 @@ struct MultiIndexQuantizer : Index {
153
163
  void reconstruct(idx_t key, float* recons) const override;
154
164
  };
155
165
 
166
+ // block size used in MultiIndexQuantizer::search
167
+ FAISS_API extern int multi_index_quantizer_search_bs;
168
+
156
169
  /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
157
170
  */
158
171
  struct MultiIndexQuantizer2 : MultiIndexQuantizer {
@@ -175,7 +188,8 @@ struct MultiIndexQuantizer2 : MultiIndexQuantizer {
175
188
  const float* x,
176
189
  idx_t k,
177
190
  float* distances,
178
- idx_t* labels) const override;
191
+ idx_t* labels,
192
+ const SearchParameters* params = nullptr) const override;
179
193
  };
180
194
 
181
195
  } // namespace faiss