faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -78,9 +78,10 @@ void IndexNSG::search(
78
78
  const float* x,
79
79
  idx_t k,
80
80
  float* distances,
81
- idx_t* labels) const
82
-
83
- {
81
+ idx_t* labels,
82
+ const SearchParameters* params) const {
83
+ FAISS_THROW_IF_NOT_MSG(
84
+ !params, "search params not supported for this index");
84
85
  FAISS_THROW_IF_NOT_MSG(
85
86
  storage,
86
87
  "Please use IndexNSGFlat (or variants) instead of IndexNSG directly");
@@ -104,9 +105,7 @@ void IndexNSG::search(
104
105
  float* simi = distances + i * k;
105
106
  dis->set_query(x + i * d);
106
107
 
107
- maxheap_heapify(k, simi, idxi);
108
108
  nsg.search(*dis, k, idxi, simi, vt);
109
- maxheap_reorder(k, simi, idxi);
110
109
 
111
110
  vt.advance();
112
111
  }
@@ -300,4 +299,37 @@ IndexNSGFlat::IndexNSGFlat(int d, int R, MetricType metric)
300
299
  is_trained = true;
301
300
  }
302
301
 
302
+ /**************************************************************
303
+ * IndexNSGPQ implementation
304
+ **************************************************************/
305
+
306
+ IndexNSGPQ::IndexNSGPQ() {}
307
+
308
+ IndexNSGPQ::IndexNSGPQ(int d, int pq_m, int M)
309
+ : IndexNSG(new IndexPQ(d, pq_m, 8), M) {
310
+ own_fields = true;
311
+ is_trained = false;
312
+ }
313
+
314
+ void IndexNSGPQ::train(idx_t n, const float* x) {
315
+ IndexNSG::train(n, x);
316
+ (dynamic_cast<IndexPQ*>(storage))->pq.compute_sdc_table();
317
+ }
318
+
319
+ /**************************************************************
320
+ * IndexNSGSQ implementation
321
+ **************************************************************/
322
+
323
+ IndexNSGSQ::IndexNSGSQ(
324
+ int d,
325
+ ScalarQuantizer::QuantizerType qtype,
326
+ int M,
327
+ MetricType metric)
328
+ : IndexNSG(new IndexScalarQuantizer(d, qtype, metric), M) {
329
+ is_trained = false;
330
+ own_fields = true;
331
+ }
332
+
333
+ IndexNSGSQ::IndexNSGSQ() {}
334
+
303
335
  } // namespace faiss
@@ -13,6 +13,8 @@
13
13
 
14
14
  #include <faiss/IndexFlat.h>
15
15
  #include <faiss/IndexNNDescent.h>
16
+ #include <faiss/IndexPQ.h>
17
+ #include <faiss/IndexScalarQuantizer.h>
16
18
  #include <faiss/impl/NSG.h>
17
19
  #include <faiss/utils/utils.h>
18
20
 
@@ -64,7 +66,8 @@ struct IndexNSG : Index {
64
66
  const float* x,
65
67
  idx_t k,
66
68
  float* distances,
67
- idx_t* labels) const override;
69
+ idx_t* labels,
70
+ const SearchParameters* params = nullptr) const override;
68
71
 
69
72
  void reconstruct(idx_t key, float* recons) const override;
70
73
 
@@ -82,4 +85,25 @@ struct IndexNSGFlat : IndexNSG {
82
85
  IndexNSGFlat(int d, int R, MetricType metric = METRIC_L2);
83
86
  };
84
87
 
88
+ /** PQ index topped with with a NSG structure to access elements
89
+ * more efficiently.
90
+ */
91
+ struct IndexNSGPQ : IndexNSG {
92
+ IndexNSGPQ();
93
+ IndexNSGPQ(int d, int pq_m, int M);
94
+ void train(idx_t n, const float* x) override;
95
+ };
96
+
97
+ /** SQ index topped with with a NSG structure to access elements
98
+ * more efficiently.
99
+ */
100
+ struct IndexNSGSQ : IndexNSG {
101
+ IndexNSGSQ();
102
+ IndexNSGSQ(
103
+ int d,
104
+ ScalarQuantizer::QuantizerType qtype,
105
+ int M,
106
+ MetricType metric = METRIC_L2);
107
+ };
108
+
85
109
  } // namespace faiss
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexPQ.h>
11
9
 
12
10
  #include <cinttypes>
@@ -17,7 +15,7 @@
17
15
 
18
16
  #include <algorithm>
19
17
 
20
- #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/DistanceComputer.h>
21
19
  #include <faiss/impl/FaissAssert.h>
22
20
  #include <faiss/utils/hamming.h>
23
21
 
@@ -28,12 +26,13 @@ namespace faiss {
28
26
  ********************************************************/
29
27
 
30
28
  IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
31
- : Index(d, metric), pq(d, M, nbits) {
29
+ : IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
32
30
  is_trained = false;
33
31
  do_polysemous_training = false;
34
32
  polysemous_ht = nbits * M + 1;
35
33
  search_type = ST_PQ;
36
34
  encode_signs = false;
35
+ code_size = pq.code_size;
37
36
  }
38
37
 
39
38
  IndexPQ::IndexPQ() {
@@ -69,69 +68,19 @@ void IndexPQ::train(idx_t n, const float* x) {
69
68
  is_trained = true;
70
69
  }
71
70
 
72
- void IndexPQ::add(idx_t n, const float* x) {
73
- FAISS_THROW_IF_NOT(is_trained);
74
- codes.resize((n + ntotal) * pq.code_size);
75
- pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
76
- ntotal += n;
77
- }
78
-
79
- size_t IndexPQ::remove_ids(const IDSelector& sel) {
80
- idx_t j = 0;
81
- for (idx_t i = 0; i < ntotal; i++) {
82
- if (sel.is_member(i)) {
83
- // should be removed
84
- } else {
85
- if (i > j) {
86
- memmove(&codes[pq.code_size * j],
87
- &codes[pq.code_size * i],
88
- pq.code_size);
89
- }
90
- j++;
91
- }
92
- }
93
- size_t nremove = ntotal - j;
94
- if (nremove > 0) {
95
- ntotal = j;
96
- codes.resize(ntotal * pq.code_size);
97
- }
98
- return nremove;
99
- }
100
-
101
- void IndexPQ::reset() {
102
- codes.clear();
103
- ntotal = 0;
104
- }
105
-
106
- void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
107
- FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
108
- for (idx_t i = 0; i < ni; i++) {
109
- const uint8_t* code = &codes[(i0 + i) * pq.code_size];
110
- pq.decode(code, recons + i * d);
111
- }
112
- }
113
-
114
- void IndexPQ::reconstruct(idx_t key, float* recons) const {
115
- FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
116
- pq.decode(&codes[key * pq.code_size], recons);
117
- }
118
-
119
71
  namespace {
120
72
 
121
73
  template <class PQDecoder>
122
- struct PQDistanceComputer : DistanceComputer {
74
+ struct PQDistanceComputer : FlatCodesDistanceComputer {
123
75
  size_t d;
124
76
  MetricType metric;
125
77
  Index::idx_t nb;
126
- const uint8_t* codes;
127
- size_t code_size;
128
78
  const ProductQuantizer& pq;
129
79
  const float* sdc;
130
80
  std::vector<float> precomputed_table;
131
81
  size_t ndis;
132
82
 
133
- float operator()(idx_t i) override {
134
- const uint8_t* code = codes + i * code_size;
83
+ float distance_to_code(const uint8_t* code) final {
135
84
  const float* dt = precomputed_table.data();
136
85
  PQDecoder decoder(code, pq.nbits);
137
86
  float accu = 0;
@@ -158,13 +107,15 @@ struct PQDistanceComputer : DistanceComputer {
158
107
  return accu;
159
108
  }
160
109
 
161
- explicit PQDistanceComputer(const IndexPQ& storage) : pq(storage.pq) {
110
+ explicit PQDistanceComputer(const IndexPQ& storage)
111
+ : FlatCodesDistanceComputer(
112
+ storage.codes.data(),
113
+ storage.code_size),
114
+ pq(storage.pq) {
162
115
  precomputed_table.resize(pq.M * pq.ksub);
163
116
  nb = storage.ntotal;
164
117
  d = storage.d;
165
118
  metric = storage.metric_type;
166
- codes = storage.codes.data();
167
- code_size = pq.code_size;
168
119
  if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
169
120
  sdc = pq.sdc_table.data();
170
121
  } else {
@@ -184,7 +135,7 @@ struct PQDistanceComputer : DistanceComputer {
184
135
 
185
136
  } // namespace
186
137
 
187
- DistanceComputer* IndexPQ::get_distance_computer() const {
138
+ FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
188
139
  if (pq.nbits == 8) {
189
140
  return new PQDistanceComputer<PQDecoder8>(*this);
190
141
  } else if (pq.nbits == 16) {
@@ -203,10 +154,21 @@ void IndexPQ::search(
203
154
  const float* x,
204
155
  idx_t k,
205
156
  float* distances,
206
- idx_t* labels) const {
157
+ idx_t* labels,
158
+ const SearchParameters* iparams) const {
207
159
  FAISS_THROW_IF_NOT(k > 0);
208
-
209
160
  FAISS_THROW_IF_NOT(is_trained);
161
+
162
+ const SearchParametersPQ* params = nullptr;
163
+ Search_type_t search_type = this->search_type;
164
+
165
+ if (iparams) {
166
+ params = dynamic_cast<const SearchParametersPQ*>(iparams);
167
+ FAISS_THROW_IF_NOT_MSG(params, "invalid search params");
168
+ FAISS_THROW_IF_NOT_MSG(!params->sel, "selector not supported");
169
+ search_type = params->search_type;
170
+ }
171
+
210
172
  if (search_type == ST_PQ) { // Simple PQ search
211
173
 
212
174
  if (metric_type == METRIC_L2) {
@@ -225,8 +187,16 @@ void IndexPQ::search(
225
187
  search_type == ST_polysemous ||
226
188
  search_type == ST_polysemous_generalize) {
227
189
  FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
228
-
229
- search_core_polysemous(n, x, k, distances, labels);
190
+ int polysemous_ht =
191
+ params ? params->polysemous_ht : this->polysemous_ht;
192
+ search_core_polysemous(
193
+ n,
194
+ x,
195
+ k,
196
+ distances,
197
+ labels,
198
+ polysemous_ht,
199
+ search_type == ST_polysemous_generalize);
230
200
 
231
201
  } else { // code-to-code distances
232
202
 
@@ -302,12 +272,12 @@ static size_t polysemous_inner_loop(
302
272
  const uint8_t* q_code,
303
273
  size_t k,
304
274
  float* heap_dis,
305
- int64_t* heap_ids) {
275
+ int64_t* heap_ids,
276
+ int ht) {
306
277
  int M = index.pq.M;
307
278
  int code_size = index.pq.code_size;
308
279
  int ksub = index.pq.ksub;
309
280
  size_t ntotal = index.ntotal;
310
- int ht = index.polysemous_ht;
311
281
 
312
282
  const uint8_t* b_code = index.codes.data();
313
283
 
@@ -342,11 +312,16 @@ void IndexPQ::search_core_polysemous(
342
312
  const float* x,
343
313
  idx_t k,
344
314
  float* distances,
345
- idx_t* labels) const {
315
+ idx_t* labels,
316
+ int polysemous_ht,
317
+ bool generalized_hamming) const {
346
318
  FAISS_THROW_IF_NOT(k > 0);
347
-
348
319
  FAISS_THROW_IF_NOT(pq.nbits == 8);
349
320
 
321
+ if (polysemous_ht == 0) {
322
+ polysemous_ht = pq.nbits * pq.M + 1;
323
+ }
324
+
350
325
  // PQ distance tables
351
326
  float* dis_tables = new float[n * pq.ksub * pq.M];
352
327
  ScopeDeleter<float> del(dis_tables);
@@ -369,7 +344,9 @@ void IndexPQ::search_core_polysemous(
369
344
 
370
345
  size_t n_pass = 0;
371
346
 
372
- #pragma omp parallel for reduction(+ : n_pass)
347
+ int bad_code_size = 0;
348
+
349
+ #pragma omp parallel for reduction(+ : n_pass, bad_code_size)
373
350
  for (idx_t qi = 0; qi < n; qi++) {
374
351
  const uint8_t* q_code = q_codes + qi * pq.code_size;
375
352
 
@@ -379,28 +356,24 @@ void IndexPQ::search_core_polysemous(
379
356
  float* heap_dis = distances + qi * k;
380
357
  maxheap_heapify(k, heap_dis, heap_ids);
381
358
 
382
- if (search_type == ST_polysemous) {
359
+ if (!generalized_hamming) {
383
360
  switch (pq.code_size) {
384
- case 4:
385
- n_pass += polysemous_inner_loop<HammingComputer4>(
386
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
387
- break;
388
- case 8:
389
- n_pass += polysemous_inner_loop<HammingComputer8>(
390
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
391
- break;
392
- case 16:
393
- n_pass += polysemous_inner_loop<HammingComputer16>(
394
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
395
- break;
396
- case 32:
397
- n_pass += polysemous_inner_loop<HammingComputer32>(
398
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
399
- break;
400
- case 20:
401
- n_pass += polysemous_inner_loop<HammingComputer20>(
402
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
403
- break;
361
+ #define DISPATCH(cs) \
362
+ case cs: \
363
+ n_pass += polysemous_inner_loop<HammingComputer##cs>( \
364
+ *this, \
365
+ dis_table_qi, \
366
+ q_code, \
367
+ k, \
368
+ heap_dis, \
369
+ heap_ids, \
370
+ polysemous_ht); \
371
+ break;
372
+ DISPATCH(4)
373
+ DISPATCH(8)
374
+ DISPATCH(16)
375
+ DISPATCH(32)
376
+ DISPATCH(20)
404
377
  default:
405
378
  if (pq.code_size % 4 == 0) {
406
379
  n_pass += polysemous_inner_loop<HammingComputerDefault>(
@@ -409,28 +382,30 @@ void IndexPQ::search_core_polysemous(
409
382
  q_code,
410
383
  k,
411
384
  heap_dis,
412
- heap_ids);
385
+ heap_ids,
386
+ polysemous_ht);
413
387
  } else {
414
- FAISS_THROW_FMT(
415
- "code size %zd not supported for polysemous",
416
- pq.code_size);
388
+ bad_code_size++;
417
389
  }
418
390
  break;
419
391
  }
420
- } else {
392
+ #undef DISPATCH
393
+ } else { // generalized hamming
421
394
  switch (pq.code_size) {
422
- case 8:
423
- n_pass += polysemous_inner_loop<GenHammingComputer8>(
424
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
425
- break;
426
- case 16:
427
- n_pass += polysemous_inner_loop<GenHammingComputer16>(
428
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
429
- break;
430
- case 32:
431
- n_pass += polysemous_inner_loop<GenHammingComputer32>(
432
- *this, dis_table_qi, q_code, k, heap_dis, heap_ids);
433
- break;
395
+ #define DISPATCH(cs) \
396
+ case cs: \
397
+ n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
398
+ *this, \
399
+ dis_table_qi, \
400
+ q_code, \
401
+ k, \
402
+ heap_dis, \
403
+ heap_ids, \
404
+ polysemous_ht); \
405
+ break;
406
+ DISPATCH(8)
407
+ DISPATCH(16)
408
+ DISPATCH(32)
434
409
  default:
435
410
  if (pq.code_size % 8 == 0) {
436
411
  n_pass += polysemous_inner_loop<GenHammingComputerM8>(
@@ -439,27 +414,29 @@ void IndexPQ::search_core_polysemous(
439
414
  q_code,
440
415
  k,
441
416
  heap_dis,
442
- heap_ids);
417
+ heap_ids,
418
+ polysemous_ht);
443
419
  } else {
444
- FAISS_THROW_FMT(
445
- "code size %zd not supported for polysemous",
446
- pq.code_size);
420
+ bad_code_size++;
447
421
  }
448
422
  break;
423
+ #undef DISPATCH
449
424
  }
450
425
  }
451
426
  maxheap_reorder(k, heap_dis, heap_ids);
452
427
  }
453
428
 
429
+ if (bad_code_size) {
430
+ FAISS_THROW_FMT(
431
+ "code size %zd not supported for polysemous", pq.code_size);
432
+ }
433
+
454
434
  indexPQ_stats.nq += n;
455
435
  indexPQ_stats.ncode += n * ntotal;
456
436
  indexPQ_stats.n_hamming_pass += n_pass;
457
437
  }
458
438
 
459
439
  /* The standalone codec interface (just remaps to the PQ functions) */
460
- size_t IndexPQ::sa_code_size() const {
461
- return pq.code_size;
462
- }
463
440
 
464
441
  void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
465
442
  pq.compute_codes(x, bytes, n);
@@ -914,19 +891,25 @@ void MultiIndexQuantizer::train(idx_t n, const float* x) {
914
891
  ntotal *= pq.ksub;
915
892
  }
916
893
 
894
+ // block size used in MultiIndexQuantizer::search
895
+ int multi_index_quantizer_search_bs = 32768;
896
+
917
897
  void MultiIndexQuantizer::search(
918
898
  idx_t n,
919
899
  const float* x,
920
900
  idx_t k,
921
901
  float* distances,
922
- idx_t* labels) const {
923
- if (n == 0)
902
+ idx_t* labels,
903
+ const SearchParameters* params) const {
904
+ FAISS_THROW_IF_NOT_MSG(
905
+ !params, "search params not supported for this index");
906
+ if (n == 0) {
924
907
  return;
925
-
908
+ }
926
909
  FAISS_THROW_IF_NOT(k > 0);
927
910
 
928
911
  // the allocation just below can be severe...
929
- idx_t bs = 32768;
912
+ idx_t bs = multi_index_quantizer_search_bs;
930
913
  if (n > bs) {
931
914
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
932
915
  idx_t i1 = std::min(i0 + bs, n);
@@ -1061,9 +1044,14 @@ void MultiIndexQuantizer2::search(
1061
1044
  const float* x,
1062
1045
  idx_t K,
1063
1046
  float* distances,
1064
- idx_t* labels) const {
1065
- if (n == 0)
1047
+ idx_t* labels,
1048
+ const SearchParameters* params) const {
1049
+ FAISS_THROW_IF_NOT_MSG(
1050
+ !params, "search params not supported for this index");
1051
+
1052
+ if (n == 0) {
1066
1053
  return;
1054
+ }
1067
1055
 
1068
1056
  int k2 = std::min(K, int64_t(pq.ksub));
1069
1057
  FAISS_THROW_IF_NOT(k2);
@@ -12,7 +12,7 @@
12
12
 
13
13
  #include <vector>
14
14
 
15
- #include <faiss/Index.h>
15
+ #include <faiss/IndexFlatCodes.h>
16
16
  #include <faiss/impl/PolysemousTraining.h>
17
17
  #include <faiss/impl/ProductQuantizer.h>
18
18
  #include <faiss/impl/platform_macros.h>
@@ -21,13 +21,10 @@ namespace faiss {
21
21
 
22
22
  /** Index based on a product quantizer. Stored vectors are
23
23
  * approximated by PQ codes. */
24
- struct IndexPQ : Index {
24
+ struct IndexPQ : IndexFlatCodes {
25
25
  /// The product quantizer used to encode the vectors
26
26
  ProductQuantizer pq;
27
27
 
28
- /// Codes. Size ntotal * pq.code_size
29
- std::vector<uint8_t> codes;
30
-
31
28
  /** Constructor.
32
29
  *
33
30
  * @param d dimensionality of the input vectors
@@ -43,31 +40,20 @@ struct IndexPQ : Index {
43
40
 
44
41
  void train(idx_t n, const float* x) override;
45
42
 
46
- void add(idx_t n, const float* x) override;
47
-
48
43
  void search(
49
44
  idx_t n,
50
45
  const float* x,
51
46
  idx_t k,
52
47
  float* distances,
53
- idx_t* labels) const override;
54
-
55
- void reset() override;
56
-
57
- void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
58
-
59
- void reconstruct(idx_t key, float* recons) const override;
60
-
61
- size_t remove_ids(const IDSelector& sel) override;
48
+ idx_t* labels,
49
+ const SearchParameters* params = nullptr) const override;
62
50
 
63
51
  /* The standalone codec interface */
64
- size_t sa_code_size() const override;
65
-
66
52
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
67
53
 
68
54
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
69
55
 
70
- DistanceComputer* get_distance_computer() const override;
56
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
71
57
 
72
58
  /******************************************************
73
59
  * Polysemous codes implementation
@@ -102,7 +88,9 @@ struct IndexPQ : Index {
102
88
  const float* x,
103
89
  idx_t k,
104
90
  float* distances,
105
- idx_t* labels) const;
91
+ idx_t* labels,
92
+ int polysemous_ht,
93
+ bool generalized_hamming) const;
106
94
 
107
95
  /// prepare query for a polysemous search, but instead of
108
96
  /// computing the result, just get the histogram of Hamming
@@ -124,6 +112,12 @@ struct IndexPQ : Index {
124
112
  void hamming_distance_table(idx_t n, const float* x, int32_t* dis) const;
125
113
  };
126
114
 
115
+ /// override search parameters from the class
116
+ struct SearchParametersPQ : SearchParameters {
117
+ IndexPQ::Search_type_t search_type;
118
+ int polysemous_ht;
119
+ };
120
+
127
121
  /// statistics are robust to internal threading, but not if
128
122
  /// IndexPQ::search is called by multiple threads
129
123
  struct IndexPQStats {
@@ -157,7 +151,8 @@ struct MultiIndexQuantizer : Index {
157
151
  const float* x,
158
152
  idx_t k,
159
153
  float* distances,
160
- idx_t* labels) const override;
154
+ idx_t* labels,
155
+ const SearchParameters* params = nullptr) const override;
161
156
 
162
157
  /// add and reset will crash at runtime
163
158
  void add(idx_t n, const float* x) override;
@@ -168,6 +163,9 @@ struct MultiIndexQuantizer : Index {
168
163
  void reconstruct(idx_t key, float* recons) const override;
169
164
  };
170
165
 
166
+ // block size used in MultiIndexQuantizer::search
167
+ FAISS_API extern int multi_index_quantizer_search_bs;
168
+
171
169
  /** MultiIndexQuantizer where the PQ assignmnet is performed by sub-indexes
172
170
  */
173
171
  struct MultiIndexQuantizer2 : MultiIndexQuantizer {
@@ -190,7 +188,8 @@ struct MultiIndexQuantizer2 : MultiIndexQuantizer {
190
188
  const float* x,
191
189
  idx_t k,
192
190
  float* distances,
193
- idx_t* labels) const override;
191
+ idx_t* labels,
192
+ const SearchParameters* params = nullptr) const override;
194
193
  };
195
194
 
196
195
  } // namespace faiss