faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -43,6 +43,8 @@ IndexIVFFastScan::IndexIVFFastScan(
43
43
  size_t code_size,
44
44
  MetricType metric)
45
45
  : IndexIVF(quantizer, d, nlist, code_size, metric) {
46
+ // unlike other indexes, we prefer no residuals for performance reasons.
47
+ by_residual = false;
46
48
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
47
49
  }
48
50
 
@@ -50,6 +52,7 @@ IndexIVFFastScan::IndexIVFFastScan() {
50
52
  bbs = 0;
51
53
  M2 = 0;
52
54
  is_trained = false;
55
+ by_residual = false;
53
56
  }
54
57
 
55
58
  void IndexIVFFastScan::init_fastscan(
@@ -79,7 +82,7 @@ void IndexIVFFastScan::init_code_packer() {
79
82
  bil->packer = get_CodePacker();
80
83
  }
81
84
 
82
- IndexIVFFastScan::~IndexIVFFastScan() {}
85
+ IndexIVFFastScan::~IndexIVFFastScan() = default;
83
86
 
84
87
  /*********************************************************
85
88
  * Code management functions
@@ -195,7 +198,7 @@ CodePacker* IndexIVFFastScan::get_CodePacker() const {
195
198
 
196
199
  namespace {
197
200
 
198
- template <class C, typename dis_t, class Scaler>
201
+ template <class C, typename dis_t>
199
202
  void estimators_from_tables_generic(
200
203
  const IndexIVFFastScan& index,
201
204
  const uint8_t* codes,
@@ -206,22 +209,26 @@ void estimators_from_tables_generic(
206
209
  size_t k,
207
210
  typename C::T* heap_dis,
208
211
  int64_t* heap_ids,
209
- const Scaler& scaler) {
212
+ const NormTableScaler* scaler) {
210
213
  using accu_t = typename C::T;
214
+ size_t nscale = scaler ? scaler->nscale : 0;
211
215
  for (size_t j = 0; j < ncodes; ++j) {
212
216
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
213
217
  accu_t dis = bias;
214
218
  const dis_t* __restrict dt = dis_table;
215
- for (size_t m = 0; m < index.M - scaler.nscale; m++) {
219
+
220
+ for (size_t m = 0; m < index.M - nscale; m++) {
216
221
  uint64_t c = bsr.read(index.nbits);
217
222
  dis += dt[c];
218
223
  dt += index.ksub;
219
224
  }
220
225
 
221
- for (size_t m = 0; m < scaler.nscale; m++) {
222
- uint64_t c = bsr.read(index.nbits);
223
- dis += scaler.scale_one(dt[c]);
224
- dt += index.ksub;
226
+ if (scaler) {
227
+ for (size_t m = 0; m < nscale; m++) {
228
+ uint64_t c = bsr.read(index.nbits);
229
+ dis += scaler->scale_one(dt[c]);
230
+ dt += index.ksub;
231
+ }
225
232
  }
226
233
 
227
234
  if (C::cmp(heap_dis[0], dis)) {
@@ -242,18 +249,15 @@ using namespace quantize_lut;
242
249
  void IndexIVFFastScan::compute_LUT_uint8(
243
250
  size_t n,
244
251
  const float* x,
245
- const idx_t* coarse_ids,
246
- const float* coarse_dis,
252
+ const CoarseQuantized& cq,
247
253
  AlignedTable<uint8_t>& dis_tables,
248
254
  AlignedTable<uint16_t>& biases,
249
255
  float* normalizers) const {
250
256
  AlignedTable<float> dis_tables_float;
251
257
  AlignedTable<float> biases_float;
252
258
 
253
- uint64_t t0 = get_cy();
254
- compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
255
- IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
256
-
259
+ compute_LUT(n, x, cq, dis_tables_float, biases_float);
260
+ size_t nprobe = cq.nprobe;
257
261
  bool lut_is_3d = lookup_table_is_3d();
258
262
  size_t dim123 = ksub * M;
259
263
  size_t dim123_2 = ksub * M2;
@@ -265,8 +269,8 @@ void IndexIVFFastScan::compute_LUT_uint8(
265
269
  if (biases_float.get()) {
266
270
  biases.resize(n * nprobe);
267
271
  }
268
- uint64_t t1 = get_cy();
269
272
 
273
+ // OMP for MSVC requires i to have signed integral type
270
274
  #pragma omp parallel for if (n > 100)
271
275
  for (int64_t i = 0; i < n; i++) {
272
276
  const float* t_in = dis_tables_float.get() + i * dim123;
@@ -291,7 +295,6 @@ void IndexIVFFastScan::compute_LUT_uint8(
291
295
  normalizers + 2 * i,
292
296
  normalizers + 2 * i + 1);
293
297
  }
294
- IVFFastScan_stats.t_round += get_cy() - t1;
295
298
  }
296
299
 
297
300
  /*********************************************************
@@ -304,45 +307,195 @@ void IndexIVFFastScan::search(
304
307
  idx_t k,
305
308
  float* distances,
306
309
  idx_t* labels,
307
- const SearchParameters* params) const {
310
+ const SearchParameters* params_in) const {
311
+ const IVFSearchParameters* params = nullptr;
312
+ if (params_in) {
313
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
314
+ FAISS_THROW_IF_NOT_MSG(
315
+ params, "IndexIVFFastScan params have incorrect type");
316
+ }
317
+
318
+ search_preassigned(
319
+ n, x, k, nullptr, nullptr, distances, labels, false, params);
320
+ }
321
+
322
+ void IndexIVFFastScan::search_preassigned(
323
+ idx_t n,
324
+ const float* x,
325
+ idx_t k,
326
+ const idx_t* assign,
327
+ const float* centroid_dis,
328
+ float* distances,
329
+ idx_t* labels,
330
+ bool store_pairs,
331
+ const IVFSearchParameters* params,
332
+ IndexIVFStats* stats) const {
333
+ size_t nprobe = this->nprobe;
334
+ if (params) {
335
+ FAISS_THROW_IF_NOT(params->max_codes == 0);
336
+ nprobe = params->nprobe;
337
+ }
338
+
308
339
  FAISS_THROW_IF_NOT_MSG(
309
- !params, "search params not supported for this index");
340
+ !store_pairs, "store_pairs not supported for this index");
341
+ FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
310
342
  FAISS_THROW_IF_NOT(k > 0);
311
343
 
312
- DummyScaler scaler;
313
- if (metric_type == METRIC_L2) {
314
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
344
+ const CoarseQuantized cq = {nprobe, centroid_dis, assign};
345
+ search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params);
346
+ }
347
+
348
+ void IndexIVFFastScan::range_search(
349
+ idx_t n,
350
+ const float* x,
351
+ float radius,
352
+ RangeSearchResult* result,
353
+ const SearchParameters* params_in) const {
354
+ size_t nprobe = this->nprobe;
355
+ const IVFSearchParameters* params = nullptr;
356
+ if (params_in) {
357
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
358
+ FAISS_THROW_IF_NOT_MSG(
359
+ params, "IndexIVFFastScan params have incorrect type");
360
+ nprobe = params->nprobe;
361
+ }
362
+
363
+ const CoarseQuantized cq = {nprobe, nullptr, nullptr};
364
+ range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params);
365
+ }
366
+
367
+ namespace {
368
+
369
+ template <class C>
370
+ ResultHandlerCompare<C, true>* make_knn_handler_fixC(
371
+ int impl,
372
+ idx_t n,
373
+ idx_t k,
374
+ float* distances,
375
+ idx_t* labels,
376
+ const IDSelector* sel) {
377
+ using HeapHC = HeapHandler<C, true>;
378
+ using ReservoirHC = ReservoirHandler<C, true>;
379
+ using SingleResultHC = SingleResultHandler<C, true>;
380
+
381
+ if (k == 1) {
382
+ return new SingleResultHC(n, 0, distances, labels, sel);
383
+ } else if (impl % 2 == 0) {
384
+ return new HeapHC(n, 0, k, distances, labels, sel);
385
+ } else /* if (impl % 2 == 1) */ {
386
+ return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel);
387
+ }
388
+ }
389
+
390
+ SIMDResultHandlerToFloat* make_knn_handler(
391
+ bool is_max,
392
+ int impl,
393
+ idx_t n,
394
+ idx_t k,
395
+ float* distances,
396
+ idx_t* labels,
397
+ const IDSelector* sel) {
398
+ if (is_max) {
399
+ return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
400
+ impl, n, k, distances, labels, sel);
315
401
  } else {
316
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
402
+ return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
403
+ impl, n, k, distances, labels, sel);
317
404
  }
318
405
  }
319
406
 
320
- void IndexIVFFastScan::range_search(
321
- idx_t,
322
- const float*,
323
- float,
324
- RangeSearchResult*,
325
- const SearchParameters*) const {
326
- FAISS_THROW_MSG("not implemented");
407
+ using CoarseQuantized = IndexIVFFastScan::CoarseQuantized;
408
+
409
+ struct CoarseQuantizedWithBuffer : CoarseQuantized {
410
+ explicit CoarseQuantizedWithBuffer(const CoarseQuantized& cq)
411
+ : CoarseQuantized(cq) {}
412
+
413
+ bool done() const {
414
+ return ids != nullptr;
415
+ }
416
+
417
+ std::vector<idx_t> ids_buffer;
418
+ std::vector<float> dis_buffer;
419
+
420
+ void quantize(
421
+ const Index* quantizer,
422
+ idx_t n,
423
+ const float* x,
424
+ const SearchParameters* quantizer_params) {
425
+ dis_buffer.resize(nprobe * n);
426
+ ids_buffer.resize(nprobe * n);
427
+ quantizer->search(
428
+ n,
429
+ x,
430
+ nprobe,
431
+ dis_buffer.data(),
432
+ ids_buffer.data(),
433
+ quantizer_params);
434
+ dis = dis_buffer.data();
435
+ ids = ids_buffer.data();
436
+ }
437
+ };
438
+
439
+ struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
440
+ size_t i0, i1;
441
+ CoarseQuantizedSlice(const CoarseQuantized& cq, size_t i0, size_t i1)
442
+ : CoarseQuantizedWithBuffer(cq), i0(i0), i1(i1) {
443
+ if (done()) {
444
+ dis += nprobe * i0;
445
+ ids += nprobe * i0;
446
+ }
447
+ }
448
+
449
+ void quantize_slice(
450
+ const Index* quantizer,
451
+ const float* x,
452
+ const SearchParameters* quantizer_params) {
453
+ quantize(quantizer, i1 - i0, x + quantizer->d * i0, quantizer_params);
454
+ }
455
+ };
456
+
457
+ int compute_search_nslice(
458
+ const IndexIVFFastScan* index,
459
+ size_t n,
460
+ size_t nprobe) {
461
+ int nslice;
462
+ if (n <= omp_get_max_threads()) {
463
+ nslice = n;
464
+ } else if (index->lookup_table_is_3d()) {
465
+ // make sure we don't make too big LUT tables
466
+ size_t lut_size_per_query = index->M * index->ksub * nprobe *
467
+ (sizeof(float) + sizeof(uint8_t));
468
+
469
+ size_t max_lut_size = precomputed_table_max_bytes;
470
+ // how many queries we can handle within mem budget
471
+ size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
472
+ nslice = roundup(
473
+ std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
474
+ } else {
475
+ // LUTs unlikely to be a limiting factor
476
+ nslice = omp_get_max_threads();
477
+ }
478
+ return nslice;
327
479
  }
328
480
 
329
- template <bool is_max, class Scaler>
481
+ } // namespace
482
+
330
483
  void IndexIVFFastScan::search_dispatch_implem(
331
484
  idx_t n,
332
485
  const float* x,
333
486
  idx_t k,
334
487
  float* distances,
335
488
  idx_t* labels,
336
- const Scaler& scaler) const {
337
- using Cfloat = typename std::conditional<
338
- is_max,
339
- CMax<float, int64_t>,
340
- CMin<float, int64_t>>::type;
489
+ const CoarseQuantized& cq_in,
490
+ const NormTableScaler* scaler,
491
+ const IVFSearchParameters* params) const {
492
+ const idx_t nprobe = params ? params->nprobe : this->nprobe;
493
+ const IDSelector* sel = (params) ? params->sel : nullptr;
494
+ const SearchParameters* quantizer_params =
495
+ params ? params->quantizer_params : nullptr;
341
496
 
342
- using C = typename std::conditional<
343
- is_max,
344
- CMax<uint16_t, int64_t>,
345
- CMin<uint16_t, int64_t>>::type;
497
+ bool is_max = !is_similarity_metric(metric_type);
498
+ using RH = SIMDResultHandlerToFloat;
346
499
 
347
500
  if (n == 0) {
348
501
  return;
@@ -357,70 +510,93 @@ void IndexIVFFastScan::search_dispatch_implem(
357
510
  } else {
358
511
  impl = 10;
359
512
  }
360
- if (k > 20) {
513
+ if (k > 20) { // use reservoir rather than heap
361
514
  impl++;
362
515
  }
363
516
  }
364
517
 
518
+ bool multiple_threads =
519
+ n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
520
+ if (impl >= 100) {
521
+ multiple_threads = false;
522
+ impl -= 100;
523
+ }
524
+
525
+ CoarseQuantizedWithBuffer cq(cq_in);
526
+ cq.nprobe = nprobe;
527
+
528
+ if (!cq.done() && !multiple_threads) {
529
+ // we do the coarse quantization here execpt when search is
530
+ // sliced over threads (then it is more efficient to have each thread do
531
+ // its own coarse quantization)
532
+ cq.quantize(quantizer, n, x, quantizer_params);
533
+ invlists->prefetch_lists(cq.ids, n * cq.nprobe);
534
+ }
535
+
365
536
  if (impl == 1) {
366
- search_implem_1<Cfloat>(n, x, k, distances, labels, scaler);
537
+ if (is_max) {
538
+ search_implem_1<CMax<float, int64_t>>(
539
+ n, x, k, distances, labels, cq, scaler, params);
540
+ } else {
541
+ search_implem_1<CMin<float, int64_t>>(
542
+ n, x, k, distances, labels, cq, scaler, params);
543
+ }
367
544
  } else if (impl == 2) {
368
- search_implem_2<C>(n, x, k, distances, labels, scaler);
369
-
545
+ if (is_max) {
546
+ search_implem_2<CMax<uint16_t, int64_t>>(
547
+ n, x, k, distances, labels, cq, scaler, params);
548
+ } else {
549
+ search_implem_2<CMin<uint16_t, int64_t>>(
550
+ n, x, k, distances, labels, cq, scaler, params);
551
+ }
370
552
  } else if (impl >= 10 && impl <= 15) {
371
553
  size_t ndis = 0, nlist_visited = 0;
372
554
 
373
- if (n < 2) {
555
+ if (!multiple_threads) {
556
+ // clang-format off
374
557
  if (impl == 12 || impl == 13) {
375
- search_implem_12<C>(
376
- n,
377
- x,
378
- k,
379
- distances,
380
- labels,
381
- impl,
382
- &ndis,
383
- &nlist_visited,
384
- scaler);
558
+ std::unique_ptr<RH> handler(
559
+ make_knn_handler(
560
+ is_max,
561
+ impl,
562
+ n,
563
+ k,
564
+ distances,
565
+ labels, sel
566
+ )
567
+ );
568
+ search_implem_12(
569
+ n, x, *handler.get(),
570
+ cq, &ndis, &nlist_visited, scaler, params);
385
571
  } else if (impl == 14 || impl == 15) {
386
- search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
572
+ search_implem_14(
573
+ n, x, k, distances, labels,
574
+ cq, impl, scaler, params);
387
575
  } else {
388
- search_implem_10<C>(
389
- n,
390
- x,
391
- k,
392
- distances,
576
+ std::unique_ptr<RH> handler(
577
+ make_knn_handler(
578
+ is_max,
579
+ impl,
580
+ n,
581
+ k,
582
+ distances,
393
583
  labels,
394
- impl,
395
- &ndis,
396
- &nlist_visited,
397
- scaler);
584
+ sel
585
+ )
586
+ );
587
+ search_implem_10(
588
+ n, x, *handler.get(), cq,
589
+ &ndis, &nlist_visited, scaler, params);
398
590
  }
591
+ // clang-format on
399
592
  } else {
400
593
  // explicitly slice over threads
401
- int nslice;
402
- if (n <= omp_get_max_threads()) {
403
- nslice = n;
404
- } else if (lookup_table_is_3d()) {
405
- // make sure we don't make too big LUT tables
406
- size_t lut_size_per_query =
407
- M * ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
408
-
409
- size_t max_lut_size = precomputed_table_max_bytes;
410
- // how many queries we can handle within mem budget
411
- size_t nq_ok =
412
- std::max(max_lut_size / lut_size_per_query, size_t(1));
413
- nslice =
414
- roundup(std::max(size_t(n / nq_ok), size_t(1)),
415
- omp_get_max_threads());
416
- } else {
417
- // LUTs unlikely to be a limiting factor
418
- nslice = omp_get_max_threads();
419
- }
420
- if (impl == 14 ||
421
- impl == 15) { // this might require slicing if there are too
422
- // many queries (for now we keep this simple)
423
- search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
594
+ int nslice = compute_search_nslice(this, n, cq.nprobe);
595
+ if (impl == 14 || impl == 15) {
596
+ // this might require slicing if there are too
597
+ // many queries (for now we keep this simple)
598
+ search_implem_14(
599
+ n, x, k, distances, labels, cq, impl, scaler, params);
424
600
  } else {
425
601
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
426
602
  for (int slice = 0; slice < nslice; slice++) {
@@ -428,29 +604,23 @@ void IndexIVFFastScan::search_dispatch_implem(
428
604
  idx_t i1 = n * (slice + 1) / nslice;
429
605
  float* dis_i = distances + i0 * k;
430
606
  idx_t* lab_i = labels + i0 * k;
607
+ CoarseQuantizedSlice cq_i(cq, i0, i1);
608
+ if (!cq_i.done()) {
609
+ cq_i.quantize_slice(quantizer, x, quantizer_params);
610
+ }
611
+ std::unique_ptr<RH> handler(make_knn_handler(
612
+ is_max, impl, i1 - i0, k, dis_i, lab_i, sel));
613
+ // clang-format off
431
614
  if (impl == 12 || impl == 13) {
432
- search_implem_12<C>(
433
- i1 - i0,
434
- x + i0 * d,
435
- k,
436
- dis_i,
437
- lab_i,
438
- impl,
439
- &ndis,
440
- &nlist_visited,
441
- scaler);
615
+ search_implem_12(
616
+ i1 - i0, x + i0 * d, *handler.get(),
617
+ cq_i, &ndis, &nlist_visited, scaler, params);
442
618
  } else {
443
- search_implem_10<C>(
444
- i1 - i0,
445
- x + i0 * d,
446
- k,
447
- dis_i,
448
- lab_i,
449
- impl,
450
- &ndis,
451
- &nlist_visited,
452
- scaler);
619
+ search_implem_10(
620
+ i1 - i0, x + i0 * d, *handler.get(),
621
+ cq_i, &ndis, &nlist_visited, scaler, params);
453
622
  }
623
+ // clang-format on
454
624
  }
455
625
  }
456
626
  }
@@ -462,31 +632,149 @@ void IndexIVFFastScan::search_dispatch_implem(
462
632
  }
463
633
  }
464
634
 
465
- template <class C, class Scaler>
635
+ void IndexIVFFastScan::range_search_dispatch_implem(
636
+ idx_t n,
637
+ const float* x,
638
+ float radius,
639
+ RangeSearchResult& rres,
640
+ const CoarseQuantized& cq_in,
641
+ const NormTableScaler* scaler,
642
+ const IVFSearchParameters* params) const {
643
+ // const idx_t nprobe = params ? params->nprobe : this->nprobe;
644
+ const IDSelector* sel = (params) ? params->sel : nullptr;
645
+ const SearchParameters* quantizer_params =
646
+ params ? params->quantizer_params : nullptr;
647
+
648
+ bool is_max = !is_similarity_metric(metric_type);
649
+
650
+ if (n == 0) {
651
+ return;
652
+ }
653
+
654
+ // actual implementation used
655
+ int impl = implem;
656
+
657
+ if (impl == 0) {
658
+ if (bbs == 32) {
659
+ impl = 12;
660
+ } else {
661
+ impl = 10;
662
+ }
663
+ }
664
+
665
+ CoarseQuantizedWithBuffer cq(cq_in);
666
+
667
+ bool multiple_threads =
668
+ n > 1 && impl >= 10 && impl <= 13 && omp_get_max_threads() > 1;
669
+ if (impl >= 100) {
670
+ multiple_threads = false;
671
+ impl -= 100;
672
+ }
673
+
674
+ if (!multiple_threads && !cq.done()) {
675
+ cq.quantize(quantizer, n, x, quantizer_params);
676
+ invlists->prefetch_lists(cq.ids, n * cq.nprobe);
677
+ }
678
+
679
+ size_t ndis = 0, nlist_visited = 0;
680
+
681
+ if (!multiple_threads) { // single thread
682
+ std::unique_ptr<SIMDResultHandlerToFloat> handler;
683
+ if (is_max) {
684
+ handler.reset(new RangeHandler<CMax<uint16_t, int64_t>, true>(
685
+ rres, radius, 0, sel));
686
+ } else {
687
+ handler.reset(new RangeHandler<CMin<uint16_t, int64_t>, true>(
688
+ rres, radius, 0, sel));
689
+ }
690
+ if (impl == 12) {
691
+ search_implem_12(
692
+ n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
693
+ } else if (impl == 10) {
694
+ search_implem_10(
695
+ n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
696
+ } else {
697
+ FAISS_THROW_FMT("Range search implem %d not implemented", impl);
698
+ }
699
+ } else {
700
+ // explicitly slice over threads
701
+ int nslice = compute_search_nslice(this, n, cq.nprobe);
702
+ #pragma omp parallel
703
+ {
704
+ RangeSearchPartialResult pres(&rres);
705
+
706
+ #pragma omp for reduction(+ : ndis, nlist_visited)
707
+ for (int slice = 0; slice < nslice; slice++) {
708
+ idx_t i0 = n * slice / nslice;
709
+ idx_t i1 = n * (slice + 1) / nslice;
710
+ CoarseQuantizedSlice cq_i(cq, i0, i1);
711
+ if (!cq_i.done()) {
712
+ cq_i.quantize_slice(quantizer, x, quantizer_params);
713
+ }
714
+ std::unique_ptr<SIMDResultHandlerToFloat> handler;
715
+ if (is_max) {
716
+ handler.reset(new PartialRangeHandler<
717
+ CMax<uint16_t, int64_t>,
718
+ true>(pres, radius, 0, i0, i1, sel));
719
+ } else {
720
+ handler.reset(new PartialRangeHandler<
721
+ CMin<uint16_t, int64_t>,
722
+ true>(pres, radius, 0, i0, i1, sel));
723
+ }
724
+
725
+ if (impl == 12 || impl == 13) {
726
+ search_implem_12(
727
+ i1 - i0,
728
+ x + i0 * d,
729
+ *handler.get(),
730
+ cq_i,
731
+ &ndis,
732
+ &nlist_visited,
733
+ scaler,
734
+ params);
735
+ } else {
736
+ search_implem_10(
737
+ i1 - i0,
738
+ x + i0 * d,
739
+ *handler.get(),
740
+ cq_i,
741
+ &ndis,
742
+ &nlist_visited,
743
+ scaler,
744
+ params);
745
+ }
746
+ }
747
+ pres.finalize();
748
+ }
749
+ }
750
+
751
+ indexIVF_stats.nq += n;
752
+ indexIVF_stats.ndis += ndis;
753
+ indexIVF_stats.nlist += nlist_visited;
754
+ }
755
+
756
+ template <class C>
466
757
  void IndexIVFFastScan::search_implem_1(
467
758
  idx_t n,
468
759
  const float* x,
469
760
  idx_t k,
470
761
  float* distances,
471
762
  idx_t* labels,
472
- const Scaler& scaler) const {
763
+ const CoarseQuantized& cq,
764
+ const NormTableScaler* scaler,
765
+ const IVFSearchParameters* params) const {
473
766
  FAISS_THROW_IF_NOT(orig_invlists);
474
767
 
475
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
476
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
477
-
478
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
479
-
480
768
  size_t dim12 = ksub * M;
481
769
  AlignedTable<float> dis_tables;
482
770
  AlignedTable<float> biases;
483
771
 
484
- compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
772
+ compute_LUT(n, x, cq, dis_tables, biases);
485
773
 
486
774
  bool single_LUT = !lookup_table_is_3d();
487
775
 
488
776
  size_t ndis = 0, nlist_visited = 0;
489
-
777
+ size_t nprobe = cq.nprobe;
490
778
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
491
779
  for (idx_t i = 0; i < n; i++) {
492
780
  int64_t* heap_ids = labels + i * k;
@@ -501,7 +789,7 @@ void IndexIVFFastScan::search_implem_1(
501
789
  if (!single_LUT) {
502
790
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
503
791
  }
504
- idx_t list_no = coarse_ids[i * nprobe + j];
792
+ idx_t list_no = cq.ids[i * nprobe + j];
505
793
  if (list_no < 0)
506
794
  continue;
507
795
  size_t ls = orig_invlists->list_size(list_no);
@@ -533,38 +821,29 @@ void IndexIVFFastScan::search_implem_1(
533
821
  indexIVF_stats.nlist += nlist_visited;
534
822
  }
535
823
 
536
- template <class C, class Scaler>
824
+ template <class C>
537
825
  void IndexIVFFastScan::search_implem_2(
538
826
  idx_t n,
539
827
  const float* x,
540
828
  idx_t k,
541
829
  float* distances,
542
830
  idx_t* labels,
543
- const Scaler& scaler) const {
831
+ const CoarseQuantized& cq,
832
+ const NormTableScaler* scaler,
833
+ const IVFSearchParameters* params) const {
544
834
  FAISS_THROW_IF_NOT(orig_invlists);
545
835
 
546
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
547
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
548
-
549
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
550
-
551
836
  size_t dim12 = ksub * M2;
552
837
  AlignedTable<uint8_t> dis_tables;
553
838
  AlignedTable<uint16_t> biases;
554
839
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
555
840
 
556
- compute_LUT_uint8(
557
- n,
558
- x,
559
- coarse_ids.get(),
560
- coarse_dis.get(),
561
- dis_tables,
562
- biases,
563
- normalizers.get());
841
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
564
842
 
565
843
  bool single_LUT = !lookup_table_is_3d();
566
844
 
567
845
  size_t ndis = 0, nlist_visited = 0;
846
+ size_t nprobe = cq.nprobe;
568
847
 
569
848
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
570
849
  for (idx_t i = 0; i < n; i++) {
@@ -581,7 +860,7 @@ void IndexIVFFastScan::search_implem_2(
581
860
  if (!single_LUT) {
582
861
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
583
862
  }
584
- idx_t list_no = coarse_ids[i * nprobe + j];
863
+ idx_t list_no = cq.ids[i * nprobe + j];
585
864
  if (list_no < 0)
586
865
  continue;
587
866
  size_t ls = orig_invlists->list_size(list_no);
@@ -626,171 +905,103 @@ void IndexIVFFastScan::search_implem_2(
626
905
  indexIVF_stats.nlist += nlist_visited;
627
906
  }
628
907
 
629
- template <class C, class Scaler>
630
908
  void IndexIVFFastScan::search_implem_10(
631
909
  idx_t n,
632
910
  const float* x,
633
- idx_t k,
634
- float* distances,
635
- idx_t* labels,
636
- int impl,
911
+ SIMDResultHandlerToFloat& handler,
912
+ const CoarseQuantized& cq,
637
913
  size_t* ndis_out,
638
914
  size_t* nlist_out,
639
- const Scaler& scaler) const {
640
- memset(distances, -1, sizeof(float) * k * n);
641
- memset(labels, -1, sizeof(idx_t) * k * n);
642
-
643
- using HeapHC = HeapHandler<C, true>;
644
- using ReservoirHC = ReservoirHandler<C, true>;
645
- using SingleResultHC = SingleResultHandler<C, true>;
646
-
647
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
648
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
649
-
650
- uint64_t times[10];
651
- memset(times, 0, sizeof(times));
652
- int ti = 0;
653
- #define TIC times[ti++] = get_cy()
654
- TIC;
655
-
656
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
657
-
658
- TIC;
659
-
915
+ const NormTableScaler* scaler,
916
+ const IVFSearchParameters* params) const {
660
917
  size_t dim12 = ksub * M2;
661
918
  AlignedTable<uint8_t> dis_tables;
662
919
  AlignedTable<uint16_t> biases;
663
920
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
664
921
 
665
- compute_LUT_uint8(
666
- n,
667
- x,
668
- coarse_ids.get(),
669
- coarse_dis.get(),
670
- dis_tables,
671
- biases,
672
- normalizers.get());
673
-
674
- TIC;
922
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
675
923
 
676
924
  bool single_LUT = !lookup_table_is_3d();
677
925
 
678
- TIC;
679
- size_t ndis = 0, nlist_visited = 0;
926
+ size_t ndis = 0;
927
+ int qmap1[1];
680
928
 
681
- {
682
- AlignedTable<uint16_t> tmp_distances(k);
683
- for (idx_t i = 0; i < n; i++) {
684
- const uint8_t* LUT = nullptr;
685
- int qmap1[1] = {0};
686
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
687
-
688
- if (k == 1) {
689
- handler.reset(new SingleResultHC(1, 0));
690
- } else if (impl == 10) {
691
- handler.reset(new HeapHC(
692
- 1, tmp_distances.get(), labels + i * k, k, 0));
693
- } else if (impl == 11) {
694
- handler.reset(new ReservoirHC(1, 0, k, 2 * k));
695
- } else {
696
- FAISS_THROW_MSG("invalid");
697
- }
929
+ handler.q_map = qmap1;
930
+ handler.begin(skip & 16 ? nullptr : normalizers.get());
931
+ size_t nprobe = cq.nprobe;
698
932
 
699
- handler->q_map = qmap1;
933
+ for (idx_t i = 0; i < n; i++) {
934
+ const uint8_t* LUT = nullptr;
935
+ qmap1[0] = i;
700
936
 
701
- if (single_LUT) {
702
- LUT = dis_tables.get() + i * dim12;
937
+ if (single_LUT) {
938
+ LUT = dis_tables.get() + i * dim12;
939
+ }
940
+ for (idx_t j = 0; j < nprobe; j++) {
941
+ size_t ij = i * nprobe + j;
942
+ if (!single_LUT) {
943
+ LUT = dis_tables.get() + ij * dim12;
944
+ }
945
+ if (biases.get()) {
946
+ handler.dbias = biases.get() + ij;
703
947
  }
704
- for (idx_t j = 0; j < nprobe; j++) {
705
- size_t ij = i * nprobe + j;
706
- if (!single_LUT) {
707
- LUT = dis_tables.get() + ij * dim12;
708
- }
709
- if (biases.get()) {
710
- handler->dbias = biases.get() + ij;
711
- }
712
-
713
- idx_t list_no = coarse_ids[ij];
714
- if (list_no < 0)
715
- continue;
716
- size_t ls = invlists->list_size(list_no);
717
- if (ls == 0)
718
- continue;
719
948
 
720
- InvertedLists::ScopedCodes codes(invlists, list_no);
721
- InvertedLists::ScopedIds ids(invlists, list_no);
949
+ idx_t list_no = cq.ids[ij];
950
+ if (list_no < 0) {
951
+ continue;
952
+ }
953
+ size_t ls = invlists->list_size(list_no);
954
+ if (ls == 0) {
955
+ continue;
956
+ }
722
957
 
723
- handler->ntotal = ls;
724
- handler->id_map = ids.get();
958
+ InvertedLists::ScopedCodes codes(invlists, list_no);
959
+ InvertedLists::ScopedIds ids(invlists, list_no);
725
960
 
726
- #define DISPATCH(classHC) \
727
- if (dynamic_cast<classHC*>(handler.get())) { \
728
- auto* res = static_cast<classHC*>(handler.get()); \
729
- pq4_accumulate_loop( \
730
- 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res, scaler); \
731
- }
732
- DISPATCH(HeapHC)
733
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
734
- #undef DISPATCH
961
+ handler.ntotal = ls;
962
+ handler.id_map = ids.get();
735
963
 
736
- nlist_visited++;
737
- ndis++;
738
- }
964
+ pq4_accumulate_loop(
965
+ 1,
966
+ roundup(ls, bbs),
967
+ bbs,
968
+ M2,
969
+ codes.get(),
970
+ LUT,
971
+ handler,
972
+ scaler);
739
973
 
740
- handler->to_flat_arrays(
741
- distances + i * k,
742
- labels + i * k,
743
- skip & 16 ? nullptr : normalizers.get() + i * 2);
974
+ ndis++;
744
975
  }
745
976
  }
977
+
978
+ handler.end();
746
979
  *ndis_out = ndis;
747
980
  *nlist_out = nlist;
748
981
  }
749
982
 
750
- template <class C, class Scaler>
751
983
  void IndexIVFFastScan::search_implem_12(
752
984
  idx_t n,
753
985
  const float* x,
754
- idx_t k,
755
- float* distances,
756
- idx_t* labels,
757
- int impl,
986
+ SIMDResultHandlerToFloat& handler,
987
+ const CoarseQuantized& cq,
758
988
  size_t* ndis_out,
759
989
  size_t* nlist_out,
760
- const Scaler& scaler) const {
990
+ const NormTableScaler* scaler,
991
+ const IVFSearchParameters* params) const {
761
992
  if (n == 0) { // does not work well with reservoir
762
993
  return;
763
994
  }
764
995
  FAISS_THROW_IF_NOT(bbs == 32);
765
996
 
766
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
767
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
768
-
769
- uint64_t times[10];
770
- memset(times, 0, sizeof(times));
771
- int ti = 0;
772
- #define TIC times[ti++] = get_cy()
773
- TIC;
774
-
775
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
776
-
777
- TIC;
778
-
779
997
  size_t dim12 = ksub * M2;
780
998
  AlignedTable<uint8_t> dis_tables;
781
999
  AlignedTable<uint16_t> biases;
782
1000
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
783
1001
 
784
- compute_LUT_uint8(
785
- n,
786
- x,
787
- coarse_ids.get(),
788
- coarse_dis.get(),
789
- dis_tables,
790
- biases,
791
- normalizers.get());
1002
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
792
1003
 
793
- TIC;
1004
+ handler.begin(skip & 16 ? nullptr : normalizers.get());
794
1005
 
795
1006
  struct QC {
796
1007
  int qno; // sequence number of the query
@@ -798,14 +1009,15 @@ void IndexIVFFastScan::search_implem_12(
798
1009
  int rank; // this is the rank'th result of the coarse quantizer
799
1010
  };
800
1011
  bool single_LUT = !lookup_table_is_3d();
1012
+ size_t nprobe = cq.nprobe;
801
1013
 
802
1014
  std::vector<QC> qcs;
803
1015
  {
804
1016
  int ij = 0;
805
1017
  for (int i = 0; i < n; i++) {
806
1018
  for (int j = 0; j < nprobe; j++) {
807
- if (coarse_ids[ij] >= 0) {
808
- qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
1019
+ if (cq.ids[ij] >= 0) {
1020
+ qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
809
1021
  }
810
1022
  ij++;
811
1023
  }
@@ -814,42 +1026,22 @@ void IndexIVFFastScan::search_implem_12(
814
1026
  return a.list_no < b.list_no;
815
1027
  });
816
1028
  }
817
- TIC;
818
1029
 
819
1030
  // prepare the result handlers
820
1031
 
821
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
822
- AlignedTable<uint16_t> tmp_distances;
823
-
824
- using HeapHC = HeapHandler<C, true>;
825
- using ReservoirHC = ReservoirHandler<C, true>;
826
- using SingleResultHC = SingleResultHandler<C, true>;
827
-
828
- if (k == 1) {
829
- handler.reset(new SingleResultHC(n, 0));
830
- } else if (impl == 12) {
831
- tmp_distances.resize(n * k);
832
- handler.reset(new HeapHC(n, tmp_distances.get(), labels, k, 0));
833
- } else if (impl == 13) {
834
- handler.reset(new ReservoirHC(n, 0, k, 2 * k));
835
- }
836
-
837
1032
  int qbs2 = this->qbs2 ? this->qbs2 : 11;
838
1033
 
839
1034
  std::vector<uint16_t> tmp_bias;
840
1035
  if (biases.get()) {
841
1036
  tmp_bias.resize(qbs2);
842
- handler->dbias = tmp_bias.data();
1037
+ handler.dbias = tmp_bias.data();
843
1038
  }
844
- TIC;
845
1039
 
846
1040
  size_t ndis = 0;
847
1041
 
848
1042
  size_t i0 = 0;
849
1043
  uint64_t t_copy_pack = 0, t_scan = 0;
850
1044
  while (i0 < qcs.size()) {
851
- uint64_t tt0 = get_cy();
852
-
853
1045
  // find all queries that access this inverted list
854
1046
  int list_no = qcs[i0].list_no;
855
1047
  size_t i1 = i0 + 1;
@@ -897,93 +1089,50 @@ void IndexIVFFastScan::search_implem_12(
897
1089
 
898
1090
  // prepare the handler
899
1091
 
900
- handler->ntotal = list_size;
901
- handler->q_map = q_map.data();
902
- handler->id_map = ids.get();
903
- uint64_t tt1 = get_cy();
1092
+ handler.ntotal = list_size;
1093
+ handler.q_map = q_map.data();
1094
+ handler.id_map = ids.get();
904
1095
 
905
- #define DISPATCH(classHC) \
906
- if (dynamic_cast<classHC*>(handler.get())) { \
907
- auto* res = static_cast<classHC*>(handler.get()); \
908
- pq4_accumulate_loop_qbs( \
909
- qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
910
- }
911
- DISPATCH(HeapHC)
912
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
913
-
914
- // prepare for next loop
915
- i0 = i1;
916
-
917
- uint64_t tt2 = get_cy();
918
- t_copy_pack += tt1 - tt0;
919
- t_scan += tt2 - tt1;
1096
+ pq4_accumulate_loop_qbs(
1097
+ qbs, list_size, M2, codes.get(), LUT.get(), handler, scaler);
1098
+ // prepare for next loop
1099
+ i0 = i1;
920
1100
  }
921
- TIC;
922
1101
 
923
- // labels is in-place for HeapHC
924
- handler->to_flat_arrays(
925
- distances, labels, skip & 16 ? nullptr : normalizers.get());
926
-
927
- TIC;
1102
+ handler.end();
928
1103
 
929
1104
  // these stats are not thread-safe
930
1105
 
931
- for (int i = 1; i < ti; i++) {
932
- IVFFastScan_stats.times[i] += times[i] - times[i - 1];
933
- }
934
1106
  IVFFastScan_stats.t_copy_pack += t_copy_pack;
935
1107
  IVFFastScan_stats.t_scan += t_scan;
936
1108
 
937
- if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
938
- for (int i = 0; i < 4; i++) {
939
- IVFFastScan_stats.reservoir_times[i] += rh->times[i];
940
- }
941
- }
942
-
943
1109
  *ndis_out = ndis;
944
1110
  *nlist_out = nlist;
945
1111
  }
946
1112
 
947
- template <class C, class Scaler>
948
1113
  void IndexIVFFastScan::search_implem_14(
949
1114
  idx_t n,
950
1115
  const float* x,
951
1116
  idx_t k,
952
1117
  float* distances,
953
1118
  idx_t* labels,
1119
+ const CoarseQuantized& cq,
954
1120
  int impl,
955
- const Scaler& scaler) const {
1121
+ const NormTableScaler* scaler,
1122
+ const IVFSearchParameters* params) const {
956
1123
  if (n == 0) { // does not work well with reservoir
957
1124
  return;
958
1125
  }
959
1126
  FAISS_THROW_IF_NOT(bbs == 32);
960
1127
 
961
- std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
962
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
963
-
964
- uint64_t ttg0 = get_cy();
965
-
966
- quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
967
-
968
- uint64_t ttg1 = get_cy();
969
- uint64_t coarse_search_tt = ttg1 - ttg0;
1128
+ const IDSelector* sel = params ? params->sel : nullptr;
970
1129
 
971
1130
  size_t dim12 = ksub * M2;
972
1131
  AlignedTable<uint8_t> dis_tables;
973
1132
  AlignedTable<uint16_t> biases;
974
1133
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
975
1134
 
976
- compute_LUT_uint8(
977
- n,
978
- x,
979
- coarse_ids.get(),
980
- coarse_dis.get(),
981
- dis_tables,
982
- biases,
983
- normalizers.get());
984
-
985
- uint64_t ttg2 = get_cy();
986
- uint64_t lut_compute_tt = ttg2 - ttg1;
1135
+ compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
987
1136
 
988
1137
  struct QC {
989
1138
  int qno; // sequence number of the query
@@ -991,14 +1140,15 @@ void IndexIVFFastScan::search_implem_14(
991
1140
  int rank; // this is the rank'th result of the coarse quantizer
992
1141
  };
993
1142
  bool single_LUT = !lookup_table_is_3d();
1143
+ size_t nprobe = cq.nprobe;
994
1144
 
995
1145
  std::vector<QC> qcs;
996
1146
  {
997
1147
  int ij = 0;
998
1148
  for (int i = 0; i < n; i++) {
999
1149
  for (int j = 0; j < nprobe; j++) {
1000
- if (coarse_ids[ij] >= 0) {
1001
- qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
1150
+ if (cq.ids[ij] >= 0) {
1151
+ qcs.push_back(QC{i, int(cq.ids[ij]), int(j)});
1002
1152
  }
1003
1153
  ij++;
1004
1154
  }
@@ -1036,14 +1186,13 @@ void IndexIVFFastScan::search_implem_14(
1036
1186
  ses.push_back(SE{i0_l, i1, list_size});
1037
1187
  i0_l = i1;
1038
1188
  }
1039
- uint64_t ttg3 = get_cy();
1040
- uint64_t compute_clusters_tt = ttg3 - ttg2;
1041
1189
 
1042
1190
  // function to handle the global heap
1191
+ bool is_max = !is_similarity_metric(metric_type);
1043
1192
  using HeapForIP = CMin<float, idx_t>;
1044
1193
  using HeapForL2 = CMax<float, idx_t>;
1045
1194
  auto init_result = [&](float* simi, idx_t* idxi) {
1046
- if (metric_type == METRIC_INNER_PRODUCT) {
1195
+ if (!is_max) {
1047
1196
  heap_heapify<HeapForIP>(k, simi, idxi);
1048
1197
  } else {
1049
1198
  heap_heapify<HeapForL2>(k, simi, idxi);
@@ -1054,7 +1203,7 @@ void IndexIVFFastScan::search_implem_14(
1054
1203
  const idx_t* local_idx,
1055
1204
  float* simi,
1056
1205
  idx_t* idxi) {
1057
- if (metric_type == METRIC_INNER_PRODUCT) {
1206
+ if (!is_max) {
1058
1207
  heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
1059
1208
  } else {
1060
1209
  heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
@@ -1062,14 +1211,12 @@ void IndexIVFFastScan::search_implem_14(
1062
1211
  };
1063
1212
 
1064
1213
  auto reorder_result = [&](float* simi, idx_t* idxi) {
1065
- if (metric_type == METRIC_INNER_PRODUCT) {
1214
+ if (!is_max) {
1066
1215
  heap_reorder<HeapForIP>(k, simi, idxi);
1067
1216
  } else {
1068
1217
  heap_reorder<HeapForL2>(k, simi, idxi);
1069
1218
  }
1070
1219
  };
1071
- uint64_t ttg4 = get_cy();
1072
- uint64_t fn_tt = ttg4 - ttg3;
1073
1220
 
1074
1221
  size_t ndis = 0;
1075
1222
  size_t nlist_visited = 0;
@@ -1081,22 +1228,9 @@ void IndexIVFFastScan::search_implem_14(
1081
1228
  std::vector<float> local_dis(k * n);
1082
1229
 
1083
1230
  // prepare the result handlers
1084
- std::unique_ptr<SIMDResultHandler<C, true>> handler;
1085
- AlignedTable<uint16_t> tmp_distances;
1086
-
1087
- using HeapHC = HeapHandler<C, true>;
1088
- using ReservoirHC = ReservoirHandler<C, true>;
1089
- using SingleResultHC = SingleResultHandler<C, true>;
1090
-
1091
- if (k == 1) {
1092
- handler.reset(new SingleResultHC(n, 0));
1093
- } else if (impl == 14) {
1094
- tmp_distances.resize(n * k);
1095
- handler.reset(
1096
- new HeapHC(n, tmp_distances.get(), local_idx.data(), k, 0));
1097
- } else if (impl == 15) {
1098
- handler.reset(new ReservoirHC(n, 0, k, 2 * k));
1099
- }
1231
+ std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
1232
+ is_max, impl, n, k, local_dis.data(), local_idx.data(), sel));
1233
+ handler->begin(normalizers.get());
1100
1234
 
1101
1235
  int qbs2 = this->qbs2 ? this->qbs2 : 11;
1102
1236
 
@@ -1106,14 +1240,10 @@ void IndexIVFFastScan::search_implem_14(
1106
1240
  handler->dbias = tmp_bias.data();
1107
1241
  }
1108
1242
 
1109
- uint64_t ttg5 = get_cy();
1110
- uint64_t handler_tt = ttg5 - ttg4;
1111
-
1112
1243
  std::set<int> q_set;
1113
1244
  uint64_t t_copy_pack = 0, t_scan = 0;
1114
1245
  #pragma omp for schedule(dynamic)
1115
1246
  for (idx_t cluster = 0; cluster < ses.size(); cluster++) {
1116
- uint64_t tt0 = get_cy();
1117
1247
  size_t i0 = ses[cluster].start;
1118
1248
  size_t i1 = ses[cluster].end;
1119
1249
  size_t list_size = ses[cluster].list_size;
@@ -1153,28 +1283,21 @@ void IndexIVFFastScan::search_implem_14(
1153
1283
  handler->ntotal = list_size;
1154
1284
  handler->q_map = q_map.data();
1155
1285
  handler->id_map = ids.get();
1156
- uint64_t tt1 = get_cy();
1157
-
1158
- #define DISPATCH(classHC) \
1159
- if (dynamic_cast<classHC*>(handler.get())) { \
1160
- auto* res = static_cast<classHC*>(handler.get()); \
1161
- pq4_accumulate_loop_qbs( \
1162
- qbs, list_size, M2, codes.get(), LUT.get(), *res, scaler); \
1163
- }
1164
- DISPATCH(HeapHC)
1165
- else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
1166
1286
 
1167
- uint64_t tt2 = get_cy();
1168
- t_copy_pack += tt1 - tt0;
1169
- t_scan += tt2 - tt1;
1287
+ pq4_accumulate_loop_qbs(
1288
+ qbs,
1289
+ list_size,
1290
+ M2,
1291
+ codes.get(),
1292
+ LUT.get(),
1293
+ *handler.get(),
1294
+ scaler);
1170
1295
  }
1171
1296
 
1172
1297
  // labels is in-place for HeapHC
1173
- handler->to_flat_arrays(
1174
- local_dis.data(),
1175
- local_idx.data(),
1176
- skip & 16 ? nullptr : normalizers.get());
1298
+ handler->end();
1177
1299
 
1300
+ // merge per-thread results
1178
1301
  #pragma omp single
1179
1302
  {
1180
1303
  // we init the results as a heap
@@ -1197,12 +1320,6 @@ void IndexIVFFastScan::search_implem_14(
1197
1320
 
1198
1321
  IVFFastScan_stats.t_copy_pack += t_copy_pack;
1199
1322
  IVFFastScan_stats.t_scan += t_scan;
1200
-
1201
- if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
1202
- for (int i = 0; i < 4; i++) {
1203
- IVFFastScan_stats.reservoir_times[i] += rh->times[i];
1204
- }
1205
- }
1206
1323
  }
1207
1324
  #pragma omp barrier
1208
1325
  #pragma omp single
@@ -1272,20 +1389,4 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
1272
1389
 
1273
1390
  IVFFastScanStats IVFFastScan_stats;
1274
1391
 
1275
- template void IndexIVFFastScan::search_dispatch_implem<true, NormTableScaler>(
1276
- idx_t n,
1277
- const float* x,
1278
- idx_t k,
1279
- float* distances,
1280
- idx_t* labels,
1281
- const NormTableScaler& scaler) const;
1282
-
1283
- template void IndexIVFFastScan::search_dispatch_implem<false, NormTableScaler>(
1284
- idx_t n,
1285
- const float* x,
1286
- idx_t k,
1287
- float* distances,
1288
- idx_t* labels,
1289
- const NormTableScaler& scaler) const;
1290
-
1291
1392
  } // namespace faiss