faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -5,16 +5,11 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- #pragma once
9
-
10
- #ifdef __AVX2__
8
+ #ifdef COMPILE_SIMD_AVX2
11
9
 
12
10
  #include <immintrin.h>
13
11
 
14
- #include <type_traits>
15
-
16
- #include <faiss/impl/ProductQuantizer.h>
17
- #include <faiss/impl/code_distance/code_distance-generic.h>
12
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
18
13
 
19
14
  // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=78782
20
15
  #if defined(__GNUC__) && __GNUC__ < 9
@@ -31,20 +26,17 @@ inline float horizontal_sum(const __m128 v) {
31
26
  return _mm_cvtss_f32(v3);
32
27
  }
33
28
 
34
- // Computes a horizontal sum over an __m256 register
29
+ // Computes a horizontal sum over an __m256 register.
35
30
  inline float horizontal_sum(const __m256 v) {
36
31
  const __m128 v0 =
37
32
  _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
38
33
  return horizontal_sum(v0);
39
34
  }
40
35
 
41
- // processes a single code for M=4, ksub=256, nbits=8
36
+ // Processes a single code for M=4, ksub=256, nbits=8.
42
37
  float inline distance_single_code_avx2_pqdecoder8_m4(
43
- // precomputed distances, layout (4, 256)
44
38
  const float* sim_table,
45
39
  const uint8_t* code) {
46
- float result = 0;
47
-
48
40
  const float* tab = sim_table;
49
41
  constexpr size_t ksub = 1 << 8;
50
42
 
@@ -52,39 +44,19 @@ float inline distance_single_code_avx2_pqdecoder8_m4(
52
44
  __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
53
45
  offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
54
46
 
55
- // accumulators of partial sums
56
- __m128 partialSum;
57
-
58
- // load 4 uint8 values
59
47
  const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code));
60
- {
61
- // convert uint8 values (low part of __m128i) to int32
62
- // values
63
- const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
64
-
65
- // add offsets
66
- const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
48
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
49
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
50
+ __m128 collected =
51
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
67
52
 
68
- // gather 8 values, similar to 8 operations of tab[idx]
69
- __m128 collected =
70
- _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
71
-
72
- // collect partial sums
73
- partialSum = collected;
74
- }
75
-
76
- // horizontal sum for partialSum
77
- result = horizontal_sum(partialSum);
78
- return result;
53
+ return horizontal_sum(collected);
79
54
  }
80
55
 
81
- // processes a single code for M=8, ksub=256, nbits=8
56
+ // Processes a single code for M=8, ksub=256, nbits=8.
82
57
  float inline distance_single_code_avx2_pqdecoder8_m8(
83
- // precomputed distances, layout (8, 256)
84
58
  const float* sim_table,
85
59
  const uint8_t* code) {
86
- float result = 0;
87
-
88
60
  const float* tab = sim_table;
89
61
  constexpr size_t ksub = 1 << 8;
90
62
 
@@ -92,42 +64,21 @@ float inline distance_single_code_avx2_pqdecoder8_m8(
92
64
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
93
65
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
94
66
 
95
- // accumulators of partial sums
96
- __m256 partialSum;
97
-
98
- // load 8 uint8 values
99
67
  const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code);
100
- {
101
- // convert uint8 values (low part of __m128i) to int32
102
- // values
103
- const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
68
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
69
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
70
+ __m256 collected =
71
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
104
72
 
105
- // add offsets
106
- const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
107
-
108
- // gather 8 values, similar to 8 operations of tab[idx]
109
- __m256 collected =
110
- _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
111
-
112
- // collect partial sums
113
- partialSum = collected;
114
- }
115
-
116
- // horizontal sum for partialSum
117
- result = horizontal_sum(partialSum);
118
- return result;
73
+ return horizontal_sum(collected);
119
74
  }
120
75
 
121
- // processes four codes for M=4, ksub=256, nbits=8
122
76
  inline void distance_four_codes_avx2_pqdecoder8_m4(
123
- // precomputed distances, layout (4, 256)
124
77
  const float* sim_table,
125
- // codes
126
78
  const uint8_t* __restrict code0,
127
79
  const uint8_t* __restrict code1,
128
80
  const uint8_t* __restrict code2,
129
81
  const uint8_t* __restrict code3,
130
- // computed distances
131
82
  float& result0,
132
83
  float& result1,
133
84
  float& result2,
@@ -137,15 +88,12 @@ inline void distance_four_codes_avx2_pqdecoder8_m4(
137
88
  const float* tab = sim_table;
138
89
  constexpr size_t ksub = 1 << 8;
139
90
 
140
- // process 8 values
141
91
  const __m128i vksub = _mm_set1_epi32(ksub);
142
92
  __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
143
93
  offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
144
94
 
145
- // accumulators of partial sums
146
95
  __m128 partialSums[N];
147
96
 
148
- // load 4 uint8 values
149
97
  __m128i mm1[N];
150
98
  mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0));
151
99
  mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1));
@@ -153,38 +101,25 @@ inline void distance_four_codes_avx2_pqdecoder8_m4(
153
101
  mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3));
154
102
 
155
103
  for (intptr_t j = 0; j < N; j++) {
156
- // convert uint8 values (low part of __m128i) to int32
157
- // values
158
104
  const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]);
159
-
160
- // add offsets
161
105
  const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
162
-
163
- // gather 4 values, similar to 4 operations of tab[idx]
164
106
  __m128 collected =
165
107
  _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
166
-
167
- // collect partial sums
168
108
  partialSums[j] = collected;
169
109
  }
170
110
 
171
- // horizontal sum for partialSum
172
111
  result0 = horizontal_sum(partialSums[0]);
173
112
  result1 = horizontal_sum(partialSums[1]);
174
113
  result2 = horizontal_sum(partialSums[2]);
175
114
  result3 = horizontal_sum(partialSums[3]);
176
115
  }
177
116
 
178
- // processes four codes for M=8, ksub=256, nbits=8
179
117
  inline void distance_four_codes_avx2_pqdecoder8_m8(
180
- // precomputed distances, layout (8, 256)
181
118
  const float* sim_table,
182
- // codes
183
119
  const uint8_t* __restrict code0,
184
120
  const uint8_t* __restrict code1,
185
121
  const uint8_t* __restrict code2,
186
122
  const uint8_t* __restrict code3,
187
- // computed distances
188
123
  float& result0,
189
124
  float& result1,
190
125
  float& result2,
@@ -194,15 +129,12 @@ inline void distance_four_codes_avx2_pqdecoder8_m8(
194
129
  const float* tab = sim_table;
195
130
  constexpr size_t ksub = 1 << 8;
196
131
 
197
- // process 8 values
198
132
  const __m256i vksub = _mm256_set1_epi32(ksub);
199
133
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
200
134
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
201
135
 
202
- // accumulators of partial sums
203
136
  __m256 partialSums[N];
204
137
 
205
- // load 8 uint8 values
206
138
  __m128i mm1[N];
207
139
  mm1[0] = _mm_loadu_si64((const __m128i_u*)code0);
208
140
  mm1[1] = _mm_loadu_si64((const __m128i_u*)code1);
@@ -210,22 +142,13 @@ inline void distance_four_codes_avx2_pqdecoder8_m8(
210
142
  mm1[3] = _mm_loadu_si64((const __m128i_u*)code3);
211
143
 
212
144
  for (intptr_t j = 0; j < N; j++) {
213
- // convert uint8 values (low part of __m128i) to int32
214
- // values
215
145
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
216
-
217
- // add offsets
218
146
  const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
219
-
220
- // gather 8 values, similar to 8 operations of tab[idx]
221
147
  __m256 collected =
222
148
  _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
223
-
224
- // collect partial sums
225
149
  partialSums[j] = collected;
226
150
  }
227
151
 
228
- // horizontal sum for partialSum
229
152
  result0 = horizontal_sum(partialSums[0]);
230
153
  result1 = horizontal_sum(partialSums[1]);
231
154
  result2 = horizontal_sum(partialSums[2]);
@@ -235,31 +158,15 @@ inline void distance_four_codes_avx2_pqdecoder8_m8(
235
158
  } // namespace
236
159
 
237
160
  namespace faiss {
161
+ namespace pq_code_distance {
238
162
 
239
- template <typename PQDecoderT>
240
- typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>::
241
- type inline distance_single_code_avx2(
242
- // number of subquantizers
243
- const size_t M,
244
- // number of bits per quantization index
245
- const size_t nbits,
246
- // precomputed distances, layout (M, ksub)
247
- const float* sim_table,
248
- const uint8_t* code) {
249
- // default implementation
250
- return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
251
- }
252
-
253
- template <typename PQDecoderT>
254
- typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
255
- type inline distance_single_code_avx2(
256
- // number of subquantizers
257
- const size_t M,
258
- // number of bits per quantization index
259
- const size_t nbits,
260
- // precomputed distances, layout (M, ksub)
261
- const float* sim_table,
262
- const uint8_t* code) {
163
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
164
+ template <>
165
+ float pq_code_distance_single_impl<SIMDLevel::AVX2>(
166
+ size_t M,
167
+ size_t nbits,
168
+ const float* sim_table,
169
+ const uint8_t* code) {
263
170
  if (M == 4) {
264
171
  return distance_single_code_avx2_pqdecoder8_m4(sim_table, code);
265
172
  }
@@ -267,6 +174,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
267
174
  return distance_single_code_avx2_pqdecoder8_m8(sim_table, code);
268
175
  }
269
176
 
177
+ // Precomputed distances, layout (M, ksub).
270
178
  float result = 0;
271
179
  constexpr size_t ksub = 1 << 8;
272
180
 
@@ -276,67 +184,46 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
276
184
  const float* tab = sim_table;
277
185
 
278
186
  if (pqM16 > 0) {
279
- // process 16 values per loop
280
-
281
187
  const __m256i vksub = _mm256_set1_epi32(ksub);
282
188
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
283
189
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
284
190
 
285
- // accumulators of partial sums
286
191
  __m256 partialSum = _mm256_setzero_ps();
287
192
 
288
- // loop
193
+ // Process 16 values per loop iteration.
289
194
  for (m = 0; m < pqM16 * 16; m += 16) {
290
- // load 16 uint8 values
291
195
  const __m128i mm1 = _mm_loadu_si128((const __m128i_u*)(code + m));
196
+ // Process first 8 codes.
292
197
  {
293
- // convert uint8 values (low part of __m128i) to int32
294
- // values
295
198
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
296
-
297
- // add offsets
298
199
  const __m256i indices_to_read_from =
299
200
  _mm256_add_epi32(idx1, offsets_0);
300
-
301
- // gather 8 values, similar to 8 operations of tab[idx]
302
201
  __m256 collected = _mm256_i32gather_ps(
303
202
  tab, indices_to_read_from, sizeof(float));
304
203
  tab += ksub * 8;
305
-
306
- // collect partial sums
307
204
  partialSum = _mm256_add_ps(partialSum, collected);
308
205
  }
309
206
 
310
- // move high 8 uint8 to low ones
207
+ // Process next 8 codes.
311
208
  const __m128i mm2 = _mm_unpackhi_epi64(mm1, _mm_setzero_si128());
312
209
  {
313
- // convert uint8 values (low part of __m128i) to int32
314
- // values
315
210
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);
316
-
317
- // add offsets
318
211
  const __m256i indices_to_read_from =
319
212
  _mm256_add_epi32(idx1, offsets_0);
320
-
321
- // gather 8 values, similar to 8 operations of tab[idx]
322
213
  __m256 collected = _mm256_i32gather_ps(
323
214
  tab, indices_to_read_from, sizeof(float));
324
215
  tab += ksub * 8;
325
-
326
- // collect partial sums
327
216
  partialSum = _mm256_add_ps(partialSum, collected);
328
217
  }
329
218
  }
330
219
 
331
- // horizontal sum for partialSum
220
+ // Horizontal sum for partialSum.
332
221
  result += horizontal_sum(partialSum);
333
222
  }
334
223
 
335
- //
224
+ // Process leftovers.
336
225
  if (m < M) {
337
- // process leftovers
338
226
  PQDecoder8 decoder(code + m, nbits);
339
-
340
227
  for (; m < M; m++) {
341
228
  result += tab[decoder.decode()];
342
229
  tab += ksub;
@@ -346,56 +233,17 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
346
233
  return result;
347
234
  }
348
235
 
349
- template <typename PQDecoderT>
350
- typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
351
- type
352
- distance_four_codes_avx2(
353
- // number of subquantizers
354
- const size_t M,
355
- // number of bits per quantization index
356
- const size_t nbits,
357
- // precomputed distances, layout (M, ksub)
358
- const float* sim_table,
359
- // codes
360
- const uint8_t* __restrict code0,
361
- const uint8_t* __restrict code1,
362
- const uint8_t* __restrict code2,
363
- const uint8_t* __restrict code3,
364
- // computed distances
365
- float& result0,
366
- float& result1,
367
- float& result2,
368
- float& result3) {
369
- distance_four_codes_generic<PQDecoderT>(
370
- M,
371
- nbits,
372
- sim_table,
373
- code0,
374
- code1,
375
- code2,
376
- code3,
377
- result0,
378
- result1,
379
- result2,
380
- result3);
381
- }
382
-
383
- // Combines 4 operations of distance_single_code()
384
- template <typename PQDecoderT>
385
- typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type
386
- distance_four_codes_avx2(
387
- // number of subquantizers
388
- const size_t M,
389
- // number of bits per quantization index
390
- const size_t nbits,
391
- // precomputed distances, layout (M, ksub)
236
+ // Combines 4 operations of pq_code_distance_single_impl().
237
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
238
+ template <>
239
+ void pq_code_distance_four_impl<SIMDLevel::AVX2>(
240
+ size_t M,
241
+ size_t nbits,
392
242
  const float* sim_table,
393
- // codes
394
243
  const uint8_t* __restrict code0,
395
244
  const uint8_t* __restrict code1,
396
245
  const uint8_t* __restrict code2,
397
246
  const uint8_t* __restrict code3,
398
- // computed distances
399
247
  float& result0,
400
248
  float& result1,
401
249
  float& result2,
@@ -427,6 +275,7 @@ distance_four_codes_avx2(
427
275
  return;
428
276
  }
429
277
 
278
+ // Precomputed distances, layout (M, ksub).
430
279
  result0 = 0;
431
280
  result1 = 0;
432
281
  result2 = 0;
@@ -441,80 +290,57 @@ distance_four_codes_avx2(
441
290
  const float* tab = sim_table;
442
291
 
443
292
  if (pqM16 > 0) {
444
- // process 16 values per loop
445
293
  const __m256i vksub = _mm256_set1_epi32(ksub);
446
294
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
447
295
  offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
448
296
 
449
- // accumulators of partial sums
450
297
  __m256 partialSums[N];
451
298
  for (intptr_t j = 0; j < N; j++) {
452
299
  partialSums[j] = _mm256_setzero_ps();
453
300
  }
454
301
 
455
- // loop
302
+ // Process 16 values per loop iteration.
456
303
  for (m = 0; m < pqM16 * 16; m += 16) {
457
- // load 16 uint8 values
458
304
  __m128i mm1[N];
459
305
  mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m));
460
306
  mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m));
461
307
  mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m));
462
308
  mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m));
463
309
 
464
- // process first 8 codes
310
+ // Process first 8 codes.
465
311
  for (intptr_t j = 0; j < N; j++) {
466
- // convert uint8 values (low part of __m128i) to int32
467
- // values
468
312
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
469
-
470
- // add offsets
471
313
  const __m256i indices_to_read_from =
472
314
  _mm256_add_epi32(idx1, offsets_0);
473
-
474
- // gather 8 values, similar to 8 operations of tab[idx]
475
315
  __m256 collected = _mm256_i32gather_ps(
476
316
  tab, indices_to_read_from, sizeof(float));
477
-
478
- // collect partial sums
479
317
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
480
318
  }
481
319
  tab += ksub * 8;
482
320
 
483
- // process next 8 codes
321
+ // Process next 8 codes.
484
322
  for (intptr_t j = 0; j < N; j++) {
485
- // move high 8 uint8 to low ones
486
323
  const __m128i mm2 =
487
324
  _mm_unpackhi_epi64(mm1[j], _mm_setzero_si128());
488
-
489
- // convert uint8 values (low part of __m128i) to int32
490
- // values
491
325
  const __m256i idx1 = _mm256_cvtepu8_epi32(mm2);
492
-
493
- // add offsets
494
326
  const __m256i indices_to_read_from =
495
327
  _mm256_add_epi32(idx1, offsets_0);
496
-
497
- // gather 8 values, similar to 8 operations of tab[idx]
498
328
  __m256 collected = _mm256_i32gather_ps(
499
329
  tab, indices_to_read_from, sizeof(float));
500
-
501
- // collect partial sums
502
330
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
503
331
  }
504
332
 
505
333
  tab += ksub * 8;
506
334
  }
507
335
 
508
- // horizontal sum for partialSum
509
336
  result0 += horizontal_sum(partialSums[0]);
510
337
  result1 += horizontal_sum(partialSums[1]);
511
338
  result2 += horizontal_sum(partialSums[2]);
512
339
  result3 += horizontal_sum(partialSums[3]);
513
340
  }
514
341
 
515
- //
342
+ // Process leftovers.
516
343
  if (m < M) {
517
- // process leftovers
518
344
  PQDecoder8 decoder0(code0 + m, nbits);
519
345
  PQDecoder8 decoder1(code1 + m, nbits);
520
346
  PQDecoder8 decoder2(code2 + m, nbits);
@@ -529,6 +355,7 @@ distance_four_codes_avx2(
529
355
  }
530
356
  }
531
357
 
358
+ } // namespace pq_code_distance
532
359
  } // namespace faiss
533
360
 
534
- #endif
361
+ #endif // COMPILE_SIMD_AVX2