numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,457 @@
1
+ /**
2
+ * @brief SIMD-accelerated Curved Space Similarity for Skylake.
3
+ * @file include/numkong/curved/skylake.h
4
+ * @author Ash Vardanian
5
+ * @date January 14, 2026
6
+ *
7
+ * @sa include/numkong/curved.h
8
+ *
9
+ * Implements f32 and f64 bilinear forms and Mahalanobis distance using AVX-512:
10
+ * - f32 inputs accumulate in f64 to avoid catastrophic cancellation
11
+ * - f64 inputs use Dot2 algorithm (Ogita-Rump-Oishi 2005) for error compensation
12
+ */
13
+ #ifndef NK_CURVED_SKYLAKE_H
14
+ #define NK_CURVED_SKYLAKE_H
15
+
16
+ #if NK_TARGET_X86_
17
+ #if NK_TARGET_SKYLAKE
18
+
19
+ #include "numkong/types.h"
20
+ #include "numkong/spatial/haswell.h" // `nk_f64_sqrt_haswell`
21
+
22
+ #if defined(__cplusplus)
23
+ extern "C" {
24
+ #endif
25
+
26
+ #if defined(__clang__)
27
+ #pragma clang attribute push(__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,f16c,fma,bmi,bmi2"))), \
28
+ apply_to = function)
29
+ #elif defined(__GNUC__)
30
+ #pragma GCC push_options
31
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
32
+ #endif
33
+
34
+ NK_PUBLIC void nk_bilinear_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
35
+ nk_f64_t *result) {
36
+
37
+ // Default case for arbitrary size `n`
38
+ nk_size_t const tail_length = n % 8;
39
+ nk_size_t const tail_start = n - tail_length;
40
+ __m512d sum_f64x8 = _mm512_setzero_pd();
41
+ __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
42
+
43
+ for (nk_size_t i = 0; i != n; ++i) {
44
+ __m512d a_f64x8 = _mm512_set1_pd((nk_f64_t)a[i]);
45
+ __m512d cb_j_f64x8 = _mm512_setzero_pd();
46
+ __m256 b_f32x8, c_f32x8;
47
+ nk_size_t j = 0;
48
+
49
+ nk_bilinear_f32_skylake_cycle:
50
+ if (j + 8 <= n) {
51
+ b_f32x8 = _mm256_loadu_ps(b + j);
52
+ c_f32x8 = _mm256_loadu_ps(c + i * n + j);
53
+ }
54
+ else {
55
+ b_f32x8 = _mm256_maskz_loadu_ps(tail_mask, b + tail_start);
56
+ c_f32x8 = _mm256_maskz_loadu_ps(tail_mask, c + i * n + tail_start);
57
+ }
58
+ cb_j_f64x8 = _mm512_fmadd_pd(_mm512_cvtps_pd(b_f32x8), _mm512_cvtps_pd(c_f32x8), cb_j_f64x8);
59
+ j += 8;
60
+ if (j < n) goto nk_bilinear_f32_skylake_cycle;
61
+ sum_f64x8 = _mm512_fmadd_pd(a_f64x8, cb_j_f64x8, sum_f64x8);
62
+ }
63
+
64
+ *result = _mm512_reduce_add_pd(sum_f64x8);
65
+ }
66
+
67
+ NK_PUBLIC void nk_mahalanobis_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
68
+ nk_f64_t *result) {
69
+ // We use f64 accumulators to prevent catastrophic cancellation.
70
+ nk_size_t const tail_length = n % 8;
71
+ nk_size_t const tail_start = n - tail_length;
72
+ __m512d sum_f64x8 = _mm512_setzero_pd();
73
+ __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
74
+
75
+ for (nk_size_t i = 0; i != n; ++i) {
76
+ __m512d diff_i_f64x8 = _mm512_set1_pd((nk_f64_t)a[i] - (nk_f64_t)b[i]);
77
+ __m512d cdiff_j_f64x8 = _mm512_setzero_pd();
78
+ __m256 a_j_f32x8, b_j_f32x8, c_f32x8;
79
+ nk_size_t j = 0;
80
+
81
+ // The nested loop is cleaner to implement with a `goto` in this case:
82
+ nk_mahalanobis_f32_skylake_cycle:
83
+ if (j + 8 <= n) {
84
+ a_j_f32x8 = _mm256_loadu_ps(a + j);
85
+ b_j_f32x8 = _mm256_loadu_ps(b + j);
86
+ c_f32x8 = _mm256_loadu_ps(c + i * n + j);
87
+ }
88
+ else {
89
+ a_j_f32x8 = _mm256_maskz_loadu_ps(tail_mask, a + tail_start);
90
+ b_j_f32x8 = _mm256_maskz_loadu_ps(tail_mask, b + tail_start);
91
+ c_f32x8 = _mm256_maskz_loadu_ps(tail_mask, c + i * n + tail_start);
92
+ }
93
+ __m512d diff_j_f64x8 = _mm512_sub_pd(_mm512_cvtps_pd(a_j_f32x8), _mm512_cvtps_pd(b_j_f32x8));
94
+ cdiff_j_f64x8 = _mm512_fmadd_pd(diff_j_f64x8, _mm512_cvtps_pd(c_f32x8), cdiff_j_f64x8);
95
+ j += 8;
96
+ if (j < n) goto nk_mahalanobis_f32_skylake_cycle;
97
+ sum_f64x8 = _mm512_fmadd_pd(diff_i_f64x8, cdiff_j_f64x8, sum_f64x8);
98
+ }
99
+
100
+ nk_f64_t quadratic = _mm512_reduce_add_pd(sum_f64x8);
101
+ *result = nk_f64_sqrt_haswell(quadratic > 0 ? quadratic : 0);
102
+ }
103
+
104
+ NK_PUBLIC void nk_bilinear_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
105
+ nk_f64c_t *results) {
106
+
107
+ // We take into account, that FMS is the same as FMA with a negative multiplier.
108
+ // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
109
+ // This way we can avoid the shuffling and the need for separate real and imaginary parts.
110
+ // For the imaginary part of the product, we would need to swap the real and imaginary parts of
111
+ // one of the vectors. We use f64 accumulators to prevent catastrophic cancellation.
112
+ __m512i const sign_flip_i64x8 = _mm512_set_epi64( //
113
+ 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, //
114
+ 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 //
115
+ );
116
+
117
+ // Default case for arbitrary size `n`
118
+ nk_size_t const tail_length = n % 4;
119
+ nk_size_t const tail_start = n - tail_length;
120
+ __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length * 2);
121
+ nk_f64_t sum_real = 0;
122
+ nk_f64_t sum_imag = 0;
123
+
124
+ for (nk_size_t i = 0; i != n; ++i) {
125
+ nk_f64_t const a_i_real = (nk_f64_t)a[i].real;
126
+ nk_f64_t const a_i_imag = (nk_f64_t)a[i].imag;
127
+ __m512d cb_j_real_f64x8 = _mm512_setzero_pd();
128
+ __m512d cb_j_imag_f64x8 = _mm512_setzero_pd();
129
+ __m256 b_f32x8, c_f32x8;
130
+ nk_size_t j = 0;
131
+
132
+ nk_bilinear_f32c_skylake_cycle:
133
+ if (j + 4 <= n) {
134
+ b_f32x8 = _mm256_loadu_ps((nk_f32_t const *)(b + j));
135
+ c_f32x8 = _mm256_loadu_ps((nk_f32_t const *)(c + i * n + j));
136
+ }
137
+ else {
138
+ b_f32x8 = _mm256_maskz_loadu_ps(tail_mask, (nk_f32_t const *)(b + tail_start));
139
+ c_f32x8 = _mm256_maskz_loadu_ps(tail_mask, (nk_f32_t const *)(c + i * n + tail_start));
140
+ }
141
+ __m512d b_f64x8 = _mm512_cvtps_pd(b_f32x8);
142
+ __m512d c_f64x8 = _mm512_cvtps_pd(c_f32x8);
143
+ // The real part of the product: b.real * c.real - b.imag * c.imag.
144
+ // The subtraction will be performed later with a sign flip.
145
+ cb_j_real_f64x8 = _mm512_fmadd_pd(c_f64x8, b_f64x8, cb_j_real_f64x8);
146
+ // The imaginary part of the product: b.real * c.imag + b.imag * c.real.
147
+ // Swap the imaginary and real parts of `c` before multiplication:
148
+ c_f64x8 = _mm512_permute_pd(c_f64x8, 0x55); //? Same as 0b01010101. Swap adjacent entries within each pair
149
+ cb_j_imag_f64x8 = _mm512_fmadd_pd(c_f64x8, b_f64x8, cb_j_imag_f64x8);
150
+ j += 4;
151
+ if (j < n) goto nk_bilinear_f32c_skylake_cycle;
152
+ // Flip the sign bit in every second scalar before accumulation:
153
+ cb_j_real_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(cb_j_real_f64x8), sign_flip_i64x8));
154
+ // Horizontal sums are the expensive part of the computation:
155
+ nk_f64_t const cb_j_real = _mm512_reduce_add_pd(cb_j_real_f64x8);
156
+ nk_f64_t const cb_j_imag = _mm512_reduce_add_pd(cb_j_imag_f64x8);
157
+ sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag;
158
+ sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real;
159
+ }
160
+
161
+ // Reduce horizontal sums:
162
+ results->real = sum_real;
163
+ results->imag = sum_imag;
164
+ }
165
+
166
+ NK_PUBLIC void nk_bilinear_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
167
+ nk_f64_t *result) {
168
+
169
+ // Default case for arbitrary size `n`
170
+ // Using Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated summation.
171
+ nk_size_t const tail_length = n % 8;
172
+ nk_size_t const tail_start = n - tail_length;
173
+ __m512d sum_f64x8 = _mm512_setzero_pd();
174
+ __m512d compensation_f64x8 = _mm512_setzero_pd();
175
+ __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
176
+
177
+ for (nk_size_t i = 0; i != n; ++i) {
178
+ __m512d a_f64x8 = _mm512_set1_pd(a[i]);
179
+ __m512d cb_j_f64x8 = _mm512_setzero_pd();
180
+ __m512d inner_compensation_f64x8 = _mm512_setzero_pd();
181
+ __m512d b_f64x8, c_f64x8;
182
+ nk_size_t j = 0;
183
+
184
+ nk_bilinear_f64_skylake_cycle:
185
+ if (j + 8 <= n) {
186
+ b_f64x8 = _mm512_loadu_pd(b + j);
187
+ c_f64x8 = _mm512_loadu_pd(c + i * n + j);
188
+ }
189
+ else {
190
+ b_f64x8 = _mm512_maskz_loadu_pd(tail_mask, b + tail_start);
191
+ c_f64x8 = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start);
192
+ }
193
+ // Inner loop Dot2: accumulate cb_j = sum(b[j] * c[i,j])
194
+ // TwoProd: product = b * c, product_error = fma(b, c, -product)
195
+ {
196
+ __m512d product_f64x8 = _mm512_mul_pd(b_f64x8, c_f64x8);
197
+ __m512d product_error_f64x8 = _mm512_fmsub_pd(b_f64x8, c_f64x8, product_f64x8);
198
+ // TwoSum: t = cb_j + product
199
+ __m512d tentative_sum_f64x8 = _mm512_add_pd(cb_j_f64x8, product_f64x8);
200
+ __m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cb_j_f64x8);
201
+ __m512d sum_error_f64x8 = _mm512_add_pd(
202
+ _mm512_sub_pd(cb_j_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
203
+ _mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
204
+ cb_j_f64x8 = tentative_sum_f64x8;
205
+ inner_compensation_f64x8 = _mm512_add_pd(inner_compensation_f64x8,
206
+ _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
207
+ }
208
+ j += 8;
209
+ if (j < n) goto nk_bilinear_f64_skylake_cycle;
210
+
211
+ // Combine inner sum with compensation before outer accumulation
212
+ cb_j_f64x8 = _mm512_add_pd(cb_j_f64x8, inner_compensation_f64x8);
213
+
214
+ // Outer loop Dot2: accumulate sum += a[i] * cb_j
215
+ // TwoProd: product = a * cb_j, product_error = fma(a, cb_j, -product)
216
+ {
217
+ __m512d product_f64x8 = _mm512_mul_pd(a_f64x8, cb_j_f64x8);
218
+ __m512d product_error_f64x8 = _mm512_fmsub_pd(a_f64x8, cb_j_f64x8, product_f64x8);
219
+ // TwoSum: t = sum + product
220
+ __m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, product_f64x8);
221
+ __m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
222
+ __m512d sum_error_f64x8 = _mm512_add_pd(
223
+ _mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
224
+ _mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
225
+ sum_f64x8 = tentative_sum_f64x8;
226
+ compensation_f64x8 = _mm512_add_pd(compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
227
+ }
228
+ }
229
+
230
+ // Final: combine sum + compensation before reduce
231
+ *result = _mm512_reduce_add_pd(_mm512_add_pd(sum_f64x8, compensation_f64x8));
232
+ }
233
+
234
+ NK_PUBLIC void nk_mahalanobis_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
235
+ nk_f64_t *result) {
236
+ // Using Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated summation.
237
+ nk_size_t const tail_length = n % 8;
238
+ nk_size_t const tail_start = n - tail_length;
239
+ __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length);
240
+ __m512d sum_f64x8 = _mm512_setzero_pd();
241
+ __m512d compensation_f64x8 = _mm512_setzero_pd();
242
+
243
+ for (nk_size_t i = 0; i != n; ++i) {
244
+ __m512d diff_i_f64x8 = _mm512_set1_pd(a[i] - b[i]);
245
+ __m512d cdiff_j_f64x8 = _mm512_setzero_pd();
246
+ __m512d inner_compensation_f64x8 = _mm512_setzero_pd();
247
+ __m512d a_j_f64x8, b_j_f64x8, diff_j_f64x8, c_f64x8;
248
+ nk_size_t j = 0;
249
+
250
+ // The nested loop is cleaner to implement with a `goto` in this case:
251
+ nk_mahalanobis_f64_skylake_cycle:
252
+ if (j + 8 <= n) {
253
+ a_j_f64x8 = _mm512_loadu_pd(a + j);
254
+ b_j_f64x8 = _mm512_loadu_pd(b + j);
255
+ c_f64x8 = _mm512_loadu_pd(c + i * n + j);
256
+ }
257
+ else {
258
+ a_j_f64x8 = _mm512_maskz_loadu_pd(tail_mask, a + tail_start);
259
+ b_j_f64x8 = _mm512_maskz_loadu_pd(tail_mask, b + tail_start);
260
+ c_f64x8 = _mm512_maskz_loadu_pd(tail_mask, c + i * n + tail_start);
261
+ }
262
+ diff_j_f64x8 = _mm512_sub_pd(a_j_f64x8, b_j_f64x8);
263
+
264
+ // Inner loop Dot2: accumulate cdiff_j = sum(diff_j * c[i,j])
265
+ // TwoProd: product = diff_j * c, product_error = fma(diff_j, c, -product)
266
+ {
267
+ __m512d product_f64x8 = _mm512_mul_pd(diff_j_f64x8, c_f64x8);
268
+ __m512d product_error_f64x8 = _mm512_fmsub_pd(diff_j_f64x8, c_f64x8, product_f64x8);
269
+ // TwoSum: t = cdiff_j + product
270
+ __m512d tentative_sum_f64x8 = _mm512_add_pd(cdiff_j_f64x8, product_f64x8);
271
+ __m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cdiff_j_f64x8);
272
+ __m512d sum_error_f64x8 = _mm512_add_pd(
273
+ _mm512_sub_pd(cdiff_j_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
274
+ _mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
275
+ cdiff_j_f64x8 = tentative_sum_f64x8;
276
+ inner_compensation_f64x8 = _mm512_add_pd(inner_compensation_f64x8,
277
+ _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
278
+ }
279
+ j += 8;
280
+ if (j < n) goto nk_mahalanobis_f64_skylake_cycle;
281
+
282
+ // Combine inner sum with compensation before outer accumulation
283
+ cdiff_j_f64x8 = _mm512_add_pd(cdiff_j_f64x8, inner_compensation_f64x8);
284
+
285
+ // Outer loop Dot2: accumulate sum += diff_i * cdiff_j
286
+ // TwoProd: product = diff_i * cdiff_j, product_error = fma(diff_i, cdiff_j, -product)
287
+ {
288
+ __m512d product_f64x8 = _mm512_mul_pd(diff_i_f64x8, cdiff_j_f64x8);
289
+ __m512d product_error_f64x8 = _mm512_fmsub_pd(diff_i_f64x8, cdiff_j_f64x8, product_f64x8);
290
+ // TwoSum: t = sum + product
291
+ __m512d tentative_sum_f64x8 = _mm512_add_pd(sum_f64x8, product_f64x8);
292
+ __m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, sum_f64x8);
293
+ __m512d sum_error_f64x8 = _mm512_add_pd(
294
+ _mm512_sub_pd(sum_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
295
+ _mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
296
+ sum_f64x8 = tentative_sum_f64x8;
297
+ compensation_f64x8 = _mm512_add_pd(compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
298
+ }
299
+ }
300
+
301
+ // Final: combine sum + compensation before reduce
302
+ nk_f64_t quadratic = _mm512_reduce_add_pd(_mm512_add_pd(sum_f64x8, compensation_f64x8));
303
+ *result = nk_f64_sqrt_haswell(quadratic > 0 ? quadratic : 0);
304
+ }
305
+
306
+ NK_PUBLIC void nk_bilinear_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
307
+ nk_f64c_t *results) {
308
+
309
+ // We take into account, that FMS is the same as FMA with a negative multiplier.
310
+ // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit.
311
+ // This way we can avoid the shuffling and the need for separate real and imaginary parts.
312
+ // For the imaginary part of the product, we would need to swap the real and imaginary parts of
313
+ // one of the vectors.
314
+ // Using Dot2 algorithm (Ogita-Rump-Oishi 2005) for compensated summation.
315
+ __m512i const sign_flip_i64x8 = _mm512_set_epi64( //
316
+ 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, //
317
+ 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 //
318
+ );
319
+
320
+ // Default case for arbitrary size `n`
321
+ nk_size_t const tail_length = n % 4;
322
+ nk_size_t const tail_start = n - tail_length;
323
+ __mmask8 const tail_mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, tail_length * 2);
324
+ nk_f64_t sum_real = 0;
325
+ nk_f64_t sum_imag = 0;
326
+ nk_f64_t compensation_real = 0;
327
+ nk_f64_t compensation_imag = 0;
328
+
329
+ for (nk_size_t i = 0; i != n; ++i) {
330
+ nk_f64_t const a_i_real = a[i].real;
331
+ nk_f64_t const a_i_imag = a[i].imag;
332
+ __m512d cb_j_real_f64x8 = _mm512_setzero_pd();
333
+ __m512d cb_j_imag_f64x8 = _mm512_setzero_pd();
334
+ __m512d compensation_real_f64x8 = _mm512_setzero_pd();
335
+ __m512d compensation_imag_f64x8 = _mm512_setzero_pd();
336
+ __m512d b_f64x8, c_f64x8;
337
+ nk_size_t j = 0;
338
+
339
+ nk_bilinear_f64c_skylake_cycle:
340
+ if (j + 4 <= n) {
341
+ b_f64x8 = _mm512_loadu_pd((nk_f64_t const *)(b + j));
342
+ c_f64x8 = _mm512_loadu_pd((nk_f64_t const *)(c + i * n + j));
343
+ }
344
+ else {
345
+ b_f64x8 = _mm512_maskz_loadu_pd(tail_mask, (nk_f64_t const *)(b + tail_start));
346
+ c_f64x8 = _mm512_maskz_loadu_pd(tail_mask, (nk_f64_t const *)(c + i * n + tail_start));
347
+ }
348
+ // The real part of the product: b.real * c.real - b.imag * c.imag.
349
+ // The subtraction will be performed later with a sign flip.
350
+ // Inner loop Dot2 for real accumulator
351
+ {
352
+ __m512d product_f64x8 = _mm512_mul_pd(c_f64x8, b_f64x8);
353
+ __m512d product_error_f64x8 = _mm512_fmsub_pd(c_f64x8, b_f64x8, product_f64x8);
354
+ __m512d tentative_sum_f64x8 = _mm512_add_pd(cb_j_real_f64x8, product_f64x8);
355
+ __m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cb_j_real_f64x8);
356
+ __m512d sum_error_f64x8 = _mm512_add_pd(
357
+ _mm512_sub_pd(cb_j_real_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
358
+ _mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
359
+ cb_j_real_f64x8 = tentative_sum_f64x8;
360
+ compensation_real_f64x8 = _mm512_add_pd(compensation_real_f64x8,
361
+ _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
362
+ }
363
+ // The imaginary part of the product: b.real * c.imag + b.imag * c.real.
364
+ // Swap the imaginary and real parts of `c` before multiplication:
365
+ c_f64x8 = _mm512_permute_pd(c_f64x8, 0x55); //? Same as 0b01010101.
366
+ // Inner loop Dot2 for imaginary accumulator
367
+ {
368
+ __m512d product_f64x8 = _mm512_mul_pd(c_f64x8, b_f64x8);
369
+ __m512d product_error_f64x8 = _mm512_fmsub_pd(c_f64x8, b_f64x8, product_f64x8);
370
+ __m512d tentative_sum_f64x8 = _mm512_add_pd(cb_j_imag_f64x8, product_f64x8);
371
+ __m512d virtual_addend_f64x8 = _mm512_sub_pd(tentative_sum_f64x8, cb_j_imag_f64x8);
372
+ __m512d sum_error_f64x8 = _mm512_add_pd(
373
+ _mm512_sub_pd(cb_j_imag_f64x8, _mm512_sub_pd(tentative_sum_f64x8, virtual_addend_f64x8)),
374
+ _mm512_sub_pd(product_f64x8, virtual_addend_f64x8));
375
+ cb_j_imag_f64x8 = tentative_sum_f64x8;
376
+ compensation_imag_f64x8 = _mm512_add_pd(compensation_imag_f64x8,
377
+ _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
378
+ }
379
+ j += 4;
380
+ if (j < n) goto nk_bilinear_f64c_skylake_cycle;
381
+
382
+ // Flip the sign bit in every second scalar before accumulation:
383
+ cb_j_real_f64x8 = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(cb_j_real_f64x8), sign_flip_i64x8));
384
+ compensation_real_f64x8 = _mm512_castsi512_pd(
385
+ _mm512_xor_si512(_mm512_castpd_si512(compensation_real_f64x8), sign_flip_i64x8));
386
+
387
+ // Combine inner sums with compensation before horizontal reduce
388
+ cb_j_real_f64x8 = _mm512_add_pd(cb_j_real_f64x8, compensation_real_f64x8);
389
+ cb_j_imag_f64x8 = _mm512_add_pd(cb_j_imag_f64x8, compensation_imag_f64x8);
390
+
391
+ // Horizontal sums are the expensive part of the computation:
392
+ nk_f64_t const cb_j_real = _mm512_reduce_add_pd(cb_j_real_f64x8);
393
+ nk_f64_t const cb_j_imag = _mm512_reduce_add_pd(cb_j_imag_f64x8);
394
+
395
+ // Outer loop Dot2 for real part: sum_real += a_i_real * cb_j_real - a_i_imag * cb_j_imag
396
+ {
397
+ // First term: a_i_real * cb_j_real
398
+ nk_f64_t product1 = a_i_real * cb_j_real;
399
+ nk_f64_t product_error1 = (a_i_real * cb_j_real) - product1;
400
+ // Second term: -a_i_imag * cb_j_imag
401
+ nk_f64_t product2 = a_i_imag * cb_j_imag;
402
+ nk_f64_t product_error2 = (a_i_imag * cb_j_imag) - product2;
403
+ // TwoSum for first addition: t = sum_real + product1
404
+ nk_f64_t t1 = sum_real + product1;
405
+ nk_f64_t z1 = t1 - sum_real;
406
+ nk_f64_t sum_error1 = (sum_real - (t1 - z1)) + (product1 - z1);
407
+ sum_real = t1;
408
+ compensation_real += sum_error1 + product_error1;
409
+ // TwoSum for subtraction: t = sum_real - product2
410
+ nk_f64_t t2 = sum_real - product2;
411
+ nk_f64_t z2 = t2 - sum_real;
412
+ nk_f64_t sum_error2 = (sum_real - (t2 - z2)) + (-product2 - z2);
413
+ sum_real = t2;
414
+ compensation_real += sum_error2 - product_error2;
415
+ }
416
+
417
+ // Outer loop Dot2 for imaginary part: sum_imag += a_i_real * cb_j_imag + a_i_imag * cb_j_real
418
+ {
419
+ // First term: a_i_real * cb_j_imag
420
+ nk_f64_t product1 = a_i_real * cb_j_imag;
421
+ nk_f64_t product_error1 = (a_i_real * cb_j_imag) - product1;
422
+ // Second term: a_i_imag * cb_j_real
423
+ nk_f64_t product2 = a_i_imag * cb_j_real;
424
+ nk_f64_t product_error2 = (a_i_imag * cb_j_real) - product2;
425
+ // TwoSum for first addition: t = sum_imag + product1
426
+ nk_f64_t t1 = sum_imag + product1;
427
+ nk_f64_t z1 = t1 - sum_imag;
428
+ nk_f64_t sum_error1 = (sum_imag - (t1 - z1)) + (product1 - z1);
429
+ sum_imag = t1;
430
+ compensation_imag += sum_error1 + product_error1;
431
+ // TwoSum for second addition: t = sum_imag + product2
432
+ nk_f64_t t2 = sum_imag + product2;
433
+ nk_f64_t z2 = t2 - sum_imag;
434
+ nk_f64_t sum_error2 = (sum_imag - (t2 - z2)) + (product2 - z2);
435
+ sum_imag = t2;
436
+ compensation_imag += sum_error2 + product_error2;
437
+ }
438
+ }
439
+
440
+ // Final: combine sum + compensation
441
+ results->real = sum_real + compensation_real;
442
+ results->imag = sum_imag + compensation_imag;
443
+ }
444
+
445
+ #if defined(__clang__)
446
+ #pragma clang attribute pop
447
+ #elif defined(__GNUC__)
448
+ #pragma GCC pop_options
449
+ #endif
450
+
451
+ #if defined(__cplusplus)
452
+ } // extern "C"
453
+ #endif
454
+
455
+ #endif // NK_TARGET_SKYLAKE
456
+ #endif // NK_TARGET_X86_
457
+ #endif // NK_CURVED_SKYLAKE_H