faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -8,13 +8,11 @@
8
8
  #include <faiss/IndexAdditiveQuantizer.h>
9
9
 
10
10
  #include <algorithm>
11
- #include <cmath>
12
11
  #include <cstring>
13
12
 
14
13
  #include <faiss/impl/FaissAssert.h>
15
14
  #include <faiss/impl/ResidualQuantizer.h>
16
15
  #include <faiss/impl/ResultHandler.h>
17
- #include <faiss/utils/distances.h>
18
16
  #include <faiss/utils/extra_distances.h>
19
17
 
20
18
  namespace faiss {
@@ -189,17 +187,14 @@ void search_with_LUT(
189
187
  FlatCodesDistanceComputer* IndexAdditiveQuantizer::
190
188
  get_FlatCodesDistanceComputer() const {
191
189
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
192
- if (metric_type == METRIC_L2) {
193
- using VD = VectorDistance<METRIC_L2>;
194
- VD vd = {size_t(d), metric_arg};
195
- return new AQDistanceComputerDecompress<VD>(*this, vd);
196
- } else if (metric_type == METRIC_INNER_PRODUCT) {
197
- using VD = VectorDistance<METRIC_INNER_PRODUCT>;
198
- VD vd = {size_t(d), metric_arg};
199
- return new AQDistanceComputerDecompress<VD>(*this, vd);
200
- } else {
201
- FAISS_THROW_MSG("unsupported metric");
202
- }
190
+ return with_VectorDistance(
191
+ d,
192
+ metric_type,
193
+ metric_arg,
194
+ [&](auto vd) -> FlatCodesDistanceComputer* {
195
+ return new AQDistanceComputerDecompress<decltype(vd)>(
196
+ *this, vd);
197
+ });
203
198
  } else {
204
199
  if (metric_type == METRIC_INNER_PRODUCT) {
205
200
  return new AQDistanceComputerLUT<
@@ -242,17 +237,17 @@ void IndexAdditiveQuantizer::search(
242
237
  !params, "search params not supported for this index");
243
238
 
244
239
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
245
- if (metric_type == METRIC_L2) {
246
- using VD = VectorDistance<METRIC_L2>;
247
- VD vd = {size_t(d), metric_arg};
248
- HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
249
- search_with_decompress(*this, x, vd, rh);
250
- } else if (metric_type == METRIC_INNER_PRODUCT) {
251
- using VD = VectorDistance<METRIC_INNER_PRODUCT>;
252
- VD vd = {size_t(d), metric_arg};
253
- HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
254
- search_with_decompress(*this, x, vd, rh);
255
- }
240
+ with_VectorDistance(d, metric_type, metric_arg, [&](auto vd) {
241
+ if constexpr (decltype(vd)::is_similarity) {
242
+ HeapBlockResultHandler<CMin<float, idx_t>> rh(
243
+ n, distances, labels, k);
244
+ search_with_decompress(*this, x, vd, rh);
245
+ } else {
246
+ HeapBlockResultHandler<CMax<float, idx_t>> rh(
247
+ n, distances, labels, k);
248
+ search_with_decompress(*this, x, vd, rh);
249
+ }
250
+ });
256
251
  } else {
257
252
  if (metric_type == METRIC_INNER_PRODUCT) {
258
253
  HeapBlockResultHandler<CMin<float, idx_t>> rh(
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <cinttypes>
14
14
  #include <cstring>
15
+ #include <typeinfo>
15
16
 
16
17
  namespace faiss {
17
18
 
@@ -22,6 +22,7 @@
22
22
  #include <faiss/impl/DistanceComputer.h>
23
23
  #include <faiss/impl/FaissAssert.h>
24
24
  #include <faiss/impl/ResultHandler.h>
25
+ #include <faiss/impl/VisitedTable.h>
25
26
  #include <faiss/utils/Heap.h>
26
27
  #include <faiss/utils/hamming.h>
27
28
  #include <faiss/utils/random.h>
@@ -205,10 +206,14 @@ void IndexBinaryHNSW::search(
205
206
  idx_t k,
206
207
  int32_t* distances,
207
208
  idx_t* labels,
208
- const SearchParameters* params) const {
209
- FAISS_THROW_IF_NOT_MSG(
210
- !params, "search params not supported for this index");
209
+ const SearchParameters* params_in) const {
211
210
  FAISS_THROW_IF_NOT(k > 0);
211
+ const SearchParametersHNSW* params = nullptr;
212
+ if (params_in) {
213
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
214
+ FAISS_THROW_IF_NOT_MSG(
215
+ params, "IndexBinaryHNSW params have incorrect type");
216
+ }
212
217
 
213
218
  // we use the buffer for distances as float but convert them back
214
219
  // to int in the end
@@ -231,7 +236,7 @@ void IndexBinaryHNSW::search(
231
236
  // as the index parameter. This state does not get used in the
232
237
  // search function, as it is merely there to to enable Panorama
233
238
  // execution for IndexHNSWFlatPanorama.
234
- hnsw.search(*dis, nullptr, res, vt);
239
+ hnsw.search(*dis, nullptr, res, vt, params_in);
235
240
  res.end();
236
241
  }
237
242
  }
@@ -14,6 +14,7 @@
14
14
  #include <cstdio>
15
15
 
16
16
  #include <algorithm>
17
+ #include <limits>
17
18
  #include <memory>
18
19
 
19
20
  #include <faiss/IndexFlat.h>
@@ -120,25 +121,46 @@ void IndexBinaryIVF::search(
120
121
  idx_t k,
121
122
  int32_t* distances,
122
123
  idx_t* labels,
123
- const SearchParameters* params) const {
124
- FAISS_THROW_IF_NOT_MSG(
125
- !params, "search params not supported for this index");
124
+ const SearchParameters* params_in) const {
126
125
  FAISS_THROW_IF_NOT(k > 0);
126
+ const IVFSearchParameters* params = nullptr;
127
+ if (params_in) {
128
+ params = dynamic_cast<const IVFSearchParameters*>(params_in);
129
+ FAISS_THROW_IF_NOT_MSG(
130
+ params, "IndexBinaryIVF params have incorrect type");
131
+ FAISS_THROW_IF_MSG(
132
+ params->sel, "IDSelector is not supported for IndexBinaryIVF");
133
+ }
134
+ const size_t nprobe =
135
+ std::min(nlist, params ? params->nprobe : this->nprobe);
127
136
  FAISS_THROW_IF_NOT(nprobe > 0);
128
137
 
129
- const size_t nprobe_2 = std::min(nlist, this->nprobe);
130
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
131
- std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
138
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
139
+ std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
132
140
 
133
141
  double t0 = getmillisecs();
134
- quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
142
+ quantizer->search(
143
+ n,
144
+ x,
145
+ nprobe,
146
+ coarse_dis.get(),
147
+ idx.get(),
148
+ params ? params->quantizer_params : nullptr);
135
149
  indexIVF_stats.quantization_time += getmillisecs() - t0;
136
150
 
137
151
  t0 = getmillisecs();
138
- invlists->prefetch_lists(idx.get(), n * nprobe_2);
152
+ invlists->prefetch_lists(idx.get(), n * nprobe);
139
153
 
140
154
  search_preassigned(
141
- n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
155
+ n,
156
+ x,
157
+ k,
158
+ idx.get(),
159
+ coarse_dis.get(),
160
+ distances,
161
+ labels,
162
+ false,
163
+ params);
142
164
  indexIVF_stats.search_time += getmillisecs() - t0;
143
165
  }
144
166
 
@@ -389,6 +411,10 @@ void search_knn_hamming_heap(
389
411
  idx_t nprobe = params ? params->nprobe : ivf->nprobe;
390
412
  nprobe = std::min((idx_t)ivf->nlist, nprobe);
391
413
  idx_t max_codes = params ? params->max_codes : ivf->max_codes;
414
+ const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
415
+ if (max_codes == 0) {
416
+ max_codes = unlimited_list_size;
417
+ }
392
418
  MetricType metric_type = ivf->metric_type;
393
419
 
394
420
  // almost verbatim copy from IndexIVF::search_preassigned
@@ -437,6 +463,10 @@ void search_knn_hamming_heap(
437
463
  nlistv++;
438
464
 
439
465
  size_t list_size = ivf->invlists->list_size(key);
466
+ size_t list_size_max = max_codes - nscan;
467
+ if (list_size > list_size_max) {
468
+ list_size = list_size_max;
469
+ }
440
470
  InvertedLists::ScopedCodes scodes(ivf->invlists, key);
441
471
  std::unique_ptr<InvertedLists::ScopedIds> sids;
442
472
  const idx_t* ids = nullptr;
@@ -451,7 +481,7 @@ void search_knn_hamming_heap(
451
481
  list_size, scodes.get(), ids, simi, idxi, k);
452
482
 
453
483
  nscan += list_size;
454
- if (max_codes && nscan >= max_codes) {
484
+ if (nscan >= max_codes) {
455
485
  break;
456
486
  }
457
487
  }
@@ -525,6 +555,10 @@ void search_knn_hamming_count(
525
555
 
526
556
  nlistv++;
527
557
  size_t list_size = ivf->invlists->list_size(key);
558
+ size_t list_size_max = max_codes - nscan;
559
+ if (list_size > list_size_max) {
560
+ list_size = list_size_max;
561
+ }
528
562
  InvertedLists::ScopedCodes scodes(ivf->invlists, key);
529
563
  const uint8_t* list_vecs = scodes.get();
530
564
  const idx_t* ids =
@@ -541,7 +575,7 @@ void search_knn_hamming_count(
541
575
  }
542
576
 
543
577
  nscan += list_size;
544
- if (max_codes && nscan >= max_codes) {
578
+ if (nscan >= max_codes) {
545
579
  break;
546
580
  }
547
581
  }
@@ -20,11 +20,8 @@
20
20
  #include <faiss/impl/pq4_fast_scan.h>
21
21
  #include <faiss/impl/simd_result_handlers.h>
22
22
  #include <faiss/utils/hamming.h>
23
- #include <faiss/utils/utils.h>
24
-
25
- #include <faiss/impl/pq4_fast_scan.h>
26
- #include <faiss/impl/simd_result_handlers.h>
27
23
  #include <faiss/utils/quantize_lut.h>
24
+ #include <faiss/utils/utils.h>
28
25
 
29
26
  namespace faiss {
30
27
 
@@ -84,7 +81,8 @@ void IndexFastScan::add(idx_t n, const float* x) {
84
81
  compute_codes(tmp_codes.get(), n, x);
85
82
 
86
83
  ntotal2 = roundup(ntotal + n, bbs);
87
- size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
84
+ size_t n_blocks = ntotal2 / bbs;
85
+ size_t new_size = n_blocks * get_block_stride();
88
86
  size_t old_size = codes.size();
89
87
  if (new_size > old_size) {
90
88
  codes.resize(new_size);
@@ -92,7 +90,15 @@ void IndexFastScan::add(idx_t n, const float* x) {
92
90
  }
93
91
 
94
92
  pq4_pack_codes_range(
95
- tmp_codes.get(), M, ntotal, ntotal + n, bbs, M2, codes.get());
93
+ tmp_codes.get(),
94
+ M,
95
+ ntotal,
96
+ ntotal + n,
97
+ bbs,
98
+ M2,
99
+ codes.get(),
100
+ 0,
101
+ get_block_stride());
96
102
 
97
103
  ntotal += n;
98
104
  }
@@ -101,17 +107,25 @@ CodePacker* IndexFastScan::get_CodePacker() const {
101
107
  return new CodePackerPQ4(M, bbs);
102
108
  }
103
109
 
110
+ size_t IndexFastScan::get_block_stride() const {
111
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
112
+ FAISS_THROW_IF_NOT_MSG(
113
+ packer->nvec == static_cast<size_t>(bbs),
114
+ "CodePacker must pack bbs vectors per block for fast-scan");
115
+ return packer->block_size;
116
+ }
117
+
104
118
  size_t IndexFastScan::remove_ids(const IDSelector& sel) {
105
119
  idx_t j = 0;
106
120
  std::vector<uint8_t> buffer(code_size);
107
- CodePackerPQ4 packer(M, bbs);
121
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
108
122
  for (idx_t i = 0; i < ntotal; i++) {
109
123
  if (sel.is_member(i)) {
110
124
  // should be removed
111
125
  } else {
112
126
  if (i > j) {
113
- packer.unpack_1(codes.data(), i, buffer.data());
114
- packer.pack_1(buffer.data(), j, codes.data());
127
+ packer->unpack_1(codes.data(), i, buffer.data());
128
+ packer->pack_1(buffer.data(), j, codes.data());
115
129
  }
116
130
  j++;
117
131
  }
@@ -120,8 +134,7 @@ size_t IndexFastScan::remove_ids(const IDSelector& sel) {
120
134
  if (nremove > 0) {
121
135
  ntotal = j;
122
136
  ntotal2 = roundup(ntotal, bbs);
123
- size_t new_size = ntotal2 * M2 / 2;
124
- codes.resize(new_size);
137
+ codes.resize(ntotal2 / bbs * get_block_stride());
125
138
  }
126
139
  return nremove;
127
140
  }
@@ -143,13 +156,14 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
143
156
  check_compatible_for_merge(otherIndex);
144
157
  IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex);
145
158
  ntotal2 = roundup(ntotal + other->ntotal, bbs);
146
- codes.resize(ntotal2 * M2 / 2);
159
+ codes.resize(ntotal2 / bbs * get_block_stride());
147
160
  std::vector<uint8_t> buffer(code_size);
148
- CodePackerPQ4 packer(M, bbs);
161
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
162
+ std::unique_ptr<CodePacker> other_packer(other->get_CodePacker());
149
163
 
150
164
  for (int i = 0; i < other->ntotal; i++) {
151
- packer.unpack_1(other->codes.data(), i, buffer.data());
152
- packer.pack_1(buffer.data(), ntotal + i, codes.data());
165
+ other_packer->unpack_1(other->codes.data(), i, buffer.data());
166
+ packer->pack_1(buffer.data(), ntotal + i, codes.data());
153
167
  }
154
168
  ntotal += other->ntotal;
155
169
  other->reset();
@@ -531,7 +545,8 @@ void IndexFastScan::search_implem_12(
531
545
  codes.get(),
532
546
  LUT.get(),
533
547
  *handler.get(),
534
- context.norm_scaler);
548
+ context.norm_scaler,
549
+ get_block_stride());
535
550
  }
536
551
  if (!(skip & 8)) {
537
552
  handler->end();
@@ -614,7 +629,8 @@ void IndexFastScan::search_implem_14(
614
629
  codes.get(),
615
630
  LUT.get(),
616
631
  *handler.get(),
617
- context.norm_scaler);
632
+ context.norm_scaler,
633
+ get_block_stride());
618
634
  }
619
635
  if (!(skip & 8)) {
620
636
  handler->end();
@@ -639,11 +655,8 @@ template void IndexFastScan::search_dispatch_implem<false>(
639
655
 
640
656
  void IndexFastScan::reconstruct(idx_t key, float* recons) const {
641
657
  std::vector<uint8_t> code(code_size, 0);
642
- BitstringWriter bsw(code.data(), code_size);
643
- for (size_t m = 0; m < M; m++) {
644
- uint8_t c = pq4_get_packed_element(codes.data(), bbs, M2, key, m);
645
- bsw.write(c, nbits);
646
- }
658
+ std::unique_ptr<CodePacker> packer(get_CodePacker());
659
+ packer->unpack_1(codes.data(), key, code.data());
647
660
  sa_decode(1, code.data(), recons);
648
661
  }
649
662
 
@@ -214,7 +214,16 @@ struct IndexFastScan : Index {
214
214
  *
215
215
  * @return pointer to the code packer
216
216
  */
217
- CodePacker* get_CodePacker() const;
217
+ virtual CodePacker* get_CodePacker() const;
218
+
219
+ /** Get stride in bytes between consecutive SIMD blocks.
220
+ *
221
+ * Derived from get_CodePacker()->block_size so that there is a
222
+ * single source of truth for the block layout.
223
+ *
224
+ * @return stride in bytes
225
+ */
226
+ size_t get_block_stride() const;
218
227
 
219
228
  /** Merge another index into this one
220
229
  *