faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -32,6 +32,9 @@ struct ScalarQuantizer : Quantizer {
32
32
  QT_fp16,
33
33
  QT_8bit_direct, ///< fast indexing of uint8s
34
34
  QT_6bit, ///< 6 bits per component
35
+ QT_bf16,
36
+ QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
37
+ ///< [-128 to 127]
35
38
  };
36
39
 
37
40
  QuantizerType qtype = QT_8bit;
@@ -65,14 +68,6 @@ struct ScalarQuantizer : Quantizer {
65
68
 
66
69
  void train(size_t n, const float* x) override;
67
70
 
68
- /// Used by an IVF index to train based on the residuals
69
- void train_residual(
70
- size_t n,
71
- const float* x,
72
- Index* quantizer,
73
- bool by_residual,
74
- bool verbose);
75
-
76
71
  /** Encode a set of vectors
77
72
  *
78
73
  * @param x vectors to encode, size n * d
@@ -13,25 +13,223 @@
13
13
 
14
14
  #include <type_traits>
15
15
 
16
+ #include <faiss/impl/ProductQuantizer.h>
16
17
  #include <faiss/impl/code_distance/code_distance-generic.h>
17
18
 
19
+ // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782
20
+ #if defined(__GNUC__) && __GNUC__ < 9
21
+ #define _mm_loadu_si64(x) (_mm_loadl_epi64((__m128i_u*)x))
22
+ #endif
23
+
18
24
  namespace {
19
25
 
26
+ inline float horizontal_sum(const __m128 v) {
27
+ const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
28
+ const __m128 v1 = _mm_add_ps(v, v0);
29
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
30
+ const __m128 v3 = _mm_add_ps(v1, v2);
31
+ return _mm_cvtss_f32(v3);
32
+ }
33
+
20
34
  // Computes a horizontal sum over an __m256 register
21
- inline float horizontal_sum(const __m256 reg) {
22
- const __m256 h0 = _mm256_hadd_ps(reg, reg);
23
- const __m256 h1 = _mm256_hadd_ps(h0, h0);
35
+ inline float horizontal_sum(const __m256 v) {
36
+ const __m128 v0 =
37
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
38
+ return horizontal_sum(v0);
39
+ }
40
+
41
+ // processes a single code for M=4, ksub=256, nbits=8
42
+ float inline distance_single_code_avx2_pqdecoder8_m4(
43
+ // precomputed distances, layout (4, 256)
44
+ const float* sim_table,
45
+ const uint8_t* code) {
46
+ float result = 0;
47
+
48
+ const float* tab = sim_table;
49
+ constexpr size_t ksub = 1 << 8;
50
+
51
+ const __m128i vksub = _mm_set1_epi32(ksub);
52
+ __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
53
+ offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
54
+
55
+ // accumulators of partial sums
56
+ __m128 partialSum;
57
+
58
+ // load 4 uint8 values
59
+ const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code));
60
+ {
61
+ // convert uint8 values (low part of __m128i) to int32
62
+ // values
63
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
64
+
65
+ // add offsets
66
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
67
+
68
+ // gather 8 values, similar to 8 operations of tab[idx]
69
+ __m128 collected =
70
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
71
+
72
+ // collect partial sums
73
+ partialSum = collected;
74
+ }
75
+
76
+ // horizontal sum for partialSum
77
+ result = horizontal_sum(partialSum);
78
+ return result;
79
+ }
80
+
81
+ // processes a single code for M=8, ksub=256, nbits=8
82
+ float inline distance_single_code_avx2_pqdecoder8_m8(
83
+ // precomputed distances, layout (8, 256)
84
+ const float* sim_table,
85
+ const uint8_t* code) {
86
+ float result = 0;
87
+
88
+ const float* tab = sim_table;
89
+ constexpr size_t ksub = 1 << 8;
90
+
91
+ const __m256i vksub = _mm256_set1_epi32(ksub);
92
+ __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
93
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
94
+
95
+ // accumulators of partial sums
96
+ __m256 partialSum;
97
+
98
+ // load 8 uint8 values
99
+ const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code);
100
+ {
101
+ // convert uint8 values (low part of __m128i) to int32
102
+ // values
103
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
24
104
 
25
- // extract high and low __m128 regs from __m256
26
- const __m128 h2 = _mm256_extractf128_ps(h1, 1);
27
- const __m128 h3 = _mm256_castps256_ps128(h1);
105
+ // add offsets
106
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
28
107
 
29
- // get a final hsum into all 4 regs
30
- const __m128 h4 = _mm_add_ss(h2, h3);
108
+ // gather 8 values, similar to 8 operations of tab[idx]
109
+ __m256 collected =
110
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
31
111
 
32
- // extract f[0] from __m128
33
- const float hsum = _mm_cvtss_f32(h4);
34
- return hsum;
112
+ // collect partial sums
113
+ partialSum = collected;
114
+ }
115
+
116
+ // horizontal sum for partialSum
117
+ result = horizontal_sum(partialSum);
118
+ return result;
119
+ }
120
+
121
+ // processes four codes for M=4, ksub=256, nbits=8
122
+ inline void distance_four_codes_avx2_pqdecoder8_m4(
123
+ // precomputed distances, layout (4, 256)
124
+ const float* sim_table,
125
+ // codes
126
+ const uint8_t* __restrict code0,
127
+ const uint8_t* __restrict code1,
128
+ const uint8_t* __restrict code2,
129
+ const uint8_t* __restrict code3,
130
+ // computed distances
131
+ float& result0,
132
+ float& result1,
133
+ float& result2,
134
+ float& result3) {
135
+ constexpr intptr_t N = 4;
136
+
137
+ const float* tab = sim_table;
138
+ constexpr size_t ksub = 1 << 8;
139
+
140
+ // process 8 values
141
+ const __m128i vksub = _mm_set1_epi32(ksub);
142
+ __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
143
+ offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
144
+
145
+ // accumulators of partial sums
146
+ __m128 partialSums[N];
147
+
148
+ // load 4 uint8 values
149
+ __m128i mm1[N];
150
+ mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0));
151
+ mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1));
152
+ mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2));
153
+ mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3));
154
+
155
+ for (intptr_t j = 0; j < N; j++) {
156
+ // convert uint8 values (low part of __m128i) to int32
157
+ // values
158
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]);
159
+
160
+ // add offsets
161
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
162
+
163
+ // gather 4 values, similar to 4 operations of tab[idx]
164
+ __m128 collected =
165
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
166
+
167
+ // collect partial sums
168
+ partialSums[j] = collected;
169
+ }
170
+
171
+ // horizontal sum for partialSum
172
+ result0 = horizontal_sum(partialSums[0]);
173
+ result1 = horizontal_sum(partialSums[1]);
174
+ result2 = horizontal_sum(partialSums[2]);
175
+ result3 = horizontal_sum(partialSums[3]);
176
+ }
177
+
178
+ // processes four codes for M=8, ksub=256, nbits=8
179
+ inline void distance_four_codes_avx2_pqdecoder8_m8(
180
+ // precomputed distances, layout (8, 256)
181
+ const float* sim_table,
182
+ // codes
183
+ const uint8_t* __restrict code0,
184
+ const uint8_t* __restrict code1,
185
+ const uint8_t* __restrict code2,
186
+ const uint8_t* __restrict code3,
187
+ // computed distances
188
+ float& result0,
189
+ float& result1,
190
+ float& result2,
191
+ float& result3) {
192
+ constexpr intptr_t N = 4;
193
+
194
+ const float* tab = sim_table;
195
+ constexpr size_t ksub = 1 << 8;
196
+
197
+ // process 8 values
198
+ const __m256i vksub = _mm256_set1_epi32(ksub);
199
+ __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
200
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
201
+
202
+ // accumulators of partial sums
203
+ __m256 partialSums[N];
204
+
205
+ // load 8 uint8 values
206
+ __m128i mm1[N];
207
+ mm1[0] = _mm_loadu_si64((const __m128i_u*)code0);
208
+ mm1[1] = _mm_loadu_si64((const __m128i_u*)code1);
209
+ mm1[2] = _mm_loadu_si64((const __m128i_u*)code2);
210
+ mm1[3] = _mm_loadu_si64((const __m128i_u*)code3);
211
+
212
+ for (intptr_t j = 0; j < N; j++) {
213
+ // convert uint8 values (low part of __m128i) to int32
214
+ // values
215
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
216
+
217
+ // add offsets
218
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
219
+
220
+ // gather 8 values, similar to 8 operations of tab[idx]
221
+ __m256 collected =
222
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
223
+
224
+ // collect partial sums
225
+ partialSums[j] = collected;
226
+ }
227
+
228
+ // horizontal sum for partialSum
229
+ result0 = horizontal_sum(partialSums[0]);
230
+ result1 = horizontal_sum(partialSums[1]);
231
+ result2 = horizontal_sum(partialSums[2]);
232
+ result3 = horizontal_sum(partialSums[3]);
35
233
  }
36
234
 
37
235
  } // namespace
@@ -41,36 +239,48 @@ namespace faiss {
41
239
  template <typename PQDecoderT>
42
240
  typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>::
43
241
  type inline distance_single_code_avx2(
44
- // the product quantizer
45
- const ProductQuantizer& pq,
242
+ // number of subquantizers
243
+ const size_t M,
244
+ // number of bits per quantization index
245
+ const size_t nbits,
46
246
  // precomputed distances, layout (M, ksub)
47
247
  const float* sim_table,
48
248
  const uint8_t* code) {
49
249
  // default implementation
50
- return distance_single_code_generic<PQDecoderT>(pq, sim_table, code);
250
+ return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
51
251
  }
52
252
 
53
253
  template <typename PQDecoderT>
54
254
  typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
55
255
  type inline distance_single_code_avx2(
56
- // the product quantizer
57
- const ProductQuantizer& pq,
256
+ // number of subquantizers
257
+ const size_t M,
258
+ // number of bits per quantization index
259
+ const size_t nbits,
58
260
  // precomputed distances, layout (M, ksub)
59
261
  const float* sim_table,
60
262
  const uint8_t* code) {
263
+ if (M == 4) {
264
+ return distance_single_code_avx2_pqdecoder8_m4(sim_table, code);
265
+ }
266
+ if (M == 8) {
267
+ return distance_single_code_avx2_pqdecoder8_m8(sim_table, code);
268
+ }
269
+
61
270
  float result = 0;
271
+ constexpr size_t ksub = 1 << 8;
62
272
 
63
273
  size_t m = 0;
64
- const size_t pqM16 = pq.M / 16;
274
+ const size_t pqM16 = M / 16;
65
275
 
66
276
  const float* tab = sim_table;
67
277
 
68
278
  if (pqM16 > 0) {
69
279
  // process 16 values per loop
70
280
 
71
- const __m256i ksub = _mm256_set1_epi32(pq.ksub);
281
+ const __m256i vksub = _mm256_set1_epi32(ksub);
72
282
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
73
- offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
283
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
74
284
 
75
285
  // accumulators of partial sums
76
286
  __m256 partialSum = _mm256_setzero_ps();
@@ -91,7 +301,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
91
301
  // gather 8 values, similar to 8 operations of tab[idx]
92
302
  __m256 collected = _mm256_i32gather_ps(
93
303
  tab, indices_to_read_from, sizeof(float));
94
- tab += pq.ksub * 8;
304
+ tab += ksub * 8;
95
305
 
96
306
  // collect partial sums
97
307
  partialSum = _mm256_add_ps(partialSum, collected);
@@ -111,7 +321,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
111
321
  // gather 8 values, similar to 8 operations of tab[idx]
112
322
  __m256 collected = _mm256_i32gather_ps(
113
323
  tab, indices_to_read_from, sizeof(float));
114
- tab += pq.ksub * 8;
324
+ tab += ksub * 8;
115
325
 
116
326
  // collect partial sums
117
327
  partialSum = _mm256_add_ps(partialSum, collected);
@@ -123,13 +333,13 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
123
333
  }
124
334
 
125
335
  //
126
- if (m < pq.M) {
336
+ if (m < M) {
127
337
  // process leftovers
128
- PQDecoder8 decoder(code + m, pq.nbits);
338
+ PQDecoder8 decoder(code + m, nbits);
129
339
 
130
- for (; m < pq.M; m++) {
340
+ for (; m < M; m++) {
131
341
  result += tab[decoder.decode()];
132
- tab += pq.ksub;
342
+ tab += ksub;
133
343
  }
134
344
  }
135
345
 
@@ -140,8 +350,10 @@ template <typename PQDecoderT>
140
350
  typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
141
351
  type
142
352
  distance_four_codes_avx2(
143
- // the product quantizer
144
- const ProductQuantizer& pq,
353
+ // number of subquantizers
354
+ const size_t M,
355
+ // number of bits per quantization index
356
+ const size_t nbits,
145
357
  // precomputed distances, layout (M, ksub)
146
358
  const float* sim_table,
147
359
  // codes
@@ -155,7 +367,8 @@ typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
155
367
  float& result2,
156
368
  float& result3) {
157
369
  distance_four_codes_generic<PQDecoderT>(
158
- pq,
370
+ M,
371
+ nbits,
159
372
  sim_table,
160
373
  code0,
161
374
  code1,
@@ -171,8 +384,10 @@ typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
171
384
  template <typename PQDecoderT>
172
385
  typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type
173
386
  distance_four_codes_avx2(
174
- // the product quantizer
175
- const ProductQuantizer& pq,
387
+ // number of subquantizers
388
+ const size_t M,
389
+ // number of bits per quantization index
390
+ const size_t nbits,
176
391
  // precomputed distances, layout (M, ksub)
177
392
  const float* sim_table,
178
393
  // codes
@@ -185,13 +400,41 @@ distance_four_codes_avx2(
185
400
  float& result1,
186
401
  float& result2,
187
402
  float& result3) {
403
+ if (M == 4) {
404
+ distance_four_codes_avx2_pqdecoder8_m4(
405
+ sim_table,
406
+ code0,
407
+ code1,
408
+ code2,
409
+ code3,
410
+ result0,
411
+ result1,
412
+ result2,
413
+ result3);
414
+ return;
415
+ }
416
+ if (M == 8) {
417
+ distance_four_codes_avx2_pqdecoder8_m8(
418
+ sim_table,
419
+ code0,
420
+ code1,
421
+ code2,
422
+ code3,
423
+ result0,
424
+ result1,
425
+ result2,
426
+ result3);
427
+ return;
428
+ }
429
+
188
430
  result0 = 0;
189
431
  result1 = 0;
190
432
  result2 = 0;
191
433
  result3 = 0;
434
+ constexpr size_t ksub = 1 << 8;
192
435
 
193
436
  size_t m = 0;
194
- const size_t pqM16 = pq.M / 16;
437
+ const size_t pqM16 = M / 16;
195
438
 
196
439
  constexpr intptr_t N = 4;
197
440
 
@@ -199,9 +442,9 @@ distance_four_codes_avx2(
199
442
 
200
443
  if (pqM16 > 0) {
201
444
  // process 16 values per loop
202
- const __m256i ksub = _mm256_set1_epi32(pq.ksub);
445
+ const __m256i vksub = _mm256_set1_epi32(ksub);
203
446
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
204
- offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
447
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
205
448
 
206
449
  // accumulators of partial sums
207
450
  __m256 partialSums[N];
@@ -235,7 +478,7 @@ distance_four_codes_avx2(
235
478
  // collect partial sums
236
479
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
237
480
  }
238
- tab += pq.ksub * 8;
481
+ tab += ksub * 8;
239
482
 
240
483
  // process next 8 codes
241
484
  for (intptr_t j = 0; j < N; j++) {
@@ -259,7 +502,7 @@ distance_four_codes_avx2(
259
502
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
260
503
  }
261
504
 
262
- tab += pq.ksub * 8;
505
+ tab += ksub * 8;
263
506
  }
264
507
 
265
508
  // horizontal sum for partialSum
@@ -270,18 +513,18 @@ distance_four_codes_avx2(
270
513
  }
271
514
 
272
515
  //
273
- if (m < pq.M) {
516
+ if (m < M) {
274
517
  // process leftovers
275
- PQDecoder8 decoder0(code0 + m, pq.nbits);
276
- PQDecoder8 decoder1(code1 + m, pq.nbits);
277
- PQDecoder8 decoder2(code2 + m, pq.nbits);
278
- PQDecoder8 decoder3(code3 + m, pq.nbits);
279
- for (; m < pq.M; m++) {
518
+ PQDecoder8 decoder0(code0 + m, nbits);
519
+ PQDecoder8 decoder1(code1 + m, nbits);
520
+ PQDecoder8 decoder2(code2 + m, nbits);
521
+ PQDecoder8 decoder3(code3 + m, nbits);
522
+ for (; m < M; m++) {
280
523
  result0 += tab[decoder0.decode()];
281
524
  result1 += tab[decoder1.decode()];
282
525
  result2 += tab[decoder2.decode()];
283
526
  result3 += tab[decoder3.decode()];
284
- tab += pq.ksub;
527
+ tab += ksub;
285
528
  }
286
529
  }
287
530
  }