faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -9,10 +9,10 @@
9
9
 
10
10
  #include <faiss/IndexIVFPQ.h>
11
11
 
12
- #include <stdint.h>
13
12
  #include <cassert>
14
13
  #include <cinttypes>
15
14
  #include <cmath>
15
+ #include <cstdint>
16
16
  #include <cstdio>
17
17
 
18
18
  #include <algorithm>
@@ -64,74 +64,16 @@ IndexIVFPQ::IndexIVFPQ(
64
64
  /****************************************************************
65
65
  * training */
66
66
 
67
- void IndexIVFPQ::train_residual(idx_t n, const float* x) {
68
- train_residual_o(n, x, nullptr);
69
- }
70
-
71
- void IndexIVFPQ::train_residual_o(idx_t n, const float* x, float* residuals_2) {
72
- const float* x_in = x;
73
-
74
- x = fvecs_maybe_subsample(
75
- d,
76
- (size_t*)&n,
77
- pq.cp.max_points_per_centroid * pq.ksub,
78
- x,
79
- verbose,
80
- pq.cp.seed);
81
-
82
- ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
83
-
84
- const float* trainset;
85
- ScopeDeleter<float> del_residuals;
86
- if (by_residual) {
87
- if (verbose)
88
- printf("computing residuals\n");
89
- idx_t* assign = new idx_t[n]; // assignement to coarse centroids
90
- ScopeDeleter<idx_t> del(assign);
91
- quantizer->assign(n, x, assign);
92
- float* residuals = new float[n * d];
93
- del_residuals.set(residuals);
94
- for (idx_t i = 0; i < n; i++)
95
- quantizer->compute_residual(
96
- x + i * d, residuals + i * d, assign[i]);
97
-
98
- trainset = residuals;
99
- } else {
100
- trainset = x;
101
- }
102
- if (verbose)
103
- printf("training %zdx%zd product quantizer on %" PRId64
104
- " vectors in %dD\n",
105
- pq.M,
106
- pq.ksub,
107
- n,
108
- d);
109
- pq.verbose = verbose;
110
- pq.train(n, trainset);
67
+ void IndexIVFPQ::train_encoder(idx_t n, const float* x, const idx_t* assign) {
68
+ pq.train(n, x);
111
69
 
112
70
  if (do_polysemous_training) {
113
71
  if (verbose)
114
72
  printf("doing polysemous training for PQ\n");
115
73
  PolysemousTraining default_pt;
116
- PolysemousTraining* pt = polysemous_training;
117
- if (!pt)
118
- pt = &default_pt;
119
- pt->optimize_pq_for_hamming(pq, n, trainset);
120
- }
121
-
122
- // prepare second-level residuals for refine PQ
123
- if (residuals_2) {
124
- uint8_t* train_codes = new uint8_t[pq.code_size * n];
125
- ScopeDeleter<uint8_t> del(train_codes);
126
- pq.compute_codes(trainset, train_codes, n);
127
-
128
- for (idx_t i = 0; i < n; i++) {
129
- const float* xx = trainset + i * d;
130
- float* res = residuals_2 + i * d;
131
- pq.decode(train_codes + i * pq.code_size, res);
132
- for (int j = 0; j < d; j++)
133
- res[j] = xx[j] - res[j];
134
- }
74
+ PolysemousTraining* pt =
75
+ polysemous_training ? polysemous_training : &default_pt;
76
+ pt->optimize_pq_for_hamming(pq, n, x);
135
77
  }
136
78
 
137
79
  if (by_residual) {
@@ -139,6 +81,10 @@ void IndexIVFPQ::train_residual_o(idx_t n, const float* x, float* residuals_2) {
139
81
  }
140
82
  }
141
83
 
84
+ idx_t IndexIVFPQ::train_encoder_num_vectors() const {
85
+ return pq.cp.max_points_per_centroid * pq.ksub;
86
+ }
87
+
142
88
  /****************************************************************
143
89
  * IVFPQ as codec */
144
90
 
@@ -189,24 +135,25 @@ void IndexIVFPQ::add_core(
189
135
  idx_t n,
190
136
  const float* x,
191
137
  const idx_t* xids,
192
- const idx_t* coarse_idx) {
193
- add_core_o(n, x, xids, nullptr, coarse_idx);
138
+ const idx_t* coarse_idx,
139
+ void* inverted_list_context) {
140
+ add_core_o(n, x, xids, nullptr, coarse_idx, inverted_list_context);
194
141
  }
195
142
 
196
- static float* compute_residuals(
143
+ static std::unique_ptr<float[]> compute_residuals(
197
144
  const Index* quantizer,
198
145
  idx_t n,
199
146
  const float* x,
200
147
  const idx_t* list_nos) {
201
148
  size_t d = quantizer->d;
202
- float* residuals = new float[n * d];
149
+ std::unique_ptr<float[]> residuals(new float[n * d]);
203
150
  // TODO: parallelize?
204
151
  for (size_t i = 0; i < n; i++) {
205
152
  if (list_nos[i] < 0)
206
- memset(residuals + i * d, 0, sizeof(*residuals) * d);
153
+ memset(residuals.get() + i * d, 0, sizeof(float) * d);
207
154
  else
208
155
  quantizer->compute_residual(
209
- x + i * d, residuals + i * d, list_nos[i]);
156
+ x + i * d, residuals.get() + i * d, list_nos[i]);
210
157
  }
211
158
  return residuals;
212
159
  }
@@ -218,9 +165,9 @@ void IndexIVFPQ::encode_vectors(
218
165
  uint8_t* codes,
219
166
  bool include_listnos) const {
220
167
  if (by_residual) {
221
- float* to_encode = compute_residuals(quantizer, n, x, list_nos);
222
- ScopeDeleter<float> del(to_encode);
223
- pq.compute_codes(to_encode, codes, n);
168
+ std::unique_ptr<float[]> to_encode =
169
+ compute_residuals(quantizer, n, x, list_nos);
170
+ pq.compute_codes(to_encode.get(), codes, n);
224
171
  } else {
225
172
  pq.compute_codes(x, codes, n);
226
173
  }
@@ -266,7 +213,8 @@ void IndexIVFPQ::add_core_o(
266
213
  const float* x,
267
214
  const idx_t* xids,
268
215
  float* residuals_2,
269
- const idx_t* precomputed_idx) {
216
+ const idx_t* precomputed_idx,
217
+ void* inverted_list_context) {
270
218
  idx_t bs = index_ivfpq_add_core_o_bs;
271
219
  if (n > bs) {
272
220
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
@@ -283,7 +231,8 @@ void IndexIVFPQ::add_core_o(
283
231
  x + i0 * d,
284
232
  xids ? xids + i0 : nullptr,
285
233
  residuals_2 ? residuals_2 + i0 * d : nullptr,
286
- precomputed_idx ? precomputed_idx + i0 : nullptr);
234
+ precomputed_idx ? precomputed_idx + i0 : nullptr,
235
+ inverted_list_context);
287
236
  }
288
237
  return;
289
238
  }
@@ -295,31 +244,30 @@ void IndexIVFPQ::add_core_o(
295
244
  FAISS_THROW_IF_NOT(is_trained);
296
245
  double t0 = getmillisecs();
297
246
  const idx_t* idx;
298
- ScopeDeleter<idx_t> del_idx;
247
+ std::unique_ptr<idx_t[]> del_idx;
299
248
 
300
249
  if (precomputed_idx) {
301
250
  idx = precomputed_idx;
302
251
  } else {
303
252
  idx_t* idx0 = new idx_t[n];
304
- del_idx.set(idx0);
253
+ del_idx.reset(idx0);
305
254
  quantizer->assign(n, x, idx0);
306
255
  idx = idx0;
307
256
  }
308
257
 
309
258
  double t1 = getmillisecs();
310
- uint8_t* xcodes = new uint8_t[n * code_size];
311
- ScopeDeleter<uint8_t> del_xcodes(xcodes);
259
+ std::unique_ptr<uint8_t[]> xcodes(new uint8_t[n * code_size]);
312
260
 
313
261
  const float* to_encode = nullptr;
314
- ScopeDeleter<float> del_to_encode;
262
+ std::unique_ptr<const float[]> del_to_encode;
315
263
 
316
264
  if (by_residual) {
317
- to_encode = compute_residuals(quantizer, n, x, idx);
318
- del_to_encode.set(to_encode);
265
+ del_to_encode = compute_residuals(quantizer, n, x, idx);
266
+ to_encode = del_to_encode.get();
319
267
  } else {
320
268
  to_encode = x;
321
269
  }
322
- pq.compute_codes(to_encode, xcodes, n);
270
+ pq.compute_codes(to_encode, xcodes.get(), n);
323
271
 
324
272
  double t2 = getmillisecs();
325
273
  // TODO: parallelize?
@@ -335,8 +283,9 @@ void IndexIVFPQ::add_core_o(
335
283
  continue;
336
284
  }
337
285
 
338
- uint8_t* code = xcodes + i * code_size;
339
- size_t offset = invlists->add_entry(key, id, code);
286
+ uint8_t* code = xcodes.get() + i * code_size;
287
+ size_t offset =
288
+ invlists->add_entry(key, id, code, inverted_list_context);
340
289
 
341
290
  if (residuals_2) {
342
291
  float* res2 = residuals_2 + i * d;
@@ -804,7 +753,7 @@ struct QueryTables {
804
753
  }
805
754
  };
806
755
 
807
- // This way of handling the sleector is not optimal since all distances
756
+ // This way of handling the selector is not optimal since all distances
808
757
  // are computed even if the id would filter it out.
809
758
  template <class C, bool use_sel>
810
759
  struct KnnSearchResults {
@@ -937,7 +886,8 @@ struct IVFPQScannerT : QueryTables {
937
886
  float distance_2 = 0;
938
887
  float distance_3 = 0;
939
888
  distance_four_codes<PQDecoder>(
940
- pq,
889
+ pq.M,
890
+ pq.nbits,
941
891
  sim_table,
942
892
  codes + saved_j[0] * pq.code_size,
943
893
  codes + saved_j[1] * pq.code_size,
@@ -957,24 +907,30 @@ struct IVFPQScannerT : QueryTables {
957
907
  }
958
908
 
959
909
  if (counter >= 1) {
960
- float dis =
961
- dis0 +
910
+ float dis = dis0 +
962
911
  distance_single_code<PQDecoder>(
963
- pq, sim_table, codes + saved_j[0] * pq.code_size);
912
+ pq.M,
913
+ pq.nbits,
914
+ sim_table,
915
+ codes + saved_j[0] * pq.code_size);
964
916
  res.add(saved_j[0], dis);
965
917
  }
966
918
  if (counter >= 2) {
967
- float dis =
968
- dis0 +
919
+ float dis = dis0 +
969
920
  distance_single_code<PQDecoder>(
970
- pq, sim_table, codes + saved_j[1] * pq.code_size);
921
+ pq.M,
922
+ pq.nbits,
923
+ sim_table,
924
+ codes + saved_j[1] * pq.code_size);
971
925
  res.add(saved_j[1], dis);
972
926
  }
973
927
  if (counter >= 3) {
974
- float dis =
975
- dis0 +
928
+ float dis = dis0 +
976
929
  distance_single_code<PQDecoder>(
977
- pq, sim_table, codes + saved_j[2] * pq.code_size);
930
+ pq.M,
931
+ pq.nbits,
932
+ sim_table,
933
+ codes + saved_j[2] * pq.code_size);
978
934
  res.add(saved_j[2], dis);
979
935
  }
980
936
  }
@@ -1090,7 +1046,7 @@ struct IVFPQScannerT : QueryTables {
1090
1046
  const uint8_t* codes,
1091
1047
  SearchResultType& res) const {
1092
1048
  int ht = ivfpq.polysemous_ht;
1093
- size_t n_hamming_pass = 0, nup = 0;
1049
+ size_t n_hamming_pass = 0;
1094
1050
 
1095
1051
  int code_size = pq.code_size;
1096
1052
 
@@ -1137,7 +1093,8 @@ struct IVFPQScannerT : QueryTables {
1137
1093
  float distance_2 = dis0;
1138
1094
  float distance_3 = dis0;
1139
1095
  distance_four_codes<PQDecoder>(
1140
- pq,
1096
+ pq.M,
1097
+ pq.nbits,
1141
1098
  sim_table,
1142
1099
  codes + saved_j[0] * pq.code_size,
1143
1100
  codes + saved_j[1] * pq.code_size,
@@ -1165,10 +1122,12 @@ struct IVFPQScannerT : QueryTables {
1165
1122
  for (size_t kk = 0; kk < counter; kk++) {
1166
1123
  n_hamming_pass++;
1167
1124
 
1168
- float dis =
1169
- dis0 +
1125
+ float dis = dis0 +
1170
1126
  distance_single_code<PQDecoder>(
1171
- pq, sim_table, codes + saved_j[kk] * pq.code_size);
1127
+ pq.M,
1128
+ pq.nbits,
1129
+ sim_table,
1130
+ codes + saved_j[kk] * pq.code_size);
1172
1131
 
1173
1132
  res.add(saved_j[kk], dis);
1174
1133
  }
@@ -1185,7 +1144,10 @@ struct IVFPQScannerT : QueryTables {
1185
1144
 
1186
1145
  float dis = dis0 +
1187
1146
  distance_single_code<PQDecoder>(
1188
- pq, sim_table, codes + j * code_size);
1147
+ pq.M,
1148
+ pq.nbits,
1149
+ sim_table,
1150
+ codes + j * code_size);
1189
1151
 
1190
1152
  res.add(j, dis);
1191
1153
  }
@@ -1195,30 +1157,23 @@ struct IVFPQScannerT : QueryTables {
1195
1157
  { indexIVFPQ_stats.n_hamming_pass += n_hamming_pass; }
1196
1158
  }
1197
1159
 
1160
+ template <class SearchResultType>
1161
+ struct Run_scan_list_polysemous_hc {
1162
+ using T = void;
1163
+ template <class HammingComputer, class... Types>
1164
+ void f(const IVFPQScannerT* scanner, Types... args) {
1165
+ scanner->scan_list_polysemous_hc<HammingComputer, SearchResultType>(
1166
+ args...);
1167
+ }
1168
+ };
1169
+
1198
1170
  template <class SearchResultType>
1199
1171
  void scan_list_polysemous(
1200
1172
  size_t ncode,
1201
1173
  const uint8_t* codes,
1202
1174
  SearchResultType& res) const {
1203
- switch (pq.code_size) {
1204
- #define HANDLE_CODE_SIZE(cs) \
1205
- case cs: \
1206
- scan_list_polysemous_hc<HammingComputer##cs, SearchResultType>( \
1207
- ncode, codes, res); \
1208
- break
1209
- HANDLE_CODE_SIZE(4);
1210
- HANDLE_CODE_SIZE(8);
1211
- HANDLE_CODE_SIZE(16);
1212
- HANDLE_CODE_SIZE(20);
1213
- HANDLE_CODE_SIZE(32);
1214
- HANDLE_CODE_SIZE(64);
1215
- #undef HANDLE_CODE_SIZE
1216
- default:
1217
- scan_list_polysemous_hc<
1218
- HammingComputerDefault,
1219
- SearchResultType>(ncode, codes, res);
1220
- break;
1221
- }
1175
+ Run_scan_list_polysemous_hc<SearchResultType> r;
1176
+ dispatch_HammingComputer(pq.code_size, r, this, ncode, codes, res);
1222
1177
  }
1223
1178
  };
1224
1179
 
@@ -1248,6 +1203,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1248
1203
  precompute_mode(precompute_mode),
1249
1204
  sel(sel) {
1250
1205
  this->store_pairs = store_pairs;
1206
+ this->keep_max = is_similarity_metric(METRIC_TYPE);
1251
1207
  }
1252
1208
 
1253
1209
  void set_query(const float* query) override {
@@ -1263,7 +1219,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1263
1219
  assert(precompute_mode == 2);
1264
1220
  float dis = this->dis0 +
1265
1221
  distance_single_code<PQDecoder>(
1266
- this->pq, this->sim_table, code);
1222
+ this->pq.M, this->pq.nbits, this->sim_table, code);
1267
1223
  return dis;
1268
1224
  }
1269
1225
 
@@ -32,8 +32,6 @@ FAISS_API extern size_t precomputed_table_max_bytes;
32
32
  * vector is encoded as a product quantizer code.
33
33
  */
34
34
  struct IndexIVFPQ : IndexIVF {
35
- bool by_residual; ///< Encode residual or plain vector?
36
-
37
35
  ProductQuantizer pq; ///< produces the codes
38
36
 
39
37
  bool do_polysemous_training; ///< reorder PQ centroids after training?
@@ -73,7 +71,8 @@ struct IndexIVFPQ : IndexIVF {
73
71
  idx_t n,
74
72
  const float* x,
75
73
  const idx_t* xids,
76
- const idx_t* precomputed_idx) override;
74
+ const idx_t* precomputed_idx,
75
+ void* inverted_list_context = nullptr) override;
77
76
 
78
77
  /// same as add_core, also:
79
78
  /// - output 2nd level residuals if residuals_2 != NULL
@@ -83,13 +82,13 @@ struct IndexIVFPQ : IndexIVF {
83
82
  const float* x,
84
83
  const idx_t* xids,
85
84
  float* residuals_2,
86
- const idx_t* precomputed_idx = nullptr);
85
+ const idx_t* precomputed_idx = nullptr,
86
+ void* inverted_list_context = nullptr);
87
87
 
88
88
  /// trains the product quantizer
89
- void train_residual(idx_t n, const float* x) override;
89
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
90
90
 
91
- /// same as train_residual, also output 2nd level residuals
92
- void train_residual_o(idx_t n, const float* x, float* residuals_2);
91
+ idx_t train_encoder_num_vectors() const override;
93
92
 
94
93
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
95
94
  const override;
@@ -44,7 +44,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(
44
44
  MetricType metric,
45
45
  int bbs)
46
46
  : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
47
- by_residual = false; // set to false by default because it's much faster
47
+ by_residual = false; // set to false by default because it's faster
48
48
 
49
49
  init_fastscan(M, nbits, nlist, metric, bbs);
50
50
  }
@@ -106,54 +106,22 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
106
106
  * Training
107
107
  *********************************************************/
108
108
 
109
- void IndexIVFPQFastScan::train_residual(idx_t n, const float* x_in) {
110
- const float* x = fvecs_maybe_subsample(
111
- d,
112
- (size_t*)&n,
113
- pq.cp.max_points_per_centroid * pq.ksub,
114
- x_in,
115
- verbose,
116
- pq.cp.seed);
117
-
118
- std::unique_ptr<float[]> del_x;
119
- if (x != x_in) {
120
- del_x.reset((float*)x);
121
- }
122
-
123
- const float* trainset;
124
- AlignedTable<float> residuals;
125
-
126
- if (by_residual) {
127
- if (verbose)
128
- printf("computing residuals\n");
129
- std::vector<idx_t> assign(n);
130
- quantizer->assign(n, x, assign.data());
131
- residuals.resize(n * d);
132
- for (idx_t i = 0; i < n; i++) {
133
- quantizer->compute_residual(
134
- x + i * d, residuals.data() + i * d, assign[i]);
135
- }
136
- trainset = residuals.data();
137
- } else {
138
- trainset = x;
139
- }
140
-
141
- if (verbose) {
142
- printf("training %zdx%zd product quantizer on "
143
- "%" PRId64 " vectors in %dD\n",
144
- pq.M,
145
- pq.ksub,
146
- n,
147
- d);
148
- }
109
+ void IndexIVFPQFastScan::train_encoder(
110
+ idx_t n,
111
+ const float* x,
112
+ const idx_t* assign) {
149
113
  pq.verbose = verbose;
150
- pq.train(n, trainset);
114
+ pq.train(n, x);
151
115
 
152
116
  if (by_residual && metric_type == METRIC_L2) {
153
117
  precompute_table();
154
118
  }
155
119
  }
156
120
 
121
+ idx_t IndexIVFPQFastScan::train_encoder_num_vectors() const {
122
+ return pq.cp.max_points_per_centroid * pq.ksub;
123
+ }
124
+
157
125
  void IndexIVFPQFastScan::precompute_table() {
158
126
  initialize_IVFPQ_precomputed_table(
159
127
  use_precomputed_table,
@@ -203,7 +171,7 @@ void IndexIVFPQFastScan::encode_vectors(
203
171
  * Look-Up Table functions
204
172
  *********************************************************/
205
173
 
206
- void fvec_madd_avx(
174
+ void fvec_madd_simd(
207
175
  size_t n,
208
176
  const float* a,
209
177
  float bf,
@@ -234,12 +202,12 @@ bool IndexIVFPQFastScan::lookup_table_is_3d() const {
234
202
  void IndexIVFPQFastScan::compute_LUT(
235
203
  size_t n,
236
204
  const float* x,
237
- const idx_t* coarse_ids,
238
- const float* coarse_dis,
205
+ const CoarseQuantized& cq,
239
206
  AlignedTable<float>& dis_tables,
240
207
  AlignedTable<float>& biases) const {
241
208
  size_t dim12 = pq.ksub * pq.M;
242
209
  size_t d = pq.d;
210
+ size_t nprobe = this->nprobe;
243
211
 
244
212
  if (by_residual) {
245
213
  if (metric_type == METRIC_L2) {
@@ -247,7 +215,7 @@ void IndexIVFPQFastScan::compute_LUT(
247
215
 
248
216
  if (use_precomputed_table == 1) {
249
217
  biases.resize(n * nprobe);
250
- memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
218
+ memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe);
251
219
 
252
220
  AlignedTable<float> ip_table(n * dim12);
253
221
  pq.compute_inner_prod_tables(n, x, ip_table.get());
@@ -256,10 +224,10 @@ void IndexIVFPQFastScan::compute_LUT(
256
224
  for (idx_t ij = 0; ij < n * nprobe; ij++) {
257
225
  idx_t i = ij / nprobe;
258
226
  float* tab = dis_tables.get() + ij * dim12;
259
- idx_t cij = coarse_ids[ij];
227
+ idx_t cij = cq.ids[ij];
260
228
 
261
229
  if (cij >= 0) {
262
- fvec_madd_avx(
230
+ fvec_madd_simd(
263
231
  dim12,
264
232
  precomputed_table.get() + cij * dim12,
265
233
  -2,
@@ -281,7 +249,7 @@ void IndexIVFPQFastScan::compute_LUT(
281
249
  for (idx_t ij = 0; ij < n * nprobe; ij++) {
282
250
  idx_t i = ij / nprobe;
283
251
  float* xij = &xrel[ij * d];
284
- idx_t cij = coarse_ids[ij];
252
+ idx_t cij = cq.ids[ij];
285
253
 
286
254
  if (cij >= 0) {
287
255
  quantizer->compute_residual(x + i * d, xij, cij);
@@ -301,7 +269,7 @@ void IndexIVFPQFastScan::compute_LUT(
301
269
  // compute_inner_prod_tables(pq, n, x, dis_tables.get());
302
270
 
303
271
  biases.resize(n * nprobe);
304
- memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
272
+ memcpy(biases.get(), cq.dis, sizeof(float) * n * nprobe);
305
273
  } else {
306
274
  FAISS_THROW_FMT("metric %d not supported", metric_type);
307
275
  }
@@ -54,7 +54,9 @@ struct IndexIVFPQFastScan : IndexIVFFastScan {
54
54
  // built from an IndexIVFPQ
55
55
  explicit IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs = 32);
56
56
 
57
- void train_residual(idx_t n, const float* x) override;
57
+ void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
58
+
59
+ idx_t train_encoder_num_vectors() const override;
58
60
 
59
61
  /// build precomputed table, possibly updating use_precomputed_table
60
62
  void precompute_table();
@@ -75,8 +77,7 @@ struct IndexIVFPQFastScan : IndexIVFFastScan {
75
77
  void compute_LUT(
76
78
  size_t n,
77
79
  const float* x,
78
- const idx_t* coarse_ids,
79
- const float* coarse_dis,
80
+ const CoarseQuantized& cq,
80
81
  AlignedTable<float>& dis_tables,
81
82
  AlignedTable<float>& biases) const override;
82
83