faiss 0.4.1 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +39 -29
  5. data/vendor/faiss/faiss/Clustering.cpp +4 -2
  6. data/vendor/faiss/faiss/IVFlib.cpp +14 -7
  7. data/vendor/faiss/faiss/Index.h +72 -3
  8. data/vendor/faiss/faiss/Index2Layer.cpp +2 -4
  9. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +0 -1
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +1 -0
  11. data/vendor/faiss/faiss/IndexBinary.h +46 -3
  12. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +118 -4
  13. data/vendor/faiss/faiss/IndexBinaryHNSW.h +41 -0
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +18 -7
  16. data/vendor/faiss/faiss/IndexBinaryIVF.h +5 -1
  17. data/vendor/faiss/faiss/IndexFlat.cpp +6 -4
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +65 -24
  19. data/vendor/faiss/faiss/IndexHNSW.h +10 -1
  20. data/vendor/faiss/faiss/IndexIDMap.cpp +96 -18
  21. data/vendor/faiss/faiss/IndexIDMap.h +20 -0
  22. data/vendor/faiss/faiss/IndexIVF.cpp +28 -10
  23. data/vendor/faiss/faiss/IndexIVF.h +16 -1
  24. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -16
  25. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +18 -6
  26. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +33 -21
  27. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +16 -6
  28. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +24 -15
  29. data/vendor/faiss/faiss/IndexIVFFastScan.h +4 -2
  30. data/vendor/faiss/faiss/IndexIVFFlat.cpp +59 -43
  31. data/vendor/faiss/faiss/IndexIVFFlat.h +10 -2
  32. data/vendor/faiss/faiss/IndexIVFPQ.cpp +16 -3
  33. data/vendor/faiss/faiss/IndexIVFPQ.h +8 -1
  34. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +14 -6
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +2 -1
  36. data/vendor/faiss/faiss/IndexIVFPQR.cpp +14 -4
  37. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  38. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +28 -3
  39. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +8 -1
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +9 -2
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  42. data/vendor/faiss/faiss/IndexLattice.cpp +8 -4
  43. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -7
  44. data/vendor/faiss/faiss/IndexNSG.cpp +3 -3
  45. data/vendor/faiss/faiss/IndexPQ.cpp +0 -1
  46. data/vendor/faiss/faiss/IndexPQ.h +1 -0
  47. data/vendor/faiss/faiss/IndexPQFastScan.cpp +0 -2
  48. data/vendor/faiss/faiss/IndexPreTransform.cpp +4 -2
  49. data/vendor/faiss/faiss/IndexRefine.cpp +11 -6
  50. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +16 -4
  51. data/vendor/faiss/faiss/IndexScalarQuantizer.h +10 -3
  52. data/vendor/faiss/faiss/IndexShards.cpp +7 -6
  53. data/vendor/faiss/faiss/MatrixStats.cpp +16 -8
  54. data/vendor/faiss/faiss/MetaIndexes.cpp +12 -6
  55. data/vendor/faiss/faiss/MetricType.h +5 -3
  56. data/vendor/faiss/faiss/clone_index.cpp +2 -4
  57. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +6 -0
  58. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +9 -4
  59. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +32 -10
  60. data/vendor/faiss/faiss/gpu/GpuIndex.h +88 -0
  61. data/vendor/faiss/faiss/gpu/GpuIndexBinaryCagra.h +125 -0
  62. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +39 -4
  63. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +3 -3
  64. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -1
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +3 -2
  66. data/vendor/faiss/faiss/gpu/utils/CuvsFilterConvert.h +41 -0
  67. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +6 -3
  68. data/vendor/faiss/faiss/impl/HNSW.cpp +34 -19
  69. data/vendor/faiss/faiss/impl/IDSelector.cpp +2 -1
  70. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +2 -3
  71. data/vendor/faiss/faiss/impl/NNDescent.cpp +17 -9
  72. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +42 -21
  73. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +6 -24
  74. data/vendor/faiss/faiss/impl/ResultHandler.h +56 -47
  75. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +28 -15
  76. data/vendor/faiss/faiss/impl/index_read.cpp +36 -11
  77. data/vendor/faiss/faiss/impl/index_write.cpp +19 -6
  78. data/vendor/faiss/faiss/impl/io.cpp +9 -5
  79. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +18 -11
  80. data/vendor/faiss/faiss/impl/mapped_io.cpp +4 -7
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +0 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +0 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +6 -6
  84. data/vendor/faiss/faiss/impl/zerocopy_io.cpp +1 -1
  85. data/vendor/faiss/faiss/impl/zerocopy_io.h +2 -2
  86. data/vendor/faiss/faiss/index_factory.cpp +49 -33
  87. data/vendor/faiss/faiss/index_factory.h +8 -2
  88. data/vendor/faiss/faiss/index_io.h +0 -3
  89. data/vendor/faiss/faiss/invlists/DirectMap.cpp +2 -1
  90. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +12 -6
  91. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +8 -4
  92. data/vendor/faiss/faiss/utils/Heap.cpp +15 -8
  93. data/vendor/faiss/faiss/utils/Heap.h +23 -12
  94. data/vendor/faiss/faiss/utils/distances.cpp +42 -21
  95. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  96. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +1 -1
  97. data/vendor/faiss/faiss/utils/distances_simd.cpp +5 -3
  98. data/vendor/faiss/faiss/utils/extra_distances-inl.h +27 -4
  99. data/vendor/faiss/faiss/utils/extra_distances.cpp +8 -4
  100. data/vendor/faiss/faiss/utils/hamming.cpp +20 -10
  101. data/vendor/faiss/faiss/utils/partitioning.cpp +8 -4
  102. data/vendor/faiss/faiss/utils/quantize_lut.cpp +17 -9
  103. data/vendor/faiss/faiss/utils/rabitq_simd.h +539 -0
  104. data/vendor/faiss/faiss/utils/random.cpp +14 -7
  105. data/vendor/faiss/faiss/utils/utils.cpp +0 -3
  106. metadata +5 -2
@@ -18,7 +18,6 @@
18
18
  #include <faiss/impl/pq4_fast_scan.h>
19
19
  #include <faiss/invlists/BlockInvertedLists.h>
20
20
  #include <faiss/utils/distances.h>
21
- #include <faiss/utils/hamming.h>
22
21
  #include <faiss/utils/quantize_lut.h>
23
22
  #include <faiss/utils/utils.h>
24
23
 
@@ -34,10 +33,11 @@ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan(
34
33
  size_t d,
35
34
  size_t nlist,
36
35
  MetricType metric,
37
- int bbs)
38
- : IndexIVFFastScan(quantizer, d, nlist, 0, metric) {
36
+ int bbs,
37
+ bool own_invlists)
38
+ : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists) {
39
39
  if (aq != nullptr) {
40
- init(aq, nlist, metric, bbs);
40
+ init(aq, nlist, metric, bbs, own_invlists);
41
41
  }
42
42
  }
43
43
 
@@ -45,7 +45,8 @@ void IndexIVFAdditiveQuantizerFastScan::init(
45
45
  AdditiveQuantizer* aq,
46
46
  size_t nlist,
47
47
  MetricType metric,
48
- int bbs) {
48
+ int bbs,
49
+ bool own_invlists) {
49
50
  FAISS_THROW_IF_NOT(aq != nullptr);
50
51
  FAISS_THROW_IF_NOT(!aq->nbits.empty());
51
52
  FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
@@ -66,7 +67,7 @@ void IndexIVFAdditiveQuantizerFastScan::init(
66
67
  } else {
67
68
  M = aq->M;
68
69
  }
69
- init_fastscan(aq, M, 4, nlist, metric, bbs);
70
+ init_fastscan(aq, M, 4, nlist, metric, bbs, own_invlists);
70
71
 
71
72
  max_train_points = 1024 * ksub * M;
72
73
  by_residual = true;
@@ -80,17 +81,20 @@ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan(
80
81
  orig.d,
81
82
  orig.nlist,
82
83
  0,
83
- orig.metric_type),
84
+ orig.metric_type,
85
+ orig.own_invlists),
84
86
  aq(orig.aq) {
85
87
  FAISS_THROW_IF_NOT(
86
88
  metric_type == METRIC_INNER_PRODUCT || !orig.by_residual);
87
89
 
88
- init(aq, nlist, metric_type, bbs);
90
+ init(aq, nlist, metric_type, bbs, own_invlists);
89
91
 
90
92
  is_trained = orig.is_trained;
91
93
  ntotal = orig.ntotal;
92
94
  nprobe = orig.nprobe;
93
-
95
+ if (!orig.own_invlists) {
96
+ return; // skip packing codes below
97
+ }
94
98
  for (size_t i = 0; i < nlist; i++) {
95
99
  size_t nb = orig.invlists->list_size(i);
96
100
  size_t nb2 = roundup(nb, bbs);
@@ -448,17 +452,19 @@ IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan(
448
452
  size_t nbits,
449
453
  MetricType metric,
450
454
  Search_type_t search_type,
451
- int bbs)
455
+ int bbs,
456
+ bool own_invlists)
452
457
  : IndexIVFAdditiveQuantizerFastScan(
453
458
  quantizer,
454
459
  nullptr,
455
460
  d,
456
461
  nlist,
457
462
  metric,
458
- bbs),
463
+ bbs,
464
+ own_invlists),
459
465
  lsq(d, M, nbits, search_type) {
460
466
  FAISS_THROW_IF_NOT(nbits == 4);
461
- init(&lsq, nlist, metric, bbs);
467
+ init(&lsq, nlist, metric, bbs, own_invlists);
462
468
  }
463
469
 
464
470
  IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan() {
@@ -474,17 +480,19 @@ IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan(
474
480
  size_t nbits,
475
481
  MetricType metric,
476
482
  Search_type_t search_type,
477
- int bbs)
483
+ int bbs,
484
+ bool own_invlists)
478
485
  : IndexIVFAdditiveQuantizerFastScan(
479
486
  quantizer,
480
487
  nullptr,
481
488
  d,
482
489
  nlist,
483
490
  metric,
484
- bbs),
491
+ bbs,
492
+ own_invlists),
485
493
  rq(d, M, nbits, search_type) {
486
494
  FAISS_THROW_IF_NOT(nbits == 4);
487
- init(&rq, nlist, metric, bbs);
495
+ init(&rq, nlist, metric, bbs, own_invlists);
488
496
  }
489
497
 
490
498
  IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan() {
@@ -502,17 +510,19 @@ IndexIVFProductLocalSearchQuantizerFastScan::
502
510
  size_t nbits,
503
511
  MetricType metric,
504
512
  Search_type_t search_type,
505
- int bbs)
513
+ int bbs,
514
+ bool own_invlists)
506
515
  : IndexIVFAdditiveQuantizerFastScan(
507
516
  quantizer,
508
517
  nullptr,
509
518
  d,
510
519
  nlist,
511
520
  metric,
512
- bbs),
521
+ bbs,
522
+ own_invlists),
513
523
  plsq(d, nsplits, Msub, nbits, search_type) {
514
524
  FAISS_THROW_IF_NOT(nbits == 4);
515
- init(&plsq, nlist, metric, bbs);
525
+ init(&plsq, nlist, metric, bbs, own_invlists);
516
526
  }
517
527
 
518
528
  IndexIVFProductLocalSearchQuantizerFastScan::
@@ -531,17 +541,19 @@ IndexIVFProductResidualQuantizerFastScan::
531
541
  size_t nbits,
532
542
  MetricType metric,
533
543
  Search_type_t search_type,
534
- int bbs)
544
+ int bbs,
545
+ bool own_invlists)
535
546
  : IndexIVFAdditiveQuantizerFastScan(
536
547
  quantizer,
537
548
  nullptr,
538
549
  d,
539
550
  nlist,
540
551
  metric,
541
- bbs),
552
+ bbs,
553
+ own_invlists),
542
554
  prq(d, nsplits, Msub, nbits, search_type) {
543
555
  FAISS_THROW_IF_NOT(nbits == 4);
544
- init(&prq, nlist, metric, bbs);
556
+ init(&prq, nlist, metric, bbs, own_invlists);
545
557
  }
546
558
 
547
559
  IndexIVFProductResidualQuantizerFastScan::
@@ -50,9 +50,15 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
50
50
  size_t d,
51
51
  size_t nlist,
52
52
  MetricType metric = METRIC_L2,
53
- int bbs = 32);
53
+ int bbs = 32,
54
+ bool own_invlists = true);
54
55
 
55
- void init(AdditiveQuantizer* aq, size_t nlist, MetricType metric, int bbs);
56
+ void init(
57
+ AdditiveQuantizer* aq,
58
+ size_t nlist,
59
+ MetricType metric,
60
+ int bbs,
61
+ bool own_invlists);
56
62
 
57
63
  IndexIVFAdditiveQuantizerFastScan();
58
64
 
@@ -110,7 +116,8 @@ struct IndexIVFLocalSearchQuantizerFastScan
110
116
  size_t nbits,
111
117
  MetricType metric = METRIC_L2,
112
118
  Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
113
- int bbs = 32);
119
+ int bbs = 32,
120
+ bool own_invlists = true);
114
121
 
115
122
  IndexIVFLocalSearchQuantizerFastScan();
116
123
  };
@@ -126,7 +133,8 @@ struct IndexIVFResidualQuantizerFastScan : IndexIVFAdditiveQuantizerFastScan {
126
133
  size_t nbits,
127
134
  MetricType metric = METRIC_L2,
128
135
  Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
129
- int bbs = 32);
136
+ int bbs = 32,
137
+ bool own_invlists = true);
130
138
 
131
139
  IndexIVFResidualQuantizerFastScan();
132
140
  };
@@ -144,7 +152,8 @@ struct IndexIVFProductLocalSearchQuantizerFastScan
144
152
  size_t nbits,
145
153
  MetricType metric = METRIC_L2,
146
154
  Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
147
- int bbs = 32);
155
+ int bbs = 32,
156
+ bool own_invlists = true);
148
157
 
149
158
  IndexIVFProductLocalSearchQuantizerFastScan();
150
159
  };
@@ -162,7 +171,8 @@ struct IndexIVFProductResidualQuantizerFastScan
162
171
  size_t nbits,
163
172
  MetricType metric = METRIC_L2,
164
173
  Search_type_t search_type = AdditiveQuantizer::ST_norm_lsq2x4,
165
- int bbs = 32);
174
+ int bbs = 32,
175
+ bool own_invlists = true);
166
176
 
167
177
  IndexIVFProductResidualQuantizerFastScan();
168
178
  };
@@ -8,7 +8,6 @@
8
8
  #include <faiss/IndexIVFFastScan.h>
9
9
 
10
10
  #include <cassert>
11
- #include <cinttypes>
12
11
  #include <cstdio>
13
12
  #include <set>
14
13
 
@@ -40,8 +39,9 @@ IndexIVFFastScan::IndexIVFFastScan(
40
39
  size_t d,
41
40
  size_t nlist,
42
41
  size_t code_size,
43
- MetricType metric)
44
- : IndexIVF(quantizer, d, nlist, code_size, metric) {
42
+ MetricType metric,
43
+ bool own_invlists)
44
+ : IndexIVF(quantizer, d, nlist, code_size, metric, own_invlists) {
45
45
  // unlike other indexes, we prefer no residuals for performance reasons.
46
46
  by_residual = false;
47
47
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
@@ -60,7 +60,8 @@ void IndexIVFFastScan::init_fastscan(
60
60
  size_t nbits_init,
61
61
  size_t nlist,
62
62
  MetricType /* metric */,
63
- int bbs_2) {
63
+ int bbs_2,
64
+ bool own_invlists) {
64
65
  FAISS_THROW_IF_NOT(bbs_2 % 32 == 0);
65
66
  FAISS_THROW_IF_NOT(nbits_init == 4);
66
67
  FAISS_THROW_IF_NOT(fine_quantizer->d == d);
@@ -75,7 +76,9 @@ void IndexIVFFastScan::init_fastscan(
75
76
  FAISS_THROW_IF_NOT(code_size == fine_quantizer->code_size);
76
77
 
77
78
  is_trained = false;
78
- replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
79
+ if (own_invlists) {
80
+ replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
81
+ }
79
82
  }
80
83
 
81
84
  void IndexIVFFastScan::init_code_packer() {
@@ -793,11 +796,13 @@ void IndexIVFFastScan::search_implem_1(
793
796
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
794
797
  }
795
798
  idx_t list_no = cq.ids[i * nprobe + j];
796
- if (list_no < 0)
799
+ if (list_no < 0) {
797
800
  continue;
801
+ }
798
802
  size_t ls = orig_invlists->list_size(list_no);
799
- if (ls == 0)
803
+ if (ls == 0) {
800
804
  continue;
805
+ }
801
806
  InvertedLists::ScopedCodes codes(orig_invlists, list_no);
802
807
  InvertedLists::ScopedIds ids(orig_invlists, list_no);
803
808
 
@@ -815,7 +820,7 @@ void IndexIVFFastScan::search_implem_1(
815
820
  heap_ids,
816
821
  scaler);
817
822
  nlist_visited++;
818
- ndis++;
823
+ ndis += ls;
819
824
  }
820
825
  heap_reorder<C>(k, heap_dis, heap_ids);
821
826
  }
@@ -864,11 +869,13 @@ void IndexIVFFastScan::search_implem_2(
864
869
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
865
870
  }
866
871
  idx_t list_no = cq.ids[i * nprobe + j];
867
- if (list_no < 0)
872
+ if (list_no < 0) {
868
873
  continue;
874
+ }
869
875
  size_t ls = orig_invlists->list_size(list_no);
870
- if (ls == 0)
876
+ if (ls == 0) {
871
877
  continue;
878
+ }
872
879
  InvertedLists::ScopedCodes codes(orig_invlists, list_no);
873
880
  InvertedLists::ScopedIds ids(orig_invlists, list_no);
874
881
 
@@ -926,7 +933,7 @@ void IndexIVFFastScan::search_implem_10(
926
933
 
927
934
  bool single_LUT = !lookup_table_is_3d();
928
935
 
929
- size_t ndis = 0;
936
+ size_t ndis = 0, nlist_visited = 0;
930
937
  int qmap1[1];
931
938
 
932
939
  handler.q_map = qmap1;
@@ -974,13 +981,14 @@ void IndexIVFFastScan::search_implem_10(
974
981
  handler,
975
982
  scaler);
976
983
 
977
- ndis++;
984
+ ndis += ls;
985
+ nlist_visited++;
978
986
  }
979
987
  }
980
988
 
981
989
  handler.end();
982
990
  *ndis_out = ndis;
983
- *nlist_out = nlist;
991
+ *nlist_out = nlist_visited;
984
992
  }
985
993
 
986
994
  void IndexIVFFastScan::search_implem_12(
@@ -1040,7 +1048,7 @@ void IndexIVFFastScan::search_implem_12(
1040
1048
  handler.dbias = tmp_bias.data();
1041
1049
  }
1042
1050
 
1043
- size_t ndis = 0;
1051
+ size_t ndis = 0, nlist_visited = 0;
1044
1052
 
1045
1053
  size_t i0 = 0;
1046
1054
  uint64_t t_copy_pack = 0, t_scan = 0;
@@ -1062,6 +1070,7 @@ void IndexIVFFastScan::search_implem_12(
1062
1070
  i0 = i1;
1063
1071
  continue;
1064
1072
  }
1073
+ nlist_visited++;
1065
1074
 
1066
1075
  // re-organize LUTs and biases into the right order
1067
1076
  int nc = i1 - i0;
@@ -1120,7 +1129,7 @@ void IndexIVFFastScan::search_implem_12(
1120
1129
  IVFFastScan_stats.t_scan += t_scan;
1121
1130
 
1122
1131
  *ndis_out = ndis;
1123
- *nlist_out = nlist;
1132
+ *nlist_out = nlist_visited;
1124
1133
  }
1125
1134
 
1126
1135
  void IndexIVFFastScan::search_implem_14(
@@ -68,7 +68,8 @@ struct IndexIVFFastScan : IndexIVF {
68
68
  size_t d,
69
69
  size_t nlist,
70
70
  size_t code_size,
71
- MetricType metric = METRIC_L2);
71
+ MetricType metric = METRIC_L2,
72
+ bool own_invlists = true);
72
73
 
73
74
  IndexIVFFastScan();
74
75
 
@@ -79,7 +80,8 @@ struct IndexIVFFastScan : IndexIVF {
79
80
  size_t nbits,
80
81
  size_t nlist,
81
82
  MetricType metric,
82
- int bbs);
83
+ int bbs,
84
+ bool own_invlists);
83
85
 
84
86
  // initialize the CodePacker in the InvertedLists
85
87
  void init_code_packer();
@@ -21,6 +21,7 @@
21
21
 
22
22
  #include <faiss/impl/FaissAssert.h>
23
23
  #include <faiss/utils/distances.h>
24
+ #include <faiss/utils/extra_distances.h>
24
25
  #include <faiss/utils/utils.h>
25
26
 
26
27
  namespace faiss {
@@ -33,8 +34,15 @@ IndexIVFFlat::IndexIVFFlat(
33
34
  Index* quantizer,
34
35
  size_t d,
35
36
  size_t nlist,
36
- MetricType metric)
37
- : IndexIVF(quantizer, d, nlist, sizeof(float) * d, metric) {
37
+ MetricType metric,
38
+ bool own_invlists)
39
+ : IndexIVF(
40
+ quantizer,
41
+ d,
42
+ nlist,
43
+ sizeof(float) * d,
44
+ metric,
45
+ own_invlists) {
38
46
  code_size = sizeof(float) * d;
39
47
  by_residual = false;
40
48
  }
@@ -115,6 +123,18 @@ void IndexIVFFlat::encode_vectors(
115
123
  }
116
124
  }
117
125
 
126
+ void IndexIVFFlat::decode_vectors(
127
+ idx_t n,
128
+ const uint8_t* codes,
129
+ const idx_t* /*listnos*/,
130
+ float* x) const {
131
+ for (size_t i = 0; i < n; i++) {
132
+ const uint8_t* code = codes + i * code_size;
133
+ float* xi = x + i * d;
134
+ memcpy(xi, code, code_size);
135
+ }
136
+ }
137
+
118
138
  void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
119
139
  size_t coarse_size = coarse_code_size();
120
140
  for (size_t i = 0; i < n; i++) {
@@ -126,13 +146,18 @@ void IndexIVFFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
126
146
 
127
147
  namespace {
128
148
 
129
- template <MetricType metric, class C, bool use_sel>
149
+ template <typename VectorDistance, bool use_sel>
130
150
  struct IVFFlatScanner : InvertedListScanner {
131
- size_t d;
132
-
133
- IVFFlatScanner(size_t d, bool store_pairs, const IDSelector* sel)
134
- : InvertedListScanner(store_pairs, sel), d(d) {
135
- keep_max = is_similarity_metric(metric);
151
+ VectorDistance vd;
152
+ using C = typename VectorDistance::C;
153
+
154
+ IVFFlatScanner(
155
+ const VectorDistance& vd,
156
+ bool store_pairs,
157
+ const IDSelector* sel)
158
+ : InvertedListScanner(store_pairs, sel), vd(vd) {
159
+ keep_max = vd.is_similarity;
160
+ code_size = vd.d * sizeof(float);
136
161
  }
137
162
 
138
163
  const float* xi;
@@ -146,10 +171,7 @@ struct IVFFlatScanner : InvertedListScanner {
146
171
 
147
172
  float distance_to_code(const uint8_t* code) const override {
148
173
  const float* yj = (float*)code;
149
- float dis = metric == METRIC_INNER_PRODUCT
150
- ? fvec_inner_product(xi, yj, d)
151
- : fvec_L2sqr(xi, yj, d);
152
- return dis;
174
+ return vd(xi, yj);
153
175
  }
154
176
 
155
177
  size_t scan_codes(
@@ -162,13 +184,11 @@ struct IVFFlatScanner : InvertedListScanner {
162
184
  const float* list_vecs = (const float*)codes;
163
185
  size_t nup = 0;
164
186
  for (size_t j = 0; j < list_size; j++) {
165
- const float* yj = list_vecs + d * j;
187
+ const float* yj = list_vecs + vd.d * j;
166
188
  if (use_sel && !sel->is_member(ids[j])) {
167
189
  continue;
168
190
  }
169
- float dis = metric == METRIC_INNER_PRODUCT
170
- ? fvec_inner_product(xi, yj, d)
171
- : fvec_L2sqr(xi, yj, d);
191
+ float dis = vd(xi, yj);
172
192
  if (C::cmp(simi[0], dis)) {
173
193
  int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
174
194
  heap_replace_top<C>(k, simi, idxi, dis, id);
@@ -186,13 +206,11 @@ struct IVFFlatScanner : InvertedListScanner {
186
206
  RangeQueryResult& res) const override {
187
207
  const float* list_vecs = (const float*)codes;
188
208
  for (size_t j = 0; j < list_size; j++) {
189
- const float* yj = list_vecs + d * j;
209
+ const float* yj = list_vecs + vd.d * j;
190
210
  if (use_sel && !sel->is_member(ids[j])) {
191
211
  continue;
192
212
  }
193
- float dis = metric == METRIC_INNER_PRODUCT
194
- ? fvec_inner_product(xi, yj, d)
195
- : fvec_L2sqr(xi, yj, d);
213
+ float dis = vd(xi, yj);
196
214
  if (C::cmp(radius, dis)) {
197
215
  int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
198
216
  res.add(dis, id);
@@ -201,23 +219,22 @@ struct IVFFlatScanner : InvertedListScanner {
201
219
  }
202
220
  };
203
221
 
204
- template <bool use_sel>
205
- InvertedListScanner* get_InvertedListScanner1(
206
- const IndexIVFFlat* ivf,
207
- bool store_pairs,
208
- const IDSelector* sel) {
209
- if (ivf->metric_type == METRIC_INNER_PRODUCT) {
210
- return new IVFFlatScanner<
211
- METRIC_INNER_PRODUCT,
212
- CMin<float, int64_t>,
213
- use_sel>(ivf->d, store_pairs, sel);
214
- } else if (ivf->metric_type == METRIC_L2) {
215
- return new IVFFlatScanner<METRIC_L2, CMax<float, int64_t>, use_sel>(
216
- ivf->d, store_pairs, sel);
217
- } else {
218
- FAISS_THROW_MSG("metric type not supported");
222
+ struct Run_get_InvertedListScanner {
223
+ using T = InvertedListScanner*;
224
+
225
+ template <class VD>
226
+ InvertedListScanner* f(
227
+ VD& vd,
228
+ const IndexIVFFlat* ivf,
229
+ bool store_pairs,
230
+ const IDSelector* sel) {
231
+ if (sel) {
232
+ return new IVFFlatScanner<VD, true>(vd, store_pairs, sel);
233
+ } else {
234
+ return new IVFFlatScanner<VD, false>(vd, store_pairs, sel);
235
+ }
219
236
  }
220
- }
237
+ };
221
238
 
222
239
  } // anonymous namespace
223
240
 
@@ -225,11 +242,9 @@ InvertedListScanner* IndexIVFFlat::get_InvertedListScanner(
225
242
  bool store_pairs,
226
243
  const IDSelector* sel,
227
244
  const IVFSearchParameters*) const {
228
- if (sel) {
229
- return get_InvertedListScanner1<true>(this, store_pairs, sel);
230
- } else {
231
- return get_InvertedListScanner1<false>(this, store_pairs, sel);
232
- }
245
+ Run_get_InvertedListScanner run;
246
+ return dispatch_VectorDistance(
247
+ d, metric_type, metric_arg, run, this, store_pairs, sel);
233
248
  }
234
249
 
235
250
  void IndexIVFFlat::reconstruct_from_offset(
@@ -247,8 +262,9 @@ IndexIVFFlatDedup::IndexIVFFlatDedup(
247
262
  Index* quantizer,
248
263
  size_t d,
249
264
  size_t nlist_,
250
- MetricType metric_type)
251
- : IndexIVFFlat(quantizer, d, nlist_, metric_type) {}
265
+ MetricType metric_type,
266
+ bool own_invlists)
267
+ : IndexIVFFlat(quantizer, d, nlist_, metric_type, own_invlists) {}
252
268
 
253
269
  void IndexIVFFlatDedup::train(idx_t n, const float* x) {
254
270
  std::unordered_map<uint64_t, idx_t> map;
@@ -26,7 +26,8 @@ struct IndexIVFFlat : IndexIVF {
26
26
  Index* quantizer,
27
27
  size_t d,
28
28
  size_t nlist_,
29
- MetricType = METRIC_L2);
29
+ MetricType = METRIC_L2,
30
+ bool own_invlists = true);
30
31
 
31
32
  void add_core(
32
33
  idx_t n,
@@ -42,6 +43,12 @@ struct IndexIVFFlat : IndexIVF {
42
43
  uint8_t* codes,
43
44
  bool include_listnos = false) const override;
44
45
 
46
+ void decode_vectors(
47
+ idx_t n,
48
+ const uint8_t* codes,
49
+ const idx_t* list_nos,
50
+ float* x) const override;
51
+
45
52
  InvertedListScanner* get_InvertedListScanner(
46
53
  bool store_pairs,
47
54
  const IDSelector* sel,
@@ -65,7 +72,8 @@ struct IndexIVFFlatDedup : IndexIVFFlat {
65
72
  Index* quantizer,
66
73
  size_t d,
67
74
  size_t nlist_,
68
- MetricType = METRIC_L2);
75
+ MetricType = METRIC_L2,
76
+ bool own_invlists = true);
69
77
 
70
78
  /// also dedups the training set
71
79
  void train(idx_t n, const float* x) override;
@@ -46,10 +46,14 @@ IndexIVFPQ::IndexIVFPQ(
46
46
  size_t nlist,
47
47
  size_t M,
48
48
  size_t nbits_per_idx,
49
- MetricType metric)
50
- : IndexIVF(quantizer, d, nlist, 0, metric), pq(d, M, nbits_per_idx) {
49
+ MetricType metric,
50
+ bool own_invlists)
51
+ : IndexIVF(quantizer, d, nlist, 0, metric, own_invlists),
52
+ pq(d, M, nbits_per_idx) {
51
53
  code_size = pq.code_size;
52
- invlists->code_size = code_size;
54
+ if (own_invlists) {
55
+ invlists->code_size = code_size;
56
+ }
53
57
  is_trained = false;
54
58
  by_residual = true;
55
59
  use_precomputed_table = 0;
@@ -181,6 +185,14 @@ void IndexIVFPQ::encode_vectors(
181
185
  }
182
186
  }
183
187
 
188
+ void IndexIVFPQ::decode_vectors(
189
+ idx_t n,
190
+ const uint8_t* codes,
191
+ const idx_t* listnos,
192
+ float* x) const {
193
+ return decode_multiple(n, listnos, codes, x);
194
+ }
195
+
184
196
  void IndexIVFPQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
185
197
  size_t coarse_size = coarse_code_size();
186
198
 
@@ -1201,6 +1213,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1201
1213
  sel(sel) {
1202
1214
  this->store_pairs = store_pairs;
1203
1215
  this->keep_max = is_similarity_metric(METRIC_TYPE);
1216
+ this->code_size = this->pq.code_size;
1204
1217
  }
1205
1218
 
1206
1219
  void set_query(const float* query) override {
@@ -56,7 +56,8 @@ struct IndexIVFPQ : IndexIVF {
56
56
  size_t nlist,
57
57
  size_t M,
58
58
  size_t nbits_per_idx,
59
- MetricType metric = METRIC_L2);
59
+ MetricType metric = METRIC_L2,
60
+ bool own_invlists = true);
60
61
 
61
62
  void encode_vectors(
62
63
  idx_t n,
@@ -65,6 +66,12 @@ struct IndexIVFPQ : IndexIVF {
65
66
  uint8_t* codes,
66
67
  bool include_listnos = false) const override;
67
68
 
69
+ void decode_vectors(
70
+ idx_t n,
71
+ const uint8_t* codes,
72
+ const idx_t* listnos,
73
+ float* x) const override;
74
+
68
75
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
69
76
 
70
77
  void add_core(
@@ -8,7 +8,6 @@
8
8
  #include <faiss/IndexIVFPQFastScan.h>
9
9
 
10
10
  #include <cassert>
11
- #include <cinttypes>
12
11
  #include <cstdio>
13
12
 
14
13
  #include <memory>
@@ -38,11 +37,13 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(
38
37
  size_t M,
39
38
  size_t nbits,
40
39
  MetricType metric,
41
- int bbs)
42
- : IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
40
+ int bbs,
41
+ bool own_invlists)
42
+ : IndexIVFFastScan(quantizer, d, nlist, 0, metric, own_invlists),
43
+ pq(d, M, nbits) {
43
44
  by_residual = false; // set to false by default because it's faster
44
45
 
45
- init_fastscan(&pq, M, nbits, nlist, metric, bbs);
46
+ init_fastscan(&pq, M, nbits, nlist, metric, bbs, own_invlists);
46
47
  }
47
48
 
48
49
  IndexIVFPQFastScan::IndexIVFPQFastScan() {
@@ -57,12 +58,19 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
57
58
  orig.d,
58
59
  orig.nlist,
59
60
  orig.pq.code_size,
60
- orig.metric_type),
61
+ orig.metric_type,
62
+ orig.own_invlists),
61
63
  pq(orig.pq) {
62
64
  FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
63
65
 
64
66
  init_fastscan(
65
- &pq, orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
67
+ &pq,
68
+ orig.pq.M,
69
+ orig.pq.nbits,
70
+ orig.nlist,
71
+ orig.metric_type,
72
+ bbs,
73
+ orig.own_invlists);
66
74
 
67
75
  by_residual = orig.by_residual;
68
76
  ntotal = orig.ntotal;