faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexPQ.h>
11
9
 
12
10
  #include <cinttypes>
@@ -17,7 +15,7 @@
17
15
 
18
16
  #include <algorithm>
19
17
 
20
- #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/DistanceComputer.h>
21
19
  #include <faiss/impl/FaissAssert.h>
22
20
  #include <faiss/utils/hamming.h>
23
21
 
@@ -73,19 +71,16 @@ void IndexPQ::train(idx_t n, const float* x) {
73
71
  namespace {
74
72
 
75
73
  template <class PQDecoder>
76
- struct PQDistanceComputer : DistanceComputer {
74
+ struct PQDistanceComputer : FlatCodesDistanceComputer {
77
75
  size_t d;
78
76
  MetricType metric;
79
77
  Index::idx_t nb;
80
- const uint8_t* codes;
81
- size_t code_size;
82
78
  const ProductQuantizer& pq;
83
79
  const float* sdc;
84
80
  std::vector<float> precomputed_table;
85
81
  size_t ndis;
86
82
 
87
- float operator()(idx_t i) override {
88
- const uint8_t* code = codes + i * code_size;
83
+ float distance_to_code(const uint8_t* code) final {
89
84
  const float* dt = precomputed_table.data();
90
85
  PQDecoder decoder(code, pq.nbits);
91
86
  float accu = 0;
@@ -112,13 +107,15 @@ struct PQDistanceComputer : DistanceComputer {
112
107
  return accu;
113
108
  }
114
109
 
115
- explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
110
+ explicit PQDistanceComputer(const IndexPQ& storage)
111
+ : FlatCodesDistanceComputer(
112
+ storage.codes.data(),
113
+ storage.code_size),
114
+ pq(storage.pq) {
116
115
  precomputed_table.resize(pq.M * pq.ksub);
117
116
  nb = storage.ntotal;
118
117
  d = storage.d;
119
118
  metric = storage.metric_type;
120
- codes = storage.codes.data();
121
- code_size = pq.code_size;
122
119
  if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
123
120
  sdc = pq.sdc_table.data();
124
121
  } else {
@@ -138,7 +135,7 @@ struct PQDistanceComputer : DistanceComputer {
138
135
 
139
136
  } // namespace
140
137
 
141
- DistanceComputer* IndexPQ::get_distance_computer() const {
138
+ FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
142
139
  if (pq.nbits == 8) {
143
140
  return new PQDistanceComputer<PQDecoder8>(*this);
144
141
  } else if (pq.nbits == 16) {
@@ -157,10 +154,21 @@ void IndexPQ::search(
157
154
  const float* x,
158
155
  idx_t k,
159
156
  float* distances,
160
- idx_t* labels) const {
157
+ idx_t* labels,
158
+ const SearchParameters* iparams) const {
161
159
  FAISS_THROW_IF_NOT(k > 0);
162
-
163
160
  FAISS_THROW_IF_NOT(is_trained);
161
+
162
+ const SearchParametersPQ* params = nullptr;
163
+ Search_type_t search_type = this->search_type;
164
+
165
+ if (iparams) {
166
+ params = dynamic_cast<const SearchParametersPQ*>(iparams);
167
+ FAISS_THROW_IF_NOT_MSG(params, "invalid search params");
168
+ FAISS_THROW_IF_NOT_MSG(!params->sel, "selector not supported");
169
+ search_type = params->search_type;
170
+ }
171
+
164
172
  if (search_type == ST_PQ) { // Simple PQ search
165
173
 
166
174
  if (metric_type == METRIC_L2) {
@@ -179,8 +187,16 @@ void IndexPQ::search(
179
187
  search_type == ST_polysemous ||
180
188
  search_type == ST_polysemous_generalize) {
181
189
  FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
182
-
183
- search_core_polysemous(n, x, k, distances, labels);
190
+ int polysemous_ht =
191
+ params ? params->polysemous_ht : this->polysemous_ht;
192
+ search_core_polysemous(
193
+ n,
194
+ x,
195
+ k,
196
+ distances,
197
+ labels,
198
+ polysemous_ht,
199
+ search_type == ST_polysemous_generalize);
184
200
 
185
201
  } else { // code-to-code distances
186
202
 
@@ -256,12 +272,12 @@ static size_t polysemous_inner_loop(
256
272
  const uint8_t* q_code,
257
273
  size_t k,
258
274
  float* heap_dis,
259
- int64_t* heap_ids) {
275
+ int64_t* heap_ids,
276
+ int ht) {
260
277
  int M = index.pq.M;
261
278
  int code_size = index.pq.code_size;
262
279
  int ksub = index.pq.ksub;
263
280
  size_t ntotal = index.ntotal;
264
- int ht = index.polysemous_ht;
265
281
 
266
282
  const uint8_t* b_code = index.codes.data();
267
283
 
@@ -296,11 +312,16 @@ void IndexPQ::search_core_polysemous(
296
312
  const float* x,
297
313
  idx_t k,
298
314
  float* distances,
299
- idx_t* labels) const {
315
+ idx_t* labels,
316
+ int polysemous_ht,
317
+ bool generalized_hamming) const {
300
318
  FAISS_THROW_IF_NOT(k > 0);
301
-
302
319
  FAISS_THROW_IF_NOT(pq.nbits == 8);
303
320
 
321
+ if (polysemous_ht == 0) {
322
+ polysemous_ht = pq.nbits * pq.M + 1;
323
+ }
324
+
304
325
  // PQ distance tables
305
326
  float* dis_tables = new float[n * pq.ksub * pq.M];
306
327
  ScopeDeleter<float> del(dis_tables);
@@ -323,7 +344,9 @@ void IndexPQ::search_core_polysemous(
323
344
 
324
345
  size_t n_pass = 0;
325
346
 
326
- #pragma omp parallel for reduction(+ : n_pass)
347
+ int bad_code_size = 0;
348
+
349
+ #pragma omp parallel for reduction(+ : n_pass, bad_code_size)
327
350
  for (idx_t qi = 0; qi < n; qi++) {
328
351
  const uint8_t* q_code = q_codes + qi * pq.code_size;
329
352
 
@@ -333,28 +356,24 @@ void IndexPQ::search_core_polysemous(
333
356
  float* heap_dis = distances + qi * k;
334
357
  maxheap_heapify(k, heap_dis, heap_ids);
335
358
 
336
- if (search_type == ST_polysemous) {
359
+ if (!generalized_hamming) {
337
360
  switch (pq.code_size) {
338
- case 4:
339
- n_pass += polysemous_inner_loop<HammingComputer4>(
340
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
341
- break;
342
- case 8:
343
- n_pass += polysemous_inner_loop<HammingComputer8>(
344
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
345
- break;
346
- case 16:
347
- n_pass += polysemous_inner_loop<HammingComputer16>(
348
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
349
- break;
350
- case 32:
351
- n_pass += polysemous_inner_loop<HammingComputer32>(
352
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
353
- break;
354
- case 20:
355
- n_pass += polysemous_inner_loop<HammingComputer20>(
356
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
357
- break;
361
+ #define DISPATCH(cs) \
362
+ case cs: \
363
+ n_pass += polysemous_inner_loop<HammingComputer##cs>( \
364
+ *this, \
365
+ dis_table_qi, \
366
+ q_code, \
367
+ k, \
368
+ heap_dis, \
369
+ heap_ids, \
370
+ polysemous_ht); \
371
+ break;
372
+ DISPATCH(4)
373
+ DISPATCH(8)
374
+ DISPATCH(16)
375
+ DISPATCH(32)
376
+ DISPATCH(20)
358
377
  default:
359
378
  if (pq.code_size % 4 == 0) {
360
379
  n_pass += polysemous_inner_loop<HammingComputerDefault>(
@@ -363,28 +382,30 @@ void IndexPQ::search_core_polysemous(
363
382
  q_code,
364
383
  k,
365
384
  heap_dis,
366
- heap_ids);
385
+ heap_ids,
386
+ polysemous_ht);
367
387
  } else {
368
- FAISS_THROW_FMT(
369
- "code size %zd not supported for polysemous",
370
- pq.code_size);
388
+ bad_code_size++;
371
389
  }
372
390
  break;
373
391
  }
374
- } else {
392
+ #undef DISPATCH
393
+ } else { // generalized hamming
375
394
  switch (pq.code_size) {
376
- case 8:
377
- n_pass += polysemous_inner_loop<GenHammingComputer8>(
378
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
379
- break;
380
- case 16:
381
- n_pass += polysemous_inner_loop<GenHammingComputer16>(
382
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
383
- break;
384
- case 32:
385
- n_pass += polysemous_inner_loop<GenHammingComputer32>(
386
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
387
- break;
395
+ #define DISPATCH(cs) \
396
+ case cs: \
397
+ n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
398
+ *this, \
399
+ dis_table_qi, \
400
+ q_code, \
401
+ k, \
402
+ heap_dis, \
403
+ heap_ids, \
404
+ polysemous_ht); \
405
+ break;
406
+ DISPATCH(8)
407
+ DISPATCH(16)
408
+ DISPATCH(32)
388
409
  default:
389
410
  if (pq.code_size % 8 == 0) {
390
411
  n_pass += polysemous_inner_loop<GenHammingComputerM8>(
@@ -393,18 +414,23 @@ void IndexPQ::search_core_polysemous(
393
414
  q_code,
394
415
  k,
395
416
  heap_dis,
396
- heap_ids);
417
+ heap_ids,
418
+ polysemous_ht);
397
419
  } else {
398
- FAISS_THROW_FMT(
399
- "code size %zd not supported for polysemous",
400
- pq.code_size);
420
+ bad_code_size++;
401
421
  }
402
422
  break;
423
+ #undef DISPATCH
403
424
  }
404
425
  }
405
426
  maxheap_reorder(k, heap_dis, heap_ids);
406
427
  }
407
428
 
429
+ if (bad_code_size) {
430
+ FAISS_THROW_FMT(
431
+ "code size %zd not supported for polysemous", pq.code_size);
432
+ }
433
+
408
434
  indexPQ_stats.nq += n;
409
435
  indexPQ_stats.ncode += n * ntotal;
410
436
  indexPQ_stats.n_hamming_pass += n_pass;
@@ -865,19 +891,25 @@ void MultiIndexQuantizer::train(idx_t n, const float* x) {
865
891
  ntotal *= pq.ksub;
866
892
  }
867
893
 
894
+ // block size used in MultiIndexQuantizer::search
895
+ int multi_index_quantizer_search_bs = 32768;
896
+
868
897
  void MultiIndexQuantizer::search(
869
898
  idx_t n,
870
899
  const float* x,
871
900
  idx_t k,
872
901
  float* distances,
873
- idx_t* labels) const {
874
- if (n == 0)
902
+ idx_t* labels,
903
+ const SearchParameters* params) const {
904
+ FAISS_THROW_IF_NOT_MSG(
905
+ !params, "search params not supported for this index");
906
+ if (n == 0) {
875
907
  return;
876
-
908
+ }
877
909
  FAISS_THROW_IF_NOT(k > 0);
878
910
 
879
911
  // the allocation just below can be severe...
880
- idx_t bs = 32768;
912
+ idx_t bs = multi_index_quantizer_search_bs;
881
913
  if (n > bs) {
882
914
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
883
915
  idx_t i1 = std::min(i0 + bs, n);
@@ -1012,9 +1044,14 @@ void MultiIndexQuantizer2::search(
1012
1044
  const float* x,
1013
1045
  idx_t K,
1014
1046
  float* distances,
1015
- idx_t* labels) const {
1016
- if (n == 0)
1047
+ idx_t* labels,
1048
+ const SearchParameters* params) const {
1049
+ FAISS_THROW_IF_NOT_MSG(
1050
+ !params, "search params not supported for this index");
1051
+
1052
+ if (n == 0) {
1017
1053
  return;
1054
+ }
1018
1055
 
1019
1056
  int k2 = std::min(K, int64_t(pq.ksub));
1020
1057
  FAISS_THROW_IF_NOT(k2);
@@ -45,14 +45,15 @@ struct IndexPQ : IndexFlatCodes {
45
45
  const float* x,
46
46
  idx_t k,
47
47
  float* distances,
48
- idx_t* labels) const override;
48
+ idx_t* labels,
49
+ const SearchParameters* params = nullptr) const override;
49
50
 
50
51
  /* The standalone codec interface */
51
52
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
52
53
 
53
54
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
54
55
 
55
- DistanceComputer* get_distance_computer() const override;
56
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
56
57
 
57
58
  /******************************************************
58
59
  * Polysemous codes implementation
@@ -87,7 +88,9 @@ struct IndexPQ : IndexFlatCodes {
87
88
  const float* x,
88
89
  idx_t k,
89
90
  float* distances,
90
- idx_t* labels) const;
91
+ idx_t* labels,
92
+ int polysemous_ht,
93
+ bool generalized_hamming) const;
91
94
 
92
95
  /// prepare query for a polysemous search, but instead of
93
96
  /// computing the result, just get the histogram of Hamming
@@ -109,6 +112,12 @@ struct IndexPQ : IndexFlatCodes {
109
112
  void hamming_distance_table(idx_t n, const float* x, int32_t* dis) const;
110
113
  };
111
114
 
115
+ /// override search parameters from the class
116
+ struct SearchParametersPQ : SearchParameters {
117
+ IndexPQ::Search_type_t search_type;
118
+ int polysemous_ht;
119
+ };
120
+
112
121
  /// statistics are robust to internal threading, but not if
113
122
  /// IndexPQ::search is called by multiple threads
114
123
  struct IndexPQStats {
@@ -142,7 +151,8 @@ struct MultiIndexQuantizer : Index {
142
151
  const float* x,
143
152
  idx_t k,
144
153
  float* distances,
145
- idx_t* labels) const override;
154
+ idx_t* labels,
155
+ const SearchParameters* params = nullptr) const override;
146
156
 
147
157
  /// add and reset will crash at runtime
148
158
  void add(idx_t n, const float* x) override;
@@ -153,6 +163,9 @@ struct MultiIndexQuantizer : Index {
153
163
  void reconstruct(idx_t key, float* recons) const override;
154
164
  };
155
165
 
166
+ // block size used in MultiIndexQuantizer::search
167
+ FAISS_API extern int multi_index_quantizer_search_bs;
168
+
156
169
  /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
157
170
  */
158
171
  struct MultiIndexQuantizer2 : MultiIndexQuantizer {
@@ -175,7 +188,8 @@ struct MultiIndexQuantizer2 : MultiIndexQuantizer {
175
188
  const float* x,
176
189
  idx_t k,
177
190
  float* distances,
178
- idx_t* labels) const override;
191
+ idx_t* labels,
192
+ const SearchParameters* params = nullptr) const override;
179
193
  };
180
194
 
181
195
  } // namespace faiss