faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -23,7 +23,9 @@
23
23
  #include <immintrin.h>
24
24
  #endif
25
25
 
26
- #ifdef __AVX2__
26
+ #if defined(__AVX512F__)
27
+ #include <faiss/utils/transpose/transpose-avx512-inl.h>
28
+ #elif defined(__AVX2__)
27
29
  #include <faiss/utils/transpose/transpose-avx2-inl.h>
28
30
  #endif
29
31
 
@@ -223,6 +225,76 @@ float fvec_L2sqr(const float* x, const float* y, size_t d) {
223
225
  }
224
226
  FAISS_PRAGMA_IMPRECISE_FUNCTION_END
225
227
 
228
+ /// Special version of inner product that computes 4 distances
229
+ /// between x and yi
230
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
231
+ void fvec_inner_product_batch_4(
232
+ const float* __restrict x,
233
+ const float* __restrict y0,
234
+ const float* __restrict y1,
235
+ const float* __restrict y2,
236
+ const float* __restrict y3,
237
+ const size_t d,
238
+ float& dis0,
239
+ float& dis1,
240
+ float& dis2,
241
+ float& dis3) {
242
+ float d0 = 0;
243
+ float d1 = 0;
244
+ float d2 = 0;
245
+ float d3 = 0;
246
+ FAISS_PRAGMA_IMPRECISE_LOOP
247
+ for (size_t i = 0; i < d; ++i) {
248
+ d0 += x[i] * y0[i];
249
+ d1 += x[i] * y1[i];
250
+ d2 += x[i] * y2[i];
251
+ d3 += x[i] * y3[i];
252
+ }
253
+
254
+ dis0 = d0;
255
+ dis1 = d1;
256
+ dis2 = d2;
257
+ dis3 = d3;
258
+ }
259
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
260
+
261
+ /// Special version of L2sqr that computes 4 distances
262
+ /// between x and yi, which is performance oriented.
263
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
264
+ void fvec_L2sqr_batch_4(
265
+ const float* x,
266
+ const float* y0,
267
+ const float* y1,
268
+ const float* y2,
269
+ const float* y3,
270
+ const size_t d,
271
+ float& dis0,
272
+ float& dis1,
273
+ float& dis2,
274
+ float& dis3) {
275
+ float d0 = 0;
276
+ float d1 = 0;
277
+ float d2 = 0;
278
+ float d3 = 0;
279
+ FAISS_PRAGMA_IMPRECISE_LOOP
280
+ for (size_t i = 0; i < d; ++i) {
281
+ const float q0 = x[i] - y0[i];
282
+ const float q1 = x[i] - y1[i];
283
+ const float q2 = x[i] - y2[i];
284
+ const float q3 = x[i] - y3[i];
285
+ d0 += q0 * q0;
286
+ d1 += q1 * q1;
287
+ d2 += q2 * q2;
288
+ d3 += q3 * q3;
289
+ }
290
+
291
+ dis0 = d0;
292
+ dis1 = d1;
293
+ dis2 = d2;
294
+ dis3 = d3;
295
+ }
296
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
297
+
226
298
  /*********************************************************
227
299
  * SSE and AVX implementations
228
300
  */
@@ -236,8 +308,10 @@ static inline __m128 masked_read(int d, const float* x) {
236
308
  switch (d) {
237
309
  case 3:
238
310
  buf[2] = x[2];
311
+ [[fallthrough]];
239
312
  case 2:
240
313
  buf[1] = x[1];
314
+ [[fallthrough]];
241
315
  case 1:
242
316
  buf[0] = x[0];
243
317
  }
@@ -247,6 +321,41 @@ static inline __m128 masked_read(int d, const float* x) {
247
321
 
248
322
  namespace {
249
323
 
324
+ /// helper function
325
+ inline float horizontal_sum(const __m128 v) {
326
+ // say, v is [x0, x1, x2, x3]
327
+
328
+ // v0 is [x2, x3, ..., ...]
329
+ const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
330
+ // v1 is [x0 + x2, x1 + x3, ..., ...]
331
+ const __m128 v1 = _mm_add_ps(v, v0);
332
+ // v2 is [x1 + x3, ..., .... ,...]
333
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
334
+ // v3 is [x0 + x1 + x2 + x3, ..., ..., ...]
335
+ const __m128 v3 = _mm_add_ps(v1, v2);
336
+ // return v3[0]
337
+ return _mm_cvtss_f32(v3);
338
+ }
339
+
340
+ #ifdef __AVX2__
341
+ /// helper function for AVX2
342
+ inline float horizontal_sum(const __m256 v) {
343
+ // add high and low parts
344
+ const __m128 v0 =
345
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
346
+ // perform horizontal sum on v0
347
+ return horizontal_sum(v0);
348
+ }
349
+ #endif
350
+
351
+ #ifdef __AVX512F__
352
+ /// helper function for AVX512
353
+ inline float horizontal_sum(const __m512 v) {
354
+ // performs better than adding the high and low parts
355
+ return _mm512_reduce_add_ps(v);
356
+ }
357
+ #endif
358
+
250
359
  /// Function that does a component-wise operation between x and y
251
360
  /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
252
361
  /// functions below
@@ -260,6 +369,20 @@ struct ElementOpL2 {
260
369
  __m128 tmp = _mm_sub_ps(x, y);
261
370
  return _mm_mul_ps(tmp, tmp);
262
371
  }
372
+
373
+ #ifdef __AVX2__
374
+ static __m256 op(__m256 x, __m256 y) {
375
+ __m256 tmp = _mm256_sub_ps(x, y);
376
+ return _mm256_mul_ps(tmp, tmp);
377
+ }
378
+ #endif
379
+
380
+ #ifdef __AVX512F__
381
+ static __m512 op(__m512 x, __m512 y) {
382
+ __m512 tmp = _mm512_sub_ps(x, y);
383
+ return _mm512_mul_ps(tmp, tmp);
384
+ }
385
+ #endif
263
386
  };
264
387
 
265
388
  /// Function that does a component-wise operation between x and y
@@ -272,6 +395,18 @@ struct ElementOpIP {
272
395
  static __m128 op(__m128 x, __m128 y) {
273
396
  return _mm_mul_ps(x, y);
274
397
  }
398
+
399
+ #ifdef __AVX2__
400
+ static __m256 op(__m256 x, __m256 y) {
401
+ return _mm256_mul_ps(x, y);
402
+ }
403
+ #endif
404
+
405
+ #ifdef __AVX512F__
406
+ static __m512 op(__m512 x, __m512 y) {
407
+ return _mm512_mul_ps(x, y);
408
+ }
409
+ #endif
275
410
  };
276
411
 
277
412
  template <class ElementOp>
@@ -314,26 +449,133 @@ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
314
449
  }
315
450
  }
316
451
 
317
- template <class ElementOp>
318
- void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
319
- __m128 x0 = _mm_loadu_ps(x);
452
+ #if defined(__AVX512F__)
320
453
 
321
- for (size_t i = 0; i < ny; i++) {
322
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
323
- y += 4;
324
- accu = _mm_hadd_ps(accu, accu);
325
- accu = _mm_hadd_ps(accu, accu);
326
- dis[i] = _mm_cvtss_f32(accu);
454
+ template <>
455
+ void fvec_op_ny_D2<ElementOpIP>(
456
+ float* dis,
457
+ const float* x,
458
+ const float* y,
459
+ size_t ny) {
460
+ const size_t ny16 = ny / 16;
461
+ size_t i = 0;
462
+
463
+ if (ny16 > 0) {
464
+ // process 16 D2-vectors per loop.
465
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
466
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
467
+
468
+ const __m512 m0 = _mm512_set1_ps(x[0]);
469
+ const __m512 m1 = _mm512_set1_ps(x[1]);
470
+
471
+ for (i = 0; i < ny16 * 16; i += 16) {
472
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
473
+
474
+ // load 16x2 matrix and transpose it in registers.
475
+ // the typical bottleneck is memory access, so
476
+ // let's trade instructions for the bandwidth.
477
+
478
+ __m512 v0;
479
+ __m512 v1;
480
+
481
+ transpose_16x2(
482
+ _mm512_loadu_ps(y + 0 * 16),
483
+ _mm512_loadu_ps(y + 1 * 16),
484
+ v0,
485
+ v1);
486
+
487
+ // compute distances (dot product)
488
+ __m512 distances = _mm512_mul_ps(m0, v0);
489
+ distances = _mm512_fmadd_ps(m1, v1, distances);
490
+
491
+ // store
492
+ _mm512_storeu_ps(dis + i, distances);
493
+
494
+ y += 32; // move to the next set of 16x2 elements
495
+ }
496
+ }
497
+
498
+ if (i < ny) {
499
+ // process leftovers
500
+ float x0 = x[0];
501
+ float x1 = x[1];
502
+
503
+ for (; i < ny; i++) {
504
+ float distance = x0 * y[0] + x1 * y[1];
505
+ y += 2;
506
+ dis[i] = distance;
507
+ }
327
508
  }
328
509
  }
329
510
 
330
- #ifdef __AVX2__
511
+ template <>
512
+ void fvec_op_ny_D2<ElementOpL2>(
513
+ float* dis,
514
+ const float* x,
515
+ const float* y,
516
+ size_t ny) {
517
+ const size_t ny16 = ny / 16;
518
+ size_t i = 0;
519
+
520
+ if (ny16 > 0) {
521
+ // process 16 D2-vectors per loop.
522
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
523
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
524
+
525
+ const __m512 m0 = _mm512_set1_ps(x[0]);
526
+ const __m512 m1 = _mm512_set1_ps(x[1]);
527
+
528
+ for (i = 0; i < ny16 * 16; i += 16) {
529
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
530
+
531
+ // load 16x2 matrix and transpose it in registers.
532
+ // the typical bottleneck is memory access, so
533
+ // let's trade instructions for the bandwidth.
534
+
535
+ __m512 v0;
536
+ __m512 v1;
537
+
538
+ transpose_16x2(
539
+ _mm512_loadu_ps(y + 0 * 16),
540
+ _mm512_loadu_ps(y + 1 * 16),
541
+ v0,
542
+ v1);
543
+
544
+ // compute differences
545
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
546
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
547
+
548
+ // compute squares of differences
549
+ __m512 distances = _mm512_mul_ps(d0, d0);
550
+ distances = _mm512_fmadd_ps(d1, d1, distances);
551
+
552
+ // store
553
+ _mm512_storeu_ps(dis + i, distances);
554
+
555
+ y += 32; // move to the next set of 16x2 elements
556
+ }
557
+ }
558
+
559
+ if (i < ny) {
560
+ // process leftovers
561
+ float x0 = x[0];
562
+ float x1 = x[1];
563
+
564
+ for (; i < ny; i++) {
565
+ float sub0 = x0 - y[0];
566
+ float sub1 = x1 - y[1];
567
+ float distance = sub0 * sub0 + sub1 * sub1;
568
+
569
+ y += 2;
570
+ dis[i] = distance;
571
+ }
572
+ }
573
+ }
331
574
 
332
- // Specialized versions for AVX2 for any CPUs that support gather/scatter.
333
- // Todo: implement fvec_op_ny_Dxxx in the same way.
575
+ #elif defined(__AVX2__)
334
576
 
335
577
  template <>
336
- void fvec_op_ny_D4<ElementOpIP>(
578
+ void fvec_op_ny_D2<ElementOpIP>(
337
579
  float* dis,
338
580
  const float* x,
339
581
  const float* y,
@@ -342,68 +584,55 @@ void fvec_op_ny_D4<ElementOpIP>(
342
584
  size_t i = 0;
343
585
 
344
586
  if (ny8 > 0) {
345
- // process 8 D4-vectors per loop.
346
- _mm_prefetch(y, _MM_HINT_NTA);
347
- _mm_prefetch(y + 16, _MM_HINT_NTA);
587
+ // process 8 D2-vectors per loop.
588
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
589
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
348
590
 
349
- // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
350
591
  const __m256 m0 = _mm256_set1_ps(x[0]);
351
- // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
352
592
  const __m256 m1 = _mm256_set1_ps(x[1]);
353
- // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
354
- const __m256 m2 = _mm256_set1_ps(x[2]);
355
- // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
356
- const __m256 m3 = _mm256_set1_ps(x[3]);
357
593
 
358
594
  for (i = 0; i < ny8 * 8; i += 8) {
359
- // load 8x4 matrix and transpose it in registers.
595
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
596
+
597
+ // load 8x2 matrix and transpose it in registers.
360
598
  // the typical bottleneck is memory access, so
361
599
  // let's trade instructions for the bandwidth.
362
600
 
363
601
  __m256 v0;
364
602
  __m256 v1;
365
- __m256 v2;
366
- __m256 v3;
367
603
 
368
- transpose_8x4(
604
+ transpose_8x2(
369
605
  _mm256_loadu_ps(y + 0 * 8),
370
606
  _mm256_loadu_ps(y + 1 * 8),
371
- _mm256_loadu_ps(y + 2 * 8),
372
- _mm256_loadu_ps(y + 3 * 8),
373
607
  v0,
374
- v1,
375
- v2,
376
- v3);
608
+ v1);
377
609
 
378
610
  // compute distances
379
611
  __m256 distances = _mm256_mul_ps(m0, v0);
380
612
  distances = _mm256_fmadd_ps(m1, v1, distances);
381
- distances = _mm256_fmadd_ps(m2, v2, distances);
382
- distances = _mm256_fmadd_ps(m3, v3, distances);
383
613
 
384
614
  // store
385
615
  _mm256_storeu_ps(dis + i, distances);
386
616
 
387
- y += 32;
617
+ y += 16;
388
618
  }
389
619
  }
390
620
 
391
621
  if (i < ny) {
392
622
  // process leftovers
393
- __m128 x0 = _mm_loadu_ps(x);
623
+ float x0 = x[0];
624
+ float x1 = x[1];
394
625
 
395
626
  for (; i < ny; i++) {
396
- __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
397
- y += 4;
398
- accu = _mm_hadd_ps(accu, accu);
399
- accu = _mm_hadd_ps(accu, accu);
400
- dis[i] = _mm_cvtss_f32(accu);
627
+ float distance = x0 * y[0] + x1 * y[1];
628
+ y += 2;
629
+ dis[i] = distance;
401
630
  }
402
631
  }
403
632
  }
404
633
 
405
634
  template <>
406
- void fvec_op_ny_D4<ElementOpL2>(
635
+ void fvec_op_ny_D2<ElementOpL2>(
407
636
  float* dis,
408
637
  const float* x,
409
638
  const float* y,
@@ -412,68 +641,56 @@ void fvec_op_ny_D4<ElementOpL2>(
412
641
  size_t i = 0;
413
642
 
414
643
  if (ny8 > 0) {
415
- // process 8 D4-vectors per loop.
416
- _mm_prefetch(y, _MM_HINT_NTA);
417
- _mm_prefetch(y + 16, _MM_HINT_NTA);
644
+ // process 8 D2-vectors per loop.
645
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
646
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
418
647
 
419
- // m0 = (x[0], x[0], x[0], x[0], x[0], x[0], x[0], x[0])
420
648
  const __m256 m0 = _mm256_set1_ps(x[0]);
421
- // m1 = (x[1], x[1], x[1], x[1], x[1], x[1], x[1], x[1])
422
649
  const __m256 m1 = _mm256_set1_ps(x[1]);
423
- // m2 = (x[2], x[2], x[2], x[2], x[2], x[2], x[2], x[2])
424
- const __m256 m2 = _mm256_set1_ps(x[2]);
425
- // m3 = (x[3], x[3], x[3], x[3], x[3], x[3], x[3], x[3])
426
- const __m256 m3 = _mm256_set1_ps(x[3]);
427
650
 
428
651
  for (i = 0; i < ny8 * 8; i += 8) {
429
- // load 8x4 matrix and transpose it in registers.
652
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
653
+
654
+ // load 8x2 matrix and transpose it in registers.
430
655
  // the typical bottleneck is memory access, so
431
656
  // let's trade instructions for the bandwidth.
432
657
 
433
658
  __m256 v0;
434
659
  __m256 v1;
435
- __m256 v2;
436
- __m256 v3;
437
660
 
438
- transpose_8x4(
661
+ transpose_8x2(
439
662
  _mm256_loadu_ps(y + 0 * 8),
440
663
  _mm256_loadu_ps(y + 1 * 8),
441
- _mm256_loadu_ps(y + 2 * 8),
442
- _mm256_loadu_ps(y + 3 * 8),
443
664
  v0,
444
- v1,
445
- v2,
446
- v3);
665
+ v1);
447
666
 
448
667
  // compute differences
449
668
  const __m256 d0 = _mm256_sub_ps(m0, v0);
450
669
  const __m256 d1 = _mm256_sub_ps(m1, v1);
451
- const __m256 d2 = _mm256_sub_ps(m2, v2);
452
- const __m256 d3 = _mm256_sub_ps(m3, v3);
453
670
 
454
671
  // compute squares of differences
455
672
  __m256 distances = _mm256_mul_ps(d0, d0);
456
673
  distances = _mm256_fmadd_ps(d1, d1, distances);
457
- distances = _mm256_fmadd_ps(d2, d2, distances);
458
- distances = _mm256_fmadd_ps(d3, d3, distances);
459
674
 
460
675
  // store
461
676
  _mm256_storeu_ps(dis + i, distances);
462
677
 
463
- y += 32;
678
+ y += 16;
464
679
  }
465
680
  }
466
681
 
467
682
  if (i < ny) {
468
683
  // process leftovers
469
- __m128 x0 = _mm_loadu_ps(x);
684
+ float x0 = x[0];
685
+ float x1 = x[1];
470
686
 
471
687
  for (; i < ny; i++) {
472
- __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
473
- y += 4;
474
- accu = _mm_hadd_ps(accu, accu);
475
- accu = _mm_hadd_ps(accu, accu);
476
- dis[i] = _mm_cvtss_f32(accu);
688
+ float sub0 = x0 - y[0];
689
+ float sub1 = x1 - y[1];
690
+ float distance = sub0 * sub0 + sub1 * sub1;
691
+
692
+ y += 2;
693
+ dis[i] = distance;
477
694
  }
478
695
  }
479
696
  }
@@ -481,77 +698,698 @@ void fvec_op_ny_D4<ElementOpL2>(
481
698
  #endif
482
699
 
483
700
  template <class ElementOp>
484
- void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
485
- __m128 x0 = _mm_loadu_ps(x);
486
- __m128 x1 = _mm_loadu_ps(x + 4);
487
-
488
- for (size_t i = 0; i < ny; i++) {
489
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
490
- y += 4;
491
- accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
492
- y += 4;
493
- accu = _mm_hadd_ps(accu, accu);
494
- accu = _mm_hadd_ps(accu, accu);
495
- dis[i] = _mm_cvtss_f32(accu);
496
- }
497
- }
498
-
499
- template <class ElementOp>
500
- void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
701
+ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
501
702
  __m128 x0 = _mm_loadu_ps(x);
502
- __m128 x1 = _mm_loadu_ps(x + 4);
503
- __m128 x2 = _mm_loadu_ps(x + 8);
504
703
 
505
704
  for (size_t i = 0; i < ny; i++) {
506
705
  __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
507
706
  y += 4;
508
- accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
509
- y += 4;
510
- accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
511
- y += 4;
512
- accu = _mm_hadd_ps(accu, accu);
513
- accu = _mm_hadd_ps(accu, accu);
514
- dis[i] = _mm_cvtss_f32(accu);
707
+ dis[i] = horizontal_sum(accu);
515
708
  }
516
709
  }
517
710
 
518
- } // anonymous namespace
711
+ #if defined(__AVX512F__)
519
712
 
520
- void fvec_L2sqr_ny(
713
+ template <>
714
+ void fvec_op_ny_D4<ElementOpIP>(
521
715
  float* dis,
522
716
  const float* x,
523
717
  const float* y,
524
- size_t d,
525
718
  size_t ny) {
526
- // optimized for a few special cases
719
+ const size_t ny16 = ny / 16;
720
+ size_t i = 0;
527
721
 
528
- #define DISPATCH(dval) \
529
- case dval: \
530
- fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
531
- return;
722
+ if (ny16 > 0) {
723
+ // process 16 D4-vectors per loop.
724
+ const __m512 m0 = _mm512_set1_ps(x[0]);
725
+ const __m512 m1 = _mm512_set1_ps(x[1]);
726
+ const __m512 m2 = _mm512_set1_ps(x[2]);
727
+ const __m512 m3 = _mm512_set1_ps(x[3]);
532
728
 
533
- switch (d) {
534
- DISPATCH(1)
535
- DISPATCH(2)
536
- DISPATCH(4)
537
- DISPATCH(8)
538
- DISPATCH(12)
539
- default:
540
- fvec_L2sqr_ny_ref(dis, x, y, d, ny);
541
- return;
729
+ for (i = 0; i < ny16 * 16; i += 16) {
730
+ // load 16x4 matrix and transpose it in registers.
731
+ // the typical bottleneck is memory access, so
732
+ // let's trade instructions for the bandwidth.
733
+
734
+ __m512 v0;
735
+ __m512 v1;
736
+ __m512 v2;
737
+ __m512 v3;
738
+
739
+ transpose_16x4(
740
+ _mm512_loadu_ps(y + 0 * 16),
741
+ _mm512_loadu_ps(y + 1 * 16),
742
+ _mm512_loadu_ps(y + 2 * 16),
743
+ _mm512_loadu_ps(y + 3 * 16),
744
+ v0,
745
+ v1,
746
+ v2,
747
+ v3);
748
+
749
+ // compute distances
750
+ __m512 distances = _mm512_mul_ps(m0, v0);
751
+ distances = _mm512_fmadd_ps(m1, v1, distances);
752
+ distances = _mm512_fmadd_ps(m2, v2, distances);
753
+ distances = _mm512_fmadd_ps(m3, v3, distances);
754
+
755
+ // store
756
+ _mm512_storeu_ps(dis + i, distances);
757
+
758
+ y += 64; // move to the next set of 16x4 elements
759
+ }
760
+ }
761
+
762
+ if (i < ny) {
763
+ // process leftovers
764
+ __m128 x0 = _mm_loadu_ps(x);
765
+
766
+ for (; i < ny; i++) {
767
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
768
+ y += 4;
769
+ dis[i] = horizontal_sum(accu);
770
+ }
542
771
  }
543
- #undef DISPATCH
544
772
  }
545
773
 
546
- void fvec_inner_products_ny(
774
+ template <>
775
+ void fvec_op_ny_D4<ElementOpL2>(
547
776
  float* dis,
548
777
  const float* x,
549
778
  const float* y,
550
- size_t d,
551
779
  size_t ny) {
552
- #define DISPATCH(dval) \
553
- case dval: \
554
- fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
780
+ const size_t ny16 = ny / 16;
781
+ size_t i = 0;
782
+
783
+ if (ny16 > 0) {
784
+ // process 16 D4-vectors per loop.
785
+ const __m512 m0 = _mm512_set1_ps(x[0]);
786
+ const __m512 m1 = _mm512_set1_ps(x[1]);
787
+ const __m512 m2 = _mm512_set1_ps(x[2]);
788
+ const __m512 m3 = _mm512_set1_ps(x[3]);
789
+
790
+ for (i = 0; i < ny16 * 16; i += 16) {
791
+ // load 16x4 matrix and transpose it in registers.
792
+ // the typical bottleneck is memory access, so
793
+ // let's trade instructions for the bandwidth.
794
+
795
+ __m512 v0;
796
+ __m512 v1;
797
+ __m512 v2;
798
+ __m512 v3;
799
+
800
+ transpose_16x4(
801
+ _mm512_loadu_ps(y + 0 * 16),
802
+ _mm512_loadu_ps(y + 1 * 16),
803
+ _mm512_loadu_ps(y + 2 * 16),
804
+ _mm512_loadu_ps(y + 3 * 16),
805
+ v0,
806
+ v1,
807
+ v2,
808
+ v3);
809
+
810
+ // compute differences
811
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
812
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
813
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
814
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
815
+
816
+ // compute squares of differences
817
+ __m512 distances = _mm512_mul_ps(d0, d0);
818
+ distances = _mm512_fmadd_ps(d1, d1, distances);
819
+ distances = _mm512_fmadd_ps(d2, d2, distances);
820
+ distances = _mm512_fmadd_ps(d3, d3, distances);
821
+
822
+ // store
823
+ _mm512_storeu_ps(dis + i, distances);
824
+
825
+ y += 64; // move to the next set of 16x4 elements
826
+ }
827
+ }
828
+
829
+ if (i < ny) {
830
+ // process leftovers
831
+ __m128 x0 = _mm_loadu_ps(x);
832
+
833
+ for (; i < ny; i++) {
834
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
835
+ y += 4;
836
+ dis[i] = horizontal_sum(accu);
837
+ }
838
+ }
839
+ }
840
+
841
+ #elif defined(__AVX2__)
842
+
843
+ template <>
844
+ void fvec_op_ny_D4<ElementOpIP>(
845
+ float* dis,
846
+ const float* x,
847
+ const float* y,
848
+ size_t ny) {
849
+ const size_t ny8 = ny / 8;
850
+ size_t i = 0;
851
+
852
+ if (ny8 > 0) {
853
+ // process 8 D4-vectors per loop.
854
+ const __m256 m0 = _mm256_set1_ps(x[0]);
855
+ const __m256 m1 = _mm256_set1_ps(x[1]);
856
+ const __m256 m2 = _mm256_set1_ps(x[2]);
857
+ const __m256 m3 = _mm256_set1_ps(x[3]);
858
+
859
+ for (i = 0; i < ny8 * 8; i += 8) {
860
+ // load 8x4 matrix and transpose it in registers.
861
+ // the typical bottleneck is memory access, so
862
+ // let's trade instructions for the bandwidth.
863
+
864
+ __m256 v0;
865
+ __m256 v1;
866
+ __m256 v2;
867
+ __m256 v3;
868
+
869
+ transpose_8x4(
870
+ _mm256_loadu_ps(y + 0 * 8),
871
+ _mm256_loadu_ps(y + 1 * 8),
872
+ _mm256_loadu_ps(y + 2 * 8),
873
+ _mm256_loadu_ps(y + 3 * 8),
874
+ v0,
875
+ v1,
876
+ v2,
877
+ v3);
878
+
879
+ // compute distances
880
+ __m256 distances = _mm256_mul_ps(m0, v0);
881
+ distances = _mm256_fmadd_ps(m1, v1, distances);
882
+ distances = _mm256_fmadd_ps(m2, v2, distances);
883
+ distances = _mm256_fmadd_ps(m3, v3, distances);
884
+
885
+ // store
886
+ _mm256_storeu_ps(dis + i, distances);
887
+
888
+ y += 32;
889
+ }
890
+ }
891
+
892
+ if (i < ny) {
893
+ // process leftovers
894
+ __m128 x0 = _mm_loadu_ps(x);
895
+
896
+ for (; i < ny; i++) {
897
+ __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y));
898
+ y += 4;
899
+ dis[i] = horizontal_sum(accu);
900
+ }
901
+ }
902
+ }
903
+
904
+ template <>
905
+ void fvec_op_ny_D4<ElementOpL2>(
906
+ float* dis,
907
+ const float* x,
908
+ const float* y,
909
+ size_t ny) {
910
+ const size_t ny8 = ny / 8;
911
+ size_t i = 0;
912
+
913
+ if (ny8 > 0) {
914
+ // process 8 D4-vectors per loop.
915
+ const __m256 m0 = _mm256_set1_ps(x[0]);
916
+ const __m256 m1 = _mm256_set1_ps(x[1]);
917
+ const __m256 m2 = _mm256_set1_ps(x[2]);
918
+ const __m256 m3 = _mm256_set1_ps(x[3]);
919
+
920
+ for (i = 0; i < ny8 * 8; i += 8) {
921
+ // load 8x4 matrix and transpose it in registers.
922
+ // the typical bottleneck is memory access, so
923
+ // let's trade instructions for the bandwidth.
924
+
925
+ __m256 v0;
926
+ __m256 v1;
927
+ __m256 v2;
928
+ __m256 v3;
929
+
930
+ transpose_8x4(
931
+ _mm256_loadu_ps(y + 0 * 8),
932
+ _mm256_loadu_ps(y + 1 * 8),
933
+ _mm256_loadu_ps(y + 2 * 8),
934
+ _mm256_loadu_ps(y + 3 * 8),
935
+ v0,
936
+ v1,
937
+ v2,
938
+ v3);
939
+
940
+ // compute differences
941
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
942
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
943
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
944
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
945
+
946
+ // compute squares of differences
947
+ __m256 distances = _mm256_mul_ps(d0, d0);
948
+ distances = _mm256_fmadd_ps(d1, d1, distances);
949
+ distances = _mm256_fmadd_ps(d2, d2, distances);
950
+ distances = _mm256_fmadd_ps(d3, d3, distances);
951
+
952
+ // store
953
+ _mm256_storeu_ps(dis + i, distances);
954
+
955
+ y += 32;
956
+ }
957
+ }
958
+
959
+ if (i < ny) {
960
+ // process leftovers
961
+ __m128 x0 = _mm_loadu_ps(x);
962
+
963
+ for (; i < ny; i++) {
964
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
965
+ y += 4;
966
+ dis[i] = horizontal_sum(accu);
967
+ }
968
+ }
969
+ }
970
+
971
+ #endif
972
+
973
+ template <class ElementOp>
974
+ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
975
+ __m128 x0 = _mm_loadu_ps(x);
976
+ __m128 x1 = _mm_loadu_ps(x + 4);
977
+
978
+ for (size_t i = 0; i < ny; i++) {
979
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
980
+ y += 4;
981
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
982
+ y += 4;
983
+ accu = _mm_hadd_ps(accu, accu);
984
+ accu = _mm_hadd_ps(accu, accu);
985
+ dis[i] = _mm_cvtss_f32(accu);
986
+ }
987
+ }
988
+
989
+ #if defined(__AVX512F__)
990
+
991
+ template <>
992
+ void fvec_op_ny_D8<ElementOpIP>(
993
+ float* dis,
994
+ const float* x,
995
+ const float* y,
996
+ size_t ny) {
997
+ const size_t ny16 = ny / 16;
998
+ size_t i = 0;
999
+
1000
+ if (ny16 > 0) {
1001
+ // process 16 D16-vectors per loop.
1002
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1003
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1004
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1005
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1006
+ const __m512 m4 = _mm512_set1_ps(x[4]);
1007
+ const __m512 m5 = _mm512_set1_ps(x[5]);
1008
+ const __m512 m6 = _mm512_set1_ps(x[6]);
1009
+ const __m512 m7 = _mm512_set1_ps(x[7]);
1010
+
1011
+ for (i = 0; i < ny16 * 16; i += 16) {
1012
+ // load 16x8 matrix and transpose it in registers.
1013
+ // the typical bottleneck is memory access, so
1014
+ // let's trade instructions for the bandwidth.
1015
+
1016
+ __m512 v0;
1017
+ __m512 v1;
1018
+ __m512 v2;
1019
+ __m512 v3;
1020
+ __m512 v4;
1021
+ __m512 v5;
1022
+ __m512 v6;
1023
+ __m512 v7;
1024
+
1025
+ transpose_16x8(
1026
+ _mm512_loadu_ps(y + 0 * 16),
1027
+ _mm512_loadu_ps(y + 1 * 16),
1028
+ _mm512_loadu_ps(y + 2 * 16),
1029
+ _mm512_loadu_ps(y + 3 * 16),
1030
+ _mm512_loadu_ps(y + 4 * 16),
1031
+ _mm512_loadu_ps(y + 5 * 16),
1032
+ _mm512_loadu_ps(y + 6 * 16),
1033
+ _mm512_loadu_ps(y + 7 * 16),
1034
+ v0,
1035
+ v1,
1036
+ v2,
1037
+ v3,
1038
+ v4,
1039
+ v5,
1040
+ v6,
1041
+ v7);
1042
+
1043
+ // compute distances
1044
+ __m512 distances = _mm512_mul_ps(m0, v0);
1045
+ distances = _mm512_fmadd_ps(m1, v1, distances);
1046
+ distances = _mm512_fmadd_ps(m2, v2, distances);
1047
+ distances = _mm512_fmadd_ps(m3, v3, distances);
1048
+ distances = _mm512_fmadd_ps(m4, v4, distances);
1049
+ distances = _mm512_fmadd_ps(m5, v5, distances);
1050
+ distances = _mm512_fmadd_ps(m6, v6, distances);
1051
+ distances = _mm512_fmadd_ps(m7, v7, distances);
1052
+
1053
+ // store
1054
+ _mm512_storeu_ps(dis + i, distances);
1055
+
1056
+ y += 128; // 16 floats * 8 rows
1057
+ }
1058
+ }
1059
+
1060
+ if (i < ny) {
1061
+ // process leftovers
1062
+ __m256 x0 = _mm256_loadu_ps(x);
1063
+
1064
+ for (; i < ny; i++) {
1065
+ __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
1066
+ y += 8;
1067
+ dis[i] = horizontal_sum(accu);
1068
+ }
1069
+ }
1070
+ }
1071
+
1072
+ template <>
1073
+ void fvec_op_ny_D8<ElementOpL2>(
1074
+ float* dis,
1075
+ const float* x,
1076
+ const float* y,
1077
+ size_t ny) {
1078
+ const size_t ny16 = ny / 16;
1079
+ size_t i = 0;
1080
+
1081
+ if (ny16 > 0) {
1082
+ // process 16 D16-vectors per loop.
1083
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1084
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1085
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1086
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1087
+ const __m512 m4 = _mm512_set1_ps(x[4]);
1088
+ const __m512 m5 = _mm512_set1_ps(x[5]);
1089
+ const __m512 m6 = _mm512_set1_ps(x[6]);
1090
+ const __m512 m7 = _mm512_set1_ps(x[7]);
1091
+
1092
+ for (i = 0; i < ny16 * 16; i += 16) {
1093
+ // load 16x8 matrix and transpose it in registers.
1094
+ // the typical bottleneck is memory access, so
1095
+ // let's trade instructions for the bandwidth.
1096
+
1097
+ __m512 v0;
1098
+ __m512 v1;
1099
+ __m512 v2;
1100
+ __m512 v3;
1101
+ __m512 v4;
1102
+ __m512 v5;
1103
+ __m512 v6;
1104
+ __m512 v7;
1105
+
1106
+ transpose_16x8(
1107
+ _mm512_loadu_ps(y + 0 * 16),
1108
+ _mm512_loadu_ps(y + 1 * 16),
1109
+ _mm512_loadu_ps(y + 2 * 16),
1110
+ _mm512_loadu_ps(y + 3 * 16),
1111
+ _mm512_loadu_ps(y + 4 * 16),
1112
+ _mm512_loadu_ps(y + 5 * 16),
1113
+ _mm512_loadu_ps(y + 6 * 16),
1114
+ _mm512_loadu_ps(y + 7 * 16),
1115
+ v0,
1116
+ v1,
1117
+ v2,
1118
+ v3,
1119
+ v4,
1120
+ v5,
1121
+ v6,
1122
+ v7);
1123
+
1124
+ // compute differences
1125
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1126
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1127
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
1128
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
1129
+ const __m512 d4 = _mm512_sub_ps(m4, v4);
1130
+ const __m512 d5 = _mm512_sub_ps(m5, v5);
1131
+ const __m512 d6 = _mm512_sub_ps(m6, v6);
1132
+ const __m512 d7 = _mm512_sub_ps(m7, v7);
1133
+
1134
+ // compute squares of differences
1135
+ __m512 distances = _mm512_mul_ps(d0, d0);
1136
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1137
+ distances = _mm512_fmadd_ps(d2, d2, distances);
1138
+ distances = _mm512_fmadd_ps(d3, d3, distances);
1139
+ distances = _mm512_fmadd_ps(d4, d4, distances);
1140
+ distances = _mm512_fmadd_ps(d5, d5, distances);
1141
+ distances = _mm512_fmadd_ps(d6, d6, distances);
1142
+ distances = _mm512_fmadd_ps(d7, d7, distances);
1143
+
1144
+ // store
1145
+ _mm512_storeu_ps(dis + i, distances);
1146
+
1147
+ y += 128; // 16 floats * 8 rows
1148
+ }
1149
+ }
1150
+
1151
+ if (i < ny) {
1152
+ // process leftovers
1153
+ __m256 x0 = _mm256_loadu_ps(x);
1154
+
1155
+ for (; i < ny; i++) {
1156
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1157
+ y += 8;
1158
+ dis[i] = horizontal_sum(accu);
1159
+ }
1160
+ }
1161
+ }
1162
+
1163
+ #elif defined(__AVX2__)
1164
+
1165
+ template <>
1166
+ void fvec_op_ny_D8<ElementOpIP>(
1167
+ float* dis,
1168
+ const float* x,
1169
+ const float* y,
1170
+ size_t ny) {
1171
+ const size_t ny8 = ny / 8;
1172
+ size_t i = 0;
1173
+
1174
+ if (ny8 > 0) {
1175
+ // process 8 D8-vectors per loop.
1176
+ const __m256 m0 = _mm256_set1_ps(x[0]);
1177
+ const __m256 m1 = _mm256_set1_ps(x[1]);
1178
+ const __m256 m2 = _mm256_set1_ps(x[2]);
1179
+ const __m256 m3 = _mm256_set1_ps(x[3]);
1180
+ const __m256 m4 = _mm256_set1_ps(x[4]);
1181
+ const __m256 m5 = _mm256_set1_ps(x[5]);
1182
+ const __m256 m6 = _mm256_set1_ps(x[6]);
1183
+ const __m256 m7 = _mm256_set1_ps(x[7]);
1184
+
1185
+ for (i = 0; i < ny8 * 8; i += 8) {
1186
+ // load 8x8 matrix and transpose it in registers.
1187
+ // the typical bottleneck is memory access, so
1188
+ // let's trade instructions for the bandwidth.
1189
+
1190
+ __m256 v0;
1191
+ __m256 v1;
1192
+ __m256 v2;
1193
+ __m256 v3;
1194
+ __m256 v4;
1195
+ __m256 v5;
1196
+ __m256 v6;
1197
+ __m256 v7;
1198
+
1199
+ transpose_8x8(
1200
+ _mm256_loadu_ps(y + 0 * 8),
1201
+ _mm256_loadu_ps(y + 1 * 8),
1202
+ _mm256_loadu_ps(y + 2 * 8),
1203
+ _mm256_loadu_ps(y + 3 * 8),
1204
+ _mm256_loadu_ps(y + 4 * 8),
1205
+ _mm256_loadu_ps(y + 5 * 8),
1206
+ _mm256_loadu_ps(y + 6 * 8),
1207
+ _mm256_loadu_ps(y + 7 * 8),
1208
+ v0,
1209
+ v1,
1210
+ v2,
1211
+ v3,
1212
+ v4,
1213
+ v5,
1214
+ v6,
1215
+ v7);
1216
+
1217
+ // compute distances
1218
+ __m256 distances = _mm256_mul_ps(m0, v0);
1219
+ distances = _mm256_fmadd_ps(m1, v1, distances);
1220
+ distances = _mm256_fmadd_ps(m2, v2, distances);
1221
+ distances = _mm256_fmadd_ps(m3, v3, distances);
1222
+ distances = _mm256_fmadd_ps(m4, v4, distances);
1223
+ distances = _mm256_fmadd_ps(m5, v5, distances);
1224
+ distances = _mm256_fmadd_ps(m6, v6, distances);
1225
+ distances = _mm256_fmadd_ps(m7, v7, distances);
1226
+
1227
+ // store
1228
+ _mm256_storeu_ps(dis + i, distances);
1229
+
1230
+ y += 64;
1231
+ }
1232
+ }
1233
+
1234
+ if (i < ny) {
1235
+ // process leftovers
1236
+ __m256 x0 = _mm256_loadu_ps(x);
1237
+
1238
+ for (; i < ny; i++) {
1239
+ __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y));
1240
+ y += 8;
1241
+ dis[i] = horizontal_sum(accu);
1242
+ }
1243
+ }
1244
+ }
1245
+
1246
+ template <>
1247
+ void fvec_op_ny_D8<ElementOpL2>(
1248
+ float* dis,
1249
+ const float* x,
1250
+ const float* y,
1251
+ size_t ny) {
1252
+ const size_t ny8 = ny / 8;
1253
+ size_t i = 0;
1254
+
1255
+ if (ny8 > 0) {
1256
+ // process 8 D8-vectors per loop.
1257
+ const __m256 m0 = _mm256_set1_ps(x[0]);
1258
+ const __m256 m1 = _mm256_set1_ps(x[1]);
1259
+ const __m256 m2 = _mm256_set1_ps(x[2]);
1260
+ const __m256 m3 = _mm256_set1_ps(x[3]);
1261
+ const __m256 m4 = _mm256_set1_ps(x[4]);
1262
+ const __m256 m5 = _mm256_set1_ps(x[5]);
1263
+ const __m256 m6 = _mm256_set1_ps(x[6]);
1264
+ const __m256 m7 = _mm256_set1_ps(x[7]);
1265
+
1266
+ for (i = 0; i < ny8 * 8; i += 8) {
1267
+ // load 8x8 matrix and transpose it in registers.
1268
+ // the typical bottleneck is memory access, so
1269
+ // let's trade instructions for the bandwidth.
1270
+
1271
+ __m256 v0;
1272
+ __m256 v1;
1273
+ __m256 v2;
1274
+ __m256 v3;
1275
+ __m256 v4;
1276
+ __m256 v5;
1277
+ __m256 v6;
1278
+ __m256 v7;
1279
+
1280
+ transpose_8x8(
1281
+ _mm256_loadu_ps(y + 0 * 8),
1282
+ _mm256_loadu_ps(y + 1 * 8),
1283
+ _mm256_loadu_ps(y + 2 * 8),
1284
+ _mm256_loadu_ps(y + 3 * 8),
1285
+ _mm256_loadu_ps(y + 4 * 8),
1286
+ _mm256_loadu_ps(y + 5 * 8),
1287
+ _mm256_loadu_ps(y + 6 * 8),
1288
+ _mm256_loadu_ps(y + 7 * 8),
1289
+ v0,
1290
+ v1,
1291
+ v2,
1292
+ v3,
1293
+ v4,
1294
+ v5,
1295
+ v6,
1296
+ v7);
1297
+
1298
+ // compute differences
1299
+ const __m256 d0 = _mm256_sub_ps(m0, v0);
1300
+ const __m256 d1 = _mm256_sub_ps(m1, v1);
1301
+ const __m256 d2 = _mm256_sub_ps(m2, v2);
1302
+ const __m256 d3 = _mm256_sub_ps(m3, v3);
1303
+ const __m256 d4 = _mm256_sub_ps(m4, v4);
1304
+ const __m256 d5 = _mm256_sub_ps(m5, v5);
1305
+ const __m256 d6 = _mm256_sub_ps(m6, v6);
1306
+ const __m256 d7 = _mm256_sub_ps(m7, v7);
1307
+
1308
+ // compute squares of differences
1309
+ __m256 distances = _mm256_mul_ps(d0, d0);
1310
+ distances = _mm256_fmadd_ps(d1, d1, distances);
1311
+ distances = _mm256_fmadd_ps(d2, d2, distances);
1312
+ distances = _mm256_fmadd_ps(d3, d3, distances);
1313
+ distances = _mm256_fmadd_ps(d4, d4, distances);
1314
+ distances = _mm256_fmadd_ps(d5, d5, distances);
1315
+ distances = _mm256_fmadd_ps(d6, d6, distances);
1316
+ distances = _mm256_fmadd_ps(d7, d7, distances);
1317
+
1318
+ // store
1319
+ _mm256_storeu_ps(dis + i, distances);
1320
+
1321
+ y += 64;
1322
+ }
1323
+ }
1324
+
1325
+ if (i < ny) {
1326
+ // process leftovers
1327
+ __m256 x0 = _mm256_loadu_ps(x);
1328
+
1329
+ for (; i < ny; i++) {
1330
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1331
+ y += 8;
1332
+ dis[i] = horizontal_sum(accu);
1333
+ }
1334
+ }
1335
+ }
1336
+
1337
+ #endif
1338
+
1339
+ template <class ElementOp>
1340
+ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
1341
+ __m128 x0 = _mm_loadu_ps(x);
1342
+ __m128 x1 = _mm_loadu_ps(x + 4);
1343
+ __m128 x2 = _mm_loadu_ps(x + 8);
1344
+
1345
+ for (size_t i = 0; i < ny; i++) {
1346
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
1347
+ y += 4;
1348
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
1349
+ y += 4;
1350
+ accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
1351
+ y += 4;
1352
+ dis[i] = horizontal_sum(accu);
1353
+ }
1354
+ }
1355
+
1356
+ } // anonymous namespace
1357
+
1358
+ void fvec_L2sqr_ny(
1359
+ float* dis,
1360
+ const float* x,
1361
+ const float* y,
1362
+ size_t d,
1363
+ size_t ny) {
1364
+ // optimized for a few special cases
1365
+
1366
+ #define DISPATCH(dval) \
1367
+ case dval: \
1368
+ fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
1369
+ return;
1370
+
1371
+ switch (d) {
1372
+ DISPATCH(1)
1373
+ DISPATCH(2)
1374
+ DISPATCH(4)
1375
+ DISPATCH(8)
1376
+ DISPATCH(12)
1377
+ default:
1378
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
1379
+ return;
1380
+ }
1381
+ #undef DISPATCH
1382
+ }
1383
+
1384
+ void fvec_inner_products_ny(
1385
+ float* dis,
1386
+ const float* x,
1387
+ const float* y,
1388
+ size_t d,
1389
+ size_t ny) {
1390
+ #define DISPATCH(dval) \
1391
+ case dval: \
1392
+ fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
555
1393
  return;
556
1394
 
557
1395
  switch (d) {
@@ -564,121 +1402,506 @@ void fvec_inner_products_ny(
564
1402
  fvec_inner_products_ny_ref(dis, x, y, d, ny);
565
1403
  return;
566
1404
  }
567
- #undef DISPATCH
1405
+ #undef DISPATCH
1406
+ }
1407
+
1408
+ #if defined(__AVX512F__)
1409
+
1410
+ template <size_t DIM>
1411
+ void fvec_L2sqr_ny_y_transposed_D(
1412
+ float* distances,
1413
+ const float* x,
1414
+ const float* y,
1415
+ const float* y_sqlen,
1416
+ const size_t d_offset,
1417
+ size_t ny) {
1418
+ // current index being processed
1419
+ size_t i = 0;
1420
+
1421
+ // squared length of x
1422
+ float x_sqlen = 0;
1423
+ for (size_t j = 0; j < DIM; j++) {
1424
+ x_sqlen += x[j] * x[j];
1425
+ }
1426
+
1427
+ // process 16 vectors per loop
1428
+ const size_t ny16 = ny / 16;
1429
+
1430
+ if (ny16 > 0) {
1431
+ // m[i] = (2 * x[i], ... 2 * x[i])
1432
+ __m512 m[DIM];
1433
+ for (size_t j = 0; j < DIM; j++) {
1434
+ m[j] = _mm512_set1_ps(x[j]);
1435
+ m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j]
1436
+ }
1437
+
1438
+ __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen);
1439
+
1440
+ for (; i < ny16 * 16; i += 16) {
1441
+ // Load vectors for 16 dimensions
1442
+ __m512 v[DIM];
1443
+ for (size_t j = 0; j < DIM; j++) {
1444
+ v[j] = _mm512_loadu_ps(y + j * d_offset);
1445
+ }
1446
+
1447
+ // Compute dot products
1448
+ __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm);
1449
+ for (size_t j = 1; j < DIM; j++) {
1450
+ dp = _mm512_fnmadd_ps(m[j], v[j], dp);
1451
+ }
1452
+
1453
+ // Compute y^2 - (2 * x, y) + x^2
1454
+ __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp);
1455
+
1456
+ _mm512_storeu_ps(distances + i, distances_v);
1457
+
1458
+ // Scroll y and y_sqlen forward
1459
+ y += 16;
1460
+ y_sqlen += 16;
1461
+ }
1462
+ }
1463
+
1464
+ if (i < ny) {
1465
+ // Process leftovers
1466
+ for (; i < ny; i++) {
1467
+ float dp = 0;
1468
+ for (size_t j = 0; j < DIM; j++) {
1469
+ dp += x[j] * y[j * d_offset];
1470
+ }
1471
+
1472
+ // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
1473
+ // lowest distance.
1474
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1475
+ distances[i] = distance;
1476
+
1477
+ y += 1;
1478
+ y_sqlen += 1;
1479
+ }
1480
+ }
1481
+ }
1482
+
1483
+ #elif defined(__AVX2__)
1484
+
1485
+ template <size_t DIM>
1486
+ void fvec_L2sqr_ny_y_transposed_D(
1487
+ float* distances,
1488
+ const float* x,
1489
+ const float* y,
1490
+ const float* y_sqlen,
1491
+ const size_t d_offset,
1492
+ size_t ny) {
1493
+ // current index being processed
1494
+ size_t i = 0;
1495
+
1496
+ // squared length of x
1497
+ float x_sqlen = 0;
1498
+ for (size_t j = 0; j < DIM; j++) {
1499
+ x_sqlen += x[j] * x[j];
1500
+ }
1501
+
1502
+ // process 8 vectors per loop.
1503
+ const size_t ny8 = ny / 8;
1504
+
1505
+ if (ny8 > 0) {
1506
+ // m[i] = (2 * x[i], ... 2 * x[i])
1507
+ __m256 m[DIM];
1508
+ for (size_t j = 0; j < DIM; j++) {
1509
+ m[j] = _mm256_set1_ps(x[j]);
1510
+ m[j] = _mm256_add_ps(m[j], m[j]);
1511
+ }
1512
+
1513
+ __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
1514
+
1515
+ for (; i < ny8 * 8; i += 8) {
1516
+ // collect dim 0 for 8 D4-vectors.
1517
+ const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
1518
+
1519
+ // compute dot products
1520
+ // this is x^2 - 2x[0]*y[0]
1521
+ __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
1522
+
1523
+ for (size_t j = 1; j < DIM; j++) {
1524
+ // collect dim j for 8 D4-vectors.
1525
+ const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
1526
+ dp = _mm256_fnmadd_ps(m[j], vj, dp);
1527
+ }
1528
+
1529
+ // we've got x^2 - (2x, y) at this point
1530
+
1531
+ // y^2 - (2x, y) + x^2
1532
+ __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
1533
+
1534
+ _mm256_storeu_ps(distances + i, distances_v);
1535
+
1536
+ // scroll y and y_sqlen forward.
1537
+ y += 8;
1538
+ y_sqlen += 8;
1539
+ }
1540
+ }
1541
+
1542
+ if (i < ny) {
1543
+ // process leftovers
1544
+ for (; i < ny; i++) {
1545
+ float dp = 0;
1546
+ for (size_t j = 0; j < DIM; j++) {
1547
+ dp += x[j] * y[j * d_offset];
1548
+ }
1549
+
1550
+ // compute y^2 - 2 * (x, y), which is sufficient for looking for the
1551
+ // lowest distance.
1552
+ const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
1553
+ distances[i] = distance;
1554
+
1555
+ y += 1;
1556
+ y_sqlen += 1;
1557
+ }
1558
+ }
1559
+ }
1560
+
1561
+ #endif
1562
+
1563
+ void fvec_L2sqr_ny_transposed(
1564
+ float* dis,
1565
+ const float* x,
1566
+ const float* y,
1567
+ const float* y_sqlen,
1568
+ size_t d,
1569
+ size_t d_offset,
1570
+ size_t ny) {
1571
+ // optimized for a few special cases
1572
+
1573
+ #ifdef __AVX2__
1574
+ #define DISPATCH(dval) \
1575
+ case dval: \
1576
+ return fvec_L2sqr_ny_y_transposed_D<dval>( \
1577
+ dis, x, y, y_sqlen, d_offset, ny);
1578
+
1579
+ switch (d) {
1580
+ DISPATCH(1)
1581
+ DISPATCH(2)
1582
+ DISPATCH(4)
1583
+ DISPATCH(8)
1584
+ default:
1585
+ return fvec_L2sqr_ny_y_transposed_ref(
1586
+ dis, x, y, y_sqlen, d, d_offset, ny);
1587
+ }
1588
+ #undef DISPATCH
1589
+ #else
1590
+ // non-AVX2 case
1591
+ return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
1592
+ #endif
1593
+ }
1594
+
1595
+ #if defined(__AVX512F__)
1596
+
1597
+ size_t fvec_L2sqr_ny_nearest_D2(
1598
+ float* distances_tmp_buffer,
1599
+ const float* x,
1600
+ const float* y,
1601
+ size_t ny) {
1602
+ // this implementation does not use distances_tmp_buffer.
1603
+
1604
+ size_t i = 0;
1605
+ float current_min_distance = HUGE_VALF;
1606
+ size_t current_min_index = 0;
1607
+
1608
+ const size_t ny16 = ny / 16;
1609
+ if (ny16 > 0) {
1610
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
1611
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
1612
+
1613
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1614
+ __m512i min_indices = _mm512_set1_epi32(0);
1615
+
1616
+ __m512i current_indices = _mm512_setr_epi32(
1617
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1618
+ const __m512i indices_increment = _mm512_set1_epi32(16);
1619
+
1620
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1621
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1622
+
1623
+ for (; i < ny16 * 16; i += 16) {
1624
+ _mm_prefetch((const char*)(y + 64), _MM_HINT_T0);
1625
+
1626
+ __m512 v0;
1627
+ __m512 v1;
1628
+
1629
+ transpose_16x2(
1630
+ _mm512_loadu_ps(y + 0 * 16),
1631
+ _mm512_loadu_ps(y + 1 * 16),
1632
+ v0,
1633
+ v1);
1634
+
1635
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1636
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1637
+
1638
+ __m512 distances = _mm512_mul_ps(d0, d0);
1639
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1640
+
1641
+ __mmask16 comparison =
1642
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1643
+
1644
+ min_distances = _mm512_min_ps(distances, min_distances);
1645
+ min_indices = _mm512_mask_blend_epi32(
1646
+ comparison, min_indices, current_indices);
1647
+
1648
+ current_indices =
1649
+ _mm512_add_epi32(current_indices, indices_increment);
1650
+
1651
+ y += 32;
1652
+ }
1653
+
1654
+ alignas(64) float min_distances_scalar[16];
1655
+ alignas(64) uint32_t min_indices_scalar[16];
1656
+ _mm512_store_ps(min_distances_scalar, min_distances);
1657
+ _mm512_store_epi32(min_indices_scalar, min_indices);
1658
+
1659
+ for (size_t j = 0; j < 16; j++) {
1660
+ if (current_min_distance > min_distances_scalar[j]) {
1661
+ current_min_distance = min_distances_scalar[j];
1662
+ current_min_index = min_indices_scalar[j];
1663
+ }
1664
+ }
1665
+ }
1666
+
1667
+ if (i < ny) {
1668
+ float x0 = x[0];
1669
+ float x1 = x[1];
1670
+
1671
+ for (; i < ny; i++) {
1672
+ float sub0 = x0 - y[0];
1673
+ float sub1 = x1 - y[1];
1674
+ float distance = sub0 * sub0 + sub1 * sub1;
1675
+
1676
+ y += 2;
1677
+
1678
+ if (current_min_distance > distance) {
1679
+ current_min_distance = distance;
1680
+ current_min_index = i;
1681
+ }
1682
+ }
1683
+ }
1684
+
1685
+ return current_min_index;
568
1686
  }
569
1687
 
570
- #ifdef __AVX2__
571
- template <size_t DIM>
572
- void fvec_L2sqr_ny_y_transposed_D(
573
- float* distances,
1688
+ size_t fvec_L2sqr_ny_nearest_D4(
1689
+ float* distances_tmp_buffer,
574
1690
  const float* x,
575
1691
  const float* y,
576
- const float* y_sqlen,
577
- const size_t d_offset,
578
1692
  size_t ny) {
579
- // current index being processed
1693
+ // this implementation does not use distances_tmp_buffer.
1694
+
580
1695
  size_t i = 0;
1696
+ float current_min_distance = HUGE_VALF;
1697
+ size_t current_min_index = 0;
581
1698
 
582
- // squared length of x
583
- float x_sqlen = 0;
584
- ;
585
- for (size_t j = 0; j < DIM; j++) {
586
- x_sqlen += x[j] * x[j];
587
- }
1699
+ const size_t ny16 = ny / 16;
588
1700
 
589
- // process 8 vectors per loop.
590
- const size_t ny8 = ny / 8;
1701
+ if (ny16 > 0) {
1702
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1703
+ __m512i min_indices = _mm512_set1_epi32(0);
591
1704
 
592
- if (ny8 > 0) {
593
- // m[i] = (2 * x[i], ... 2 * x[i])
594
- __m256 m[DIM];
595
- for (size_t j = 0; j < DIM; j++) {
596
- m[j] = _mm256_set1_ps(x[j]);
597
- m[j] = _mm256_add_ps(m[j], m[j]);
598
- }
1705
+ __m512i current_indices = _mm512_setr_epi32(
1706
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1707
+ const __m512i indices_increment = _mm512_set1_epi32(16);
599
1708
 
600
- __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen);
1709
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1710
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1711
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1712
+ const __m512 m3 = _mm512_set1_ps(x[3]);
601
1713
 
602
- for (; i < ny8 * 8; i += 8) {
603
- // collect dim 0 for 8 D4-vectors.
604
- const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset);
1714
+ for (; i < ny16 * 16; i += 16) {
1715
+ __m512 v0;
1716
+ __m512 v1;
1717
+ __m512 v2;
1718
+ __m512 v3;
605
1719
 
606
- // compute dot products
607
- // this is x^2 - 2x[0]*y[0]
608
- __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm);
1720
+ transpose_16x4(
1721
+ _mm512_loadu_ps(y + 0 * 16),
1722
+ _mm512_loadu_ps(y + 1 * 16),
1723
+ _mm512_loadu_ps(y + 2 * 16),
1724
+ _mm512_loadu_ps(y + 3 * 16),
1725
+ v0,
1726
+ v1,
1727
+ v2,
1728
+ v3);
609
1729
 
610
- for (size_t j = 1; j < DIM; j++) {
611
- // collect dim j for 8 D4-vectors.
612
- const __m256 vj = _mm256_loadu_ps(y + j * d_offset);
613
- dp = _mm256_fnmadd_ps(m[j], vj, dp);
614
- }
1730
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1731
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1732
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
1733
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
615
1734
 
616
- // we've got x^2 - (2x, y) at this point
1735
+ __m512 distances = _mm512_mul_ps(d0, d0);
1736
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1737
+ distances = _mm512_fmadd_ps(d2, d2, distances);
1738
+ distances = _mm512_fmadd_ps(d3, d3, distances);
617
1739
 
618
- // y^2 - (2x, y) + x^2
619
- __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp);
1740
+ __mmask16 comparison =
1741
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
620
1742
 
621
- _mm256_storeu_ps(distances + i, distances_v);
1743
+ min_distances = _mm512_min_ps(distances, min_distances);
1744
+ min_indices = _mm512_mask_blend_epi32(
1745
+ comparison, min_indices, current_indices);
622
1746
 
623
- // scroll y and y_sqlen forward.
624
- y += 8;
625
- y_sqlen += 8;
1747
+ current_indices =
1748
+ _mm512_add_epi32(current_indices, indices_increment);
1749
+
1750
+ y += 64;
1751
+ }
1752
+
1753
+ alignas(64) float min_distances_scalar[16];
1754
+ alignas(64) uint32_t min_indices_scalar[16];
1755
+ _mm512_store_ps(min_distances_scalar, min_distances);
1756
+ _mm512_store_epi32(min_indices_scalar, min_indices);
1757
+
1758
+ for (size_t j = 0; j < 16; j++) {
1759
+ if (current_min_distance > min_distances_scalar[j]) {
1760
+ current_min_distance = min_distances_scalar[j];
1761
+ current_min_index = min_indices_scalar[j];
1762
+ }
626
1763
  }
627
1764
  }
628
1765
 
629
1766
  if (i < ny) {
630
- // process leftovers
631
- for (; i < ny; i++) {
632
- float dp = 0;
633
- for (size_t j = 0; j < DIM; j++) {
634
- dp += x[j] * y[j * d_offset];
635
- }
1767
+ __m128 x0 = _mm_loadu_ps(x);
636
1768
 
637
- // compute y^2 - 2 * (x, y), which is sufficient for looking for the
638
- // lowest distance.
639
- const float distance = y_sqlen[0] - 2 * dp + x_sqlen;
640
- distances[i] = distance;
1769
+ for (; i < ny; i++) {
1770
+ __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
1771
+ y += 4;
1772
+ const float distance = horizontal_sum(accu);
641
1773
 
642
- y += 1;
643
- y_sqlen += 1;
1774
+ if (current_min_distance > distance) {
1775
+ current_min_distance = distance;
1776
+ current_min_index = i;
1777
+ }
644
1778
  }
645
1779
  }
1780
+
1781
+ return current_min_index;
646
1782
  }
647
- #endif
648
1783
 
649
- void fvec_L2sqr_ny_transposed(
650
- float* dis,
1784
+ size_t fvec_L2sqr_ny_nearest_D8(
1785
+ float* distances_tmp_buffer,
651
1786
  const float* x,
652
1787
  const float* y,
653
- const float* y_sqlen,
654
- size_t d,
655
- size_t d_offset,
656
1788
  size_t ny) {
657
- // optimized for a few special cases
1789
+ // this implementation does not use distances_tmp_buffer.
658
1790
 
659
- #ifdef __AVX2__
660
- #define DISPATCH(dval) \
661
- case dval: \
662
- return fvec_L2sqr_ny_y_transposed_D<dval>( \
663
- dis, x, y, y_sqlen, d_offset, ny);
1791
+ size_t i = 0;
1792
+ float current_min_distance = HUGE_VALF;
1793
+ size_t current_min_index = 0;
664
1794
 
665
- switch (d) {
666
- DISPATCH(1)
667
- DISPATCH(2)
668
- DISPATCH(4)
669
- DISPATCH(8)
670
- default:
671
- return fvec_L2sqr_ny_y_transposed_ref(
672
- dis, x, y, y_sqlen, d, d_offset, ny);
1795
+ const size_t ny16 = ny / 16;
1796
+ if (ny16 > 0) {
1797
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
1798
+ __m512i min_indices = _mm512_set1_epi32(0);
1799
+
1800
+ __m512i current_indices = _mm512_setr_epi32(
1801
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1802
+ const __m512i indices_increment = _mm512_set1_epi32(16);
1803
+
1804
+ const __m512 m0 = _mm512_set1_ps(x[0]);
1805
+ const __m512 m1 = _mm512_set1_ps(x[1]);
1806
+ const __m512 m2 = _mm512_set1_ps(x[2]);
1807
+ const __m512 m3 = _mm512_set1_ps(x[3]);
1808
+
1809
+ const __m512 m4 = _mm512_set1_ps(x[4]);
1810
+ const __m512 m5 = _mm512_set1_ps(x[5]);
1811
+ const __m512 m6 = _mm512_set1_ps(x[6]);
1812
+ const __m512 m7 = _mm512_set1_ps(x[7]);
1813
+
1814
+ for (; i < ny16 * 16; i += 16) {
1815
+ __m512 v0;
1816
+ __m512 v1;
1817
+ __m512 v2;
1818
+ __m512 v3;
1819
+ __m512 v4;
1820
+ __m512 v5;
1821
+ __m512 v6;
1822
+ __m512 v7;
1823
+
1824
+ transpose_16x8(
1825
+ _mm512_loadu_ps(y + 0 * 16),
1826
+ _mm512_loadu_ps(y + 1 * 16),
1827
+ _mm512_loadu_ps(y + 2 * 16),
1828
+ _mm512_loadu_ps(y + 3 * 16),
1829
+ _mm512_loadu_ps(y + 4 * 16),
1830
+ _mm512_loadu_ps(y + 5 * 16),
1831
+ _mm512_loadu_ps(y + 6 * 16),
1832
+ _mm512_loadu_ps(y + 7 * 16),
1833
+ v0,
1834
+ v1,
1835
+ v2,
1836
+ v3,
1837
+ v4,
1838
+ v5,
1839
+ v6,
1840
+ v7);
1841
+
1842
+ const __m512 d0 = _mm512_sub_ps(m0, v0);
1843
+ const __m512 d1 = _mm512_sub_ps(m1, v1);
1844
+ const __m512 d2 = _mm512_sub_ps(m2, v2);
1845
+ const __m512 d3 = _mm512_sub_ps(m3, v3);
1846
+ const __m512 d4 = _mm512_sub_ps(m4, v4);
1847
+ const __m512 d5 = _mm512_sub_ps(m5, v5);
1848
+ const __m512 d6 = _mm512_sub_ps(m6, v6);
1849
+ const __m512 d7 = _mm512_sub_ps(m7, v7);
1850
+
1851
+ __m512 distances = _mm512_mul_ps(d0, d0);
1852
+ distances = _mm512_fmadd_ps(d1, d1, distances);
1853
+ distances = _mm512_fmadd_ps(d2, d2, distances);
1854
+ distances = _mm512_fmadd_ps(d3, d3, distances);
1855
+ distances = _mm512_fmadd_ps(d4, d4, distances);
1856
+ distances = _mm512_fmadd_ps(d5, d5, distances);
1857
+ distances = _mm512_fmadd_ps(d6, d6, distances);
1858
+ distances = _mm512_fmadd_ps(d7, d7, distances);
1859
+
1860
+ __mmask16 comparison =
1861
+ _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS);
1862
+
1863
+ min_distances = _mm512_min_ps(distances, min_distances);
1864
+ min_indices = _mm512_mask_blend_epi32(
1865
+ comparison, min_indices, current_indices);
1866
+
1867
+ current_indices =
1868
+ _mm512_add_epi32(current_indices, indices_increment);
1869
+
1870
+ y += 128;
1871
+ }
1872
+
1873
+ alignas(64) float min_distances_scalar[16];
1874
+ alignas(64) uint32_t min_indices_scalar[16];
1875
+ _mm512_store_ps(min_distances_scalar, min_distances);
1876
+ _mm512_store_epi32(min_indices_scalar, min_indices);
1877
+
1878
+ for (size_t j = 0; j < 16; j++) {
1879
+ if (current_min_distance > min_distances_scalar[j]) {
1880
+ current_min_distance = min_distances_scalar[j];
1881
+ current_min_index = min_indices_scalar[j];
1882
+ }
1883
+ }
673
1884
  }
674
- #undef DISPATCH
675
- #else
676
- // non-AVX2 case
677
- return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny);
678
- #endif
1885
+
1886
+ if (i < ny) {
1887
+ __m256 x0 = _mm256_loadu_ps(x);
1888
+
1889
+ for (; i < ny; i++) {
1890
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1891
+ y += 8;
1892
+ const float distance = horizontal_sum(accu);
1893
+
1894
+ if (current_min_distance > distance) {
1895
+ current_min_distance = distance;
1896
+ current_min_index = i;
1897
+ }
1898
+ }
1899
+ }
1900
+
1901
+ return current_min_index;
679
1902
  }
680
1903
 
681
- #ifdef __AVX2__
1904
+ #elif defined(__AVX2__)
682
1905
 
683
1906
  size_t fvec_L2sqr_ny_nearest_D2(
684
1907
  float* distances_tmp_buffer,
@@ -697,8 +1920,8 @@ size_t fvec_L2sqr_ny_nearest_D2(
697
1920
  // process 8 D2-vectors per loop.
698
1921
  const size_t ny8 = ny / 8;
699
1922
  if (ny8 > 0) {
700
- _mm_prefetch(y, _MM_HINT_T0);
701
- _mm_prefetch(y + 16, _MM_HINT_T0);
1923
+ _mm_prefetch((const char*)y, _MM_HINT_T0);
1924
+ _mm_prefetch((const char*)(y + 16), _MM_HINT_T0);
702
1925
 
703
1926
  // track min distance and the closest vector independently
704
1927
  // for each of 8 AVX2 components.
@@ -713,7 +1936,7 @@ size_t fvec_L2sqr_ny_nearest_D2(
713
1936
  const __m256 m1 = _mm256_set1_ps(x[1]);
714
1937
 
715
1938
  for (; i < ny8 * 8; i += 8) {
716
- _mm_prefetch(y + 32, _MM_HINT_T0);
1939
+ _mm_prefetch((const char*)(y + 32), _MM_HINT_T0);
717
1940
 
718
1941
  __m256 v0;
719
1942
  __m256 v1;
@@ -892,10 +2115,7 @@ size_t fvec_L2sqr_ny_nearest_D4(
892
2115
  for (; i < ny; i++) {
893
2116
  __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y));
894
2117
  y += 4;
895
- accu = _mm_hadd_ps(accu, accu);
896
- accu = _mm_hadd_ps(accu, accu);
897
-
898
- const auto distance = _mm_cvtss_f32(accu);
2118
+ const float distance = horizontal_sum(accu);
899
2119
 
900
2120
  if (current_min_distance > distance) {
901
2121
  current_min_distance = distance;
@@ -1031,23 +2251,9 @@ size_t fvec_L2sqr_ny_nearest_D8(
1031
2251
  __m256 x0 = _mm256_loadu_ps(x);
1032
2252
 
1033
2253
  for (; i < ny; i++) {
1034
- __m256 sub = _mm256_sub_ps(x0, _mm256_loadu_ps(y));
1035
- __m256 accu = _mm256_mul_ps(sub, sub);
2254
+ __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y));
1036
2255
  y += 8;
1037
-
1038
- // horitontal sum
1039
- const __m256 h0 = _mm256_hadd_ps(accu, accu);
1040
- const __m256 h1 = _mm256_hadd_ps(h0, h0);
1041
-
1042
- // extract high and low __m128 regs from __m256
1043
- const __m128 h2 = _mm256_extractf128_ps(h1, 1);
1044
- const __m128 h3 = _mm256_castps256_ps128(h1);
1045
-
1046
- // get a final hsum into all 4 regs
1047
- const __m128 h4 = _mm_add_ss(h2, h3);
1048
-
1049
- // extract f[0] from __m128
1050
- const float distance = _mm_cvtss_f32(h4);
2256
+ const float distance = horizontal_sum(accu);
1051
2257
 
1052
2258
  if (current_min_distance > distance) {
1053
2259
  current_min_distance = distance;
@@ -1106,7 +2312,123 @@ size_t fvec_L2sqr_ny_nearest(
1106
2312
  #undef DISPATCH
1107
2313
  }
1108
2314
 
1109
- #ifdef __AVX2__
2315
+ #if defined(__AVX512F__)
2316
+
2317
+ template <size_t DIM>
2318
+ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
2319
+ float* distances_tmp_buffer,
2320
+ const float* x,
2321
+ const float* y,
2322
+ const float* y_sqlen,
2323
+ const size_t d_offset,
2324
+ size_t ny) {
2325
+ // This implementation does not use distances_tmp_buffer.
2326
+
2327
+ // Current index being processed
2328
+ size_t i = 0;
2329
+
2330
+ // Min distance and the index of the closest vector so far
2331
+ float current_min_distance = HUGE_VALF;
2332
+ size_t current_min_index = 0;
2333
+
2334
+ // Process 16 vectors per loop
2335
+ const size_t ny16 = ny / 16;
2336
+
2337
+ if (ny16 > 0) {
2338
+ // Track min distance and the closest vector independently
2339
+ // for each of 16 AVX-512 components.
2340
+ __m512 min_distances = _mm512_set1_ps(HUGE_VALF);
2341
+ __m512i min_indices = _mm512_set1_epi32(0);
2342
+
2343
+ __m512i current_indices = _mm512_setr_epi32(
2344
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
2345
+ const __m512i indices_increment = _mm512_set1_epi32(16);
2346
+
2347
+ // m[i] = (2 * x[i], ... 2 * x[i])
2348
+ __m512 m[DIM];
2349
+ for (size_t j = 0; j < DIM; j++) {
2350
+ m[j] = _mm512_set1_ps(x[j]);
2351
+ m[j] = _mm512_add_ps(m[j], m[j]);
2352
+ }
2353
+
2354
+ for (; i < ny16 * 16; i += 16) {
2355
+ // Compute dot products
2356
+ const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset);
2357
+ __m512 dp = _mm512_mul_ps(m[0], v0);
2358
+ for (size_t j = 1; j < DIM; j++) {
2359
+ const __m512 vj = _mm512_loadu_ps(y + j * d_offset);
2360
+ dp = _mm512_fmadd_ps(m[j], vj, dp);
2361
+ }
2362
+
2363
+ // Compute y^2 - (2 * x, y), which is sufficient for looking for the
2364
+ // lowest distance.
2365
+ // x^2 is the constant that can be avoided.
2366
+ const __m512 distances =
2367
+ _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp);
2368
+
2369
+ // Compare the new distances to the min distances
2370
+ __mmask16 comparison =
2371
+ _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
2372
+
2373
+ // Update min distances and indices with closest vectors if needed
2374
+ min_distances =
2375
+ _mm512_mask_blend_ps(comparison, distances, min_distances);
2376
+ min_indices = _mm512_castps_si512(_mm512_mask_blend_ps(
2377
+ comparison,
2378
+ _mm512_castsi512_ps(current_indices),
2379
+ _mm512_castsi512_ps(min_indices)));
2380
+
2381
+ // Update current indices values. Basically, +16 to each of the 16
2382
+ // AVX-512 components.
2383
+ current_indices =
2384
+ _mm512_add_epi32(current_indices, indices_increment);
2385
+
2386
+ // Scroll y and y_sqlen forward.
2387
+ y += 16;
2388
+ y_sqlen += 16;
2389
+ }
2390
+
2391
+ // Dump values and find the minimum distance / minimum index
2392
+ float min_distances_scalar[16];
2393
+ uint32_t min_indices_scalar[16];
2394
+ _mm512_storeu_ps(min_distances_scalar, min_distances);
2395
+ _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices);
2396
+
2397
+ for (size_t j = 0; j < 16; j++) {
2398
+ if (current_min_distance > min_distances_scalar[j]) {
2399
+ current_min_distance = min_distances_scalar[j];
2400
+ current_min_index = min_indices_scalar[j];
2401
+ }
2402
+ }
2403
+ }
2404
+
2405
+ if (i < ny) {
2406
+ // Process leftovers
2407
+ for (; i < ny; i++) {
2408
+ float dp = 0;
2409
+ for (size_t j = 0; j < DIM; j++) {
2410
+ dp += x[j] * y[j * d_offset];
2411
+ }
2412
+
2413
+ // Compute y^2 - 2 * (x, y), which is sufficient for looking for the
2414
+ // lowest distance.
2415
+ const float distance = y_sqlen[0] - 2 * dp;
2416
+
2417
+ if (current_min_distance > distance) {
2418
+ current_min_distance = distance;
2419
+ current_min_index = i;
2420
+ }
2421
+
2422
+ y += 1;
2423
+ y_sqlen += 1;
2424
+ }
2425
+ }
2426
+
2427
+ return current_min_index;
2428
+ }
2429
+
2430
+ #elif defined(__AVX2__)
2431
+
1110
2432
  template <size_t DIM>
1111
2433
  size_t fvec_L2sqr_ny_nearest_y_transposed_D(
1112
2434
  float* distances_tmp_buffer,
@@ -1222,6 +2544,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_D(
1222
2544
 
1223
2545
  return current_min_index;
1224
2546
  }
2547
+
1225
2548
  #endif
1226
2549
 
1227
2550
  size_t fvec_L2sqr_ny_nearest_y_transposed(
@@ -1260,21 +2583,6 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
1260
2583
 
1261
2584
  #ifdef USE_AVX
1262
2585
 
1263
- // reads 0 <= d < 8 floats as __m256
1264
- static inline __m256 masked_read_8(int d, const float* x) {
1265
- assert(0 <= d && d < 8);
1266
- if (d < 4) {
1267
- __m256 res = _mm256_setzero_ps();
1268
- res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
1269
- return res;
1270
- } else {
1271
- __m256 res = _mm256_setzero_ps();
1272
- res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
1273
- res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
1274
- return res;
1275
- }
1276
- }
1277
-
1278
2586
  float fvec_L1(const float* x, const float* y, size_t d) {
1279
2587
  __m256 msum1 = _mm256_setzero_ps();
1280
2588
  __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
@@ -1493,7 +2801,7 @@ void fvec_inner_products_ny(
1493
2801
  * heavily optimized table computations
1494
2802
  ***************************************************************************/
1495
2803
 
1496
- static inline void fvec_madd_ref(
2804
+ [[maybe_unused]] static inline void fvec_madd_ref(
1497
2805
  size_t n,
1498
2806
  const float* a,
1499
2807
  float bf,
@@ -1503,7 +2811,39 @@ static inline void fvec_madd_ref(
1503
2811
  c[i] = a[i] + bf * b[i];
1504
2812
  }
1505
2813
 
1506
- #ifdef __AVX2__
2814
+ #if defined(__AVX512F__)
2815
+
2816
+ static inline void fvec_madd_avx512(
2817
+ const size_t n,
2818
+ const float* __restrict a,
2819
+ const float bf,
2820
+ const float* __restrict b,
2821
+ float* __restrict c) {
2822
+ const size_t n16 = n / 16;
2823
+ const size_t n_for_masking = n % 16;
2824
+
2825
+ const __m512 bfmm = _mm512_set1_ps(bf);
2826
+
2827
+ size_t idx = 0;
2828
+ for (idx = 0; idx < n16 * 16; idx += 16) {
2829
+ const __m512 ax = _mm512_loadu_ps(a + idx);
2830
+ const __m512 bx = _mm512_loadu_ps(b + idx);
2831
+ const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
2832
+ _mm512_storeu_ps(c + idx, abmul);
2833
+ }
2834
+
2835
+ if (n_for_masking > 0) {
2836
+ const __mmask16 mask = (1 << n_for_masking) - 1;
2837
+
2838
+ const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx);
2839
+ const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx);
2840
+ const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax);
2841
+ _mm512_mask_storeu_ps(c + idx, mask, abmul);
2842
+ }
2843
+ }
2844
+
2845
+ #elif defined(__AVX2__)
2846
+
1507
2847
  static inline void fvec_madd_avx2(
1508
2848
  const size_t n,
1509
2849
  const float* __restrict a,
@@ -1556,11 +2896,12 @@ static inline void fvec_madd_avx2(
1556
2896
  _mm256_maskstore_ps(c + idx, mask, abmul);
1557
2897
  }
1558
2898
  }
2899
+
1559
2900
  #endif
1560
2901
 
1561
2902
  #ifdef __SSE3__
1562
2903
 
1563
- static inline void fvec_madd_sse(
2904
+ [[maybe_unused]] static inline void fvec_madd_sse(
1564
2905
  size_t n,
1565
2906
  const float* a,
1566
2907
  float bf,
@@ -1581,7 +2922,9 @@ static inline void fvec_madd_sse(
1581
2922
  }
1582
2923
 
1583
2924
  void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
1584
- #ifdef __AVX2__
2925
+ #ifdef __AVX512F__
2926
+ fvec_madd_avx512(n, a, bf, b, c);
2927
+ #elif __AVX2__
1585
2928
  fvec_madd_avx2(n, a, bf, b, c);
1586
2929
  #else
1587
2930
  if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
@@ -1807,10 +3150,13 @@ void pq2_8cents_table(
1807
3150
  switch (nout) {
1808
3151
  case 4:
1809
3152
  ip3.storeu(out + 3 * ldo);
3153
+ [[fallthrough]];
1810
3154
  case 3:
1811
3155
  ip2.storeu(out + 2 * ldo);
3156
+ [[fallthrough]];
1812
3157
  case 2:
1813
3158
  ip1.storeu(out + 1 * ldo);
3159
+ [[fallthrough]];
1814
3160
  case 1:
1815
3161
  ip0.storeu(out);
1816
3162
  }