faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -19,7 +19,6 @@ namespace faiss {
19
19
  struct ProductQuantizer;
20
20
  struct ScalarQuantizer;
21
21
 
22
- void read_index_header(Index* idx, IOReader* f);
23
22
  void read_direct_map(DirectMap* dm, IOReader* f);
24
23
  void read_ivf_header(
25
24
  IndexIVF* ivf,
@@ -12,10 +12,14 @@
12
12
 
13
13
  #include <cstdio>
14
14
  #include <cstdlib>
15
+ #include <cstring>
15
16
 
16
17
  #include <faiss/invlists/InvertedListsIOHook.h>
17
18
 
19
+ #include <faiss/invlists/BlockInvertedLists.h>
20
+
18
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/RaBitQUtils.h>
19
23
  #include <faiss/utils/hamming.h>
20
24
 
21
25
  #include <faiss/Index2Layer.h>
@@ -101,7 +105,14 @@ static void write_index_header(const Index* idx, IOWriter* f) {
101
105
  }
102
106
 
103
107
  void write_VectorTransform(const VectorTransform* vt, IOWriter* f) {
104
- if (const LinearTransform* lt = dynamic_cast<const LinearTransform*>(vt)) {
108
+ if (const HadamardRotation* hr =
109
+ dynamic_cast<const HadamardRotation*>(vt)) {
110
+ uint32_t h = fourcc("HRot");
111
+ WRITE1(h);
112
+ WRITE1(hr->seed);
113
+ } else if (
114
+ const LinearTransform* lt =
115
+ dynamic_cast<const LinearTransform*>(vt)) {
105
116
  if (dynamic_cast<const RandomRotationMatrix*>(lt)) {
106
117
  uint32_t h = fourcc("rrot");
107
118
  WRITE1(h);
@@ -446,9 +457,9 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) {
446
457
  uint32_t h = fourcc("null");
447
458
  WRITE1(h);
448
459
  } else if (
449
- const IndexFlatL2Panorama* idxpan =
450
- dynamic_cast<const IndexFlatL2Panorama*>(idx)) {
451
- uint32_t h = fourcc("IxFP");
460
+ const IndexFlatPanorama* idxpan =
461
+ dynamic_cast<const IndexFlatPanorama*>(idx)) {
462
+ uint32_t h = fourcc(idxpan->metric_type == METRIC_L2 ? "IxFP" : "IxFp");
452
463
  WRITE1(h);
453
464
  WRITE1(idxpan->d);
454
465
  WRITE1(idxpan->n_levels);
@@ -937,13 +948,12 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) {
937
948
  } else if (
938
949
  const IndexRaBitQFastScan* idxqfs =
939
950
  dynamic_cast<const IndexRaBitQFastScan*>(idx)) {
940
- uint32_t h = fourcc("Irfs");
951
+ uint32_t h = fourcc("Irfn");
941
952
  WRITE1(h);
942
953
  write_index_header(idx, f);
943
954
  write_RaBitQuantizer(&idxqfs->rabitq, f);
944
955
  WRITEVECTOR(idxqfs->center);
945
956
  WRITE1(idxqfs->qb);
946
- WRITEVECTOR(idxqfs->flat_storage);
947
957
  WRITE1(idxqfs->bbs);
948
958
  WRITE1(idxqfs->ntotal2);
949
959
  WRITE1(idxqfs->M2);
@@ -1060,7 +1070,7 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) {
1060
1070
  else if (
1061
1071
  const IndexIVFRaBitQFastScan* ivrqfs =
1062
1072
  dynamic_cast<const IndexIVFRaBitQFastScan*>(idx)) {
1063
- uint32_t h = fourcc("Iwrf");
1073
+ uint32_t h = fourcc("Iwrn");
1064
1074
  WRITE1(h);
1065
1075
  write_ivf_header(ivrqfs, f);
1066
1076
  write_RaBitQuantizer(&ivrqfs->rabitq, f);
@@ -1072,7 +1082,6 @@ void write_index(const Index* idx, IOWriter* f, int io_flags) {
1072
1082
  WRITE1(ivrqfs->implem);
1073
1083
  WRITE1(ivrqfs->qb);
1074
1084
  WRITE1(ivrqfs->centered);
1075
- WRITEVECTOR(ivrqfs->flat_storage);
1076
1085
  write_InvertedLists(ivrqfs->invlists, f);
1077
1086
  } else {
1078
1087
  FAISS_THROW_MSG("don't know how to serialize this type of index");
@@ -18,6 +18,9 @@
18
18
  #include <queue>
19
19
  #include <unordered_set>
20
20
 
21
+ #include <faiss/impl/FaissAssert.h>
22
+
23
+ #include <faiss/impl/simd_dispatch.h>
21
24
  #include <faiss/utils/distances.h>
22
25
 
23
26
  namespace faiss {
@@ -302,18 +305,20 @@ void EnumeratedVectors::find_nn(
302
305
  }
303
306
 
304
307
  std::vector<float> c(dim);
305
- for (size_t i = 0; i < nc; i++) {
306
- uint64_t code = codes[nc];
307
- decode(code, c.data());
308
- for (size_t j = 0; j < nq; j++) {
309
- const float* x = xq + j * dim;
310
- float dis = fvec_inner_product(x, c.data(), dim);
311
- if (dis > distances[j]) {
312
- distances[j] = dis;
313
- labels[j] = i;
308
+ with_simd_level([&]<SIMDLevel SL>() {
309
+ for (size_t i = 0; i < nc; i++) {
310
+ uint64_t code = codes[nc];
311
+ decode(code, c.data());
312
+ for (size_t j = 0; j < nq; j++) {
313
+ const float* x = xq + j * dim;
314
+ float dis = fvec_inner_product<SL>(x, c.data(), dim);
315
+ if (dis > distances[j]) {
316
+ distances[j] = dis;
317
+ labels[j] = i;
318
+ }
314
319
  }
315
320
  }
316
- }
321
+ });
317
322
  }
318
323
 
319
324
  /**********************************************************
@@ -321,6 +326,12 @@ void EnumeratedVectors::find_nn(
321
326
  **********************************************************/
322
327
 
323
328
  ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
329
+ FAISS_THROW_IF_NOT_MSG(
330
+ dim > 0 && dim <= 64, "ZnSphereSearch: dim must be in [1, 64]");
331
+ FAISS_THROW_IF_NOT_MSG(
332
+ r2 >= 0 && r2 <= 512,
333
+ "ZnSphereSearch: r2 must be in [0, 512] to avoid"
334
+ " excessive computation in sum_of_sq");
324
335
  voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
325
336
  natom = voc.size() / dim;
326
337
  }
@@ -355,13 +366,15 @@ float ZnSphereSearch::search(
355
366
  // find best
356
367
  int ibest = -1;
357
368
  float dpbest = -100;
358
- for (int i = 0; i < natom; i++) {
359
- float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
360
- if (dp > dpbest) {
361
- dpbest = dp;
362
- ibest = i;
369
+ with_simd_level([&]<SIMDLevel SL>() {
370
+ for (int i = 0; i < natom; i++) {
371
+ float dp = fvec_inner_product<SL>(voc.data() + i * dim, xperm, dim);
372
+ if (dp > dpbest) {
373
+ dpbest = dp;
374
+ ibest = i;
375
+ }
363
376
  }
364
- }
377
+ });
365
378
  // revert sort
366
379
  const float* cin = voc.data() + ibest * dim;
367
380
  for (int i = 0; i < dim; i++) {
@@ -486,14 +499,28 @@ void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
486
499
 
487
500
  ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
488
501
  : EnumeratedVectors(dim), r2(r2) {
502
+ FAISS_THROW_IF_NOT_MSG(
503
+ dim > 0 && r2 >= 0, "invalid ZnSphereCodecRec parameters");
489
504
  log2_dim = 0;
490
505
  while (dim > (1 << log2_dim)) {
491
506
  log2_dim++;
492
507
  }
493
- assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
494
-
495
- all_nv.resize((log2_dim + 1) * (r2 + 1));
496
- all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
508
+ assert(dim == (1 << log2_dim) && "dimension must be a power of 2");
509
+
510
+ // Validate allocation sizes to avoid null pointer dereference on
511
+ // allocation failure. The cumulative table has O(r2^2) entries.
512
+ size_t nv_size = (size_t)(log2_dim + 1) * (r2 + 1);
513
+ size_t nv_cum_size = nv_size * (r2 + 1);
514
+ FAISS_THROW_IF_NOT_MSG(
515
+ nv_cum_size / (r2 + 1) == nv_size,
516
+ "ZnSphereCodecRec: allocation size overflow");
517
+ // Cap at ~1GB worth of uint64_t entries
518
+ FAISS_THROW_IF_NOT_MSG(
519
+ nv_cum_size <= (size_t(1) << 27),
520
+ "ZnSphereCodecRec: r2 too large, would require excessive memory");
521
+
522
+ all_nv.resize(nv_size);
523
+ all_nv_cum.resize(nv_cum_size);
497
524
 
498
525
  for (int r2a = 0; r2a <= r2; r2a++) {
499
526
  int r = int(sqrt(r2a));
@@ -34,8 +34,15 @@ struct MmappedFileMappingOwner::PImpl {
34
34
  size_t ptr_size = 0;
35
35
 
36
36
  explicit PImpl(const std::string& filename) {
37
- auto f = std::unique_ptr<FILE, decltype(&fclose)>(
38
- fopen(filename.c_str(), "r"), &fclose);
37
+ struct FileDeleter {
38
+ void operator()(FILE* f) const {
39
+ if (f)
40
+ fclose(f);
41
+ }
42
+ };
43
+
44
+ auto f = std::unique_ptr<FILE, FileDeleter>(
45
+ fopen(filename.c_str(), "r"), FileDeleter{});
39
46
  FAISS_THROW_IF_NOT_FMT(
40
47
  f.get(),
41
48
  "could not open %s for reading: %s",
@@ -111,7 +111,8 @@ void pq4_pack_codes_range(
111
111
  size_t bbs,
112
112
  size_t nsq,
113
113
  uint8_t* blocks,
114
- size_t code_stride) {
114
+ size_t code_stride,
115
+ size_t block_stride) {
115
116
  // Determine stride: use custom if provided, otherwise use legacy
116
117
  // calculation
117
118
  size_t actual_stride = (code_stride == 0) ? (M + 1) / 2 : code_stride;
@@ -136,7 +137,7 @@ void pq4_pack_codes_range(
136
137
  size_t block1 = ((i1 - 1) / bbs) + 1;
137
138
 
138
139
  for (size_t b = block0; b < block1; b++) {
139
- uint8_t* codes2 = blocks + b * bbs * nsq / 2;
140
+ uint8_t* codes2 = blocks + b * block_stride;
140
141
  int64_t i_base = b * bbs - i0;
141
142
  for (int sq = 0; sq < nsq; sq += 2) {
142
143
  for (size_t i = 0; i < bbs; i += 32) {
@@ -272,6 +273,10 @@ void CodePackerPQ4::unpack_1(
272
273
  }
273
274
  }
274
275
 
276
+ CodePacker* CodePackerPQ4::clone() const {
277
+ return new CodePackerPQ4(*this);
278
+ }
279
+
275
280
  /***************************************************************
276
281
  * Packing functions for Look-Up Tables (LUT)
277
282
  ***************************************************************/
@@ -59,6 +59,7 @@ void pq4_pack_codes(
59
59
  * @param blocks output array, size at least ceil(i1 / bbs) * bbs * nsq / 2
60
60
  * @param code_stride optional stride between consecutive codes (0 = use
61
61
  * default (M + 1) / 2)
62
+ * @param block_stride stride in bytes between consecutive blocks.
62
63
  */
63
64
  void pq4_pack_codes_range(
64
65
  const uint8_t* codes,
@@ -68,7 +69,8 @@ void pq4_pack_codes_range(
68
69
  size_t bbs,
69
70
  size_t nsq,
70
71
  uint8_t* blocks,
71
- size_t code_stride = 0);
72
+ size_t code_stride,
73
+ size_t block_stride);
72
74
 
73
75
  /** get a single element from a packed codes table
74
76
  *
@@ -101,6 +103,8 @@ struct CodePackerPQ4 : CodePacker {
101
103
 
102
104
  CodePackerPQ4(size_t nsq, size_t bbs);
103
105
 
106
+ CodePacker* clone() const final;
107
+
104
108
  void pack_1(const uint8_t* flat_code, size_t offset, uint8_t* block)
105
109
  const final;
106
110
  void unpack_1(const uint8_t* block, size_t offset, uint8_t* flat_code)
@@ -125,6 +129,7 @@ void pq4_pack_LUT(int nq, int nsq, const uint8_t* src, uint8_t* dest);
125
129
  * @param codes packed codes array
126
130
  * @param LUT packed look-up table
127
131
  * @param scaler scaler to scale the encoded norm
132
+ * @param block_stride stride in bytes between consecutive blocks.
128
133
  */
129
134
  void pq4_accumulate_loop(
130
135
  int nq,
@@ -134,7 +139,8 @@ void pq4_accumulate_loop(
134
139
  const uint8_t* codes,
135
140
  const uint8_t* LUT,
136
141
  SIMDResultHandler& res,
137
- const NormTableScaler* scaler);
142
+ const NormTableScaler* scaler,
143
+ size_t block_stride);
138
144
 
139
145
  /* qbs versions, supported only for bbs=32.
140
146
  *
@@ -185,6 +191,7 @@ int pq4_pack_LUT_qbs_q_map(
185
191
  * @param LUT look-up table (packed)
186
192
  * @param res call-back for the results
187
193
  * @param scaler scaler to scale the encoded norm
194
+ * @param block_stride stride in bytes between consecutive blocks.
188
195
  */
189
196
  void pq4_accumulate_loop_qbs(
190
197
  int qbs,
@@ -193,7 +200,8 @@ void pq4_accumulate_loop_qbs(
193
200
  const uint8_t* codes,
194
201
  const uint8_t* LUT,
195
202
  SIMDResultHandler& res,
196
- const NormTableScaler* scaler = nullptr);
203
+ const NormTableScaler* scaler,
204
+ size_t block_stride);
197
205
 
198
206
  /** Wrapper of pq4_accumulate_loop_qbs using simple StoreResultHandler
199
207
  * and DummyScaler
@@ -123,14 +123,15 @@ void accumulate_fixed_blocks(
123
123
  const uint8_t* codes,
124
124
  const uint8_t* LUT,
125
125
  ResultHandler& res,
126
- const Scaler& scaler) {
126
+ const Scaler& scaler,
127
+ size_t block_stride) {
127
128
  constexpr int bbs = 32 * BB;
128
129
  for (size_t j0 = 0; j0 < nb; j0 += bbs) {
129
130
  FixedStorageHandler<NQ, 2 * BB> res2;
130
131
  kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2, scaler);
131
132
  res.set_block_origin(0, j0);
132
133
  res2.to_other_handler(res);
133
- codes += bbs * nsq / 2;
134
+ codes += block_stride;
134
135
  }
135
136
  }
136
137
 
@@ -143,15 +144,17 @@ void pq4_accumulate_loop_fixed_scaler(
143
144
  const uint8_t* codes,
144
145
  const uint8_t* LUT,
145
146
  ResultHandler& res,
146
- const Scaler& scaler) {
147
+ const Scaler& scaler,
148
+ size_t block_stride) {
147
149
  FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
148
150
  FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
149
151
  FAISS_THROW_IF_NOT(bbs % 32 == 0);
150
152
  FAISS_THROW_IF_NOT(nb % bbs == 0);
151
153
 
152
- #define DISPATCH(NQ, BB) \
153
- case NQ * 1000 + BB: \
154
- accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res, scaler); \
154
+ #define DISPATCH(NQ, BB) \
155
+ case NQ * 1000 + BB: \
156
+ accumulate_fixed_blocks<NQ, BB>( \
157
+ nb, nsq, codes, LUT, res, scaler, block_stride); \
155
158
  break
156
159
 
157
160
  switch (nq * 1000 + bbs / 32) {
@@ -179,14 +182,15 @@ void pq4_accumulate_loop_fixed_handler(
179
182
  const uint8_t* codes,
180
183
  const uint8_t* LUT,
181
184
  ResultHandler& res,
182
- const NormTableScaler* scaler) {
185
+ const NormTableScaler* scaler,
186
+ size_t block_stride) {
183
187
  if (scaler) {
184
188
  pq4_accumulate_loop_fixed_scaler(
185
- nq, nb, bbs, nsq, codes, LUT, res, *scaler);
189
+ nq, nb, bbs, nsq, codes, LUT, res, *scaler, block_stride);
186
190
  } else {
187
191
  DummyScaler dscaler;
188
192
  pq4_accumulate_loop_fixed_scaler(
189
- nq, nb, bbs, nsq, codes, LUT, res, dscaler);
193
+ nq, nb, bbs, nsq, codes, LUT, res, dscaler, block_stride);
190
194
  }
191
195
  }
192
196
 
@@ -199,9 +203,10 @@ struct Run_pq4_accumulate_loop {
199
203
  int nsq,
200
204
  const uint8_t* codes,
201
205
  const uint8_t* LUT,
202
- const NormTableScaler* scaler) {
206
+ const NormTableScaler* scaler,
207
+ size_t block_stride) {
203
208
  pq4_accumulate_loop_fixed_handler(
204
- nq, nb, bbs, nsq, codes, LUT, res, scaler);
209
+ nq, nb, bbs, nsq, codes, LUT, res, scaler, block_stride);
205
210
  }
206
211
  };
207
212
 
@@ -215,10 +220,11 @@ void pq4_accumulate_loop(
215
220
  const uint8_t* codes,
216
221
  const uint8_t* LUT,
217
222
  SIMDResultHandler& res,
218
- const NormTableScaler* scaler) {
223
+ const NormTableScaler* scaler,
224
+ size_t block_stride) {
219
225
  Run_pq4_accumulate_loop consumer;
220
226
  dispatch_SIMDResultHandler(
221
- res, consumer, nq, nb, bbs, nsq, codes, LUT, scaler);
227
+ res, consumer, nq, nb, bbs, nsq, codes, LUT, scaler, block_stride);
222
228
  }
223
229
 
224
230
  } // namespace faiss
@@ -565,7 +565,8 @@ void accumulate_q_4step(
565
565
  const uint8_t* codes,
566
566
  const uint8_t* LUT0,
567
567
  ResultHandler& res,
568
- const Scaler& scaler) {
568
+ const Scaler& scaler,
569
+ size_t block_stride) {
569
570
  constexpr int Q1 = QBS & 15;
570
571
  constexpr int Q2 = (QBS >> 4) & 15;
571
572
  constexpr int Q3 = (QBS >> 8) & 15;
@@ -593,7 +594,7 @@ void accumulate_q_4step(
593
594
  }
594
595
  res.set_block_origin(0, j0);
595
596
  res2.to_other_handler(res);
596
- codes += 32 * nsq / 2;
597
+ codes += block_stride;
597
598
  }
598
599
  }
599
600
 
@@ -604,11 +605,13 @@ void kernel_accumulate_block_loop(
604
605
  const uint8_t* codes,
605
606
  const uint8_t* LUT,
606
607
  ResultHandler& res,
607
- const Scaler& scaler) {
608
+ const Scaler& scaler,
609
+ size_t block_stride) {
608
610
  for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
609
611
  res.set_block_origin(0, j0);
610
612
  kernel_accumulate_block<NQ, ResultHandler>(
611
- nsq, codes + j0 * nsq / 2, LUT, res, scaler);
613
+ nsq, codes, LUT, res, scaler);
614
+ codes += block_stride;
612
615
  }
613
616
  }
614
617
 
@@ -621,14 +624,15 @@ void accumulate(
621
624
  const uint8_t* codes,
622
625
  const uint8_t* LUT,
623
626
  ResultHandler& res,
624
- const Scaler& scaler) {
627
+ const Scaler& scaler,
628
+ size_t block_stride) {
625
629
  assert(nsq % 2 == 0);
626
630
  assert(is_aligned_pointer(LUT));
627
631
 
628
- #define DISPATCH(NQ) \
629
- case NQ: \
630
- kernel_accumulate_block_loop<NQ, ResultHandler>( \
631
- ntotal2, nsq, codes, LUT, res, scaler); \
632
+ #define DISPATCH(NQ) \
633
+ case NQ: \
634
+ kernel_accumulate_block_loop<NQ, ResultHandler>( \
635
+ ntotal2, nsq, codes, LUT, res, scaler, block_stride); \
632
636
  return
633
637
 
634
638
  switch (nq) {
@@ -650,16 +654,18 @@ void pq4_accumulate_loop_qbs_fixed_scaler(
650
654
  const uint8_t* codes,
651
655
  const uint8_t* LUT0,
652
656
  ResultHandler& res,
653
- const Scaler& scaler) {
657
+ const Scaler& scaler,
658
+ size_t block_stride = 0) {
654
659
  assert(nsq % 2 == 0);
655
660
  assert(is_aligned_pointer(codes));
656
661
  assert(is_aligned_pointer(LUT0));
657
662
 
658
663
  // try out optimized versions
659
664
  switch (qbs) {
660
- #define DISPATCH(QBS) \
661
- case QBS: \
662
- accumulate_q_4step<QBS>(ntotal2, nsq, codes, LUT0, res, scaler); \
665
+ #define DISPATCH(QBS) \
666
+ case QBS: \
667
+ accumulate_q_4step<QBS>( \
668
+ ntotal2, nsq, codes, LUT0, res, scaler, block_stride); \
663
669
  return;
664
670
  DISPATCH(0x3333); // 12
665
671
  DISPATCH(0x2333); // 11
@@ -688,7 +694,6 @@ void pq4_accumulate_loop_qbs_fixed_scaler(
688
694
  }
689
695
 
690
696
  // default implementation where qbs is not known at compile time
691
-
692
697
  for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
693
698
  const uint8_t* LUT = LUT0;
694
699
  int qi = qbs;
@@ -714,7 +719,7 @@ void pq4_accumulate_loop_qbs_fixed_scaler(
714
719
  i0 += nq;
715
720
  LUT += nq * nsq * 16;
716
721
  }
717
- codes += 32 * nsq / 2;
722
+ codes += block_stride;
718
723
  }
719
724
  }
720
725
 
@@ -726,14 +731,15 @@ struct Run_pq4_accumulate_loop_qbs {
726
731
  int nsq,
727
732
  const uint8_t* codes,
728
733
  const uint8_t* LUT,
729
- const NormTableScaler* scaler) {
734
+ const NormTableScaler* scaler,
735
+ size_t block_stride) {
730
736
  if (scaler) {
731
737
  pq4_accumulate_loop_qbs_fixed_scaler(
732
- qbs, nb, nsq, codes, LUT, res, *scaler);
738
+ qbs, nb, nsq, codes, LUT, res, *scaler, block_stride);
733
739
  } else {
734
740
  DummyScaler dummy;
735
741
  pq4_accumulate_loop_qbs_fixed_scaler(
736
- qbs, nb, nsq, codes, LUT, res, dummy);
742
+ qbs, nb, nsq, codes, LUT, res, dummy, block_stride);
737
743
  }
738
744
  }
739
745
  };
@@ -747,9 +753,11 @@ void pq4_accumulate_loop_qbs(
747
753
  const uint8_t* codes,
748
754
  const uint8_t* LUT,
749
755
  SIMDResultHandler& res,
750
- const NormTableScaler* scaler) {
756
+ const NormTableScaler* scaler,
757
+ size_t block_stride) {
751
758
  Run_pq4_accumulate_loop_qbs consumer;
752
- dispatch_SIMDResultHandler(res, consumer, qbs, nb, nsq, codes, LUT, scaler);
759
+ dispatch_SIMDResultHandler(
760
+ res, consumer, qbs, nb, nsq, codes, LUT, scaler, block_stride);
753
761
  }
754
762
 
755
763
  /***************************************************************
@@ -777,7 +785,7 @@ void accumulate_to_mem(
777
785
  FAISS_THROW_IF_NOT(ntotal2 % 32 == 0);
778
786
  StoreResultHandler handler(accu, ntotal2);
779
787
  DummyScaler scaler;
780
- accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler);
788
+ accumulate(nq, ntotal2, nsq, codes, LUT, handler, scaler, 32 * nsq / 2);
781
789
  }
782
790
 
783
791
  int pq4_preferred_qbs(int n) {