faiss 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -211,7 +211,7 @@ void estimators_from_tables_generic(
211
211
  int64_t* heap_ids,
212
212
  const NormTableScaler* scaler) {
213
213
  using accu_t = typename C::T;
214
- int nscale = scaler ? scaler->nscale : 0;
214
+ size_t nscale = scaler ? scaler->nscale : 0;
215
215
  for (size_t j = 0; j < ncodes; ++j) {
216
216
  BitstringReader bsr(codes + j * index.code_size, index.code_size);
217
217
  accu_t dis = bias;
@@ -270,6 +270,7 @@ void IndexIVFFastScan::compute_LUT_uint8(
270
270
  biases.resize(n * nprobe);
271
271
  }
272
272
 
273
+ // OMP for MSVC requires i to have signed integral type
273
274
  #pragma omp parallel for if (n > 100)
274
275
  for (int64_t i = 0; i < n; i++) {
275
276
  const float* t_in = dis_tables_float.get() + i * dim123;
@@ -306,11 +307,16 @@ void IndexIVFFastScan::search(
306
307
  idx_t k,
307
308
  float* distances,
308
309
  idx_t* labels,
309
- const SearchParameters* params) const {
310
- auto paramsi = dynamic_cast<const SearchParametersIVF*>(params);
311
- FAISS_THROW_IF_NOT_MSG(!params || paramsi, "need IVFSearchParameters");
310
+ const SearchParameters* params_in) const {
311
+ const IVFSearchParameters* params = nullptr;
312
+ if (params_in) {
313
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
314
+ FAISS_THROW_IF_NOT_MSG(
315
+ params, "IndexIVFFastScan params have incorrect type");
316
+ }
317
+
312
318
  search_preassigned(
313
- n, x, k, nullptr, nullptr, distances, labels, false, paramsi);
319
+ n, x, k, nullptr, nullptr, distances, labels, false, params);
314
320
  }
315
321
 
316
322
  void IndexIVFFastScan::search_preassigned(
@@ -326,18 +332,17 @@ void IndexIVFFastScan::search_preassigned(
326
332
  IndexIVFStats* stats) const {
327
333
  size_t nprobe = this->nprobe;
328
334
  if (params) {
329
- FAISS_THROW_IF_NOT_MSG(
330
- !params->quantizer_params, "quantizer params not supported");
331
335
  FAISS_THROW_IF_NOT(params->max_codes == 0);
332
336
  nprobe = params->nprobe;
333
337
  }
338
+
334
339
  FAISS_THROW_IF_NOT_MSG(
335
340
  !store_pairs, "store_pairs not supported for this index");
336
341
  FAISS_THROW_IF_NOT_MSG(!stats, "stats not supported for this index");
337
342
  FAISS_THROW_IF_NOT(k > 0);
338
343
 
339
344
  const CoarseQuantized cq = {nprobe, centroid_dis, assign};
340
- search_dispatch_implem(n, x, k, distances, labels, cq, nullptr);
345
+ search_dispatch_implem(n, x, k, distances, labels, cq, nullptr, params);
341
346
  }
342
347
 
343
348
  void IndexIVFFastScan::range_search(
@@ -345,10 +350,18 @@ void IndexIVFFastScan::range_search(
345
350
  const float* x,
346
351
  float radius,
347
352
  RangeSearchResult* result,
348
- const SearchParameters* params) const {
349
- FAISS_THROW_IF_NOT(!params);
353
+ const SearchParameters* params_in) const {
354
+ size_t nprobe = this->nprobe;
355
+ const IVFSearchParameters* params = nullptr;
356
+ if (params_in) {
357
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
358
+ FAISS_THROW_IF_NOT_MSG(
359
+ params, "IndexIVFFastScan params have incorrect type");
360
+ nprobe = params->nprobe;
361
+ }
362
+
350
363
  const CoarseQuantized cq = {nprobe, nullptr, nullptr};
351
- range_search_dispatch_implem(n, x, radius, *result, cq, nullptr);
364
+ range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params);
352
365
  }
353
366
 
354
367
  namespace {
@@ -359,17 +372,18 @@ ResultHandlerCompare<C, true>* make_knn_handler_fixC(
359
372
  idx_t n,
360
373
  idx_t k,
361
374
  float* distances,
362
- idx_t* labels) {
375
+ idx_t* labels,
376
+ const IDSelector* sel) {
363
377
  using HeapHC = HeapHandler<C, true>;
364
378
  using ReservoirHC = ReservoirHandler<C, true>;
365
379
  using SingleResultHC = SingleResultHandler<C, true>;
366
380
 
367
381
  if (k == 1) {
368
- return new SingleResultHC(n, 0, distances, labels);
382
+ return new SingleResultHC(n, 0, distances, labels, sel);
369
383
  } else if (impl % 2 == 0) {
370
- return new HeapHC(n, 0, k, distances, labels);
384
+ return new HeapHC(n, 0, k, distances, labels, sel);
371
385
  } else /* if (impl % 2 == 1) */ {
372
- return new ReservoirHC(n, 0, k, 2 * k, distances, labels);
386
+ return new ReservoirHC(n, 0, k, 2 * k, distances, labels, sel);
373
387
  }
374
388
  }
375
389
 
@@ -379,13 +393,14 @@ SIMDResultHandlerToFloat* make_knn_handler(
379
393
  idx_t n,
380
394
  idx_t k,
381
395
  float* distances,
382
- idx_t* labels) {
396
+ idx_t* labels,
397
+ const IDSelector* sel) {
383
398
  if (is_max) {
384
399
  return make_knn_handler_fixC<CMax<uint16_t, int64_t>>(
385
- impl, n, k, distances, labels);
400
+ impl, n, k, distances, labels, sel);
386
401
  } else {
387
402
  return make_knn_handler_fixC<CMin<uint16_t, int64_t>>(
388
- impl, n, k, distances, labels);
403
+ impl, n, k, distances, labels, sel);
389
404
  }
390
405
  }
391
406
 
@@ -402,10 +417,20 @@ struct CoarseQuantizedWithBuffer : CoarseQuantized {
402
417
  std::vector<idx_t> ids_buffer;
403
418
  std::vector<float> dis_buffer;
404
419
 
405
- void quantize(const Index* quantizer, idx_t n, const float* x) {
420
+ void quantize(
421
+ const Index* quantizer,
422
+ idx_t n,
423
+ const float* x,
424
+ const SearchParameters* quantizer_params) {
406
425
  dis_buffer.resize(nprobe * n);
407
426
  ids_buffer.resize(nprobe * n);
408
- quantizer->search(n, x, nprobe, dis_buffer.data(), ids_buffer.data());
427
+ quantizer->search(
428
+ n,
429
+ x,
430
+ nprobe,
431
+ dis_buffer.data(),
432
+ ids_buffer.data(),
433
+ quantizer_params);
409
434
  dis = dis_buffer.data();
410
435
  ids = ids_buffer.data();
411
436
  }
@@ -421,8 +446,11 @@ struct CoarseQuantizedSlice : CoarseQuantizedWithBuffer {
421
446
  }
422
447
  }
423
448
 
424
- void quantize_slice(const Index* quantizer, const float* x) {
425
- quantize(quantizer, i1 - i0, x + quantizer->d * i0);
449
+ void quantize_slice(
450
+ const Index* quantizer,
451
+ const float* x,
452
+ const SearchParameters* quantizer_params) {
453
+ quantize(quantizer, i1 - i0, x + quantizer->d * i0, quantizer_params);
426
454
  }
427
455
  };
428
456
 
@@ -459,7 +487,13 @@ void IndexIVFFastScan::search_dispatch_implem(
459
487
  float* distances,
460
488
  idx_t* labels,
461
489
  const CoarseQuantized& cq_in,
462
- const NormTableScaler* scaler) const {
490
+ const NormTableScaler* scaler,
491
+ const IVFSearchParameters* params) const {
492
+ const idx_t nprobe = params ? params->nprobe : this->nprobe;
493
+ const IDSelector* sel = (params) ? params->sel : nullptr;
494
+ const SearchParameters* quantizer_params =
495
+ params ? params->quantizer_params : nullptr;
496
+
463
497
  bool is_max = !is_similarity_metric(metric_type);
464
498
  using RH = SIMDResultHandlerToFloat;
465
499
 
@@ -489,52 +523,70 @@ void IndexIVFFastScan::search_dispatch_implem(
489
523
  }
490
524
 
491
525
  CoarseQuantizedWithBuffer cq(cq_in);
526
+ cq.nprobe = nprobe;
492
527
 
493
528
  if (!cq.done() && !multiple_threads) {
494
529
  // we do the coarse quantization here execpt when search is
495
530
  // sliced over threads (then it is more efficient to have each thread do
496
531
  // its own coarse quantization)
497
- cq.quantize(quantizer, n, x);
532
+ cq.quantize(quantizer, n, x, quantizer_params);
533
+ invlists->prefetch_lists(cq.ids, n * cq.nprobe);
498
534
  }
499
535
 
500
536
  if (impl == 1) {
501
537
  if (is_max) {
502
538
  search_implem_1<CMax<float, int64_t>>(
503
- n, x, k, distances, labels, cq, scaler);
539
+ n, x, k, distances, labels, cq, scaler, params);
504
540
  } else {
505
541
  search_implem_1<CMin<float, int64_t>>(
506
- n, x, k, distances, labels, cq, scaler);
542
+ n, x, k, distances, labels, cq, scaler, params);
507
543
  }
508
544
  } else if (impl == 2) {
509
545
  if (is_max) {
510
546
  search_implem_2<CMax<uint16_t, int64_t>>(
511
- n, x, k, distances, labels, cq, scaler);
547
+ n, x, k, distances, labels, cq, scaler, params);
512
548
  } else {
513
549
  search_implem_2<CMin<uint16_t, int64_t>>(
514
- n, x, k, distances, labels, cq, scaler);
550
+ n, x, k, distances, labels, cq, scaler, params);
515
551
  }
516
-
517
552
  } else if (impl >= 10 && impl <= 15) {
518
553
  size_t ndis = 0, nlist_visited = 0;
519
554
 
520
555
  if (!multiple_threads) {
521
556
  // clang-format off
522
557
  if (impl == 12 || impl == 13) {
523
- std::unique_ptr<RH> handler(make_knn_handler(is_max, impl, n, k, distances, labels));
558
+ std::unique_ptr<RH> handler(
559
+ make_knn_handler(
560
+ is_max,
561
+ impl,
562
+ n,
563
+ k,
564
+ distances,
565
+ labels, sel
566
+ )
567
+ );
524
568
  search_implem_12(
525
569
  n, x, *handler.get(),
526
- cq, &ndis, &nlist_visited, scaler);
527
-
570
+ cq, &ndis, &nlist_visited, scaler, params);
528
571
  } else if (impl == 14 || impl == 15) {
529
-
530
572
  search_implem_14(
531
573
  n, x, k, distances, labels,
532
- cq, impl, scaler);
574
+ cq, impl, scaler, params);
533
575
  } else {
534
- std::unique_ptr<RH> handler(make_knn_handler(is_max, impl, n, k, distances, labels));
576
+ std::unique_ptr<RH> handler(
577
+ make_knn_handler(
578
+ is_max,
579
+ impl,
580
+ n,
581
+ k,
582
+ distances,
583
+ labels,
584
+ sel
585
+ )
586
+ );
535
587
  search_implem_10(
536
588
  n, x, *handler.get(), cq,
537
- &ndis, &nlist_visited, scaler);
589
+ &ndis, &nlist_visited, scaler, params);
538
590
  }
539
591
  // clang-format on
540
592
  } else {
@@ -543,7 +595,8 @@ void IndexIVFFastScan::search_dispatch_implem(
543
595
  if (impl == 14 || impl == 15) {
544
596
  // this might require slicing if there are too
545
597
  // many queries (for now we keep this simple)
546
- search_implem_14(n, x, k, distances, labels, cq, impl, scaler);
598
+ search_implem_14(
599
+ n, x, k, distances, labels, cq, impl, scaler, params);
547
600
  } else {
548
601
  #pragma omp parallel for reduction(+ : ndis, nlist_visited)
549
602
  for (int slice = 0; slice < nslice; slice++) {
@@ -553,19 +606,19 @@ void IndexIVFFastScan::search_dispatch_implem(
553
606
  idx_t* lab_i = labels + i0 * k;
554
607
  CoarseQuantizedSlice cq_i(cq, i0, i1);
555
608
  if (!cq_i.done()) {
556
- cq_i.quantize_slice(quantizer, x);
609
+ cq_i.quantize_slice(quantizer, x, quantizer_params);
557
610
  }
558
611
  std::unique_ptr<RH> handler(make_knn_handler(
559
- is_max, impl, i1 - i0, k, dis_i, lab_i));
612
+ is_max, impl, i1 - i0, k, dis_i, lab_i, sel));
560
613
  // clang-format off
561
614
  if (impl == 12 || impl == 13) {
562
615
  search_implem_12(
563
616
  i1 - i0, x + i0 * d, *handler.get(),
564
- cq_i, &ndis, &nlist_visited, scaler);
617
+ cq_i, &ndis, &nlist_visited, scaler, params);
565
618
  } else {
566
619
  search_implem_10(
567
620
  i1 - i0, x + i0 * d, *handler.get(),
568
- cq_i, &ndis, &nlist_visited, scaler);
621
+ cq_i, &ndis, &nlist_visited, scaler, params);
569
622
  }
570
623
  // clang-format on
571
624
  }
@@ -585,7 +638,13 @@ void IndexIVFFastScan::range_search_dispatch_implem(
585
638
  float radius,
586
639
  RangeSearchResult& rres,
587
640
  const CoarseQuantized& cq_in,
588
- const NormTableScaler* scaler) const {
641
+ const NormTableScaler* scaler,
642
+ const IVFSearchParameters* params) const {
643
+ // const idx_t nprobe = params ? params->nprobe : this->nprobe;
644
+ const IDSelector* sel = (params) ? params->sel : nullptr;
645
+ const SearchParameters* quantizer_params =
646
+ params ? params->quantizer_params : nullptr;
647
+
589
648
  bool is_max = !is_similarity_metric(metric_type);
590
649
 
591
650
  if (n == 0) {
@@ -613,7 +672,8 @@ void IndexIVFFastScan::range_search_dispatch_implem(
613
672
  }
614
673
 
615
674
  if (!multiple_threads && !cq.done()) {
616
- cq.quantize(quantizer, n, x);
675
+ cq.quantize(quantizer, n, x, quantizer_params);
676
+ invlists->prefetch_lists(cq.ids, n * cq.nprobe);
617
677
  }
618
678
 
619
679
  size_t ndis = 0, nlist_visited = 0;
@@ -622,10 +682,10 @@ void IndexIVFFastScan::range_search_dispatch_implem(
622
682
  std::unique_ptr<SIMDResultHandlerToFloat> handler;
623
683
  if (is_max) {
624
684
  handler.reset(new RangeHandler<CMax<uint16_t, int64_t>, true>(
625
- rres, radius, 0));
685
+ rres, radius, 0, sel));
626
686
  } else {
627
687
  handler.reset(new RangeHandler<CMin<uint16_t, int64_t>, true>(
628
- rres, radius, 0));
688
+ rres, radius, 0, sel));
629
689
  }
630
690
  if (impl == 12) {
631
691
  search_implem_12(
@@ -634,7 +694,7 @@ void IndexIVFFastScan::range_search_dispatch_implem(
634
694
  search_implem_10(
635
695
  n, x, *handler.get(), cq, &ndis, &nlist_visited, scaler);
636
696
  } else {
637
- FAISS_THROW_FMT("Range search implem %d not impemented", impl);
697
+ FAISS_THROW_FMT("Range search implem %d not implemented", impl);
638
698
  }
639
699
  } else {
640
700
  // explicitly slice over threads
@@ -649,17 +709,17 @@ void IndexIVFFastScan::range_search_dispatch_implem(
649
709
  idx_t i1 = n * (slice + 1) / nslice;
650
710
  CoarseQuantizedSlice cq_i(cq, i0, i1);
651
711
  if (!cq_i.done()) {
652
- cq_i.quantize_slice(quantizer, x);
712
+ cq_i.quantize_slice(quantizer, x, quantizer_params);
653
713
  }
654
714
  std::unique_ptr<SIMDResultHandlerToFloat> handler;
655
715
  if (is_max) {
656
716
  handler.reset(new PartialRangeHandler<
657
717
  CMax<uint16_t, int64_t>,
658
- true>(pres, radius, 0, i0, i1));
718
+ true>(pres, radius, 0, i0, i1, sel));
659
719
  } else {
660
720
  handler.reset(new PartialRangeHandler<
661
721
  CMin<uint16_t, int64_t>,
662
- true>(pres, radius, 0, i0, i1));
722
+ true>(pres, radius, 0, i0, i1, sel));
663
723
  }
664
724
 
665
725
  if (impl == 12 || impl == 13) {
@@ -670,7 +730,8 @@ void IndexIVFFastScan::range_search_dispatch_implem(
670
730
  cq_i,
671
731
  &ndis,
672
732
  &nlist_visited,
673
- scaler);
733
+ scaler,
734
+ params);
674
735
  } else {
675
736
  search_implem_10(
676
737
  i1 - i0,
@@ -679,7 +740,8 @@ void IndexIVFFastScan::range_search_dispatch_implem(
679
740
  cq_i,
680
741
  &ndis,
681
742
  &nlist_visited,
682
- scaler);
743
+ scaler,
744
+ params);
683
745
  }
684
746
  }
685
747
  pres.finalize();
@@ -699,7 +761,8 @@ void IndexIVFFastScan::search_implem_1(
699
761
  float* distances,
700
762
  idx_t* labels,
701
763
  const CoarseQuantized& cq,
702
- const NormTableScaler* scaler) const {
764
+ const NormTableScaler* scaler,
765
+ const IVFSearchParameters* params) const {
703
766
  FAISS_THROW_IF_NOT(orig_invlists);
704
767
 
705
768
  size_t dim12 = ksub * M;
@@ -766,7 +829,8 @@ void IndexIVFFastScan::search_implem_2(
766
829
  float* distances,
767
830
  idx_t* labels,
768
831
  const CoarseQuantized& cq,
769
- const NormTableScaler* scaler) const {
832
+ const NormTableScaler* scaler,
833
+ const IVFSearchParameters* params) const {
770
834
  FAISS_THROW_IF_NOT(orig_invlists);
771
835
 
772
836
  size_t dim12 = ksub * M2;
@@ -848,7 +912,8 @@ void IndexIVFFastScan::search_implem_10(
848
912
  const CoarseQuantized& cq,
849
913
  size_t* ndis_out,
850
914
  size_t* nlist_out,
851
- const NormTableScaler* scaler) const {
915
+ const NormTableScaler* scaler,
916
+ const IVFSearchParameters* params) const {
852
917
  size_t dim12 = ksub * M2;
853
918
  AlignedTable<uint8_t> dis_tables;
854
919
  AlignedTable<uint16_t> biases;
@@ -909,6 +974,7 @@ void IndexIVFFastScan::search_implem_10(
909
974
  ndis++;
910
975
  }
911
976
  }
977
+
912
978
  handler.end();
913
979
  *ndis_out = ndis;
914
980
  *nlist_out = nlist;
@@ -921,7 +987,8 @@ void IndexIVFFastScan::search_implem_12(
921
987
  const CoarseQuantized& cq,
922
988
  size_t* ndis_out,
923
989
  size_t* nlist_out,
924
- const NormTableScaler* scaler) const {
990
+ const NormTableScaler* scaler,
991
+ const IVFSearchParameters* params) const {
925
992
  if (n == 0) { // does not work well with reservoir
926
993
  return;
927
994
  }
@@ -933,6 +1000,7 @@ void IndexIVFFastScan::search_implem_12(
933
1000
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
934
1001
 
935
1002
  compute_LUT_uint8(n, x, cq, dis_tables, biases, normalizers.get());
1003
+
936
1004
  handler.begin(skip & 16 ? nullptr : normalizers.get());
937
1005
 
938
1006
  struct QC {
@@ -958,6 +1026,7 @@ void IndexIVFFastScan::search_implem_12(
958
1026
  return a.list_no < b.list_no;
959
1027
  });
960
1028
  }
1029
+
961
1030
  // prepare the result handlers
962
1031
 
963
1032
  int qbs2 = this->qbs2 ? this->qbs2 : 11;
@@ -1049,12 +1118,15 @@ void IndexIVFFastScan::search_implem_14(
1049
1118
  idx_t* labels,
1050
1119
  const CoarseQuantized& cq,
1051
1120
  int impl,
1052
- const NormTableScaler* scaler) const {
1121
+ const NormTableScaler* scaler,
1122
+ const IVFSearchParameters* params) const {
1053
1123
  if (n == 0) { // does not work well with reservoir
1054
1124
  return;
1055
1125
  }
1056
1126
  FAISS_THROW_IF_NOT(bbs == 32);
1057
1127
 
1128
+ const IDSelector* sel = params ? params->sel : nullptr;
1129
+
1058
1130
  size_t dim12 = ksub * M2;
1059
1131
  AlignedTable<uint8_t> dis_tables;
1060
1132
  AlignedTable<uint16_t> biases;
@@ -1157,7 +1229,7 @@ void IndexIVFFastScan::search_implem_14(
1157
1229
 
1158
1230
  // prepare the result handlers
1159
1231
  std::unique_ptr<SIMDResultHandlerToFloat> handler(make_knn_handler(
1160
- is_max, impl, n, k, local_dis.data(), local_idx.data()));
1232
+ is_max, impl, n, k, local_dis.data(), local_idx.data(), sel));
1161
1233
  handler->begin(normalizers.get());
1162
1234
 
1163
1235
  int qbs2 = this->qbs2 ? this->qbs2 : 11;
@@ -1167,6 +1239,7 @@ void IndexIVFFastScan::search_implem_14(
1167
1239
  tmp_bias.resize(qbs2);
1168
1240
  handler->dbias = tmp_bias.data();
1169
1241
  }
1242
+
1170
1243
  std::set<int> q_set;
1171
1244
  uint64_t t_copy_pack = 0, t_scan = 0;
1172
1245
  #pragma omp for schedule(dynamic)
@@ -148,7 +148,8 @@ struct IndexIVFFastScan : IndexIVF {
148
148
  float* distances,
149
149
  idx_t* labels,
150
150
  const CoarseQuantized& cq,
151
- const NormTableScaler* scaler) const;
151
+ const NormTableScaler* scaler,
152
+ const IVFSearchParameters* params = nullptr) const;
152
153
 
153
154
  void range_search_dispatch_implem(
154
155
  idx_t n,
@@ -156,7 +157,8 @@ struct IndexIVFFastScan : IndexIVF {
156
157
  float radius,
157
158
  RangeSearchResult& rres,
158
159
  const CoarseQuantized& cq_in,
159
- const NormTableScaler* scaler) const;
160
+ const NormTableScaler* scaler,
161
+ const IVFSearchParameters* params = nullptr) const;
160
162
 
161
163
  // impl 1 and 2 are just for verification
162
164
  template <class C>
@@ -167,7 +169,8 @@ struct IndexIVFFastScan : IndexIVF {
167
169
  float* distances,
168
170
  idx_t* labels,
169
171
  const CoarseQuantized& cq,
170
- const NormTableScaler* scaler) const;
172
+ const NormTableScaler* scaler,
173
+ const IVFSearchParameters* params = nullptr) const;
171
174
 
172
175
  template <class C>
173
176
  void search_implem_2(
@@ -177,7 +180,8 @@ struct IndexIVFFastScan : IndexIVF {
177
180
  float* distances,
178
181
  idx_t* labels,
179
182
  const CoarseQuantized& cq,
180
- const NormTableScaler* scaler) const;
183
+ const NormTableScaler* scaler,
184
+ const IVFSearchParameters* params = nullptr) const;
181
185
 
182
186
  // implem 10 and 12 are not multithreaded internally, so
183
187
  // export search stats
@@ -188,7 +192,8 @@ struct IndexIVFFastScan : IndexIVF {
188
192
  const CoarseQuantized& cq,
189
193
  size_t* ndis_out,
190
194
  size_t* nlist_out,
191
- const NormTableScaler* scaler) const;
195
+ const NormTableScaler* scaler,
196
+ const IVFSearchParameters* params = nullptr) const;
192
197
 
193
198
  void search_implem_12(
194
199
  idx_t n,
@@ -197,7 +202,8 @@ struct IndexIVFFastScan : IndexIVF {
197
202
  const CoarseQuantized& cq,
198
203
  size_t* ndis_out,
199
204
  size_t* nlist_out,
200
- const NormTableScaler* scaler) const;
205
+ const NormTableScaler* scaler,
206
+ const IVFSearchParameters* params = nullptr) const;
201
207
 
202
208
  // implem 14 is multithreaded internally across nprobes and queries
203
209
  void search_implem_14(
@@ -208,7 +214,8 @@ struct IndexIVFFastScan : IndexIVF {
208
214
  idx_t* labels,
209
215
  const CoarseQuantized& cq,
210
216
  int impl,
211
- const NormTableScaler* scaler) const;
217
+ const NormTableScaler* scaler,
218
+ const IVFSearchParameters* params = nullptr) const;
212
219
 
213
220
  // reconstruct vectors from packed invlists
214
221
  void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
@@ -318,16 +318,14 @@ void IndexIVFPQ::reconstruct_from_offset(
318
318
  float* recons) const {
319
319
  const uint8_t* code = invlists->get_single_code(list_no, offset);
320
320
 
321
+ pq.decode(code, recons);
321
322
  if (by_residual) {
322
323
  std::vector<float> centroid(d);
323
324
  quantizer->reconstruct(list_no, centroid.data());
324
325
 
325
- pq.decode(code, recons);
326
326
  for (int i = 0; i < d; ++i) {
327
327
  recons[i] += centroid[i];
328
328
  }
329
- } else {
330
- pq.decode(code, recons);
331
329
  }
332
330
  }
333
331
 
@@ -286,9 +286,28 @@ void IndexIVFPQFastScan::compute_LUT(
286
286
  }
287
287
  }
288
288
 
289
- void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x)
289
+ void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
290
290
  const {
291
- pq.decode(bytes, x, n);
291
+ size_t coarse_size = coarse_code_size();
292
+
293
+ #pragma omp parallel if (n > 1)
294
+ {
295
+ std::vector<float> residual(d);
296
+
297
+ #pragma omp for
298
+ for (idx_t i = 0; i < n; i++) {
299
+ const uint8_t* code = codes + i * (code_size + coarse_size);
300
+ int64_t list_no = decode_listno(code);
301
+ float* xi = x + i * d;
302
+ pq.decode(code + coarse_size, xi);
303
+ if (by_residual) {
304
+ quantizer->reconstruct(list_no, residual.data());
305
+ for (size_t j = 0; j < d; j++) {
306
+ xi[j] += residual[j];
307
+ }
308
+ }
309
+ }
310
+ }
292
311
  }
293
312
 
294
313
  } // namespace faiss
@@ -15,7 +15,7 @@
15
15
  namespace faiss {
16
16
 
17
17
  IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
18
- : Index(d),
18
+ : IndexFlatCodes(0, d, METRIC_L2),
19
19
  nsq(nsq),
20
20
  dsq(d / nsq),
21
21
  zn_sphere_codec(dsq, r2),
@@ -114,22 +114,4 @@ void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
114
114
  }
115
115
  }
116
116
 
117
- void IndexLattice::add(idx_t, const float*) {
118
- FAISS_THROW_MSG("not implemented");
119
- }
120
-
121
- void IndexLattice::search(
122
- idx_t,
123
- const float*,
124
- idx_t,
125
- float*,
126
- idx_t*,
127
- const SearchParameters*) const {
128
- FAISS_THROW_MSG("not implemented");
129
- }
130
-
131
- void IndexLattice::reset() {
132
- FAISS_THROW_MSG("not implemented");
133
- }
134
-
135
117
  } // namespace faiss
@@ -5,21 +5,18 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
- #ifndef FAISS_INDEX_LATTICE_H
11
- #define FAISS_INDEX_LATTICE_H
8
+ #pragma once
12
9
 
13
10
  #include <vector>
14
11
 
15
- #include <faiss/IndexIVF.h>
12
+ #include <faiss/IndexFlatCodes.h>
16
13
  #include <faiss/impl/lattice_Zn.h>
17
14
 
18
15
  namespace faiss {
19
16
 
20
17
  /** Index that encodes a vector with a series of Zn lattice quantizers
21
18
  */
22
- struct IndexLattice : Index {
19
+ struct IndexLattice : IndexFlatCodes {
23
20
  /// number of sub-vectors
24
21
  int nsq;
25
22
  /// dimension of sub-vectors
@@ -30,8 +27,6 @@ struct IndexLattice : Index {
30
27
 
31
28
  /// nb bits used to encode the scale, per subvector
32
29
  int scale_nbit, lattice_nbit;
33
- /// total, in bytes
34
- size_t code_size;
35
30
 
36
31
  /// mins and maxes of the vector norms, per subquantizer
37
32
  std::vector<float> trained;
@@ -46,20 +41,6 @@ struct IndexLattice : Index {
46
41
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
47
42
 
48
43
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
49
-
50
- /// not implemented
51
- void add(idx_t n, const float* x) override;
52
- void search(
53
- idx_t n,
54
- const float* x,
55
- idx_t k,
56
- float* distances,
57
- idx_t* labels,
58
- const SearchParameters* params = nullptr) const override;
59
-
60
- void reset() override;
61
44
  };
62
45
 
63
46
  } // namespace faiss
64
-
65
- #endif