faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -0,0 +1,391 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/utils/distances.h>
11
+
12
+ #include <immintrin.h>
13
+
14
+ namespace faiss {
15
+
16
+ [[maybe_unused]] inline void fvec_madd_sse(
17
+ size_t n,
18
+ const float* a,
19
+ float bf,
20
+ const float* b,
21
+ float* c) {
22
+ n >>= 2;
23
+ __m128 bf4 = _mm_set_ps1(bf);
24
+ __m128* a4 = (__m128*)a;
25
+ __m128* b4 = (__m128*)b;
26
+ __m128* c4 = (__m128*)c;
27
+
28
+ while (n--) {
29
+ *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
30
+ b4++;
31
+ a4++;
32
+ c4++;
33
+ }
34
+ }
35
+
36
+ /// helper function
37
+ inline float horizontal_sum(const __m128 v) {
38
+ // say, v is [x0, x1, x2, x3]
39
+
40
+ // v0 is [x2, x3, ..., ...]
41
+ const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
42
+ // v1 is [x0 + x2, x1 + x3, ..., ...]
43
+ const __m128 v1 = _mm_add_ps(v, v0);
44
+ // v2 is [x1 + x3, ..., .... ,...]
45
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
46
+ // v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
47
+ const __m128 v3 = _mm_add_ps(v1, v2);
48
+ // return v3[0]
49
+ return _mm_cvtss_f32(v3);
50
+ }
51
+
52
+ /// Function that does a component-wise operation between x and y
53
+ /// to compute inner products
54
+ struct ElementOpIP {
55
+ static float op(float x, float y) {
56
+ return x * y;
57
+ }
58
+
59
+ static __m128 op(__m128 x, __m128 y) {
60
+ return _mm_mul_ps(x, y);
61
+ }
62
+ };
63
+
64
+ /// Function that does a component-wise operation between x and y
65
+ /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
66
+ /// functions below
67
+ struct ElementOpL2 {
68
+ static float op(float x, float y) {
69
+ float tmp = x - y;
70
+ return tmp * tmp;
71
+ }
72
+
73
+ static __m128 op(__m128 x, __m128 y) {
74
+ __m128 tmp = _mm_sub_ps(x, y);
75
+ return _mm_mul_ps(tmp, tmp);
76
+ }
77
+ };
78
+
79
+ template <class ElementOp>
80
+ void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
81
+ float x0s = x[0];
82
+ __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
83
+
84
+ size_t i;
85
+ for (i = 0; i + 3 < ny; i += 4) {
86
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
87
+ y += 4;
88
+ dis[i] = _mm_cvtss_f32(accu);
89
+ __m128 tmp = _mm_shuffle_ps(accu, accu, 1);
90
+ dis[i + 1] = _mm_cvtss_f32(tmp);
91
+ tmp = _mm_shuffle_ps(accu, accu, 2);
92
+ dis[i + 2] = _mm_cvtss_f32(tmp);
93
+ tmp = _mm_shuffle_ps(accu, accu, 3);
94
+ dis[i + 3] = _mm_cvtss_f32(tmp);
95
+ }
96
+ while (i < ny) { // handle non-multiple-of-4 case
97
+ dis[i++] = ElementOp::op(x0s, *y++);
98
+ }
99
+ }
100
+
101
+ template <class ElementOp>
102
+ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
103
+ __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
104
+
105
+ size_t i;
106
+ for (i = 0; i + 1 < ny; i += 2) {
107
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
108
+ y += 4;
109
+ accu = _mm_hadd_ps(accu, accu);
110
+ dis[i] = _mm_cvtss_f32(accu);
111
+ accu = _mm_shuffle_ps(accu, accu, 3);
112
+ dis[i + 1] = _mm_cvtss_f32(accu);
113
+ }
114
+ if (i < ny) { // handle odd case
115
+ dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
116
+ }
117
+ }
118
+
119
+ template <class ElementOp>
120
+ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
121
+ __m128 x0 = _mm_loadu_ps(x);
122
+
123
+ for (size_t i = 0; i < ny; i++) {
124
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
125
+ y += 4;
126
+ dis[i] = horizontal_sum(accu);
127
+ }
128
+ }
129
+
130
+ template <class ElementOp>
131
+ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
132
+ __m128 x0 = _mm_loadu_ps(x);
133
+ __m128 x1 = _mm_loadu_ps(x + 4);
134
+
135
+ for (size_t i = 0; i < ny; i++) {
136
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
137
+ y += 4;
138
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
139
+ y += 4;
140
+ accu = _mm_hadd_ps(accu, accu);
141
+ accu = _mm_hadd_ps(accu, accu);
142
+ dis[i] = _mm_cvtss_f32(accu);
143
+ }
144
+ }
145
+
146
+ template <class ElementOp>
147
+ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
148
+ __m128 x0 = _mm_loadu_ps(x);
149
+ __m128 x1 = _mm_loadu_ps(x + 4);
150
+ __m128 x2 = _mm_loadu_ps(x + 8);
151
+
152
+ for (size_t i = 0; i < ny; i++) {
153
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
154
+ y += 4;
155
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
156
+ y += 4;
157
+ accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
158
+ y += 4;
159
+ dis[i] = horizontal_sum(accu);
160
+ }
161
+ }
162
+
163
+ template <class ElementOpIP>
164
+ void fvec_inner_products_ny_ref(
165
+ float* dis,
166
+ const float* x,
167
+ const float* y,
168
+ size_t d,
169
+ size_t ny) {
170
+ #define DISPATCH(dval) \
171
+ case dval: \
172
+ fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
173
+ return;
174
+
175
+ switch (d) {
176
+ DISPATCH(1)
177
+ DISPATCH(2)
178
+ DISPATCH(4)
179
+ DISPATCH(8)
180
+ DISPATCH(12)
181
+ default:
182
+ fvec_inner_products_ny<SIMDLevel::NONE>(dis, x, y, d, ny);
183
+ return;
184
+ }
185
+ #undef DISPATCH
186
+ }
187
+
188
+ template <class ElementOpL2>
189
+ void fvec_L2sqr_ny_ref(
190
+ float* dis,
191
+ const float* x,
192
+ const float* y,
193
+ size_t d,
194
+ size_t ny) {
195
+ // optimized for a few special cases
196
+
197
+ #define DISPATCH(dval) \
198
+ case dval: \
199
+ fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
200
+ return;
201
+
202
+ switch (d) {
203
+ DISPATCH(1)
204
+ DISPATCH(2)
205
+ DISPATCH(4)
206
+ DISPATCH(8)
207
+ DISPATCH(12)
208
+ default:
209
+ fvec_L2sqr_ny<SIMDLevel::NONE>(dis, x, y, d, ny);
210
+ return;
211
+ }
212
+ #undef DISPATCH
213
+ }
214
+
215
+ template <SIMDLevel>
216
+ size_t fvec_L2sqr_ny_nearest_D2(
217
+ float* distances_tmp_buffer,
218
+ const float* x,
219
+ const float* y,
220
+ size_t ny);
221
+
222
+ template <SIMDLevel>
223
+ size_t fvec_L2sqr_ny_nearest_D4(
224
+ float* distances_tmp_buffer,
225
+ const float* x,
226
+ const float* y,
227
+ size_t ny);
228
+
229
+ template <SIMDLevel>
230
+ size_t fvec_L2sqr_ny_nearest_D8(
231
+ float* distances_tmp_buffer,
232
+ const float* x,
233
+ const float* y,
234
+ size_t ny);
235
+
236
+ template <SIMDLevel SIMD>
237
+ size_t fvec_L2sqr_ny_nearest_x86(
238
+ float* distances_tmp_buffer,
239
+ const float* x,
240
+ const float* y,
241
+ size_t d,
242
+ size_t ny,
243
+ size_t (*fvec_L2sqr_ny_nearest_D2_func)(
244
+ float*,
245
+ const float*,
246
+ const float*,
247
+ size_t) = &fvec_L2sqr_ny_nearest_D2<SIMD>,
248
+ size_t (*fvec_L2sqr_ny_nearest_D4_func)(
249
+ float*,
250
+ const float*,
251
+ const float*,
252
+ size_t) = &fvec_L2sqr_ny_nearest_D4<SIMD>,
253
+ size_t (*fvec_L2sqr_ny_nearest_D8_func)(
254
+ float*,
255
+ const float*,
256
+ const float*,
257
+ size_t) = &fvec_L2sqr_ny_nearest_D8<SIMD>);
258
+
259
+ template <SIMDLevel SIMD>
260
+ size_t fvec_L2sqr_ny_nearest_x86(
261
+ float* distances_tmp_buffer,
262
+ const float* x,
263
+ const float* y,
264
+ size_t d,
265
+ size_t ny,
266
+ size_t (*fvec_L2sqr_ny_nearest_D2_func)(
267
+ float*,
268
+ const float*,
269
+ const float*,
270
+ size_t),
271
+ size_t (*fvec_L2sqr_ny_nearest_D4_func)(
272
+ float*,
273
+ const float*,
274
+ const float*,
275
+ size_t),
276
+ size_t (*fvec_L2sqr_ny_nearest_D8_func)(
277
+ float*,
278
+ const float*,
279
+ const float*,
280
+ size_t)) {
281
+ switch (d) {
282
+ case 2:
283
+ return fvec_L2sqr_ny_nearest_D2_func(
284
+ distances_tmp_buffer, x, y, ny);
285
+ case 4:
286
+ return fvec_L2sqr_ny_nearest_D4_func(
287
+ distances_tmp_buffer, x, y, ny);
288
+ case 8:
289
+ return fvec_L2sqr_ny_nearest_D8_func(
290
+ distances_tmp_buffer, x, y, ny);
291
+ }
292
+
293
+ return fvec_L2sqr_ny_nearest<SIMDLevel::NONE>(
294
+ distances_tmp_buffer, x, y, d, ny);
295
+ }
296
+
297
+ template <SIMDLevel SIMD>
298
+ inline size_t fvec_L2sqr_ny_nearest(
299
+ float* distances_tmp_buffer,
300
+ const float* x,
301
+ const float* y,
302
+ size_t d,
303
+ size_t ny);
304
+
305
+ inline int fvec_madd_and_argmin_sse_ref(
306
+ size_t n,
307
+ const float* a,
308
+ float bf,
309
+ const float* b,
310
+ float* c) {
311
+ n >>= 2;
312
+ __m128 bf4 = _mm_set_ps1(bf);
313
+ __m128 vmin4 = _mm_set_ps1(1e20);
314
+ __m128i imin4 = _mm_set1_epi32(-1);
315
+ __m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
316
+ __m128i inc4 = _mm_set1_epi32(4);
317
+ __m128* a4 = (__m128*)a;
318
+ __m128* b4 = (__m128*)b;
319
+ __m128* c4 = (__m128*)c;
320
+
321
+ while (n--) {
322
+ __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
323
+ *c4 = vc4;
324
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
325
+ // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
326
+
327
+ imin4 = _mm_or_si128(
328
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
329
+ vmin4 = _mm_min_ps(vmin4, vc4);
330
+ b4++;
331
+ a4++;
332
+ c4++;
333
+ idx4 = _mm_add_epi32(idx4, inc4);
334
+ }
335
+
336
+ // 4 values -> 2
337
+ {
338
+ idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
339
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
340
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
341
+ imin4 = _mm_or_si128(
342
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
343
+ vmin4 = _mm_min_ps(vmin4, vc4);
344
+ }
345
+ // 2 values -> 1
346
+ {
347
+ idx4 = _mm_shuffle_epi32(imin4, 1);
348
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
349
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
350
+ imin4 = _mm_or_si128(
351
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
352
+ // vmin4 = _mm_min_ps (vmin4, vc4);
353
+ }
354
+ return _mm_cvtsi128_si32(imin4);
355
+ }
356
+
357
+ inline int fvec_madd_and_argmin_sse(
358
+ size_t n,
359
+ const float* a,
360
+ float bf,
361
+ const float* b,
362
+ float* c) {
363
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) {
364
+ return fvec_madd_and_argmin_sse_ref(n, a, bf, b, c);
365
+ } else {
366
+ return fvec_madd_and_argmin<SIMDLevel::NONE>(n, a, bf, b, c);
367
+ }
368
+ }
369
+
370
+ // reads 0 <= d < 4 floats as __m128
371
+ inline __m128 masked_read(int d, const float* x) {
372
+ assert(0 <= d && d < 4);
373
+ ALIGNED(16) float buf[4] = {0, 0, 0, 0};
374
+ switch (d) {
375
+ case 3:
376
+ buf[2] = x[2];
377
+ [[fallthrough]];
378
+ case 2:
379
+ buf[1] = x[1];
380
+ [[fallthrough]];
381
+ case 1:
382
+ buf[0] = x[0];
383
+ break;
384
+ default:
385
+ break;
386
+ }
387
+ return _mm_load_ps(buf);
388
+ // cannot use AVX2 _mm_mask_set1_epi32
389
+ }
390
+
391
+ } // namespace faiss
@@ -0,0 +1,322 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/simd_levels.h>
9
+
10
+ #include <cstdlib>
11
+
12
+ #include <faiss/impl/FaissAssert.h>
13
+
14
+ namespace faiss {
15
+
16
+ // Static member definitions - used in both DD and static modes
17
+ SIMDLevel SIMDConfig::level = SIMDLevel::NONE;
18
+
19
+ // Bitmask of supported SIMD levels (1 << SIMDLevel)
20
+ uint64_t SIMDConfig::supported_simd_levels = 0;
21
+
22
+ // ARM SVE runtime detection
23
+ #if defined(__aarch64__) || defined(_M_ARM64)
24
+
25
+ #if defined(__linux__)
26
+ #include <sys/auxv.h>
27
+ #ifndef HWCAP_SVE
28
+ #define HWCAP_SVE (1 << 22)
29
+ #endif
30
+
31
+ static bool has_sve() {
32
+ return (getauxval(AT_HWCAP) & HWCAP_SVE) != 0;
33
+ }
34
+
35
+ #elif defined(__APPLE__)
36
+ // Apple Silicon does NOT support SVE
37
+ static bool has_sve() {
38
+ return false;
39
+ }
40
+
41
+ #else
42
+ // Other aarch64 platforms: conservatively report no SVE
43
+ static bool has_sve() {
44
+ return false;
45
+ }
46
+
47
+ #endif // __linux__ / __APPLE__ / other
48
+
49
+ #else // Not ARM64
50
+ static bool has_sve() {
51
+ return false;
52
+ }
53
+ #endif
54
+
55
+ #ifdef FAISS_ENABLE_DD
56
+
57
+ // =============================================================================
58
+ // Dynamic Dispatch (DD) mode implementation
59
+ // =============================================================================
60
+
61
+ // Static initializer to run constructor at load time
62
+ // NOLINTNEXTLINE(facebook-avoid-non-const-global-variables)
63
+ static SIMDConfig simd_config_initializer;
64
+
65
+ SIMDConfig::SIMDConfig(const char** faiss_simd_level_env) {
66
+ // Support dependency injection for testing
67
+ const char* env_var = faiss_simd_level_env ? *faiss_simd_level_env
68
+ : getenv("FAISS_SIMD_LEVEL");
69
+
70
+ if (!env_var) {
71
+ level = auto_detect_simd_level();
72
+ } else {
73
+ level = to_simd_level(env_var);
74
+ supported_simd_levels = (1 << static_cast<int>(level));
75
+ }
76
+ supported_simd_levels |= (1 << static_cast<int>(SIMDLevel::NONE));
77
+ }
78
+
79
+ void SIMDConfig::set_level(SIMDLevel l) {
80
+ if (!is_simd_level_available(l)) {
81
+ FAISS_THROW_FMT(
82
+ "SIMDConfig::set_level: level %s is not available",
83
+ to_string(l).c_str());
84
+ }
85
+ level = l;
86
+ }
87
+
88
+ SIMDLevel SIMDConfig::get_level() {
89
+ return level;
90
+ }
91
+
92
+ std::string SIMDConfig::get_level_name() {
93
+ return to_string(level);
94
+ }
95
+
96
+ bool SIMDConfig::is_simd_level_available(SIMDLevel l) {
97
+ return (supported_simd_levels & (1 << static_cast<int>(l))) != 0;
98
+ }
99
+
100
+ SIMDLevel SIMDConfig::auto_detect_simd_level() {
101
+ SIMDLevel detected_level = SIMDLevel::NONE;
102
+
103
+ #if defined(__x86_64__) && \
104
+ (defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512))
105
+ unsigned int eax, ebx, ecx, edx;
106
+
107
+ eax = 1;
108
+ ecx = 0;
109
+ asm volatile("cpuid"
110
+ : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
111
+ : "a"(eax), "c"(ecx));
112
+
113
+ bool has_avx = (ecx & (1 << 28)) != 0;
114
+
115
+ bool has_xsave_osxsave =
116
+ (ecx & ((1 << 26) | (1 << 27))) == ((1 << 26) | (1 << 27));
117
+
118
+ bool avx_supported = false;
119
+ if (has_avx && has_xsave_osxsave) {
120
+ unsigned int xcr0;
121
+ asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0));
122
+ avx_supported = (xcr0 & 6) == 6;
123
+ }
124
+
125
+ if (avx_supported) {
126
+ eax = 7;
127
+ ecx = 0;
128
+ asm volatile("cpuid"
129
+ : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx)
130
+ : "a"(eax), "c"(ecx));
131
+
132
+ unsigned int xcr0;
133
+ asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0));
134
+
135
+ #if defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512)
136
+ bool has_avx2 = (ebx & (1 << 5)) != 0;
137
+ if (has_avx2) {
138
+ supported_simd_levels |= (1 << static_cast<int>(SIMDLevel::AVX2));
139
+ detected_level = SIMDLevel::AVX2;
140
+ }
141
+
142
+ #if defined(COMPILE_SIMD_AVX512)
143
+ bool cpu_has_avx512f = (ebx & (1 << 16)) != 0;
144
+ bool os_supports_avx512 = (xcr0 & 0xE0) == 0xE0;
145
+ bool has_avx512f = cpu_has_avx512f && os_supports_avx512;
146
+ if (has_avx512f) {
147
+ bool has_avx512cd = (ebx & (1 << 28)) != 0;
148
+ bool has_avx512vl = (ebx & (1 << 31)) != 0;
149
+ bool has_avx512dq = (ebx & (1 << 17)) != 0;
150
+ bool has_avx512bw = (ebx & (1 << 30)) != 0;
151
+ if (has_avx512bw && has_avx512cd && has_avx512vl && has_avx512dq) {
152
+ detected_level = SIMDLevel::AVX512;
153
+ supported_simd_levels |=
154
+ (1 << static_cast<int>(SIMDLevel::AVX512));
155
+
156
+ #if defined(COMPILE_SIMD_AVX512_SPR)
157
+ // Check for Sapphire Rapids features (AVX512_BF16)
158
+ // CPUID EAX=7, ECX=1: EAX bit 5 = AVX512_BF16
159
+ unsigned int eax1, ebx1, ecx1, edx1;
160
+ eax1 = 7;
161
+ ecx1 = 1;
162
+ asm volatile("cpuid"
163
+ : "=a"(eax1), "=b"(ebx1), "=c"(ecx1), "=d"(edx1)
164
+ : "a"(eax1), "c"(ecx1));
165
+ bool has_avx512_bf16 = (eax1 & (1 << 5)) != 0;
166
+ if (has_avx512_bf16) {
167
+ detected_level = SIMDLevel::AVX512_SPR;
168
+ supported_simd_levels |=
169
+ (1 << static_cast<int>(SIMDLevel::AVX512_SPR));
170
+ }
171
+ #endif // defined(COMPILE_SIMD_AVX512_SPR)
172
+ }
173
+ }
174
+ #endif // defined(COMPILE_SIMD_AVX512)
175
+ #endif // defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512)
176
+ }
177
+ #endif // defined(__x86_64__) && ...
178
+
179
+ #ifdef COMPILE_SIMD_ARM_NEON
180
+ // ARM NEON is standard on aarch64
181
+ supported_simd_levels |= (1 << static_cast<int>(SIMDLevel::ARM_NEON));
182
+ detected_level = SIMDLevel::ARM_NEON;
183
+ #endif
184
+
185
+ #ifdef COMPILE_SIMD_ARM_SVE
186
+ if (has_sve()) {
187
+ supported_simd_levels |= (1 << static_cast<int>(SIMDLevel::ARM_SVE));
188
+ detected_level = SIMDLevel::ARM_SVE;
189
+ }
190
+ #endif
191
+
192
+ return detected_level;
193
+ }
194
+
195
+ // Include private header for DISPATCH_SIMDLevel macro
196
+ #include <faiss/impl/simd_dispatch.h>
197
+
198
+ namespace {
199
+
200
+ template <SIMDLevel Level>
201
+ SIMDLevel get_dispatched_level_impl() {
202
+ return Level;
203
+ }
204
+
205
+ } // namespace
206
+
207
+ SIMDLevel SIMDConfig::get_dispatched_level() {
208
+ DISPATCH_SIMDLevel(get_dispatched_level_impl);
209
+ }
210
+
211
+ #else // Static mode
212
+
213
+ // =============================================================================
214
+ // Static mode implementation
215
+ // =============================================================================
216
+
217
+ // Static initializer to set up the single supported level
218
+ // NOLINTNEXTLINE(facebook-avoid-non-const-global-variables)
219
+ static SIMDConfig simd_config_initializer;
220
+
221
+ SIMDConfig::SIMDConfig(const char** /* faiss_simd_level_env */) {
222
+ // In static mode, the level is fixed at compile time
223
+ level = auto_detect_simd_level();
224
+ supported_simd_levels = (1 << static_cast<int>(level));
225
+ }
226
+
227
+ void SIMDConfig::set_level(SIMDLevel l) {
228
+ if (!is_simd_level_available(l)) {
229
+ FAISS_THROW_FMT(
230
+ "SIMDConfig::set_level: level %s is not available "
231
+ "(static build only supports %s)",
232
+ to_string(l).c_str(),
233
+ to_string(level).c_str());
234
+ }
235
+ // In static mode, setting to the same level is a no-op
236
+ level = l;
237
+ }
238
+
239
+ SIMDLevel SIMDConfig::get_level() {
240
+ return level;
241
+ }
242
+
243
+ std::string SIMDConfig::get_level_name() {
244
+ return to_string(level);
245
+ }
246
+
247
+ bool SIMDConfig::is_simd_level_available(SIMDLevel l) {
248
+ return (supported_simd_levels & (1 << static_cast<int>(l))) != 0;
249
+ }
250
+
251
+ SIMDLevel SIMDConfig::auto_detect_simd_level() {
252
+ // In static mode, return the compiled-in level
253
+ #if defined(COMPILE_SIMD_AVX512_SPR)
254
+ return SIMDLevel::AVX512_SPR;
255
+ #elif defined(COMPILE_SIMD_AVX512)
256
+ return SIMDLevel::AVX512;
257
+ #elif defined(COMPILE_SIMD_AVX2)
258
+ return SIMDLevel::AVX2;
259
+ #elif defined(COMPILE_SIMD_ARM_SVE)
260
+ return SIMDLevel::ARM_SVE;
261
+ #elif defined(COMPILE_SIMD_ARM_NEON)
262
+ return SIMDLevel::ARM_NEON;
263
+ #else
264
+ return SIMDLevel::NONE;
265
+ #endif
266
+ }
267
+
268
+ SIMDLevel SIMDConfig::get_dispatched_level() {
269
+ // In static mode, just return the current level (no dispatch)
270
+ return get_level();
271
+ }
272
+
273
+ #endif // FAISS_ENABLE_DD
274
+
275
+ // =============================================================================
276
+ // Common functions (both modes)
277
+ // =============================================================================
278
+
279
+ std::string to_string(SIMDLevel level) {
280
+ switch (level) {
281
+ case SIMDLevel::NONE:
282
+ return "NONE";
283
+ case SIMDLevel::AVX2:
284
+ return "AVX2";
285
+ case SIMDLevel::AVX512:
286
+ return "AVX512";
287
+ case SIMDLevel::AVX512_SPR:
288
+ return "AVX512_SPR";
289
+ case SIMDLevel::ARM_NEON:
290
+ return "ARM_NEON";
291
+ case SIMDLevel::ARM_SVE:
292
+ return "ARM_SVE";
293
+ case SIMDLevel::COUNT:
294
+ default:
295
+ throw FaissException("Invalid SIMDLevel");
296
+ }
297
+ }
298
+
299
+ SIMDLevel to_simd_level(const std::string& level_str) {
300
+ if (level_str == "NONE") {
301
+ return SIMDLevel::NONE;
302
+ }
303
+ if (level_str == "AVX2") {
304
+ return SIMDLevel::AVX2;
305
+ }
306
+ if (level_str == "AVX512") {
307
+ return SIMDLevel::AVX512;
308
+ }
309
+ if (level_str == "AVX512_SPR") {
310
+ return SIMDLevel::AVX512_SPR;
311
+ }
312
+ if (level_str == "ARM_NEON") {
313
+ return SIMDLevel::ARM_NEON;
314
+ }
315
+ if (level_str == "ARM_SVE") {
316
+ return SIMDLevel::ARM_SVE;
317
+ }
318
+
319
+ throw FaissException("Invalid SIMD level string: " + level_str);
320
+ }
321
+
322
+ } // namespace faiss