faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -5,49 +5,28 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- #pragma once
9
-
10
- #ifdef __AVX512F__
8
+ #ifdef COMPILE_SIMD_AVX512
11
9
 
12
10
  #include <immintrin.h>
13
11
 
14
- #include <type_traits>
15
-
16
- #include <faiss/impl/ProductQuantizer.h>
17
- #include <faiss/impl/code_distance/code_distance-generic.h>
12
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
18
13
 
19
14
  namespace faiss {
15
+ namespace pq_code_distance {
20
16
 
21
17
  // According to experiments, the AVX-512 version may be SLOWER than
22
- // the AVX2 version, which is somewhat unexpected.
23
- // This version is not used for now, but it may be used later.
18
+ // the AVX2 version, which is somewhat unexpected.
19
+ // This version is kept for completeness.
24
20
  //
25
21
  // TODO: test for AMD CPUs.
26
22
 
27
- template <typename PQDecoderT>
28
- typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>::
29
- type inline distance_single_code_avx512(
30
- // number of subquantizers
31
- const size_t M,
32
- // number of bits per quantization index
33
- const size_t nbits,
34
- // precomputed distances, layout (M, ksub)
35
- const float* sim_table,
36
- const uint8_t* code) {
37
- // default implementation
38
- return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
39
- }
40
-
41
- template <typename PQDecoderT>
42
- typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
43
- type inline distance_single_code_avx512(
44
- // number of subquantizers
45
- const size_t M,
46
- // number of bits per quantization index
47
- const size_t nbits,
48
- // precomputed distances, layout (M, ksub)
49
- const float* sim_table,
50
- const uint8_t* code0) {
23
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
24
+ template <>
25
+ float pq_code_distance_single_impl<SIMDLevel::AVX512>(
26
+ size_t M,
27
+ size_t nbits,
28
+ const float* sim_table,
29
+ const uint8_t* code0) {
51
30
  float result0 = 0;
52
31
  constexpr size_t ksub = 1 << 8;
53
32
 
@@ -59,49 +38,37 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
59
38
  const float* tab = sim_table;
60
39
 
61
40
  if (pqM16 > 0) {
62
- // process 16 values per loop
63
41
  const __m512i vksub = _mm512_set1_epi32(ksub);
64
42
  __m512i offsets_0 = _mm512_setr_epi32(
65
43
  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
66
44
  offsets_0 = _mm512_mullo_epi32(offsets_0, vksub);
67
45
 
68
- // accumulators of partial sums
69
46
  __m512 partialSums[N];
70
47
  for (intptr_t j = 0; j < N; j++) {
71
48
  partialSums[j] = _mm512_setzero_ps();
72
49
  }
73
50
 
74
- // loop
51
+ // Process 16 values per loop iteration.
75
52
  for (m = 0; m < pqM16 * 16; m += 16) {
76
- // load 16 uint8 values
77
53
  __m128i mm1[N];
78
54
  mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m));
79
55
 
80
- // process first 8 codes
81
56
  for (intptr_t j = 0; j < N; j++) {
82
57
  const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]);
83
-
84
- // add offsets
85
58
  const __m512i indices_to_read_from =
86
59
  _mm512_add_epi32(idx1, offsets_0);
87
-
88
- // gather 16 values, similar to 16 operations of tab[idx]
89
60
  __m512 collected = _mm512_i32gather_ps(
90
61
  indices_to_read_from, tab, sizeof(float));
91
-
92
- // collect partial sums
93
62
  partialSums[j] = _mm512_add_ps(partialSums[j], collected);
94
63
  }
95
64
  tab += ksub * 16;
96
65
  }
97
66
 
98
- // horizontal sum for partialSum
99
67
  result0 += _mm512_reduce_add_ps(partialSums[0]);
100
68
  }
101
69
 
102
- //
70
+ // Process leftovers.
103
71
  if (m < M) {
104
- // process leftovers
105
72
  PQDecoder8 decoder0(code0 + m, nbits);
106
73
  for (; m < M; m++) {
107
74
  result0 += tab[decoder0.decode()];
@@ -112,56 +79,17 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
112
79
  return result0;
113
80
  }
114
81
 
115
- template <typename PQDecoderT>
116
- typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
117
- type
118
- distance_four_codes_avx512(
119
- // number of subquantizers
120
- const size_t M,
121
- // number of bits per quantization index
122
- const size_t nbits,
123
- // precomputed distances, layout (M, ksub)
124
- const float* sim_table,
125
- // codes
126
- const uint8_t* __restrict code0,
127
- const uint8_t* __restrict code1,
128
- const uint8_t* __restrict code2,
129
- const uint8_t* __restrict code3,
130
- // computed distances
131
- float& result0,
132
- float& result1,
133
- float& result2,
134
- float& result3) {
135
- distance_four_codes_generic<PQDecoderT>(
136
- M,
137
- nbits,
138
- sim_table,
139
- code0,
140
- code1,
141
- code2,
142
- code3,
143
- result0,
144
- result1,
145
- result2,
146
- result3);
147
- }
148
-
149
- // Combines 4 operations of distance_single_code()
150
- template <typename PQDecoderT>
151
- typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type
152
- distance_four_codes_avx512(
153
- // number of subquantizers
154
- const size_t M,
155
- // number of bits per quantization index
156
- const size_t nbits,
157
- // precomputed distances, layout (M, ksub)
82
+ // Combines 4 operations of pq_code_distance_single_impl().
83
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
84
+ template <>
85
+ void pq_code_distance_four_impl<SIMDLevel::AVX512>(
86
+ size_t M,
87
+ size_t nbits,
158
88
  const float* sim_table,
159
- // codes
160
89
  const uint8_t* __restrict code0,
161
90
  const uint8_t* __restrict code1,
162
91
  const uint8_t* __restrict code2,
163
92
  const uint8_t* __restrict code3,
164
- // computed distances
165
93
  float& result0,
166
94
  float& result1,
167
95
  float& result2,
@@ -180,55 +108,43 @@ distance_four_codes_avx512(
180
108
  const float* tab = sim_table;
181
109
 
182
110
  if (pqM16 > 0) {
183
- // process 16 values per loop
184
111
  const __m512i vksub = _mm512_set1_epi32(ksub);
185
112
  __m512i offsets_0 = _mm512_setr_epi32(
186
113
  0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
187
114
  offsets_0 = _mm512_mullo_epi32(offsets_0, vksub);
188
115
 
189
- // accumulators of partial sums
190
116
  __m512 partialSums[N];
191
117
  for (intptr_t j = 0; j < N; j++) {
192
118
  partialSums[j] = _mm512_setzero_ps();
193
119
  }
194
120
 
195
- // loop
121
+ // Process 16 values per loop iteration.
196
122
  for (m = 0; m < pqM16 * 16; m += 16) {
197
- // load 16 uint8 values
198
123
  __m128i mm1[N];
199
124
  mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m));
200
125
  mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m));
201
126
  mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m));
202
127
  mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m));
203
128
 
204
- // process first 8 codes
205
129
  for (intptr_t j = 0; j < N; j++) {
206
130
  const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]);
207
-
208
- // add offsets
209
131
  const __m512i indices_to_read_from =
210
132
  _mm512_add_epi32(idx1, offsets_0);
211
-
212
- // gather 16 values, similar to 16 operations of tab[idx]
213
133
  __m512 collected = _mm512_i32gather_ps(
214
134
  indices_to_read_from, tab, sizeof(float));
215
-
216
- // collect partial sums
217
135
  partialSums[j] = _mm512_add_ps(partialSums[j], collected);
218
136
  }
219
137
  tab += ksub * 16;
220
138
  }
221
139
 
222
- // horizontal sum for partialSum
223
140
  result0 += _mm512_reduce_add_ps(partialSums[0]);
224
141
  result1 += _mm512_reduce_add_ps(partialSums[1]);
225
142
  result2 += _mm512_reduce_add_ps(partialSums[2]);
226
143
  result3 += _mm512_reduce_add_ps(partialSums[3]);
227
144
  }
228
145
 
229
- //
146
+ // Process leftovers.
230
147
  if (m < M) {
231
- // process leftovers
232
148
  PQDecoder8 decoder0(code0 + m, nbits);
233
149
  PQDecoder8 decoder1(code1 + m, nbits);
234
150
  PQDecoder8 decoder2(code2 + m, nbits);
@@ -243,6 +159,51 @@ distance_four_codes_avx512(
243
159
  }
244
160
  }
245
161
 
162
+ #ifdef COMPILE_SIMD_AVX512_SPR
163
+ // AVX512_SPR: Sapphire Rapids is a superset of AVX512. Reuse the
164
+ // AVX512 implementation until a dedicated SPR specialization is written.
165
+
166
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
167
+ template <>
168
+ float pq_code_distance_single_impl<SIMDLevel::AVX512_SPR>(
169
+ size_t M,
170
+ size_t nbits,
171
+ const float* sim_table,
172
+ const uint8_t* code) {
173
+ return pq_code_distance_single_impl<SIMDLevel::AVX512>(
174
+ M, nbits, sim_table, code);
175
+ }
176
+
177
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
178
+ template <>
179
+ void pq_code_distance_four_impl<SIMDLevel::AVX512_SPR>(
180
+ size_t M,
181
+ size_t nbits,
182
+ const float* sim_table,
183
+ const uint8_t* __restrict code0,
184
+ const uint8_t* __restrict code1,
185
+ const uint8_t* __restrict code2,
186
+ const uint8_t* __restrict code3,
187
+ float& result0,
188
+ float& result1,
189
+ float& result2,
190
+ float& result3) {
191
+ pq_code_distance_four_impl<SIMDLevel::AVX512>(
192
+ M,
193
+ nbits,
194
+ sim_table,
195
+ code0,
196
+ code1,
197
+ code2,
198
+ code3,
199
+ result0,
200
+ result1,
201
+ result2,
202
+ result3);
203
+ }
204
+ #endif // COMPILE_SIMD_AVX512_SPR
205
+
206
+ } // namespace pq_code_distance
246
207
  } // namespace faiss
247
208
 
248
- #endif
209
+ #endif // COMPILE_SIMD_AVX512
@@ -0,0 +1,141 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // This TU provides:
9
+ // 1. _impl specializations for NONE (and ARM_NEON), using scalar code.
10
+ // 2. Non-templated PQ code distance dispatch wrappers
11
+ // (pq_code_distance_single, pq_code_distance_four) declared in
12
+ // pq_code_distance.h. These use DISPATCH_SIMDLevel to route to the
13
+ // best available SIMD implementation via pq_code_distance_*_impl
14
+ // function template specializations defined in the per-SIMD .cpp files.
15
+
16
+ #include <faiss/impl/pq_code_distance/pq_code_distance-inl.h>
17
+
18
+ namespace faiss {
19
+ namespace pq_code_distance {
20
+
21
+ // NONE: use scalar directly.
22
+
23
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
24
+ template <>
25
+ float pq_code_distance_single_impl<SIMDLevel::NONE>(
26
+ size_t M,
27
+ size_t nbits,
28
+ const float* sim_table,
29
+ const uint8_t* code) {
30
+ return PQCodeDistanceScalar<PQDecoder8>::distance_single_code(
31
+ M, nbits, sim_table, code);
32
+ }
33
+
34
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
35
+ template <>
36
+ void pq_code_distance_four_impl<SIMDLevel::NONE>(
37
+ size_t M,
38
+ size_t nbits,
39
+ const float* sim_table,
40
+ const uint8_t* __restrict code0,
41
+ const uint8_t* __restrict code1,
42
+ const uint8_t* __restrict code2,
43
+ const uint8_t* __restrict code3,
44
+ float& result0,
45
+ float& result1,
46
+ float& result2,
47
+ float& result3) {
48
+ PQCodeDistanceScalar<PQDecoder8>::distance_four_codes(
49
+ M,
50
+ nbits,
51
+ sim_table,
52
+ code0,
53
+ code1,
54
+ code2,
55
+ code3,
56
+ result0,
57
+ result1,
58
+ result2,
59
+ result3);
60
+ }
61
+
62
+ #ifdef COMPILE_SIMD_ARM_NEON
63
+ // ARM_NEON: No NEON-optimized PQ code distance exists. Use scalar.
64
+
65
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
66
+ template <>
67
+ float pq_code_distance_single_impl<SIMDLevel::ARM_NEON>(
68
+ size_t M,
69
+ size_t nbits,
70
+ const float* sim_table,
71
+ const uint8_t* code) {
72
+ return PQCodeDistanceScalar<PQDecoder8>::distance_single_code(
73
+ M, nbits, sim_table, code);
74
+ }
75
+
76
+ // NOLINTNEXTLINE(facebook-hte-MisplacedTemplateSpecialization)
77
+ template <>
78
+ void pq_code_distance_four_impl<SIMDLevel::ARM_NEON>(
79
+ size_t M,
80
+ size_t nbits,
81
+ const float* sim_table,
82
+ const uint8_t* __restrict code0,
83
+ const uint8_t* __restrict code1,
84
+ const uint8_t* __restrict code2,
85
+ const uint8_t* __restrict code3,
86
+ float& result0,
87
+ float& result1,
88
+ float& result2,
89
+ float& result3) {
90
+ PQCodeDistanceScalar<PQDecoder8>::distance_four_codes(
91
+ M,
92
+ nbits,
93
+ sim_table,
94
+ code0,
95
+ code1,
96
+ code2,
97
+ code3,
98
+ result0,
99
+ result1,
100
+ result2,
101
+ result3);
102
+ }
103
+ #endif // COMPILE_SIMD_ARM_NEON
104
+
105
+ float pq_code_distance_single(
106
+ size_t M,
107
+ size_t nbits,
108
+ const float* sim_table,
109
+ const uint8_t* code) {
110
+ DISPATCH_SIMDLevel(pq_code_distance_single_impl, M, nbits, sim_table, code);
111
+ }
112
+
113
+ void pq_code_distance_four(
114
+ size_t M,
115
+ size_t nbits,
116
+ const float* sim_table,
117
+ const uint8_t* __restrict code0,
118
+ const uint8_t* __restrict code1,
119
+ const uint8_t* __restrict code2,
120
+ const uint8_t* __restrict code3,
121
+ float& result0,
122
+ float& result1,
123
+ float& result2,
124
+ float& result3) {
125
+ DISPATCH_SIMDLevel(
126
+ pq_code_distance_four_impl,
127
+ M,
128
+ nbits,
129
+ sim_table,
130
+ code0,
131
+ code1,
132
+ code2,
133
+ code3,
134
+ result0,
135
+ result1,
136
+ result2,
137
+ result3);
138
+ }
139
+
140
+ } // namespace pq_code_distance
141
+ } // namespace faiss
@@ -0,0 +1,23 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ /**
11
+ * @file pq_code_distance-inl.h
12
+ * @brief Private header for PQ code distance SIMD implementations.
13
+ *
14
+ * This is a PRIVATE header — do not include in public APIs or user code.
15
+ * Only faiss internal .cpp files (the per-SIMD implementation files and
16
+ * pq_code_distance-generic.cpp) should include this header.
17
+ *
18
+ * This header re-exports the public API (pq_code_distance.h) plus the
19
+ * simd_dispatch.h machinery needed by the implementation files.
20
+ */
21
+
22
+ #include <faiss/impl/simd_dispatch.h>
23
+ #include <faiss/utils/pq_code_distance.h>