faiss 0.5.3 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +8 -0
  3. data/ext/faiss/ext.cpp +1 -1
  4. data/ext/faiss/extconf.rb +5 -6
  5. data/ext/faiss/index_binary.cpp +38 -28
  6. data/ext/faiss/{index.cpp → index_rb.cpp} +64 -46
  7. data/ext/faiss/kmeans.cpp +10 -9
  8. data/ext/faiss/pca_matrix.cpp +10 -8
  9. data/ext/faiss/product_quantizer.cpp +14 -12
  10. data/ext/faiss/{utils.cpp → utils_rb.cpp} +5 -3
  11. data/ext/faiss/{utils.h → utils_rb.h} +4 -0
  12. data/lib/faiss/version.rb +1 -1
  13. data/lib/faiss.rb +1 -1
  14. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  15. data/vendor/faiss/faiss/AutoTune.h +14 -1
  16. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  17. data/vendor/faiss/faiss/Clustering.h +12 -0
  18. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  19. data/vendor/faiss/faiss/Index.cpp +20 -8
  20. data/vendor/faiss/faiss/Index.h +25 -3
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  22. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  25. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  26. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  27. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  28. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  29. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  30. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  31. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  32. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  33. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  34. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  35. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  36. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  37. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  38. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  39. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  40. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  41. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  42. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  43. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  44. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  45. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  46. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  47. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  48. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  49. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  50. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  51. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  52. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  53. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  54. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  55. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  56. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  57. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  58. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  59. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  60. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  61. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  62. data/vendor/faiss/faiss/MetricType.h +16 -0
  63. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  64. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  65. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  66. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  67. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  68. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  69. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  70. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  71. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  72. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  73. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  74. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  75. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  76. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  77. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  78. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  79. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  80. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  81. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  82. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  83. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  84. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  85. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  86. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  87. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  88. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  89. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  90. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  91. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  92. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  93. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  94. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  95. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  96. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  97. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  98. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  99. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  100. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  101. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  102. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  103. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  104. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  105. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  106. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  109. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  110. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  111. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  112. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  113. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  114. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  115. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  116. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  126. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  127. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  128. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  129. data/vendor/faiss/faiss/index_io.h +29 -3
  130. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  131. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  132. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  133. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  134. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  135. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  136. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  137. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  138. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  139. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  140. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  141. data/vendor/faiss/faiss/utils/distances.h +98 -0
  142. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  143. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  144. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  145. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  146. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  147. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  148. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  149. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  150. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  151. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  152. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  157. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  158. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  159. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  160. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  161. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  162. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  163. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  164. metadata +47 -18
  165. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  166. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  167. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -642,6 +642,27 @@ void merge_knn_results(
642
642
  typename C::T* distances,
643
643
  idx_t* labels);
644
644
 
645
+ /** Reduces k_base pairs (base_labels, base_distances) into k pairs
646
+ * (labels, distances). The function is used for the refining process.
647
+ *
648
+ * @param n number of vectors to process
649
+ * @param k number of output nearest neighbors per vector
650
+ * @param labels output labels, size (n, k)
651
+ * @param distances output distances, size (n, k)
652
+ * @param k_base number of input nearest neighbors per vector
653
+ * @param base_labels input labels, size (n, k_base)
654
+ * @param base_distances input distances, size (n, k_base)
655
+ */
656
+ template <class C>
657
+ void reorder_2_heaps(
658
+ int64_t n,
659
+ int64_t k,
660
+ typename C::TI* __restrict labels,
661
+ float* __restrict distances,
662
+ int64_t k_base,
663
+ const typename C::TI* __restrict base_labels,
664
+ const float* __restrict base_distances);
665
+
645
666
  } // namespace faiss
646
667
 
647
668
  #endif /* FAISS_Heap_h */
@@ -12,6 +12,7 @@
12
12
  #include <cstring>
13
13
 
14
14
  #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/impl/simd_dispatch.h>
15
16
  #include <faiss/utils/distances.h>
16
17
 
17
18
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
@@ -265,14 +266,16 @@ nn::Int32Tensor2D QINCoStep::encode(
265
266
  const float* db = zqs_r.data() + i * K * d;
266
267
  float dis_min = HUGE_VALF;
267
268
  int64_t idx = -1;
268
- for (size_t j = 0; j < K; j++) {
269
- float dis = fvec_L2sqr(q, db, d);
270
- if (dis < dis_min) {
271
- dis_min = dis;
272
- idx = j;
269
+ with_simd_level([&]<SIMDLevel SL>() {
270
+ for (size_t j = 0; j < K; j++) {
271
+ float dis = fvec_L2sqr<SL>(q, db, d);
272
+ if (dis < dis_min) {
273
+ dis_min = dis;
274
+ idx = j;
275
+ }
276
+ db += d;
273
277
  }
274
- db += d;
275
- }
278
+ });
276
279
  codes.v[i] = idx;
277
280
  if (res) {
278
281
  const float* xhat_row = xhat.data() + i * d;
@@ -27,6 +27,7 @@
27
27
  #include <faiss/impl/IDSelector.h>
28
28
  #include <faiss/impl/ResultHandler.h>
29
29
 
30
+ #include <faiss/utils/distances_dispatch.h>
30
31
  #include <faiss/utils/distances_fused/distances_fused.h>
31
32
 
32
33
  #ifndef FINTEGER
@@ -55,6 +56,122 @@ int sgemm_(
55
56
 
56
57
  namespace faiss {
57
58
 
59
+ /***************************************************************************
60
+ * Public API dispatch wrappers
61
+ ***************************************************************************/
62
+
63
+ float fvec_L1(const float* x, const float* y, size_t d) {
64
+ return fvec_L1_dispatch(x, y, d);
65
+ }
66
+
67
+ float fvec_Linf(const float* x, const float* y, size_t d) {
68
+ return fvec_Linf_dispatch(x, y, d);
69
+ }
70
+
71
+ float fvec_norm_L2sqr(const float* x, size_t d) {
72
+ return fvec_norm_L2sqr_dispatch(x, d);
73
+ }
74
+
75
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
76
+ return fvec_L2sqr_dispatch(x, y, d);
77
+ }
78
+
79
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
80
+ return fvec_inner_product_dispatch(x, y, d);
81
+ }
82
+
83
+ void fvec_inner_product_batch_4(
84
+ const float* x,
85
+ const float* y0,
86
+ const float* y1,
87
+ const float* y2,
88
+ const float* y3,
89
+ const size_t d,
90
+ float& dis0,
91
+ float& dis1,
92
+ float& dis2,
93
+ float& dis3) {
94
+ fvec_inner_product_batch_4_dispatch(
95
+ x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
96
+ }
97
+
98
+ void fvec_L2sqr_batch_4(
99
+ const float* x,
100
+ const float* y0,
101
+ const float* y1,
102
+ const float* y2,
103
+ const float* y3,
104
+ const size_t d,
105
+ float& dis0,
106
+ float& dis1,
107
+ float& dis2,
108
+ float& dis3) {
109
+ fvec_L2sqr_batch_4_dispatch(x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
110
+ }
111
+
112
+ void fvec_L2sqr_ny_transposed(
113
+ float* dis,
114
+ const float* x,
115
+ const float* y,
116
+ const float* y_sqlen,
117
+ size_t d,
118
+ size_t d_offset,
119
+ size_t ny) {
120
+ fvec_L2sqr_ny_transposed_dispatch(dis, x, y, y_sqlen, d, d_offset, ny);
121
+ }
122
+
123
+ void fvec_inner_products_ny(
124
+ float* ip,
125
+ const float* x,
126
+ const float* y,
127
+ size_t d,
128
+ size_t ny) {
129
+ fvec_inner_products_ny_dispatch(ip, x, y, d, ny);
130
+ }
131
+
132
+ void fvec_L2sqr_ny(
133
+ float* dis,
134
+ const float* x,
135
+ const float* y,
136
+ size_t d,
137
+ size_t ny) {
138
+ fvec_L2sqr_ny_dispatch(dis, x, y, d, ny);
139
+ }
140
+
141
+ size_t fvec_L2sqr_ny_nearest(
142
+ float* distances_tmp_buffer,
143
+ const float* x,
144
+ const float* y,
145
+ size_t d,
146
+ size_t ny) {
147
+ return fvec_L2sqr_ny_nearest_dispatch(distances_tmp_buffer, x, y, d, ny);
148
+ }
149
+
150
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
151
+ float* distances_tmp_buffer,
152
+ const float* x,
153
+ const float* y,
154
+ const float* y_sqlen,
155
+ size_t d,
156
+ size_t d_offset,
157
+ size_t ny) {
158
+ return fvec_L2sqr_ny_nearest_y_transposed_dispatch(
159
+ distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);
160
+ }
161
+
162
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
163
+ fvec_madd_dispatch(n, a, bf, b, c);
164
+ }
165
+
166
+ int fvec_madd_and_argmin(
167
+ size_t n,
168
+ const float* a,
169
+ float bf,
170
+ const float* b,
171
+ float* c) {
172
+ return fvec_madd_and_argmin_dispatch(n, a, bf, b, c);
173
+ }
174
+
58
175
  /***************************************************************************
59
176
  * Matrix/vector ops
60
177
  ***************************************************************************/
@@ -67,7 +184,7 @@ void fvec_norms_L2(
67
184
  size_t nx) {
68
185
  #pragma omp parallel for if (nx > 10000)
69
186
  for (int64_t i = 0; i < nx; i++) {
70
- nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
187
+ nr[i] = sqrtf(fvec_norm_L2sqr_dispatch(x + i * d, d));
71
188
  }
72
189
  }
73
190
 
@@ -78,7 +195,7 @@ void fvec_norms_L2sqr(
78
195
  size_t nx) {
79
196
  #pragma omp parallel for if (nx > 10000)
80
197
  for (int64_t i = 0; i < nx; i++) {
81
- nr[i] = fvec_norm_L2sqr(x + i * d, d);
198
+ nr[i] = fvec_norm_L2sqr_dispatch(x + i * d, d);
82
199
  }
83
200
  }
84
201
 
@@ -93,16 +210,16 @@ void fvec_norms_L2sqr(
93
210
  // The workaround below is explicitly branching
94
211
  // off to a codepath without omp.
95
212
 
96
- #define FVEC_RENORM_L2_IMPL \
97
- float* __restrict xi = x + i * d; \
98
- \
99
- float nr = fvec_norm_L2sqr(xi, d); \
100
- \
101
- if (nr > 0) { \
102
- size_t j; \
103
- const float inv_nr = 1.0 / sqrtf(nr); \
104
- for (j = 0; j < d; j++) \
105
- xi[j] *= inv_nr; \
213
+ #define FVEC_RENORM_L2_IMPL \
214
+ float* __restrict xi = x + i * d; \
215
+ \
216
+ float nr = fvec_norm_L2sqr_dispatch(xi, d); \
217
+ \
218
+ if (nr > 0) { \
219
+ size_t j; \
220
+ const float inv_nr = 1.0 / sqrtf(nr); \
221
+ for (j = 0; j < d; j++) \
222
+ xi[j] *= inv_nr; \
106
223
  }
107
224
 
108
225
  void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
@@ -159,7 +276,7 @@ void exhaustive_inner_product_seq(
159
276
  if (!res.is_in_selection(j)) {
160
277
  continue;
161
278
  }
162
- float ip = fvec_inner_product(x_i, y_j, d);
279
+ float ip = fvec_inner_product_dispatch(x_i, y_j, d);
163
280
  resi.add_result(ip, j);
164
281
  }
165
282
  resi.end();
@@ -191,7 +308,7 @@ void exhaustive_L2sqr_seq(
191
308
  if (!res.is_in_selection(j)) {
192
309
  continue;
193
310
  }
194
- float disij = fvec_L2sqr(x_i, y_j, d);
311
+ float disij = fvec_L2sqr_dispatch(x_i, y_j, d);
195
312
  resi.add_result(disij, j);
196
313
  }
197
314
  resi.end();
@@ -998,7 +1115,7 @@ void fvec_inner_products_by_idx(
998
1115
  if (idsj[i] < 0) {
999
1116
  ipj[i] = -INFINITY;
1000
1117
  } else {
1001
- ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
1118
+ ipj[i] = fvec_inner_product_dispatch(xj, y + d * idsj[i], d);
1002
1119
  }
1003
1120
  }
1004
1121
  }
@@ -1023,7 +1140,7 @@ void fvec_L2sqr_by_idx(
1023
1140
  if (idsj[i] < 0) {
1024
1141
  disj[i] = INFINITY;
1025
1142
  } else {
1026
- disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
1143
+ disj[i] = fvec_L2sqr_dispatch(xj, y + d * idsj[i], d);
1027
1144
  }
1028
1145
  }
1029
1146
  }
@@ -1040,7 +1157,7 @@ void pairwise_indexed_L2sqr(
1040
1157
  #pragma omp parallel for if (n > 1)
1041
1158
  for (int64_t j = 0; j < n; j++) {
1042
1159
  if (ix[j] >= 0 && iy[j] >= 0) {
1043
- dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
1160
+ dis[j] = fvec_L2sqr_dispatch(x + d * ix[j], y + d * iy[j], d);
1044
1161
  } else {
1045
1162
  dis[j] = INFINITY;
1046
1163
  }
@@ -1058,7 +1175,8 @@ void pairwise_indexed_inner_product(
1058
1175
  #pragma omp parallel for if (n > 1)
1059
1176
  for (int64_t j = 0; j < n; j++) {
1060
1177
  if (ix[j] >= 0 && iy[j] >= 0) {
1061
- dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
1178
+ dis[j] = fvec_inner_product_dispatch(
1179
+ x + d * ix[j], y + d * iy[j], d);
1062
1180
  } else {
1063
1181
  dis[j] = -INFINITY;
1064
1182
  }
@@ -1096,7 +1214,7 @@ void knn_inner_products_by_idx(
1096
1214
  if (idsi[j] < 0 || idsi[j] >= ny) {
1097
1215
  break;
1098
1216
  }
1099
- float ip = fvec_inner_product(x_, y + d * idsi[j], d);
1217
+ float ip = fvec_inner_product_dispatch(x_, y + d * idsi[j], d);
1100
1218
 
1101
1219
  if (ip > simi[0]) {
1102
1220
  minheap_replace_top(k, simi, idxi, ip, idsi[j]);
@@ -1132,7 +1250,7 @@ void knn_L2sqr_by_idx(
1132
1250
  if (idsi[j] < 0 || idsi[j] >= ny) {
1133
1251
  break;
1134
1252
  }
1135
- float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
1253
+ float disij = fvec_L2sqr_dispatch(x_, y + d * idsi[j], d);
1136
1254
 
1137
1255
  if (disij < simi[0]) {
1138
1256
  maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
@@ -1170,19 +1288,19 @@ void pairwise_L2sqr(
1170
1288
 
1171
1289
  #pragma omp parallel for if (nb > 1)
1172
1290
  for (int64_t i = 0; i < nb; i++) {
1173
- b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
1291
+ b_norms[i] = fvec_norm_L2sqr_dispatch(xb + i * ldb, d);
1174
1292
  }
1175
1293
 
1176
1294
  #pragma omp parallel for
1177
1295
  for (int64_t i = 1; i < nq; i++) {
1178
- float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
1296
+ float q_norm = fvec_norm_L2sqr_dispatch(xq + i * ldq, d);
1179
1297
  for (int64_t j = 0; j < nb; j++) {
1180
1298
  dis[i * ldd + j] = q_norm + b_norms[j];
1181
1299
  }
1182
1300
  }
1183
1301
 
1184
1302
  {
1185
- float q_norm = fvec_norm_L2sqr(xq, d);
1303
+ float q_norm = fvec_norm_L2sqr_dispatch(xq, d);
1186
1304
  for (int64_t j = 0; j < nb; j++) {
1187
1305
  dis[j] += q_norm;
1188
1306
  }
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <faiss/impl/platform_macros.h>
17
17
  #include <faiss/utils/Heap.h>
18
+ #include <faiss/utils/simd_levels.h>
18
19
 
19
20
  namespace faiss {
20
21
 
@@ -27,15 +28,27 @@ struct IDSelector;
27
28
  /// Squared L2 distance between two vectors
28
29
  float fvec_L2sqr(const float* x, const float* y, size_t d);
29
30
 
31
+ template <SIMDLevel>
32
+ float fvec_L2sqr(const float* x, const float* y, size_t d);
33
+
30
34
  /// inner product
31
35
  float fvec_inner_product(const float* x, const float* y, size_t d);
32
36
 
37
+ template <SIMDLevel>
38
+ float fvec_inner_product(const float* x, const float* y, size_t d);
39
+
33
40
  /// L1 distance
34
41
  float fvec_L1(const float* x, const float* y, size_t d);
35
42
 
43
+ template <SIMDLevel>
44
+ float fvec_L1(const float* x, const float* y, size_t d);
45
+
36
46
  /// infinity distance
37
47
  float fvec_Linf(const float* x, const float* y, size_t d);
38
48
 
49
+ template <SIMDLevel>
50
+ float fvec_Linf(const float* x, const float* y, size_t d);
51
+
39
52
  /// Special version of inner product that computes 4 distances
40
53
  /// between x and yi, which is performance oriented.
41
54
  void fvec_inner_product_batch_4(
@@ -50,6 +63,19 @@ void fvec_inner_product_batch_4(
50
63
  float& dis2,
51
64
  float& dis3);
52
65
 
66
+ template <SIMDLevel>
67
+ void fvec_inner_product_batch_4(
68
+ const float* x,
69
+ const float* y0,
70
+ const float* y1,
71
+ const float* y2,
72
+ const float* y3,
73
+ const size_t d,
74
+ float& dis0,
75
+ float& dis1,
76
+ float& dis2,
77
+ float& dis3);
78
+
53
79
  /// Special version of L2sqr that computes 4 distances
54
80
  /// between x and yi, which is performance oriented.
55
81
  void fvec_L2sqr_batch_4(
@@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4(
64
90
  float& dis2,
65
91
  float& dis3);
66
92
 
93
+ template <SIMDLevel>
94
+ void fvec_L2sqr_batch_4(
95
+ const float* x,
96
+ const float* y0,
97
+ const float* y1,
98
+ const float* y2,
99
+ const float* y3,
100
+ const size_t d,
101
+ float& dis0,
102
+ float& dis1,
103
+ float& dis2,
104
+ float& dis3);
105
+
67
106
  /** Compute pairwise distances between sets of vectors
68
107
  *
69
108
  * @param d dimension of the vectors
@@ -93,6 +132,14 @@ void fvec_inner_products_ny(
93
132
  size_t d,
94
133
  size_t ny);
95
134
 
135
+ template <SIMDLevel>
136
+ void fvec_inner_products_ny(
137
+ float* ip, /* output inner product */
138
+ const float* x,
139
+ const float* y,
140
+ size_t d,
141
+ size_t ny);
142
+
96
143
  /* compute ny square L2 distance between x and a set of contiguous y vectors */
97
144
  void fvec_L2sqr_ny(
98
145
  float* dis,
@@ -101,6 +148,14 @@ void fvec_L2sqr_ny(
101
148
  size_t d,
102
149
  size_t ny);
103
150
 
151
+ template <SIMDLevel>
152
+ void fvec_L2sqr_ny(
153
+ float* dis,
154
+ const float* x,
155
+ const float* y,
156
+ size_t d,
157
+ size_t ny);
158
+
104
159
  /* compute ny square L2 distance between x and a set of transposed contiguous
105
160
  y vectors. squared lengths of y should be provided as well */
106
161
  void fvec_L2sqr_ny_transposed(
@@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed(
112
167
  size_t d_offset,
113
168
  size_t ny);
114
169
 
170
+ template <SIMDLevel>
171
+ void fvec_L2sqr_ny_transposed(
172
+ float* dis,
173
+ const float* x,
174
+ const float* y,
175
+ const float* y_sqlen,
176
+ size_t d,
177
+ size_t d_offset,
178
+ size_t ny);
179
+
115
180
  /* compute ny square L2 distance between x and a set of contiguous y vectors
116
181
  and return the index of the nearest vector.
117
182
  return 0 if ny == 0. */
@@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest(
122
187
  size_t d,
123
188
  size_t ny);
124
189
 
190
+ template <SIMDLevel>
191
+ size_t fvec_L2sqr_ny_nearest(
192
+ float* distances_tmp_buffer,
193
+ const float* x,
194
+ const float* y,
195
+ size_t d,
196
+ size_t ny);
197
+
125
198
  /* compute ny square L2 distance between x and a set of transposed contiguous
126
199
  y vectors and return the index of the nearest vector.
127
200
  squared lengths of y should be provided as well
@@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
135
208
  size_t d_offset,
136
209
  size_t ny);
137
210
 
211
+ template <SIMDLevel>
212
+ size_t fvec_L2sqr_ny_nearest_y_transposed(
213
+ float* distances_tmp_buffer,
214
+ const float* x,
215
+ const float* y,
216
+ const float* y_sqlen,
217
+ size_t d,
218
+ size_t d_offset,
219
+ size_t ny);
220
+
138
221
  /** squared norm of a vector */
139
222
  float fvec_norm_L2sqr(const float* x, size_t d);
140
223
 
224
+ template <SIMDLevel>
225
+ float fvec_norm_L2sqr(const float* x, size_t d);
226
+
141
227
  /** compute the L2 norms for a set of vectors
142
228
  *
143
229
  * @param norms output norms, size nx
@@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2(
473
559
  */
474
560
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
475
561
 
562
+ /* same statically */
563
+ template <SIMDLevel>
564
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
565
+
476
566
  /** same as fvec_madd, also return index of the min of the result table
477
567
  * @return index of the min of table c
478
568
  */
@@ -483,4 +573,12 @@ int fvec_madd_and_argmin(
483
573
  const float* b,
484
574
  float* c);
485
575
 
576
+ template <SIMDLevel>
577
+ int fvec_madd_and_argmin(
578
+ size_t n,
579
+ const float* a,
580
+ float bf,
581
+ const float* b,
582
+ float* c);
583
+
486
584
  } // namespace faiss
@@ -0,0 +1,170 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ /**
11
+ * @file distances_dispatch.h
12
+ * @brief Inlineable dispatch wrappers for distance functions.
13
+ *
14
+ * This is a PRIVATE header. Do not include in public APIs or user code.
15
+ *
16
+ * These wrappers call DISPATCH_SIMDLevel to route to the correct SIMD
17
+ * implementation. They are plain inline functions with a _dispatch suffix
18
+ * (e.g. fvec_L2sqr_dispatch). Internal callers that want inlining include
19
+ * this header and call the _dispatch variants directly.
20
+ *
21
+ * The public API functions (fvec_L2sqr, etc.) are defined as regular extern
22
+ * functions in distances.cpp and simply delegate to these _dispatch variants.
23
+ */
24
+
25
+ #include <faiss/impl/simd_dispatch.h>
26
+ #include <faiss/utils/distances.h>
27
+
28
+ namespace faiss {
29
+
30
+ inline float fvec_L1_dispatch(const float* x, const float* y, size_t d) {
31
+ DISPATCH_SIMDLevel(fvec_L1, x, y, d);
32
+ }
33
+
34
+ inline float fvec_Linf_dispatch(const float* x, const float* y, size_t d) {
35
+ DISPATCH_SIMDLevel(fvec_Linf, x, y, d);
36
+ }
37
+
38
+ inline float fvec_norm_L2sqr_dispatch(const float* x, size_t d) {
39
+ DISPATCH_SIMDLevel(fvec_norm_L2sqr, x, d);
40
+ }
41
+
42
+ inline float fvec_L2sqr_dispatch(const float* x, const float* y, size_t d) {
43
+ DISPATCH_SIMDLevel(fvec_L2sqr, x, y, d);
44
+ }
45
+
46
+ inline float fvec_inner_product_dispatch(
47
+ const float* x,
48
+ const float* y,
49
+ size_t d) {
50
+ DISPATCH_SIMDLevel(fvec_inner_product, x, y, d);
51
+ }
52
+
53
+ inline void fvec_inner_product_batch_4_dispatch(
54
+ const float* x,
55
+ const float* y0,
56
+ const float* y1,
57
+ const float* y2,
58
+ const float* y3,
59
+ const size_t d,
60
+ float& dis0,
61
+ float& dis1,
62
+ float& dis2,
63
+ float& dis3) {
64
+ DISPATCH_SIMDLevel(
65
+ fvec_inner_product_batch_4,
66
+ x,
67
+ y0,
68
+ y1,
69
+ y2,
70
+ y3,
71
+ d,
72
+ dis0,
73
+ dis1,
74
+ dis2,
75
+ dis3);
76
+ }
77
+
78
+ inline void fvec_L2sqr_batch_4_dispatch(
79
+ const float* x,
80
+ const float* y0,
81
+ const float* y1,
82
+ const float* y2,
83
+ const float* y3,
84
+ const size_t d,
85
+ float& dis0,
86
+ float& dis1,
87
+ float& dis2,
88
+ float& dis3) {
89
+ DISPATCH_SIMDLevel(
90
+ fvec_L2sqr_batch_4, x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3);
91
+ }
92
+
93
+ inline void fvec_L2sqr_ny_transposed_dispatch(
94
+ float* dis,
95
+ const float* x,
96
+ const float* y,
97
+ const float* y_sqlen,
98
+ size_t d,
99
+ size_t d_offset,
100
+ size_t ny) {
101
+ DISPATCH_SIMDLevel(
102
+ fvec_L2sqr_ny_transposed, dis, x, y, y_sqlen, d, d_offset, ny);
103
+ }
104
+
105
+ inline void fvec_inner_products_ny_dispatch(
106
+ float* ip,
107
+ const float* x,
108
+ const float* y,
109
+ size_t d,
110
+ size_t ny) {
111
+ DISPATCH_SIMDLevel(fvec_inner_products_ny, ip, x, y, d, ny);
112
+ }
113
+
114
+ inline void fvec_L2sqr_ny_dispatch(
115
+ float* dis,
116
+ const float* x,
117
+ const float* y,
118
+ size_t d,
119
+ size_t ny) {
120
+ DISPATCH_SIMDLevel(fvec_L2sqr_ny, dis, x, y, d, ny);
121
+ }
122
+
123
+ inline size_t fvec_L2sqr_ny_nearest_dispatch(
124
+ float* distances_tmp_buffer,
125
+ const float* x,
126
+ const float* y,
127
+ size_t d,
128
+ size_t ny) {
129
+ DISPATCH_SIMDLevel(
130
+ fvec_L2sqr_ny_nearest, distances_tmp_buffer, x, y, d, ny);
131
+ }
132
+
133
+ inline size_t fvec_L2sqr_ny_nearest_y_transposed_dispatch(
134
+ float* distances_tmp_buffer,
135
+ const float* x,
136
+ const float* y,
137
+ const float* y_sqlen,
138
+ size_t d,
139
+ size_t d_offset,
140
+ size_t ny) {
141
+ DISPATCH_SIMDLevel(
142
+ fvec_L2sqr_ny_nearest_y_transposed,
143
+ distances_tmp_buffer,
144
+ x,
145
+ y,
146
+ y_sqlen,
147
+ d,
148
+ d_offset,
149
+ ny);
150
+ }
151
+
152
+ inline void fvec_madd_dispatch(
153
+ size_t n,
154
+ const float* a,
155
+ float bf,
156
+ const float* b,
157
+ float* c) {
158
+ DISPATCH_SIMDLevel(fvec_madd, n, a, bf, b, c);
159
+ }
160
+
161
+ inline int fvec_madd_and_argmin_dispatch(
162
+ size_t n,
163
+ const float* a,
164
+ float bf,
165
+ const float* b,
166
+ float* c) {
167
+ DISPATCH_SIMDLevel(fvec_madd_and_argmin, n, a, bf, b, c);
168
+ }
169
+
170
+ } // namespace faiss